# Exploring the Impacts of Architecture and Scale on GNN Performance on Relational Data
By: Joseph Guman, Atindra Jha, and Christopher Pondoc

## Introduction
Welcome back to Relbench! In this tutorial, we'll dive a bit deeper into the benchmark + Relational Deep Learning and explore several choices around architecture, scale, and generalizability. In particular, we'll look to answer the following questions:

1. Can we train our Relational Deep Learning on one entity classification task and expect strong zero-shot performance on another entity classification task? What happens if we finetune the model?
2. How does our choice of using embedding models to generate expressive node features impact our performance on node classification tasks?
3. How can we alter and/or extend the architecture of our existing Relational Deep Learning model to improve performance on different tasks?

This notebook already assumes you've looked through the tutorials on [loading in data](https://github.com/snap-stanford/relbench/blob/main/tutorials/load_data.ipynb) and [training a model](https://github.com/snap-stanford/relbench/blob/main/tutorials/train_model.ipynb), as our walkthrough uses those guides as a launchpad to explore deeper questions. If you haven't had a chance to look through those notebooks, we suggest starting there first.

With all that being said, let's get started!

## Question 1: Can we generalize?
Let's take a look at our first question, which involves looking at whether our Relational Deep Learning model can generalize to other tasks with/without finetuning.

Let's first start by looking setting up Relbench. As with the other tutorials, we're taking a look at the `rel-f1` dataset and focusing on node classification tasks. We'll begin by training a model on the `driver-dnf` task, which predicts whether a driver will not finish a race in the next month.

In [1]:
from src.tasks.tasks import initialize_task, db_to_graph
import torch
from torch.nn import BCEWithLogitsLoss
from torch_geometric.seed import seed_everything

# Set up dataset and task, define metrics and loss
dataset, task, train_table, val_table, test_table = initialize_task(
    "rel-f1", "driver-dnf"
)
loss_fn = BCEWithLogitsLoss()

# Set up device
seed_everything(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


We can then preprocess all of our Relbench data.

In [2]:
import os
from relbench.modeling.graph import make_pkey_fkey_graph
from torch_frame.config.text_embedder import TextEmbedderConfig
from src.embeddings.glove import GloveTextEmbedding

# Preprocess the database data and set up our text embedder
db, col_to_stype_dict = db_to_graph(dataset)
text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=128
)

# Load in data used to train model
root_dir = "./data"
data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,
    text_embedder_cfg=text_embedder_cfg,
    cache_dir=os.path.join(root_dir, f"rel-f1_materialized_cache"),
)

Loading Database object from /home/cpondoc/.cache/relbench/rel-f1/db...
Done in 0.02 seconds.


  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)


Next, let's load in the data and have our model set up.

In [3]:
from src.models.loader import get_loader
from src.models.rdl.graph_sage import RDLModel

# Set up data loader and model
loader_dict, entity_table = get_loader(train_table, val_table, test_table, task, data)
model = RDLModel(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)

# if you try out different RelBench tasks you will need to change these
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
epochs = 10

Finalize, let's initialize our training run, and evaluate our model!

In [4]:
from src.models.training import eval_model, training_run

# Get model after a training run
state_dict = training_run(
    model, device, optimizer, task, loader_dict, val_table, loss_fn, entity_table
)
model.load_state_dict(state_dict)

# Evaluate on val and test set
eval_model(model, loader_dict, "val", task, device, val_table)
eval_model(model, loader_dict, "test", task, device, None)

100%|██████████| 23/23 [00:02<00:00,  7.83it/s]


Epoch: 01, Train loss: 0.3705387778567783, Val metrics: {'average_precision': np.float64(0.8445602886699217), 'accuracy': 0.7791519434628975, 'f1': np.float64(0.8758689175769613), 'roc_auc': np.float64(0.61340589569161)}


100%|██████████| 23/23 [00:02<00:00,  8.81it/s]


Epoch: 02, Train loss: 0.339241144625481, Val metrics: {'average_precision': np.float64(0.8824246844093944), 'accuracy': 0.734982332155477, 'f1': np.float64(0.8407643312101911), 'roc_auc': np.float64(0.6783673469387754)}


100%|██████████| 23/23 [00:02<00:00,  8.78it/s]


Epoch: 03, Train loss: 0.3132456022974758, Val metrics: {'average_precision': np.float64(0.8877369162685687), 'accuracy': 0.6431095406360424, 'f1': np.float64(0.7548543689320388), 'roc_auc': np.float64(0.6638004535147392)}


