# Import libraries

In [1]:
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.utils.data import Subset

from tqdm.notebook import tqdm

from torch_geometric.utils import scatter
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import Sequential, GCNConv

from sklearn.model_selection import train_test_split

from model import VGAE, Encoder, Decoder

# Config

In [18]:
config = {
    "DEVICE": "cpu",
    "EPOCHS": 10,
    "LR": 1e-4,
    "BATCH_SIZE": 2,
    "SHUFFLE": True,
    "TEST_SIZE": 0.2,
}

# Import our dataset

In [19]:
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)

train_idx, eval_idx = train_test_split(list(range(len(dataset))), test_size=config["TEST_SIZE"])

train_loader = DataLoader(Subset(dataset, train_idx), batch_size=config["BATCH_SIZE"], shuffle=config["SHUFFLE"])
eval_loader = DataLoader(Subset(dataset, eval_idx), batch_size=config["BATCH_SIZE"], shuffle=config["SHUFFLE"])

# Model

In [20]:
hidden_model = Sequential('x, edge_index, edge_attr', [
    (GCNConv(dataset.num_features, 64), 'x, edge_index -> x1'),
    nn.ReLU(inplace=True),
    (GCNConv(64, 64), 'x1, edge_index -> x2'),
    nn.ReLU(inplace=True),
    (GCNConv(64, 64), 'x2, edge_index -> x3'),
    nn.ReLU(inplace=True),

])

mean_model = Sequential('x, edge_index, edge_attr', [
    (GCNConv(64, 64), 'x, edge_index -> x1'),
    nn.ReLU(inplace=True),
    (GCNConv(64, 64), 'x1, edge_index -> x2'),
    nn.ReLU(inplace=True),

])

std_model = Sequential('x, edge_index, edge_attr', [
    (GCNConv(64, 64), 'x, edge_index -> x1'),
    nn.ReLU(inplace=True),
    (GCNConv(64, 64), 'x1, edge_index -> x2'),
    nn.ReLU(inplace=True),

])

encoder = Encoder(
    hidden_model=hidden_model,
    mean_model=mean_model,
    std_model=std_model
)

decoder = Decoder()

# Training

In [21]:
def train_epoch(model, loader, loss_function, optimizer, device):
    model.train()
    model.to(device)
    loss_function.to(device)
    
    total_loss = 0.
    
    for batch in tqdm(loader):
        adj = torch.zeros((batch.x.shape[0], batch.x.shape[0]))
        adj[batch.edge_index[0], batch.edge_index[1]] = 1
        
        optimizer.zero_grad()
        adj_output, mu, logvar = model(batch)

        loss = loss_function(adj_output.flatten(), adj.flatten())
        
        total_loss += loss.item()
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
    print(f"TRAIN Loss: {total_loss}")
        
        
def eval_epoch(model, loader, loss_function, device):
    model.eval()
    model.to(device)
    loss_function.to(device)
    
    total_loss = 0.
    
    for batch in tqdm(loader): 
        adj = torch.zeros((batch.x.shape[0], batch.x.shape[0]))
        adj[batch.edge_index[0], batch.edge_index[1]] = 1
        
        with torch.no_grad():
            adj_output, mu, logvar = model(batch)
            loss = loss_function(adj_output.flatten(), adj.flatten())
        
        total_loss += loss.item()
    
    print(f"EVAL Loss: {total_loss}")

In [22]:
model = VGAE(
    encoder=encoder,
    decoder=decoder
)

loss_function = CrossEntropyLoss()

optimizer = AdamW(params=model.parameters(), lr=config["LR"])

In [23]:
for i in range(config["EPOCHS"]):
    train_epoch(
        model=model, 
        loader=train_loader, 
        loss_function=loss_function, 
        optimizer=optimizer,
        device=config["DEVICE"]
    )
    
    eval_epoch(
        model=model, 
        loader=eval_loader, 
        loss_function=loss_function, 
        device=config["DEVICE"]
    )

  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 502899.2523803711


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 1160779.2059326172


  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 502791.4272918701


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 755476.1560058594


  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 502570.38720703125


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 682573.5141601562


  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 501179.42193603516


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 695980.3703613281


  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 502565.59017944336


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 689219.4141845703


  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 502426.2296142578


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 730646.8529052734


  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 503060.9334411621


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 662174.1577148438


  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 501806.31732177734


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 611894.8481445312


  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 502155.2572937012


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 564789.8994140625


  0%|          | 0/240 [00:00<?, ?it/s]

TRAIN Loss: 502164.7787475586


  0%|          | 0/60 [00:00<?, ?it/s]

EVAL Loss: 566677.2173461914
