In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.nn.functional import one_hot
import torch_geometric as pyg
import networkx as nx
import torch.nn.functional as F
from torch_geometric.datasets import ZINC
from torch_geometric.utils import to_networkx
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from gcn_model import GCNNet, GCN

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
train_ds = ZINC(root='./data/zinc/', subset=True, split='train')
test_ds = ZINC(root='./data/zinc/', subset=True, split='test')

In [5]:
len(train_ds)

10000

In [6]:
len(test_ds)

1000

In [7]:
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

In [8]:
model = GCNNet([21, 16, 16]).to(device)

In [9]:
for params in model.parameters():
    print(params.shape)

torch.Size([16])
torch.Size([16, 21])
torch.Size([16])
torch.Size([16, 16])
torch.Size([8, 16])
torch.Size([8])
torch.Size([1, 8])
torch.Size([1])


In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=5e-4)

In [11]:
loss = nn.MSELoss()

In [22]:
def train(model, train_loader, optimizer, loss):
    model.train()
    loss_acc = 0
    total_graphs = 0
    for graph_batch in train_loader:
        graph_batch = graph_batch.to(device)
        optimizer.zero_grad()
        
        x_oh = one_hot(graph_batch.x.flatten(), num_classes=21).type(torch.cuda.FloatTensor)
        preds = model(x_oh, graph_batch.edge_index, graph_batch.batch).squeeze()
        loss_val = loss(preds, graph_batch.y)
        loss_acc += loss_val.item()
        total_graphs += graph_batch.num_graphs
        loss_val.backward()
        optimizer.step()
        
    loss_acc /= total_graphs
    return loss_acc

In [13]:
def validate(model, valid_loader, loss):
    model.eval()
    loss_acc = 0
    total_graphs = 0
    for graph_batch in valid_loader:
        graph_batch = graph_batch.to(device)
        x_oh = one_hot(graph_batch.x.flatten(), num_classes=21).type(torch.cuda.FloatTensor)
        preds = model(x_oh, graph_batch.edge_index, graph_batch.batch).squeeze()
        loss_val = loss(preds, graph_batch.y)
        loss_acc += loss_val.item()
        total_graphs += graph_batch.num_graphs
        
    loss_acc /= total_graphs
    return loss_acc

In [27]:
for batch in train_loader:
    batch1 = batch
    break

In [28]:
batch1.x.shape

torch.Size([773, 1])

In [29]:
batch1.batch.shape

torch.Size([773])

In [30]:
batch1.batch

tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  3,  3,
         3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
         3,  3,  3,  3,  3,  3,  3,  3,  3,  4,  4,  4,  4,  4,  4,  4,  4,  4,
         4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  6,  6,  6,  6,  6,  6,  6,
         6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  7,
         7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,
         8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
         8,  8,  8,  8,  8,  8,  8,  8, 

In [16]:
train_loss = []
test_loss = []
for epoch in range(50):
    print('EPOCH:', epoch+1)
    print('Training...')
    loss_value = train(model, train_loader, optimizer, loss)
    train_loss.append(loss_value)
    print('Training Loss:', loss_value)

    print('Validating')
    loss_value = validate(model, test_loader, loss)
    test_loss.append(loss_value)
    print('Validation Loss:', loss_value)


EPOCH: 1
Training...
Training Loss: 0.1277686879992485
Validating
Validation Loss: 0.13161374521255492
EPOCH: 2
Training...


KeyboardInterrupt: 