In [162]:
%pip install matplotlib pandas numpy torch torchvision torchaudio scikit-learn


Note: you may need to restart the kernel to use updated packages.


In [163]:

import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '0'

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple


In [174]:

DEVICE = torch.device('cpu')
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    x = torch.ones(1, device=DEVICE)
    print(x)
else:
    print("MPS device not found.")

@dataclass
class TrainerConfig:
    dataset_path: Path = Path('fastf1_lap_dataset.csv')
    max_drivers: int = 20
    min_laps_per_session: int = 5
    train_years: Tuple[int, ...] = (2018, 2019, 2020, 2021)
    val_years: Tuple[int, ...] = (2023,)
    test_years: Tuple[int, ...] = (2024, 2025)
    batch_size: int = 2
    num_epochs: int = 50
    learning_rate: float = 1e-3
    weight_decay: float = 1e-4
    max_train_steps_per_epoch: Optional[int] = None
    max_val_steps: Optional[int] = None
    seed: int = 42
    debug_sessions: Optional[int] = None

CONFIG = TrainerConfig()


tensor([1.], device='mps:0')


In [175]:
df.columns

Index(['driver_id', 'team_id', 'circuit_id', 'total_race_laps', 'year',
       'session_name', 'current_position', 'gap_to_leader_s', 'gap_to_ahead_s',
       'lap_time_s', 'laps_on_current_tyre', 'tyre_compound',
       'safety_car_this_lap', 'lap_number', 'drs_enabled', 'track_temperature',
       'air_temperature', 'has_rain', 'session_key', 'race_name', 'team_name',
       'virtual_sc_this_lap', 'humidity', 'pressure', 'rainfall', 'wind_speed',
       'wind_direction', 'circuit_tensor', 'driver_tensor', 'laptime_tensor'],
      dtype='object')

In [176]:

df = pd.read_csv(CONFIG.dataset_path)
df = df.drop("grid_position", axis=1)

df['current_position'] = df['current_position'].fillna(method='ffill')
df['lap_time_s'] = df['lap_time_s'].fillna(method='ffill').fillna(method='bfill')
df['tyre_compound'] = df['tyre_compound'].fillna(method='ffill')
df['laps_on_current_tyre'] = df['laps_on_current_tyre'].fillna(df['laps_on_current_tyre'].median())

print(df.isna().sum())
df[df.isna().any(axis=1)]


# Per-race lap-time normalization helper

def compute_race_stats(df):
    stats = {}
    for sess, g in df.groupby('session_key'):
        mean = g['lap_time_s'].mean()
        std = g['lap_time_s'].std()
        if std is None or std <= 1e-6:
            std = 1.0
        stats[sess] = (float(mean), float(std))
    global_mean = df['lap_time_s'].mean()
    global_std = df['lap_time_s'].std() or 1.0
    return stats, (float(global_mean), float(global_std))

def zscore_lap(time_s: float, mean: float, std: float) -> float:
    return float((time_s - mean) / (std if std > 1e-6 else 1.0))


# compute per-session lap time stats
race_stats, global_stats = compute_race_stats(df)


driver_id               0
team_id                 0
circuit_id              0
total_race_laps         0
year                    0
session_name            0
current_position        0
gap_to_leader_s         0
gap_to_ahead_s          0
lap_time_s              0
laps_on_current_tyre    0
tyre_compound           0
safety_car_this_lap     0
lap_number              0
drs_enabled             0
track_temperature       0
air_temperature         0
has_rain                0
session_key             0
race_name               0
team_name               0
virtual_sc_this_lap     0
humidity                0
pressure                0
rainfall                0
wind_speed              0
wind_direction          0
dtype: int64


  df['current_position'] = df['current_position'].fillna(method='ffill')
  df['lap_time_s'] = df['lap_time_s'].fillna(method='ffill').fillna(method='bfill')
  df['tyre_compound'] = df['tyre_compound'].fillna(method='ffill')


# Create tensors

