# Exploring the Impacts of Architecture and Scale on GNN Performance on Relational Data
By: Joseph Guman, Atindra Jha, and Christopher Pondoc

## Introduction
This notebook will go over taking the existing Relbench tutorial on training a model and allowing the user to explore different architectures. In particular, we'll look at:
- Focusing on entity classification instead of link prediction tasks.
- Using different embedding model and GNN architectures.
- Transfer learning to apply from one task to another.

## Training a Model
We'll first set up the necessary packages and load in our data.

In [1]:
# Install required packages.
import os
import torch
import numpy as np

from torch.nn import BCEWithLogitsLoss, L1Loss
from relbench.datasets import get_dataset
from relbench.tasks import get_task
from src.tasks.tasks import initialize_task

# Set up dataset and task, define metrics and loss
dataset, task, train_table, val_table, test_table = initialize_task("rel-f1", "driver-dnf")

out_channels = 1
loss_fn = BCEWithLogitsLoss()
tune_metric = "roc_auc"
higher_is_better = True

We can also do some more bookkeeping.

In [2]:
import os
import math
import numpy as np
from tqdm import tqdm

import torch
import torch_geometric
import torch_frame

# Some book keeping
from torch_geometric.seed import seed_everything

seed_everything(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)  # check that it's cuda if you want it to run in reasonable time!
root_dir = "./data"

cuda


  from .autonotebook import tqdm as notebook_tqdm


We can first build a graph out of the database.

In [3]:
from relbench.modeling.utils import get_stype_proposal

db = dataset.get_db()
col_to_stype_dict = get_stype_proposal(db)
col_to_stype_dict

Loading Database object from /home/cpondoc/.cache/relbench/rel-f1/db...
Done in 0.02 seconds.


{'standings': {'driverStandingsId': <stype.numerical: 'numerical'>,
  'raceId': <stype.numerical: 'numerical'>,
  'driverId': <stype.numerical: 'numerical'>,
  'points': <stype.numerical: 'numerical'>,
  'position': <stype.numerical: 'numerical'>,
  'wins': <stype.numerical: 'numerical'>,
  'date': <stype.timestamp: 'timestamp'>},
 'constructors': {'constructorId': <stype.numerical: 'numerical'>,
  'constructorRef': <stype.text_embedded: 'text_embedded'>,
  'name': <stype.text_embedded: 'text_embedded'>,
  'nationality': <stype.text_embedded: 'text_embedded'>},
 'circuits': {'circuitId': <stype.numerical: 'numerical'>,
  'circuitRef': <stype.text_embedded: 'text_embedded'>,
  'name': <stype.text_embedded: 'text_embedded'>,
  'location': <stype.text_embedded: 'text_embedded'>,
  'country': <stype.text_embedded: 'text_embedded'>,
  'lat': <stype.numerical: 'numerical'>,
  'lng': <stype.numerical: 'numerical'>,
  'alt': <stype.numerical: 'numerical'>},
 'qualifying': {'qualifyId': <stype.

Let's also define our text encoding model.

In [4]:
from src.embeddings.glove import GloveTextEmbedding
from src.embeddings.bert import BertTextEmbedding


We can now make our primary key and foreign key graph.

In [5]:
from torch_frame.config.text_embedder import TextEmbedderConfig
from relbench.modeling.graph import make_pkey_fkey_graph

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=BertTextEmbedding(device=device), batch_size=128
)

data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,  # speficied column types
    text_embedder_cfg=text_embedder_cfg,  # our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-f1_materialized_cache"
    ),  # store materialized graph for convenience
)

No sentence-transformers model found with name google-bert/bert-base-uncased. Creating a new one with mean pooling.
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)


Set up our data lodaers.

In [6]:
from relbench.modeling.graph import get_node_train_table_input, make_pkey_fkey_graph
from torch_geometric.loader import NeighborLoader

loader_dict = {}

for split, table in [
    ("train", train_table),
    ("val", val_table),
    ("test", test_table),
]:
    table_input = get_node_train_table_input(
        table=table,
        task=task,
    )
    entity_table = table_input.nodes[0]
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=[
            128 for i in range(2)
        ],  # we sample subgraphs of depth 2, 128 neighbors per node.
        time_attr="time",
        input_nodes=table_input.nodes,
        input_time=table_input.time,
        transform=table_input.transform,
        batch_size=512,
        temporal_strategy="uniform",
        shuffle=split == "train",
        num_workers=0,
        persistent_workers=False,
    )

Define our model.

In [None]:
from src.models.rdl import RDLModel
import copy

model = RDLModel(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)


# if you try out different RelBench tasks you will need to change these
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
epochs = 10

Train/Test Loops.

