# Phase 1

In [7]:
import pandas as pd
from datetime import datetime, timedelta
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split, Dataset
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass, field
from sklearn.preprocessing import StandardScaler
import os

# Define the RaceFeatures dataclass
@dataclass
class RaceFeatures:
    """Data structure for race features"""
    static_features: List[str] = field(default_factory=lambda: [
        'driver_overall_skill', 'driver_circuit_skill', 'driver_consistency',
        'driver_reliability', 'driver_aggression', 'driver_risk_taking',
        'fp1_median_time', 'fp2_median_time', 'fp3_median_time', 'quali_time'
    ])
    
    dynamic_features: List[str] = field(default_factory=lambda: [
        'tire_age', 'fuel_load', 'track_position', 'track_temp',
        'air_temp', 'humidity', 'tire_compound', 'TrackStatus', 'is_pit_lap'
    ])
    
    target: str = 'milliseconds'

# Define the F1Dataset class
class F1Dataset(Dataset):
    def __init__(self, sequences, static_features, targets):
        self.sequences = torch.FloatTensor(sequences)
        self.static_features = torch.FloatTensor(static_features)
        self.targets = torch.FloatTensor(targets)
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        return {
            'sequence': self.sequences[idx],
            'static': self.static_features[idx],
            'target': self.targets[idx]
        }

