# Graph convolutional network

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import torch
import torch.nn as nn
from torch_geometric import datasets

from gutils import GCNNodeClassifier, train

## Dataset

In [None]:
# load small dataset with dummy node features and node labels
data_set = datasets.KarateClub()

In [None]:
# print summaries
print(f'Number of data: {len(data_set)}')
print(f'Dataset tensor shapes: {data_set[0]}')

print(f'Number of features: {data_set.num_features}')
print(f'Number of classes: {data_set.num_classes}')

## Model

In [None]:
# create model
model = GCNNodeClassifier(
    num_channels=[data_set.num_features, 16],
    num_classes=data_set.num_classes
)

In [None]:
# check output shape
y = model(data_set[0].x, data_set[0].edge_index)

print(f'Prediction shape: {y.shape}')

## Training

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=200,
    data=data_set[0],
    mask=data_set[0].train_mask,
    log_every=10
)