# Import libraries

In [1]:
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.utils.data import Subset

from tqdm.notebook import tqdm

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import Sequential, GCNConv

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

from model import VGAE, Encoder, Decoder
from loss import VGAELoss
from utils import adj_matrix_from_edge_index

# Config

In [2]:
config = {
    "DEVICE": "cpu",
    "EPOCHS": 10,
    "LR": 1e-3,
    "BATCH_SIZE": 2,
    "SHUFFLE": True,
    "TEST_SIZE": 0.2,
}

# Import our dataset

In [3]:
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)

train_idx, eval_idx = train_test_split(list(range(len(dataset))), test_size=config["TEST_SIZE"])

train_loader = DataLoader(Subset(dataset, train_idx), batch_size=config["BATCH_SIZE"], shuffle=config["SHUFFLE"])
eval_loader = DataLoader(Subset(dataset, eval_idx), batch_size=config["BATCH_SIZE"], shuffle=config["SHUFFLE"])

# Model

In [4]:
hidden_model = Sequential('x, edge_index, edge_attr', [
    (GCNConv(dataset.num_features, 64), 'x, edge_index -> x1'),
    nn.ReLU(inplace=True),
    (GCNConv(64, 64), 'x1, edge_index -> x2'),
    nn.ReLU(inplace=True),
    (GCNConv(64, 64), 'x2, edge_index -> x3'),
    nn.ReLU(inplace=True),

])

mean_model = Sequential('x, edge_index, edge_attr', [
    (GCNConv(64, 64), 'x, edge_index -> x1'),
    nn.ReLU(inplace=True),
    (GCNConv(64, 64), 'x1, edge_index -> x2'),
    nn.ReLU(inplace=True),

])

std_model = Sequential('x, edge_index, edge_attr', [
    (GCNConv(64, 64), 'x, edge_index -> x1'),
    nn.ReLU(inplace=True),
    (GCNConv(64, 64), 'x1, edge_index -> x2'),
    nn.ReLU(inplace=True),

])

encoder = Encoder(
    hidden_model=hidden_model,
    mean_model=mean_model,
    std_model=std_model
)

decoder = Decoder()

# Training

In [5]:
def train_epoch(model, loader, loss_function, optimizer, device):
    model.train()
    model.to(device)
    loss_function.to(device)
    
    preds = []
    targets = []
    total_loss = 0.
    
    for batch in tqdm(loader):
        adj = adj_matrix_from_edge_index(batch.x, batch.edge_index)
        
        optimizer.zero_grad()
        adj_output, mu, logvar = model(batch)

        loss = loss_function(adj_output, mu, logvar, adj)
        
        total_loss += loss.item()
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        preds.append(adj_output.flatten())
        targets.append(adj.flatten())
    
    preds = torch.cat(preds, dim=0).sigmoid().detach().cpu().numpy()
    targets = torch.cat(targets, dim=0).detach().cpu().numpy()
    roc_auc = roc_auc_score(targets, preds)

    print(f"TRAIN Loss: {total_loss}, ROC AUC: {roc_auc}")
        
        
def eval_epoch(model, loader, loss_function, device):
    model.eval()
    model.to(device)
    loss_function.to(device)
    
    preds = []
    targets = []
    total_loss = 0.
    
    for batch in tqdm(loader): 
        adj = adj_matrix_from_edge_index(batch.x, batch.edge_index)
        
        with torch.no_grad():
            adj_output, mu, logvar = model(batch)
            loss = loss_function(adj_output, mu, logvar, adj)
        
        total_loss += loss.item()
        
                
        preds.append(adj_output.flatten())
        targets.append(adj.flatten())
    
    preds = torch.cat(preds, dim=0).sigmoid().detach().cpu().numpy()
    targets = torch.cat(targets, dim=0).detach().cpu().numpy()
    roc_auc = roc_auc_score(targets, preds)

    print(f"EVAL Loss: {total_loss}, ROC AUC: {roc_auc}")

In [6]:
model = VGAE(
    encoder=encoder,
    decoder=decoder
)

loss_function = VGAELoss(norm=2)

optimizer = AdamW(params=model.parameters(), lr=config["LR"])

In [7]:
for i in range(config["EPOCHS"]):
    train_epoch(
        model=model, 
        loader=train_loader, 
        loss_function=loss_function, 
        optimizer=optimizer,
        device=config["DEVICE"]
    )
    
    eval_epoch(
        model=model, 
        loader=eval_loader, 
        loss_function=loss_function, 
        device=config["DEVICE"]
    )

  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 1956400511.5910645, ROC AUC: 0.5094194370167905


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 2153421.643310547, ROC AUC: 0.5590516156507831


  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 1017277.4483032227, ROC AUC: 0.5376699134056462


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 1785433.9011230469, ROC AUC: 0.565318463668138


  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 1039813.8815917969, ROC AUC: 0.5420008340629882


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 1896260.0495605469, ROC AUC: 0.5631713037618202


  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 3084743.5422973633, ROC AUC: 0.5724000919251624


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 8125558.408203125, ROC AUC: 0.5617781348233422


  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 1014292.0114746094, ROC AUC: 0.5587126258757157


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 7348085.2841796875, ROC AUC: 0.6369905271590965


  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 1041931.3036499023, ROC AUC: 0.6155953556838267


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 5463522.7392578125, ROC AUC: 0.6543351281736994


  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 1008825.8999023438, ROC AUC: 0.6172037906397804


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 3640775.68359375, ROC AUC: 0.6374171065500491


  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 1008460.9428710938, ROC AUC: 0.6168987419226757


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 3804952.9736328125, ROC AUC: 0.6673395776282218


  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 1007312.3347167969, ROC AUC: 0.6178571176117192


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 3727371.849609375, ROC AUC: 0.6586989108474455


  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 1007379.8637695312, ROC AUC: 0.6222208583774522


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 2080051.7219238281, ROC AUC: 0.6835265984248047
