In [1]:
import os
# import gae.input_data as id
import scipy.sparse as sp
import scipy.io
import torch
from torch_geometric.utils import remove_self_loops, remove_isolated_nodes , is_undirected, to_undirected


def load_data(data_source):
    data = scipy.io.loadmat("gae/data/{}.mat".format(data_source))
    # labels = data["gnd"]
    return data

def sparse_matrix_to_edge_index(sparse_matrix):
    
    # Ensure the matrix is in COO format
    sparse_matrix = sparse_matrix.tocoo()

    # Extract row and column indices
    row = torch.tensor(sparse_matrix.row, dtype=torch.long)
    col = torch.tensor(sparse_matrix.col, dtype=torch.long)

    # Stack indices to form edge_index
    edge_index = torch.stack([row, col], dim=0)

    return edge_index

def has_repeating_edges(edge_index):
    
    edges = list(zip(edge_index[0].tolist(), edge_index[1].tolist()))

    # Use a set to track seen edges
    seen_edges = set()

    # Check for duplicates
    for edge in edges:
        if edge in seen_edges:
            return True
        seen_edges.add(edge)

    return False

def keep_bidirectional_edges(edge_index):
    """
    Keep only bidirectional edges in the edge_index tensor.

    Parameters:
    edge_index (torch.Tensor): The edge_index tensor of shape [2, num_edges].

    Returns:
    torch.Tensor: The edge_index tensor with only bidirectional edges.
    """
    # Convert edge_index to a list of tuples
    edges = list(zip(edge_index[0].tolist(), edge_index[1].tolist()))

    # Use a set to track seen edges
    seen_edges = set()
    bidirectional_edges = []

    for edge in edges:
        if (edge[1], edge[0]) in seen_edges:
            bidirectional_edges.append(edge)
            bidirectional_edges.append((edge[1], edge[0]))
        seen_edges.add(edge)

    # Convert back to tensor
    bidirectional_edges = torch.tensor(bidirectional_edges, dtype=torch.long).t()

    return bidirectional_edges

def remove_duplicate_edges(edge_index):
    """
    Remove duplicate edges from the edge_index tensor.

    Parameters:
    edge_index (torch.Tensor): The edge_index tensor of shape [2, num_edges].

    Returns:
    torch.Tensor: The edge_index tensor with duplicate edges removed.
    """
    # Convert edge_index to a list of tuples
    edges = list(zip(edge_index[0].tolist(), edge_index[1].tolist()))

    # Use a set to track unique edges
    unique_edges = list(set(edges))

    # Convert back to tensor
    unique_edges = torch.tensor(unique_edges, dtype=torch.long).t()

    return unique_edges


def get_undirected_edges(edge_index):
    """
    Identify and return undirected edges from the edge_index tensor.

    Parameters:
    edge_index (torch.Tensor): The edge_index tensor of shape [2, num_edges].

    Returns:
    torch.Tensor: The edge_index tensor with only undirected edges.
    """
    # Convert edge_index to a list of tuples
    edges = list(zip(edge_index[0].tolist(), edge_index[1].tolist()))

    # Use a set to track seen edges
    seen_edges = set()
    undirected_edges = []

    for edge in edges:
        if (edge[1], edge[0]) in seen_edges:
            undirected_edges.append(edge)
            undirected_edges.append((edge[1], edge[0]))
        seen_edges.add(edge)

    # Convert back to tensor
    undirected_edges = torch.tensor(undirected_edges, dtype=torch.long).t()

    return undirected_edges

def get_directed_edges(edge_index):
    """
    Identify and return directed edges from the edge_index tensor.

    Parameters:
    edge_index (torch.Tensor): The edge_index tensor of shape [2, num_edges].

    Returns:
    torch.Tensor: The edge_index tensor with only directed edges.
    """
    # Convert edge_index to a list of tuples
    edges = list(zip(edge_index[0].tolist(), edge_index[1].tolist()))

    # Use a set to track seen edges
    seen_edges = set(edges)
    directed_edges = []

    for edge in edges:
        if (edge[1], edge[0]) not in seen_edges:
            directed_edges.append(edge)

    # Convert back to tensor
    directed_edges = torch.tensor(directed_edges, dtype=torch.long).t()

    return directed_edges