In [177]:
# Circuit one-hot encoding

unique_circuits = df['circuit_id'].unique().tolist()
circ2idx = {c: i for i, c in enumerate(unique_circuits)}
idx2circ = {i: c for c, i in circ2idx.items()}

circ_idx_full = torch.tensor(
    df['circuit_id'].map(circ2idx).values,
    dtype=torch.long
)
df['circuit_tensor'] = list(torch.nn.functional.one_hot(
    circ_idx_full,
    num_classes=len(unique_circuits)
).float())

# Driver one-hot encoding

unique_drivers = df['driver_id'].unique().tolist()
drv2idx = {d: i for i, d in enumerate(unique_drivers)}
idx2drv = {i: d for d, i in drv2idx.items()}

driv_idx_full = torch.tensor(
    df['driver_id'].map(drv2idx).values,
    dtype=torch.long
)
df['driver_tensor'] = list(torch.nn.functional.one_hot(
    driv_idx_full,
    num_classes=len(unique_drivers)
).float())


# Lap time encoding

df['laptime_tensor'] = df['lap_time_s'].map(
    lambda x: torch.from_numpy(lt_embed(x)).float()
)

print(len(unique_circuits), len(unique_drivers))

tuple(df.groupby('session_key'))


35 43


(('2018_abu_dhabi_grand_prix_race',
      driver_id  team_id  circuit_id  total_race_laps  year session_name  \
  0         ALO  mclaren  yas_marina               55  2018         Race   
  1         ALO  mclaren  yas_marina               55  2018         Race   
  2         ALO  mclaren  yas_marina               55  2018         Race   
  3         ALO  mclaren  yas_marina               55  2018         Race   
  4         ALO  mclaren  yas_marina               55  2018         Race   
  ..        ...      ...         ...              ...   ...          ...   
  938       VET  ferrari  yas_marina               55  2018         Race   
  939       VET  ferrari  yas_marina               55  2018         Race   
  940       VET  ferrari  yas_marina               55  2018         Race   
  941       VET  ferrari  yas_marina               55  2018         Race   
  942       VET  ferrari  yas_marina               55  2018         Race   
  
       current_position  gap_to_leader_s  gap_to_

In [178]:
# Dimensions from encoded features
circuit_dim = df['circuit_tensor'].iloc[0].shape[0]
driver_dim  = df['driver_tensor'].iloc[0].shape[0]
lap_dim     = df['laptime_tensor'].iloc[0].shape[0]

# How many drivers per session? (should be 20 everywhere in F1)
drivers_per_session = df.groupby('session_key')['driver_id'].nunique()
print("Drivers per session value counts:")
print(drivers_per_session.value_counts())

# We assume all sessions have the same number of drivers
num_drivers_used = int(drivers_per_session.max())
#assert drivers_per_session.min() == num_drivers_used, (
#    f"Not all sessions have the same number of drivers: "
#    f"min={drivers_per_session.min()}, max={num_drivers_used}"
#)

# +1 feature for safety_car_this_lap of the TARGET lap
safety_car_dim = 1

input_dim  = circuit_dim + safety_car_dim + num_drivers_used * (driver_dim + lap_dim)
output_dim = num_drivers_used * lap_dim

print("num_drivers_used:", num_drivers_used)
print("input_dim:", input_dim, "output_dim:", output_dim)

Drivers per session value counts:
driver_id
20    160
19      8
Name: count, dtype: int64
num_drivers_used: 20
input_dim: 1536 output_dim: 640


# Create Dataset

In [179]:
import torch
import numpy as np

# ------------------------------------------------------------------
# 1) Baseline lap time tables (computed once)
# ------------------------------------------------------------------

driver_circuit_mean = (
    df.groupby(['circuit_id', 'driver_id'])['lap_time_s']
      .mean()
      .to_dict()
)

circuit_mean = (
    df.groupby('circuit_id')['lap_time_s']
      .mean()
      .to_dict()
)

global_mean = float(df['lap_time_s'].mean())


def get_baseline_laptime(circuit_id, driver_id):
    """
    Returns a reasonable baseline lap time in seconds
    for a given circuit and driver.
    """
    key = (circuit_id, driver_id)
    if key in driver_circuit_mean:
        return float(driver_circuit_mean[key])
    if circuit_id in circuit_mean:
        return float(circuit_mean[circuit_id])
    return global_mean


# ------------------------------------------------------------------
# 2) Precompute driver one-hot and baseline embeddings
# ------------------------------------------------------------------

# driver_id -> driver_tensor (take the first occurrence)
driver_tensor_by_id = (
    df.groupby('driver_id')['driver_tensor']
      .first()
      .to_dict()
)

# cache for baseline embeddings: (circuit_id, driver_id) -> torch.Tensor
baseline_embed_cache = {}

def get_baseline_embed(circuit_id, driver_id):
    key = (circuit_id, driver_id)
    if key not in baseline_embed_cache:
        base_time = get_baseline_laptime(circuit_id, driver_id)
        baseline_embed_cache[key] = torch.from_numpy(lt_embed(base_time)).float()
    return baseline_embed_cache[key]


# ------------------------------------------------------------------
# 3) Build input/target for a single step (lap t -> lap t+1)
# ------------------------------------------------------------------

def build_lap_input(group_lap_prev,
                    group_lap_target,
                    drivers_in_race_sess,
                    circuit_dim,
                    driver_dim,
                    lap_dim):
    """
    Build x_t for the pair (lap_t, lap_{t+1}).

    - group_lap_prev   is df slice for lap t
    - group_lap_target is df slice for lap t+1
    - Input includes:
        * circuit vector of lap t
        * safety_car_this_lap of lap t+1
        * for each driver: driver one hot and lap_t embedding
    """
    circuit_vec = group_lap_prev['circuit_tensor'].iloc[0]
    circuit_id  = group_lap_prev['circuit_id'].iloc[0]

    sc_flag_target = float(group_lap_target['safety_car_this_lap'].iloc[0])
    sc_tensor = torch.tensor([sc_flag_target], dtype=torch.float32)

    # index previous lap by driver_id once
    group_prev_idx = group_lap_prev.set_index('driver_id')

    driver_chunks = []

    for d in drivers_in_race_sess:
        if d in group_prev_idx.index:
            row_prev = group_prev_idx.loc[d]
            drv_one_hot = row_prev['driver_tensor']
            lap_embed   = row_prev['laptime_tensor']   # lap t embedding
        else:
            # e.g. driver missing on lap t, use baseline
            drv_one_hot = driver_tensor_by_id.get(d, torch.zeros(driver_dim))
            lap_embed   = get_baseline_embed(circuit_id, d)

        driver_chunks.append(torch.cat([drv_one_hot, lap_embed], dim=0))

    driver_part = torch.cat(driver_chunks, dim=0)
    # [circuit_vec, safety_car_flag_for_target_lap, per-driver blocks]
    x_t = torch.cat([circuit_vec, sc_tensor, driver_part], dim=0)
    return x_t


def build_lap_target(group_lap_next, drivers_in_race_sess, lap_dim):
    """
    Target is lap_{t+1} embeddings for the same session drivers.
    """
    circuit_id = group_lap_next['circuit_id'].iloc[0]
    group_idx = group_lap_next.set_index('driver_id')

    target_chunks = []

    for d in drivers_in_race_sess:
        if d in group_idx.index:
            row = group_idx.loc[d]
            lap_embed = row['laptime_tensor']          # lap t+1 embedding
        else:
            lap_embed = get_baseline_embed(circuit_id, d)

        target_chunks.append(lap_embed)

    y_t1 = torch.cat(target_chunks, dim=0)
    return y_t1


# ------------------------------------------------------------------
# 4) Build all_X, all_Y over all sessions
# ------------------------------------------------------------------

all_X = []
all_Y = []

# for testing: store first and last session driver lists
session_keys_sorted = sorted(df['session_key'].unique())
first_session_key = session_keys_sorted[0]
last_session_key  = session_keys_sorted[-1]

drivers_first_session = None
drivers_last_session  = None

for session_key, df_sess in df.groupby('session_key'):
    df_sess = df_sess.sort_values(['lap_number', 'driver_id'])

    drivers_in_race_sess = sorted(df_sess['driver_id'].unique())
    if len(drivers_in_race_sess) != num_drivers_used:
        print(f"Skipping session {session_key}: "
              f"{len(drivers_in_race_sess)} drivers (expected {num_drivers_used})")
        continue

    if session_key == first_session_key:
        drivers_first_session = drivers_in_race_sess
    if session_key == last_session_key:
        drivers_last_session = drivers_in_race_sess

    laps = sorted(df_sess['lap_number'].unique())
    if len(laps) < 2:
        continue

    # pre-group by lap once for speed
    lap_groups = {lap: g for lap, g in df_sess.groupby('lap_number')}

    seq_X = []
    seq_Y = []

    # build (x_t, y_{t+1}) pairs
    for i in range(len(laps) - 1):
        lap_t   = laps[i]
        lap_tp1 = laps[i + 1]

        group_lap_t   = lap_groups[lap_t]
        group_lap_tp1 = lap_groups[lap_tp1]

        x_t  = build_lap_input(
            group_lap_t,
            group_lap_tp1,
            drivers_in_race_sess,
            circuit_dim,
            driver_dim,
            lap_dim
        )
        y_tp1 = build_lap_target(
            group_lap_tp1,
            drivers_in_race_sess,
            lap_dim
        )

        seq_X.append(x_t)
        seq_Y.append(y_tp1)

    X_seq = torch.stack(seq_X, dim=0)   # [T, input_dim]
    Y_seq = torch.stack(seq_Y, dim=0)   # [T, output_dim]

    all_X.append(X_seq)
    all_Y.append(Y_seq)

print("Example shapes from first built sequence:")
print(all_X[0].shape, all_Y[0].shape)

# ------------------------------------------------------------------
# 5) Test: show drivers for first and last race
# ------------------------------------------------------------------

print("\nFirst session key:", first_session_key)
print("Drivers in first race (alphabetical):")
print(drivers_first_session)

print("\nLast session key:", last_session_key)
print("Drivers in last race (alphabetical):")
print(drivers_last_session)


Skipping session 2021_abu_dhabi_grand_prix_race: 19 drivers (expected 20)
Skipping session 2023_azerbaijan_grand_prix_sprint: 19 drivers (expected 20)
Skipping session 2023_qatar_grand_prix_race: 19 drivers (expected 20)
Skipping session 2023_singapore_grand_prix_race: 19 drivers (expected 20)
Skipping session 2024_australian_grand_prix_race: 19 drivers (expected 20)
Skipping session 2024_são_paulo_grand_prix_race: 19 drivers (expected 20)
Skipping session 2025_miami_grand_prix_sprint: 19 drivers (expected 20)
Skipping session 2025_spanish_grand_prix_race: 19 drivers (expected 20)
Example shapes from first built sequence:
torch.Size([54, 1536]) torch.Size([54, 640])

First session key: 2018_abu_dhabi_grand_prix_race
Drivers in first race (alphabetical):
['ALO', 'BOT', 'ERI', 'GAS', 'GRO', 'HAM', 'HAR', 'HUL', 'LEC', 'MAG', 'OCO', 'PER', 'RAI', 'RIC', 'SAI', 'SIR', 'STR', 'VAN', 'VER', 'VET']

Last session key: 2025_united_states_grand_prix_sprint
Drivers in last race (alphabetical):
['

In [180]:
class LapTimeLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, dropout=0.0):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout
        )
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # x: [batch_size, T, input_dim]
        out, _ = self.lstm(x)     # out: [batch_size, T, hidden_dim]
        out = self.fc(out)        # out: [batch_size, T, output_dim]
        return out

In [187]:
import torch
import torch.nn as nn
import random

hidden_dim = 1200

model = LapTimeLSTM(
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    output_dim=output_dim,
    num_layers=2,
    dropout=0.2
).to(DEVICE)

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-3,
)

