# Exploring the Impacts of Architecture and Scale on GNN Performance on Relational Data
By: Joseph Guman, Atindra Jha, and Christopher Pondoc

## Introduction
Welcome back to Relbench! In this tutorial, we'll dive a bit deeper into the benchmark + Relational Deep Learning and explore several choices around architecture, scale, and generalizability. In particular, we'll look to answer the following questions:

1. Can we train our Relational Deep Learning on one entity classification task and expect strong zero-shot performance on another entity classification task? What happens if we finetune the model?
2. How does our choice of using embedding models to generate expressive node features impact our performance on node classification tasks?
3. How can we alter and/or extend the architecture of our existing Relational Deep Learning model to improve performance on different tasks?

This notebook already assumes you've looked through the tutorials on [loading in data](https://github.com/snap-stanford/relbench/blob/main/tutorials/load_data.ipynb) and [training a model](https://github.com/snap-stanford/relbench/blob/main/tutorials/train_model.ipynb), as our walkthrough uses those guides as a launchpad to explore deeper questions. If you haven't had a chance to look through those notebooks, we suggest starting there first.

With all that being said, let's get started!

## Question 1: Can we generalize?
Let's take a look at our first question, which involves looking at whether our Relational Deep Learning model can generalize to other tasks with/without finetuning.

Let's first start by looking setting up Relbench. As with the other tutorials, we're taking a look at the `rel-f1` dataset and focusing on node classification tasks. We'll begin by training a model on the `driver-dnf` task, which predicts whether a driver will not finish a race in the next month.

In [1]:
from src.tasks.tasks import initialize_task, db_to_graph
import torch
from torch.nn import BCEWithLogitsLoss
from torch_geometric.seed import seed_everything

# Set up dataset and task, define metrics and loss
dataset, task, train_table, val_table, test_table = initialize_task(
    "rel-f1", "driver-dnf"
)
loss_fn = BCEWithLogitsLoss()

# Set up device
seed_everything(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


We can then preprocess all of our Relbench data.

In [2]:
import os
from relbench.modeling.graph import make_pkey_fkey_graph
from torch_frame.config.text_embedder import TextEmbedderConfig
from src.embeddings.glove import GloveTextEmbedding

# Preprocess the database data and set up our text embedder
db, col_to_stype_dict = db_to_graph(dataset)
text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=128
)

# Load in data used to train model
root_dir = "./data"
data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,
    text_embedder_cfg=text_embedder_cfg,
    cache_dir=os.path.join(root_dir, f"rel-f1_materialized_cache"),
)

Loading Database object from /home/cpondoc/.cache/relbench/rel-f1/db...
Done in 0.02 seconds.


  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)


Next, let's load in the data and have our model set up.

In [3]:
from src.models.loader import get_loader
from src.models.rdl import RDLModel

# Set up data loader and model
loader_dict, entity_table = get_loader(train_table, val_table, test_table, task, data)
model = RDLModel(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)

# if you try out different RelBench tasks you will need to change these
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
epochs = 10

Finalize, let's initialize our training run, and evaluate our model!

In [4]:
from src.models.training import eval_model, training_run

# Get model after a training run
state_dict = training_run(
    model, device, optimizer, task, loader_dict, val_table, loss_fn, entity_table
)
model.load_state_dict(state_dict)

# Evaluate on val and test set
eval_model(model, loader_dict, "val", task, device, val_table)
eval_model(model, loader_dict, "test", task, device, None)

100%|██████████| 23/23 [00:02<00:00,  7.95it/s]


Epoch: 01, Train loss: 0.37025071884912797, Val metrics: {'average_precision': np.float64(0.8432619527191028), 'accuracy': 0.7791519434628975, 'f1': np.float64(0.8758689175769613), 'roc_auc': np.float64(0.6106848072562359)}


100%|██████████| 23/23 [00:02<00:00,  9.05it/s]


Epoch: 02, Train loss: 0.3418975639652489, Val metrics: {'average_precision': np.float64(0.8750363898375417), 'accuracy': 0.6484098939929329, 'f1': np.float64(0.7593712212817413), 'roc_auc': np.float64(0.6606258503401361)}


100%|██████████| 23/23 [00:02<00:00,  8.80it/s]


Epoch: 03, Train loss: 0.31509140877152875, Val metrics: {'average_precision': np.float64(0.8834959595947474), 'accuracy': 0.7102473498233216, 'f1': np.float64(0.8185840707964602), 'roc_auc': np.float64(0.668172335600907)}


100%|██████████| 23/23 [00:02<00:00,  8.68it/s]


Epoch: 04, Train loss: 0.3110679211128438, Val metrics: {'average_precision': np.float64(0.8890812660558268), 'accuracy': 0.6625441696113075, 'f1': np.float64(0.7744982290436836), 'roc_auc': np.float64(0.6707664399092971)}


100%|██████████| 23/23 [00:02<00:00,  8.78it/s]


