
# Lap Context Transformer

Train a transformer that ingests every lap of a race up to lap `t` and predicts the next lap (`t+1`) lap time and position for all drivers. Each batch may contain races of different lengths; we pad shorter contexts and mask padded tokens during attention.


In [78]:
%pip install pandas numpy scikit-learn tqdm
%pip install torch

Python(44682) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Note: you may need to restart the kernel to use updated packages.


Python(44685) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Note: you may need to restart the kernel to use updated packages.



## Imports & configuration
Define the training configuration, seed everything, and set up helper functions for statistics.


In [None]:
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '0'
import argparse
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
import pandas as pd
import torch
print(torch.__version__)
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
#DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE = torch.device('cpu')
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    x = torch.ones(1, device=DEVICE)
    print (x)
else:
    print ("MPS device not found.")
def ensure_pos_std(series: pd.Series) -> float:
    std = float(series.std())
    return std if std > 1e-6 else 1.0
@dataclass
class TrainerConfig:
    dataset_path: Path = Path('fastf1_lap_dataset.csv')
    max_drivers: int = 20
    min_laps_per_session: int = 5
    train_years: Tuple[int, ...] = (2018, 2019, 2020, 2021)
    val_years: Tuple[int, ...] = (2023,)
    test_years: Tuple[int, ...] = (2024, 2025)
    batch_size: int = 2
    num_epochs: int = 1
    learning_rate: float = 1e-3
    weight_decay: float = 1e-4
    max_train_steps_per_epoch: Optional[int] = None
    max_val_steps: Optional[int] = None
    seed: int = 42
    debug_sessions: Optional[int] = None
CONFIG = TrainerConfig()
np.random.seed(CONFIG.seed)
torch.manual_seed(CONFIG.seed)


2.9.1
tensor([1.], device='mps:0')


<torch._C.Generator at 0x113606250>


## Load and preprocess the lap dataset
Clean missing values, fill lagged lap times, and derive helper columns for gaps, tyre age, etc.


In [None]:


def load_lap_dataframe(csv_path: Path) -> pd.DataFrame:
    df = pd.read_csv(csv_path)
    required = {'session_key', 'driver_id', 'lap_number'}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f'Missing columns: {missing}')
    df = df.dropna(subset=list(required)).copy()
    df['lap_number'] = df['lap_number'].astype(int)
    df = df.sort_values(['session_key', 'driver_id', 'lap_number'])

    # Fill various gaps
    df['gap_to_leader_s'] = df['gap_to_leader_s'].fillna(0)
    df['gap_to_ahead_s'] = df['gap_to_ahead_s'].fillna(0)
    # Avoid using lap_number as a proxy for position; carry last known or leave NaN
    df['current_position'] = df['current_position'].fillna(method='ffill')
    df['current_position'] = df['current_position'].fillna(method='bfill')
    df['grid_position'] = df.groupby(['session_key', 'driver_id'])['grid_position'].transform(lambda s: s.fillna(s.iloc[0]))
    df['grid_position'] = df['grid_position'].fillna(df['current_position'])
    df['grid_position'] = df['grid_position'].fillna(df['lap_number'])
    df['laps_on_current_tyre'] = df['laps_on_current_tyre'].fillna(1)
    df['lap_time_s'] = df.groupby('session_key')['lap_time_s'].transform(lambda s: s.fillna(s.median()))
    df['lap_time_s'] = df['lap_time_s'].fillna(df['lap_time_s'].median())

    for col in ['track_temperature', 'air_temperature', 'humidity', 'wind_speed', 'pressure']:
        if col in df.columns:
            df[col] = df[col].fillna(df[col].median())
    for col in ['drs_enabled', 'safety_car_this_lap', 'virtual_sc_this_lap', 'rainfall', 'has_rain']:
        if col in df.columns:
            df[col] = df[col].fillna(False)
    df['tyre_compound'] = df['tyre_compound'].fillna('UNKNOWN')

    df['lap_time_prev1'] = df.groupby(['session_key', 'driver_id'])['lap_time_s'].shift(1)
    df['lap_time_prev2'] = df.groupby(['session_key', 'driver_id'])['lap_time_s'].shift(2)
    df['lap_time_prev1'] = df['lap_time_prev1'].fillna(df['lap_time_s'])
    df['lap_time_prev2'] = df['lap_time_prev2'].fillna(df['lap_time_s'])
    df['pit_flag'] = False
    for (session_key, driver_id), group in df.groupby(['session_key', 'driver_id']):
        order = group.sort_values('lap_number')
        tyre_diff = order['laps_on_current_tyre'].diff()
        compound_change = order['tyre_compound'].ne(order['tyre_compound'].shift(1))
        pits = (tyre_diff < 0) | compound_change.fillna(False)
        df.loc[order.index, 'pit_flag'] = pits.fillna(False)
    return df

raw_df = load_lap_dataframe(CONFIG.dataset_path)
raw_df.head()



## Track/weather scalers & vocabularies
Compute per-track mean/std for lap times and gaps, global weather scalers, and integer vocabularies for drivers/teams/circuits/compounds.


In [53]:


