In [2]:
import os
import networkx as nx
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import from_networkx
from torch_geometric.utils import subgraph
from torch_geometric.data import Data

  from .autonotebook import tqdm as notebook_tqdm


In [3]:

dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]


folder_path = 'client_subgraphs'
sub_data_list = []

for filename in os.listdir(folder_path):
    if filename.endswith('.gml'):
      
        file_path = os.path.join(folder_path, filename)
        g = nx.read_gml(file_path)

        subgraph_nodes = list(g.nodes)
        subgraph_nodes = [int(node) for node in subgraph_nodes]  # Convert to integer if they are not

        sub_edge_index, _ = subgraph(subgraph_nodes, data.edge_index, relabel_nodes=True)

        sub_data = Data(x=data.x[subgraph_nodes], edge_index=sub_edge_index, y=data.y[subgraph_nodes])
        sub_data_list.append(sub_data)

# for filename in os.listdir(folder_path):
#     if filename.endswith('.gml'):
#         # Read GML file using networkx
#         file_path = os.path.join(folder_path, filename)
#         g = nx.read_gml(file_path)

#         # Convert networkx graph to PyTorch Geometric data
#         sub_data = from_networkx(g)
#         # Ensure node features and labels are set (this will depend on how data is stored in the GML file)

#         # Example: Set dummy features and labels if not present
#         if sub_data.x is None:
#             num_nodes = sub_data.num_nodes
#             sub_data.x = torch.randn((num_nodes, data.num_node_features))  # Replace with actual node features
#         if sub_data.y is None:
#             sub_data.y = torch.randint(0, dataset.num_classes, (sub_data.num_nodes,))  # Replace with actual labels
        
#         sub_data_list.append(sub_data)

In [4]:
print(sub_data_list[0])

Data(x=[17, 1433], edge_index=[2, 58], y=[17])


In [5]:
# #plot subgraph 0 to make sure it looks right
# import matplotlib.pyplot as plt
# import networkx as nx
# from torch_geometric.utils import to_networkx

# G = to_networkx(sub_data_list[82], to_undirected=False)
# plt.figure(figsize=(20,20))
# nx.draw(G, with_labels=True, node_size=15, node_color='g', edge_color='b')
# plt.show()


In [6]:

class GCN(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, 16)
        self.conv2 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # First convolutional layer
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        # Second convolutional layer
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


In [8]:
#split into train and validation
from torch_geometric.data import DataLoader
from sklearn.model_selection import train_test_split


In [16]:
#keep a  list of training and validation loss per epoch for each subgraph
train_losses = []
val_losses = []

In [39]:
import numpy as np
def train_val_split(data):
        #handle labels with only one sample
    # `data.x` contains the node features and `data.y` contains the labels

    # Count the occurrences of each label in the dataset
    label_counts = torch.bincount(data.y)

    # Find the labels that appear only once (single occurrence)
    single_occurrence_labels = torch.nonzero(label_counts == 1).flatten()

    # Initialize empty lists to store samples
    single_sample_label = []
    other_label = []

    # Separate samples based on labels
    for i, label in enumerate(data.y):
        if label in single_occurrence_labels:
            single_sample_label.append(i)
        else:
            other_label.append(i)

    # Convert the lists of sample indices into tensors
    single_sample_label = torch.tensor(single_sample_label)
    other_label = torch.tensor(other_label)

    # Extract the corresponding node features and labels
    single_sample_x = data.x[single_sample_label]
    single_sample_y = data.y[single_sample_label]
    other_x = data.x[other_label]
    other_y = data.y[other_label]
    
    other_x_train, other_x_test, other_y_train, other_y_test = train_test_split(other_x, other_y, test_size=0.2, random_state=42)

    # Add single_sample_x and single_sample_y to the training set
    combined_x_train = torch.cat((other_x_train, single_sample_x), dim=0)
    combined_y_train = torch.cat((other_y_train, single_sample_y), dim=0)
    
    return combined_x_train, combined_y_train, other_x_test, other_y_test
                    
 
    

In [41]:
for i in range(0, 100):
    
    sub_data = sub_data_list[i]
    #split sub_data.x and sub_data.y into train and validation, but keep the same edge_index
    train_x, train_y, val_x, val_y = train_val_split(sub_data)
    continue
    sub_data_train = Data(x=train_x, edge_index=sub_data.edge_index, y=train_y)
    sub_data_val = Data(x=val_x, edge_index=sub_data.edge_index, y=val_y)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = GCN(sub_data.num_node_features, dataset.num_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

    
    client_train_losses = []
    client_val_losses = []
    
    for epoch in range(200):
        model = model.train()
        optimizer.zero_grad()
        out = model(sub_data_train)
        loss = F.nll_loss(out, sub_data_train.y)
        loss.backward()
        optimizer.step()
        
        client_train_losses.append(loss.item())
        
        #calculate validation loss
        model = model.eval()
        out_val = model(sub_data_val)
        loss = F.nll_loss(out_val, sub_data_val.y)
        client_val_losses.append(loss.item())
        
        
    train_losses.append(client_train_losses)
    val_losses.append(client_val_losses)
    # print out metrics
    #calculate final training accuracy
    true_labels = sub_data_train.y
    _, pred = out.max(1)
    train_correct = pred.eq(true_labels).sum().item()
    print("final training loss: ", loss.item(), " for subgraph ", i)
    print("final training accuracy: ", train_correct / len(true_labels), " for subgraph ", i)
    print()
    #calculate validation accuracy and loss
    true_labels = sub_data_val.y
    _, pred = out_val.max(1)
    val_correct = pred.eq(true_labels).sum().item()
    
    
    print("final validation loss: ", loss.item(), " for subgraph ", i)
    print("final validation accuracy: ", val_correct / len(true_labels), " for subgraph ", i)
    print("------------------------------------------------------------")

IndexError: tensors used as indices must be long, byte or bool tensors