100%|██████████| 23/23 [00:02<00:00,  8.71it/s]


Epoch: 04, Train loss: 0.3137629664954457, Val metrics: {'average_precision': np.float64(0.8837337509565886), 'accuracy': 0.6501766784452296, 'f1': np.float64(0.7602905569007264), 'roc_auc': np.float64(0.6612789115646259)}


100%|██████████| 23/23 [00:02<00:00,  8.70it/s]


Epoch: 05, Train loss: 0.3080086231978641, Val metrics: {'average_precision': np.float64(0.8871259518097254), 'accuracy': 0.7279151943462897, 'f1': np.float64(0.8354700854700855), 'roc_auc': np.float64(0.6681360544217687)}


100%|██████████| 23/23 [00:02<00:00,  8.81it/s]


Epoch: 06, Train loss: 0.3053495470402107, Val metrics: {'average_precision': np.float64(0.8891425442587507), 'accuracy': 0.6590106007067138, 'f1': np.float64(0.7721369539551358), 'roc_auc': np.float64(0.6683900226757371)}


100%|██████████| 23/23 [00:02<00:00,  8.85it/s]


Epoch: 07, Train loss: 0.30645355049707795, Val metrics: {'average_precision': np.float64(0.8894602472798984), 'accuracy': 0.6519434628975265, 'f1': np.float64(0.7646356033452808), 'roc_auc': np.float64(0.6683537414965987)}


100%|██████████| 23/23 [00:02<00:00,  8.81it/s]


Epoch: 08, Train loss: 0.3030857765207616, Val metrics: {'average_precision': np.float64(0.8869164475195142), 'accuracy': 0.6378091872791519, 'f1': np.float64(0.7427854454203262), 'roc_auc': np.float64(0.6694603174603174)}


100%|██████████| 23/23 [00:02<00:00,  8.78it/s]


Epoch: 09, Train loss: 0.30097401697110593, Val metrics: {'average_precision': np.float64(0.8869307529447668), 'accuracy': 0.6784452296819788, 'f1': np.float64(0.789838337182448), 'roc_auc': np.float64(0.6650340136054422)}


100%|██████████| 23/23 [00:02<00:00,  8.64it/s]


Epoch: 10, Train loss: 0.29329782412802446, Val metrics: {'average_precision': np.float64(0.8873963155089515), 'accuracy': 0.7120141342756183, 'f1': np.float64(0.812428078250863), 'roc_auc': np.float64(0.6679183673469389)}
Best val metrics: {'average_precision': np.float64(0.8824005571542506), 'accuracy': 0.734982332155477, 'f1': np.float64(0.8407643312101911), 'roc_auc': np.float64(0.6782947845804989)}
Best test metrics: {'average_precision': np.float64(0.8373712236678277), 'accuracy': 0.7264957264957265, 'f1': np.float64(0.8254545454545454), 'roc_auc': np.float64(0.6952617967110721)}


As we can see, we are able to roughly replicate the results from the [core Relbench paper](https://huggingface.co/spaces/relbench/leaderboard). However, do the results generalize? To do so, let's load in the data for the other entity classification task within `rel-f1` -- `driver-top3` -- and see how we do. 

In [5]:
# Reuse functions to set up `driver-top3 task`
dataset, task, train_table, val_table, test_table = initialize_task(
    "rel-f1", "driver-top3"
)
db, col_to_stype_dict = db_to_graph(dataset)
data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,
    text_embedder_cfg=text_embedder_cfg,
    cache_dir=os.path.join(root_dir, f"rel-f1_materialized_cache"),
)

loader_dict, entity_table = get_loader(train_table, val_table, test_table, task, data)
model = RDLModel(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)
model.load_state_dict(state_dict)
eval_model(model, loader_dict, "test", task, device, None)

  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)


Best test metrics: {'average_precision': np.float64(0.11380641939650467), 'accuracy': 0.19696969696969696, 'f1': np.float64(0.22781456953642384), 'roc_auc': np.float64(0.23915003135451504)}


Unfortunately, trying out our model zero-shot does not yield amazing results. However, what happens if we use this model as a starting point for finetuning on the task? Let's experiment on fine-tuning this model with fewer epochs on the `driver-top3` task and checking its performance.

In [6]:
# Get model after a training run
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
state_dict = training_run(
    model,
    device,
    optimizer,
    task,
    loader_dict,
    val_table,
    loss_fn,
    entity_table,
    epochs=5,
    state_dict=state_dict,
)
model.load_state_dict(state_dict)