# how much more important SC steps are compared to normal ones
SC_WEIGHT = 5.0  # you can tune this: 3, 5, 10, ...

# ----------------------------------------------------
# Train / test split at session level
# ----------------------------------------------------
num_sessions = len(all_X)
train_ratio = 0.8
train_size = int(num_sessions * train_ratio)

indices = list(range(num_sessions))
random.seed(42)          # for reproducibility
random.shuffle(indices)

train_indices = indices[:train_size]
test_indices  = indices[train_size:]

train_X = [all_X[i] for i in train_indices]
train_Y = [all_Y[i] for i in train_indices]
test_X  = [all_X[i] for i in test_indices]
test_Y  = [all_Y[i] for i in test_indices]

print(f"Num sessions: {num_sessions}, train: {len(train_X)}, test: {len(test_X)}")

for epoch in range(CONFIG.num_epochs):
    # ============================
    # TRAIN
    # ============================
    model.train()
    total_loss = 0.0
    total_steps = 0

    total_sc_mse = torch.tensor(0.0, device=DEVICE)
    total_sc_steps = torch.tensor(0.0, device=DEVICE)

    for X_seq, Y_seq in zip(train_X, train_Y):
        X_batch = X_seq.unsqueeze(0).to(DEVICE)     # [1, T, input_dim]
        Y_batch = Y_seq.unsqueeze(0).to(DEVICE)     # [1, T, output_dim]

        optimizer.zero_grad()

        y_hat = model(X_batch)                      # [B, T, output_dim]
        batch_size, T, _ = X_batch.shape

        # per step MSE: [B, T]
        per_step_mse = ((y_hat - Y_batch) ** 2).mean(dim=-1)

        # SC flag at index circuit_dim in input
        sc_flags = X_batch[..., circuit_dim]        # [B, T]
        sc_mask = sc_flags > 0.5                    # [B, T]

        # weights: SC_WEIGHT for SC steps, 1.0 otherwise
        weights = torch.ones_like(per_step_mse, device=DEVICE)
        weights[sc_mask] = SC_WEIGHT

        loss = (per_step_mse * weights).mean()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch_size
        total_steps += batch_size

        if sc_mask.any():
            total_sc_mse += per_step_mse[sc_mask].sum()
            total_sc_steps += sc_mask.sum()

    train_avg_loss = total_loss / max(total_steps, 1)
    if total_sc_steps.item() > 0:
        train_avg_sc_mse = (total_sc_mse / total_sc_steps).item()
        train_sc_steps_int = int(total_sc_steps.item())
    else:
        train_avg_sc_mse = float("nan")
        train_sc_steps_int = 0

    # ============================
    # TEST / VALIDATION
    # ============================
    model.eval()
    val_total_loss = 0.0
    val_total_steps = 0

    val_total_sc_mse = torch.tensor(0.0, device=DEVICE)
    val_total_sc_steps = torch.tensor(0.0, device=DEVICE)

    with torch.no_grad():
        for X_seq, Y_seq in zip(test_X, test_Y):
            X_batch = X_seq.unsqueeze(0).to(DEVICE)
            Y_batch = Y_seq.unsqueeze(0).to(DEVICE)

            y_hat = model(X_batch)
            batch_size, T, _ = X_batch.shape

            per_step_mse = ((y_hat - Y_batch) ** 2).mean(dim=-1)

            sc_flags = X_batch[..., circuit_dim]
            sc_mask = sc_flags > 0.5

            weights = torch.ones_like(per_step_mse, device=DEVICE)
            weights[sc_mask] = SC_WEIGHT

            val_loss = (per_step_mse * weights).mean()

            val_total_loss += val_loss.item() * batch_size
            val_total_steps += batch_size

            if sc_mask.any():
                val_total_sc_mse += per_step_mse[sc_mask].sum()
                val_total_sc_steps += sc_mask.sum()

    val_avg_loss = val_total_loss / max(val_total_steps, 1)
    if val_total_sc_steps.item() > 0:
        val_avg_sc_mse = (val_total_sc_mse / val_total_sc_steps).item()
        val_sc_steps_int = int(val_total_sc_steps.item())
    else:
        val_avg_sc_mse = float("nan")
        val_sc_steps_int = 0

    print(
        f"Epoch {epoch+1:03d}  "
        f"train_loss={train_avg_loss:.6f}  "
        f"train_sc_mse={train_avg_sc_mse:.6f}  "
        f"train_sc_steps={train_sc_steps_int}  |  "
        f"val_loss={val_avg_loss:.6f}  "
        f"val_sc_mse={val_avg_sc_mse:.6f}  "
        f"val_sc_steps={val_sc_steps_int}"
    )