# Define the F1DataPreprocessor class
class F1DataPreprocessor:
    def __init__(self):
        self.static_scaler = StandardScaler()
        self.dynamic_scaler = StandardScaler()
        self.lap_time_scaler = StandardScaler()
        
    def prepare_sequence_data(self, df: pd.DataFrame, window_size: int = 3) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Prepare sequential data with sliding window and apply scaling
        """
        sequences = []
        static_features = []
        targets = []
        
        # Instantiate RaceFeatures
        race_features = RaceFeatures()
        
        # Sort the dataframe to ensure consistent ordering
        df = df.sort_values(['raceId', 'driverId', 'lap'])
        
        # Group by race and driver
        for (race_id, driver_id), group in df.groupby(['raceId', 'driverId']):
            group = group.sort_values('lap')
            
            # Extract static features (assumed to be constant per driver per race)
            static = group[race_features.static_features].iloc[0].values
            static_features.append(static)
            
            # Extract dynamic features and target
            lap_times = group[race_features.target].values.reshape(-1, 1)  # Shape: (num_laps, 1)
            dynamic = group[race_features.dynamic_features].values  # Shape: (num_laps, num_dynamic_features)
            
            # Apply scaling
            # Note: Scalers should be fitted on the training data to prevent data leakage.
            # Here, for simplicity, we're fitting on the entire dataset. For a real-world scenario,
            # consider splitting the data first before fitting the scalers.
            dynamic_features_to_scale = [col for col in race_features.dynamic_features if col != 'tire_compound']
            tire_compounds = dynamic[:, race_features.dynamic_features.index('tire_compound')].reshape(-1, 1)
            other_dynamic = dynamic[:, [race_features.dynamic_features.index(col) for col in dynamic_features_to_scale]]
            
            lap_times_scaled = self.lap_time_scaler.fit_transform(lap_times).flatten()
            other_dynamic_scaled = self.dynamic_scaler.fit_transform(other_dynamic)
            static_scaled = self.static_scaler.fit_transform(static.reshape(1, -1)).flatten()
            
            dynamic_scaled = np.hstack((tire_compounds, other_dynamic_scaled))
            
            # Create sequences
            # Create sequences
        for i in range(len(lap_times_scaled) - window_size):
            sequence_lap_times = lap_times_scaled[i:i+window_size].reshape(-1, 1)  # Shape: (window_size, 1)
            sequence_dynamic = dynamic_scaled[i:i+window_size]  # Shape: (window_size, num_dynamic_features)
            sequence = np.hstack((sequence_lap_times, sequence_dynamic))  # Shape: (window_size, 1 + num_dynamic_features)
            sequences.append(sequence)
            static_features.append(static_scaled)
            targets.append(lap_times_scaled[i + window_size])
        
        return (np.array(sequences), 
                np.array(static_features), 
                np.array(targets))

    
    def create_train_val_loaders(
        self, 
        sequences: np.ndarray, 
        static_features: np.ndarray, 
        targets: np.ndarray,
        batch_size: int = 32,
        val_split: float = 0.2
    ) -> Tuple[DataLoader, DataLoader]:
        """
        Create train and validation dataloaders with given split ratio
        """
        dataset = F1Dataset(sequences, static_features, targets)
        
        # Calculate lengths for split
        val_size = int(len(dataset) * val_split)
        train_size = len(dataset) - val_size
        
        # Split dataset
        train_dataset, val_dataset = random_split(
            dataset, 
            [train_size, val_size],
            generator=torch.Generator().manual_seed(42)
        )
        
        # Create dataloaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False
        )
        
        return train_loader, val_loader

class F1PredictionModel(nn.Module):
    def __init__(self, 
                 sequence_dim: int,
                 static_dim: int,
                 hidden_dim: int = 64,
                 num_layers: int = 2,
                 dropout_prob: float = 0.5):
        super().__init__()
        
        # LSTM for sequential features with dropout
        self.lstm = nn.LSTM(
            input_size=sequence_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout_prob
        )
        
        # Static features processing with dropout
        self.static_network = nn.Sequential(
            nn.Linear(static_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob)
        )
        
        # Combine everything
        self.final_network = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, sequence, static):
        # Process sequence through LSTM
        lstm_out, _ = self.lstm(sequence)
        lstm_out = lstm_out[:, -1, :]  # Output of the last time step
        
        # Process static features
        static_out = self.static_network(static)
        
        # Combine LSTM output and static features
        combined = torch.cat([lstm_out, static_out], dim=1)
        
        # Final prediction
        prediction = self.final_network(combined)
        
        return prediction.squeeze()


# Define the training function
def train_model(model: nn.Module, 
                train_loader: DataLoader,
                val_loader: DataLoader,
                epochs: int = 10,
                learning_rate: float = 0.001,
                lap_time_scaler: StandardScaler = None,  # Pass the lap time scaler
                device: Optional[str] = None) -> Dict[str, List[float]]:
    """
    Train the model and return training history including MAE in milliseconds
    """
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()
    history = {'train_loss': [], 'val_loss': [], 'train_mae': [], 'val_mae': []}
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_losses = []
        train_maes = []
        for batch in train_loader:
            sequences = batch['sequence'].to(device)
            static = batch['static'].to(device)
            targets = batch['target'].to(device)
            
            optimizer.zero_grad()
            predictions = model(sequences, static)
            loss = criterion(predictions, targets)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            
            # Calculate MAE in normalized scale
            mae = torch.mean(torch.abs(predictions - targets)).item()
            train_maes.append(mae)
        
        # Validation
        model.eval()
        val_losses = []
        val_maes = []
        with torch.no_grad():
            for batch in val_loader:
                sequences = batch['sequence'].to(device)
                static = batch['static'].to(device)
                targets = batch['target'].to(device)
                
                predictions = model(sequences, static)
                loss = criterion(predictions, targets)
                val_losses.append(loss.item())
                
                # Calculate MAE in normalized scale
                mae_normalized = torch.mean(torch.abs(predictions - targets)).item()
                val_maes.append(mae_normalized)
        
        # Record metrics
        train_loss = np.mean(train_losses)
        val_loss = np.mean(val_losses)
        train_mae_normalized = np.mean(train_maes)
        val_mae_normalized = np.mean(val_maes)
        
        # Convert MAE back to milliseconds using the inverse scaler
        if lap_time_scaler:
            train_mae_ms = lap_time_scaler.inverse_transform([[train_mae_normalized]])[0][0]
            val_mae_ms = lap_time_scaler.inverse_transform([[val_mae_normalized]])[0][0]
        else:
            train_mae_ms, val_mae_ms = train_mae_normalized, val_mae_normalized
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_mae'].append(train_mae_ms)
        history['val_mae'].append(val_mae_ms)
        
        print(f'Epoch {epoch+1}/{epochs}:')
        print(f'Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train MAE: {train_mae_ms:.2f} ms, Val MAE: {val_mae_ms:.2f} ms')
    
    return history

def predict_with_uncertainty(model, inputs, n_samples=100):
    model.train()  # Enable dropout layers
    predictions = []
    with torch.no_grad():
        for _ in range(n_samples):
            prediction = model(**inputs).cpu().numpy()
            predictions.append(prediction)
    predictions = np.array(predictions)
    mean_prediction = predictions.mean(axis=0)
    std_prediction = predictions.std(axis=0)
    return mean_prediction, std_prediction

# Define a function to save the model
def save_model_with_preprocessor(model, preprocessor, sequence_dim, static_dim, path: str):
    torch.save({
        'model_state_dict': model.state_dict(),
        'lap_time_scaler': preprocessor.lap_time_scaler,
        'dynamic_scaler': preprocessor.dynamic_scaler, 
        'static_scaler': preprocessor.static_scaler,
        'sequence_dim': sequence_dim,
        'static_dim': static_dim
    }, path)
    print(f"Model and preprocessor saved to {path}")


# Now, integrate your code snippets into data preprocessing
def load_and_preprocess_data() -> pd.DataFrame:
    """
    Load data from CSV files and preprocess it to create the enhanced_laps DataFrame.
    """
    # Load data
    na_values = ['\\N', 'NaN', '']
    lap_times = pd.read_csv('../../data/raw_data/lap_times.csv', na_values=na_values)
    drivers = pd.read_csv('../../data/raw_data/drivers.csv', na_values=na_values)
    races = pd.read_csv('../../data/raw_data/races.csv', na_values=na_values)
    circuits = pd.read_csv('../../data/raw_data/circuits.csv', na_values=na_values)
    pit_stops = pd.read_csv('../../data/raw_data/pit_stops.csv', na_values=na_values)
    pit_stops.rename(columns={'milliseconds' : 'pitstop_milliseconds'}, inplace=True)
    results = pd.read_csv('../../data/raw_data/results.csv', na_values=na_values)
    results.rename(columns={'milliseconds' : 'racetime_milliseconds'}, inplace=True)

    qualifying = pd.read_csv('../../data/raw_data/qualifying.csv', na_values=na_values)
    status = pd.read_csv('../../data/raw_data/status.csv', na_values=na_values)
    weather_data = pd.read_csv('../../data/raw_data/ff1_weather.csv', na_values=na_values)
    practice_sessions = pd.read_csv('../../data/raw_data/ff1_laps.csv', na_values=na_values)
    # Load the tire data
    tire_data = pd.read_csv('../../data/raw_data/ff1_laps.csv', na_values=na_values)

    # Convert date columns to datetime
    races['date'] = pd.to_datetime(races['date'])
    results['date'] = results['raceId'].map(races.set_index('raceId')['date'])
    lap_times['date'] = lap_times['raceId'].map(races.set_index('raceId')['date'])
    
    # Merge dataframes
    laps = lap_times.merge(drivers, on='driverId', how='left')
    print(laps.shape)
    laps = laps.merge(races, on='raceId', how='left', suffixes=('', '_race'))
    laps.rename(columns={'quali_time' : 'quali_date_time'}, inplace=True)
    print(laps.shape)
    laps = laps.merge(circuits, on='circuitId', how='left')
    print(laps.shape)
    laps = laps.merge(results[['raceId', 'driverId', 'positionOrder', 'grid', 'racetime_milliseconds', 'fastestLap', 'statusId']], on=['raceId', 'driverId'], how='left')
    print(laps.shape)
    laps = laps.merge(status, on='statusId', how='left')
    print(laps.shape)
    laps = laps.merge(pit_stops[['raceId', 'driverId', 'lap', 'pitstop_milliseconds']], on=['raceId', 'driverId', 'lap'], how='left')
    print(laps.shape)
    laps['pitstop_milliseconds'].fillna(0, inplace=True)  # Assuming 0 if no pit stop
    print(laps.shape)
    
    # Add weather information
    # Filter weather data to include only the Race session
    weather_data = weather_data[weather_data['SessionName'] == 'R']
    
    # Merge weather data with races to get raceId
    weather_data = weather_data.merge(
        races[['raceId', 'year', 'name']], 
        left_on=['EventName', 'Year'],
        right_on=['name', 'year'],
        how='left'
    )
    
    # Compute cumulative time from the start of the race for each driver
    laps.sort_values(['raceId', 'driverId', 'lap'], inplace=True)
    laps['cumulative_milliseconds'] = laps.groupby(['raceId', 'driverId'])['milliseconds'].cumsum()
    laps['seconds_from_start'] = laps['cumulative_milliseconds'] / 1000
    print(laps.shape)
    
    # Use 'Time' in weather_data as 'seconds_from_start'
    weather_data['seconds_from_start'] = weather_data['Time']
    
    # Standardize text data
    tire_data['Compound'] = tire_data['Compound'].str.upper()
    tire_data['EventName'] = tire_data['EventName'].str.strip().str.upper()
    races['name'] = races['name'].str.strip().str.upper()
    
    # Filter for race sessions only
    tire_data = tire_data[tire_data['SessionName'] == 'R']
    
    # Merge with races to get raceId
    tire_data = tire_data.merge(
        races[['raceId', 'year', 'name']],
        left_on=['Year', 'EventName'],
        right_on=['year', 'name'],
        how='left'
    )
    
    # Map driver codes to driverId
    tire_data['Driver'] = tire_data['Driver'].str.strip().str.upper()
    drivers['code'] = drivers['code'].str.strip().str.upper()
    driver_code_to_id = drivers.set_index('code')['driverId'].to_dict()
    tire_data['driverId'] = tire_data['Driver'].map(driver_code_to_id)
    
    # Rename 'LapNumber' to 'lap' and ensure integer type
    tire_data.rename(columns={'LapNumber': 'lap'}, inplace=True)
    tire_data['lap'] = tire_data['lap'].astype(int)
    laps['lap'] = laps['lap'].astype(int)
    
    # Create compound mapping (ordered from hardest to softest)
    compound_mapping = {
        'UNKNOWN': 0,
        'HARD': 1,
        'MEDIUM': 2,
        'SOFT': 3,
        'INTERMEDIATE': 4,
        'WET': 5
    }
    
    # Merge tire_data with laps
    laps = laps.merge(
        tire_data[['raceId', 'driverId', 'lap', 'Compound', 'TrackStatus']],
        on=['raceId', 'driverId', 'lap'],
        how='left'
    )

    
    # Handle missing compounds and apply numeric encoding
    laps['Compound'].fillna('UNKNOWN', inplace=True)
    laps['tire_compound'] = laps['Compound'].map(compound_mapping)
    
    # Drop the original Compound column if desired
    laps.drop('Compound', axis=1, inplace=True)
    
    # Standardize names
    practice_sessions['EventName'] = practice_sessions['EventName'].str.strip().str.upper()
    races['name'] = races['name'].str.strip().str.upper()
    
    # Merge practice_sessions with races to get raceId
    practice_sessions = practice_sessions.merge(
        races[['raceId', 'year', 'name']],
        left_on=['Year', 'EventName'],
        right_on=['year', 'name'],
        how='left'
    )
    
    # Map driver codes to driverId
    practice_sessions['Driver'] = practice_sessions['Driver'].str.strip().str.upper()
    drivers['code'] = drivers['code'].str.strip().str.upper()
    driver_code_to_id = drivers.set_index('code')['driverId'].to_dict()
    practice_sessions['driverId'] = practice_sessions['Driver'].map(driver_code_to_id)
    
    # Convert LapTime to milliseconds
    practice_sessions['LapTime_ms'] = practice_sessions['LapTime'].apply(lambda x: pd.to_timedelta(x).total_seconds() * 1000)
    
    # Calculate median lap times for each driver in each session
    session_medians = practice_sessions.groupby(['raceId', 'driverId', 'SessionName'])['LapTime_ms'].median().reset_index()
    
    # Pivot the data to have sessions as columns
    session_medians_pivot = session_medians.pivot_table(
        index=['raceId', 'driverId'],
        columns='SessionName',
        values='LapTime_ms'
    ).reset_index()
    
    # Rename columns for clarity
    session_medians_pivot.rename(columns={
        'FP1': 'fp1_median_time',
        'FP2': 'fp2_median_time',
        'FP3': 'fp3_median_time',
        'Q': 'quali_time'
    }, inplace=True)
    
    laps = laps.merge(
    session_medians_pivot,
    on=['raceId', 'driverId'],
    how='left'
    )
    
    # Fill missing practice times with global median or a placeholder value
    global_median_fp1 = laps['fp1_median_time'].median()
    laps['fp1_median_time'].fillna(global_median_fp1, inplace=True)
    
    # Repeat for other sessions
    global_median_fp2 = laps['fp2_median_time'].median()
    laps['fp2_median_time'].fillna(global_median_fp2, inplace=True)
    
    global_median_fp3 = laps['fp3_median_time'].median()
    laps['fp3_median_time'].fillna(global_median_fp3, inplace=True)
    
    global_median_quali = laps['quali_time'].median()
    laps['quali_time'].fillna(global_median_quali, inplace=True)

    
    # Create a binary indicator for pit stops
    laps['is_pit_lap'] = laps['pitstop_milliseconds'].apply(lambda x: 1 if x > 0 else 0)

    
    # Define a function to match weather data to laps
    def match_weather_to_lap(race_laps, race_weather):
        """
        For each lap, find the closest weather measurement in time
        """
        race_laps = race_laps.sort_values('seconds_from_start')
        race_weather = race_weather.sort_values('seconds_from_start')
        merged = pd.merge_asof(
            race_laps,
            race_weather,
            on='seconds_from_start',
            direction='nearest'
        )
        return merged

    # Apply matching per race
    matched_laps_list = []
    for race_id in laps['raceId'].unique():
        print(f'Matching for {race_id}')
        race_laps = laps[laps['raceId'] == race_id]
        race_weather = weather_data[weather_data['raceId'] == race_id]
        
        if not race_weather.empty:
            matched = match_weather_to_lap(race_laps, race_weather)
            print(f"Matched DataFrame shape: {matched.shape}")
            matched_laps_list.append(matched)
        else:
            matched_laps_list.append(race_laps)  # No weather data for this race

    # Concatenate all matched laps
    laps = pd.concat(matched_laps_list, ignore_index=True)
    print(laps.shape)
    
    # Fill missing weather data with default values
    laps['track_temp'] = laps['TrackTemp'].fillna(25.0)
    laps['air_temp'] = laps['AirTemp'].fillna(20.0)
    laps['humidity'] = laps['Humidity'].fillna(50.0)
    
    # Calculate driver aggression and skill
    # Create driver names
    drivers['driver_name'] = drivers['forename'] + ' ' + drivers['surname']
    driver_mapping = drivers[['driverId', 'driver_name']].copy()
    driver_mapping.set_index('driverId', inplace=True)
    driver_names = driver_mapping['driver_name'].to_dict()
    
    # Map statusId to status descriptions
    status_dict = status.set_index('statusId')['status'].to_dict()
    results['status'] = results['statusId'].map(status_dict)
    
    # Calculate driver aggression and skill
    def calculate_aggression(driver_results):
        if len(driver_results) == 0:
            return 0.5  # Default aggression for new drivers
        
        # Only consider recent races for more current behavior
        recent_results = driver_results.sort_values('date', ascending=False).head(20)
        
        # Calculate overtaking metrics
        positions_gained = recent_results['grid'] - recent_results['positionOrder']
        
        # Calculate risk metrics
        dnf_rate = (recent_results['status'] != 'Finished').mean()
        incidents = (recent_results['statusId'].isin([
            4,  # Collision
            5,  # Spun off
            6,  # Accident
            20, # Collision damage
            82, # Collision with another driver
        ])).mean()
        
        # Calculate overtaking success rate (normalized between 0-1)
        positive_overtakes = (positions_gained > 0).sum()
        negative_overtakes = (positions_gained < 0).sum()
        total_overtake_attempts = positive_overtakes + negative_overtakes
        overtake_success_rate = positive_overtakes / total_overtake_attempts if total_overtake_attempts > 0 else 0.5
        
        # Normalize average positions gained (0-1)
        avg_positions_gained = positions_gained[positions_gained > 0].mean() if len(positions_gained[positions_gained > 0]) > 0 else 0
        max_possible_gain = 20  # Maximum grid positions that could be gained
        normalized_gains = np.clip(avg_positions_gained / max_possible_gain, 0, 1)
        
        # Normalize risk factors (0-1)
        normalized_dnf = np.clip(dnf_rate, 0, 1)
        normalized_incidents = np.clip(incidents, 0, 1)
        
        # Calculate component scores (each between 0-1)
        overtaking_component = (normalized_gains * 0.6 + overtake_success_rate * 0.4)
        risk_component = (normalized_dnf * 0.5 + normalized_incidents * 0.5)
        
        # Combine components with weights (ensuring sum of weights = 1)
        weights = {
            'overtaking': 0.4,  # Aggressive overtaking
            'risk': 0.5,       # Risk-taking behavior
            'baseline': 0.1    # Baseline aggression
        }
        
        aggression = (
            overtaking_component * weights['overtaking'] +
            risk_component * weights['risk'] +
            0.5 * weights['baseline']  # Baseline aggression factor
        )
        
        # Add small random variation while maintaining 0-1 bounds
        variation = np.random.normal(0, 0.02)
        aggression = np.clip(aggression + variation, 0, 1)
        
        return aggression
    
    def calculate_skill(driver_data, results_data, circuit_id):
        driver_results = results_data[
            (results_data['driverId'] == driver_data['driverId']) & 
            (results_data['circuitId'] == circuit_id)
        ].sort_values('date', ascending=False).head(10)  # Use last 10 races at circuit
        
        if len(driver_results) == 0:
            return 0.5  # Default skill
        
        # Calculate performance metrics
        avg_finish_pos = driver_results['positionOrder'].mean()
        avg_quali_pos = driver_results['grid'].mean()
        points_per_race = driver_results['points'].mean()
        fastest_laps = (driver_results['rank'] == 1).mean()  # Add fastest lap consideration
        
        # Improved normalization (exponential decay for positions)
        normalized_finish_pos = np.exp(-avg_finish_pos/5) # Better spread of values
        normalized_quali_pos = np.exp(-avg_quali_pos/5)
        
        # Points normalization with improved scaling
        max_points_per_race = 26  # Maximum possible points (25 + 1 fastest lap)
        normalized_points = points_per_race / max_points_per_race
        
        # Weighted combination with more factors
        weights = {
            'finish': 0.35,
            'quali': 0.25,
            'points': 0.25,
            'fastest_laps': 0.15
        }
        
        skill = (
            weights['finish'] * normalized_finish_pos +
            weights['quali'] * normalized_quali_pos +
            weights['points'] * normalized_points +
            weights['fastest_laps'] * fastest_laps
        )
        
        # Add random variation to prevent identical skills
        skill = np.clip(skill + np.random.normal(0, 0.05), 0.1, 1.0)
        
        return skill
    
    # First merge results with races to get circuitId
    results = results.merge(
        races[['raceId', 'circuitId']], 
        on='raceId',
        how='left'
    )

    # Now calculate driver aggression and skill
    driver_aggression = {}
    driver_skill = {}
    for driver_id in drivers['driverId'].unique():
        driver_results = results[results['driverId'] == driver_id]
        aggression = calculate_aggression(driver_results)
        driver_aggression[driver_id] = aggression
        
        # Now we have circuit_id from the merge
        recent_race = driver_results.sort_values('date', ascending=False).head(1)
        if not recent_race.empty:
            circuit_id = recent_race['circuitId'].iloc[0]
            skill = calculate_skill({'driverId': driver_id}, results, circuit_id)
            driver_skill[driver_id] = skill
        else:
            driver_skill[driver_id] = 0.5  # Default skill for new drivers
    
    # Map calculated aggression and skill back to laps DataFrame
    laps['driver_aggression'] = laps['driverId'].map(driver_aggression)
    laps['driver_overall_skill'] = laps['driverId'].map(driver_skill)
    laps['driver_circuit_skill'] = laps['driver_overall_skill']  # For simplicity, using overall skill
    laps['driver_consistency'] = 0.5  # Placeholder
    laps['driver_reliability'] = 0.5  # Placeholder
    laps['driver_risk_taking'] = laps['driver_aggression']  # Assuming similar to aggression
    
    # Dynamic features
    laps['tire_age'] = laps.groupby(['raceId', 'driverId'])['lap'].cumcount()
    laps['fuel_load'] = laps.groupby(['raceId', 'driverId'])['lap'].transform(lambda x: x.max() - x + 1)
    laps['track_position'] = laps['position']  # Assuming 'position' is available in laps data
    
    # Ensure that all required columns are present
    # Create an instance of RaceFeatures
    race_features = RaceFeatures()

    
    laps['TrackStatus'].fillna(1, inplace=True)  # 1 = regular racing status
    
    # Ensure that all required columns are present
    required_columns = race_features.static_features + race_features.dynamic_features
    # Before dropping NaN values
    print("\nNaN counts in required columns:")
    for col in required_columns:
        nan_count = laps[col].isna().sum()
        total_rows = len(laps)
        if nan_count > 0:
            print(f"{col}: {nan_count} NaN values ({(nan_count/total_rows*100):.2f}% of rows)")
    missing_columns = set(required_columns) - set(laps.columns)
    if missing_columns:
        raise ValueError(f"Missing required columns: {missing_columns}")
    
    # Drop rows with missing values in required columns
    laps.to_csv('laps_withNan.csv', index=False)
    laps = laps.dropna(subset=required_columns)
    
    print(laps.shape)
    
    return laps

# Update the main function
def main():
    # Load and preprocess data
    enhanced_laps = load_and_preprocess_data()
    
    enhanced_laps.drop(columns=['position', 'time', 'driverRef', 'number', 'R', 'S', 'code', 'forename', 'surname', 'url_x', 'url_race', 'name_x', 'circuitRef', 'name_y', 'location', 'country', 'url_y', 'positionOrder', 'fastestLap', 'cumulative_milliseconds', 'seconds_from_start', 'raceId_x', 'year_x', 'Time', 'TrackTemp', 'AirTemp', 'Humidity', 'name', 'year_y', 'raceId_y'], inplace=True)
    
    # Save the preprocessed laps DataFrame for inspection
    enhanced_laps.to_csv('enhanced_laps_before_training.csv', index=False)
    print(enhanced_laps.shape)
    
    print("Enhanced laps DataFrame saved to 'enhanced_laps_before_training.csv'")
    
    preprocessor = F1DataPreprocessor()
    sequences, static, targets = preprocessor.prepare_sequence_data(enhanced_laps, window_size=3)
    
    # Create train and validation loaders
    train_loader, val_loader = preprocessor.create_train_val_loaders(
        sequences, 
        static, 
        targets,
        batch_size=32,
        val_split=0.2
    )
    
    # Initialize model
    model = F1PredictionModel(
        sequence_dim=sequences.shape[2],
        static_dim=static.shape[1]
    )
    
    # Train the model
    history = train_model(model, train_loader, val_loader, epochs=100, learning_rate=0.001)
    
    # Save the trained model
    save_model_with_preprocessor(model, preprocessor, sequences.shape[2], static.shape[1], 'f1_prediction_model.pth')

if __name__ == "__main__":
    main()

  practice_sessions = pd.read_csv('../../data/raw_data/ff1_laps.csv', na_values=na_values)
  tire_data = pd.read_csv('../../data/raw_data/ff1_laps.csv', na_values=na_values)


(586171, 15)
(586171, 32)
(586171, 40)
(586171, 45)
(586171, 46)
(586171, 47)
(586171, 47)
(586171, 49)


The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  laps['pitstop_milliseconds'].fillna(0, inplace=True)  # Assuming 0 if no pit stop
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  laps['Compound'].fillna('UNKNOWN', inplace=True)
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate 

Matching for 1
Matching for 2
Matching for 3
Matching for 4
Matching for 5
Matching for 6
Matching for 7
Matching for 8
Matching for 9
Matching for 10
Matching for 11
Matching for 12
Matching for 13
Matching for 14
Matching for 15
Matching for 16
Matching for 17
Matching for 18
Matching for 19
Matching for 20
Matching for 21
Matching for 22
Matching for 23
Matching for 24
Matching for 25
Matching for 26
Matching for 27
Matching for 28
Matching for 29
Matching for 30
Matching for 31
Matching for 32
Matching for 33
Matching for 34
Matching for 35
Matching for 36
Matching for 37
Matching for 38
Matching for 39
Matching for 40
Matching for 41
Matching for 42
Matching for 43
Matching for 44
Matching for 45
Matching for 46
Matching for 47
Matching for 48
Matching for 49
Matching for 50
Matching for 51
Matching for 52
Matching for 53
Matching for 54
Matching for 55
Matching for 56
Matching for 57
Matching for 58
Matching for 59
Matching for 60
Matching for 61
Matching for 62
Matching for 63
M

The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  laps['TrackStatus'].fillna(1, inplace=True)  # 1 = regular racing status



NaN counts in required columns:
tire_age: 133323 NaN values (22.74% of rows)
fuel_load: 133323 NaN values (22.74% of rows)
tire_compound: 11101 NaN values (1.89% of rows)
(451356, 88)
(451356, 58)
Enhanced laps DataFrame saved to 'enhanced_laps_before_training.csv'
Epoch 1/100:
Train Loss: 31112518.0000, Val Loss: 340683.9375, Train MAE: 4373.87 ms, Val MAE: 583.68 ms
Epoch 2/100:
Train Loss: 31255954.0000, Val Loss: 989.1229, Train MAE: 4651.10 ms, Val MAE: 31.44 ms
Epoch 3/100:
Train Loss: 31863176.0000, Val Loss: 214527.6562, Train MAE: 4250.26 ms, Val MAE: 463.17 ms
Epoch 4/100:
Train Loss: 32338770.0000, Val Loss: 423561.0000, Train MAE: 4153.52 ms, Val MAE: 650.82 ms
Epoch 5/100:
Train Loss: 18023134.0000, Val Loss: 487934.0625, Train MAE: 3355.22 ms, Val MAE: 698.52 ms
Epoch 6/100:
Train Loss: 28255366.0000, Val Loss: 488731.0625, Train MAE: 3987.97 ms, Val MAE: 699.09 ms
Epoch 7/100:
Train Loss: 12530812.0000, Val Loss: 509798.0938, Train MAE: 3015.23 ms, Val MAE: 714.00 ms
Ep

# Phase 2

In [8]:
class RaceSimulator:
    def __init__(self, model, preprocessor):
        self.model = model
        self.preprocessor = preprocessor
        self.race_features = RaceFeatures()
        
    def simulate_lap(self, current_state):
        """
        Predict lap time and uncertainty for a single lap.
        """
        # Prepare input data
        sequence = current_state['sequence']
        static = current_state['static']
        
        # Convert to tensors
        sequence_tensor = torch.FloatTensor(sequence).unsqueeze(0)  # Add batch dimension
        static_tensor = torch.FloatTensor(static).unsqueeze(0)
        
        # Predict with uncertainty
        mean_pred, std_pred = predict_with_uncertainty(
            self.model, 
            {'sequence': sequence_tensor, 'static': static_tensor},
            n_samples=50
        )
        
        # Inverse transform the prediction
        lap_time = self.preprocessor.lap_time_scaler.inverse_transform([[mean_pred]])[0][0]
        uncertainty = self.preprocessor.lap_time_scaler.inverse_transform([[std_pred]])[0][0]
        
        return lap_time, uncertainty
    
    def simulate_full_race(self, initial_state, strategy):
        """
        Simulate the entire race lap by lap based on the provided strategy.
        """
        lap_times = []
        uncertainties = []
        current_state = initial_state.copy()
        
        total_laps = strategy.total_laps
        for lap in range(1, total_laps + 1):
            # Update dynamic features based on strategy and lap number
            current_state = self.update_state(current_state, lap, strategy)
            
            # Simulate lap
            lap_time, uncertainty = self.simulate_lap(current_state)
            lap_times.append(lap_time)
            uncertainties.append(uncertainty)
            
            # Update sequence data for the next lap
            current_state['sequence'] = self.update_sequence(
                current_state['sequence'], lap_time, current_state['dynamic']
            )
        
        return lap_times, uncertainties
    
    def update_state(self, state, lap, strategy):
        """
        Update the state for the next lap based on the strategy.
        """
        # Update tire age
        state['dynamic']['tire_age'] += 1
        
        # Check for pit stops
        if lap in strategy.pit_stop_laps:
            # Reset tire age and update tire compound
            state['dynamic']['tire_age'] = 0
            state['dynamic']['tire_compound'] = strategy.pit_stop_compounds[lap]
        
        # Update fuel load
        state['dynamic']['fuel_load'] -= strategy.fuel_consumption_per_lap
        
        # Update other dynamic features as needed
        # ...
        
        return state
    
    def update_sequence(self, sequence, new_lap_time, dynamic_features):
        """
        Update the sequence data with the latest lap information.
        """
        # Remove the oldest lap data
        sequence = sequence[1:]
        # Append the new lap data
        new_sequence_entry = np.hstack((
            [new_lap_time],
            [dynamic_features[feature] for feature in self.race_features.dynamic_features]
        ))
        sequence = np.vstack([sequence, new_sequence_entry])
        return sequence
    
    # Additional helper methods as needed


In [9]:
class RaceStrategy:
    def __init__(self, total_laps, pit_stops):
        """
        pit_stops: List of dictionaries with keys 'lap' and 'compound'
        Example: [{'lap': 15, 'compound': 'MEDIUM'}, {'lap': 30, 'compound': 'SOFT'}]
        """
        self.total_laps = total_laps
        self.pit_stops = pit_stops  # List of pit stop events
        self.pit_stop_laps = [stop['lap'] for stop in pit_stops]
        self.pit_stop_compounds = {stop['lap']: stop['compound'] for stop in pit_stops}
        self.fuel_consumption_per_lap = 1.5  # Example value, adjust as needed
    
    def evaluate(self, simulator, initial_state):
        """
        Simulate the race using this strategy and return total race time.
        """
        lap_times, uncertainties = simulator.simulate_full_race(initial_state, self)
        total_time = sum(lap_times)
        return total_time, lap_times, uncertainties


In [10]:
class RaceStrategyOptimizer:
    def __init__(self, simulator):
        self.simulator = simulator
    
    def optimize(self, initial_conditions, constraints):
        """
        Find the optimal strategy given initial conditions and constraints.
        """
        best_strategy = None
        best_time = float('inf')
        
        # Generate possible strategies within constraints
        possible_strategies = self.generate_strategies(constraints)
        
        # Evaluate each strategy
        for strategy in possible_strategies:
            total_time, _, _ = strategy.evaluate(self.simulator, initial_conditions)
            if total_time < best_time:
                best_time = total_time
                best_strategy = strategy
        
        return best_strategy
    
    def generate_strategies(self, constraints):
        """
        Generate possible strategies based on constraints.
        """
        min_pit_stops = constraints.get('min_pit_stops', 1)
        max_pit_stops = constraints.get('max_pit_stops', 3)
        available_compounds = constraints.get('available_compounds', ['SOFT', 'MEDIUM', 'HARD'])
        total_laps = constraints.get('total_laps', 50)
        
        strategies = []
        
        # Example: Generate strategies with different pit stop laps and compounds
        for num_pit_stops in range(min_pit_stops, max_pit_stops + 1):
            pit_stop_laps_options = self.get_pit_stop_lap_combinations(total_laps, num_pit_stops)
            for pit_stop_laps in pit_stop_laps_options:
                for compounds in self.get_compound_combinations(available_compounds, num_pit_stops):
                    pit_stops = [{'lap': lap, 'compound': comp} for lap, comp in zip(pit_stop_laps, compounds)]
                    strategy = RaceStrategy(total_laps, pit_stops)
                    strategies.append(strategy)
        return strategies
    
    def get_pit_stop_lap_combinations(self, total_laps, num_pit_stops):
        """
        Generate possible pit stop lap combinations.
        """
        from itertools import combinations
        lap_numbers = range(5, total_laps - 5)  # Avoid pitting too early or too late
        return combinations(lap_numbers, num_pit_stops)
    
    def get_compound_combinations(self, compounds, num_pit_stops):
        """
        Generate possible combinations of compounds for pit stops.
        """
        from itertools import product
        return product(compounds, repeat=num_pit_stops)
    
    def load_model_with_preprocessor(path: str, sequence_dim: int, static_dim: int, hidden_dim: int = 64, num_layers: int = 2):
        """
        Load the model and preprocessor from a saved file.
        """
        checkpoint = torch.load(path, map_location=torch.device('cpu'))
        
        # Recreate the model architecture
        model = F1PredictionModel(
            sequence_dim=sequence_dim,
            static_dim=static_dim,
            hidden_dim=hidden_dim,
            num_layers=num_layers
        )
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()  # Set model to evaluation mode
        
        # Recreate the preprocessor
        preprocessor = F1DataPreprocessor()
        preprocessor.lap_time_scaler = checkpoint['lap_time_scaler']
        preprocessor.dynamic_scaler = checkpoint['dynamic_scaler']
        preprocessor.static_scaler = checkpoint['static_scaler']
        
        print(f"Model and preprocessor loaded from {path}")
        return model, preprocessor


In [11]:
def load_model_with_preprocessor(path: str) -> Tuple[F1PredictionModel, F1DataPreprocessor]:
   """Load saved model and preprocessor"""
   checkpoint = torch.load(path, map_location=torch.device('cpu'))
   
   model = F1PredictionModel(
       sequence_dim=checkpoint['sequence_dim'], 
       static_dim=checkpoint['static_dim']
   )
   model.load_state_dict(checkpoint['model_state_dict'])
   
   preprocessor = F1DataPreprocessor()
   preprocessor.lap_time_scaler = checkpoint['lap_time_scaler']
   preprocessor.dynamic_scaler = checkpoint['dynamic_scaler']
   preprocessor.static_scaler = checkpoint['static_scaler']
   
   return model, preprocessor

In [14]:
# The model and preprocessor are now ready for simulation or further training
model, preprocessor = load_model_with_preprocessor('f1_prediction_model.pth')

# Assume 'model' and 'preprocessor' have been loaded or trained
simulator = RaceSimulator(model, preprocessor)

# Step 1: Define static features
static_features_list = [
    'driver_overall_skill', 'driver_circuit_skill', 'driver_consistency',
    'driver_reliability', 'driver_aggression', 'driver_risk_taking',
    'fp1_median_time', 'fp2_median_time', 'fp3_median_time', 'quali_time'
]

static_features = np.array([
    0.8,   # driver_overall_skill
    0.75,  # driver_circuit_skill
    0.7,   # driver_consistency
    0.9,   # driver_reliability
    0.6,   # driver_aggression
    0.5,   # driver_risk_taking
    88000, # fp1_median_time
    87500, # fp2_median_time
    87000, # fp3_median_time
    86000  # quali_time
])

# Step 2: Define initial dynamic features for previous laps
dynamic_features_list = [
    'tire_age', 'fuel_load', 'track_position', 'track_temp',
    'air_temp', 'humidity', 'TrackStatus', 'is_pit_lap'
]

# Common values
track_temp = 35.0
air_temp = 25.0
humidity = 50.0
TrackStatus = 1
is_pit_lap = 0

dynamic_features = np.array([
    [0, 100, 1, track_temp, air_temp, humidity, TrackStatus, is_pit_lap],  # Lap 1
    [1, 98.5, 1, track_temp, air_temp, humidity, TrackStatus, is_pit_lap],  # Lap 2
    [2, 97, 1, track_temp, air_temp, humidity, TrackStatus, is_pit_lap]     # Lap 3
])

# Tire compound (not scaled)
tire_compound = 0  # Example value for 'HARD' tires
tire_compound_column = np.full((dynamic_features.shape[0], 1), tire_compound)

# Lap times in milliseconds
lap_times = np.array([90000, 89500, 89200])  # Laps 1-3

# Step 3: Combine lap times, tire_compound, and dynamic features
sequence = np.hstack((
    lap_times.reshape(-1, 1),  # Lap times
    tire_compound_column,      # Tire compound
    dynamic_features           # Dynamic features
))

# Step 4: Scale the lap times
lap_times_scaled = preprocessor.lap_time_scaler.transform(lap_times.reshape(-1, 1)).flatten()

# Step 5: Scale the dynamic features (excluding 'tire_compound')
dynamic_features_to_scale = sequence[:, 2:]  # Exclude lap times and 'tire_compound'
dynamic_scaled = preprocessor.dynamic_scaler.transform(dynamic_features_to_scale)

# Step 6: Reconstruct the scaled sequence
sequence_scaled = np.hstack((
    lap_times_scaled.reshape(-1, 1),  # Scaled lap times
    tire_compound_column,             # Tire compound (not scaled)
    dynamic_scaled                    # Scaled dynamic features
))

# Step 7: Scale the static features
static_scaled = preprocessor.static_scaler.transform(static_features.reshape(1, -1)).flatten()

# Step 8: Prepare current dynamic features
current_dynamic_features = {
    'tire_age': 3,
    'fuel_load': 95.5,
    'track_position': 1,
    'track_temp': 35.0,
    'air_temp': 25.0,
    'humidity': 50.0,
    'TrackStatus': 1,
    'is_pit_lap': 0,
    'tire_compound': 0
}

# Extract and scale dynamic features (excluding 'tire_compound')
dynamic_feature_values = np.array([
    current_dynamic_features[feature] for feature in dynamic_features_list
])

dynamic_scaled_current = preprocessor.dynamic_scaler.transform(dynamic_feature_values.reshape(1, -1)).flatten()

# Step 9: Update the initial state
initial_state = {
    'sequence': sequence_scaled,   # Shape: (window_size, sequence_dim)
    'static': static_scaled,       # Shape: (static_dim,)
    'dynamic': current_dynamic_features
}


# Define constraints
constraints = {
    'min_pit_stops': 1,
    'max_pit_stops': 2,
    'available_compounds': [3, 2, 1],
    'total_laps': 50
}

# Create optimizer and find the best strategy
optimizer = RaceStrategyOptimizer(simulator)
best_strategy = optimizer.optimize(initial_state, constraints)

# Evaluate the best strategy
total_time, lap_times, uncertainties = best_strategy.evaluate(simulator, initial_state)
print(f"Optimal total race time: {total_time:.2f} seconds")


  checkpoint = torch.load(path, map_location=torch.device('cpu'))


KeyboardInterrupt: 

"""
# %% [markdown]
# # Phase 1