def compute_track_scalers(df: pd.DataFrame) -> Dict[str, Dict[str, Tuple[float, float]]]:
    stats = {}
    for circuit_id, group in df.groupby('circuit_id'):
        lap_series = group['lap_time_s']
        gap_leader = np.log1p(group['gap_to_leader_s'])
        gap_ahead = np.log1p(group['gap_to_ahead_s'])
        stats[circuit_id] = {
            'lap_time_s': (float(lap_series.mean()), ensure_pos_std(lap_series)),
            'gap_to_leader_s': (float(gap_leader.mean()), ensure_pos_std(gap_leader)),
            'gap_to_ahead_s': (float(gap_ahead.mean()), ensure_pos_std(gap_ahead)),
        }
    if not stats:
        raise ValueError('No circuits found for scalers')
    return stats


def compute_weather_scaler(df: pd.DataFrame) -> Dict[str, Tuple[float, float]]:
    scalers = {}
    for col in ['track_temperature', 'air_temperature', 'humidity', 'wind_speed', 'pressure']:
        series = df[col].astype(float)
        scalers[col] = (float(series.mean()), ensure_pos_std(series))
    return scalers

track_scalers = compute_track_scalers(raw_df)
weather_scaler = compute_weather_scaler(raw_df)


@dataclass
class LapVocabulary:
    driver_to_idx: Dict[str, int]
    team_to_idx: Dict[str, int]
    compound_to_idx: Dict[str, int]
    track_to_idx: Dict[str, int]

    @property
    def num_drivers(self) -> int:
        return len(self.driver_to_idx) + 1

    @property
    def num_teams(self) -> int:
        return len(self.team_to_idx) + 1

    @property
    def num_compounds(self) -> int:
        return len(self.compound_to_idx) + 1

    @property
    def num_tracks(self) -> int:
        return len(self.track_to_idx) + 1


def build_vocab(df: pd.DataFrame) -> LapVocabulary:
    return LapVocabulary(
        driver_to_idx={drv: idx + 1 for idx, drv in enumerate(sorted(df['driver_id'].unique()))},
        team_to_idx={team: idx + 1 for idx, team in enumerate(sorted(df['team_id'].unique()))},
        compound_to_idx={comp: idx + 1 for idx, comp in enumerate(sorted(df['tyre_compound'].unique()))},
        track_to_idx={trk: idx + 1 for idx, trk in enumerate(sorted(df['circuit_id'].unique()))},
    )

vocab = build_vocab(raw_df)
len(vocab.driver_to_idx), len(vocab.track_to_idx)


(43, 35)


## Session tensor builder
Convert each race (session) into dense tensors: `[laps, max_drivers, dynamics]`, global lap features, categorical tokens, and normalized targets. Only races with at least `min_laps_per_session` laps are kept.


In [54]:
@dataclass
class SessionTensor:
    session_key: str
    year: int
    track_token: int
    dynamic_numeric: np.ndarray
    global_numeric: np.ndarray
    rank_tokens: np.ndarray
    compound_tokens: np.ndarray
    alive_mask: np.ndarray
    lap_time_norm: np.ndarray
    position_norm: np.ndarray
    driver_tokens: np.ndarray
    team_tokens: np.ndarray
    static_numeric: np.ndarray