Num sessions: 160, train: 128, test: 32
Epoch 001  train_loss=0.088371  train_sc_mse=0.101757  train_sc_steps=164  |  val_loss=0.078278  val_sc_mse=0.100909  val_sc_steps=29
Epoch 002  train_loss=0.077982  train_sc_mse=0.095171  train_sc_steps=164  |  val_loss=0.075400  val_sc_mse=0.095062  val_sc_steps=29
Epoch 003  train_loss=0.075135  train_sc_mse=0.091366  train_sc_steps=164  |  val_loss=0.074131  val_sc_mse=0.095995  val_sc_steps=29
Epoch 004  train_loss=0.073359  train_sc_mse=0.088175  train_sc_steps=164  |  val_loss=0.073469  val_sc_mse=0.098065  val_sc_steps=29
Epoch 005  train_loss=0.071977  train_sc_mse=0.085540  train_sc_steps=164  |  val_loss=0.072982  val_sc_mse=0.095691  val_sc_steps=29
Epoch 006  train_loss=0.070692  train_sc_mse=0.083294  train_sc_steps=164  |  val_loss=0.073885  val_sc_mse=0.097856  val_sc_steps=29
Epoch 007  train_loss=0.069813  train_sc_mse=0.082748  train_sc_steps=164  |  val_loss=0.072690  val_sc_mse=0.096772  val_sc_steps=29
Epoch 008  train_loss=