# %%
import pandas as pd
from datetime import datetime, timedelta
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split, Dataset
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass, field
from sklearn.preprocessing import StandardScaler
import os

# Define the RaceFeatures dataclass
@dataclass
class RaceFeatures:
    """Data structure for race features"""
    static_features: List[str] = field(default_factory=lambda: [
        'driver_overall_skill', 'driver_circuit_skill', 'driver_consistency',
        'driver_reliability', 'driver_aggression', 'driver_risk_taking',
        'fp1_median_time', 'fp2_median_time', 'fp3_median_time', 'quali_time'
    ])
    
    dynamic_features: List[str] = field(default_factory=lambda: [
        'tire_age', 'fuel_load', 'track_position', 'track_temp',
        'air_temp', 'humidity', 'tire_compound', 'TrackStatus', 'is_pit_lap'
    ])
    
    target: str = 'milliseconds'

# Define the F1Dataset class
class F1Dataset(Dataset):
    def __init__(self, sequences, static_features, targets):
        self.sequences = torch.FloatTensor(sequences)
        self.static_features = torch.FloatTensor(static_features)
        self.targets = torch.FloatTensor(targets)
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        return {
            'sequence': self.sequences[idx],
            'static': self.static_features[idx],
            'target': self.targets[idx]
        }

