In [11]:
import os
import networkx as nx
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import from_networkx
from torch_geometric.utils import subgraph
from torch_geometric.data import Data

In [12]:

dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]


folder_path = 'client_subgraphs'
sub_data_list = []

for filename in os.listdir(folder_path):
    if filename.endswith('.gml'):
      
        file_path = os.path.join(folder_path, filename)
        g = nx.read_gml(file_path)

        subgraph_nodes = list(g.nodes)
        subgraph_nodes = [int(node) for node in subgraph_nodes]  # Convert to integer if they are not

        sub_edge_index, _ = subgraph(subgraph_nodes, data.edge_index, relabel_nodes=True)

        sub_data = Data(x=data.x[subgraph_nodes], edge_index=sub_edge_index, y=data.y[subgraph_nodes])
        sub_data_list.append(sub_data)

# for filename in os.listdir(folder_path):
#     if filename.endswith('.gml'):
#         # Read GML file using networkx
#         file_path = os.path.join(folder_path, filename)
#         g = nx.read_gml(file_path)

#         # Convert networkx graph to PyTorch Geometric data
#         sub_data = from_networkx(g)
#         # Ensure node features and labels are set (this will depend on how data is stored in the GML file)

#         # Example: Set dummy features and labels if not present
#         if sub_data.x is None:
#             num_nodes = sub_data.num_nodes
#             sub_data.x = torch.randn((num_nodes, data.num_node_features))  # Replace with actual node features
#         if sub_data.y is None:
#             sub_data.y = torch.randint(0, dataset.num_classes, (sub_data.num_nodes,))  # Replace with actual labels
        
#         sub_data_list.append(sub_data)

In [14]:
print(sub_data_list[0])

Data(x=[18, 1433], edge_index=[2, 58], y=[18])


In [20]:

class GCN(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, 16)
        self.conv2 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # First convolutional layer
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        # Second convolutional layer
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)



for i in range(0, 100):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = GCN(sub_data.num_node_features, dataset.num_classes).to(device)
    sub_data_train = sub_data_list[i].to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

    model.train()
    for epoch in range(200):
        optimizer.zero_grad()
        out = model(sub_data_train)
        loss = F.nll_loss(out, sub_data_train.y)
        loss.backward()
        optimizer.step()
    # print out metrics
    print(loss, " for subgraph ", i)
        






tensor(0.0007, grad_fn=<NllLossBackward0>)  for subgraph  0
tensor(0.0390, grad_fn=<NllLossBackward0>)  for subgraph  1
tensor(0.3194, grad_fn=<NllLossBackward0>)  for subgraph  2
tensor(0.0120, grad_fn=<NllLossBackward0>)  for subgraph  3
tensor(0.2557, grad_fn=<NllLossBackward0>)  for subgraph  4
tensor(0.0633, grad_fn=<NllLossBackward0>)  for subgraph  5
tensor(0.0091, grad_fn=<NllLossBackward0>)  for subgraph  6
tensor(0.1304, grad_fn=<NllLossBackward0>)  for subgraph  7
tensor(0.0790, grad_fn=<NllLossBackward0>)  for subgraph  8
tensor(0.0376, grad_fn=<NllLossBackward0>)  for subgraph  9
tensor(1.0921e-05, grad_fn=<NllLossBackward0>)  for subgraph  10
tensor(4.7087e-06, grad_fn=<NllLossBackward0>)  for subgraph  11
tensor(3.1689e-05, grad_fn=<NllLossBackward0>)  for subgraph  12
tensor(0.2011, grad_fn=<NllLossBackward0>)  for subgraph  13
tensor(0.0498, grad_fn=<NllLossBackward0>)  for subgraph  14
tensor(0.1226, grad_fn=<NllLossBackward0>)  for subgraph  15
tensor(5.4056e-05, gra