In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import random

random.seed(42)

# Create a synthetic graph
num_nodes = 20
num_edges = 40
graph = nx.gnm_random_graph(num_nodes, num_edges)

# Convert the graph to a PyTorch Geometric format
edge_index = torch.tensor(list(graph.edges)).t().contiguous()
x = torch.randn(num_nodes, 16)  # Random node features

# Define a simple Graph Neural Network (GNN) model
class GNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super(GNN, self).__init__()
        self.conv_layers = nn.ModuleList()
        self.num_layers = num_layers
        self.conv_layers.append(GCNConv(input_dim, hidden_dim))
        for _ in range(num_layers - 2):
            self.conv_layers.append(GCNConv(hidden_dim, hidden_dim))
        self.conv_layers.append(GCNConv(hidden_dim, output_dim))

    def forward(self, x, edge_index):
        for i in range(self.num_layers - 1):
            x = F.relu(self.conv_layers[i](x, edge_index))
        x = self.conv_layers[-1](x, edge_index)
        return x

# Modify the GNN model to include skip connections with linear projection
class GNNWithSkipConnections(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super(GNNWithSkipConnections, self).__init__()
        self.conv_layers = nn.ModuleList()
        self.num_layers = num_layers
        self.conv_layers.append(GCNConv(input_dim, hidden_dim))
        for _ in range(num_layers - 2):
            self.conv_layers.append(GCNConv(hidden_dim, hidden_dim))
        self.conv_layers.append(GCNConv(hidden_dim, output_dim))
        self.skip_connections = nn.ModuleList([nn.Linear(input_dim, hidden_dim)] +
                                              [nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])

    def forward(self, x, edge_index):
        h = x
        for i in range(self.num_layers - 1):
            h = F.relu(self.conv_layers[i](h, edge_index) + self.skip_connections[i](x))
        h = self.conv_layers[-1](h, edge_index)
        return h

In [None]:
# Train the GNN model
def train_model(model, x, edge_index, num_epochs=100, lr=0.01):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()
        output = model(x, edge_index)
        y = torch.randint(0, 2, (x.size(0),))  # Binary classification task
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')
        

# Initialize and train the GNN
input_dim = 16  # Dimensionality of node features
hidden_dim = 32  # Hidden layer dimension
output_dim = 2  # Output dimension (binary classification task)
num_layers = 8  # Number of GNN layers
model = GNN(input_dim, hidden_dim, output_dim, num_layers)
# model = GNNWithSkipConnections(input_dim, hidden_dim, output_dim, num_layers)
train_model(model, x, edge_index)

# Visualize the node representations at different layers
def visualize_node_representations(model, x, edge_index):
    with torch.no_grad():
        model.eval()
        node_reps = [x.numpy()]
        for i in range(model.num_layers - 1):
            x = F.relu(model.conv_layers[i](x, edge_index))
            node_reps.append(x.numpy())
    return node_reps

node_reps = visualize_node_representations(model, x, edge_index)


In [None]:
# Plot node representations across layers
layer_names = [f'Layer {i+1}' for i in range(num_layers)]
fig, axs = plt.subplots(1, num_layers, figsize=(15, 5))
for i, rep in enumerate(node_reps):
    axs[i].scatter(rep[:, 0], rep[:, 1], c=np.arange(num_nodes), cmap='viridis')
    axs[i].set_title(layer_names[i])
plt.tight_layout()
plt.show()