# Define the F1DataPreprocessor class
class F1DataPreprocessor:
    def __init__(self):
        self.static_scaler = StandardScaler()
        self.dynamic_scaler = StandardScaler()
        self.lap_time_scaler = StandardScaler()
        
    def prepare_sequence_data(self, df: pd.DataFrame, window_size: int = 3) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Prepare sequential data with sliding window and apply scaling
        """
        sequences = []
        static_features = []
        targets = []
        
        # Instantiate RaceFeatures
        race_features = RaceFeatures()
        
        # Sort the dataframe to ensure consistent ordering
        df = df.sort_values(['raceId', 'driverId', 'lap'])
        
        # Group by race and driver
        for (race_id, driver_id), group in df.groupby(['raceId', 'driverId']):
            group = group.sort_values('lap')
            
            # Extract static features (assumed to be constant per driver per race)
            static = group[race_features.static_features].iloc[0].values
            static_features.append(static)
            
            # Extract dynamic features and target
            lap_times = group[race_features.target].values.reshape(-1, 1)  # Shape: (num_laps, 1)
            dynamic = group[race_features.dynamic_features].values  # Shape: (num_laps, num_dynamic_features)
            
            # Apply scaling
            # Note: Scalers should be fitted on the training data to prevent data leakage.
            # Here, for simplicity, we're fitting on the entire dataset. For a real-world scenario,
            # consider splitting the data first before fitting the scalers.
            dynamic_features_to_scale = [col for col in race_features.dynamic_features if col != 'tire_compound']
            tire_compounds = dynamic[:, race_features.dynamic_features.index('tire_compound')].reshape(-1, 1)
            other_dynamic = dynamic[:, [race_features.dynamic_features.index(col) for col in dynamic_features_to_scale]]
            
            lap_times_scaled = self.lap_time_scaler.fit_transform(lap_times).flatten()
            other_dynamic_scaled = self.dynamic_scaler.fit_transform(other_dynamic)
            static_scaled = self.static_scaler.fit_transform(static.reshape(1, -1)).flatten()
            
            dynamic_scaled = np.hstack((tire_compounds, other_dynamic_scaled))
            
            # Create sequences
            # Create sequences
        for i in range(len(lap_times_scaled) - window_size):
            sequence_lap_times = lap_times_scaled[i:i+window_size].reshape(-1, 1)  # Shape: (window_size, 1)
            sequence_dynamic = dynamic_scaled[i:i+window_size]  # Shape: (window_size, num_dynamic_features)
            sequence = np.hstack((sequence_lap_times, sequence_dynamic))  # Shape: (window_size, 1 + num_dynamic_features)
            sequences.append(sequence)
            static_features.append(static_scaled)
            targets.append(lap_times_scaled[i + window_size])
        
        return (np.array(sequences), 
                np.array(static_features), 
                np.array(targets))

    
    def create_train_val_loaders(
        self, 
        sequences: np.ndarray, 
        static_features: np.ndarray, 
        targets: np.ndarray,
        batch_size: int = 32,
        val_split: float = 0.2
    ) -> Tuple[DataLoader, DataLoader]:
        """
        Create train and validation dataloaders with given split ratio
        """
        dataset = F1Dataset(sequences, static_features, targets)
        
        # Calculate lengths for split
        val_size = int(len(dataset) * val_split)
        train_size = len(dataset) - val_size
        
        # Split dataset
        train_dataset, val_dataset = random_split(
            dataset, 
            [train_size, val_size],
            generator=torch.Generator().manual_seed(42)
        )
        
        # Create dataloaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False
        )
        
        return train_loader, val_loader

class F1PredictionModel(nn.Module):
    def __init__(self, 
                 sequence_dim: int,
                 static_dim: int,
                 hidden_dim: int = 64,
                 num_layers: int = 2,
                 dropout_prob: float = 0.5):
        super().__init__()
        
        # LSTM for sequential features with dropout
        self.lstm = nn.LSTM(
            input_size=sequence_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout_prob
        )
        
        # Static features processing with dropout
        self.static_network = nn.Sequential(
            nn.Linear(static_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob)
        )
        
        # Combine everything
        self.final_network = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, sequence, static):
        # Process sequence through LSTM
        lstm_out, _ = self.lstm(sequence)
        lstm_out = lstm_out[:, -1, :]  # Output of the last time step
        
        # Process static features
        static_out = self.static_network(static)
        
        # Combine LSTM output and static features
        combined = torch.cat([lstm_out, static_out], dim=1)
        
        # Final prediction
        prediction = self.final_network(combined)
        
        return prediction.squeeze()


# Define the training function
def train_model(model: nn.Module, 
                train_loader: DataLoader,
                val_loader: DataLoader,
                epochs: int = 10,
                learning_rate: float = 0.001,
                lap_time_scaler: StandardScaler = None,  # Pass the lap time scaler
                device: Optional[str] = None) -> Dict[str, List[float]]:
    """
    Train the model and return training history including MAE in milliseconds
    """
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()
    history = {'train_loss': [], 'val_loss': [], 'train_mae': [], 'val_mae': []}
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_losses = []
        train_maes = []
        for batch in train_loader:
            sequences = batch['sequence'].to(device)
            static = batch['static'].to(device)
            targets = batch['target'].to(device)
            
            optimizer.zero_grad()
            predictions = model(sequences, static)
            loss = criterion(predictions, targets)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            
            # Calculate MAE in normalized scale
            mae = torch.mean(torch.abs(predictions - targets)).item()
            train_maes.append(mae)
        
        # Validation
        model.eval()
        val_losses = []
        val_maes = []
        with torch.no_grad():
            for batch in val_loader:
                sequences = batch['sequence'].to(device)
                static = batch['static'].to(device)
                targets = batch['target'].to(device)
                
                predictions = model(sequences, static)
                loss = criterion(predictions, targets)
                val_losses.append(loss.item())
                
                # Calculate MAE in normalized scale
                mae_normalized = torch.mean(torch.abs(predictions - targets)).item()
                val_maes.append(mae_normalized)
        
        # Record metrics
        train_loss = np.mean(train_losses)
        val_loss = np.mean(val_losses)
        train_mae_normalized = np.mean(train_maes)
        val_mae_normalized = np.mean(val_maes)
        
        # Convert MAE back to milliseconds using the inverse scaler
        if lap_time_scaler:
            train_mae_ms = lap_time_scaler.inverse_transform([[train_mae_normalized]])[0][0]
            val_mae_ms = lap_time_scaler.inverse_transform([[val_mae_normalized]])[0][0]
        else:
            train_mae_ms, val_mae_ms = train_mae_normalized, val_mae_normalized
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_mae'].append(train_mae_ms)
        history['val_mae'].append(val_mae_ms)
        
        print(f'Epoch {epoch+1}/{epochs}:')
        print(f'Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train MAE: {train_mae_ms:.2f} ms, Val MAE: {val_mae_ms:.2f} ms')
    
    return history

def predict_with_uncertainty(model, inputs, n_samples=100):
    model.train()  # Enable dropout layers
    predictions = []
    with torch.no_grad():
        for _ in range(n_samples):
            prediction = model(**inputs).cpu().numpy()
            predictions.append(prediction)
    predictions = np.array(predictions)
    mean_prediction = predictions.mean(axis=0)
    std_prediction = predictions.std(axis=0)
    return mean_prediction, std_prediction

# Define a function to save the model
def save_model_with_preprocessor(model, preprocessor, sequence_dim, static_dim, path: str):
    torch.save({
        'model_state_dict': model.state_dict(),
        'lap_time_scaler': preprocessor.lap_time_scaler,
        'dynamic_scaler': preprocessor.dynamic_scaler, 
        'static_scaler': preprocessor.static_scaler,
        'sequence_dim': sequence_dim,
        'static_dim': static_dim
    }, path)
    print(f"Model and preprocessor saved to {path}")


# Now, integrate your code snippets into data preprocessing
def load_and_preprocess_data() -> pd.DataFrame:
    """
    Load data from CSV files and preprocess it to create the enhanced_laps DataFrame.
    """
    # Load data
    na_values = ['\\N', 'NaN', '']
    lap_times = pd.read_csv('../../data/raw_data/lap_times.csv', na_values=na_values)
    drivers = pd.read_csv('../../data/raw_data/drivers.csv', na_values=na_values)
    races = pd.read_csv('../../data/raw_data/races.csv', na_values=na_values)
    circuits = pd.read_csv('../../data/raw_data/circuits.csv', na_values=na_values)
    pit_stops = pd.read_csv('../../data/raw_data/pit_stops.csv', na_values=na_values)
    pit_stops.rename(columns={'milliseconds' : 'pitstop_milliseconds'}, inplace=True)
    results = pd.read_csv('../../data/raw_data/results.csv', na_values=na_values)
    results.rename(columns={'milliseconds' : 'racetime_milliseconds'}, inplace=True)

    qualifying = pd.read_csv('../../data/raw_data/qualifying.csv', na_values=na_values)
    status = pd.read_csv('../../data/raw_data/status.csv', na_values=na_values)
    weather_data = pd.read_csv('../../data/raw_data/ff1_weather.csv', na_values=na_values)
    practice_sessions = pd.read_csv('../../data/raw_data/ff1_laps.csv', na_values=na_values)
    # Load the tire data
    tire_data = pd.read_csv('../../data/raw_data/ff1_laps.csv', na_values=na_values)

    # Convert date columns to datetime
    races['date'] = pd.to_datetime(races['date'])
    results['date'] = results['raceId'].map(races.set_index('raceId')['date'])
    lap_times['date'] = lap_times['raceId'].map(races.set_index('raceId')['date'])
    
    # Merge dataframes
    laps = lap_times.merge(drivers, on='driverId', how='left')
    print(laps.shape)
    laps = laps.merge(races, on='raceId', how='left', suffixes=('', '_race'))
    laps.rename(columns={'quali_time' : 'quali_date_time'}, inplace=True)
    print(laps.shape)
    laps = laps.merge(circuits, on='circuitId', how='left')
    print(laps.shape)
    laps = laps.merge(results[['raceId', 'driverId', 'positionOrder', 'grid', 'racetime_milliseconds', 'fastestLap', 'statusId']], on=['raceId', 'driverId'], how='left')
    print(laps.shape)
    laps = laps.merge(status, on='statusId', how='left')
    print(laps.shape)
    laps = laps.merge(pit_stops[['raceId', 'driverId', 'lap', 'pitstop_milliseconds']], on=['raceId', 'driverId', 'lap'], how='left')
    print(laps.shape)
    laps['pitstop_milliseconds'].fillna(0, inplace=True)  # Assuming 0 if no pit stop
    print(laps.shape)
    
    # Add weather information
    # Filter weather data to include only the Race session
    weather_data = weather_data[weather_data['SessionName'] == 'R']
    
    # Merge weather data with races to get raceId
    weather_data = weather_data.merge(
        races[['raceId', 'year', 'name']], 
        left_on=['EventName', 'Year'],
        right_on=['name', 'year'],
        how='left'
    )
    
    # Compute cumulative time from the start of the race for each driver
    laps.sort_values(['raceId', 'driverId', 'lap'], inplace=True)
    laps['cumulative_milliseconds'] = laps.groupby(['raceId', 'driverId'])['milliseconds'].cumsum()
    laps['seconds_from_start'] = laps['cumulative_milliseconds'] / 1000
    print(laps.shape)
    
    # Use 'Time' in weather_data as 'seconds_from_start'
    weather_data['seconds_from_start'] = weather_data['Time']
    
    # Standardize text data
    tire_data['Compound'] = tire_data['Compound'].str.upper()
    tire_data['EventName'] = tire_data['EventName'].str.strip().str.upper()
    races['name'] = races['name'].str.strip().str.upper()
    
    # Filter for race sessions only
    tire_data = tire_data[tire_data['SessionName'] == 'R']
    
    # Merge with races to get raceId
    tire_data = tire_data.merge(
        races[['raceId', 'year', 'name']],
        left_on=['Year', 'EventName'],
        right_on=['year', 'name'],
        how='left'
    )
    
    # Map driver codes to driverId
    tire_data['Driver'] = tire_data['Driver'].str.strip().str.upper()
    drivers['code'] = drivers['code'].str.strip().str.upper()
    driver_code_to_id = drivers.set_index('code')['driverId'].to_dict()
    tire_data['driverId'] = tire_data['Driver'].map(driver_code_to_id)
    
    # Rename 'LapNumber' to 'lap' and ensure integer type
    tire_data.rename(columns={'LapNumber': 'lap'}, inplace=True)
    tire_data['lap'] = tire_data['lap'].astype(int)
    laps['lap'] = laps['lap'].astype(int)
    
    # Create compound mapping (ordered from hardest to softest)
    compound_mapping = {
        'UNKNOWN': 0,
        'HARD': 1,
        'MEDIUM': 2,
        'SOFT': 3,
        'INTERMEDIATE': 4,
        'WET': 5
    }
    
    # Merge tire_data with laps
    laps = laps.merge(
        tire_data[['raceId', 'driverId', 'lap', 'Compound', 'TrackStatus']],
        on=['raceId', 'driverId', 'lap'],
        how='left'
    )

    
    # Handle missing compounds and apply numeric encoding
    laps['Compound'].fillna('UNKNOWN', inplace=True)
    laps['tire_compound'] = laps['Compound'].map(compound_mapping)
    
    # Drop the original Compound column if desired
    laps.drop('Compound', axis=1, inplace=True)
    
    # Standardize names
    practice_sessions['EventName'] = practice_sessions['EventName'].str.strip().str.upper()
    races['name'] = races['name'].str.strip().str.upper()
    
    # Merge practice_sessions with races to get raceId
    practice_sessions = practice_sessions.merge(
        races[['raceId', 'year', 'name']],
        left_on=['Year', 'EventName'],
        right_on=['year', 'name'],
        how='left'
    )
    
    # Map driver codes to driverId
    practice_sessions['Driver'] = practice_sessions['Driver'].str.strip().str.upper()
    drivers['code'] = drivers['code'].str.strip().str.upper()
    driver_code_to_id = drivers.set_index('code')['driverId'].to_dict()
    practice_sessions['driverId'] = practice_sessions['Driver'].map(driver_code_to_id)
    
    # Convert LapTime to milliseconds
    practice_sessions['LapTime_ms'] = practice_sessions['LapTime'].apply(lambda x: pd.to_timedelta(x).total_seconds() * 1000)
    
    # Calculate median lap times for each driver in each session
    session_medians = practice_sessions.groupby(['raceId', 'driverId', 'SessionName'])['LapTime_ms'].median().reset_index()
    
    # Pivot the data to have sessions as columns
    session_medians_pivot = session_medians.pivot_table(
        index=['raceId', 'driverId'],
        columns='SessionName',
        values='LapTime_ms'
    ).reset_index()
    
    # Rename columns for clarity
    session_medians_pivot.rename(columns={
        'FP1': 'fp1_median_time',
        'FP2': 'fp2_median_time',
        'FP3': 'fp3_median_time',
        'Q': 'quali_time'
    }, inplace=True)
    
    laps = laps.merge(
    session_medians_pivot,
    on=['raceId', 'driverId'],
    how='left'
    )
    
    # Fill missing practice times with global median or a placeholder value
    global_median_fp1 = laps['fp1_median_time'].median()
    laps['fp1_median_time'].fillna(global_median_fp1, inplace=True)
    
    # Repeat for other sessions
    global_median_fp2 = laps['fp2_median_time'].median()
    laps['fp2_median_time'].fillna(global_median_fp2, inplace=True)
    
    global_median_fp3 = laps['fp3_median_time'].median()
    laps['fp3_median_time'].fillna(global_median_fp3, inplace=True)
    
    global_median_quali = laps['quali_time'].median()
    laps['quali_time'].fillna(global_median_quali, inplace=True)

    
    # Create a binary indicator for pit stops
    laps['is_pit_lap'] = laps['pitstop_milliseconds'].apply(lambda x: 1 if x > 0 else 0)

    
    # Define a function to match weather data to laps
    def match_weather_to_lap(race_laps, race_weather):
        """
        For each lap, find the closest weather measurement in time
        """
        race_laps = race_laps.sort_values('seconds_from_start')
        race_weather = race_weather.sort_values('seconds_from_start')
        merged = pd.merge_asof(
            race_laps,
            race_weather,
            on='seconds_from_start',
            direction='nearest'
        )
        return merged

    # Apply matching per race
    matched_laps_list = []
    for race_id in laps['raceId'].unique():
        print(f'Matching for {race_id}')
        race_laps = laps[laps['raceId'] == race_id]
        race_weather = weather_data[weather_data['raceId'] == race_id]
        
        if not race_weather.empty:
            matched = match_weather_to_lap(race_laps, race_weather)
            print(f"Matched DataFrame shape: {matched.shape}")
            matched_laps_list.append(matched)
        else:
            matched_laps_list.append(race_laps)  # No weather data for this race

    # Concatenate all matched laps
    laps = pd.concat(matched_laps_list, ignore_index=True)
    print(laps.shape)
    
    # Fill missing weather data with default values
    laps['track_temp'] = laps['TrackTemp'].fillna(25.0)
    laps['air_temp'] = laps['AirTemp'].fillna(20.0)
    laps['humidity'] = laps['Humidity'].fillna(50.0)
    
    # Calculate driver aggression and skill
    # Create driver names
    drivers['driver_name'] = drivers['forename'] + ' ' + drivers['surname']
    driver_mapping = drivers[['driverId', 'driver_name']].copy()
    driver_mapping.set_index('driverId', inplace=True)
    driver_names = driver_mapping['driver_name'].to_dict()
    
    # Map statusId to status descriptions
    status_dict = status.set_index('statusId')['status'].to_dict()
    results['status'] = results['statusId'].map(status_dict)
    
    # Calculate driver aggression and skill
    def calculate_aggression(driver_results):
        if len(driver_results) == 0:
            return 0.5  # Default aggression for new drivers
        
        # Only consider recent races for more current behavior
        recent_results = driver_results.sort_values('date', ascending=False).head(20)
        
        # Calculate overtaking metrics
        positions_gained = recent_results['grid'] - recent_results['positionOrder']
        
        # Calculate risk metrics
        dnf_rate = (recent_results['status'] != 'Finished').mean()
        incidents = (recent_results['statusId'].isin([
            4,  # Collision
            5,  # Spun off
            6,  # Accident
            20, # Collision damage
            82, # Collision with another driver
        ])).mean()
        
        # Calculate overtaking success rate (normalized between 0-1)
        positive_overtakes = (positions_gained > 0).sum()
        negative_overtakes = (positions_gained < 0).sum()
        total_overtake_attempts = positive_overtakes + negative_overtakes
        overtake_success_rate = positive_overtakes / total_overtake_attempts if total_overtake_attempts > 0 else 0.5
        
        # Normalize average positions gained (0-1)
        avg_positions_gained = positions_gained[positions_gained > 0].mean() if len(positions_gained[positions_gained > 0]) > 0 else 0
        max_possible_gain = 20  # Maximum grid positions that could be gained
        normalized_gains = np.clip(avg_positions_gained / max_possible_gain, 0, 1)
        
        # Normalize risk factors (0-1)
        normalized_dnf = np.clip(dnf_rate, 0, 1)
        normalized_incidents = np.clip(incidents, 0, 1)
        
        # Calculate component scores (each between 0-1)
        overtaking_component = (normalized_gains * 0.6 + overtake_success_rate * 0.4)
        risk_component = (normalized_dnf * 0.5 + normalized_incidents * 0.5)
        
        # Combine components with weights (ensuring sum of weights = 1)
        weights = {
            'overtaking': 0.4,  # Aggressive overtaking
            'risk': 0.5,       # Risk-taking behavior
            'baseline': 0.1    # Baseline aggression
        }
        
        aggression = (
            overtaking_component * weights['overtaking'] +
            risk_component * weights['risk'] +
            0.5 * weights['baseline']  # Baseline aggression factor
        )
        
        # Add small random variation while maintaining 0-1 bounds
        variation = np.random.normal(0, 0.02)
        aggression = np.clip(aggression + variation, 0, 1)
        
        return aggression
    
    def calculate_skill(driver_data, results_data, circuit_id):
        driver_results = results_data[
            (results_data['driverId'] == driver_data['driverId']) & 
            (results_data['circuitId'] == circuit_id)
        ].sort_values('date', ascending=False).head(10)  # Use last 10 races at circuit
        
        if len(driver_results) == 0:
            return 0.5  # Default skill
        
        # Calculate performance metrics
        avg_finish_pos = driver_results['positionOrder'].mean()
        avg_quali_pos = driver_results['grid'].mean()
        points_per_race = driver_results['points'].mean()
        fastest_laps = (driver_results['rank'] == 1).mean()  # Add fastest lap consideration
        
        # Improved normalization (exponential decay for positions)
        normalized_finish_pos = np.exp(-avg_finish_pos/5) # Better spread of values
        normalized_quali_pos = np.exp(-avg_quali_pos/5)
        
        # Points normalization with improved scaling
        max_points_per_race = 26  # Maximum possible points (25 + 1 fastest lap)
        normalized_points = points_per_race / max_points_per_race
        
        # Weighted combination with more factors
        weights = {
            'finish': 0.35,
            'quali': 0.25,
            'points': 0.25,
            'fastest_laps': 0.15
        }
        
        skill = (
            weights['finish'] * normalized_finish_pos +
            weights['quali'] * normalized_quali_pos +
            weights['points'] * normalized_points +
            weights['fastest_laps'] * fastest_laps
        )
        
        # Add random variation to prevent identical skills
        skill = np.clip(skill + np.random.normal(0, 0.05), 0.1, 1.0)
        
        return skill
    
    # First merge results with races to get circuitId
    results = results.merge(
        races[['raceId', 'circuitId']], 
        on='raceId',
        how='left'
    )

    # Now calculate driver aggression and skill
    driver_aggression = {}
    driver_skill = {}
    for driver_id in drivers['driverId'].unique():
        driver_results = results[results['driverId'] == driver_id]
        aggression = calculate_aggression(driver_results)
        driver_aggression[driver_id] = aggression
        
        # Now we have circuit_id from the merge
        recent_race = driver_results.sort_values('date', ascending=False).head(1)
        if not recent_race.empty:
            circuit_id = recent_race['circuitId'].iloc[0]
            skill = calculate_skill({'driverId': driver_id}, results, circuit_id)
            driver_skill[driver_id] = skill
        else:
            driver_skill[driver_id] = 0.5  # Default skill for new drivers
    
    # Map calculated aggression and skill back to laps DataFrame
    laps['driver_aggression'] = laps['driverId'].map(driver_aggression)
    laps['driver_overall_skill'] = laps['driverId'].map(driver_skill)
    laps['driver_circuit_skill'] = laps['driver_overall_skill']  # For simplicity, using overall skill
    laps['driver_consistency'] = 0.5  # Placeholder
    laps['driver_reliability'] = 0.5  # Placeholder
    laps['driver_risk_taking'] = laps['driver_aggression']  # Assuming similar to aggression
    
    # Dynamic features
    laps['tire_age'] = laps.groupby(['raceId', 'driverId'])['lap'].cumcount()
    laps['fuel_load'] = laps.groupby(['raceId', 'driverId'])['lap'].transform(lambda x: x.max() - x + 1)
    laps['track_position'] = laps['position']  # Assuming 'position' is available in laps data
    
    # Ensure that all required columns are present
    # Create an instance of RaceFeatures
    race_features = RaceFeatures()

    
    laps['TrackStatus'].fillna(1, inplace=True)  # 1 = regular racing status
    
    # Ensure that all required columns are present
    required_columns = race_features.static_features + race_features.dynamic_features
    # Before dropping NaN values
    print("\nNaN counts in required columns:")
    for col in required_columns:
        nan_count = laps[col].isna().sum()
        total_rows = len(laps)
        if nan_count > 0:
            print(f"{col}: {nan_count} NaN values ({(nan_count/total_rows*100):.2f}% of rows)")
    missing_columns = set(required_columns) - set(laps.columns)
    if missing_columns:
        raise ValueError(f"Missing required columns: {missing_columns}")
    
    # Drop rows with missing values in required columns
    laps.to_csv('laps_withNan.csv', index=False)
    laps = laps.dropna(subset=required_columns)
    
    print(laps.shape)
    
    return laps

# Update the main function
def main():
    # Load and preprocess data
    enhanced_laps = load_and_preprocess_data()
    
    enhanced_laps.drop(columns=['position', 'time', 'driverRef', 'number', 'R', 'S', 'code', 'forename', 'surname', 'url_x', 'url_race', 'name_x', 'circuitRef', 'name_y', 'location', 'country', 'url_y', 'positionOrder', 'fastestLap', 'cumulative_milliseconds', 'seconds_from_start', 'raceId_x', 'year_x', 'Time', 'TrackTemp', 'AirTemp', 'Humidity', 'name', 'year_y', 'raceId_y'], inplace=True)
    
    # Save the preprocessed laps DataFrame for inspection
    enhanced_laps.to_csv('enhanced_laps_before_training.csv', index=False)
    print(enhanced_laps.shape)
    
    print("Enhanced laps DataFrame saved to 'enhanced_laps_before_training.csv'")
    
    preprocessor = F1DataPreprocessor()
    sequences, static, targets = preprocessor.prepare_sequence_data(enhanced_laps, window_size=3)
    
    # Create train and validation loaders
    train_loader, val_loader = preprocessor.create_train_val_loaders(
        sequences, 
        static, 
        targets,
        batch_size=32,
        val_split=0.2
    )
    
    # Initialize model
    model = F1PredictionModel(
        sequence_dim=sequences.shape[2],
        static_dim=static.shape[1]
    )
    
    # Train the model
    history = train_model(model, train_loader, val_loader, epochs=100, learning_rate=0.001)
    
    # Save the trained model
    save_model_with_preprocessor(model, preprocessor, sequences.shape[2], static.shape[1], 'f1_prediction_model.pth')

if __name__ == "__main__":
    main()

# %% [markdown]
# # Phase 2

# %%
class RaceSimulator:
    def __init__(self, model, preprocessor):
        self.model = model
        self.preprocessor = preprocessor
        self.race_features = RaceFeatures()
        
    def simulate_lap(self, current_state):
        """
        Predict lap time and uncertainty for a single lap.
        """
        # Prepare input data
        sequence = current_state['sequence']
        static = current_state['static']
        
        # Convert to tensors
        sequence_tensor = torch.FloatTensor(sequence).unsqueeze(0)  # Add batch dimension
        static_tensor = torch.FloatTensor(static).unsqueeze(0)
        
        # Predict with uncertainty
        mean_pred, std_pred = predict_with_uncertainty(
            self.model, 
            {'sequence': sequence_tensor, 'static': static_tensor},
            n_samples=50
        )
        
        # Inverse transform the prediction
        lap_time = self.preprocessor.lap_time_scaler.inverse_transform([[mean_pred]])[0][0]
        uncertainty = self.preprocessor.lap_time_scaler.inverse_transform([[std_pred]])[0][0]
        
        return lap_time, uncertainty
    
    def simulate_full_race(self, initial_state, strategy):
        """
        Simulate the entire race lap by lap based on the provided strategy.
        """
        lap_times = []
        uncertainties = []
        current_state = initial_state.copy()
        
        total_laps = strategy.total_laps
        for lap in range(1, total_laps + 1):
            # Update dynamic features based on strategy and lap number
            current_state = self.update_state(current_state, lap, strategy)
            
            # Simulate lap
            lap_time, uncertainty = self.simulate_lap(current_state)
            lap_times.append(lap_time)
            uncertainties.append(uncertainty)
            
            # Update sequence data for the next lap
            current_state['sequence'] = self.update_sequence(
                current_state['sequence'], lap_time, current_state['dynamic']
            )
        
        return lap_times, uncertainties
    
    def update_state(self, state, lap, strategy):
        """
        Update the state for the next lap based on the strategy.
        """
        # Update tire age
        state['dynamic']['tire_age'] += 1
        
        # Check for pit stops
        if lap in strategy.pit_stop_laps:
            # Reset tire age and update tire compound
            state['dynamic']['tire_age'] = 0
            state['dynamic']['tire_compound'] = strategy.pit_stop_compounds[lap]
        
        # Update fuel load
        state['dynamic']['fuel_load'] -= strategy.fuel_consumption_per_lap
        
        # Update other dynamic features as needed
        # ...
        
        return state
    
    def update_sequence(self, sequence, new_lap_time, dynamic_features):
        """
        Update the sequence data with the latest lap information.
        """
        # Remove the oldest lap data
        sequence = sequence[1:]
        # Append the new lap data
        new_sequence_entry = np.hstack((
            [new_lap_time],
            [dynamic_features[feature] for feature in self.race_features.dynamic_features]
        ))
        sequence = np.vstack([sequence, new_sequence_entry])
        return sequence
    
    # Additional helper methods as needed


# %%
class RaceStrategy:
    def __init__(self, total_laps, pit_stops):
        """
        pit_stops: List of dictionaries with keys 'lap' and 'compound'
        Example: [{'lap': 15, 'compound': 'MEDIUM'}, {'lap': 30, 'compound': 'SOFT'}]
        """
        self.total_laps = total_laps
        self.pit_stops = pit_stops  # List of pit stop events
        self.pit_stop_laps = [stop['lap'] for stop in pit_stops]
        self.pit_stop_compounds = {stop['lap']: stop['compound'] for stop in pit_stops}
        self.fuel_consumption_per_lap = 1.5  # Example value, adjust as needed
    
    def evaluate(self, simulator, initial_state):
        """
        Simulate the race using this strategy and return total race time.
        """
        lap_times, uncertainties = simulator.simulate_full_race(initial_state, self)
        total_time = sum(lap_times)
        return total_time, lap_times, uncertainties


# %%
class RaceStrategyOptimizer:
    def __init__(self, simulator):
        self.simulator = simulator
    
    def optimize(self, initial_conditions, constraints):
        """
        Find the optimal strategy given initial conditions and constraints.
        """
        best_strategy = None
        best_time = float('inf')
        
        # Generate possible strategies within constraints
        possible_strategies = self.generate_strategies(constraints)
        
        # Evaluate each strategy
        for strategy in possible_strategies:
            total_time, _, _ = strategy.evaluate(self.simulator, initial_conditions)
            if total_time < best_time:
                best_time = total_time
                best_strategy = strategy
        
        return best_strategy
    
    def generate_strategies(self, constraintas):
        """
        Generate possible strategies based on constraints.
        """
        min_pit_stops = constraints.get('min_pit_stops', 1)
        max_pit_stops = constraints.get('max_pit_stops', 3)
        available_compounds = constraints.get('available_compounds', ['SOFT', 'MEDIUM', 'HARD'])
        total_laps = constraints.get('total_laps', 50)
        
        strategies = []
        
        # Example: Generate strategies with different pit stop laps and compounds
        for num_pit_stops in range(min_pit_stops, max_pit_stops + 1):
            pit_stop_laps_options = self.get_pit_stop_lap_combinations(total_laps, num_pit_stops)
            for pit_stop_laps in pit_stop_laps_options:
                for compounds in self.get_compound_combinations(available_compounds, num_pit_stops):
                    pit_stops = [{'lap': lap, 'compound': comp} for lap, comp in zip(pit_stop_laps, compounds)]
                    strategy = RaceStrategy(total_laps, pit_stops)
                    strategies.append(strategy)
        return strategies
    
    def get_pit_stop_lap_combinations(self, total_laps, num_pit_stops):
        """
        Generate possible pit stop lap combinations.
        """
        from itertools import combinations
        lap_numbers = range(5, total_laps - 5)  # Avoid pitting too early or too late
        return combinations(lap_numbers, num_pit_stops)
    
    def get_compound_combinations(self, compounds, num_pit_stops):
        """
        Generate possible combinations of compounds for pit stops.
        """
        from itertools import product
        return product(compounds, repeat=num_pit_stops)
    
    def load_model_with_preprocessor(path: str, sequence_dim: int, static_dim: int, hidden_dim: int = 64, num_layers: int = 2):
        """
        Load the model and preprocessor from a saved file.
        """
        checkpoint = torch.load(path, map_location=torch.device('cpu'))
        
        # Recreate the model architecture
        model = F1PredictionModel(
            sequence_dim=sequence_dim,
            static_dim=static_dim,
            hidden_dim=hidden_dim,
            num_layers=num_layers
        )
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()  # Set model to evaluation mode
        
        # Recreate the preprocessor
        preprocessor = F1DataPreprocessor()
        preprocessor.lap_time_scaler = checkpoint['lap_time_scaler']
        preprocessor.dynamic_scaler = checkpoint['dynamic_scaler']
        preprocessor.static_scaler = checkpoint['static_scaler']
        
        print(f"Model and preprocessor loaded from {path}")
        return model, preprocessor


# %%
def load_model_with_preprocessor(path: str) -> Tuple[F1PredictionModel, F1DataPreprocessor]:
   """Load saved model and preprocessor"""
   checkpoint = torch.load(path, map_location=torch.device('cpu'))
   
   model = F1PredictionModel(
       sequence_dim=checkpoint['sequence_dim'], 
       static_dim=checkpoint['static_dim']
   )
   model.load_state_dict(checkpoint['model_state_dict'])
   
   preprocessor = F1DataPreprocessor()
   preprocessor.lap_time_scaler = checkpoint['lap_time_scaler']
   preprocessor.dynamic_scaler = checkpoint['dynamic_scaler']
   preprocessor.static_scaler = checkpoint['static_scaler']
   
   return model, preprocessor

# %%
# The model and preprocessor are now ready for simulation or further training
model, preprocessor = load_model_with_preprocessor('f1_prediction_model.pth')

# Assume 'model' and 'preprocessor' have been loaded or trained
simulator = RaceSimulator(model, preprocessor)

# Step 1: Define static features
static_features_list = [
    'driver_overall_skill', 'driver_circuit_skill', 'driver_consistency',
    'driver_reliability', 'driver_aggression', 'driver_risk_taking',
    'fp1_median_time', 'fp2_median_time', 'fp3_median_time', 'quali_time'
]

static_features = np.array([
    0.8,   # driver_overall_skill
    0.75,  # driver_circuit_skill
    0.7,   # driver_consistency
    0.9,   # driver_reliability
    0.6,   # driver_aggression
    0.5,   # driver_risk_taking
    88000, # fp1_median_time
    87500, # fp2_median_time
    87000, # fp3_median_time
    86000  # quali_time
])

# Step 2: Define initial dynamic features for previous laps
dynamic_features_list = [
    'tire_age', 'fuel_load', 'track_position', 'track_temp',
    'air_temp', 'humidity', 'TrackStatus', 'is_pit_lap'
]

# Common values
track_temp = 35.0
air_temp = 25.0
humidity = 50.0
TrackStatus = 1
is_pit_lap = 0

dynamic_features = np.array([
    [0, 100, 1, track_temp, air_temp, humidity, TrackStatus, is_pit_lap],  # Lap 1
    [1, 98.5, 1, track_temp, air_temp, humidity, TrackStatus, is_pit_lap],  # Lap 2
    [2, 97, 1, track_temp, air_temp, humidity, TrackStatus, is_pit_lap]     # Lap 3
])

# Tire compound (not scaled)
tire_compound = 0  # Example value for 'HARD' tires
tire_compound_column = np.full((dynamic_features.shape[0], 1), tire_compound)

# Lap times in milliseconds
lap_times = np.array([90000, 89500, 89200])  # Laps 1-3

# Step 3: Combine lap times, tire_compound, and dynamic features
sequence = np.hstack((
    lap_times.reshape(-1, 1),  # Lap times
    tire_compound_column,      # Tire compound
    dynamic_features           # Dynamic features
))

# Step 4: Scale the lap times
lap_times_scaled = preprocessor.lap_time_scaler.transform(lap_times.reshape(-1, 1)).flatten()

# Step 5: Scale the dynamic features (excluding 'tire_compound')
dynamic_features_to_scale = sequence[:, 2:]  # Exclude lap times and 'tire_compound'
dynamic_scaled = preprocessor.dynamic_scaler.transform(dynamic_features_to_scale)

# Step 6: Reconstruct the scaled sequence
sequence_scaled = np.hstack((
    lap_times_scaled.reshape(-1, 1),  # Scaled lap times
    tire_compound_column,             # Tire compound (not scaled)
    dynamic_scaled                    # Scaled dynamic features
))

# Step 7: Scale the static features
static_scaled = preprocessor.static_scaler.transform(static_features.reshape(1, -1)).flatten()

# Step 8: Prepare current dynamic features
current_dynamic_features = {
    'tire_age': 3,
    'fuel_load': 95.5,
    'track_position': 1,
    'track_temp': 35.0,
    'air_temp': 25.0,
    'humidity': 50.0,
    'TrackStatus': 1,
    'is_pit_lap': 0,
    'tire_compound': 0
}

# Extract and scale dynamic features (excluding 'tire_compound')
dynamic_feature_values = np.array([
    current_dynamic_features[feature] for feature in dynamic_features_list
])

dynamic_scaled_current = preprocessor.dynamic_scaler.transform(dynamic_feature_values.reshape(1, -1)).flatten()

# Step 9: Update the initial state
initial_state = {
    'sequence': sequence_scaled,   # Shape: (window_size, sequence_dim)
    'static': static_scaled,       # Shape: (static_dim,)
    'dynamic': current_dynamic_features
}


# Define constraints
constraints = {
    'min_pit_stops': 1,
    'max_pit_stops': 2,
    'available_compounds': [3, 2, 1],
    'total_laps': 50
}

# Create optimizer and find the best strategy
optimizer = RaceStrategyOptimizer(simulator)
best_strategy = optimizer.optimize(initial_state, constraints)

# Evaluate the best strategy
total_time, lap_times, uncertainties = best_strategy.evaluate(simulator, initial_state)
print(f"Optimal total race time: {total_time:.2f} seconds")
"""


First of all, I need somekind of logging or process bar during the simulations.

Second, I think if the model is working so far (we can try out other models later)

The goal now would be to not only have a tool to find out the best strategy, but also to predict whole races. I image it like this:
Inputs:
-data from before, free practice etc.
-circuitId to tell the model which circuit the race is on
-weatherforecast: in our case we could just do a train test split with the races (including all their laps) and since we know the weather data, we can just assume that that is the forecast
-which drivers will drive for which constructor through driverId (maybe each driver as an agent with (skill. agression etc.?)
-
-Pitstop, I assume optimizing to get the best strategy for each driver would take a while, for now lets just say we predefined one strategy for all of them (when to stop and what compounds)

When we introduce these new variables we probably need to add them the model as well?

Third, I think we need to introduce some like Global Race Events i.e Retirements, Safetey car etc. One could probably train models for this as well, e.g. which lap a safety car is most likely, how likely and when a driver / car will retire. But for now we maybe can also just predefined those things.
structure all the ideas I just outlined and think about whats missing and in which steps to implement
