
# Lap-to-Lap Race State Transformer

This notebook prepares the FastF1-derived lap dataset, tensorizes each race into driver/lap sequences, and trains a multi-head transformer that predicts the next-lap state (lap time, pit flag, tyre compound, retirements) for every car simultaneously. The model follows the architecture outlined in the design notes: per-car tokens enriched with driver/team/track embeddings, autoregressive supervision with teacher forcing, and multi-task losses.



## Environment setup
Install the Python packages required for data processing and model training. PyTorch provides the transformer layers; pandas/numpy handle the preprocessing workflow.


In [41]:

%pip install --quiet torch pandas numpy scikit-learn tqdm


Note: you may need to restart the kernel to use updated packages.



## Imports and configuration
The configuration block centralises hyperparameters so it is easy to tweak sequence length, batch size, train/validation splits, or debugging limits.


In [42]:

import math
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import StandardScaler
from tqdm.auto import tqdm

TORCH_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@dataclass
class TrainerConfig:
    dataset_path: Path
    seq_len: int = 8
    max_drivers: int = 20
    min_laps_per_session: int = 10
    train_years: Tuple[int, ...] = (2018, 2019, 2020, 2021)
    val_years: Tuple[int, ...] = (2023,)
    test_years: Tuple[int, ...] = (2024, 2025)
    batch_size: int = 2
    num_epochs: int = 1
    learning_rate: float = 1e-3
    weight_decay: float = 1e-4
    #max_train_steps_per_epoch: Optional[int] = 50
    #max_val_steps: Optional[int] = 25
    max_train_steps_per_epoch: Optional[int] = None
    max_val_steps: Optional[int] = None
    debug_num_sessions: Optional[int] = None
    seed: int = 42

CONFIG = TrainerConfig(
    dataset_path=Path('fastf1_lap_dataset2.csv'),
)

np.random.seed(CONFIG.seed)
torch.manual_seed(CONFIG.seed)


<torch._C.Generator at 0x7fb2704eac90>


## Data loading helpers
We load the lap-level dataset, clean missing values, derive pit flags, and compute per-track scalers for lap times and gaps. Weather features use global scalers.


In [43]:

def load_lap_dataframe(csv_path: Path) -> pd.DataFrame:
    df = pd.read_csv(csv_path)
    df = df.dropna(subset=['session_key', 'driver_id', 'lap_number'])
    df['lap_number'] = df['lap_number'].astype(int)
    df['current_position'] = df['current_position'].fillna(df['lap_number'])
    df['grid_position'] = df['grid_position'].fillna(df.groupby(['session_key', 'driver_id'])['grid_position'].transform('first'))
    df['grid_position'] = df['grid_position'].fillna(df['current_position'])
    df['laps_on_current_tyre'] = df['laps_on_current_tyre'].fillna(1).clip(lower=1)
    # simple pit detection based on tyre age reset
    df = df.sort_values(['session_key', 'driver_id', 'lap_number'])
    df['pit_flag'] = False
    for (session_key, driver_id), group in df.groupby(['session_key', 'driver_id']):
        laps = group['laps_on_current_tyre'].values
        pit_flags = np.zeros_like(laps, dtype=bool)
        prev = laps[0]
        for idx in range(1, len(laps)):
            pit_flags[idx] = laps[idx] <= prev - 1
            prev = laps[idx]
        df.loc[group.index, 'pit_flag'] = pit_flags
    # fill lap_time gaps per session
    df['lap_time_s'] = df.groupby('session_key')['lap_time_s'].transform(lambda s: s.fillna(s.median()))
    df['lap_time_s'] = df['lap_time_s'].fillna(df['lap_time_s'].median())
    df['gap_to_leader_s'] = df['gap_to_leader_s'].fillna(0)
    df['gap_to_ahead_s'] = df['gap_to_ahead_s'].fillna(0)
    df['tyre_compound'] = df['tyre_compound'].fillna('UNKNOWN')
    df['track_temperature'] = df['track_temperature'].fillna(df['track_temperature'].median())
    df['air_temperature'] = df['air_temperature'].fillna(df['air_temperature'].median())
    df['humidity'] = df['humidity'].fillna(df['humidity'].median())
    df['wind_speed'] = df['wind_speed'].fillna(df['wind_speed'].median())
    df['wind_direction'] = df['wind_direction'].fillna(0)
    df['rainfall'] = df['rainfall'].fillna(False)
    df['has_rain'] = df['has_rain'].fillna(False)
    df['safety_car_this_lap'] = df['safety_car_this_lap'].fillna(False)
    df['virtual_sc_this_lap'] = df['virtual_sc_this_lap'].fillna(False)
    df['drs_enabled'] = df['drs_enabled'].fillna(False)
    df['current_position'] = df['current_position'].clip(lower=1, upper=CONFIG.max_drivers)
    df = df.reset_index(drop=True)
    return df

