In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Linear, Parameter
from torch_geometric.datasets import WebKB, Planetoid, WikipediaNetwork
from torch_geometric.nn import GCNConv, VGAE
#from torch_geometric.utils import train_test_split_edges
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import negative_sampling
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
dataset = WebKB(root="/home/siddy/META/data", name="Cornell")
data = dataset[0].to(device)

In [14]:
from collections import defaultdict, deque
def neg_index(num_nodes, edge_index):
    adj_list = defaultdict(list)
    for i in range(edge_index.size(1)):
        u, v = int(edge_index[0, i]), int(edge_index[1, i])
        adj_list[u].append(v)
        adj_list[v].append(u)

    neg_indices = []
    for i in range(num_nodes):
        visited = set()
        queue = deque([(i, 0)])
        visited.add(i)
        count_two_away = 0
        while queue:
            node, distance = queue.popleft()
            if distance == 2:
                count_two_away += 1
                continue
            for neighbor in adj_list[node]:
                if neighbor not in visited:
                    visited.add(neighbor)
                    queue.append((neighbor, distance + 1))
        sum_adj_i = len(adj_list[i])
        neg_index = count_two_away / sum_adj_i if sum_adj_i > 0 else 0
        neg_indices.append(neg_index)

    return adj_list, neg_indices

In [15]:
import torch
from collections import defaultdict, deque

def pos_index(num_nodes, edge_index):
    adj_list = defaultdict(list)
    for i in range(edge_index.size(1)):
        u, v = int(edge_index[0, i]), int(edge_index[1, i])
        adj_list[u].append(v)
        adj_list[v].append(u)

    pos_indices = []
    for i in range(edge_index.size(1)):
        u, v = int(edge_index[0, i]), int(edge_index[1, i])
        common_neighbors = 0
        visited = set()
        queue = deque(adj_list[u])
        visited.add(u)

        while queue:
            node = queue.popleft()
            if node == v:
                common_neighbors += 1
            visited.add(node)
            for neighbor in adj_list[node]:
                if neighbor not in visited:
                    queue.append(neighbor)
        
        sum_adj_u = len(adj_list[u])
        sum_adj_v = len(adj_list[v])
        pos_index = common_neighbors / (sum_adj_u + sum_adj_v - common_neighbors)
        pos_indices.append(pos_index)

    return pos_indices


In [16]:
num_nodes = data.x.size(0)
edge_index = data.edge_index

In [17]:
adj_list, neg_indices = neg_index(num_nodes, edge_index)
pos_indices = pos_index(num_nodes, edge_index)

In [8]:
def neg_index_bitwise(num_nodes, edge_index):
    adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.long)
    for i in range(edge_index.size(1)):
        u, v = int(edge_index[0, i]), int(edge_index[1, i])
        adj_matrix[u, v] = 1
        adj_matrix[v, u] = 1

    neg_indices = []
    for i in range(edge_index.size(1)):
        u, v = int(edge_index[0, i]), int(edge_index[1, i])
        adj_squared = torch.mm(adj_matrix, adj_matrix)
        adj_row_u = adj_matrix[u]
        adj_row_v = adj_matrix[v]
        bitwise_and_result = torch.bitwise_and(adj_squared[u], adj_row_v[v])
        sum_bitwise_and_result = torch.sum(bitwise_and_result)
        sum_adj_u = torch.sum(adj_row_u)
        sum_adj_v = torch.sum(adj_row_v)
        neg_index = sum_bitwise_and_result.float() / (sum_adj_u + sum_adj_v)
        neg_indices.append(neg_index.item())

    return neg_indices

In [11]:
neg_indices_2 = neg_index_bitwise(num_nodes, edge_index)

In [10]:
def pos_index_bitwise(num_nodes, edge_index):
    adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.long)
    for i in range(edge_index.size(1)):
        u, v = int(edge_index[0, i]), int(edge_index[1, i])
        adj_matrix[u, v] = 1
        adj_matrix[v, u] = 1

    pos_indices = []
    for i in range(edge_index.size(1)):
        u, v = int(edge_index[0, i]), int(edge_index[1, i])
        adj_row_u = adj_matrix[u]
        adj_row_v = adj_matrix[v]
        bitwise_and_result = torch.bitwise_and(adj_row_u, adj_row_v)
        sum_bitwise_and_result = torch.sum(bitwise_and_result)
        sum_adj_u = torch.sum(adj_row_u)
        sum_adj_v = torch.sum(adj_row_v)
        pos_index = sum_bitwise_and_result.float() / (sum_adj_u + sum_adj_v)
        pos_indices.append(pos_index.item())

    return pos_indices

In [11]:
pos_indices_2 = pos_index_bitwise(num_nodes, edge_index)

In [16]:
pos_indices == pos_indices_2

False

In [18]:
neg_indices==neg_indices_2

False