class SessionTensorBuilder:
    def __init__(self, cfg: TrainerConfig, vocab: LapVocabulary, track_scalers, weather_scaler):
        self.cfg = cfg
        self.vocab = vocab
        self.track_scalers = track_scalers
        self.weather_scaler = weather_scaler
    def _zscore(self, value: float, mean: float, std: float) -> float:
        if pd.isna(value):
            return 0.0
        return float((value - mean) / (std if std > 1e-6 else 1.0))
    def build(self, df: pd.DataFrame) -> List[SessionTensor]:
        sessions = []
        for idx, (session_key, sdf) in enumerate(tqdm(df.groupby('session_key', sort=False), desc='Sessions')):
            if CONFIG.debug_sessions and idx >= CONFIG.debug_sessions:
                break
            sdf = sdf.sort_values(['lap_number', 'driver_id']).copy()
            max_lap = int(sdf['lap_number'].max())
            if max_lap < CONFIG.min_laps_per_session:
                continue
            drivers = sorted(sdf['driver_id'].unique().tolist())[: CONFIG.max_drivers]
            driver_slots = {drv: slot for slot, drv in enumerate(drivers)}
            slots = CONFIG.max_drivers
            dyn_dim = 12
            dynamic = np.zeros((max_lap, slots, dyn_dim), dtype=np.float32)
            rank_tokens = np.zeros((max_lap, slots), dtype=np.int64)
            compound_tokens = np.zeros((max_lap, slots), dtype=np.int64)
            alive = np.zeros((max_lap, slots), dtype=np.float32)
            lap_time_norm = np.zeros((max_lap, slots), dtype=np.float32)
            global_feats = np.zeros((max_lap, 8), dtype=np.float32)
            driver_tokens = np.zeros((slots,), dtype=np.int64)
            team_tokens = np.zeros((slots,), dtype=np.int64)
            static_numeric = np.zeros((slots, 2), dtype=np.float32)
            circuit_id = sdf['circuit_id'].iloc[0]
            track_stats = self.track_scalers.get(circuit_id, next(iter(self.track_scalers.values())))
            track_token = self.vocab.track_to_idx.get(circuit_id, 0)
            total_ref = max(max_lap, int(sdf['total_race_laps'].dropna().max() or max_lap))
            for driver_id, rows in sdf.groupby('driver_id'):
                if driver_id not in driver_slots:
                    continue
                slot = driver_slots[driver_id]
                driver_tokens[slot] = self.vocab.driver_to_idx.get(driver_id, 0)
                team_tokens[slot] = self.vocab.team_to_idx.get(rows['team_id'].iloc[0], 0)
                static_numeric[slot, 0] = float(rows['grid_position'].iloc[0] / CONFIG.max_drivers)
                static_numeric[slot, 1] = float(rows['year'].iloc[0] - 2018) / 10.0
                for _, row in rows.iterrows():
                    lap_idx = int(row['lap_number']) - 1
                    pos_norm = (float(row['current_position']) - 1) / max(1, CONFIG.max_drivers - 1)
                    dynamic[lap_idx, slot, 0] = pos_norm
                    dynamic[lap_idx, slot, 1] = float(row['grid_position']) / CONFIG.max_drivers
                    lap_mean, lap_std = track_stats['lap_time_s']
                    lap_norm = self._zscore(row['lap_time_s'], lap_mean, lap_std)
                    dynamic[lap_idx, slot, 2] = lap_norm
                    dynamic[lap_idx, slot, 3] = self._zscore(row['lap_time_prev1'], lap_mean, lap_std)
                    dynamic[lap_idx, slot, 4] = self._zscore(row['lap_time_prev2'], lap_mean, lap_std)
                    gap_mean, gap_std = track_stats['gap_to_leader_s']
                    ahead_mean, ahead_std = track_stats['gap_to_ahead_s']
                    dynamic[lap_idx, slot, 5] = self._zscore(np.log1p(row['gap_to_leader_s']), gap_mean, gap_std)
                    dynamic[lap_idx, slot, 6] = self._zscore(np.log1p(row['gap_to_ahead_s']), ahead_mean, ahead_std)
                    dynamic[lap_idx, slot, 7] = float(row['laps_on_current_tyre']) / 50.0
                    dynamic[lap_idx, slot, 8] = float(row['pit_flag'])
                    dynamic[lap_idx, slot, 9] = float(row['drs_enabled'])
                    dynamic[lap_idx, slot, 10] = float(row['safety_car_this_lap'])
                    dynamic[lap_idx, slot, 11] = float(row['virtual_sc_this_lap'])
                    compound_tokens[lap_idx, slot] = self.vocab.compound_to_idx.get(row['tyre_compound'], 0)
                    rank_tokens[lap_idx, slot] = int(row['current_position']) - 1
                    lap_time_norm[lap_idx, slot] = lap_norm
                    alive[lap_idx, slot] = 1.0
            for lap_idx in range(max_lap):
                lap_no = lap_idx + 1
                frac = lap_no / total_ref
                remaining = (total_ref - lap_no) / total_ref
                global_feats[lap_idx, 0] = frac
                global_feats[lap_idx, 1] = remaining
                row = sdf[sdf['lap_number'] == lap_no].iloc[0]
                for g_idx, col in enumerate(['track_temperature', 'air_temperature', 'humidity', 'wind_speed'], start=2):
                    mean, std = self.weather_scaler[col]
                    global_feats[lap_idx, g_idx] = self._zscore(row[col], mean, std)
                global_feats[lap_idx, 6] = float(row['has_rain'])
                p_mean, p_std = self.weather_scaler['pressure']
                global_feats[lap_idx, 7] = self._zscore(row['pressure'], p_mean, p_std)
            position_norm = rank_tokens.astype(np.float32)
            if CONFIG.max_drivers > 1:
                position_norm /= (CONFIG.max_drivers - 1)
            sessions.append(SessionTensor(
                session_key=session_key,
                year=int(sdf['year'].iloc[0]),
                track_token=track_token,
                dynamic_numeric=dynamic,
                global_numeric=global_feats,
                rank_tokens=rank_tokens,
                compound_tokens=compound_tokens,
                alive_mask=alive,
                lap_time_norm=lap_time_norm,
                position_norm=position_norm,
                driver_tokens=driver_tokens,
                team_tokens=team_tokens,
                static_numeric=static_numeric,
            ))
        return sessions
builder = SessionTensorBuilder(CONFIG, vocab, track_scalers, weather_scaler)
session_tensors = builder.build(raw_df)
print(f"Tensorized sessions: {len(session_tensors)}")


Sessions: 100%|██████████| 168/168 [00:04<00:00, 34.41it/s]

Tensorized sessions: 167






## Dataset: full context, next-lap target
Each sample contains all laps up to `t` and predicts lap `t+1`. We pad contexts per batch and build masks automatically.