def compute_track_scalers(df: pd.DataFrame) -> Dict[str, Dict[str, Tuple[float, float]]]:
    stats = {}
    global_defaults = {
        'lap_time_s': (df['lap_time_s'].mean(), df['lap_time_s'].std() or 1.0),
        'gap_to_leader_s': (df['gap_to_leader_s'].mean(), df['gap_to_leader_s'].std() or 1.0),
        'gap_to_ahead_s': (df['gap_to_ahead_s'].mean(), df['gap_to_ahead_s'].std() or 1.0),
    }
    for circuit_id, grp in df.groupby('circuit_id'):
        stats[circuit_id] = {}
        for col in ['lap_time_s', 'gap_to_leader_s', 'gap_to_ahead_s']:
            series = grp[col].dropna()
            if series.empty:
                stats[circuit_id][col] = global_defaults[col]
            else:
                mean = float(series.mean())
                std = float(series.std()) or 1.0
                stats[circuit_id][col] = (mean, std)
    return stats

def compute_weather_scaler(df: pd.DataFrame) -> Dict[str, Tuple[float, float]]:
    scalers = {}
    for col in ['track_temperature', 'air_temperature', 'humidity', 'wind_speed']:
        series = df[col].astype(float)
        scalers[col] = (float(series.mean()), float(series.std() or 1.0))
    return scalers



## Vocabulary builders
We map every categorical field to contiguous integer IDs so the model can use embeddings.


In [44]:

@dataclass
class LapVocabulary:
    driver_to_idx: Dict[str, int]
    team_to_idx: Dict[str, int]
    compound_to_idx: Dict[str, int]
    track_to_idx: Dict[str, int]

    @property
    def num_drivers(self) -> int:
        return len(self.driver_to_idx) + 1  # include padding

    @property
    def num_teams(self) -> int:
        return len(self.team_to_idx) + 1

    @property
    def num_compounds(self) -> int:
        return len(self.compound_to_idx) + 1

    @property
    def num_tracks(self) -> int:
        return len(self.track_to_idx) + 1


def build_vocab(df: pd.DataFrame) -> LapVocabulary:
    drivers = sorted(df['driver_id'].dropna().unique().tolist())
    teams = sorted(df['team_id'].dropna().unique().tolist())
    compounds = sorted(df['tyre_compound'].dropna().unique().tolist())
    tracks = sorted(df['circuit_id'].dropna().unique().tolist())
    driver_to_idx = {drv: idx + 1 for idx, drv in enumerate(drivers)}
    team_to_idx = {team: idx + 1 for idx, team in enumerate(teams)}
    compound_to_idx = {comp: idx + 1 for idx, comp in enumerate(compounds)}
    track_to_idx = {trk: idx + 1 for idx, trk in enumerate(tracks)}
    compound_to_idx.setdefault('UNKNOWN', len(compound_to_idx) + 1)
    return LapVocabulary(driver_to_idx, team_to_idx, compound_to_idx, track_to_idx)



## Session tensor builder
Convert each race (session) into padded tensors. Each session stores dynamic features for shape `[laps, max_drivers, F_dyn]`, categorical tokens, and targets aligned for next-step prediction.


In [45]:

@dataclass
class SessionTensor:
    session_key: str
    year: int
    track_token: int
    dynamic_numeric: np.ndarray  # [T, max_drivers, F_dyn]
    global_numeric: np.ndarray   # [T, F_global]
    rank_tokens: np.ndarray      # [T, max_drivers]
    compound_tokens: np.ndarray  # [T, max_drivers]
    alive_mask: np.ndarray       # [T, max_drivers]
    pit_flags: np.ndarray        # [T, max_drivers]
    lap_time_norm: np.ndarray    # [T, max_drivers]
    driver_tokens: np.ndarray    # [max_drivers]
    team_tokens: np.ndarray      # [max_drivers]
    static_numeric: np.ndarray   # [max_drivers, F_static]