# Evaluate on val and test set
eval_model(model, loader_dict, "val", task, device, val_table)
eval_model(model, loader_dict, "test", task, device, None)

100%|██████████| 3/3 [00:00<00:00,  7.06it/s]


Epoch: 01, Train loss: 1.3057963123871323, Val metrics: {'average_precision': np.float64(0.14596141016259206), 'accuracy': 0.7670068027210885, 'f1': np.float64(0.014388489208633094), 'roc_auc': np.float64(0.3403450932611851)}


100%|██████████| 3/3 [00:00<00:00,  7.28it/s]


Epoch: 02, Train loss: 0.492364590069319, Val metrics: {'average_precision': np.float64(0.147660166536003), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.3480317500134382)}


100%|██████████| 3/3 [00:00<00:00,  7.32it/s]


Epoch: 03, Train loss: 0.4756998234295616, Val metrics: {'average_precision': np.float64(0.18617978634488963), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.5037358226872839)}


100%|██████████| 3/3 [00:00<00:00,  7.45it/s]


Epoch: 04, Train loss: 0.457907124307538, Val metrics: {'average_precision': np.float64(0.24818470383417085), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.6028202325706402)}


100%|██████████| 3/3 [00:00<00:00,  7.46it/s]


Epoch: 05, Train loss: 0.449112638555803, Val metrics: {'average_precision': np.float64(0.2582261868388357), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.6171722420311408)}
Best val metrics: {'average_precision': np.float64(0.25844071115123446), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.617208077260755)}
Best test metrics: {'average_precision': np.float64(0.2731415142474485), 'accuracy': 0.8236914600550964, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.7204157086120402)}


Nice! It looks like after we finetune even after just one epoch. we're able to practically replicate the Relbench results. Finally, let's compare this approach to simply training on the task from scratch.

In [7]:
# Define a new model, don't load in old weights.
base_model = RDLModel(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)
base_optimizer = torch.optim.Adam(base_model.parameters(), lr=0.005)
base_state_dict = training_run(
    base_model,
    device,
    base_optimizer,
    task,
    loader_dict,
    val_table,
    loss_fn,
    entity_table,
    epochs=10,
)
base_model.load_state_dict(base_state_dict)

# Evaluate on val and test set
eval_model(base_model, loader_dict, "val", task, device, val_table)
eval_model(base_model, loader_dict, "test", task, device, None)

100%|██████████| 3/3 [00:00<00:00,  7.01it/s]


Epoch: 01, Train loss: 0.6766921718386659, Val metrics: {'average_precision': np.float64(0.22167779641641927), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.588683234487825)}


100%|██████████| 3/3 [00:00<00:00,  7.32it/s]


Epoch: 02, Train loss: 0.4552399510819209, Val metrics: {'average_precision': np.float64(0.22607865741856337), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.5989141925426887)}


100%|██████████| 3/3 [00:00<00:00,  7.14it/s]


Epoch: 03, Train loss: 0.4486066358052442, Val metrics: {'average_precision': np.float64(0.22778145495906807), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.5990575334611457)}


100%|██████████| 3/3 [00:00<00:00,  7.34it/s]


Epoch: 04, Train loss: 0.4420464193917871, Val metrics: {'average_precision': np.float64(0.23657871893378135), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.6117252871297773)}


100%|██████████| 3/3 [00:00<00:00,  7.31it/s]


Epoch: 05, Train loss: 0.43975554136814055, Val metrics: {'average_precision': np.float64(0.28826219978742496), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.6088584687606386)}


100%|██████████| 3/3 [00:00<00:00,  7.26it/s]


Epoch: 06, Train loss: 0.43636006828334184, Val metrics: {'average_precision': np.float64(0.31479855414086527), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.629159126337102)}


100%|██████████| 3/3 [00:00<00:00,  7.25it/s]


Epoch: 07, Train loss: 0.43033227553564973, Val metrics: {'average_precision': np.float64(0.331351335742997), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.6592248839834441)}


100%|██████████| 3/3 [00:00<00:00,  7.39it/s]


Epoch: 08, Train loss: 0.4114072441570158, Val metrics: {'average_precision': np.float64(0.3600801816641911), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.711920589131175)}


100%|██████████| 3/3 [00:00<00:00,  7.37it/s]


Epoch: 09, Train loss: 0.3698024127497465, Val metrics: {'average_precision': np.float64(0.4190964469924286), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.755173711275555)}