In [185]:
import torch

# We simulate one specific session
template_session = df['session_key'].unique()[0]
df_sess = df[df['session_key'] == template_session].sort_values(['lap_number', 'driver_id'])

laps_in_session = sorted(df_sess['lap_number'].unique())
T_sim = len(laps_in_session)

# Drivers in this race (alphabetical)
drivers_in_race = sorted(df_sess['driver_id'].unique())
num_drivers_used = len(drivers_in_race)

# Same dims as training (+1 for safety car flag)
safety_car_dim = 1
input_dim  = circuit_dim + safety_car_dim + num_drivers_used * (driver_dim + lap_dim)
output_dim = num_drivers_used * lap_dim


# 1) Safety car schedule: lap 24 - 28 inclusive -> SC = 1, else 0
#    sc_flags[lap_idx] corresponds to safety_car_this_lap for lap (lap_idx + 1)
sc_flags = [0.0] * T_sim
for lap in range(24, 29):  # 24,25,26,27,28
    if 1 <= lap <= T_sim:
        sc_flags[lap - 1] = 1.0


# 2) Build input for one step (lap_t as previous, SC flag for target lap)
def build_input_from_group(group_lap_prev, sc_flag_target):
    # move circuit tensor to DEVICE
    circuit_vec = group_lap_prev['circuit_tensor'].iloc[0].to(DEVICE)  # [circuit_dim]
    circuit_id  = group_lap_prev['circuit_id'].iloc[0]

    sc_tensor = torch.tensor([sc_flag_target], dtype=torch.float32, device=DEVICE)

    driver_chunks = []
    for d in drivers_in_race:
        row = group_lap_prev[group_lap_prev['driver_id'] == d]

        if len(row) == 0:
            # driver not present on this lap: use baseline
            row_any = df[df['driver_id'] == d]
            if len(row_any) == 0:
                drv_one_hot = torch.zeros(driver_dim, device=DEVICE)
            else:
                drv_one_hot = row_any['driver_tensor'].iloc[0].to(DEVICE)

            base_time = get_baseline_laptime(circuit_id, d)
            lap_embed = torch.from_numpy(lt_embed(base_time)).float().to(DEVICE)
        else:
            drv_one_hot = row['driver_tensor'].iloc[0].to(DEVICE)   # [driver_dim]
            lap_embed   = row['laptime_tensor'].iloc[0].to(DEVICE)  # [lap_dim]

        driver_chunks.append(torch.cat([drv_one_hot, lap_embed], dim=0))

    driver_part = torch.cat(driver_chunks, dim=0)  # [num_drivers * (driver_dim + lap_dim)]
    x_t = torch.cat([circuit_vec, sc_tensor, driver_part], dim=0)  # [input_dim]
    return x_t