class SessionTensorBuilder:
    def __init__(self, cfg: TrainerConfig, vocab: LapVocabulary, track_scalers, weather_scaler):
        self.cfg = cfg
        self.vocab = vocab
        self.track_scalers = track_scalers
        self.weather_scaler = weather_scaler

    def _zscore(self, value, mean, std):
        if pd.isna(value):
            return 0.0
        return float((value - mean) / std)

    def build(self, df: pd.DataFrame) -> List[SessionTensor]:
        sessions: List[SessionTensor] = []
        grouped = df.groupby('session_key', sort=False)
        for session_idx, (session_key, sdf) in enumerate(tqdm(grouped, desc='Sessions')):
            if self.cfg.debug_num_sessions and session_idx >= self.cfg.debug_num_sessions:
                break
            sdf = sdf.sort_values(['lap_number', 'driver_id']).copy()
            max_lap = int(sdf['lap_number'].max())
            if max_lap < self.cfg.min_laps_per_session:
                continue
            drivers = sorted(sdf['driver_id'].unique().tolist())
            if len(drivers) > self.cfg.max_drivers:
                drivers = drivers[: self.cfg.max_drivers]
            driver_to_slot = {drv: idx for idx, drv in enumerate(drivers)}
            slots = self.cfg.max_drivers
            dyn_features = 12
            dynamic = np.zeros((max_lap, slots, dyn_features), dtype=np.float32)
            rank_tokens = np.zeros((max_lap, slots), dtype=np.int64)
            compound_tokens = np.zeros((max_lap, slots), dtype=np.int64)
            alive = np.zeros((max_lap, slots), dtype=np.float32)
            pit_flags = np.zeros((max_lap, slots), dtype=np.float32)
            lap_time_norm = np.zeros((max_lap, slots), dtype=np.float32)
            global_feats = np.zeros((max_lap, 8), dtype=np.float32)
            driver_tokens = np.zeros((slots,), dtype=np.int64)
            team_tokens = np.zeros((slots,), dtype=np.int64)
            static_numeric = np.zeros((slots, 2), dtype=np.float32)
            circuit_id = sdf['circuit_id'].iloc[0]
            track_stats = self.track_scalers.get(circuit_id, self.track_scalers[next(iter(self.track_scalers))])
            track_token = self.vocab.track_to_idx.get(circuit_id, 0)
            total_laps = int(sdf['total_race_laps'].dropna().max() or max_lap)
            total_laps = max(total_laps, max_lap)
            for driver_id, driver_rows in sdf.groupby('driver_id'):
                if driver_id not in driver_to_slot:
                    continue
                slot = driver_to_slot[driver_id]
                driver_tokens[slot] = self.vocab.driver_to_idx.get(driver_id, 0)
                team_name = driver_rows['team_id'].iloc[0]
                team_tokens[slot] = self.vocab.team_to_idx.get(team_name, 0)
                static_numeric[slot, 0] = float(driver_rows['grid_position'].iloc[0] / self.cfg.max_drivers)
                static_numeric[slot, 1] = float(driver_rows['year'].iloc[0] - 2018) / 10.0
                for _, row in driver_rows.iterrows():
                    lap_idx = int(row['lap_number']) - 1
                    dynamic[lap_idx, slot, 0] = (float(row['current_position']) - 1) / (self.cfg.max_drivers - 1)
                    dynamic[lap_idx, slot, 1] = float(row['grid_position']) / self.cfg.max_drivers
                    lap_mean, lap_std = track_stats['lap_time_s']
                    lap_norm = self._zscore(row['lap_time_s'], lap_mean, lap_std)
                    dynamic[lap_idx, slot, 2] = lap_norm
                    dynamic[lap_idx, slot, 3] = self._zscore(row.get('lap_time_prev1', row['lap_time_s']), lap_mean, lap_std)
                    dynamic[lap_idx, slot, 4] = self._zscore(row.get('lap_time_prev2', row['lap_time_s']), lap_mean, lap_std)
                    gap_mean, gap_std = track_stats['gap_to_leader_s']
                    ahead_mean, ahead_std = track_stats['gap_to_ahead_s']
                    dynamic[lap_idx, slot, 5] = self._zscore(np.log1p(row['gap_to_leader_s']), np.log1p(gap_mean), gap_std or 1.0)
                    dynamic[lap_idx, slot, 6] = self._zscore(np.log1p(row['gap_to_ahead_s']), np.log1p(ahead_mean), ahead_std or 1.0)
                    dynamic[lap_idx, slot, 7] = float(row['laps_on_current_tyre']) / 50.0
                    dynamic[lap_idx, slot, 8] = float(row['pit_flag'])
                    dynamic[lap_idx, slot, 9] = float(row['drs_enabled'])
                    dynamic[lap_idx, slot,10] = float(row['safety_car_this_lap'])
                    dynamic[lap_idx, slot,11] = float(row['virtual_sc_this_lap'])
                    compound_tokens[lap_idx, slot] = self.vocab.compound_to_idx.get(row['tyre_compound'], 0)
                    rank_tokens[lap_idx, slot] = int(row['current_position']) - 1
                    lap_time_norm[lap_idx, slot] = lap_norm
                    pit_flags[lap_idx, slot] = float(row['pit_flag'])
                    alive[lap_idx, slot] = 1.0
            for lap_idx in range(max_lap):
                lap_no = lap_idx + 1
                frac = lap_no / total_laps
                remaining = (total_laps - lap_no) / total_laps
                global_feats[lap_idx, 0] = frac
                global_feats[lap_idx, 1] = remaining
                track_temp_mu, track_temp_std = self.weather_scaler['track_temperature']
                air_mu, air_std = self.weather_scaler['air_temperature']
                hum_mu, hum_std = self.weather_scaler['humidity']
                wind_mu, wind_std = self.weather_scaler['wind_speed']
                rows = sdf[sdf['lap_number'] == lap_no]
                if rows.empty:
                    continue
                row = rows.iloc[0]
                global_feats[lap_idx, 2] = self._zscore(row['track_temperature'], track_temp_mu, track_temp_std)
                global_feats[lap_idx, 3] = self._zscore(row['air_temperature'], air_mu, air_std)
                global_feats[lap_idx, 4] = self._zscore(row['humidity'], hum_mu, hum_std)
                global_feats[lap_idx, 5] = self._zscore(row['wind_speed'], wind_mu, wind_std)
                global_feats[lap_idx, 6] = float(row['has_rain'])
                global_feats[lap_idx, 7] = float(row['pressure'] if not pd.isna(row['pressure']) else 0.0)
            sessions.append(SessionTensor(
                session_key=session_key,
                year=int(sdf['year'].iloc[0]),
                track_token=track_token,
                dynamic_numeric=dynamic,
                global_numeric=global_feats,
                rank_tokens=rank_tokens,
                compound_tokens=compound_tokens,
                alive_mask=alive,
                pit_flags=pit_flags,
                lap_time_norm=lap_time_norm,
                driver_tokens=driver_tokens,
                team_tokens=team_tokens,
                static_numeric=static_numeric,
            ))
        return sessions