Epoch: 05, Train loss: 0.3071235752489113, Val metrics: {'average_precision': np.float64(0.8896883607978027), 'accuracy': 0.7791519434628975, 'f1': np.float64(0.8758689175769613), 'roc_auc': np.float64(0.6729977324263039)}


100%|██████████| 23/23 [00:02<00:00,  9.04it/s]


Epoch: 06, Train loss: 0.30267065225720624, Val metrics: {'average_precision': np.float64(0.8918245733325141), 'accuracy': 0.7367491166077739, 'f1': np.float64(0.8443051201671892), 'roc_auc': np.float64(0.6773333333333333)}


100%|██████████| 23/23 [00:02<00:00,  8.89it/s]


Epoch: 07, Train loss: 0.3029827275803542, Val metrics: {'average_precision': np.float64(0.8924655196986894), 'accuracy': 0.773851590106007, 'f1': np.float64(0.8712273641851107), 'roc_auc': np.float64(0.6837097505668934)}


100%|██████████| 23/23 [00:02<00:00,  9.00it/s]


Epoch: 08, Train loss: 0.2988869379309724, Val metrics: {'average_precision': np.float64(0.8872553032397764), 'accuracy': 0.6943462897526502, 'f1': np.float64(0.8071348940914158), 'roc_auc': np.float64(0.6737233560090702)}


100%|██████████| 23/23 [00:02<00:00,  8.87it/s]


Epoch: 09, Train loss: 0.2938899123959094, Val metrics: {'average_precision': np.float64(0.8948022090436293), 'accuracy': 0.6819787985865724, 'f1': np.float64(0.7841726618705036), 'roc_auc': np.float64(0.6868390022675737)}


100%|██████████| 23/23 [00:02<00:00,  8.93it/s]


Epoch: 10, Train loss: 0.2886922211630214, Val metrics: {'average_precision': np.float64(0.8937628907691045), 'accuracy': 0.7120141342756183, 'f1': np.float64(0.809801633605601), 'roc_auc': np.float64(0.7021315192743764)}
Best val metrics: {'average_precision': np.float64(0.8938356094742296), 'accuracy': 0.7137809187279152, 'f1': np.float64(0.8111888111888111), 'roc_auc': np.float64(0.7024399092970521)}
Best test metrics: {'average_precision': np.float64(0.839302596227602), 'accuracy': 0.7364672364672364, 'f1': np.float64(0.8202137998056366), 'roc_auc': np.float64(0.7271751329722346)}


As we can see, we are able to roughly replicate the results from the [core Relbench paper](https://huggingface.co/spaces/relbench/leaderboard). However, do the results generalize? To do so, let's load in the data for the other entity classification task within `rel-f1` -- `driver-top3` -- and see how we do. 

In [None]:
# Reuse functions to set up `driver-top3 task`
dataset, task, train_table, val_table, test_table = initialize_task(
    "rel-f1", "driver-top3"
)
db, col_to_stype_dict = db_to_graph(dataset)
data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,
    text_embedder_cfg=text_embedder_cfg,
    cache_dir=os.path.join(root_dir, f"rel-f1_materialized_cache"),
)

loader_dict, entity_table = get_loader(train_table, val_table, test_table, task, data)
model = RDLModel(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)
model.load_state_dict(state_dict)
eval_model(model, loader_dict, "test", task, device, None)

  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)


Best test metrics: {'average_precision': np.float64(0.10873607399424712), 'accuracy': 0.20110192837465565, 'f1': np.float64(0.14705882352941177), 'roc_auc': np.float64(0.19975438963210704)}


Unfortunately, trying out our model zero-shot does not yield amazing results. However, what happens if we use this model as a starting point for finetuning on the task? Let's experiment on fine-tuning this model with fewer epochs on the `driver-top3` task and checking its performance.

In [None]:
# Get model after a training run
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
state_dict = training_run(
    model,
    device,
    optimizer,
    task,
    loader_dict,
    val_table,
    loss_fn,
    entity_table,
    epochs=5,
    state_dict=state_dict,
)
model.load_state_dict(state_dict)

# Evaluate on val and test set
eval_model(model, loader_dict, "val", task, device, val_table)
eval_model(model, loader_dict, "test", task, device, None)

100%|██████████| 3/3 [00:00<00:00,  6.96it/s]


Epoch: 01, Train loss: 1.3868744648216569, Val metrics: {'average_precision': np.float64(0.16442445711205755), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.42697676085359515)}


100%|██████████| 3/3 [00:00<00:00,  7.02it/s]


Epoch: 02, Train loss: 0.49178143234404476, Val metrics: {'average_precision': np.float64(0.1778246280585124), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.47929619609037644)}


100%|██████████| 3/3 [00:00<00:00,  7.35it/s]


Epoch: 03, Train loss: 0.47279221674114885, Val metrics: {'average_precision': np.float64(0.20232979236434118), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.5525971582662916)}


100%|██████████| 3/3 [00:00<00:00,  7.48it/s]


Epoch: 04, Train loss: 0.47120197199929315, Val metrics: {'average_precision': np.float64(0.2657992778302137), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.6696708534159932)}