In [55]:
class LapContextDataset(Dataset):
    def __init__(self, sessions: Sequence[SessionTensor], cfg: TrainerConfig, years: Tuple[int, ...]):
        self.sessions = [s for s in sessions if s.year in years]
        self.cfg = cfg
        self.indices: List[Tuple[int, int]] = []
        for sess_idx, sess in enumerate(self.sessions):
            T = sess.dynamic_numeric.shape[0]
            for target_idx in range(1, T):
                self.indices.append((sess_idx, target_idx))
    def __len__(self) -> int:
        return len(self.indices)
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        sess_idx, target_idx = self.indices[idx]
        sess = self.sessions[sess_idx]
        ctx_slice = slice(0, target_idx)
        return {
            'dynamic': torch.from_numpy(sess.dynamic_numeric[ctx_slice]).float(),
            'global': torch.from_numpy(sess.global_numeric[ctx_slice]).float(),
            'compound_tokens': torch.from_numpy(sess.compound_tokens[ctx_slice]).long(),
            'rank_tokens': torch.from_numpy(sess.rank_tokens[ctx_slice]).long(),
            'alive_mask': torch.from_numpy(sess.alive_mask[ctx_slice]).float(),
            'context_length': torch.tensor(target_idx, dtype=torch.long),
            'driver_tokens': torch.from_numpy(sess.driver_tokens).long(),
            'team_tokens': torch.from_numpy(sess.team_tokens).long(),
            'static_numeric': torch.from_numpy(sess.static_numeric).float(),
            'track_token': torch.tensor(sess.track_token).long(),
            'target_lap_time': torch.from_numpy(sess.lap_time_norm[target_idx]).float(),
            'target_position': torch.from_numpy(sess.position_norm[target_idx]).float(),
            'target_alive': torch.from_numpy(sess.alive_mask[target_idx]).float(),
        }
def pad_tensor(seq: torch.Tensor, target_len: int) -> torch.Tensor:
    if seq.shape[0] == target_len:
        return seq
    pad_shape = (target_len - seq.shape[0],) + tuple(seq.shape[1:])
    pad_tensor = torch.zeros(pad_shape, dtype=seq.dtype)
    return torch.cat([seq, pad_tensor], dim=0)
def collate_context(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    max_len = max(int(item['context_length']) for item in batch)
    out: Dict[str, torch.Tensor] = {}
    for key in ['dynamic', 'global', 'compound_tokens', 'rank_tokens', 'alive_mask']:
        out[key] = torch.stack([pad_tensor(item[key], max_len) for item in batch], dim=0)
    out['context_length'] = torch.stack([item['context_length'] for item in batch], dim=0)
    out['driver_tokens'] = torch.stack([item['driver_tokens'] for item in batch], dim=0)
    out['team_tokens'] = torch.stack([item['team_tokens'] for item in batch], dim=0)
    out['static_numeric'] = torch.stack([item['static_numeric'] for item in batch], dim=0)
    out['track_token'] = torch.stack([item['track_token'] for item in batch], dim=0)
    out['target_lap_time'] = torch.stack([item['target_lap_time'] for item in batch], dim=0)
    out['target_position'] = torch.stack([item['target_position'] for item in batch], dim=0)
    out['target_alive'] = torch.stack([item['target_alive'] for item in batch], dim=0)
    return out
train_dataset = LapContextDataset(session_tensors, CONFIG, CONFIG.train_years)
val_dataset = LapContextDataset(session_tensors, CONFIG, CONFIG.val_years)
print(f"Train samples: {len(train_dataset)} | Val: {len(val_dataset)}")
train_loader = DataLoader(train_dataset, batch_size=CONFIG.batch_size, shuffle=True, collate_fn=collate_context, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=CONFIG.batch_size, shuffle=False, collate_fn=collate_context, num_workers=0)


Train samples: 4837 | Val: 1412



## Lap context transformer
Driver tokens are concatenated with lap/slot embeddings, dynamic scalars, and global features, then passed through a transformer encoder. We gather the last-lap tokens to predict the next lap.


In [56]:
class LapContextTransformer(nn.Module):
    def __init__(self, cfg: TrainerConfig, vocab: LapVocabulary, dyn_dim: int, global_dim: int, static_dim: int):
        super().__init__()
        self.cfg = cfg
        self.driver_emb = nn.Embedding(vocab.num_drivers + 1, 32, padding_idx=0)
        self.team_emb = nn.Embedding(vocab.num_teams + 1, 16, padding_idx=0)
        self.compound_emb = nn.Embedding(vocab.num_compounds + 1, 8, padding_idx=0)
        self.rank_emb = nn.Embedding(cfg.max_drivers + 1, 8, padding_idx=0)
        self.slot_emb = nn.Embedding(cfg.max_drivers, 8)
        self.track_emb = nn.Embedding(vocab.num_tracks + 1, 16, padding_idx=0)
        self.lap_pos_emb = nn.Embedding(256, 32)
        self.static_proj = nn.Linear(static_dim + 32 + 16, 64)
        self.global_proj = nn.Linear(global_dim + 16 + 32, 64)
        self.token_proj = nn.Linear(dyn_dim + 8 + 8 + 8 + 64 + 64, 256)
        encoder_layer = nn.TransformerEncoderLayer(d_model=256, nhead=8, dim_feedforward=1024, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
        self.time_head = nn.Sequential(nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 1))
        self.position_head = nn.Sequential(nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 1))
    def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        B, T_ctx, N, _ = batch['dynamic'].shape
        dyn = batch['dynamic'].to(DEVICE)
        if DEVICE.type == "mps" and dyn.device.type != "mps":
            raise RuntimeError("Expected MPS tensors; got CPU (fallback). Disable fallback or reduce batch size.")
        rank_tokens = batch['rank_tokens'].to(DEVICE)
        compound_tokens = batch['compound_tokens'].to(DEVICE)
        alive = batch['alive_mask'].to(DEVICE)
        global_feats = batch['global'].to(DEVICE)
        track_tokens = batch['track_token'].to(DEVICE)
        driver_tokens = batch['driver_tokens'].to(DEVICE)
        team_tokens = batch['team_tokens'].to(DEVICE)
        static_numeric = batch['static_numeric'].to(DEVICE)
        context_lengths = batch['context_length'].to(DEVICE)
        driver_emb = self.driver_emb(driver_tokens)
        team_emb = self.team_emb(team_tokens)
        static_feat = torch.cat([static_numeric, driver_emb, team_emb], dim=-1)
        static_feat = self.static_proj(static_feat).unsqueeze(1).expand(-1, T_ctx, -1, -1)
        lap_positions = torch.arange(T_ctx, device=DEVICE).view(1, T_ctx).expand(B, T_ctx)
        lap_emb = self.lap_pos_emb(torch.clamp(lap_positions, max=self.lap_pos_emb.num_embeddings - 1))
        track_emb = self.track_emb(track_tokens).unsqueeze(1).expand(-1, T_ctx, -1)
        global_cat = torch.cat([global_feats, track_emb, lap_emb], dim=-1)
        global_feat = self.global_proj(global_cat).unsqueeze(2).expand(-1, -1, N, -1)
        slot_idx = torch.arange(N, device=DEVICE).view(1, 1, N).expand(B, T_ctx, N)
        slot_emb = self.slot_emb(torch.clamp(slot_idx, max=self.cfg.max_drivers - 1))
        rank_emb = self.rank_emb(torch.clamp(rank_tokens, max=self.cfg.max_drivers))
        compound_emb = self.compound_emb(compound_tokens)
        tokens = torch.cat([dyn, rank_emb, compound_emb, slot_emb, static_feat, global_feat], dim=-1)
        tokens = self.token_proj(tokens)
        tokens = tokens.view(B, T_ctx * N, -1)
        key_padding = ~(alive.view(B, T_ctx * N).bool())
        encoded = self.encoder(tokens, src_key_padding_mask=key_padding)
        last_idx = torch.clamp(context_lengths - 1, min=0)
        base = last_idx * N
        gather = base.unsqueeze(1) + torch.arange(N, device=DEVICE)
        gather = torch.clamp(gather, max=T_ctx * N - 1)
        encoded_last = torch.gather(encoded, 1, gather.unsqueeze(-1).expand(-1, -1, encoded.shape[-1]))
        pred_time = self.time_head(encoded_last).squeeze(-1)
        pred_pos = self.position_head(encoded_last).squeeze(-1)
        return {'lap_time': pred_time, 'position': pred_pos}
