# CS 224W - Final Project

# Loading Data Tutorial

## Installation
Setting up `relbench`.

## Loading in the Dataset

In [1]:
from relbench.datasets import get_dataset

dataset = get_dataset(name="rel-trial", download=True)

In [2]:
dataset.val_timestamp, dataset.test_timestamp

(Timestamp('2020-01-01 00:00:00'), Timestamp('2021-01-01 00:00:00'))

In [3]:
db = dataset.get_db()

Loading Database object from /home/cpondoc/.cache/relbench/rel-trial/db...
Done in 3.68 seconds.


In [4]:
db.table_dict.keys()

dict_keys(['studies', 'reported_event_totals', 'drop_withdrawals', 'facilities', 'sponsors', 'conditions', 'interventions_studies', 'designs', 'sponsors_studies', 'outcome_analyses', 'conditions_studies', 'interventions', 'outcomes', 'facilities_studies', 'eligibilities'])

## Loading Tasks

In [5]:
from relbench.tasks import get_task_names, get_task

get_task_names("rel-trial")

['study-outcome',
 'study-adverse',
 'site-success',
 'condition-sponsor-run',
 'site-sponsor-run']

In [6]:
task = get_task("rel-trial", "study-outcome", download=True)

In [7]:
task.get_table("train")

Table(df=
       timestamp  nct_id  outcome
0     2003-01-05    4678        1
1     2004-01-05    1702        0
2     2004-01-05    7156        0
3     2004-01-05    3665        1
4     2004-01-05    3039        1
...          ...     ...      ...
11989 2019-01-01  185803        1
11990 2019-01-01  194455        1
11991 2019-01-01  197165        1
11992 2019-01-01  206285        1
11993 2019-01-01  206513        0

[11994 rows x 3 columns],
  fkey_col_to_pkey_table={'nct_id': 'studies'},
  pkey_col=None,
  time_col=timestamp)

In [8]:
task.get_table("test")

Table(df=
     timestamp  nct_id
0   2021-01-01   53166
1   2021-01-01   61792
2   2021-01-01  208935
3   2021-01-01  209101
4   2021-01-01  214190
..         ...     ...
820 2021-01-01  199533
821 2021-01-01  202875
822 2021-01-01  203010
823 2021-01-01  203086
824 2021-01-01  204336

[825 rows x 2 columns],
  fkey_col_to_pkey_table={'nct_id': 'studies'},
  pkey_col=None,
  time_col=timestamp)

# Train Model

In [5]:
# Install required packages.
import os
import torch

In [6]:
import numpy as np

from torch.nn import BCEWithLogitsLoss, L1Loss
from relbench.datasets import get_dataset
from relbench.tasks import get_task

dataset = get_dataset("rel-f1", download=True)
task = get_task("rel-f1", "driver-position", download=True)

train_table = task.get_table("train")
val_table = task.get_table("val")
test_table = task.get_table("test")

out_channels = 1
loss_fn = L1Loss()
tune_metric = "mae"
higher_is_better = False

In [7]:
train_table

Table(df=
           date  driverId  position
0    2004-07-05        10     10.75
1    2004-07-05        47     12.00
2    2004-03-07         7     15.00
3    2004-01-07        10      9.00
4    2003-09-09        52     13.00
...         ...       ...       ...
7448 1995-08-22        96     15.75
7449 1975-06-08       228      8.00
7450 1965-05-31       418     16.00
7451 1961-08-20       467     37.00
7452 1954-05-29       677     30.00

[7453 rows x 3 columns],
  fkey_col_to_pkey_table={'driverId': 'drivers'},
  pkey_col=None,
  time_col=date)

In [8]:
import os
import math
import numpy as np
from tqdm import tqdm

import torch
import torch_geometric
import torch_frame

# Some book keeping
from torch_geometric.seed import seed_everything

seed_everything(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)  # check that it's cuda if you want it to run in reasonable time!
root_dir = "./data"

