In [1]:
import torch
import pandas as pd
from utils.batching.batch_races import batch_races

In [2]:
df_train = pd.read_pickle('data/processed/2025/03/model_ready_train.pkl')
print(f"✅ Loaded dataset: {df_train.shape}")

✅ Loaded dataset: (2024, 109)


In [3]:
# ✅ STEP 2: Define input feature columns.

# ⚙️ Continuous features (numerical signals).
float_cols = [
    "distance_f", "field_size", "class_num", "draw", "age", "or", "rpr", "ts", "lbs",
    "trainer_ovr_runs", "trainer_ovr_wins", "trainer_ovr_win_pct", "trainer_ovr_profit",
    "trainer_last_14_runs", "trainer_last_14_wins", "trainer_last_14_win_pct", "trainer_last_14_profit",
    "jockey_ovr_runs", "jockey_ovr_wins", "jockey_ovr_win_pct", "jockey_ovr_profit",
    "jockey_last_14_runs", "jockey_last_14_wins", "jockey_last_14_win_pct", "jockey_last_14_profit",
    "rpr_rank", "or_rank", "rpr_zscore", "or_zscore"
]

# ➕ NLP-derived boolean flags.
nlp_flags = [c for c in df_train.columns if c.startswith("mentions_")]
float_cols += nlp_flags

# 🔢 Categorical embeddings (integers).
idx_cols = [
    "country_idx", "going_idx", "sex_idx", "type_idx",
    "class_label_idx", "headgear_idx", "race_class_idx", "venue_idx"
]

# 💬 Text embeddings (vector arrays).
nlp_cols = ["comment_vector", "spotlight_vector"]

print(f"📊 Float features: {len(float_cols)}")
print(f"🔢 Embedding indices: {idx_cols}")
print(f"🧠 NLP fields: {nlp_cols}")


📊 Float features: 63
🔢 Embedding indices: ['country_idx', 'going_idx', 'sex_idx', 'type_idx', 'class_label_idx', 'headgear_idx', 'race_class_idx', 'venue_idx']
🧠 NLP fields: ['comment_vector', 'spotlight_vector']


In [4]:
batches = batch_races(
    df_train,
    float_cols=float_cols,
    idx_cols=idx_cols,
    nlp_cols=nlp_cols,
    exclude_non_runners=True,
    label_col="winner_flag",
    min_runners=5
)
print(f"📦 Batches created: {len(batches)}")

print(batches[0]["winner_flag"])  # 🎯 Now included!


# Peek at shape of batch 0
batch = batches[0]
print(f"📦 float_features shape: {batch['float_features'].shape}")
print(f"🎯 winner_flag shape: {batch['winner_flag'].shape}")

📦 Batches created: 203
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
📦 float_features shape: (22, 63)
🎯 winner_flag shape: (22,)


In [5]:
# This sanity check ensures the winner is placed in different indices.
for i in range(10):
    winner_vector = batches[i]["winner_flag"]
    print(f"🏇 Batch {i}, #horses: {len(winner_vector)} → Winner at index: {int(winner_vector.argmax())}")


🏇 Batch 0, #horses: 22 → Winner at index: 2
🏇 Batch 1, #horses: 21 → Winner at index: 3
🏇 Batch 2, #horses: 9 → Winner at index: 7
🏇 Batch 3, #horses: 9 → Winner at index: 2
🏇 Batch 4, #horses: 17 → Winner at index: 7
🏇 Batch 5, #horses: 18 → Winner at index: 5
🏇 Batch 6, #horses: 10 → Winner at index: 6
🏇 Batch 7, #horses: 16 → Winner at index: 6
🏇 Batch 8, #horses: 19 → Winner at index: 3
🏇 Batch 9, #horses: 6 → Winner at index: 3


In [6]:
# Import our custom PyTorch dataset that wraps batches of races.
# This turns the batches into a PyTorch-friendly object.
from utils.training.dataloader_utils import RaceDataset

# Brings in PyTorches DataLoader, which handles:
# - Mini-batching.
# - Shuffling.
# - Efficient iteration over the dataset.
from torch.utils.data import DataLoader

# Wraps the list of race batches in the RaceDataset class.
# 'include_target=true' tells it to return 'winner_flag' for training.
train_dataset = RaceDataset(batches, include_target=True)