model = LapContextTransformer(
    CONFIG,
    vocab,
    dyn_dim=session_tensors[0].dynamic_numeric.shape[-1],
    global_dim=session_tensors[0].global_numeric.shape[-1],
    static_dim=session_tensors[0].static_numeric.shape[-1],
).to(DEVICE)



## Training & evaluation loops
Use SmoothL1 loss on lap time and position, mask retired/padded drivers, and average per batch. The evaluation computes the same metrics without gradients.


In [57]:

loss_fn = nn.SmoothL1Loss(reduction='none')
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG.learning_rate, weight_decay=CONFIG.weight_decay)

def step(model, batch, train=True):
    preds = model(batch)
    alive = batch['target_alive'].to(DEVICE)
    time_tgt = batch['target_lap_time'].to(DEVICE)
    pos_tgt = batch['target_position'].to(DEVICE)

    time_loss = loss_fn(preds['lap_time'], time_tgt)
    pos_loss = loss_fn(preds['position'], pos_tgt)
    denom = max(alive.sum().item(), 1.0)
    time_loss = (time_loss * alive).sum() / denom
    pos_loss = (pos_loss * alive).sum() / denom
    total = time_loss + pos_loss

    if train:
        optimizer.zero_grad()
        total.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

    return {'time_loss': time_loss.item(), 'position_loss': pos_loss.item(), 'total': total.item()}


def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    metrics = []
    print(f"Starting {'train' if train else 'val'} epoch with {len(loader)} batches")
    iterator = tqdm(loader, desc='train' if train else 'val', leave=True, mininterval=0.2, dynamic_ncols=True)
    steps = 0
    with torch.set_grad_enabled(train):
        for batch in iterator:
            stats = step(model, batch, train=train)
            iterator.set_postfix({k: f"{v:.3f}" for k, v in stats.items()})
            metrics.append(stats)
            steps += 1
            limit = CONFIG.max_train_steps_per_epoch if train else CONFIG.max_val_steps
            if limit and steps >= limit:
                break
    if not metrics:
        return {}
    return {k: float(np.mean([m[k] for m in metrics])) for k in metrics[0]}



## Train the model
Adjust `num_epochs`, `max_train_steps_per_epoch`, or `max_val_steps` in CONFIG if you want shorter runs.


In [None]:

train_history = []
val_history = []
for epoch in range(1, CONFIG.num_epochs + 1):
    train_metrics = run_epoch(train_loader, train=True)
    val_metrics = run_epoch(val_loader, train=False)
    train_history.append(train_metrics)
    val_history.append(val_metrics)
    print(f"Epoch {epoch}: train={train_metrics} | val={val_metrics}")


Starting train epoch with 2419 batches


train:  28%|██▊       | 673/2419 [03:42<1:29:33,  3.08s/it, time_loss=0.017, position_loss=0.040, total=0.057]


## Evaluate (optional)
Compute validation MAE on the normalized lap times.


In [10]:

model.eval()
mae_vals = []
with torch.no_grad():
    for batch_idx, batch in enumerate(val_loader):
        stats = step(model, batch, train=False)
        mae_vals.append(stats['time_loss'])
        if CONFIG.max_val_steps and batch_idx + 1 >= CONFIG.max_val_steps:
            break