100%|██████████| 3/3 [00:00<00:00,  7.37it/s]


Epoch: 10, Train loss: 0.36252130260752996, Val metrics: {'average_precision': np.float64(0.37695138254729166), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.74582071634624)}
Best val metrics: {'average_precision': np.float64(0.4188869101747401), 'accuracy': 0.7976190476190477, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.7550841232015195)}
Best test metrics: {'average_precision': np.float64(0.33543511977785767), 'accuracy': 0.8236914600550964, 'f1': np.float64(0.0), 'roc_auc': np.float64(0.787860576923077)}


Ultimately, we don't see much of a difference from starting from random weights to using a model pre-initialized from another entity classification task.

### Challenge
Does this trend necessarily work on larger and more diverse datasets? Depending on your compute availability, try out using different datasets, like `rel-amazon`, as well as across different types of tasks!

## Question 2: Different expressiveness of node features?
Next, let's take a look at using different embedding models for node features.

The embedding models are used to help turn the tabular data into usable node features. In the Relbench tutorial, the team uses GloVe embeddings, but the paper also mentions utilizing BERT-style embeddings. In traditional NLP, BERT embeddings are much more popular given that they are contextual -- the vector representation depends on the surrounding words, compared to static embeddings used by GloVe -- and can handle words outside of their vocabulary. In addition, their embedding size is $768$ compared to GloVe's $300$, which introduces an opportunity for more expressiveness.

As an investigation, let's switch out our GloVe embedding model with BERT and retrain a new model from scratch on the `driver-dnf` task.

In [8]:
from src.embeddings.bert import BertTextEmbedding

dataset, task, train_table, val_table, test_table = initialize_task(
    "rel-f1", "driver-dnf"
)

# Preprocess the database data and set up our text embedder
db, col_to_stype_dict = db_to_graph(dataset)
text_embedder_cfg = TextEmbedderConfig(
    text_embedder=BertTextEmbedding(device=device), batch_size=128
)

# Load in data used to train model
data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,
    text_embedder_cfg=text_embedder_cfg,
    cache_dir=os.path.join(root_dir, f"rel-f1_materialized_cache"),
)
loader_dict, entity_table = get_loader(train_table, val_table, test_table, task, data)

# Initialize new, untrained model using BERT embeddings
bert_model = RDLModel(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)
bert_optimizer = torch.optim.Adam(bert_model.parameters(), lr=0.005)
bert_state_dict = training_run(
    bert_model,
    device,
    bert_optimizer,
    task,
    loader_dict,
    val_table,
    loss_fn,
    entity_table,
    epochs=10,
)
bert_model.load_state_dict(bert_state_dict)

# Evaluate on val and test set
eval_model(bert_model, loader_dict, "val", task, device, val_table)
eval_model(bert_model, loader_dict, "test", task, device, None)

  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
100%|██████████| 23/23 [00:02<00:00,  8.62it/s]


Epoch: 01, Train loss: 0.38312724910414653, Val metrics: {'average_precision': np.float64(0.8345221417962831), 'accuracy': 0.7791519434628975, 'f1': np.float64(0.8758689175769613), 'roc_auc': np.float64(0.5930340136054422)}


100%|██████████| 23/23 [00:02<00:00,  8.63it/s]


Epoch: 02, Train loss: 0.3540534934752266, Val metrics: {'average_precision': np.float64(0.8473015448671046), 'accuracy': 0.7791519434628975, 'f1': np.float64(0.8758689175769613), 'roc_auc': np.float64(0.6246530612244898)}


100%|██████████| 23/23 [00:02<00:00,  8.65it/s]


Epoch: 03, Train loss: 0.3357846071938254, Val metrics: {'average_precision': np.float64(0.8905183265671703), 'accuracy': 0.7491166077738516, 'f1': np.float64(0.8476394849785408), 'roc_auc': np.float64(0.6827029478458049)}


100%|██████████| 23/23 [00:02<00:00,  8.56it/s]


Epoch: 04, Train loss: 0.3116109115942032, Val metrics: {'average_precision': np.float64(0.8965321535022579), 'accuracy': 0.6607773851590106, 'f1': np.float64(0.7664233576642335), 'roc_auc': np.float64(0.687201814058957)}


100%|██████████| 23/23 [00:02<00:00,  8.46it/s]


Epoch: 05, Train loss: 0.30605150242752743, Val metrics: {'average_precision': np.float64(0.8985108863522302), 'accuracy': 0.696113074204947, 'f1': np.float64(0.7971698113207547), 'roc_auc': np.float64(0.697342403628118)}