## Build session tensors
Load the CSV, create scalers/vocabulary, and construct tensors for every race.


In [46]:

raw_df = load_lap_dataframe(CONFIG.dataset_path)
# derive lagged lap times for dynamic context
raw_df['lap_time_prev1'] = raw_df.groupby(['session_key', 'driver_id'])['lap_time_s'].shift(1)
raw_df['lap_time_prev2'] = raw_df.groupby(['session_key', 'driver_id'])['lap_time_s'].shift(2)
raw_df['lap_time_prev1'] = raw_df['lap_time_prev1'].fillna(raw_df['lap_time_s'])
raw_df['lap_time_prev2'] = raw_df['lap_time_prev2'].fillna(raw_df['lap_time_s'])

track_scalers = compute_track_scalers(raw_df)
weather_scaler = compute_weather_scaler(raw_df)
vocab = build_vocab(raw_df)

builder = SessionTensorBuilder(CONFIG, vocab, track_scalers, weather_scaler)
session_tensors = builder.build(raw_df)
print(f"Total sessions tensorized: {len(session_tensors)}")


Sessions: 100%|██████████| 168/168 [00:10<00:00, 16.78it/s]

Total sessions tensorized: 167






## Dataset and DataLoader
Create sequence samples with teacher forcing. Each sample consists of `seq_len` laps and predicts lap `t+1` targets.


In [47]:

class LapSequenceDataset(Dataset):
    def __init__(self, sessions: List[SessionTensor], cfg: TrainerConfig, split_years: Tuple[int, ...]):
        self.sessions = [s for s in sessions if s.year in split_years]
        self.cfg = cfg
        self.indices = []  # (session_idx, lap_start)
        for sess_idx, sess in enumerate(self.sessions):
            T = sess.dynamic_numeric.shape[0]
            max_start = T - cfg.seq_len - 1
            if max_start < 0:
                continue
            for start in range(max_start + 1):
                self.indices.append((sess_idx, start))

    def __len__(self) -> int:
        return len(self.indices)

    def __getitem__(self, idx: int):
        sess_idx, start = self.indices[idx]
        sess = self.sessions[sess_idx]
        seq_slice = slice(start, start + self.cfg.seq_len)
        next_slice = slice(start + 1, start + self.cfg.seq_len + 1)
        dynamic = torch.from_numpy(sess.dynamic_numeric[seq_slice]).float()
        global_feats = torch.from_numpy(sess.global_numeric[seq_slice]).float()
        compound_tokens = torch.from_numpy(sess.compound_tokens[seq_slice]).long()
        rank_tokens = torch.from_numpy(sess.rank_tokens[seq_slice]).long()
        alive_curr = torch.from_numpy(sess.alive_mask[seq_slice]).float()
        alive_next = torch.from_numpy(sess.alive_mask[next_slice]).float()
        lap_time_target = torch.from_numpy(sess.lap_time_norm[next_slice]).float()
        pit_target = torch.from_numpy(sess.pit_flags[next_slice]).float()
        compound_next = torch.from_numpy(sess.compound_tokens[next_slice]).long()
        dnf_target = ((alive_curr == 1.0) & (alive_next == 0.0)).float()
        return {
            'dynamic': dynamic,
            'global': global_feats,
            'compound_tokens': compound_tokens,
            'rank_tokens': rank_tokens,
            'alive_curr': alive_curr,
            'alive_next': alive_next,
            'driver_tokens': torch.from_numpy(sess.driver_tokens).long(),
            'team_tokens': torch.from_numpy(sess.team_tokens).long(),
            'static_numeric': torch.from_numpy(sess.static_numeric).float(),
            'track_token': torch.tensor(sess.track_token).long(),
            'targets': {
                'lap_time': lap_time_target,
                'pit': pit_target,
                'compound': compound_next,
                'dnf': dnf_target,
            },
        }

