In [1]:
import os
import torch 

torch_version = torch.__version__.split("+")
print(torch_version)
print(torch.__version__)

['2.4.0', 'cu124']
2.4.0+cu124


In [2]:
os.environ['TORCH'] = torch_version[0]
os.environ['CUDA'] = torch_version[1]

In [3]:
# General imports
import os
import json
import collections

# Data science imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly
import scipy.sparse as sp

# Import Weights & Biases for Experiment Tracking
import wandb

# Graph imports
import torch
from torch import Tensor
import torch.nn.functional as F

import torch_geometric
from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_networkx

import networkx as nx
from networkx.algorithms import community

from tqdm.auto import trange

In [4]:
from visualize import GraphVisualization 

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
use_wandb = True
wandb_project = "intro_to_pyg"
wandb_run_name = "upload_and_analyze_dataset"

if use_wandb:
    wandb.init(project=wandb_project, name=wandb_run_name)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mdeepakpokkalla[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
from torch_geometric.datasets import TUDataset

dataset_path = "../datasets/TUDataset"
dataset = TUDataset(root=dataset_path, name='MUTAG')

dataset.download()

Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip


In [8]:
data_details = {
    "num_node_features": dataset.num_node_features,
    "num_edge_features": dataset.num_edge_features,
    "num_classes": dataset.num_classes,
    "num_node_labels": dataset.num_node_labels,
    "num_edge_labels": dataset.num_edge_labels
}

if use_wandb:
    wandb.log(data_details)
else:
    print(json.dumps(data_details, sort_keys=True, indent=4))

In [9]:
def create_graph(graph):
    g = to_networkx(graph)
    pos = nx.spring_layout(g)
    vis = GraphVisualization(g, pos, node_text_position='top left', node_size=20)
    fig = vis.create_figure()
    return fig

fig = create_graph(dataset[0])
fig.show()

In [10]:
use_wandb

True

In [None]:
# for graph in dataset:
#     print(graph)

In [11]:
if use_wandb:
    table = wandb.Table(columns=["Graph", "Number of nodes", "Number of edges", "Label"])
    for graph in dataset:
        fig = create_graph(graph)
        n_nodes = graph.num_nodes
        n_edges = graph.num_edges
        label = graph.y.item()

        table.add_data(wandb.Html(plotly.io.to_html(fig)), n_nodes, n_edges, label)
    wandb.log({'data': table})

In [12]:
if use_wandb:
    dataset_artifact = wandb.Artifact(name='MUTAG', type="dataset", metadata=data_details)
    dataset_artifact.add_dir(dataset_path)
    wandb.log_artifact(dataset_artifact)

    wandb.finish()

[34m[1mwandb[0m: Adding directory to artifact (.\..\datasets\TUDataset)... Done. 0.0s


VBox(children=(Label(value='52.111 MB of 541.191 MB uploaded\r'), FloatProgress(value=0.09628927740095047, max…

0,1
num_classes,▁
num_edge_features,▁
num_edge_labels,▁
num_node_features,▁
num_node_labels,▁

0,1
num_classes,2
num_edge_features,4
num_edge_labels,4
num_node_features,7
num_node_labels,7


## Training the model

In [13]:
torch.manual_seed(12345)
dataset = dataset.shuffle()

train_dataset = dataset[:150]
test_dataset = dataset[150:]

print(f'# training graphs: {len(train_dataset)}')
print(f'# testing graphs: {len(test_dataset)}')

# training graphs: 150
# testing graphs: 38


In [14]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=64)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'step:{step+1}, # graphs:{data.num_graphs}')
    print(data)

step:1, # graphs:64
DataBatch(edge_index=[2, 2636], x=[1188, 7], edge_attr=[2636, 4], y=[64], batch=[1188], ptr=[65])
step:2, # graphs:64
DataBatch(edge_index=[2, 2506], x=[1139, 7], edge_attr=[2506, 4], y=[64], batch=[1139], ptr=[65])
step:3, # graphs:22
DataBatch(edge_index=[2, 852], x=[387, 7], edge_attr=[852, 4], y=[22], batch=[387], ptr=[23])


In [15]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)
    
    def forward(self,x,edge_index, batch):
        x = self.conv1(x,edge_index)
        x = x.relu()
        x = self.conv2(x,edge_index)
        x = x.relu()
        x = self.conv3(x,edge_index)

        x = global_mean_pool(x,batch)

        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x
    
model = GCN(64)
print(model)


GCN(
  (conv1): GCNConv(7, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 64)
  (lin): Linear(in_features=64, out_features=2, bias=True)
)


In [16]:
wandb_project = "intro_to_pyg" #@param {type:"string"}
wandb_run_name = "upload_and_analyze_dataset" #@param {type:"string"}

# Initialize W&B run for training
if use_wandb:
    wandb.init(project="intro_to_pyg")
    wandb.use_artifact("deepakpokkalla/intro_to_pyg/MUTAG:v0")

In [17]:
model = GCN(hidden_channels=64)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()

    for data in train_loader:
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

In [18]:
def test(loader, create_table=False):
    model.eval()
    table = wandb.Table(columns=['graph', 'ground_truth', 'prediction']) if use_wandb else None
    correct = 0
    loss_ = 0
    for data in loader:
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y)
        loss_ += loss.item()
        pred = out.argmax(dim=1)

        if create_table and use_wandb:
            table.add_data(wandb.Html(plotly.io.to_html(create_graph(data))), data.y.item(), pred.item())
        
        correct += int((pred == data.y).sum())
    return correct/len(loader.dataset), loss_ / len(loader.dataset), table

In [19]:
for epoch in range(1, 171):
    train()
    train_acc, train_loss, _ = test(train_loader)
    test_acc, test_loss, test_table = test(test_loader, create_table=True)
    
    # Log metrics to W&B
    if use_wandb:
        wandb.log({
            "train/loss": train_loss,
            "train/acc": train_acc,
            "test/acc": test_acc,
            "test/loss": test_loss,
            "test/table": test_table
        })

    torch.save(model, "graph_classification_model.pt")
    
    # Log model checkpoint as an artifact to W&B
    if use_wandb:
        artifact = wandb.Artifact(name="graph_classification_model", type="model")
        artifact.add_file("graph_classification_model.pt")
        wandb.log_artifact(artifact)


# Finish the W&B run
if use_wandb:
    wandb.finish()

: 