100%|██████████| 23/23 [00:02<00:00,  8.58it/s]


Epoch: 06, Train loss: 0.3026287253466489, Val metrics: {'average_precision': np.float64(0.9018481327312416), 'accuracy': 0.7491166077738516, 'f1': np.float64(0.8517745302713987), 'roc_auc': np.float64(0.6974149659863945)}


100%|██████████| 23/23 [00:02<00:00,  8.64it/s]


Epoch: 07, Train loss: 0.2972710859361847, Val metrics: {'average_precision': np.float64(0.900533309282381), 'accuracy': 0.7402826855123675, 'f1': np.float64(0.8368479467258602), 'roc_auc': np.float64(0.7018231292517008)}


100%|██████████| 23/23 [00:02<00:00,  8.56it/s]


Epoch: 08, Train loss: 0.2959926546941494, Val metrics: {'average_precision': np.float64(0.9012056339689314), 'accuracy': 0.7614840989399293, 'f1': np.float64(0.8524590163934426), 'roc_auc': np.float64(0.7092607709750567)}


100%|██████████| 23/23 [00:02<00:00,  8.62it/s]


Epoch: 09, Train loss: 0.2905985877021002, Val metrics: {'average_precision': np.float64(0.9032973199191973), 'accuracy': 0.7526501766784452, 'f1': np.float64(0.8409090909090909), 'roc_auc': np.float64(0.7175328798185941)}


100%|██████████| 23/23 [00:02<00:00,  8.61it/s]


Epoch: 10, Train loss: 0.2868953991673829, Val metrics: {'average_precision': np.float64(0.9080292417308474), 'accuracy': 0.7067137809187279, 'f1': np.float64(0.7995169082125604), 'roc_auc': np.float64(0.7248616780045352)}
Best val metrics: {'average_precision': np.float64(0.9080472407878784), 'accuracy': 0.7067137809187279, 'f1': np.float64(0.7995169082125604), 'roc_auc': np.float64(0.7248979591836734)}
Best test metrics: {'average_precision': np.float64(0.84692732763889), 'accuracy': 0.7222222222222222, 'f1': np.float64(0.7940865892291447), 'roc_auc': np.float64(0.731849899965842)}


We ultimately don't see that drastic of a difference between using BERT embeddings and GloVe embeddings. Despite being trained differently, the fact that the models are close in size and perform similarly on [general embedding benchmarks](https://huggingface.co/spaces/mteb/leaderboard) may suggest that the results will not be that drastic. 

### Challenge
We encourage you to try larger models with even larger embedding dimensions -- to do so, use our `CustomTextEmbedding` class! To use this class, import it as below, and then specify the name of a model as used on HuggingFace:

```python
from src.embeddings.custom import CustomTextEmbedding
text_embedder_cfg = TextEmbedderConfig(
    text_embedder=CustomTextEmbedding(model_name=<INSERT_HUGGINGFACE_MODEL_HERE>, device=device), batch_size=128
)
```

## Question 3: Different RDL model architectures?
Finally, we experiment with different RDL model architectures. In particular, we investigate what happens as we add or subtract GNN layers from our model.

First, we double the number of GNN layers in our RDL pipeline, moving from `num_layers=2` to `num_layers=4`. The idea is that by adding more layers, we can create a more expressive network that can understand more complex relationships.

In [9]:
# Define a new model, don't load in old weights.
deep_model = RDLModel(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=4,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)
deep_optimizer = torch.optim.Adam(deep_model.parameters(), lr=0.005)
deep_state_dict = training_run(
    deep_model,
    device,
    deep_optimizer,
    task,
    loader_dict,
    val_table,
    loss_fn,
    entity_table,
    epochs=10,
)
deep_model.load_state_dict(deep_state_dict)

# Evaluate on val and test set
eval_model(deep_model, loader_dict, "val", task, device, val_table)
eval_model(deep_model, loader_dict, "test", task, device, None)

100%|██████████| 23/23 [00:03<00:00,  6.44it/s]


Epoch: 01, Train loss: 0.37630944748345185, Val metrics: {'average_precision': np.float64(0.8354548486707016), 'accuracy': 0.7791519434628975, 'f1': np.float64(0.8758689175769613), 'roc_auc': np.float64(0.584453514739229)}


100%|██████████| 23/23 [00:03<00:00,  6.49it/s]


