In [2]:
!pip install torch_geometric torch torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.2.0+cu118.html


Looking in links: https://data.pyg.org/whl/torch-2.2.0+cu118.html
Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.2.0%2Bcu118/torch_scatter-2.1.2%2Bpt22cu118-cp311-cp311-linux_x86_64.whl (10.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m76.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.2.0%2Bcu118/torch_sparse-0.6.18%2Bpt22cu118-cp311-cp311-linux_x86_64.whl (4.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.9/4.9 MB[0m [31m25.1 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecti

In [15]:
import networkx as nx
from torch_geometric.utils import from_networkx
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import community.community_louvain as community_louvain
from networkx.algorithms.community import label_propagation_communities, girvan_newman, asyn_lpa_communities
import itertools
import ipywidgets as widgets
from IPython.display import display, clear_output


In [8]:
!pip install python-louvain




In [18]:
# Upload and read the GraphML file
from google.colab import files
uploaded = files.upload()

# Load the first uploaded file
import io
filename = list(uploaded.keys())[0]
G = nx.read_graphml(io.BytesIO(uploaded[filename]))

# Ensure nodes have string type
G = nx.convert_node_labels_to_integers(G, label_attribute='original_id')

Saving cust_item_network.graphml to cust_item_network.graphml


In [19]:

# Apply Louvain community detection
partition = community_louvain.best_partition(G)
louvain_labels = list(partition.values())

# One-hot features
node_list = list(G.nodes())
for i, node in enumerate(node_list):
    G.nodes[node]['x'] = [int(i == j) for j in range(len(node_list))]
    G.nodes[node]['y'] = louvain_labels[i]

# Convert to PyG data format
from torch_geometric.utils import from_networkx
data = from_networkx(G)
data.x = torch.tensor([G.nodes[n]['x'] for n in node_list], dtype=torch.float)
data.y = torch.tensor([G.nodes[n]['y'] for n in node_list], dtype=torch.long)
data.train_mask = torch.ones(data.num_nodes, dtype=torch.bool)


In [20]:

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(data.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, len(set(louvain_labels)))

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.conv2(x, edge_index)
        return x

# Train GCN
model = GCN(hidden_channels=8)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

for epoch in range(101):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out, data.y)
    loss.backward()
    optimizer.step()

# GCN predictions
model.eval()
pred = model(data.x, data.edge_index).argmax(dim=1).numpy()


In [21]:

# Interactive toggle for node labels
toggle = widgets.Checkbox(
    value=True,
    description='Show Node Labels',
    disabled=False
)

def plot_graphs(show_labels):
    pos = nx.spring_layout(G, seed=42)
    fig, axes = plt.subplots(1, 2, figsize=(16, 7))

    nx.draw(G, pos, node_color=louvain_labels, cmap='tab10', with_labels=show_labels,
            ax=axes[0], node_size=500, font_weight='bold')
    axes[0].set_title("Louvain Communities")

    nx.draw(G, pos, node_color=pred.tolist(), cmap='tab10', with_labels=show_labels,
            ax=axes[1], node_size=500, font_weight='bold')
    axes[1].set_title("GCN-Predicted Communities")

    plt.tight_layout()
    plt.show()

widgets.interact(plot_graphs, show_labels=toggle)


interactive(children=(Checkbox(value=True, description='Show Node Labels'), Output()), _dom_classes=('widget-i…