cuda


In [9]:
from relbench.modeling.utils import get_stype_proposal

db = dataset.get_db()
col_to_stype_dict = get_stype_proposal(db)
col_to_stype_dict

Loading Database object from /home/cpondoc/.cache/relbench/rel-f1/db...
Done in 0.02 seconds.


{'standings': {'driverStandingsId': <stype.numerical: 'numerical'>,
  'raceId': <stype.numerical: 'numerical'>,
  'driverId': <stype.numerical: 'numerical'>,
  'points': <stype.numerical: 'numerical'>,
  'position': <stype.numerical: 'numerical'>,
  'wins': <stype.numerical: 'numerical'>,
  'date': <stype.timestamp: 'timestamp'>},
 'constructors': {'constructorId': <stype.numerical: 'numerical'>,
  'constructorRef': <stype.text_embedded: 'text_embedded'>,
  'name': <stype.text_embedded: 'text_embedded'>,
  'nationality': <stype.text_embedded: 'text_embedded'>},
 'circuits': {'circuitId': <stype.numerical: 'numerical'>,
  'circuitRef': <stype.text_embedded: 'text_embedded'>,
  'name': <stype.text_embedded: 'text_embedded'>,
  'location': <stype.text_embedded: 'text_embedded'>,
  'country': <stype.text_embedded: 'text_embedded'>,
  'lat': <stype.numerical: 'numerical'>,
  'lng': <stype.numerical: 'numerical'>,
  'alt': <stype.numerical: 'numerical'>},
 'qualifying': {'qualifyId': <stype.

In [11]:
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from torch import Tensor


class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))



  from tqdm.autonotebook import tqdm, trange


In [12]:
from torch_frame.config.text_embedder import TextEmbedderConfig
from relbench.modeling.graph import make_pkey_fkey_graph

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=128
)

data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,  # speficied column types
    text_embedder_cfg=text_embedder_cfg,  # our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-f1_materialized_cache"
    ),  # store materialized graph for convenience
)

