In [8]:
# Install necessary packages (skip if already installed)
!pip install torch==2.4.0
!pip install torch-geometric
!pip install pytorch_frame
!pip install relbench

# Import required libraries
import os
import torch
from relbench.datasets import get_dataset
from relbench.tasks import get_task
from relbench.modeling.graph import make_pkey_fkey_graph
from relbench.modeling.utils import get_stype_proposal
from torch_geometric.loader import NeighborLoader
from torch_frame.config.text_embedder import TextEmbedderConfig

# Load the dataset and define the task
dataset = get_dataset("rel-f1", download=True)
task = get_task("rel-f1", "driver-position", download=True)

# Extract train, validation, and test tables
train_table = task.get_table("train")
val_table = task.get_table("val")
test_table = task.get_table("test")

# Print the tables to verify
print("Training Table:", train_table)
print("Validation Table:", val_table)
print("Test Table:", test_table)

# Define configuration for text embedder
# Define configuration for text embedder
text_embedder_config = TextEmbedderConfig(text_embedder=text_embedder)  # Pass text_embedder as argument

# Map columns to embedding configurations
col_to_text_embedder_cfg = {
    "driverRef": text_embedder_config,  # Specify columns requiring text embeddings
}

# Get column types proposal
col_to_stype_dict = get_stype_proposal(dataset.get_db())

# Build the graph with the text embedder
data, col_stats_dict = make_pkey_fkey_graph(
    db=dataset.get_db(),
    col_to_stype_dict=col_to_stype_dict,
    col_to_text_embedder_cfg=col_to_text_embedder_cfg,
    cache_dir="./cache",
)

# Create Data Loaders
loader_dict = {}
for split, table in [("train", train_table), ("val", val_table), ("test", test_table)]:
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=[128, 128],
        input_nodes=table,
        batch_size=512,
        shuffle=(split == "train"),
    )

# Define GNN Model
from torch_geometric.nn import MLP
from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE

class GNNModel(torch.nn.Module):
    def __init__(self, data, col_stats_dict, channels, out_channels):
        super().__init__()
        self.encoder = HeteroEncoder(channels, data, col_stats_dict)
        self.gnn = HeteroGraphSAGE(data.node_types, data.edge_types, channels, "mean", num_layers=2)
        self.head = MLP(channels, out_channels, num_layers=1)

    def forward(self, batch):
        x_dict = self.encoder(batch)
        x_dict = self.gnn(x_dict, batch.edge_index_dict)
        return self.head(x_dict["drivers"])  # Update "drivers" based on your node type

# Initialize model, optimizer, and loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GNNModel(data, col_stats_dict, channels=64, out_channels=1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.L1Loss()

# Training and Evaluation Functions
from tqdm import tqdm
import copy

def train():
    model.train()
    total_loss = 0
    for batch in tqdm(loader_dict["train"]):
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = model(batch)
        loss = loss_fn(pred, batch["drivers"].y.float())  # Update "drivers" based on your target node type
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader_dict["train"])

@torch.no_grad()
def evaluate(loader):
    model.eval()
    predictions = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(batch)
        predictions.append(pred.cpu().numpy())
    return torch.cat(predictions)

# Training Loop
best_val_metric = float("inf")
for epoch in range(10):
    train_loss = train()
    val_pred = evaluate(loader_dict["val"])
    val_metrics = task.evaluate(val_pred, val_table)
    print(f"Epoch {epoch + 1}, Loss: {train_loss:.4f}, Val MAE: {val_metrics['mae']:.4f}")
    if val_metrics["mae"] < best_val_metric:
        best_val_metric = val_metrics["mae"]
        best_model = copy.deepcopy(model.state_dict())

# Test Evaluation
model.load_state_dict(best_model)
test_pred = evaluate(loader_dict["test"])
test_metrics = task.evaluate(test_pred, test_table)
print(f"Test MAE: {test_metrics['mae']:.4f}")


Training Table: Table(df=
           date  driverId  position
0    2004-07-05        10     10.75
1    2004-07-05        47     12.00
2    2004-03-07         7     15.00
3    2004-01-07        10      9.00
4    2003-09-09        52     13.00
...         ...       ...       ...
7448 1995-08-22        96     15.75
7449 1975-06-08       228      8.00
7450 1965-05-31       418     16.00
7451 1961-08-20       467     37.00
7452 1954-05-29       677     30.00

[7453 rows x 3 columns],
  fkey_col_to_pkey_table={'driverId': 'drivers'},
  pkey_col=None,
  time_col=date)
Validation Table: Table(df=
          date  driverId   position
0   2009-08-08         7   4.200000
1   2009-06-09        11  12.333333
2   2009-06-09         0  11.666667
3   2009-04-10         5  15.600000
4   2009-04-10        17   1.400000
..         ...       ...        ...
494 2005-06-30        16  13.800000
495 2005-03-02        32  12.750000
496 2008-10-12        14  11.000000
497 2007-06-20        28  22.000000
498 2006

TypeError: make_pkey_fkey_graph() got an unexpected keyword argument 'col_to_text_embedder_cfg'