Epoch: 02, Train loss: 0.35387199274831915, Val metrics: {'average_precision': np.float64(0.8368606888184498), 'accuracy': 0.7791519434628975, 'f1': np.float64(0.8758689175769613), 'roc_auc': np.float64(0.5990204081632653)}


100%|██████████| 23/23 [00:03<00:00,  6.47it/s]


Epoch: 03, Train loss: 0.35184785197424895, Val metrics: {'average_precision': np.float64(0.8451311245963857), 'accuracy': 0.7791519434628975, 'f1': np.float64(0.8758689175769613), 'roc_auc': np.float64(0.614766439909297)}


100%|██████████| 23/23 [00:03<00:00,  6.44it/s]


Epoch: 04, Train loss: 0.3307221185266789, Val metrics: {'average_precision': np.float64(0.88688278486892), 'accuracy': 0.7279151943462897, 'f1': np.float64(0.8378947368421052), 'roc_auc': np.float64(0.6512834467120181)}


100%|██████████| 23/23 [00:03<00:00,  6.45it/s]


Epoch: 05, Train loss: 0.31274610553330184, Val metrics: {'average_precision': np.float64(0.8841466762743575), 'accuracy': 0.7120141342756183, 'f1': np.float64(0.8190899001109878), 'roc_auc': np.float64(0.6808344671201814)}


100%|██████████| 23/23 [00:03<00:00,  6.43it/s]


Epoch: 06, Train loss: 0.3114925918194866, Val metrics: {'average_precision': np.float64(0.8896977751836486), 'accuracy': 0.7049469964664311, 'f1': np.float64(0.8138238573021181), 'roc_auc': np.float64(0.6868208616780045)}


100%|██████████| 23/23 [00:03<00:00,  6.46it/s]


Epoch: 07, Train loss: 0.30794541325559543, Val metrics: {'average_precision': np.float64(0.888868560924413), 'accuracy': 0.7014134275618374, 'f1': np.float64(0.8072976054732041), 'roc_auc': np.float64(0.6837913832199546)}


100%|██████████| 23/23 [00:03<00:00,  6.42it/s]


Epoch: 08, Train loss: 0.3053583731634382, Val metrics: {'average_precision': np.float64(0.8928719671015126), 'accuracy': 0.7102473498233216, 'f1': np.float64(0.8144796380090498), 'roc_auc': np.float64(0.6915374149659864)}


100%|██████████| 23/23 [00:03<00:00,  6.45it/s]


Epoch: 09, Train loss: 0.305654157629858, Val metrics: {'average_precision': np.float64(0.8915923514845823), 'accuracy': 0.7402826855123675, 'f1': np.float64(0.842443729903537), 'roc_auc': np.float64(0.690485260770975)}


100%|██████████| 23/23 [00:03<00:00,  6.45it/s]


Epoch: 10, Train loss: 0.3016997343055749, Val metrics: {'average_precision': np.float64(0.9017659218757775), 'accuracy': 0.7420494699646644, 'f1': np.float64(0.8459915611814346), 'roc_auc': np.float64(0.7103673469387755)}
Best val metrics: {'average_precision': np.float64(0.9016887622946698), 'accuracy': 0.7420494699646644, 'f1': np.float64(0.8459915611814346), 'roc_auc': np.float64(0.7101133786848073)}
Best test metrics: {'average_precision': np.float64(0.8500868239437355), 'accuracy': 0.7236467236467237, 'f1': np.float64(0.8239564428312159), 'roc_auc': np.float64(0.7235592641389743)}


As we see, training using $4$ layers actually makes the model perform worse over time. Thus, given the simplicity of the task and the size of dataset, it is likely that we are overfitting.

Given that using double the amount of layers leads to less optimal results, we can try the opposite strategy and halve the number of layers in the network.

In [10]:
# Define a new model, don't load in old weights.
shallow_model = RDLModel(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=1,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)
shallow_optimizer = torch.optim.Adam(shallow_model.parameters(), lr=0.005)
shallow_state_dict = training_run(
    shallow_model,
    device,
    shallow_optimizer,
    task,
    loader_dict,
    val_table,
    loss_fn,
    entity_table,
    epochs=10,
)
shallow_model.load_state_dict(shallow_state_dict)

# Evaluate on val and test set
eval_model(shallow_model, loader_dict, "val", task, device, val_table)
eval_model(shallow_model, loader_dict, "test", task, device, None)

100%|██████████| 23/23 [00:02<00:00, 11.25it/s]


