<a href="https://colab.research.google.com/github/neel26desai/cmpe297_timegpt_tabula_gnn/blob/main/GAN_with_RelBench.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Overview

RelBench is a comprehensive benchmark designed to facilitate research in deep learning on relational databases. It aims to standardize and streamline the development and evaluation of models that work with data spread across multiple interconnected tables, which is a common structure in relational databases. This benchmark is particularly focused on relational deep learning (RDL), an emerging approach that leverages graph neural networks (GNNs) to model relationships between entities in such databases.

In [None]:
# Install required packages.
!pip install torch==2.4.0
!pip install torch-geometric torch-sparse torch-scatter torch-cluster torch-spline-conv pyg-lib -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
!pip install pytorch_frame
!pip install relbench

In [None]:
!pip install torchvision==0.19.0
!pip install -U sentence-transformers

In [1]:
import os
import torch
import relbench

relbench.__version__

'1.1.0'

In [2]:
import numpy as np

from torch.nn import BCEWithLogitsLoss, L1Loss
from relbench.datasets import get_dataset
from relbench.tasks import get_task


Loading the relational dataset from RelBench. It may involve converting relational data into a graph format where nodes represent entities (e.g., rows in tables) and edges represent relationships (e.g., foreign key links)

In [3]:
dataset = get_dataset("rel-f1", download=True)
task = get_task("rel-f1", "driver-position", download=True)

train_table = task.get_table("train")
val_table = task.get_table("val")
test_table = task.get_table("test")

out_channels = 1
loss_fn = L1Loss()
tune_metric = "mae"
higher_is_better = False

In [4]:
train_table

Table(df=
           date  driverId  position
0    2004-07-05        10     10.75
1    2004-07-05        47     12.00
2    2004-03-07         7     15.00
3    2004-01-07        10      9.00
4    2003-09-09        52     13.00
...         ...       ...       ...
7448 1995-08-22        96     15.75
7449 1975-06-08       228      8.00
7450 1965-05-31       418     16.00
7451 1961-08-20       467     37.00
7452 1954-05-29       677     30.00

[7453 rows x 3 columns],
  fkey_col_to_pkey_table={'driverId': 'drivers'},
  pkey_col=None,
  time_col=date)

In [5]:
import math
import numpy as np
from tqdm import tqdm

import torch
import torch_geometric
import torch_frame

# Some book keeping
from torch_geometric.seed import seed_everything

seed_everything(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)  # check that it's cuda if you want it to run in reasonable time!
root_dir = "./data"

cuda


Getting other data that are related to the driver data

In [6]:
from relbench.modeling.utils import get_stype_proposal

db = dataset.get_db()
col_to_stype_dict = get_stype_proposal(db)
col_to_stype_dict

Loading Database object from /root/.cache/relbench/rel-f1/db...
Done in 0.04 seconds.


{'circuits': {'circuitId': <stype.numerical: 'numerical'>,
  'circuitRef': <stype.text_embedded: 'text_embedded'>,
  'name': <stype.text_embedded: 'text_embedded'>,
  'location': <stype.text_embedded: 'text_embedded'>,
  'country': <stype.text_embedded: 'text_embedded'>,
  'lat': <stype.numerical: 'numerical'>,
  'lng': <stype.numerical: 'numerical'>,
  'alt': <stype.numerical: 'numerical'>},
 'races': {'raceId': <stype.numerical: 'numerical'>,
  'year': <stype.categorical: 'categorical'>,
  'round': <stype.numerical: 'numerical'>,
  'circuitId': <stype.numerical: 'numerical'>,
  'name': <stype.text_embedded: 'text_embedded'>,
  'date': <stype.timestamp: 'timestamp'>,
  'time': <stype.timestamp: 'timestamp'>},
 'results': {'resultId': <stype.numerical: 'numerical'>,
  'raceId': <stype.numerical: 'numerical'>,
  'driverId': <stype.numerical: 'numerical'>,
  'constructorId': <stype.numerical: 'numerical'>,
  'number': <stype.numerical: 'numerical'>,
  'grid': <stype.numerical: 'numerical

Using sentence-transformers/average_word_embeddings_glove.6B.300d as the embedding model

In [7]:
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from torch import Tensor


class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))



