# Graph Classification

## Graph Classes

we use the following graph classes:
- Chemical Graphs (Molecules)
- Random
- Small World
- Scale Free

## Setup

We use pyg (pytorch-geometric) to generate the model to train.
The model is a GCN with 2 layers and 32 hidden units.

In [1]:
import torch_geometric
import torch

print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.mps.is_available())
print(torch_geometric.__version__)

1.13.1+cu117
11.7
False
2.2.0


In [2]:
# General imports
import os
import json
import collections

# Data science imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly
import scipy.sparse as sp

# Import Weights & Biases for Experiment Tracking
import wandb

# Graph imports
import torch
from torch import Tensor
import torch.nn.functional as F

import torch_geometric
from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_networkx

import networkx as nx
from networkx.algorithms import community

from tqdm.auto import trange

In [3]:
use_wandb = True  # @param {type:"boolean"}
wandb_project = "bt-hostettler"  # @param {type:"string"}
wandb_run_name = "upload_and_analyze_dataset"  # @param {type:"string"}

if use_wandb:
    wandb.init(project=wandb_project, name=wandb_run_name)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlucahost[0m ([33mbt-hostettler[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import from_networkx

# we need to read the dataset from the pickle file
import pickle

with open('dataset.pickle', 'rb') as f:
    dataset = pickle.load(f)

for key in dataset.keys():
    print(f"datasetkey: {key}, shape: {len(dataset[key])}")

# we need to convert the networkx graphs to pytorch geometric graphs

class_key = dict(zip(dataset.keys(), range(len(dataset.keys()))))
print(class_key)
pyg_dataset = []
for key in dataset.keys():
    for i in range(len(dataset[key])):
        graph = dataset[key][i]
        graph_tensor = from_networkx(graph, group_node_attrs=["label", "betweenness", "degree"])
        graph_tensor.y = torch.tensor([class_key[key]])
        pyg_dataset.append(graph_tensor)

datasetkey: random, shape: 1000
datasetkey: smallworld, shape: 1000
datasetkey: scalefree, shape: 200
datasetkey: complete, shape: 200
datasetkey: line, shape: 200
datasetkey: tree, shape: 400
datasetkey: star, shape: 200
{'random': 0, 'smallworld': 1, 'scalefree': 2, 'complete': 3, 'line': 4, 'tree': 5, 'star': 6}


In [5]:
import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.utils import from_networkx
import pickle


class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data.pt']


    def process(self):
        # Read data into huge `Data` list.
        with open('dataset.pickle', 'rb') as f:
            dataset = pickle.load(f)

        class_key = dict(zip(dataset.keys(), range(len(dataset.keys()))))
        print(class_key)
        data_list = []
        for key in dataset.keys():
            for i in range(len(dataset[key])):
                graph = dataset[key][i]
                graph_tensor = from_networkx(graph, group_node_attrs=["label", "betweenness", "degree", "density"])
                graph_tensor.y = torch.tensor([class_key[key]])
                data_list.append(graph_tensor)

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [6]:
import torch
from torch_geometric.datasets import TUDataset

pyg_dataset = MyOwnDataset(root='data/CustomData')

print()
print(f'Dataset: {pyg_dataset}:')
print('====================')
print(f'Number of graphs: {len(pyg_dataset)}')
print(f'Number of features: {pyg_dataset.num_features}')
print(f'Number of classes: {pyg_dataset.num_classes}')

data = pyg_dataset[0]  # Get the first graph object.

print()
print(data)
print(data.x)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

data_details = {
    "num_node_features": pyg_dataset.num_node_features,
    "num_edge_features": pyg_dataset.num_edge_features,
    "num_classes": pyg_dataset.num_classes,
}

if use_wandb:
    # Log all the details about the data to W&B.
    wandb.log(data_details)  # 🪄🐝
else:
    print(json.dumps(data_details, sort_keys=True, indent=4))


Processing...


{'random': 0, 'smallworld': 1, 'scalefree': 2, 'complete': 3, 'line': 4, 'tree': 5, 'star': 6}

Dataset: MyOwnDataset(3200):
Number of graphs: 3200
Number of features: 4
Number of classes: 7

Data(edge_index=[2, 12], x=[5, 4], y=[1])
tensor([[0.0000, 0.0833, 0.5000, 0.6000],
        [1.0000, 0.0000, 0.5000, 0.6000],
        [2.0000, 0.0833, 0.5000, 0.6000],
        [3.0000, 0.2500, 0.7500, 0.6000],
        [4.0000, 0.2500, 0.7500, 0.6000]])
Number of nodes: 5
Number of edges: 12
Average node degree: 2.40
Has isolated nodes: False
Has self-loops: False
Is undirected: True


Done!


In [7]:
from visualize import GraphVisualization

def create_graph(graph):
    g = to_networkx(graph)
    pos = nx.spring_layout(g)
    vis = GraphVisualization(
        g, pos, node_text_position='top left', node_size=20,
    )
    fig = vis.create_figure()
    return fig


fig = create_graph(pyg_dataset[0])
fig.show()


In [8]:
from tqdm.notebook import tqdm
if use_wandb:
    # Log exploratory visualizations for each data point to W&B
    table = wandb.Table(
        columns=["Graph", "Number of Nodes", "Number of Edges", "Label"])
    for graph in tqdm(pyg_dataset, total=len(pyg_dataset), desc='Generating Graphs for W&B'):
        fig = create_graph(graph)
        n_nodes = graph.num_nodes
        n_edges = graph.num_edges
        label = graph.y.item()

        table.add_data(wandb.Html(plotly.io.to_html(fig)),
                       n_nodes, n_edges, label)
    wandb.log({"data": table})


Generating Graphs for W&B:   0%|          | 0/3200 [00:00<?, ?it/s]

In [9]:
dataset_path = "data/CustomData"
if use_wandb:
    # Log the dataset to W&B as an artifact.
    dataset_artifact = wandb.Artifact(name="BT-Data", type="dataset", metadata=data_details)
    dataset_artifact.add_dir(dataset_path)
    wandb.log_artifact(dataset_artifact)
    
    # End the W&B run
    wandb.finish()

[34m[1mwandb[0m: Adding directory to artifact (.\data\CustomData)... Done. 0.6s


0,1
num_classes,▁
num_edge_features,▁
num_node_features,▁

0,1
num_classes,7
num_edge_features,0
num_node_features,4


In [10]:
torch.manual_seed(12345)
pyg_dataset = pyg_dataset.shuffle()

train_dataset = pyg_dataset[:2560]
test_dataset = pyg_dataset[2560:]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

Number of training graphs: 2560
Number of test graphs: 640


In [11]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
Number of graphs in the current batch: 64
DataBatch(edge_index=[2, 248686], x=[4916, 4], y=[64], batch=[4916], ptr=[65])

Step 2:
Number of graphs in the current batch: 64
DataBatch(edge_index=[2, 294872], x=[6873, 4], y=[64], batch=[6873], ptr=[65])

Step 3:
Number of graphs in the current batch: 64
DataBatch(edge_index=[2, 337300], x=[6702, 4], y=[64], batch=[6702], ptr=[65])

Step 4:
Number of graphs in the current batch: 64
DataBatch(edge_index=[2, 267134], x=[5797, 4], y=[64], batch=[5797], ptr=[65])

Step 5:
Number of graphs in the current batch: 64
DataBatch(edge_index=[2, 284494], x=[6300, 4], y=[64], batch=[6300], ptr=[65])

Step 6:
Number of graphs in the current batch: 64
DataBatch(edge_index=[2, 298692], x=[6344, 4], y=[64], batch=[6344], ptr=[65])

Step 7:
Number of graphs in the current batch: 64
DataBatch(edge_index=[2, 254668], x=[6073, 4], y=[64], batch=[6073], ptr=[65])

Step 8:
Number of graphs in the current batch: 64
DataBatch(edge_index=[2, 338148], x=[663

In [12]:
from model.gcn import GCN
import wandb
import matplotlib.pyplot as plt
from tqdm.notebook import trange
from IPython.display import Javascript
display(Javascript(
    '''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

# starting new w&B run
wandb_project = "intro_to_pyg"  # @param {type:"string"}
wandb_run_name = "upload_and_analyze_dataset"  # @param {type:"string"}

# Initialize W&B run for training
if use_wandb:
    wandb.init(project="bt-hostettler")
    wandb.use_artifact("bt-hostettler/bt-hostettler/BT-Data:v1")

model = GCN(hidden_channels=64, num_classes=pyg_dataset.num_classes,
            num_node_features=pyg_dataset.num_node_features)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()


def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
        # Perform a single forward pass.
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.


def test(loader, create_table=False):
    model.eval()
    table = wandb.Table(
        columns=['graph', 'ground truth', 'prediction']) if use_wandb else None
    correct = 0
    loss_ = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y)
        loss_ += loss.item()
        pred = out.argmax(dim=1)  # Use the class with highest probability.

        if create_table and use_wandb:
            table.add_data(wandb.Html(plotly.io.to_html(
                create_graph(data))), data.y.item(), pred.item())

        # Check against ground-truth labels.
        correct += int((pred == data.y).sum())
    # Derive ratio of correct predictions.
    return correct / len(loader.dataset), loss_ / len(loader.dataset), table


confusion_matrix = torch.zeros(3, 3)
train_losses = []
val_losses = []
for epoch in trange(1, 100):
    train()
    train_acc, train_loss, _ = test(train_loader)
    test_acc, test_loss, test_table = test(test_loader, create_table=True)

    # Log metrics to W&B
    if use_wandb:
        wandb.log({
            "train/loss": train_loss,
            "train/acc": train_acc,
            "test/acc": test_acc,
            "test/loss": test_loss,
            "test/table": test_table
        })

    torch.save(model, "graph_classification_model.pt")

    # Log model checkpoint as an artifact to W&B
    if use_wandb:
        artifact = wandb.Artifact(
            name="graph_classification_model", type="model")
        artifact.add_file("graph_classification_model.pt")
        wandb.log_artifact(artifact)

    train_losses.append(train_acc)
    val_losses.append(test_acc)


# Finish the W&B run
if use_wandb:
    wandb.finish()


<IPython.core.display.Javascript object>

  0%|          | 0/99 [00:00<?, ?it/s]

MemoryError: 

In [None]:
plt.figure(figsize=(10,5))
plt.title("Training and Test Accuracy")
plt.plot(val_losses,label="val")
plt.plot(train_losses,label="train")
plt.xlabel("iterations")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

In [None]:
print(model)

In [None]:
torch.save(model.state_dict(), 'gnn_model_weights.pth')
torch.save(model, 'gnn_model.pth')

In [None]:
from torchviz import make_dot
import os

os.environ["PATH"] += os.pathsep + 'C:/Program Files/Graphviz/bin/'

shuffled_dataset = pyg_dataset.shuffle()

for i, data in enumerate(shuffled_dataset[:9]):
    y = model(data.x, data.edge_index, data.batch)
    make_dot(y, params=dict(list(model.named_parameters()))).render("torchviz", format="png")
    make_dot(y.mean(), params=dict(model.named_parameters()), show_attrs=True, show_saved=True)

    break

In [None]:
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import random
from torch_geometric.utils import to_networkx
from model.gcn import GCN


class_key = dict(zip(dataset.keys(), range(len(dataset.keys()))))
class_key = {v: k for k, v in class_key.items()}

def index_to_class(index):
    return class_key[index]

fig, ax = plt.subplots(3, 3, figsize=(11, 11))
fig.suptitle('GCN - Graph classification')

shuffled_dataset = pyg_dataset.shuffle()

for i, data in enumerate(shuffled_dataset[:9]):
    # Calculate color (green if correct, red otherwise)
    out = model(data.x, data.edge_index, batch=data.batch)
    pred = out.argmax(dim=1)
    color = "green" if out.argmax(dim=1) == data.y else "red"

    # Plot graph
    ix = np.unravel_index(i, ax.shape)
    ax[ix].axis('off')
    ax[ix].set_title('Predicted: ' + index_to_class(pred.item()) + '\nActual: ' + index_to_class(data.y.item()))
    G = to_networkx(data, to_undirected=True)
    nx.draw_networkx(G,
                     pos=nx.spring_layout(G, seed=0),
                     with_labels=True,
                     node_size=150,
                     node_color=color,
                     width=0.8,
                     ax=ax[ix]
                     )
