# Graph Neural Network (GNN)

In [None]:
import torch
import torch.optim as optim

import networkx as nx

import torchvision
import torchvision.transforms as transforms
import torch_geometric.transforms as T

import matplotlib.pyplot as plt

from loguru import logger
from src.model.GNN import GNN

from torch_geometric.utils import to_networkx
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import MNISTSuperpixels

## Data

### Load Data

Download and Load Training Dataset

In [None]:
dataset = MNISTSuperpixels(
    root='./data'
)

Informations about the Dataset

In [None]:
print(f"Dataset type: {type(dataset)}")
print(f"Dataset features: {dataset.num_features}")
print(f"Dataset target: {dataset.num_classes}")
print(f"Dataset length: {dataset.len}")

In [None]:
sample = dataset[0]

print(f"Dataset sample: {sample}")
print(f"Sample nodes: {sample.num_nodes}")
print(f"Sample edges: {sample.num_edges}")

In [None]:
sample.x

In [None]:
sample.edge_index.t()

## Graph Visualization

In [None]:
sample_graph = dataset[0]

In [None]:
G = to_networkx(sample_graph, to_undirected=sample_graph.is_undirected())

In [None]:
sample_graph.y

In [None]:
nx.draw_networkx(G)

## Train GNN

In [None]:
def evaluate(
        loader,
        model : GNN,
        is_validation=False):
    
    model.eval()

    correct = 0

    for data in loader:
        with torch.no_grad():
            emb, pred = model(data)
            pred = pred.argmax(dim=1)
            label = data.y

        if model.task is GNNTask.Node:
            pred = pred[:data.num_graphs]
            
        correct += pred.eq(label).sum().item()
    
    if model.task is GNNTask.Graph:
        total = len(loader.dataset) 
    else:
        total = 0
        for data in loader.dataset:
            total += torch.sum(data.num_graphs).item()
    return correct / total

In [None]:
def train(dataset, model, opt, epochs=200):
    data_size = len(dataset)
    loader = DataLoader(dataset[:int(data_size * 0.8)], batch_size=64, shuffle=True)

    for epoch in range(epochs):
        model.train()
        
        for (batch_idx, batch) in enumerate(loader):
            opt.zero_grad()
            
            pred = model(batch)
            loss = model.loss(pred, batch.y)
            loss.backward()
            opt.step()
            
        logger.info(f"Epoch: {epoch} | Train Loss {loss:.3f}")

In [None]:
input_dim = max(dataset.num_features, 1)
hidden_dim = 32
output_dim = dataset.num_classes

model = GNN(input_dim, hidden_dim, output_dim)
opt = optim.Adam(model.parameters(), lr=0.01)

In [None]:
train(
    dataset,
    model=model,
    opt=opt,
    epochs=200
)

## Predict

In [None]:
data_size = len(dataset)
test_loader = DataLoader(dataset[int(data_size * 0.8):], batch_size=64, shuffle=True)

test_batch = next(iter(test_loader))

with torch.no_grad():
    pred = model(test_batch)
    pred = torch.argmax(pred, dim=1)

print(test_batch.y[0])
print(pred[0])