In [8]:
def train() -> float:
    model.train()

    loss_accum = count_accum = 0
    for batch in tqdm(loader_dict["train"]):
        batch = batch.to(device)

        optimizer.zero_grad()
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred

        loss = loss_fn(pred.float(), batch[entity_table].y.float())
        loss.backward()
        optimizer.step()

        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    return loss_accum / count_accum


@torch.no_grad()
def test(loader: NeighborLoader) -> np.ndarray:
    model.eval()

    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0).numpy()

Standardize the training loop.

In [9]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
for epoch in range(1, epochs + 1):
    train_loss = train()
    val_pred = test(loader_dict["val"])
    val_metrics = task.evaluate(val_pred, val_table)
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(model.state_dict())


model.load_state_dict(state_dict)
val_pred = test(loader_dict["val"])
val_metrics = task.evaluate(val_pred, val_table)
print(f"Best Val metrics: {val_metrics}")

test_pred = test(loader_dict["test"])
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

100%|██████████| 23/23 [00:02<00:00,  8.04it/s]


Epoch: 01, Train loss: 0.3801887325714644, Val metrics: {'average_precision': np.float64(0.8355929555342954), 'accuracy': 0.7791519434628975, 'f1': np.float64(0.8758689175769613), 'roc_auc': np.float64(0.5977687074829933)}


100%|██████████| 23/23 [00:02<00:00,  8.80it/s]


Epoch: 02, Train loss: 0.3479877723844993, Val metrics: {'average_precision': np.float64(0.8732252620759755), 'accuracy': 0.7314487632508834, 'f1': np.float64(0.8376068376068376), 'roc_auc': np.float64(0.6647619047619048)}


100%|██████████| 23/23 [00:02<00:00,  8.97it/s]


Epoch: 03, Train loss: 0.31301642970411586, Val metrics: {'average_precision': np.float64(0.8874774948153648), 'accuracy': 0.7614840989399293, 'f1': np.float64(0.8612538540596094), 'roc_auc': np.float64(0.679891156462585)}


100%|██████████| 23/23 [00:02<00:00,  8.92it/s]


Epoch: 04, Train loss: 0.30730570314742445, Val metrics: {'average_precision': np.float64(0.8916255989833356), 'accuracy': 0.7385159010600707, 'f1': np.float64(0.8448637316561844), 'roc_auc': np.float64(0.6831564625850339)}


100%|██████████| 23/23 [00:02<00:00,  8.92it/s]


Epoch: 05, Train loss: 0.3008378823208209, Val metrics: {'average_precision': np.float64(0.8938010756908914), 'accuracy': 0.7491166077738516, 'f1': np.float64(0.8526970954356846), 'roc_auc': np.float64(0.6840090702947845)}


100%|██████████| 23/23 [00:02<00:00,  8.89it/s]


Epoch: 06, Train loss: 0.2986423067625143, Val metrics: {'average_precision': np.float64(0.8913826747810455), 'accuracy': 0.7173144876325088, 'f1': np.float64(0.8257080610021786), 'roc_auc': np.float64(0.6806530612244898)}


100%|██████████| 23/23 [00:02<00:00,  9.01it/s]


Epoch: 07, Train loss: 0.2915483071842491, Val metrics: {'average_precision': np.float64(0.8926329149547951), 'accuracy': 0.7720848056537103, 'f1': np.float64(0.8695652173913043), 'roc_auc': np.float64(0.6908662131519275)}


100%|██████████| 23/23 [00:02<00:00,  9.07it/s]


Epoch: 08, Train loss: 0.28820174648586033, Val metrics: {'average_precision': np.float64(0.894604479593628), 'accuracy': 0.7314487632508834, 'f1': np.float64(0.8264840182648402), 'roc_auc': np.float64(0.691827664399093)}


100%|██████████| 23/23 [00:02<00:00,  8.96it/s]


Epoch: 09, Train loss: 0.28438007274232713, Val metrics: {'average_precision': np.float64(0.8997137168141898), 'accuracy': 0.7667844522968198, 'f1': np.float64(0.8565217391304348), 'roc_auc': np.float64(0.7076462585034013)}


100%|██████████| 23/23 [00:02<00:00,  8.81it/s]


Epoch: 10, Train loss: 0.28583193557120967, Val metrics: {'average_precision': np.float64(0.8982983179086812), 'accuracy': 0.657243816254417, 'f1': np.float64(0.7610837438423645), 'roc_auc': np.float64(0.699374149659864)}
Best Val metrics: {'average_precision': np.float64(0.8995958975528964), 'accuracy': 0.7667844522968198, 'f1': np.float64(0.8565217391304348), 'roc_auc': np.float64(0.7071564625850341)}
Best test metrics: {'average_precision': np.float64(0.8561460655762598), 'accuracy': 0.7236467236467237, 'f1': np.float64(0.8255395683453237), 'roc_auc': np.float64(0.7260332796564679)}
