# Graph-level regression

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import torch
import torch.nn as nn
from torch_geometric import datasets
from torch_geometric.loader import DataLoader

from utils import (
    GCNModel,
    train_graph_level,
    plot_training_curves
)

In [None]:
# set random seed
_ = torch.manual_seed(12345)

## Dataset

In [None]:
# load dataset with node features and graph targets
data_set = datasets.QM9(root='../data')

# create dataloaders
train_loader = DataLoader(
    data_set[:100000],
    batch_size=32,
    shuffle=True
)

val_loader = DataLoader(
    data_set[100000:],
    batch_size=32,
    shuffle=True
)

In [None]:
# print summaries
print(f'Number of data points: {len(data_set)}')
print(f'Number of node features: {data_set.num_features}')
print(f'Number of graph targets: {data_set.num_classes}')

print(f'\nTensor shapes (single graph):\n{data_set[0]}')
print(f'Features shape: {data_set[0].x.shape}') # (num_nodes, num_features)
print(f'Targets shape: {data_set[0].y.shape}') # (1, num_targets)

In [None]:
# get batch
batch = next(iter(val_loader)) # batching works by assembling a larger graph with isolated subgraphs

# print summaries
print(f'Tensor shapes (batch):\n{batch}')
print(f'Number of graphs: {batch.num_graphs}')
print(f'Number of nodes: {batch.num_nodes}')
print(f'Number of edges: {batch.num_edges}')

## Model

In [None]:
# create model
model = GCNModel(
    num_channels=[data_set.num_features, 128, 16],
    num_features=data_set.num_classes,
    graph_level=True
)

In [None]:
# check output shape
y = model(batch.x, batch.edge_index, batch.batch)

print(f'Prediction shape: {y.shape}')

## Training

In [None]:
# create loss function
criterion = nn.MSELoss(reduction='mean')

# initialize optimizer
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=0.01,
    weight_decay=0.01
)

In [None]:
# run training
history = train_graph_level(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=2,
    train_loader=train_loader,
    val_loader=val_loader,
    log_every=1
)

In [None]:
# plot learning curves
fig, axes = plot_training_curves(history, figsize=(5, 4))