# Exploring the Impacts of Architecture and Scale on GNN Performance on Relational Data
By: Joseph Guman, Atindra Jha, and Christopher Pondoc

## Introduction
Welcome back to Relbench! In this tutorial, we'll dive a bit deeper into the benchmark + Relational Deep Learning and explore several choices around architecture, scale, and generalizability. In particular, we'll look to answer the following questions:

1. Can we train our Relational Deep Learning on one entity classification task and expect strong zero-shot performance on another entity classification task? What happens if we finetune the model?
2. How does our choice of using embedding models to generate expressive node features impact our performance on node classification tasks?
3. How can we alter and/or extend the architecture of our existing Relational Deep Learning model to improve performance on different tasks?

This notebook already assumes you've looked through the tutorials on [loading in data](https://github.com/snap-stanford/relbench/blob/main/tutorials/load_data.ipynb) and [training a model](https://github.com/snap-stanford/relbench/blob/main/tutorials/train_model.ipynb), as our walkthrough uses those guides as a launchpad to explore deeper questions. If you haven't had a chance to look through those notebooks, we suggest starting there first.

With all that being said, let's get started!

## Question 1: Can we generalize?
Let's take a look at our first question, which involves looking at whether our Relational Deep Learning model can generalize to other tasks with/without finetuning.

Let's first start by looking setting up Relbench. As with the other tutorials, we're taking a look at the `rel-f1` dataset and focusing on node classification tasks. We'll begin by training a model on the `driver-dnf` task, which predicts whether a driver will not finish a race in the next month.

In [1]:
from src.tasks.tasks import initialize_task, db_to_graph
import torch
from torch.nn import BCEWithLogitsLoss
from torch_geometric.seed import seed_everything

# Set up dataset and task, define metrics and loss
dataset, task, train_table, val_table, test_table = initialize_task("rel-f1", "driver-dnf")
loss_fn = BCEWithLogitsLoss()

# Set up device
seed_everything(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


We can then preprocess all of our Relbench data.

In [2]:
import os
from relbench.modeling.graph import make_pkey_fkey_graph
from torch_frame.config.text_embedder import TextEmbedderConfig
from src.embeddings.glove import GloveTextEmbedding

# Preprocess the database data and set up our text embedder
db, col_to_stype_dict = db_to_graph(dataset)
text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=128
)

# Load in data used to train model
root_dir = "./data"
data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,
    text_embedder_cfg=text_embedder_cfg,
    cache_dir=os.path.join(
        root_dir, f"rel-f1_materialized_cache"
    ),
)

Loading Database object from /home/cpondoc/.cache/relbench/rel-f1/db...
Done in 0.02 seconds.


  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)


Next, let's load in the data and have our model set up.

In [3]:
from src.models.loader import get_loader
from src.models.rdl import RDLModel
import copy

# Set up data loader and model
loader_dict, entity_table = get_loader(train_table, val_table, test_table, task, data)
model = RDLModel(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)

# if you try out different RelBench tasks you will need to change these
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
epochs = 10

Finalize, let's initialize our training run, and evaluate our model!

In [6]:
from src.models.training import eval_model, training_run

# Get model after a training run
state_dict = training_run(model, device, optimizer, task, loader_dict, val_table, loss_fn, entity_table)
model.load_state_dict(state_dict)

# Evaluate on val and test set
eval_model(model, loader_dict, "val", task, device, val_table)
eval_model(model, loader_dict, "test", task, device, None)

100%|██████████| 23/23 [00:02<00:00,  7.75it/s]


Epoch: 01, Train loss: 0.2742630013000034, Val metrics: {'average_precision': np.float64(0.9091588363415968), 'accuracy': 0.7332155477031802, 'f1': np.float64(0.8262370540851554), 'roc_auc': np.float64(0.7296507936507937)}


100%|██████████| 23/23 [00:02<00:00,  8.72it/s]


Epoch: 02, Train loss: 0.27137437623166233, Val metrics: {'average_precision': np.float64(0.9196102712308644), 'accuracy': 0.7402826855123675, 'f1': np.float64(0.8243727598566308), 'roc_auc': np.float64(0.7690702947845806)}


100%|██████████| 23/23 [00:02<00:00,  8.81it/s]


Epoch: 03, Train loss: 0.26706399297763334, Val metrics: {'average_precision': np.float64(0.9194486869901948), 'accuracy': 0.7526501766784452, 'f1': np.float64(0.8360655737704918), 'roc_auc': np.float64(0.7669841269841271)}


100%|██████████| 23/23 [00:02<00:00,  8.96it/s]


Epoch: 04, Train loss: 0.2641804050202019, Val metrics: {'average_precision': np.float64(0.91130291019319), 'accuracy': 0.7367491166077739, 'f1': np.float64(0.8253223915592028), 'roc_auc': np.float64(0.751891156462585)}


100%|██████████| 23/23 [00:02<00:00,  8.94it/s]


Epoch: 05, Train loss: 0.2631346625415066, Val metrics: {'average_precision': np.float64(0.904638169380007), 'accuracy': 0.7455830388692579, 'f1': np.float64(0.8329466357308585), 'roc_auc': np.float64(0.7377233560090701)}


100%|██████████| 23/23 [00:02<00:00,  8.96it/s]


Epoch: 06, Train loss: 0.2680998766516816, Val metrics: {'average_precision': np.float64(0.9105063758182084), 'accuracy': 0.8021201413427562, 'f1': np.float64(0.8798283261802575), 'roc_auc': np.float64(0.7572607709750566)}


100%|██████████| 23/23 [00:02<00:00,  8.99it/s]


Epoch: 07, Train loss: 0.2664446948538026, Val metrics: {'average_precision': np.float64(0.9038685271605801), 'accuracy': 0.6908127208480566, 'f1': np.float64(0.7809762202753442), 'roc_auc': np.float64(0.7385034013605443)}


100%|██████████| 23/23 [00:02<00:00,  8.96it/s]


Epoch: 08, Train loss: 0.2609164908054301, Val metrics: {'average_precision': np.float64(0.9091801190972069), 'accuracy': 0.7279151943462897, 'f1': np.float64(0.8183962264150944), 'roc_auc': np.float64(0.7463764172335601)}


100%|██████████| 23/23 [00:02<00:00,  8.95it/s]


Epoch: 09, Train loss: 0.2581679197179269, Val metrics: {'average_precision': np.float64(0.9061994948069243), 'accuracy': 0.7314487632508834, 'f1': np.float64(0.8186157517899761), 'roc_auc': np.float64(0.7440907029478457)}


100%|██████████| 23/23 [00:02<00:00,  8.97it/s]


Epoch: 10, Train loss: 0.25288130452521956, Val metrics: {'average_precision': np.float64(0.9096831112598548), 'accuracy': 0.7332155477031802, 'f1': np.float64(0.8200238379022646), 'roc_auc': np.float64(0.7504943310657595)}
Best val metrics: {'average_precision': np.float64(0.9195556345244392), 'accuracy': 0.7402826855123675, 'f1': np.float64(0.8243727598566308), 'roc_auc': np.float64(0.7689795918367347)}
Best test metrics: {'average_precision': np.float64(0.8424542429485771), 'accuracy': 0.6766381766381766, 'f1': np.float64(0.7686034658511722), 'roc_auc': np.float64(0.7116917971990436)}