# Creates a PyTorch DataLoader that:
# - Loads batches ONE RACE AT A TIME (batch_size=1).
# Shuffles the order of races (not the horses within them).
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

# Peek at a batch.

# Pull the first batch from the dataset.
batch = next(iter(train_loader))

# Print the first batch details.
print(batch["float_feats"].shape)   # [B, R, F]
print(batch["targets"].shape)      # [B, R]


torch.Size([1, 12, 63])
torch.Size([1, 12])


In [7]:
# Load label_encoders.
import joblib

label_encoders = joblib.load("data/processed/2025/03/embedding_encoders_march-2025.pkl")

# For model setup:
#idx_vocab_sizes = [len(le.classes_) for le in label_encoders.values()]


In [8]:
from modeling.transformer_model import RaceTransformer

model = RaceTransformer(
    label_encoders=label_encoders,
    float_dim=63,
    embedding_dim=32,
    nlp_dim=384,
    hidden_dim=128,
    nhead=4,
    num_layers=2
)

model.train()



RaceTransformer(
  (heads): EmbeddingHeads(
    (embeddings): ModuleList(
      (0): Embedding(9, 32)
      (1): Embedding(10, 32)
      (2): Embedding(5, 32)
      (3): Embedding(4, 32)
      (4): Embedding(2, 32)
      (5): Embedding(18, 32)
      (6): Embedding(7, 32)
      (7): Embedding(33, 32)
    )
    (proj_float): Linear(in_features=63, out_features=32, bias=True)
    (proj_comment): Linear(in_features=384, out_features=32, bias=True)
    (proj_spotlight): Linear(in_features=384, out_features=32, bias=True)
  )
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=352, out_features=352, bias=True)
        )
        (linear1): Linear(in_features=352, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=352, bias=True)
        (norm1): LayerNorm(

In [9]:
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import time
import math
from datetime import datetime
import os

# 🕒 Create a timestamped directory for this run's checkpoints
timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
checkpoint_dir = f"checkpoints/transformer_{timestamp}"
os.makedirs(checkpoint_dir, exist_ok=True)

print(f"📂 Saving checkpoints to: {checkpoint_dir}")

# Create checkpoint folder if missing
os.makedirs("checkpoints", exist_ok=True)

# 🎯 Loss + Optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Number of epochs and checkpoints allowed
n_epochs = 50
max_checkpoints = 10

# Calculate which epochs to save
save_epochs = set(
    [1, n_epochs] +
    [math.ceil(i * n_epochs / max_checkpoints) for i in range(1, max_checkpoints)]
)
print(f"💾 Checkpoints will be saved at epochs: {sorted(save_epochs)}")

for epoch in range(n_epochs):
    model.train()
    total_loss = 0.0
    start_time = time.time()

    # 🔁 Loop through batches with tqdm progress bar
    for batch in tqdm(train_loader, desc=f"🧠 Epoch {epoch+1}/{n_epochs}"):
        logits = model(
            batch["float_feats"],
            batch["idx_feats"],
            batch["comment_vecs"],
            batch["spotlight_vecs"],
            batch["mask"]
        )
        targets = batch["targets"]
        loss = criterion(logits, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # ⏱️ Epoch Summary
    avg_loss = total_loss / len(train_loader)
    duration = time.time() - start_time
    print(f"📉 Epoch {epoch+1} — Loss: {avg_loss:.4f} — ⏱️ {duration:.1f}s")

    # 💾 Save Checkpoint
    if (epoch + 1) in save_epochs:
        ckpt_path = f"{checkpoint_dir}/epoch_{epoch+1}.pt"
        torch.save(model.state_dict(), ckpt_path)
        print(f"💾 Saved model to: {ckpt_path}")

📂 Saving checkpoints to: checkpoints/transformer_2025-04-04T19-08-35
💾 Checkpoints will be saved at epochs: [1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50]


🧠 Epoch 1/50: 100%|█████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 37.16it/s]


📉 Epoch 1 — Loss: 0.3672 — ⏱️ 5.5s
💾 Saved model to: checkpoints/transformer_2025-04-04T19-08-35/epoch_1.pt


🧠 Epoch 2/50: 100%|█████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 35.38it/s]


