In [15]:
import torch
from torch_geometric.loader import DataLoader
from pathlib import Path
from tqdm import tqdm
import polars as pl
from gnn_example.graphdataset import GraphDataset
import numpy as np
import time

In [16]:
DATA_DIR = Path("gnn_example") / "data"
TRAIN_PARQUET_FILE = DATA_DIR / "train_data.parquet"
VAL_PARQUET_FILE = DATA_DIR /  "val_data.parquet"
TEST_PARQUET_FILE = DATA_DIR /  "test_sequences.parquet"

files = [TRAIN_PARQUET_FILE, VAL_PARQUET_FILE, TEST_PARQUET_FILE]

EDGE_DISTANCE = 1

In [17]:
print("Loading dataset train...")
train_dataset = GraphDataset(TRAIN_PARQUET_FILE, edge_distance=EDGE_DISTANCE)
print("Loading dataset val...")
val_dataset = GraphDataset(VAL_PARQUET_FILE, edge_distance=EDGE_DISTANCE)

Loading dataset train...
Loading dataset val...


In [18]:
import torch.nn.functional as F

def loss_fn(output, target):
    clipped_target = torch.clip(target, min=0, max=1)
    mses = F.mse_loss(output, clipped_target, reduction='mean')
    return mses

def mae_fn(output, target):
    clipped_target = torch.clip(target, min=0, max=1)
    maes = F.l1_loss(output, clipped_target, reduction='mean')
    return maes

In [19]:
from torch_geometric.nn.models import EdgeCNN

model = EdgeCNN(in_channels=train_dataset.num_features, hidden_channels=128,
                num_layers=2, out_channels=1, act="sigmoid")

In [20]:
model.eval()
print(train_dataset[0].x)
print(train_dataset[0].edge_index)
for i in range(10):
    out = model(train_dataset[0].x, train_dataset[0].edge_index)
    print(i, torch.mean(out), torch.std(out))

tensor([[0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0.,

# Training

In [8]:
from torchinfo import summary
summary(model)

Layer (type:depth-idx)                   Param #
EdgeCNN                                  --
├─Dropout: 1-1                           --
├─ReLU: 1-2                              --
├─ModuleList: 1-3                        --
│    └─EdgeConv: 2-1                     --
│    │    └─MaxAggregation: 3-1          --
│    │    └─MLP: 3-2                     17,664
│    └─EdgeConv: 2-2                     --
│    │    └─MaxAggregation: 3-3          --
│    │    └─MLP: 3-4                     49,408
│    └─EdgeConv: 2-3                     --
│    │    └─MaxAggregation: 3-5          --
│    │    └─MLP: 3-6                     49,408
│    └─EdgeConv: 2-4                     --
│    │    └─MaxAggregation: 3-7          --
│    │    └─MLP: 3-8                     259
├─ModuleList: 1-4                        --
│    └─Identity: 2-5                     --
│    └─Identity: 2-6                     --
│    └─Identity: 2-7                     --
│    └─Identity: 2-8                     --
├─TrimToLayer:

In [7]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.003, weight_decay=5e-4)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_dataset[:2000], batch_size=128, shuffle=True, num_workers=0)

In [None]:

n_epochs = 10


for epoch in range(n_epochs):
    train_losses = []
    train_maes = []
    model.train()
    for batch in (pbar:= tqdm(train_dataloader, position=0, leave=True)):
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        out = torch.squeeze(out)
        loss = loss_fn(out, batch.y)
        mae = mae_fn(out, batch.y)
        loss.backward()
        train_losses.append(loss.detach().cpu().numpy())
        train_maes.append(mae.detach().cpu().numpy())
        optimizer.step() 
        pbar.set_description(f"Epoch {epoch}/{n_epochs} | MSE {np.mean(train_losses):.3f} | MAE: {np.mean(train_maes):.3f}")
    
    val_losses = []
    val_maes = []
    model.eval()
    with torch.no_grad():
        for batch in (pbar:= tqdm(val_dataloader, position=0, leave=True)):
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index)
            out = torch.squeeze(out)
            loss = loss_fn(out, batch.y)
            mae = mae_fn(out, batch.y)
            val_losses.append(loss.detach().cpu().numpy())
            val_maes.append(mae.detach().cpu().numpy())
            pbar.set_description(f"Epoch {epoch}/{n_epochs} | MSE {np.mean(val_losses):.3f} | MAE: {np.mean(val_maes):.3f}")

    print("\n")

    torch.save(model.state_dict(), DATA_DIR.parent / "models" / f"model_epoch_{epoch}.pt")
    with open(DATA_DIR.parent / "models" / "train_summary.txt", "a+") as f:
        f.write(f"Epoch {epoch}/{n_epochs} | MSE {np.mean(train_losses):.3f} | MAE: {np.mean(train_maes):.3f}\n")
        f.write(f"Epoch {epoch}/{n_epochs} | MSE {np.mean(val_losses):.3f} | MAE: {np.mean(val_maes):.3f}\n")