Representing the tabular data as graph data

In [8]:
from torch_frame.config.text_embedder import TextEmbedderConfig
from relbench.modeling.graph import make_pkey_fkey_graph

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=256
)

data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,  # speficied column types
    text_embedder_cfg=text_embedder_cfg,  # our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-f1_materialized_cache"
    ),  # store materialized graph for convenience
)

modules.json:   0%|          | 0.00/248 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

(…)WordEmbeddings/wordembedding_config.json:   0%|          | 0.00/164 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/480M [00:00<?, ?B/s]

(…)beddings/whitespacetokenizer_config.json:   0%|          | 0.00/4.61M [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00,  1.03it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 144.44it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 146.48it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 140.90it/s]
  ser = pd.to_datetime(ser, format=time_format)
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 72.60it/s]
  ser = pd.to_datetime(ser, format=self.format, errors='coerce')
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 178.49it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 167.88it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 169.81it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 163.27it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 165.35it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 81.43it/s]

In [9]:
data

HeteroData(
  circuits={ tf=TensorFrame([77, 7]) },
  races={
    tf=TensorFrame([820, 5]),
    time=[820],
  },
  results={
    tf=TensorFrame([20323, 11]),
    time=[20323],
  },
  constructor_standings={
    tf=TensorFrame([10170, 4]),
    time=[10170],
  },
  drivers={ tf=TensorFrame([857, 6]) },
  constructor_results={
    tf=TensorFrame([9408, 2]),
    time=[9408],
  },
  constructors={ tf=TensorFrame([211, 3]) },
  standings={
    tf=TensorFrame([28115, 4]),
    time=[28115],
  },
  qualifying={
    tf=TensorFrame([4082, 3]),
    time=[4082],
  },
  (races, f2p_circuitId, circuits)={ edge_index=[2, 820] },
  (circuits, rev_f2p_circuitId, races)={ edge_index=[2, 820] },
  (results, f2p_raceId, races)={ edge_index=[2, 20323] },
  (races, rev_f2p_raceId, results)={ edge_index=[2, 20323] },
  (results, f2p_driverId, drivers)={ edge_index=[2, 20323] },
  (drivers, rev_f2p_driverId, results)={ edge_index=[2, 20323] },
  (results, f2p_constructorId, constructors)={ edge_index=[2, 20323

In [10]:
data["races"].tf

TensorFrame(
  num_cols=5,
  num_rows=820,
  categorical (1): ['year'],
  numerical (1): ['round'],
  timestamp (2): ['date', 'time'],
  embedding (1): ['name'],
  has_target=False,
  device='cpu',
)

In [11]:
list(data["races"].keys())

['tf', 'time']

In [12]:
data["races"].tf[10]

TensorFrame(
  num_cols=5,
  num_rows=1,
  categorical (1): ['year'],
  numerical (1): ['round'],
  timestamp (2): ['date', 'time'],
  embedding (1): ['name'],
  has_target=False,
  device='cpu',
)

In [13]:
data["races"].tf[10:20]

TensorFrame(
  num_cols=5,
  num_rows=10,
  categorical (1): ['year'],
  numerical (1): ['round'],
  timestamp (2): ['date', 'time'],
  embedding (1): ['name'],
  has_target=False,
  device='cpu',
)

In [14]:
data[("races", "f2p_circuitId", "circuits")]

{'edge_index': tensor([[  0,   1,   2,  ..., 817, 818, 819],
        [  8,   5,  18,  ...,  21,  17,  23]])}

Splitting data into train and test

In [15]:
from relbench.modeling.graph import get_node_train_table_input, make_pkey_fkey_graph
from torch_geometric.loader import NeighborLoader

loader_dict = {}

for split, table in [
    ("train", train_table),
    ("val", val_table),
    ("test", test_table),
]:
    table_input = get_node_train_table_input(
        table=table,
        task=task,
    )
    entity_table = table_input.nodes[0]
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=[
            128 for i in range(2)
        ],  # we sample subgraphs of depth 2, 128 neighbors per node.
        time_attr="time",
        input_nodes=table_input.nodes,
        input_time=table_input.time,
        transform=table_input.transform,
        batch_size=512,
        temporal_strategy="uniform",
        shuffle=split == "train",
        num_workers=0,
        persistent_workers=False,
    )

Specifying the Model architecture using PyTorch

In [16]:
from torch.nn import BCEWithLogitsLoss
import copy
from typing import Any, Dict, List

import torch
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType

from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder


class Model(torch.nn.Module):

    def __init__(
        self,
        data: HeteroData,
        col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
        num_layers: int,
        channels: int,
        out_channels: int,
        aggr: str,
        norm: str,
        # List of node types to add shallow embeddings to input
        shallow_list: List[NodeType] = [],
        # ID awareness
        id_awareness: bool = False,
    ):
        super().__init__()

        self.encoder = HeteroEncoder(
            channels=channels,
            node_to_col_names_dict={
                node_type: data[node_type].tf.col_names_dict
                for node_type in data.node_types
            },
            node_to_col_stats=col_stats_dict,
        )
        self.temporal_encoder = HeteroTemporalEncoder(
            node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
        )
        self.gnn = HeteroGraphSAGE(
            node_types=data.node_types,
            edge_types=data.edge_types,
            channels=channels,
            aggr=aggr,
            num_layers=num_layers,
        )
        self.head = MLP(
            channels,
            out_channels=out_channels,
            norm=norm,
            num_layers=1,
        )
        self.embedding_dict = ModuleDict(
            {
                node: Embedding(data.num_nodes_dict[node], channels)
                for node in shallow_list
            }
        )

        self.id_awareness_emb = None
        if id_awareness:
            self.id_awareness_emb = torch.nn.Embedding(1, channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.encoder.reset_parameters()
        self.temporal_encoder.reset_parameters()
        self.gnn.reset_parameters()
        self.head.reset_parameters()
        for embedding in self.embedding_dict.values():
            torch.nn.init.normal_(embedding.weight, std=0.1)
        if self.id_awareness_emb is not None:
            self.id_awareness_emb.reset_parameters()

    def forward(
        self,
        batch: HeteroData,
        entity_table: NodeType,
    ) -> Tensor:
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
            batch.num_sampled_nodes_dict,
            batch.num_sampled_edges_dict,
        )

        return self.head(x_dict[entity_table][: seed_time.size(0)])

    def forward_dst_readout(
        self,
        batch: HeteroData,
        entity_table: NodeType,
        dst_table: NodeType,
    ) -> Tensor:
        if self.id_awareness_emb is None:
            raise RuntimeError(
                "id_awareness must be set True to use forward_dst_readout"
            )
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)
        # Add ID-awareness to the root node
        x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
        )

        return self.head(x_dict[dst_table])