📉 Epoch 2 — Loss: 0.3559 — ⏱️ 5.7s


🧠 Epoch 3/50: 100%|█████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.60it/s]


📉 Epoch 3 — Loss: 0.3494 — ⏱️ 5.9s


🧠 Epoch 4/50: 100%|█████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.01it/s]


📉 Epoch 4 — Loss: 0.3485 — ⏱️ 6.0s


🧠 Epoch 5/50: 100%|█████████████████████████████████████████████████████████████████| 203/203 [00:06<00:00, 32.38it/s]


📉 Epoch 5 — Loss: 0.3452 — ⏱️ 6.3s
💾 Saved model to: checkpoints/transformer_2025-04-04T19-08-35/epoch_5.pt


🧠 Epoch 6/50: 100%|█████████████████████████████████████████████████████████████████| 203/203 [00:06<00:00, 30.83it/s]


📉 Epoch 6 — Loss: 0.3395 — ⏱️ 6.6s


🧠 Epoch 7/50: 100%|█████████████████████████████████████████████████████████████████| 203/203 [00:06<00:00, 31.43it/s]


📉 Epoch 7 — Loss: 0.3373 — ⏱️ 6.5s


🧠 Epoch 8/50: 100%|█████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.53it/s]


📉 Epoch 8 — Loss: 0.3334 — ⏱️ 5.9s


🧠 Epoch 9/50: 100%|█████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.49it/s]


📉 Epoch 9 — Loss: 0.3231 — ⏱️ 5.9s


🧠 Epoch 10/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:06<00:00, 33.53it/s]


📉 Epoch 10 — Loss: 0.3215 — ⏱️ 6.1s
💾 Saved model to: checkpoints/transformer_2025-04-04T19-08-35/epoch_10.pt


🧠 Epoch 11/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 35.08it/s]


📉 Epoch 11 — Loss: 0.3185 — ⏱️ 5.8s


🧠 Epoch 12/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.98it/s]


📉 Epoch 12 — Loss: 0.3100 — ⏱️ 5.8s


🧠 Epoch 13/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 35.23it/s]


📉 Epoch 13 — Loss: 0.3022 — ⏱️ 5.8s


🧠 Epoch 14/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.54it/s]


📉 Epoch 14 — Loss: 0.2953 — ⏱️ 5.9s


🧠 Epoch 15/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.89it/s]


📉 Epoch 15 — Loss: 0.2873 — ⏱️ 5.8s
💾 Saved model to: checkpoints/transformer_2025-04-04T19-08-35/epoch_15.pt


🧠 Epoch 16/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.90it/s]


📉 Epoch 16 — Loss: 0.2747 — ⏱️ 5.8s


🧠 Epoch 17/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.57it/s]


📉 Epoch 17 — Loss: 0.2672 — ⏱️ 5.9s


🧠 Epoch 18/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 35.68it/s]


📉 Epoch 18 — Loss: 0.2611 — ⏱️ 5.7s


🧠 Epoch 19/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 33.97it/s]


📉 Epoch 19 — Loss: 0.2416 — ⏱️ 6.0s


🧠 Epoch 20/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:06<00:00, 33.78it/s]


📉 Epoch 20 — Loss: 0.2382 — ⏱️ 6.0s
💾 Saved model to: checkpoints/transformer_2025-04-04T19-08-35/epoch_20.pt


🧠 Epoch 21/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 35.04it/s]


📉 Epoch 21 — Loss: 0.2254 — ⏱️ 5.8s


🧠 Epoch 22/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 35.42it/s]


📉 Epoch 22 — Loss: 0.2190 — ⏱️ 5.7s


🧠 Epoch 23/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.95it/s]


📉 Epoch 23 — Loss: 0.2156 — ⏱️ 5.8s


🧠 Epoch 24/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.76it/s]


📉 Epoch 24 — Loss: 0.2045 — ⏱️ 5.8s


🧠 Epoch 25/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.16it/s]


📉 Epoch 25 — Loss: 0.1950 — ⏱️ 5.9s
💾 Saved model to: checkpoints/transformer_2025-04-04T19-08-35/epoch_25.pt


