In [1]:
from sympy.printing.pytorch import torch
import networkx as nx
import numpy as np
from torch_geometric.datasets import KarateClub

# Load dataset
dataset = KarateClub()
data = dataset[0]

# Load the original NetworkX karate club dataset (has 2 labels)
G = nx.karate_club_graph()

# Ground truth: 0 = "Mr. Hi", 1 = "Officer"
true_labels = np.array([0 if G.nodes[i]['club'] == 'Mr. Hi' else 1 for i in G.nodes()])

# Get adjacency matrix
A = nx.to_numpy_array(G)

print("Adjacency matrix shape:", A.shape)
print("Number of nodes:", len(G.nodes))
print("Ground truth labels:", true_labels)



# For graph neural network method

# Get indices of each class
class0_idx = (data.y == 0).nonzero(as_tuple=True)[0]
class1_idx = (data.y == 1).nonzero(as_tuple=True)[0]

# Shuffle indices
class0_idx = class0_idx[torch.randperm(len(class0_idx))]
class1_idx = class1_idx[torch.randperm(len(class1_idx))]

# Split 50% for train, 50% for test
train_class0 = class0_idx[:len(class0_idx)//2]
test_class0  = class0_idx[len(class0_idx)//2:]
train_class1 = class1_idx[:len(class1_idx)//2]
test_class1  = class1_idx[len(class1_idx)//2:]

# Combine
train_idx = torch.cat([train_class0, train_class1])
test_idx  = torch.cat([test_class0, test_class1])

# Make boolean masks
train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
train_mask[train_idx] = True
test_mask[test_idx] = True

# Save into data
data.train_mask = train_mask
data.test_mask = test_mask

print("Train nodes:", train_idx.tolist())
print("Test nodes:", test_idx.tolist())
print("Class balance in train set:", data.y[train_mask].bincount().tolist())
print("Class balance in test set:", data.y[test_mask].bincount().tolist())


Adjacency matrix shape: (34, 34)
Number of nodes: 34
Ground truth labels: [0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 0 0 1 0 1 0 1 1 1 1 1 1 1 1 1 1 1 1]
Train nodes: [20, 23, 27, 30, 14, 15, 3, 12, 19, 0, 17, 9]
Test nodes: [8, 18, 32, 26, 29, 22, 33, 13, 11, 21, 7, 1, 2]
Class balance in train set: [6, 6]
Class balance in test set: [7, 6]