val_mae = float(np.mean(mae_vals)) if mae_vals else float('nan')
print(f"Validation lap-time MAE (normalized units): {val_mae:.4f}")


Validation lap-time MAE (normalized units): nan



## Save checkpoint
Persist the trained weights along with config and vocab metadata.


In [None]:

checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': CONFIG.__dict__,
    'vocab': {
        'driver_to_idx': vocab.driver_to_idx,
        'team_to_idx': vocab.team_to_idx,
        'compound_to_idx': vocab.compound_to_idx,
        'track_to_idx': vocab.track_to_idx,
    },
}
ckpt_path = Path('models/lap_context_transformer.ckpt')
torch.save(checkpoint, ckpt_path)
print(f'Saved checkpoint to {ckpt_path.resolve()}')


In [68]:

import pandas as pd
import numpy as np
import torch

# Simple lap-by-lap simulation using the trained model and real drivers with light dynamics.
def simulate_race(model, track_id=None, total_laps=12, num_drivers=8, drivers=None, teams=None, seed=0,
                 safety_car_laps=None, vsc_laps=None, rain_laps=None, drs_enabled_laps=None):
    rng = np.random.default_rng(seed)
    model.eval()
    sc_set = set(safety_car_laps or [])
    vsc_set = set(vsc_laps or [])
    rain_set = set(rain_laps or [])
    drs_set = set(drs_enabled_laps or [])
    track_id = track_id or next(iter(track_scalers.keys()))
    track_stats = track_scalers[track_id]
    track_token = vocab.track_to_idx.get(track_id, 0)
    slots = CONFIG.max_drivers
    num_drivers = min(num_drivers, slots)

    if drivers is None:
        top_drivers = (
            raw_df[raw_df['year'].isin(CONFIG.train_years)]
            .groupby('driver_id')
            .size()
            .sort_values(ascending=False)
        )
        drivers = top_drivers.head(num_drivers).index.tolist()
    else:
        drivers = drivers[:num_drivers]

    if teams is None:
        first_team = (
            raw_df.dropna(subset=['team_id'])
            .drop_duplicates('driver_id')
            .set_index('driver_id')['team_id']
        )
        teams = [first_team.get(drv, "TEAM_SIM") for drv in drivers]
    else:
        teams = teams[:num_drivers]

    dyn_dim = session_tensors[0].dynamic_numeric.shape[-1]
    global_dim = session_tensors[0].global_numeric.shape[-1]
    static_dim = session_tensors[0].static_numeric.shape[-1]

    dynamic = np.zeros((total_laps, slots, dyn_dim), dtype=np.float32)
    rank_tokens = np.zeros((total_laps, slots), dtype=np.int64)
    compound_tokens = np.zeros((total_laps, slots), dtype=np.int64)
    alive = np.zeros((total_laps, slots), dtype=np.float32)
    global_feats = np.zeros((total_laps, global_dim), dtype=np.float32)
    driver_tokens = np.zeros((slots,), dtype=np.int64)
    team_tokens = np.zeros((slots,), dtype=np.int64)
    static_numeric = np.zeros((slots, static_dim), dtype=np.float32)
    laps_on_tyre = np.ones((slots,), dtype=np.float32)
    gap_to_leader = np.zeros((slots,), dtype=np.float32)

    compound_idx = vocab.compound_to_idx.get('UNKNOWN', 0)
    compound_tokens[:, :num_drivers] = compound_idx
    alive[:, :num_drivers] = 1.0

    for slot in range(num_drivers):
        driver_tokens[slot] = vocab.driver_to_idx.get(drivers[slot], 0)
        team_tokens[slot] = vocab.team_to_idx.get(teams[slot], 0)
        static_numeric[slot, 0] = float(slot + 1) / CONFIG.max_drivers
        static_numeric[slot, 1] = (CONFIG.train_years[-1] - 2018) / 10.0

    for lap_idx in range(total_laps):
        lap_no = lap_idx + 1
        frac = lap_no / total_laps
        remaining = (total_laps - lap_no) / total_laps
        global_feats[lap_idx, 0] = frac
        global_feats[lap_idx, 1] = remaining
        # sample weather around scaler mean with some variability (z-score units)
        for g_idx, col in enumerate(['track_temperature', 'air_temperature', 'humidity', 'wind_speed'], start=2):
            mean, std = weather_scaler[col]
            # draw actual value then z-score
            val = rng.normal(mean, std)
            global_feats[lap_idx, g_idx] = (val - mean) / (std if std > 1e-6 else 1.0)
        # occasional rain/pressure shifts
        rain_flag = 1.0 if lap_no in rain_set else float(rng.random() < 0.05)
        global_feats[lap_idx, 6] = rain_flag
        p_mean, p_std = weather_scaler['pressure']
        p_val = rng.normal(p_mean, p_std)
        global_feats[lap_idx, 7] = (p_val - p_mean) / (p_std if p_std > 1e-6 else 1.0)

    lap_mean, lap_std = track_stats['lap_time_s']
    gap_mean, gap_std = track_stats['gap_to_leader_s']
    ahead_mean, ahead_std = track_stats['gap_to_ahead_s']

    def gap_z(val, mean, std):
        return float((val - mean) / (std if std > 1e-6 else 1.0))

    # Seed lap 1 with grid order, mean lap time, and small base gaps.
    for slot in range(num_drivers):
        grid_norm = float(slot + 1) / CONFIG.max_drivers
        pos_norm = float(slot) / max(1, CONFIG.max_drivers - 1)
        dynamic[0, slot, 0] = pos_norm
        dynamic[0, slot, 1] = grid_norm
        dynamic[0, slot, 2] = 0.0
        dynamic[0, slot, 3] = 0.0
        dynamic[0, slot, 4] = 0.0
        base_gap = rng.uniform(0.3, 1.2) * slot
        gap_to_leader[slot] = base_gap
        gap_ahead = base_gap - (gap_to_leader[slot - 1] if slot > 0 else 0.0)
        dynamic[0, slot, 5] = gap_z(np.log1p(gap_to_leader[slot]), np.log1p(gap_mean), gap_std)
        dynamic[0, slot, 6] = gap_z(np.log1p(max(gap_ahead, 1e-3)), np.log1p(ahead_mean), ahead_std)
        dynamic[0, slot, 7] = laps_on_tyre[slot] / 50.0
        dynamic[0, slot, 8] = 0.0
        dynamic[0, slot, 9] = 1.0 if 1 in drs_set else 0.0
        dynamic[0, slot, 10] = 1.0 if 1 in sc_set else 0.0
        dynamic[0, slot, 11] = 1.0 if 1 in vsc_set else 0.0
        rank_tokens[0, slot] = slot

    records = []
    for slot in range(num_drivers):
        records.append({
            'lap': 1,
            'driver_id': drivers[slot],
            'driver_slot': slot + 1,
            'pred_position': slot + 1,
            'pred_lap_time_s': lap_mean,
        })

    debug_log = []

    with torch.no_grad():
        for lap_idx in range(1, total_laps):
            context_len = lap_idx
            batch = {
                'dynamic': torch.from_numpy(dynamic[:context_len]).unsqueeze(0).float().to(DEVICE),
                'global': torch.from_numpy(global_feats[:context_len]).unsqueeze(0).float().to(DEVICE),
                'compound_tokens': torch.from_numpy(compound_tokens[:context_len]).unsqueeze(0).long().to(DEVICE),
                'rank_tokens': torch.from_numpy(rank_tokens[:context_len]).unsqueeze(0).long().to(DEVICE),
                'alive_mask': torch.from_numpy(alive[:context_len]).unsqueeze(0).float().to(DEVICE),
                'context_length': torch.tensor([context_len], device=DEVICE),
                'driver_tokens': torch.from_numpy(driver_tokens).unsqueeze(0).long().to(DEVICE),
                'team_tokens': torch.from_numpy(team_tokens).unsqueeze(0).long().to(DEVICE),
                'static_numeric': torch.from_numpy(static_numeric).unsqueeze(0).float().to(DEVICE),
                'track_token': torch.tensor([track_token], device=DEVICE).long(),
            }

            preds = model(batch)
            lap_time_norm = preds['lap_time'][0, :num_drivers].cpu().numpy()
            # add per-driver pace offsets (derived from embedding index) and small noise
            pace_offsets = rng.normal(0.0, 0.05, size=num_drivers)
            lap_time_norm = lap_time_norm + pace_offsets
            lap_time_norm = lap_time_norm + rng.normal(0.0, 0.02, size=num_drivers)
            pos_norm = preds['position'][0, :num_drivers].cpu().numpy()
            pos_norm = np.clip(pos_norm, 0.0, 1.0)

            debug_log.append({
                'lap': lap_idx + 1,
                'min_z': float(lap_time_norm.min()),
                'max_z': float(lap_time_norm.max()),
                'mean_z': float(lap_time_norm.mean()),
            })

            ranks = np.argsort(pos_norm)
            slot_rank = np.zeros(num_drivers, dtype=np.int64)
            slot_rank[ranks] = np.arange(num_drivers)
            rank_tokens[lap_idx, :num_drivers] = slot_rank

            drift = rng.normal(0.0, 0.3, size=num_drivers) + slot_rank * 0.2
            gap_to_leader[:num_drivers] += drift.astype(np.float32)
            gap_to_leader[:num_drivers] = np.maximum(gap_to_leader[:num_drivers], 0.0)
            gap_to_leader[:num_drivers] -= gap_to_leader[:num_drivers].min()
            gap_ahead = np.diff(np.concatenate([[0.0], gap_to_leader[:num_drivers]]))

            for slot in range(num_drivers):
                laps_on_tyre[slot] = min(laps_on_tyre[slot] + 1, 50)
                if lap_idx > 0 and (lap_idx % 8 == 0) and slot == rng.integers(0, num_drivers):
                    laps_on_tyre[slot] = 1
                    dynamic[lap_idx, slot, 8] = 1.0
                dynamic[lap_idx, slot, 0] = pos_norm[slot]
                dynamic[lap_idx, slot, 1] = float(slot + 1) / CONFIG.max_drivers
                dynamic[lap_idx, slot, 2] = lap_time_norm[slot] + rng.normal(0.0, 0.01)
                dynamic[lap_idx, slot, 3] = dynamic[lap_idx - 1, slot, 2]
                dynamic[lap_idx, slot, 4] = dynamic[lap_idx - 2, slot, 2] if lap_idx >= 2 else dynamic[lap_idx - 1, slot, 2]
                dynamic[lap_idx, slot, 5] = gap_z(np.log1p(gap_to_leader[slot]), np.log1p(gap_mean), gap_std)
                gap_ahead_val = max(gap_ahead[slot], 1e-3) if slot < len(gap_ahead) else gap_ahead[-1]
                dynamic[lap_idx, slot, 6] = gap_z(np.log1p(gap_ahead_val), np.log1p(ahead_mean), ahead_std)
                dynamic[lap_idx, slot, 7] = laps_on_tyre[slot] / 50.0
                dynamic[lap_idx, slot, 8] = dynamic[lap_idx, slot, 8] if dynamic[lap_idx, slot, 8] else 0.0
                dynamic[lap_idx, slot, 9] = 1.0 if (lap_idx + 1) in drs_set else 0.0
                dynamic[lap_idx, slot, 10] = 1.0 if (lap_idx + 1) in sc_set else 0.0
                dynamic[lap_idx, slot, 11] = 1.0 if (lap_idx + 1) in vsc_set else 0.0
            lap_times_s = lap_time_norm * lap_std + lap_mean
            for slot in range(num_drivers):
                records.append({
                    'lap': lap_idx + 1,
                    'driver_id': drivers[slot],
                    'driver_slot': slot + 1,
                    'pred_position': int(slot_rank[slot] + 1),
                    'pred_lap_time_s': float(lap_times_s[slot]),
                })

    df = pd.DataFrame(records)
    lap_time_table = df.pivot(index='lap', columns='driver_id', values='pred_lap_time_s')
    position_table = df.pivot(index='lap', columns='driver_id', values='pred_position')
    debug_df = pd.DataFrame(debug_log)
    return df, lap_time_table, position_table, debug_df