def extract_driver_lap_embeds_from_x(x_vec):
    # skip circuit and safety car flag
    driver_part = x_vec[circuit_dim + safety_car_dim:]
    embeds = []

    for i in range(num_drivers_used):
        start = i * (driver_dim + lap_dim)
        drv_block = driver_part[start : start + driver_dim + lap_dim]
        lap_embed = drv_block[driver_dim:]
        embeds.append(lap_embed)

    return torch.stack(embeds, dim=0)  # [num_drivers, lap_dim]


def simulate_race(model, x_start, sc_flags, T_sim):
    """
    x_start encodes:
      - lap 1 times
      - safety car flag for lap 2

    sc_flags[lap_idx] is safety_car_this_lap for lap (lap_idx + 1).
    We predict laps 2..T_sim with the model.
    """
    model.eval()

    xs = [x_start.clone()]  # all tensors on DEVICE
    lap_embeds_list = [extract_driver_lap_embeds_from_x(x_start)]

    with torch.no_grad():
        for lap_idx in range(1, T_sim):
            # build input sequence so far
            X_seq = torch.stack(xs, dim=0).unsqueeze(0).to(DEVICE)  # [1, t, input_dim]
            y_hat_seq = model(X_seq)                                # [1, t, output_dim]
            y_next = y_hat_seq[0, -1]                               # [output_dim], on DEVICE

            lap_embeds_next = y_next.view(num_drivers_used, lap_dim)
            lap_embeds_list.append(lap_embeds_next)

            # prepare x_next for predicting the following lap
            x_prev = xs[-1]
            circuit_vec = x_prev[:circuit_dim]

            # SC flag for next target lap (lap_idx + 1 -> index lap_idx)
            next_lap_idx = lap_idx + 1
            if next_lap_idx < T_sim:
                sc_next = float(sc_flags[next_lap_idx])
            else:
                sc_next = 0.0
            sc_tensor = torch.tensor([sc_next], dtype=torch.float32, device=DEVICE)

            driver_blocks = []
            driver_part_prev = x_prev[circuit_dim + safety_car_dim:]
            for i in range(num_drivers_used):
                start = i * (driver_dim + lap_dim)
                drv_block_prev = driver_part_prev[start : start + driver_dim + lap_dim]
                drv_one_hot = drv_block_prev[:driver_dim]
                lap_embed_new = lap_embeds_next[i]
                driver_blocks.append(torch.cat([drv_one_hot, lap_embed_new], dim=0))

            new_driver_part = torch.cat(driver_blocks, dim=0)
            x_next = torch.cat([circuit_vec, sc_tensor, new_driver_part], dim=0)
            xs.append(x_next)

    lap_embeds_sim = torch.stack(lap_embeds_list, dim=0)  # [T_sim, num_drivers, lap_dim], on DEVICE
    return lap_embeds_sim


