In [1]:
import json
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg
from sklearn.metrics import r2_score
from torch.nn.functional import one_hot
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx

from sklearn.model_selection import train_test_split

from gcn_model import GCN, GCNNet

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def load_dataset(data_path, n_graphs=400):
    dataset=[]
    for i in range(0, n_graphs):
        graph = torch.load(data_path+f"/graph{i}.pt")
        x_oh = one_hot(graph.x.flatten().type(torch.LongTensor), num_classes=2).type(torch.cuda.FloatTensor)
        graph.x = x_oh
        dataset.append(graph)
    return dataset

In [4]:
dataset = load_dataset("./data/bapst_graphs", n_graphs=400)

In [5]:
train_dataset, test_dataset = train_test_split(dataset, test_size=40, random_state=42)
train_dataset, val_dataset = train_test_split(train_dataset, test_size=40, random_state=43)

In [6]:
print(len(train_dataset), 'training graphs')
print(len(val_dataset), 'validation graphs')
print(len(test_dataset), 'test graphs')

320 training graphs
40 validation graphs
40 test graphs


In [7]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [8]:
model = GCN([2, 32, 32, 1]).to(device)

In [9]:
def train(model, train_loader, optimizer, loss):
    model.train()
    loss_acc = 0
    total_nodes = 0

    total_preds = []
    labels = []

    for graph_batch in train_loader:
        graph_batch = graph_batch.to(device)
        optimizer.zero_grad()
        preds = model(graph_batch.x, graph_batch.edge_index)
        loss_val = loss(preds.squeeze(), graph_batch.y.squeeze())
        loss_acc += loss_val.item()
        total_nodes += graph_batch.num_nodes
        loss_val.backward()
        optimizer.step()

        total_preds.extend(preds.cpu().detach().numpy())
        labels.extend(graph_batch.y.cpu().detach().numpy())
        
    loss_acc /= total_nodes
    r2 = r2_score(labels, total_preds)

    return loss_acc, r2

In [10]:
def validate(model, valid_loader, loss):
    model.eval()
    loss_acc = 0
    total_nodes = 0
    total_preds = []
    labels = []
    with torch.no_grad():
        for graph_batch in valid_loader:
            graph_batch = graph_batch.to(device)
            preds = model(graph_batch.x, graph_batch.edge_index)
            loss_val = loss(preds.squeeze(), graph_batch.y.squeeze())
            loss_acc += loss_val.item()
            total_nodes += graph_batch.num_nodes
            total_preds.extend(preds.cpu().numpy())
            labels.extend(graph_batch.y.cpu().numpy())

    r2 = r2_score(labels, total_preds)            
    loss_acc /= total_nodes
    return loss_acc, r2

In [14]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss = nn.MSELoss()

In [15]:
train_loss = []
val_loss = []
train_r2 = []
val_r2 = []

for epoch in range(30):
    print('EPOCH:', epoch+1)
    print('Training...')
    loss_value, r2_value = train(model, train_dataloader, optimizer, loss)
    train_loss.append(loss_value)
    train_r2.append(r2_value)

    print('Validating..')
    loss_value, r2_value = validate(model, val_dataloader, loss)
    val_loss.append(loss_value)
    val_r2.append(r2_value)


    print('Training Loss:', train_loss[-1])
    print('Training R2:', train_r2[-1])
    print('Validation Loss:', val_loss[-1])
    print('Validation R2:', val_r2[-1])

EPOCH: 1
Training...
Validating..
Training Loss: 7.230719006656727e-07
Training R2: -0.0016535938304942377
Validation Loss: 1.232898466696497e-06
Validation R2: -0.0006644309315535502
EPOCH: 2
Training...
Validating..
Training Loss: 7.23352462728144e-07
Training R2: -0.0020422455961428554
Validation Loss: 1.238509503309615e-06
Validation R2: -0.006400273314388638
EPOCH: 3
Training...
Validating..
Training Loss: 7.230109360989445e-07
Training R2: -0.0015691545679539232
Validation Loss: 1.232790691574337e-06
Validation R2: -0.0009047525924501532
EPOCH: 4
Training...
Validating..
Training Loss: 7.22629891924953e-07
Training R2: -0.0010413343017054988
Validation Loss: 1.2327214790275321e-06
Validation R2: -0.000875347558418671
EPOCH: 5
Training...
Validating..
Training Loss: 7.223870682082633e-07
Training R2: -0.0007049498817812694
Validation Loss: 1.233758212038083e-06
Validation R2: -0.0020289938090833903
EPOCH: 6
Training...
Validating..
Training Loss: 7.223423949653807e-07
Training R2:

In [None]:
for batch_graph in train_dataloader:
    batch_graph = batch_graph.to(device)
    output = model(batch_graph.x, batch_graph.edge_index)
    print(output)
    print(output.shape)
    print(batch_graph.y.shape)
    break

tensor([[-0.0177],
        [-0.0191],
        [-0.0180],
        ...,
        [-0.0193],
        [-0.0199],
        [-0.0191]], device='cuda:0', grad_fn=<AddBackward0>)
torch.Size([131072, 1])
torch.Size([131072, 1])
