# Exploring the Impacts of Architecture and Scale on GNN Performance on Relational Data
By: Joseph Guman, Atindra Jha, and Christopher Pondoc

## Introduction
Welcome back to Relbench! In this tutorial, we'll dive a bit deeper into the benchmark + Relational Deep Learning and explore several choices around architecture, scale, and generalizability. In particular, we'll look to answer the following questions:

1. Can we train our Relational Deep Learning on one entity classification task and expect strong zero-shot performance on another entity classification task? What happens if we finetune the model?
2. How does our choice of using embedding models to generate expressive node features impact our performance on node classification tasks?
3. How can we alter and/or extend the architecture of our existing Relational Deep Learning model to improve performance on different tasks?

This notebook already assumes you've looked through the tutorials on [loading in data](https://github.com/snap-stanford/relbench/blob/main/tutorials/load_data.ipynb) and [training a model](https://github.com/snap-stanford/relbench/blob/main/tutorials/train_model.ipynb), as our walkthrough uses those guides as a launchpad to explore deeper questions. If you haven't had a chance to look through those notebooks, we suggest starting there first.

With all that being said, let's get started!

In [1]:
# Adding all imports here, to focus on general code below.
import copy
import os
import torch
import torch_geometric
from torch_geometric.seed import seed_everything
import torch_frame
import numpy as np
from torch.nn import BCEWithLogitsLoss, L1Loss
from src.tasks.tasks import initialize_task, db_to_graph
import math
import numpy as np
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


## Question 1: Can we generalize?
Let's take a look at our first question, which involves looking at whether our Relational Deep Learning model can generalize to other tasks with/without finetuning.

Let's first start by looking setting up Relbench. As with the other tutorials, we're taking a look at the `rel-f1` dataset and focusing on node classification tasks. We'll begin by training a model on the `driver-dnf` task, which predicts whether a driver will not finish a race in the next month.

In [2]:
# Set up dataset and task, define metrics and loss
dataset, task, train_table, val_table, test_table = initialize_task("rel-f1", "driver-dnf")
out_channels = 1
loss_fn = BCEWithLogitsLoss()
tune_metric = "roc_auc"
higher_is_better = True

# Set up device
seed_everything(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

We can then preprocess all of our Relbench data.

In [3]:
from torch_frame.config.text_embedder import TextEmbedderConfig
from relbench.modeling.graph import make_pkey_fkey_graph
from src.embeddings.glove import GloveTextEmbedding

# Preprocess the database data and set up our text embedder
db, col_to_stype_dict = db_to_graph(dataset)
text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=128
)

# Load in data used to train model
root_dir = "./data"
data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,
    text_embedder_cfg=text_embedder_cfg,
    cache_dir=os.path.join(
        root_dir, f"rel-f1_materialized_cache"
    ),
)

Loading Database object from /home/cpondoc/.cache/relbench/rel-f1/db...
Done in 0.02 seconds.


  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)


Next, let's load in the data and have our model set up.

In [5]:
from src.models.loader import get_loader
from src.models.rdl import RDLModel
import copy

# Set up data loader and model
loader_dict, entity_table = get_loader(train_table, val_table, test_table, task, data)
model = RDLModel(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)

# if you try out different RelBench tasks you will need to change these
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
epochs = 10

## Training a Model
We'll first set up the necessary packages and load in our data.

Set up our data lodaers.

In [None]:
from relbench.modeling.graph import get_node_train_table_input, make_pkey_fkey_graph
from torch_geometric.loader import NeighborLoader

loader_dict = {}

for split, table in [
    ("train", train_table),
    ("val", val_table),
    ("test", test_table),
]:
    table_input = get_node_train_table_input(
        table=table,
        task=task,
    )
    entity_table = table_input.nodes[0]
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=[
            128 for i in range(2)
        ],  # we sample subgraphs of depth 2, 128 neighbors per node.
        time_attr="time",
        input_nodes=table_input.nodes,
        input_time=table_input.time,
        transform=table_input.transform,
        batch_size=512,
        temporal_strategy="uniform",
        shuffle=split == "train",
        num_workers=0,
        persistent_workers=False,
    )

Define our model.

In [None]:
from src.models.rdl import RDLModel
import copy

model = RDLModel(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)


# if you try out different RelBench tasks you will need to change these
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
epochs = 10

Train/Test Loops.

In [7]:
def train() -> float:
    model.train()

    loss_accum = count_accum = 0
    for batch in tqdm(loader_dict["train"]):
        batch = batch.to(device)

        optimizer.zero_grad()
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred

        loss = loss_fn(pred.float(), batch[entity_table].y.float())
        loss.backward()
        optimizer.step()

        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    return loss_accum / count_accum