# 3) build x_start from real lap 1 and SC flag for lap 2
first_lap = laps_in_session[0]
group_first_lap = df_sess[df_sess['lap_number'] == first_lap]

sc_for_lap2 = sc_flags[1] if T_sim > 1 else 0.0
x_start = build_input_from_group(group_first_lap, sc_for_lap2)  # already on DEVICE

# 4) run simulation
model.to(DEVICE)
lap_embeds_sim = simulate_race(model, x_start, sc_flags, T_sim)

# 5) decode lap times (move to CPU only for decoding)
lap_times_sim = torch.zeros(T_sim, num_drivers_used, dtype=torch.float32)
for t in range(T_sim):
    for i in range(num_drivers_used):
        emb_np = lap_embeds_sim[t, i].detach().cpu().numpy()
        lap_times_sim[t, i] = float(lt_unbed(emb_np))

# 6) print lap by lap lap times
print("Lap by lap simulated lap times:")
for t in range(T_sim):
    print(f"\nLap {t + 1}: (SC={int(sc_flags[t])})")
    for i, driver_id in enumerate(drivers_in_race):
        time_sec = lap_times_sim[t, i].item()
        print(f"  Driver {driver_id}: {time_sec:.3f} s")

# 7) final classification
total_times = lap_times_sim.sum(dim=0)
order = torch.argsort(total_times)