sim_df, lap_time_table, position_table, debug_df = simulate_race(
    model,
    track_id=list(track_scalers.keys())[0],
    total_laps=12,
    num_drivers=8,
)
print(sim_df.head(20))
print("Lap times (s):")
print(lap_time_table.round(3))
print("Positions (1=leader):")
print(position_table.astype(int))
print("Predicted lap-time z-score stats per lap:")
print(debug_df)


    lap driver_id  driver_slot  pred_position  pred_lap_time_s
0     1       HAM            1              1       102.686060
1     1       VET            2              2       102.686060
2     1       BOT            3              3       102.686060
3     1       SAI            4              4       102.686060
4     1       PER            5              5       102.686060
5     1       RIC            6              6       102.686060
6     1       RAI            7              7       102.686060
7     1       VER            8              8       102.686060
8     2       HAM            1              1       101.905184
9     2       VET            2              2       102.659774
10    2       BOT            3              3       101.772095
11    2       SAI            4              4       101.824601
12    2       PER            5              5       101.644281
13    2       RIC            6              6       101.654776
14    2       RAI            7              7       101

In [None]:

import torch
import pathlib

ckpt_path = pathlib.Path('models/lap_context_transformer.ckpt')
if ckpt_path.exists():
    torch.serialization.add_safe_globals([pathlib.PosixPath])
    checkpoint = torch.load(ckpt_path, map_location=DEVICE, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print(f"Loaded checkpoint from {ckpt_path}")
else:
    print(f"Checkpoint not found at {ckpt_path}")


In [69]:

# Example simulation: Monaco with supplied grid and event laps
monaco_track = 'monaco' if 'monaco' in track_scalers else list(track_scalers.keys())[0]
driver_ids = [
    'VER','PIA','NOR','LEC','RUS','ALO','OCO','HAM','HUL','STR',
    'BOR','BEA','SAI','TSU','ANT','ALB','HAD','LAW','GAS','COL',
]
teams = [
    'red_bull','mclaren','mclaren','ferrari','mercedes','aston_martin','haas','ferrari','sauber','aston_martin',
    'sauber','haas','williams','red_bull','mercedes','williams','racing_bulls','racing_bulls','alpine','alpine',
]
sc_laps = [12, 13, 14, 15]
vsc_laps = []
rain_laps = [22, 23, 24]
drs_laps = list(range(3, 26))

sim_df, lap_time_table, position_table, debug_df = simulate_race(
    model,
    track_id=monaco_track,
    total_laps=25,
    num_drivers=20,
    drivers=driver_ids,
    teams=teams,
    safety_car_laps=sc_laps,
    vsc_laps=vsc_laps,
    rain_laps=rain_laps,
    drs_enabled_laps=drs_laps,
)
print(sim_df.head(30))
print("Lap times (s):")
print(lap_time_table.round(3))
print("Positions (1=leader):")
print(position_table.astype(int))
print("Predicted lap-time z-score stats per lap:")
print(debug_df)


    lap driver_id  driver_slot  pred_position  pred_lap_time_s
0     1       VER            1              1        81.429972
1     1       PIA            2              2        81.429972
2     1       NOR            3              3        81.429972
3     1       LEC            4              4        81.429972
4     1       RUS            5              5        81.429972
5     1       ALO            6              6        81.429972
6     1       OCO            7              7        81.429972
7     1       HAM            8              8        81.429972
8     1       HUL            9              9        81.429972
9     1       STR           10             10        81.429972
10    1       BOR           11             11        81.429972
11    1       BEA           12             12        81.429972
12    1       SAI           13             13        81.429972
13    1       TSU           14             14        81.429972
14    1       ANT           15             15        81