def collate_batch(batch):
    out = {}
    for key in ['dynamic', 'global', 'compound_tokens', 'rank_tokens', 'alive_curr', 'alive_next']:
        out[key] = torch.stack([item[key] for item in batch], dim=0)
    out['driver_tokens'] = torch.stack([item['driver_tokens'] for item in batch], dim=0)
    out['team_tokens'] = torch.stack([item['team_tokens'] for item in batch], dim=0)
    out['static_numeric'] = torch.stack([item['static_numeric'] for item in batch], dim=0)
    out['track_token'] = torch.stack([item['track_token'] for item in batch], dim=0)
    targets = {}
    for tgt_key in batch[0]['targets'].keys():
        targets[tgt_key] = torch.stack([item['targets'][tgt_key] for item in batch], dim=0)
    out['targets'] = targets
    return out

train_dataset = LapSequenceDataset(session_tensors, CONFIG, CONFIG.train_years)
val_dataset = LapSequenceDataset(session_tensors, CONFIG, CONFIG.val_years)
test_dataset = LapSequenceDataset(session_tensors, CONFIG, CONFIG.test_years)

print(f"Train samples: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=CONFIG.batch_size, shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(val_dataset, batch_size=CONFIG.batch_size, shuffle=False, collate_fn=collate_batch)


Train samples: 4256 | Val: 1216 | Test: 2525



## Model definition
The transformer tokenizes each driver per lap. Static driver/team embeddings and global lap features are concatenated with the dynamic scalars. Multi-task heads predict the next lap time, pit decision, tyre compound, and retirement probability.


In [48]:

class LapStateTransformer(nn.Module):
    def __init__(self, cfg: TrainerConfig, vocab: LapVocabulary, dyn_dim: int, global_dim: int, static_dim: int, rank_vocab: int = 20):
        super().__init__()
        self.cfg = cfg
        d_model = 256
        self.driver_emb = nn.Embedding(vocab.num_drivers + 1, 32, padding_idx=0)
        super().__init__()
        self.cfg = cfg
        d_model = 256
        self.driver_emb = nn.Embedding(vocab.num_drivers + 1, 32, padding_idx=0)
        self.team_emb = nn.Embedding(vocab.num_teams + 1, 16, padding_idx=0)
        self.compound_emb = nn.Embedding(vocab.num_compounds + 1, 8, padding_idx=0)
        self.rank_emb = nn.Embedding(rank_vocab + 1, 8, padding_idx=0)
        self.track_emb = nn.Embedding(vocab.num_tracks + 1, 16, padding_idx=0)
        self.static_proj = nn.Linear(static_dim + 32 + 16, 64)
        self.global_proj = nn.Linear(global_dim + 16, 64)
        self.token_proj = nn.Linear(dyn_dim + 8 + 8 + 64 + 64, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=8, dim_feedforward=1024, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
        self.time_head = nn.Sequential(nn.Linear(d_model, 128), nn.ReLU(), nn.Linear(128, 1))
        self.rank_delta_head = nn.Sequential(nn.Linear(d_model, 128), nn.ReLU(), nn.Linear(128, 1))

    def forward(self, batch):
        B, T, N, F_dyn = batch['dynamic'].shape
        dyn = batch['dynamic'].to(TORCH_DEVICE)
        rank_tokens = batch['rank_tokens'].to(TORCH_DEVICE)
        compound_tokens = batch['compound_tokens'].to(TORCH_DEVICE)
        driver_tokens = batch['driver_tokens'].to(TORCH_DEVICE)
        team_tokens = batch['team_tokens'].to(TORCH_DEVICE)
        static_numeric = batch['static_numeric'].to(TORCH_DEVICE)
        global_feats = batch['global'].to(TORCH_DEVICE)
        track_tokens = batch['track_token'].to(TORCH_DEVICE)
        alive_mask = batch['alive_curr'].to(TORCH_DEVICE)

        driver_emb = self.driver_emb(driver_tokens)
        team_emb = self.team_emb(team_tokens)
        static_cat = torch.cat([static_numeric, driver_emb, team_emb], dim=-1)
        static_feat = self.static_proj(static_cat)
        static_feat = static_feat.unsqueeze(1).expand(-1, T, -1, -1)

        track_emb = self.track_emb(track_tokens).unsqueeze(1).expand(-1, T, -1)
        global_cat = torch.cat([global_feats, track_emb], dim=-1)
        global_feat = self.global_proj(global_cat)
        global_feat = global_feat.unsqueeze(2).expand(-1, -1, N, -1)

        rank_emb = self.rank_emb(rank_tokens)
        compound_emb = self.compound_emb(compound_tokens)

        x = torch.cat([dyn, rank_emb, compound_emb, static_feat, global_feat], dim=-1)
        x = self.token_proj(x)
        x = x.view(B * T, N, -1)
        key_padding_mask = ~(alive_mask.view(B * T, N).bool())
        encoded = self.encoder(x, src_key_padding_mask=key_padding_mask)
        encoded = encoded.view(B, T, N, -1)
        logits_time = self.time_head(encoded).squeeze(-1)
        logits_rank = self.rank_delta_head(encoded).squeeze(-1)
        return {
            'lap_time': logits_time,
            'rank_delta': logits_rank,
        }

class LapStateTransformer(nn.Module):



## Training utilities
Multi-task losses (Huber for lap time, BCE for pit/DNF, cross entropy for compound) are masked by `alive_next` so retired or padded cars do not contribute.


In [49]:

loss_time = nn.SmoothL1Loss(reduction='none')

model = LapStateTransformer(
    cfg=CONFIG,
    vocab=vocab,
    dyn_dim=train_dataset.sessions[0].dynamic_numeric.shape[-1],
    global_dim=train_dataset.sessions[0].global_numeric.shape[-1],
    static_dim=train_dataset.sessions[0].static_numeric.shape[-1],
).to(TORCH_DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG.learning_rate, weight_decay=CONFIG.weight_decay)


def compute_losses(preds, batch):
    targets = batch['targets']
    alive_next = batch['alive_next'].to(TORCH_DEVICE)

    time_pred = preds['lap_time']
    time_tgt = targets['lap_time'].to(TORCH_DEVICE)
    time_loss = loss_time(time_pred, time_tgt)
    time_loss = (time_loss * alive_next).sum() / max(alive_next.sum().item(), 1.0)

    rank_pred = preds['rank_delta']
    rank_tgt = targets['rank_delta'].to(TORCH_DEVICE)
    rank_loss = loss_time(rank_pred, rank_tgt)
    rank_loss = (rank_loss * alive_next).sum() / max(alive_next.sum().item(), 1.0)

    total = time_loss + rank_loss
    metrics = {
        'time_loss': time_loss.item(),
        'rank_loss': rank_loss.item(),
        'total': total.item(),
    }
    return total, metrics


def run_epoch(loader, train: bool = True):
    model.train() if train else model.eval()
    epoch_metrics = []
    steps = 0
    context = tqdm(loader, desc='train' if train else 'val', leave=False)
    for batch in context:
        if train:
            optimizer.zero_grad()
        preds = model(batch)
        loss, metrics = compute_losses(preds, batch)
        if train:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        epoch_metrics.append(metrics)
        context.set_postfix({k: f"{v:.3f}" for k, v in metrics.items()})
        steps += 1
        limit = CONFIG.max_train_steps_per_epoch if train else CONFIG.max_val_steps
        if limit and steps >= limit:
            break
    if not epoch_metrics:
        return {}
    summary = {k: np.mean([m[k] for m in epoch_metrics]) for k in epoch_metrics[0]}
    return summary



## Train for a few epochs
For demonstration we cap the number of batches per epoch. Remove `max_train_steps_per_epoch`/`max_val_steps` to train on the full dataset.


In [50]:

train_history = []
val_history = []
for epoch in range(1, CONFIG.num_epochs + 1):
    train_metrics = run_epoch(train_loader, train=True)
    val_metrics = run_epoch(val_loader, train=False)
    train_history.append(train_metrics)
    val_history.append(val_metrics)
    print(f"Epoch {epoch}: train={train_metrics} | val={val_metrics}")


                                                                                                                                             

Epoch 1: train={'time_loss': 0.25353912541404705, 'pit_loss': 0.1262992090802496, 'compound_loss': 1.5984075872185535, 'dnf_loss': 0.0015930330103201872, 'total': 1.1162111235404373} | val={'time_loss': 0.2703601635041291, 'pit_loss': 0.15496818769392312, 'compound_loss': 1.3027927010859315, 'dnf_loss': 0.0, 'total': 0.9992406088858843}





## Evaluation stub
After training longer, evaluate on the held-out sessions and compute metrics such as lap-time MAE, pit F1, or end-of-race ranking accuracy.


In [54]:

def evaluate_samples(loader, max_batches: int = 10):
    model.eval()
    all_time_errors = []
    with torch.no_grad():
        for batch_idx, batch in enumerate(loader):
            preds = model(batch)
            time_pred = preds['lap_time']
            time_tgt = batch['targets']['lap_time'].to(TORCH_DEVICE)
            alive_next = batch['alive_next'].to(TORCH_DEVICE)
            err = torch.abs(time_pred - time_tgt) * alive_next
            denom = max(alive_next.sum().item(), 1.0)
            all_time_errors.append(err.sum().item() / denom)
            if max_batches and batch_idx + 1 >= max_batches:
                break
    return float(np.mean(all_time_errors)) if all_time_errors else float('nan')

val_mae = evaluate_samples(val_loader, max_batches=CONFIG.max_val_steps)
print(f"Validation lap-time MAE (normalized units): {val_mae:.4f}")


Validation lap-time MAE (normalized units): 0.4934



## Save trained model weights
Run this cell after training to persist the model (plus optimizer, config, and vocab metadata) for later reuse.


In [57]:

from datetime import datetime

checkpoint_path = Path('models/lap_state_transformer.ckpt')
checkpoint = {
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'config': CONFIG.__dict__,
    'vocab': {
        'driver_to_idx': vocab.driver_to_idx,
        'team_to_idx': vocab.team_to_idx,
        'compound_to_idx': vocab.compound_to_idx,
        'track_to_idx': vocab.track_to_idx,
    },
    'metadata': {
        'timestamp': datetime.utcnow().isoformat() + 'Z',
        'train_history': train_history,
        'val_history': val_history,
    }
}
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(checkpoint, checkpoint_path)
print(f"Saved checkpoint to {checkpoint_path.resolve()} at {checkpoint['metadata']['timestamp']}")


Saved checkpoint to /Users/ekazuki/Documents/f1stuff/models/models/lap_state_transformer.ckpt at 2025-12-07T18:46:05.257151Z



## Generate a fictive race example
The helper below takes one of the tensorized sessions, feeds sliding windows through the model, and constructs a synthetic ranking using the predicted lap times, pit probabilities, and DNF odds. It is a teacher-forced rollout (using the session context for inputs) but the leaderboard is based entirely on model outputs, so it behaves like a toy simulator.



## Race simulation helper
This helper reuses the trained model to roll forward a window of laps, denormalizes the predicted lap times, and accumulates them into a synthetic leaderboard. The lap log records each driver's predicted lap time and running position for every generated lap.


In [58]:

inverse_driver_vocab = {idx: name for name, idx in vocab.driver_to_idx.items()}
inverse_track_vocab = {idx: name for name, idx in vocab.track_to_idx.items()}
inverse_compound_vocab = {idx: name for name, idx in vocab.compound_to_idx.items()}

def _session_window_batch(sess: SessionTensor, lap_start: int):
    seq_slice = slice(lap_start, lap_start + CONFIG.seq_len)
    return {
        'dynamic': torch.from_numpy(sess.dynamic_numeric[seq_slice]).unsqueeze(0),
        'global': torch.from_numpy(sess.global_numeric[seq_slice]).unsqueeze(0),
        'compound_tokens': torch.from_numpy(sess.compound_tokens[seq_slice]).unsqueeze(0),
        'rank_tokens': torch.from_numpy(sess.rank_tokens[seq_slice]).unsqueeze(0),
        'alive_curr': torch.from_numpy(sess.alive_mask[seq_slice]).unsqueeze(0),
        'alive_next': torch.from_numpy(sess.alive_mask[seq_slice]).unsqueeze(0),
        'driver_tokens': torch.from_numpy(sess.driver_tokens).unsqueeze(0),
        'team_tokens': torch.from_numpy(sess.team_tokens).unsqueeze(0),
        'static_numeric': torch.from_numpy(sess.static_numeric).unsqueeze(0),
        'track_token': torch.tensor([sess.track_token]),
    }


def _track_stats(track_token: int):
    track_name = inverse_track_vocab.get(int(track_token), None)
    if track_name and track_name in track_scalers:
        stats = track_scalers[track_name]['lap_time_s']
    else:
        fallback = next(iter(track_scalers.keys()))
        track_name = track_name or fallback
        stats = track_scalers[fallback]['lap_time_s']
    return track_name, stats


@torch.no_grad()
def simulate_fictive_race(sess: SessionTensor, laps_to_generate: int = 10, start_lap: int = 0):
    model.eval()
    track_name, (lap_mean, lap_std) = _track_stats(sess.track_token)
    driver_names = [inverse_driver_vocab.get(int(tok), f'Driver_{idx:02d}') for idx, tok in enumerate(sess.driver_tokens) if tok != 0]
    cumulative_times = {name: 0.0 for name in driver_names}
    running_mask = {name: True for name in driver_names}
    lap_rows = []

    for step in range(laps_to_generate):
        lap_start = start_lap + step
        next_lap_idx = lap_start + CONFIG.seq_len
        if next_lap_idx >= sess.dynamic_numeric.shape[0]:
            break
        batch = _session_window_batch(sess, lap_start)
        preds = model(batch)
        lap_pred = preds['lap_time'][0, -1].cpu().numpy()
        pit_probs = torch.sigmoid(preds['pit'][0, -1]).cpu().numpy()
        dnf_probs = torch.sigmoid(preds['dnf'][0, -1]).cpu().numpy()
        compound_logits = torch.softmax(preds['compound'][0, -1], dim=-1).cpu().numpy()
        lap_predictions = {}
        for slot, driver_token in enumerate(sess.driver_tokens):
            if driver_token == 0:
                continue
            driver_name = inverse_driver_vocab.get(int(driver_token), f'Driver_{slot:02d}')
            if not running_mask[driver_name]:
                continue
            lap_seconds = float(lap_pred[slot] * lap_std + lap_mean)
            cumulative_times[driver_name] += lap_seconds
            lap_predictions[driver_name] = {
                'pred_lap_time_s': lap_seconds,
                'pit_probability': float(pit_probs[slot]),
                'dnf_probability': float(dnf_probs[slot]),
                'suggested_compound': inverse_compound_vocab.get(int(np.argmax(compound_logits[slot])), 'UNKNOWN'),
            }
            if lap_predictions[driver_name]['dnf_probability'] > 0.8:
                running_mask[driver_name] = False
        ordering = sorted(cumulative_times.items(), key=lambda kv: kv[1])
        position_map = {name: idx + 1 for idx, (name, _) in enumerate(ordering)}
        for driver_name, info in lap_predictions.items():
            lap_rows.append({
                'lap': next_lap_idx,
                'driver': driver_name,
                'race_position': position_map[driver_name],
                'pred_lap_time_s': info['pred_lap_time_s'],
                'pit_probability': info['pit_probability'],
                'dnf_probability': info['dnf_probability'],
                'suggested_compound': info['suggested_compound'],
            })

    leaderboard = pd.DataFrame([
        {'driver': name, 'simulated_time_s': time, 'still_running': running_mask[name]}
        for name, time in cumulative_times.items()
    ]).sort_values('simulated_time_s').reset_index(drop=True)
    lap_log = pd.DataFrame(lap_rows)
    return track_name, leaderboard, lap_log


In [61]:
example_session = next(sess for sess in session_tensors if sess.year in CONFIG.val_years)
track_name, race_summary, lap_log = simulate_fictive_race(example_session, laps_to_generate=30, start_lap=0)
print(f"Fictive race simulation on {track_name}")
display(race_summary.head(20))
display(lap_log[['lap', 'driver', 'race_position', 'pred_lap_time_s']].head(60))

Fictive race simulation on yas_island


Unnamed: 0,driver,simulated_time_s,still_running
0,GAS,2806.35143,True
1,SAI,2806.35143,True
2,SAR,2806.351431,True
3,RIC,2806.351431,True
4,HUL,2806.351431,True
5,LEC,2806.351431,True
6,OCO,2806.351431,True
7,PIA,2806.351431,True
8,VER,2806.351431,True
9,ZHO,2806.351431,True


Unnamed: 0,lap,driver,race_position,pred_lap_time_s
0,8,ALB,1,93.545048
1,8,ALO,2,93.545048
2,8,BOT,3,93.545048
3,8,GAS,4,93.545048
4,8,HAM,5,93.545048
5,8,HUL,6,93.545048
6,8,LEC,7,93.545048
7,8,MAG,8,93.545048
8,8,NOR,9,93.545048
9,8,OCO,10,93.545048