Embedding raw data in mini-batch: 100%|██████████| 2/2 [00:00<00:00, 28.44it/s]
Embedding raw data in mini-batch: 100%|██████████| 2/2 [00:00<00:00, 439.24it/s]
Embedding raw data in mini-batch: 100%|██████████| 2/2 [00:00<00:00, 463.13it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 382.06it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 340.42it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 459.80it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 427.90it/s]
Embedding raw data in mini-batch: 100%|██████████| 7/7 [00:00<00:00, 401.31it/s]
Embedding raw data in mini-batch: 100%|██████████| 7/7 [00:00<00:00, 401.69it/s]
Embedding raw data in mini-batch: 100%|██████████| 7/7 [00:00<00:00, 383.75it/s]
Embedding raw data in mini-batch: 100%|██████████| 7/7 [00:00<00:00, 416.51it/s]
Embedding raw data in mini-batch: 100%|██████████| 7/7 [00:00<00:00, 403.28it/s]
  ser = pd.to_datetime(ser, f

In [16]:
data

HeteroData(
  standings={
    tf=TensorFrame([28115, 4]),
    time=[28115],
  },
  constructors={ tf=TensorFrame([211, 3]) },
  circuits={ tf=TensorFrame([77, 7]) },
  qualifying={
    tf=TensorFrame([4082, 3]),
    time=[4082],
  },
  constructor_standings={
    tf=TensorFrame([10170, 4]),
    time=[10170],
  },
  drivers={ tf=TensorFrame([857, 6]) },
  constructor_results={
    tf=TensorFrame([9408, 2]),
    time=[9408],
  },
  races={
    tf=TensorFrame([820, 5]),
    time=[820],
  },
  results={
    tf=TensorFrame([20323, 11]),
    time=[20323],
  },
  (standings, f2p_raceId, races)={ edge_index=[2, 28115] },
  (races, rev_f2p_raceId, standings)={ edge_index=[2, 28115] },
  (standings, f2p_driverId, drivers)={ edge_index=[2, 28115] },
  (drivers, rev_f2p_driverId, standings)={ edge_index=[2, 28115] },
  (qualifying, f2p_raceId, races)={ edge_index=[2, 4082] },
  (races, rev_f2p_raceId, qualifying)={ edge_index=[2, 4082] },
  (qualifying, f2p_driverId, drivers)={ edge_index=[2, 4082

In [13]:
from relbench.modeling.graph import get_node_train_table_input, make_pkey_fkey_graph
from torch_geometric.loader import NeighborLoader

loader_dict = {}

for split, table in [
    ("train", train_table),
    ("val", val_table),
    ("test", test_table),
]:
    table_input = get_node_train_table_input(
        table=table,
        task=task,
    )
    entity_table = table_input.nodes[0]
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=[
            128 for i in range(2)
        ],  # we sample subgraphs of depth 2, 128 neighbors per node.
        time_attr="time",
        input_nodes=table_input.nodes,
        input_time=table_input.time,
        transform=table_input.transform,
        batch_size=512,
        temporal_strategy="uniform",
        shuffle=split == "train",
        num_workers=0,
        persistent_workers=False,
    )

In [14]:
from torch.nn import BCEWithLogitsLoss
import copy
from typing import Any, Dict, List

import torch
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType

from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder


class Model(torch.nn.Module):

    def __init__(
        self,
        data: HeteroData,
        col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
        num_layers: int,
        channels: int,
        out_channels: int,
        aggr: str,
        norm: str,
        # List of node types to add shallow embeddings to input
        shallow_list: List[NodeType] = [],
        # ID awareness
        id_awareness: bool = False,
    ):
        super().__init__()

        self.encoder = HeteroEncoder(
            channels=channels,
            node_to_col_names_dict={
                node_type: data[node_type].tf.col_names_dict
                for node_type in data.node_types
            },
            node_to_col_stats=col_stats_dict,
        )
        self.temporal_encoder = HeteroTemporalEncoder(
            node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
        )
        self.gnn = HeteroGraphSAGE(
            node_types=data.node_types,
            edge_types=data.edge_types,
            channels=channels,
            aggr=aggr,
            num_layers=num_layers,
        )
        self.head = MLP(
            channels,
            out_channels=out_channels,
            norm=norm,
            num_layers=1,
        )
        self.embedding_dict = ModuleDict(
            {
                node: Embedding(data.num_nodes_dict[node], channels)
                for node in shallow_list
            }
        )

        self.id_awareness_emb = None
        if id_awareness:
            self.id_awareness_emb = torch.nn.Embedding(1, channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.encoder.reset_parameters()
        self.temporal_encoder.reset_parameters()
        self.gnn.reset_parameters()
        self.head.reset_parameters()
        for embedding in self.embedding_dict.values():
            torch.nn.init.normal_(embedding.weight, std=0.1)
        if self.id_awareness_emb is not None:
            self.id_awareness_emb.reset_parameters()

    def forward(
        self,
        batch: HeteroData,
        entity_table: NodeType,
    ) -> Tensor:
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
            batch.num_sampled_nodes_dict,
            batch.num_sampled_edges_dict,
        )

        return self.head(x_dict[entity_table][: seed_time.size(0)])

    def forward_dst_readout(
        self,
        batch: HeteroData,
        entity_table: NodeType,
        dst_table: NodeType,
    ) -> Tensor:
        if self.id_awareness_emb is None:
            raise RuntimeError(
                "id_awareness must be set True to use forward_dst_readout"
            )
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)
        # Add ID-awareness to the root node
        x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
        )

        return self.head(x_dict[dst_table])


model = Model(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)


# if you try out different RelBench tasks you will need to change these
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
epochs = 10

In [16]:
def train() -> float:
    model.train()

    loss_accum = count_accum = 0
    for batch in tqdm(loader_dict["train"]):
        batch = batch.to(device)

        optimizer.zero_grad()
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred

        loss = loss_fn(pred.float(), batch[entity_table].y.float())
        loss.backward()
        optimizer.step()

        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    return loss_accum / count_accum


@torch.no_grad()
def test(loader: NeighborLoader) -> np.ndarray:
    model.eval()

    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0).numpy()