In [17]:
model = Model(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)


# if you try out different RelBench tasks you will need to change these
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
epochs = 10

Defining functions for training and testing

In [18]:
def train() -> float:
    model.train()

    loss_accum = count_accum = 0
    for batch in tqdm(loader_dict["train"]):
        batch = batch.to(device)

        optimizer.zero_grad()
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred

        loss = loss_fn(pred.float(), batch[entity_table].y.float())
        loss.backward()
        optimizer.step()

        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    return loss_accum / count_accum


@torch.no_grad()
def test(loader: NeighborLoader) -> np.ndarray:
    model.eval()

    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0).numpy()

Training the model

In [19]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
for epoch in range(1, epochs + 1):
    train_loss = train()
    val_pred = test(loader_dict["val"])
    val_metrics = task.evaluate(val_pred, val_table)
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(model.state_dict())


model.load_state_dict(state_dict)
val_pred = test(loader_dict["val"])
val_metrics = task.evaluate(val_pred, val_table)
print(f"Best Val metrics: {val_metrics}")

test_pred = test(loader_dict["test"])
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

100%|██████████| 15/15 [00:04<00:00,  3.12it/s]


Epoch: 01, Train loss: 8.842738806761528, Val metrics: {'r2': -0.16285193110959417, 'mae': 4.21649472332829, 'rmse': 4.999337987375898}