Epoch: 01, Train loss: 0.36102279100344126, Val metrics: {'average_precision': np.float64(0.8693253390608777), 'accuracy': 0.7544169611307421, 'f1': np.float64(0.8544502617801047), 'roc_auc': np.float64(0.6637460317460318)}


100%|██████████| 23/23 [00:02<00:00, 11.23it/s]


Epoch: 02, Train loss: 0.3243081846461581, Val metrics: {'average_precision': np.float64(0.8885520335913858), 'accuracy': 0.6908127208480566, 'f1': np.float64(0.7943595769682726), 'roc_auc': np.float64(0.6762811791383221)}


100%|██████████| 23/23 [00:02<00:00, 11.21it/s]


Epoch: 03, Train loss: 0.3095765669996069, Val metrics: {'average_precision': np.float64(0.8982087752254366), 'accuracy': 0.726148409893993, 'f1': np.float64(0.8324324324324325), 'roc_auc': np.float64(0.695219954648526)}


100%|██████████| 23/23 [00:02<00:00, 11.25it/s]


Epoch: 04, Train loss: 0.30494005722550777, Val metrics: {'average_precision': np.float64(0.9013646492474375), 'accuracy': 0.7084805653710248, 'f1': np.float64(0.8096885813148789), 'roc_auc': np.float64(0.7025850340136055)}


100%|██████████| 23/23 [00:02<00:00, 11.12it/s]


Epoch: 05, Train loss: 0.3002260228185271, Val metrics: {'average_precision': np.float64(0.9037170583963824), 'accuracy': 0.7014134275618374, 'f1': np.float64(0.8037166085946573), 'roc_auc': np.float64(0.7048526077097506)}


100%|██████████| 23/23 [00:02<00:00, 11.21it/s]


Epoch: 06, Train loss: 0.29445440567727577, Val metrics: {'average_precision': np.float64(0.9013048280403471), 'accuracy': 0.7685512367491166, 'f1': np.float64(0.8659160696008188), 'roc_auc': np.float64(0.7074648526077097)}


100%|██████████| 23/23 [00:02<00:00, 11.14it/s]


Epoch: 07, Train loss: 0.29507051193476996, Val metrics: {'average_precision': np.float64(0.9016249706950684), 'accuracy': 0.6802120141342756, 'f1': np.float64(0.7757125154894672), 'roc_auc': np.float64(0.7034195011337868)}


100%|██████████| 23/23 [00:02<00:00, 11.19it/s]


Epoch: 08, Train loss: 0.290469367433499, Val metrics: {'average_precision': np.float64(0.9050141416982755), 'accuracy': 0.715547703180212, 'f1': np.float64(0.8151549942594719), 'roc_auc': np.float64(0.7138866213151928)}


100%|██████████| 23/23 [00:02<00:00, 11.08it/s]


Epoch: 09, Train loss: 0.2851247564278104, Val metrics: {'average_precision': np.float64(0.9042633733935974), 'accuracy': 0.7226148409893993, 'f1': np.float64(0.8205714285714286), 'roc_auc': np.float64(0.7125260770975057)}


100%|██████████| 23/23 [00:02<00:00, 11.20it/s]


Epoch: 10, Train loss: 0.2823277630630775, Val metrics: {'average_precision': np.float64(0.9052831608611303), 'accuracy': 0.7208480565371025, 'f1': np.float64(0.8240534521158129), 'roc_auc': np.float64(0.7118185941043083)}
Best val metrics: {'average_precision': np.float64(0.9050579385447367), 'accuracy': 0.715547703180212, 'f1': np.float64(0.8151549942594719), 'roc_auc': np.float64(0.7141224489795919)}
Best test metrics: {'average_precision': np.float64(0.8425754218954633), 'accuracy': 0.7193732193732194, 'f1': np.float64(0.8170844939647168), 'roc_auc': np.float64(0.6876982384228763)}


Interestingly enough, we see that even with half the number of layers, we do just about the same as with double the number of layers. Once again, this might be more task-specific as opposed to a general conclusion about GNNs and the RDL pipeline.

### Challenge
Does this trend necessarily hold on larger and more diverse datasets? Depending on your compute availability, try out using different datasets, like `rel-amazon`, as well as across different types of tasks!

## Question 4: What about different graph layers?
We can also just try using different graph layers, featured in PyG. Below is GCN.