Epoch 0/10 | MSE 0.168 | MAE: 0.352: 100%|██████████| 79/79 [01:00<00:00,  1.31it/s]
Epoch 0/10 | MSE 0.105 | MAE: 0.258: 100%|██████████| 16/16 [00:07<00:00,  2.04it/s]






Epoch 1/10 | MSE 0.095 | MAE: 0.247: 100%|██████████| 79/79 [01:01<00:00,  1.29it/s]
Epoch 1/10 | MSE 0.104 | MAE: 0.251: 100%|██████████| 16/16 [00:07<00:00,  2.03it/s]






Epoch 2/10 | MSE 0.092 | MAE: 0.241: 100%|██████████| 79/79 [01:01<00:00,  1.29it/s]
Epoch 2/10 | MSE 0.106 | MAE: 0.242: 100%|██████████| 16/16 [00:07<00:00,  2.01it/s]






Epoch 3/10 | MSE 0.092 | MAE: 0.240: 100%|██████████| 79/79 [01:01<00:00,  1.28it/s]
Epoch 3/10 | MSE 0.106 | MAE: 0.246: 100%|██████████| 16/16 [00:07<00:00,  2.03it/s]






Epoch 4/10 | MSE 0.092 | MAE: 0.240: 100%|██████████| 79/79 [01:02<00:00,  1.27it/s]
Epoch 4/10 | MSE 0.106 | MAE: 0.244: 100%|██████████| 16/16 [00:07<00:00,  2.00it/s]






Epoch 5/10 | MSE 0.091 | MAE: 0.240: 100%|██████████| 79/79 [01:03<00:00,  1.25it/s]
Epoch 5/10 | MSE 0.109 | MAE: 0.242: 100%|██████████| 16/16 [00:07<00:00,  2.03it/s]






Epoch 6/10 | MSE 0.089 | MAE: 0.237: 100%|██████████| 79/79 [01:02<00:00,  1.26it/s]
Epoch 6/10 | MSE 0.108 | MAE: 0.244: 100%|██████████| 16/16 [00:07<00:00,  2.06it/s]






Epoch 7/10 | MSE 0.084 | MAE: 0.228: 100%|██████████| 79/79 [01:03<00:00,  1.24it/s]
Epoch 7/10 | MSE 0.111 | MAE: 0.255: 100%|██████████| 16/16 [00:07<00:00,  2.04it/s]






Epoch 8/10 | MSE 0.083 | MAE: 0.225: 100%|██████████| 79/79 [01:02<00:00,  1.26it/s]
Epoch 8/10 | MSE 0.110 | MAE: 0.244: 100%|██████████| 16/16 [00:07<00:00,  2.09it/s]






Epoch 9/10 | MSE 0.081 | MAE: 0.223: 100%|██████████| 79/79 [01:04<00:00,  1.22it/s]
Epoch 9/10 | MSE 0.107 | MAE: 0.256: 100%|██████████| 16/16 [00:07<00:00,  2.10it/s]