100%|██████████| 3/3 [00:00<00:00,  7.52it/s]


Epoch: 05, Train loss: 0.45727679278792405, Val metrics: {'average_precision': np.float64(0.2528989572724485), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.6369353711633907)}


100%|██████████| 3/3 [00:00<00:00,  7.50it/s]


Epoch: 06, Train loss: 0.44516043546723155, Val metrics: {'average_precision': np.float64(0.3230162360973011), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.6685062084535306)}


100%|██████████| 3/3 [00:00<00:00,  7.65it/s]


Epoch: 07, Train loss: 0.4429805654335797, Val metrics: {'average_precision': np.float64(0.34932005810440586), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.6888606188744155)}


100%|██████████| 3/3 [00:00<00:00,  7.06it/s]


Epoch: 08, Train loss: 0.43081866012003717, Val metrics: {'average_precision': np.float64(0.3564335527157481), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.7169554388919747)}


100%|██████████| 3/3 [00:00<00:00,  7.49it/s]


Epoch: 09, Train loss: 0.41196173751081083, Val metrics: {'average_precision': np.float64(0.3794067219999968), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.7339234201143143)}


100%|██████████| 3/3 [00:00<00:00,  7.42it/s]


Epoch: 10, Train loss: 0.38334668537289324, Val metrics: {'average_precision': np.float64(0.40917188453781744), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.753023597498701)}
Best val metrics: {'average_precision': np.float64(0.4088931016471459), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.7529519270394726)}
Best test metrics: {'average_precision': np.float64(0.311201367679585), 'accuracy': 0.8236914600550964, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.7520249790969901)}


Nice! It looks like after we finetune even after just one epoch. we're able to practically replicate the Relbench results. Finally, let's compare this approach to simply training on the task from scratch.

In [8]:
# Define a new model, don't load in old weights.
base_model = RDLModel(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)
base_optimizer = torch.optim.Adam(base_model.parameters(), lr=0.005)
base_state_dict = training_run(
    base_model,
    device,
    base_optimizer,
    task,
    loader_dict,
    val_table,
    loss_fn,
    entity_table,
    epochs=10,
    state_dict=state_dict,
)
base_model.load_state_dict(base_state_dict)

# Evaluate on val and test set
eval_model(base_model, loader_dict, "val", task, device, val_table)
eval_model(base_model, loader_dict, "test", task, device, None)

100%|██████████| 3/3 [00:00<00:00,  4.79it/s]


Epoch: 01, Train loss: 0.602195169272814, Val metrics: {'average_precision': np.float64(0.25376087186368423), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.6178083173567934)}


100%|██████████| 3/3 [00:00<00:00,  7.50it/s]


Epoch: 02, Train loss: 0.4583603228495373, Val metrics: {'average_precision': np.float64(0.25228851988473067), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.6213560050886024)}


100%|██████████| 3/3 [00:00<00:00,  7.49it/s]


Epoch: 03, Train loss: 0.446312434943799, Val metrics: {'average_precision': np.float64(0.23957765559346245), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.6222070917919407)}


100%|██████████| 3/3 [00:00<00:00,  7.49it/s]


Epoch: 04, Train loss: 0.4411518036966754, Val metrics: {'average_precision': np.float64(0.24985092760374591), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.6372578882299188)}


100%|██████████| 3/3 [00:00<00:00,  7.52it/s]


Epoch: 05, Train loss: 0.438761846761922, Val metrics: {'average_precision': np.float64(0.2512923460194664), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.6352511153715217)}


100%|██████████| 3/3 [00:00<00:00,  7.14it/s]


Epoch: 06, Train loss: 0.43432464902698775, Val metrics: {'average_precision': np.float64(0.2866928200572756), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.6377774990593251)}


100%|██████████| 3/3 [00:00<00:00,  7.48it/s]


Epoch: 07, Train loss: 0.4265849729511179, Val metrics: {'average_precision': np.float64(0.29608797000399273), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.6473096701367114)}


100%|██████████| 3/3 [00:00<00:00,  7.40it/s]


Epoch: 08, Train loss: 0.40803989096209464, Val metrics: {'average_precision': np.float64(0.29748833887707427), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.6353048682159431)}


100%|██████████| 3/3 [00:00<00:00,  7.42it/s]


Epoch: 09, Train loss: 0.3913779138681718, Val metrics: {'average_precision': np.float64(0.34771619293922523), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.6886993603411514)}


100%|██████████| 3/3 [00:00<00:00,  7.49it/s]


Epoch: 10, Train loss: 0.3704387682372345, Val metrics: {'average_precision': np.float64(0.3719217226228415), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.7199655981795703)}
Best val metrics: {'average_precision': np.float64(0.37211064089080803), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.7200551862536059)}
Best test metrics: {'average_precision': np.float64(0.33369004182560236), 'accuracy': 0.8236914600550964, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.7750901442307692)}


Ultimately, we don't see much of a difference from starting from random weights to using a model pre-initialized from another entity classification task.

## Question 2: Different expressiveness of node features?
Next, let's take a look at using different embedding models for node features.