# Node classification

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import torch
import torch.nn as nn
from torch_geometric import datasets, transforms

from utils import (
    GCNModel,
    train_node_level,
    plot_training_curves
)

In [None]:
# set random seed
_ = torch.manual_seed(12345)

## Dataset

In [None]:
# load dataset with node features and node labels
data_set = datasets.Planetoid(
    root='../data',
    name='Cora'
    # transform=transforms.NormalizeFeatures(attrs=['x']) # normalize rows
)

In [None]:
# print summaries
print(f'Number of graphs: {len(data_set)}')
print(f'Number of node features: {data_set.num_features}')
print(f'Number of node classes: {data_set.num_classes}')

print(f'\nNumber of train nodes: {data_set.train_mask.sum()}')
print(f'Number of val. nodes: {data_set.val_mask.sum()}')
print(f'Number of test nodes: {data_set.test_mask.sum()}')

print(f'\nTensor shapes:\n{data_set[0]}')
print(f'Features shape: {data_set[0].x.shape}') # (num_nodes, num_features)
print(f'Targets shape: {data_set[0].y.shape}') # (num_nodes,)

## Model

In [None]:
# create model
model = GCNModel(
    # num_channels=[data_set.num_features, 128, data_set.num_classes],
    num_channels=[data_set.num_features, 128, 16],
    num_features=data_set.num_classes
)

In [None]:
# check output shape
y = model(data_set[0].x, data_set[0].edge_index)

print(f'Node features shape: {data_set[0].x.shape}') # (num_nodes, num_features)
print(f'Prediction shape: {y.shape}') # (num_nodes, num_classes)

## Training

In [None]:
# determine device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [None]:
# create loss function
criterion = nn.CrossEntropyLoss(reduction='mean')

# initialize optimizer
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=0.01,
    weight_decay=0.01
)

In [None]:
# run training
history = train_node_level(
    data=data_set[0],
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=200,
    log_every=1
)

In [None]:
# plot learning curves
fig, axes = plot_training_curves(history, figsize=(9, 3.5))