In [12]:
# Define a new model, don't load in old weights.
from src.models.rdl.gat import RDLGATModel
gcn_model = RDLGATModel(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)
gcn_optimizer = torch.optim.Adam(gcn_model.parameters(), lr=0.005)
gcn_state_dict = training_run(
    gcn_model,
    device,
    gcn_optimizer,
    task,
    loader_dict,
    val_table,
    loss_fn,
    entity_table,
    epochs=10,
)
gcn_model.load_state_dict(gcn_state_dict)

# Evaluate on val and test set
eval_model(gcn_model, loader_dict, "val", task, device, val_table)
eval_model(gcn_model, loader_dict, "test", task, device, None)

100%|██████████| 23/23 [00:03<00:00,  6.68it/s]


Epoch: 01, Train loss: 0.38052082677610854, Val metrics: {'average_precision': np.float64(0.8433403727188312), 'accuracy': 0.7791519434628975, 'f1': np.float64(0.8758689175769613), 'roc_auc': np.float64(0.6251519274376416)}


100%|██████████| 23/23 [00:03<00:00,  7.23it/s]


Epoch: 02, Train loss: 0.3574965990803618, Val metrics: {'average_precision': np.float64(0.8825352397919566), 'accuracy': 0.7791519434628975, 'f1': np.float64(0.8758689175769613), 'roc_auc': np.float64(0.69340589569161)}


100%|██████████| 23/23 [00:03<00:00,  7.08it/s]


Epoch: 03, Train loss: 0.32906677605978146, Val metrics: {'average_precision': np.float64(0.8893162726600771), 'accuracy': 0.7791519434628975, 'f1': np.float64(0.8758689175769613), 'roc_auc': np.float64(0.6743854875283447)}


100%|██████████| 23/23 [00:03<00:00,  7.21it/s]


Epoch: 04, Train loss: 0.3092604429848657, Val metrics: {'average_precision': np.float64(0.8976818695773845), 'accuracy': 0.7526501766784452, 'f1': np.float64(0.855072463768116), 'roc_auc': np.float64(0.7134512471655329)}


100%|██████████| 23/23 [00:03<00:00,  7.28it/s]


Epoch: 05, Train loss: 0.29846293777005406, Val metrics: {'average_precision': np.float64(0.9036445197945777), 'accuracy': 0.7544169611307421, 'f1': np.float64(0.8500539374325782), 'roc_auc': np.float64(0.7290158730158729)}


100%|██████████| 23/23 [00:03<00:00,  7.26it/s]


Epoch: 06, Train loss: 0.2910831387073713, Val metrics: {'average_precision': np.float64(0.883193775163816), 'accuracy': 0.7579505300353356, 'f1': np.float64(0.8448471121177803), 'roc_auc': np.float64(0.693451247165533)}


100%|██████████| 23/23 [00:03<00:00,  7.11it/s]


Epoch: 07, Train loss: 0.2870623269425654, Val metrics: {'average_precision': np.float64(0.8861264658904081), 'accuracy': 0.7614840989399293, 'f1': np.float64(0.8488241881298992), 'roc_auc': np.float64(0.695655328798186)}


100%|██████████| 23/23 [00:03<00:00,  6.92it/s]


Epoch: 08, Train loss: 0.28393651492160454, Val metrics: {'average_precision': np.float64(0.8972221333887909), 'accuracy': 0.7915194346289752, 'f1': np.float64(0.8739316239316239), 'roc_auc': np.float64(0.7205442176870749)}


100%|██████████| 23/23 [00:03<00:00,  7.02it/s]


Epoch: 09, Train loss: 0.28227804249002203, Val metrics: {'average_precision': np.float64(0.8912351066454262), 'accuracy': 0.765017667844523, 'f1': np.float64(0.8533627342888643), 'roc_auc': np.float64(0.7075736961451247)}


100%|██████████| 23/23 [00:03<00:00,  7.04it/s]


Epoch: 10, Train loss: 0.2759197559206045, Val metrics: {'average_precision': np.float64(0.8906311165803125), 'accuracy': 0.7809187279151943, 'f1': np.float64(0.8655097613882863), 'roc_auc': np.float64(0.704843537414966)}
Best val metrics: {'average_precision': np.float64(0.9036447759347126), 'accuracy': 0.7544169611307421, 'f1': np.float64(0.8500539374325782), 'roc_auc': np.float64(0.729015873015873)}
Best test metrics: {'average_precision': np.float64(0.8237305452344359), 'accuracy': 0.6709401709401709, 'f1': np.float64(0.780209324452902), 'roc_auc': np.float64(0.6854047723612942)}