print("\nSimulated race result (total time):")
for rank, idx in enumerate(order.tolist(), start=1):
    driver_id = drivers_in_race[idx]
    print(f"P{rank}: driver {driver_id}  total time {total_times[idx].item():.3f} s")


Lap by lap simulated lap times:

Lap 1: (SC=0)
  Driver ALO: 108.800 s
  Driver BOT: 104.400 s
  Driver ERI: 104.500 s
  Driver GAS: 106.600 s
  Driver GRO: 106.200 s
  Driver HAM: 103.100 s
  Driver HAR: 102.600 s
  Driver HUL: 104.700 s
  Driver LEC: 104.700 s
  Driver MAG: 103.700 s
  Driver OCO: 103.400 s
  Driver PER: 104.400 s
  Driver RAI: 103.400 s
  Driver RIC: 115.300 s
  Driver SAI: 103.300 s
  Driver SIR: 103.100 s
  Driver STR: 103.800 s
  Driver VAN: 104.500 s
  Driver VER: 103.900 s
  Driver VET: 103.700 s

Lap 2: (SC=0)
  Driver ALO: 108.800 s
  Driver BOT: 104.400 s
  Driver ERI: 104.500 s
  Driver GAS: 106.600 s
  Driver GRO: 106.200 s
  Driver HAM: 103.100 s
  Driver HAR: 102.600 s
  Driver HUL: 104.700 s
  Driver LEC: 104.700 s
  Driver MAG: 103.700 s
  Driver OCO: 103.400 s
  Driver PER: 104.400 s
  Driver RAI: 103.400 s
  Driver RIC: 115.300 s
  Driver SAI: 103.300 s
  Driver SIR: 103.100 s
  Driver STR: 103.800 s
  Driver VAN: 104.500 s
  Driver VER: 103.900 s
  

In [128]:
print(df['safety_car_this_lap'].value_counts())
print(df.groupby('safety_car_this_lap')['lap_time_s'].describe())


safety_car_this_lap
False    166412
True       4570
Name: count, dtype: int64
                        count        mean        std     min       25%  \
safety_car_this_lap                                                      
False                166412.0   90.655723  14.196700  55.404  80.46900   
True                   4570.0  106.369652  18.168979  56.926  93.44175   

                         50%        75%      max  
safety_car_this_lap                               
False                 89.395   99.34900  149.985  
True                 105.595  119.29875  149.897  