100%|██████████| 15/15 [00:02<00:00,  5.67it/s]


Epoch: 02, Train loss: 5.829961511347577, Val metrics: {'r2': -0.46259189800641787, 'mae': 4.46263147735086, 'rmse': 5.606758773250842}


100%|██████████| 15/15 [00:02<00:00,  5.65it/s]


Epoch: 03, Train loss: 5.559539924287867, Val metrics: {'r2': 0.024492429874922506, 'mae': 3.766962812802118, 'rmse': 4.578946946036398}


100%|██████████| 15/15 [00:02<00:00,  5.25it/s]


Epoch: 04, Train loss: 5.423333757022523, Val metrics: {'r2': 0.04527153593335975, 'mae': 3.732058451951307, 'rmse': 4.529916794727641}


100%|██████████| 15/15 [00:02<00:00,  5.38it/s]


Epoch: 05, Train loss: 5.383147912170685, Val metrics: {'r2': 0.11079381091903195, 'mae': 3.6064125540102014, 'rmse': 4.371711842056379}


100%|██████████| 15/15 [00:02<00:00,  5.61it/s]


Epoch: 06, Train loss: 5.348593301490284, Val metrics: {'r2': 0.2240576011634916, 'mae': 3.2824803013441635, 'rmse': 4.083805325918277}


100%|██████████| 15/15 [00:02<00:00,  5.70it/s]


Epoch: 07, Train loss: 5.066744343028138, Val metrics: {'r2': 0.2667887794175344, 'mae': 3.143865940295304, 'rmse': 3.9697653883274264}


100%|██████████| 15/15 [00:02<00:00,  5.36it/s]


Epoch: 08, Train loss: 4.899520094564549, Val metrics: {'r2': 0.2406571586653936, 'mae': 3.2108977917280686, 'rmse': 4.039887218971643}


100%|██████████| 15/15 [00:02<00:00,  5.16it/s]


Epoch: 09, Train loss: 4.842159068873948, Val metrics: {'r2': 0.258886559297386, 'mae': 3.2254589577714996, 'rmse': 3.9911002292655793}


100%|██████████| 15/15 [00:02<00:00,  5.63it/s]


Epoch: 10, Train loss: 4.836300826306217, Val metrics: {'r2': 0.2547718835159639, 'mae': 3.2035254703972766, 'rmse': 4.002164225465218}




Best Val metrics: {'r2': 0.26737896397866745, 'mae': 3.143606203368447, 'rmse': 3.9681673727524314}
Best test metrics: {'r2': -0.030533308877957133, 'mae': 4.345074954032898, 'rmse': 5.289332548352804}




Evaluating the model

In [20]:
val_pred = test(loader_dict["val"])
val_metrics = task.evaluate(val_pred, val_table)
print(f"Best Val metrics: {val_metrics}")

test_pred = test(loader_dict["test"])
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")



Best Val metrics: {'r2': 0.2695335420460905, 'mae': 3.1378418349073023, 'rmse': 3.9623280498017035}
Best test metrics: {'r2': -0.03023447884377206, 'mae': 4.344400579552902, 'rmse': 5.288565602726629}