🧠 Epoch 26/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 35.11it/s]


📉 Epoch 26 — Loss: 0.1835 — ⏱️ 5.8s


🧠 Epoch 27/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.69it/s]


📉 Epoch 27 — Loss: 0.1875 — ⏱️ 5.9s


🧠 Epoch 28/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 35.04it/s]


📉 Epoch 28 — Loss: 0.1766 — ⏱️ 5.8s


🧠 Epoch 29/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.47it/s]


📉 Epoch 29 — Loss: 0.1658 — ⏱️ 5.9s


🧠 Epoch 30/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.88it/s]


📉 Epoch 30 — Loss: 0.1638 — ⏱️ 5.8s
💾 Saved model to: checkpoints/transformer_2025-04-04T19-08-35/epoch_30.pt


🧠 Epoch 31/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 35.03it/s]


📉 Epoch 31 — Loss: 0.1664 — ⏱️ 5.8s


🧠 Epoch 32/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 35.30it/s]


📉 Epoch 32 — Loss: 0.1607 — ⏱️ 5.8s


🧠 Epoch 33/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.81it/s]


📉 Epoch 33 — Loss: 0.1681 — ⏱️ 5.8s


🧠 Epoch 34/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 35.39it/s]


📉 Epoch 34 — Loss: 0.1359 — ⏱️ 5.7s


🧠 Epoch 35/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.71it/s]


📉 Epoch 35 — Loss: 0.1378 — ⏱️ 5.8s
💾 Saved model to: checkpoints/transformer_2025-04-04T19-08-35/epoch_35.pt


🧠 Epoch 36/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:06<00:00, 32.57it/s]


📉 Epoch 36 — Loss: 0.1491 — ⏱️ 6.2s


🧠 Epoch 37/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:10<00:00, 20.22it/s]


📉 Epoch 37 — Loss: 0.1224 — ⏱️ 10.0s


🧠 Epoch 38/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:06<00:00, 31.46it/s]


📉 Epoch 38 — Loss: 0.1057 — ⏱️ 6.5s


🧠 Epoch 39/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:06<00:00, 32.48it/s]


📉 Epoch 39 — Loss: 0.1166 — ⏱️ 6.3s


🧠 Epoch 40/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:06<00:00, 32.24it/s]


📉 Epoch 40 — Loss: 0.1125 — ⏱️ 6.3s
💾 Saved model to: checkpoints/transformer_2025-04-04T19-08-35/epoch_40.pt


🧠 Epoch 41/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:06<00:00, 32.98it/s]


📉 Epoch 41 — Loss: 0.1079 — ⏱️ 6.2s


🧠 Epoch 42/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.07it/s]


📉 Epoch 42 — Loss: 0.1216 — ⏱️ 6.0s


🧠 Epoch 43/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.18it/s]


📉 Epoch 43 — Loss: 0.1199 — ⏱️ 5.9s


🧠 Epoch 44/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:06<00:00, 32.96it/s]


📉 Epoch 44 — Loss: 0.1031 — ⏱️ 6.2s


🧠 Epoch 45/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 33.88it/s]


📉 Epoch 45 — Loss: 0.1284 — ⏱️ 6.0s
💾 Saved model to: checkpoints/transformer_2025-04-04T19-08-35/epoch_45.pt


🧠 Epoch 46/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:06<00:00, 33.77it/s]


📉 Epoch 46 — Loss: 0.1089 — ⏱️ 6.0s


🧠 Epoch 47/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:06<00:00, 33.78it/s]


📉 Epoch 47 — Loss: 0.0933 — ⏱️ 6.0s


🧠 Epoch 48/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 33.96it/s]


📉 Epoch 48 — Loss: 0.0696 — ⏱️ 6.0s


🧠 Epoch 49/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:05<00:00, 34.67it/s]


📉 Epoch 49 — Loss: 0.1066 — ⏱️ 5.9s


🧠 Epoch 50/50: 100%|████████████████████████████████████████████████████████████████| 203/203 [00:06<00:00, 33.25it/s]


📉 Epoch 50 — Loss: 0.0843 — ⏱️ 6.1s
💾 Saved model to: checkpoints/transformer_2025-04-04T19-08-35/epoch_50.pt
