In [1]:
import json
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg
from sklearn.metrics import r2_score
from torch.nn.functional import one_hot
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx

from sklearn.model_selection import train_test_split

from gcn_model import GCN, GCNNet
from mpn_model import GraphEncoder, MessageLayer
import mpn_model

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
def load_dataset(data_path, n_graphs=400):
    dataset=[]
    for i in range(0, n_graphs):
        graph = torch.load(data_path+f"/graph{i}.pt")
        x_oh = one_hot(graph.x.flatten().type(torch.LongTensor), num_classes=2).type(torch.cuda.FloatTensor)
        graph.x = x_oh
        dataset.append(graph)
    return dataset

In [5]:
dataset = load_dataset("./data/bapst_graphs", n_graphs=400)

In [6]:
train_dataset, test_dataset = train_test_split(dataset, test_size=40, random_state=42)
train_dataset, val_dataset = train_test_split(train_dataset, test_size=40, random_state=43)

In [7]:
print(len(train_dataset), 'training graphs')
print(len(val_dataset), 'validation graphs')
print(len(test_dataset), 'test graphs')

320 training graphs
40 validation graphs
40 test graphs


In [8]:
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [9]:
model = mpn_model.PaperGNN(2).to(device)

In [10]:
def train(model, train_loader, optimizer, loss):
    model.train()
    loss_acc = 0
    total_graphs = 0

    total_preds = []
    labels = []

    for graph_batch in train_loader:
        graph_batch = graph_batch.to(device)
        optimizer.zero_grad()
        preds = model(graph_batch.x, graph_batch.edge_attr, graph_batch.edge_index)
        loss_val = loss(preds.squeeze(), graph_batch.y.squeeze())
        loss_acc += loss_val.item()
        total_graphs += graph_batch.num_graphs
        loss_val.backward()
        optimizer.step()

        total_preds.extend(preds.cpu().detach().numpy())
        labels.extend(graph_batch.y.cpu().detach().numpy())
        
    loss_acc /= total_graphs
    r2 = r2_score(labels, total_preds)

    return loss_acc, r2

In [14]:
def validate(model, valid_loader, loss):
    model.eval()
    loss_acc = 0
    total_graphs = 0
    total_preds = []
    labels = []
    with torch.no_grad():
        for graph_batch in valid_loader:
            graph_batch = graph_batch.to(device)
            preds = model(graph_batch.x, graph_batch.edge_attr, graph_batch.edge_index)
            loss_val = loss(preds.squeeze(), graph_batch.y.squeeze())
            loss_acc += loss_val.item()
            total_graphs += graph_batch.num_graphs
            total_preds.extend(preds.cpu().numpy())
            labels.extend(graph_batch.y.cpu().numpy())

    r2 = r2_score(labels, total_preds)            
    loss_acc /= total_graphs
    return loss_acc, r2

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss = nn.MSELoss()

In [15]:
train_loss = []
val_loss = []
train_r2 = []
val_r2 = []

for epoch in range(30):
    print('EPOCH:', epoch+1)
    print('Training...')
    loss_value, r2_value = train(model, train_dataloader, optimizer, loss)
    train_loss.append(loss_value)
    train_r2.append(r2_value)

    print('Validating..')
    loss_value, r2_value = validate(model, val_dataloader, loss)
    val_loss.append(loss_value)
    val_r2.append(r2_value)


    print('Training Loss:', train_loss[-1])
    print('Training R2:', train_r2[-1])
    print('Validation Loss:', val_loss[-1])
    print('Validation R2:', val_r2[-1])

EPOCH: 1
Training...
Validating..
Training Loss: 0.06197560094296932
Training R2: 0.34499155730720077
Validation Loss: 0.0642177320085466
Validation R2: 0.3755986520333937
EPOCH: 2
Training...
Validating..
Training Loss: 0.058551107067614794
Training R2: 0.3811843876350005
Validation Loss: 0.06310237217694521
Validation R2: 0.3864435131630667
EPOCH: 3
Training...
Validating..
Training Loss: 0.05783640823792666
Training R2: 0.38873790098441763
Validation Loss: 0.06818232471123338
Validation R2: 0.33705016151151856
EPOCH: 4
Training...
Validating..
Training Loss: 0.05772982182679698
Training R2: 0.38986439382768623
Validation Loss: 0.06288199173286557
Validation R2: 0.3885863154100251
EPOCH: 5
Training...
Validating..
Training Loss: 0.05753598170122132
Training R2: 0.3919130537896488
Validation Loss: 0.06310829957947135
Validation R2: 0.38638587420146386
EPOCH: 6
Training...
Validating..
Training Loss: 0.05760627598501742
Training R2: 0.39117012882394375
Validation Loss: 0.06699299784377

KeyboardInterrupt: 