@torch.no_grad()
def test(loader) -> np.ndarray:
    model.eval()

    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0).numpy()

Standardize the training loop.

In [8]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
for epoch in range(1, epochs + 1):
    train_loss = train()
    val_pred = test(loader_dict["val"])
    val_metrics = task.evaluate(val_pred, val_table)
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(model.state_dict())


model.load_state_dict(state_dict)
val_pred = test(loader_dict["val"])
val_metrics = task.evaluate(val_pred, val_table)
print(f"Best Val metrics: {val_metrics}")

test_pred = test(loader_dict["test"])
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

100%|██████████| 23/23 [00:03<00:00,  7.40it/s]


Epoch: 01, Train loss: 0.3705791100832219, Val metrics: {'average_precision': np.float64(0.8436037599700656), 'accuracy': 0.7791519434628975, 'f1': np.float64(0.8758689175769613), 'roc_auc': np.float64(0.6107755102040816)}


100%|██████████| 23/23 [00:02<00:00,  8.88it/s]


Epoch: 02, Train loss: 0.3423137633271477, Val metrics: {'average_precision': np.float64(0.8799940068195036), 'accuracy': 0.7720848056537103, 'f1': np.float64(0.8695652173913043), 'roc_auc': np.float64(0.6642721088435374)}


100%|██████████| 23/23 [00:02<00:00,  8.88it/s]


Epoch: 03, Train loss: 0.31372833428766483, Val metrics: {'average_precision': np.float64(0.8816025239374385), 'accuracy': 0.7402826855123675, 'f1': np.float64(0.8467153284671532), 'roc_auc': np.float64(0.6661768707482993)}


100%|██████████| 23/23 [00:02<00:00,  8.93it/s]


Epoch: 04, Train loss: 0.31049752491099863, Val metrics: {'average_precision': np.float64(0.8858926490760746), 'accuracy': 0.6855123674911661, 'f1': np.float64(0.8043956043956044), 'roc_auc': np.float64(0.6563265306122449)}


100%|██████████| 23/23 [00:02<00:00,  8.88it/s]


Epoch: 05, Train loss: 0.3096003109938047, Val metrics: {'average_precision': np.float64(0.8864993593927938), 'accuracy': 0.7791519434628975, 'f1': np.float64(0.8758689175769613), 'roc_auc': np.float64(0.6629297052154195)}


100%|██████████| 23/23 [00:02<00:00,  8.74it/s]


Epoch: 06, Train loss: 0.30432194704585736, Val metrics: {'average_precision': np.float64(0.8873364211338444), 'accuracy': 0.7084805653710248, 'f1': np.float64(0.8192771084337349), 'roc_auc': np.float64(0.666485260770975)}


100%|██████████| 23/23 [00:02<00:00,  8.81it/s]


Epoch: 07, Train loss: 0.29998061832145995, Val metrics: {'average_precision': np.float64(0.8903041190279581), 'accuracy': 0.6837455830388692, 'f1': np.float64(0.7906432748538011), 'roc_auc': np.float64(0.6700226757369614)}


100%|██████████| 23/23 [00:02<00:00,  8.91it/s]


Epoch: 08, Train loss: 0.2966564161918947, Val metrics: {'average_precision': np.float64(0.8881872986413166), 'accuracy': 0.734982332155477, 'f1': np.float64(0.8417721518987342), 'roc_auc': np.float64(0.6761179138321995)}


100%|██████████| 23/23 [00:02<00:00,  8.49it/s]


Epoch: 09, Train loss: 0.2957169677794173, Val metrics: {'average_precision': np.float64(0.8833882066348911), 'accuracy': 0.7756183745583038, 'f1': np.float64(0.8718466195761857), 'roc_auc': np.float64(0.6738503401360544)}


100%|██████████| 23/23 [00:02<00:00,  8.53it/s]


Epoch: 10, Train loss: 0.29916005275197455, Val metrics: {'average_precision': np.float64(0.8864954185695677), 'accuracy': 0.6713780918727915, 'f1': np.float64(0.7726161369193154), 'roc_auc': np.float64(0.6777505668934241)}
Best Val metrics: {'average_precision': np.float64(0.8866205997515307), 'accuracy': 0.6713780918727915, 'f1': np.float64(0.7726161369193154), 'roc_auc': np.float64(0.6782403628117913)}
Best test metrics: {'average_precision': np.float64(0.8310967181107907), 'accuracy': 0.717948717948718, 'f1': np.float64(0.7838427947598253), 'roc_auc': np.float64(0.7094032108524861)}