def check_indices_in_edge_index(indices, edge_index):
    """
    Check if the edge_index array contains the indices in the list.

    Parameters:
    indices (list of tuples): List of index tuples to check.
    edge_index (torch.Tensor): The edge_index tensor of shape [2, num_edges].

    Returns:
    list of bool: Boolean array indicating the presence of each index in the edge_index.
    """
    # Convert edge_index to a set of tuples for efficient lookup
    edges_set = set(zip(edge_index[0].tolist(), edge_index[1].tolist()))

    # Check if each index is in the edges_set
    result = [(i, j) in edges_set for i, j in indices]

    return result

def check_anomalies_in_edge_index(edge_index, anomaly_indices):
    """
    Check if the edge_index contains indices from the anomaly_indices array.

    Parameters:
    edge_index (torch.Tensor): The edge_index tensor of shape [2, num_edges].
    anomaly_indices (torch.Tensor): The tensor containing anomaly indices.

    Returns:
    torch.Tensor: Boolean tensor indicating the presence of anomaly indices in the edge_index.
    """
    # Convert anomaly_indices to a set for efficient lookup
    anomaly_set = set(anomaly_indices.tolist())

    # Check if each edge contains an anomaly index
    result = [(i in anomaly_set and j in anomaly_set) for i, j in zip(edge_index[0].tolist(), edge_index[1].tolist())]

    # Convert result to a tensor
    result_tensor = torch.tensor(result, dtype=torch.bool)

    return result_tensor

In [None]:


#! i think what happened is that they added edges where the anomalies were.
# todo: 1. load all of the new datasets and make a scheme for the server.
#todo:  2. see if the semi supervised method works.

# Example usage
res = load_data('Flickr')
sparse_matrix = res['Network']
edge_index = sparse_matrix_to_edge_index(sparse_matrix)
edge_index, _ = remove_self_loops(edge_index)
# print('original edge_index')
# print(edge_index.shape[1]/2)

labels = torch.from_numpy(res['Label']).squeeze()
anomaly_indices = torch.where(labels == 1)[0]

directed_edges = get_directed_edges(edge_index)

# Check if the anomaly indices are in the directed edges

# anomalies_in_directed = check_indices_in_edge_index(anomaly_indices, directed_edges)
anomalies_directed = check_anomalies_in_edge_index(directed_edges, anomaly_indices)
#todo: these directed edges are probably the connections between anomalies and themselves! removing these will make heterophilic anomalies and keeping them will make them homophilic
#! are there any other edges in the edge index that are between two anomalies?
normal_anomalies = check_anomalies_in_edge_index(edge_index, anomaly_indices)

print(f'directed shape: {directed_edges.shape[1]}')
print(f'anomalies_directed shape: {anomalies_directed.shape[0]}')
print(f'undirected edges between anomalies: {torch.where(normal_anomalies)[0]}')
print(edge_index[:,38955])
print(anomaly_indices)


In [None]:

# PROCESS EDGE_INDEX
#====================

# edge_index = keep_bidirectional_edges(edge_index)
# print('after keep_bidirectional_edges')
# print(edge_index.shape[1]/2)
edge_index = remove_duplicate_edges(edge_index)
print('after remove_duplicate_edges')
print(edge_index.shape[1]/2)
edge_index, _, _ = remove_isolated_nodes(edge_index)
print('after remove_isolated_nodes')
print(edge_index.shape[1]/2)
edge_index, _ = remove_self_loops(edge_index)
print('after remove_self_loops')
print(edge_index.shape[1]/2)
print(is_undirected(edge_index))
print(edge_index.max())

print(edge_index.shape[1]/2 - 171743)
