In [None]:
%load_ext autoreload
%autoreload 2   

In [None]:
import torch
import torch.nn as nn
from torch_geometric.datasets import KarateClub, TUDataset
from torch_geometric.utils import to_networkx, scatter, to_dense_adj
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import networkx as nx
from matplotlib import pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from icecream import ic
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import add_self_loops, degree, to_dense_adj
from sklearn.metrics import jaccard_score
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA


In [None]:
device = "mps" if torch.backends.mps.is_available() else "cpu"

In [None]:
def compute_jaccard_similarity(edge_index, num_nodes):
    adj = to_dense_adj(edge_index, max_num_nodes=num_nodes).squeeze(0).numpy()
    jaccard_matrix = np.zeros((num_nodes, num_nodes))
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i != j:
                jaccard_matrix[i, j] = jaccard_score(adj[i], adj[j])
    return torch.from_numpy(jaccard_matrix).float()


dataset = KarateClub()
data = dataset[0]

target_jaccard = compute_jaccard_similarity(data.edge_index, data.num_nodes).to(device)


In [None]:
class SimpleGCNConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SimpleGCNConv, self).__init__()
        self.lin = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        x = self.lin(x)
        row, col = edge_index
        out = torch.zeros(x.size(0), x.size(1)).to(x.device)
        for i in range(edge_index.size(1)):
            source_node = col[i]
            target_node = row[i]
            out[target_node] += norm[i] * x[source_node]
        return out


class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = SimpleGCNConv(in_channels, hidden_channels)
        self.conv2 = SimpleGCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x


In [None]:
model = GCN(dataset.num_node_features, 16, data.num_nodes).to(device)
data = data.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.mse_loss(out, target_jaccard)
    loss.backward()
    optimizer.step()
    if epoch % 20 == 0:
        print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}")

model.eval()
embeddings = model(data.x, data.edge_index).cpu().detach().numpy()


In [None]:
pca = PCA(n_components=2)
Z_2d = pca.fit_transform(embeddings)
labels = data.y.cpu().numpy()
num_nodes = data.num_nodes
# Visualize in 2D
plt.figure(figsize=(8, 6))
scatter = plt.scatter(Z_2d[:, 0], Z_2d[:, 1], c=labels, cmap="viridis", s=100)
plt.colorbar(scatter, label="Community Label")
plt.title("2D Visualization of Node Embeddings (Zachary Karate Club) with PCA")
plt.xlabel("PCA Dimension 1")
plt.ylabel("PCA Dimension 2")
for i in range(num_nodes):  # Optional: label nodes
    plt.text(Z_2d[i, 0], Z_2d[i, 1], str(i), fontsize=8, ha="right")
plt.show()