### Libraries


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import numpy as np
import networkx as nx

import torch_geometric
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
import plotly.graph_objects as go

### Graph Convolutional Network

In [None]:
class GCNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCNModel, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x) 
        y = self.conv2(x, edge_index)
        y = F.relu(y) 

        return x, y



### Graph Isomorphism Network

In [None]:
import torch.optim as optim
from torch_geometric.nn import global_add_pool


In [None]:
class MLP(nn.Module):
    """Construct two-layer Multi-layer preceptron aggregator for GIN model"""

    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.linears = nn.ModuleList()
        # two-layer MLP
        self.linears.append(nn.Linear(input_dim, hidden_dim, bias=False))
        self.linears.append(nn.Linear(hidden_dim, output_dim, bias=False))
        self.batch_norm = nn.BatchNorm1d((hidden_dim))

    def forward(self, x):
        h = x
        h = F.relu(self.batch_norm(self.linears[0](h)))
        return self.linears[1](h)


class GIN(nn.Module):
    """Graph Isomorphism Network"""
    
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.ginlayers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        num_layers = 2
        # two-layer GCN with two-layer MLP aggregator and sum-neighbor-pooling scheme
        for layer in range(num_layers - 1):  # excluding the input layer
            if layer == 0:
                mlp = MLP(input_dim, hidden_dim, hidden_dim)
            else:
                mlp = MLP(hidden_dim, hidden_dim, hidden_dim)
            self.ginlayers.append(
                GINConv(mlp, learn_eps=False)
            )  # set to True if learning epsilon
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        # linear functions for graph sum poolings of output of each layer
        self.linear_prediction = nn.ModuleList()
        for layer in range(num_layers):
            if layer == 0:
                self.linear_prediction.append(nn.Linear(input_dim, output_dim))
            else:
                self.linear_prediction.append(nn.Linear(hidden_dim, output_dim))
        self.drop = nn.Dropout(0.5)
        self.pool = (
            SumPooling()
        )  

    def forward(self, g, h):
        hidden_rep = [h]
        for i, layer in enumerate(self.ginlayers):
            h = layer(g, h)
            h = self.batch_norms[i](h)
            h = F.relu(h)
            hidden_rep.append(h)
        score_over_layer = 0
        # perform graph sum pooling over all nodes in each layer
        for i, h in enumerate(hidden_rep):
            pooled_h = self.pool(g, h)
            score_over_layer += self.drop(self.linear_prediction[i](pooled_h))
        return score_over_layer



def evaluate(dataloader, device, model):
    model.eval()
    total = 0
    total_correct = 0
    for batched_graph, labels in dataloader:
        batched_graph = batched_graph.to(device)
        labels = labels.to(device)
        feat = batched_graph.ndata.pop("attr")
        total += len(labels)
        logits = model(batched_graph, feat)
        _, predicted = torch.max(logits, 1)
        total_correct += (predicted == labels).sum().item()
    acc = 1.0 * total_correct / total
    return acc

### VISUALS

In [None]:
model = GCNModel(input_dim=2, hidden_dim=1, output_dim=3)

# Dummy input data for visualization
#(Replace with actual data)
x = torch.rand((300, 2))  
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)  # Replace with actual edge indices

# Construct PyTorch Geometric Data object
data = Data(x=x, edge_index=edge_index)

# Visualize the model in 3D
def visualize_model_3d(model, data):
    # Set the model to evaluation mode
    model.eval()

    # Forward pass
    with torch.no_grad():
        output = model(data)

    # Convert PyTorch Geometric data to NetworkX graph
    G = to_networkx(data)

    # Get node positions for 3D layout
    pos_3d = nx.spring_layout(G, dim=3)

    # Extract x, y, z positions
    pos_x = [pos_3d[node][0] for node in G.nodes]
    pos_y = [pos_3d[node][1] for node in G.nodes]
    pos_z = [pos_3d[node][2] for node in G.nodes]

    # Create a 3D scatter plot
    trace = go.Scatter3d(
        x=pos_x,
        y=pos_y,
        z=pos_z,
        mode='markers',
        marker=dict(
            size=12,
            colorbar=dict(
                title='Node ID',
            ),
            colorscale='Viridis',
            line=dict(color='rgb(140, 140, 170)', width=0.5),
            opacity=0.9
        )
    )

    # Create layout
    layout = go.Layout(
        showlegend=False,
        scene=dict(
            xaxis=dict(title='X'),
            yaxis=dict(title='Y'),
            zaxis=dict(title='Z'),
        )
    )

    # Create figure and add trace
    fig = go.Figure(data=[trace], layout=layout)

    # Show figure
    fig.show()

# Call the visualize_model_3d function
visualize_model_3d(model, data)