In [17]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
for epoch in range(1, epochs + 1):
    train_loss = train()
    val_pred = test(loader_dict["val"])
    val_metrics = task.evaluate(val_pred, val_table)
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(model.state_dict())


model.load_state_dict(state_dict)
val_pred = test(loader_dict["val"])
val_metrics = task.evaluate(val_pred, val_table)
print(f"Best Val metrics: {val_metrics}")

test_pred = test(loader_dict["test"])
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

100%|██████████| 15/15 [00:01<00:00,  7.54it/s]


Epoch: 01, Train loss: 9.05065899581121, Val metrics: {'r2': -0.23120657793923027, 'mae': np.float64(4.330512936909994), 'rmse': np.float64(5.144175221892629)}


100%|██████████| 15/15 [00:01<00:00,  9.28it/s]


Epoch: 02, Train loss: 5.87314571368843, Val metrics: {'r2': -0.3844000746079419, 'mae': np.float64(4.353121052157824), 'rmse': np.float64(5.454828446452905)}


100%|██████████| 15/15 [00:01<00:00,  9.21it/s]


Epoch: 03, Train loss: 5.557879151905142, Val metrics: {'r2': 0.03255471843204849, 'mae': np.float64(3.806097188502371), 'rmse': np.float64(4.559985850194199)}


100%|██████████| 15/15 [00:01<00:00,  9.37it/s]


Epoch: 04, Train loss: 5.424738339061115, Val metrics: {'r2': 0.056552086012577885, 'mae': np.float64(3.7126938369485964), 'rmse': np.float64(4.503075763043278)}


100%|██████████| 15/15 [00:01<00:00,  9.32it/s]


Epoch: 05, Train loss: 5.360174358410914, Val metrics: {'r2': 0.1317851915241477, 'mae': np.float64(3.5614308015139167), 'rmse': np.float64(4.319802427020088)}


100%|██████████| 15/15 [00:01<00:00,  9.35it/s]


Epoch: 06, Train loss: 5.048339258276119, Val metrics: {'r2': 0.23940086590732568, 'mae': np.float64(3.215481931397177), 'rmse': np.float64(4.043227728237567)}


100%|██████████| 15/15 [00:01<00:00,  9.39it/s]


Epoch: 07, Train loss: 4.919287130590511, Val metrics: {'r2': 0.25037403982148254, 'mae': np.float64(3.1837538516593122), 'rmse': np.float64(4.013955918133126)}


100%|██████████| 15/15 [00:01<00:00,  9.36it/s]


Epoch: 08, Train loss: 4.876277381424038, Val metrics: {'r2': 0.23921857537919244, 'mae': np.float64(3.3032900533758967), 'rmse': np.float64(4.043712213381598)}


100%|██████████| 15/15 [00:01<00:00,  9.36it/s]


Epoch: 09, Train loss: 4.794042118696435, Val metrics: {'r2': 0.2802659337567127, 'mae': np.float64(3.1642936826946744), 'rmse': np.float64(3.9331120501026913)}


100%|██████████| 15/15 [00:01<00:00,  9.42it/s]


Epoch: 10, Train loss: 4.7091667632616465, Val metrics: {'r2': 0.2812687618263451, 'mae': np.float64(3.150532161162229), 'rmse': np.float64(3.9303710307355626)}
Best Val metrics: {'r2': 0.28143406655995884, 'mae': np.float64(3.150130301152537), 'rmse': np.float64(3.929919021511871)}
Best test metrics: {'r2': 0.09401481007128154, 'mae': np.float64(4.100206426402979), 'rmse': np.float64(4.9594144720824165)}


