In [1]:
!pip install torch_geometric networkx lxml

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m15.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [2]:
from google.colab import files
uploaded = files.upload()

Saving CollegeFootball.graphml to CollegeFootball.graphml


In [45]:
import networkx as nx
import torch
from torch_geometric.utils import from_networkx, train_test_split_edges
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

In [46]:
import torch_geometric.transforms as T
import torch

filename = list(uploaded.keys())[0]
G_nx = nx.read_graphml(filename)
data = from_networkx(G_nx)
num_nodes = data.num_nodes
data.x = torch.eye(num_nodes)

# Keep only essential attributes for RandomLinkSplit
keys_to_keep = ['x', 'edge_index']
all_keys = list(data.keys())
for key in all_keys:
    if key not in keys_to_keep:
        del data[key]

# --- Diagnosis: Inspect data after aggressive cleaning and before transform ---
print("Data object after aggressive cleaning and before transform:", data)
if hasattr(data, 'edge_attr') and data.edge_attr is not None:
    print("Edge attributes after aggressive cleaning and before transform:", data.edge_attr)
if hasattr(data, 'edge_index') and data.edge_index is not None:
    print("Edge index after aggressive cleaning and before transform:", data.edge_index)
# ------------------------------------------------


# Use RandomLinkSplit transform
transform = T.RandomLinkSplit(is_undirected=True)
global train_data, val_data, test_data
train_data, val_data, test_data = transform(data)

Data object after aggressive cleaning and before transform: Data(x=[115, 115], edge_index=[2, 1226])
Edge index after aggressive cleaning and before transform: tensor([[  0,   0,   0,  ..., 114, 114, 114],
        [  1,   4,   9,  ...,  88, 104, 110]])


In [47]:
class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, 16)
        self.conv2 = GCNConv(16, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        return self.conv2(x, edge_index)

def decode(z, edge_index):
    return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=1)

In [48]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCNEncoder(data.x.size(-1), 16).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# --- Diagnosis: Inspect train_data, val_data, test_data ---
print("Type of train_data:", type(train_data))
print("Attributes of train_data:", dir(train_data))
if hasattr(train_data, 'pos_edge_index'):
    print("train_data.pos_edge_index exists and has shape:", train_data.pos_edge_index.shape)
else:
    print("train_data does NOT have pos_edge_index attribute")

print("Type of val_data:", type(val_data))
print("Attributes of val_data:", dir(val_data))
if hasattr(val_data, 'pos_edge_index'):
    print("val_data.pos_edge_index exists and has shape:", val_data.pos_edge_index.shape)
else:
     print("val_data does NOT have pos_edge_index attribute")

print("Type of test_data:", type(test_data))
print("Attributes of test_data:", dir(test_data))
if hasattr(test_data, 'pos_edge_index'):
    print("test_data.pos_edge_index exists and has shape:", test_data.pos_edge_index.shape)
else:
    print("test_data does NOT have pos_edge_index attribute")
# ------------------------------------------------

def get_negative_samples(edge_index, num_nodes, num_samples):
    """Randomly samples negative edges."""
    neg_edge_index = torch.randint(0, num_nodes, (2, num_samples), dtype=torch.long, device=edge_index.device)
    # Optional: Filter out existing edges if necessary, but for large graphs random sampling is usually sufficient
    return neg_edge_index

for epoch in range(101):
    model.train()
    optimizer.zero_grad()
    # Use train_data.edge_index for message passing in the GCN
    z = model(train_data.x.to(device), train_data.edge_index.to(device))

    # Use train_data.edge_index as positive examples for training loss
    pos_train_edge_index = train_data.edge_index.to(device)
    pos_pred = decode(z, pos_train_edge_index)
    pos_label = torch.ones(pos_pred.size(0), device=device)

    # Manually generate negative examples for training loss
    neg_train_edge_index = get_negative_samples(pos_train_edge_index, num_nodes, pos_train_edge_index.size(1))
    neg_pred = decode(z, neg_train_edge_index)
    neg_label = torch.zeros(neg_pred.size(0), device=device)

    loss = F.binary_cross_entropy_with_logits(
        torch.cat([pos_pred, neg_pred]),
        torch.cat([pos_label, neg_label])
    )
    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        print(f"Epoch {epoch:03d}, Loss: {loss:.4f}")

Type of train_data: <class 'torch_geometric.data.data.Data'>
Attributes of train_data: ['__abstractmethods__', '__annotations__', '__call__', '__cat_dim__', '__class__', '__contains__', '__copy__', '__deepcopy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__inc__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_edge_attr_cls', '_edge_to_layout', '_edges_to_layout', '_get_edge_index', '_get_tensor', '_get_tensor_size', '_multi_get_tensor', '_put_edge_index', '_put_tensor', '_remove_edge_index', '_remove_tensor', '_store', '_tensor_attr_cls', '_to_type', 'apply', 'apply_', 'batch', 'clone', 'coalesce', 'conca

In [49]:
from sklearn.metrics import roc_auc_score, accuracy_score

model.eval()
with torch.no_grad():
    # Use the full graph's edge_index and node features for final embeddings
    z = model(data.x.to(device), data.edge_index.to(device))

    # Evaluate on test set using test_data.edge_index as positive examples
    pos_test_edge_index = test_data.edge_index.to(device)
    pos_test_pred = decode(z, pos_test_edge_index)

    # Manually generate negative examples for testing
    # Use the original data's num_nodes for generating negative samples across the whole graph
    neg_test_edge_index = get_negative_samples(pos_test_edge_index, num_nodes, pos_test_edge_index.size(1))
    neg_test_pred = decode(z, neg_test_edge_index)


    # Calculate AUC
    preds = torch.cat([pos_test_pred, neg_test_pred]).cpu().numpy()
    labels = torch.cat([torch.ones(pos_test_pred.size(0)), torch.zeros(neg_test_pred.size(0))]).cpu().numpy()
    auc = roc_auc_score(labels, preds)

    # Calculate Accuracy (simple threshold at 0.5)
    predicted_labels = (preds > 0.5).astype(float)
    accuracy = accuracy_score(labels, predicted_labels)

    print(f"\nTest AUC: {auc:.4f}")
    print(f"Test Accuracy: {accuracy:.4f}")

# Optional: Predict links on the entire graph based on learned embeddings
# adj_pred = torch.sigmoid(torch.matmul(z, z.t()))
# predicted_edges = (adj_pred > 0.9).nonzero(as_tuple=False).t()
# print("\nTop predicted links (node pairs):")
# print(predicted_edges[:, :10])


Test AUC: 0.9087
Test Accuracy: 0.8335


In [51]:
print("train_data:", train_data)
print("val_data:", val_data)
print("test_data:", test_data)

train_data: Data(x=[115, 115], edge_index=[2, 860], edge_label=[860], edge_label_index=[2, 860])
val_data: Data(x=[115, 115], edge_index=[2, 860], edge_label=[122], edge_label_index=[2, 122])
test_data: Data(x=[115, 115], edge_index=[2, 982], edge_label=[244], edge_label_index=[2, 244])
