In [None]:
import pandas as pd
import numpy as np
import math as math
import gc
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from typing import Dict, Optional, Tuple, Any, List, Union
import copy
from torch.utils.checkpoint import checkpoint
import logging
import torch.nn.functional as F
import matplotlib.pyplot as plt
from dataclasses import dataclass, field
import os
import time
import warnings
from collections import defaultdict
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
from sklearn.preprocessing import StandardScaler
from scipy.stats import spearmanr
import sys as sys
import pickle as pkl
from datetime import datetime
import json
import collections

from google.colab import drive
drive.mount('/content/drive')

# !pip install torchviz
# from torchviz import make_dot

# !pip install memory_profiler
# %load_ext memory_profiler
# !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.13-cp311-cp311-linux_x86_64.whl


### Helper functions

In [None]:

def get_device(prefer_device: Optional[str] = None) -> torch.device:
    """
    Detect and return available compute device with priority order:
    TPU > GPU > CPU

    Args:
        prefer_device: Optional preferred device ('tpu', 'gpu', 'cpu')

    Returns:
        torch.device: Best available device
    """
    device_order = []

    # Determine detection order based on preference
    if prefer_device:
        device_order.append(prefer_device.lower())
    device_order += ['tpu', 'cuda', 'mps', 'cpu']

    for device_type in device_order:
        try:
            if device_type == 'tpu':
                import torch_xla
                import torch_xla.core.xla_model as xm

                device = xm.xla_device()
                print(f"Using TPU: {device}")
                return device

            elif device_type == 'cuda' and torch.cuda.is_available():
                device = torch.device("cuda")
                print(f"Using GPU: {torch.cuda.get_device_name(device)}")
                return device

            elif device_type == 'mps' and torch.backends.mps.is_available():
                device = torch.device("mps")
                print("Using Apple MPS")
                return device

            elif device_type == 'cpu':
                device = torch.device("cpu")
                print("Using CPU")
                return device

        except ImportError:
            continue

    return torch.device("cpu")

def set_seed(seed: int, deterministic: bool = False) -> None:
    """
    Set random seeds for reproducibility.

    Args:
        seed: Random seed value
        deterministic: Enable deterministic algorithms (may impact performance)
    """
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        if deterministic:
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

    print(f"Seeds initialized to {seed} with {'deterministic' if deterministic else 'normal'} mode")

def save_training_checkpoint(
    session_dir: str,
    batch_idx: int,
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    stats: Dict[str, Any],
):
    """
    Saves training checkpoint, statistics, and generates plots.

    Args:
        session_dir: Directory created by create_directory_training_session
        batch_idx: Current batch_idx number
        model: Model to save
        optimizer: Optimizer to save
        stats: Dictionary containing training statistics
    """
    checkpoint_dir = os.path.join(session_dir, "checkpoints")
    plots_dir = os.path.join(session_dir, "plots")
    os.makedirs(plots_dir, exist_ok=True)

    # 1. Save model checkpoint
    checkpoint = {
        'batch_idx': batch_idx,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'stats': stats,
        'model_type': type(model).__name__,
    }

    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{batch_idx}.pt")
    torch.save(checkpoint, checkpoint_path)

    # 2. Save statistics to JSON
    stats_path = os.path.join(session_dir, "training_stats.json")
    with open(stats_path, 'w') as f:
        json.dump(stats, f, indent=4)

    # 3. Generate and save plots
    metrics_to_plot = [
                        ('loss', 'Training Loss'),
                        ('predicted_earnings', 'Predicted Earnings (Probabilistic)'),
                        ('predicted_earnings_greedy', 'Predicted Earnings (Greedy)'),
                        ('predicted_earnings_value', 'Predicted Earnings (Value)'),
                        ('roi_kelly', 'ROI Kelly Betting Method'),
                        ('accuracy', 'Accuracy'),
                        ('hrn', 'HRN'),
                        ('spearman', 'Spearman Correlation'),
                        # ('computation_depth', 'Computation Depth'),
                        # ('n_nodes', 'Number of Nodes'),
                        ('batch_time', 'Batch Time'),
                        ('lr', 'Learning Rate'),
                        ]

    plot_path = os.path.join(plots_dir, f"training_plots_{batch_idx}.png")
    plot_training_stats(training_stats = stats,
                       window_size = 50,
                       config = None,
                       plots_dir = plot_path,
                       metrics = metrics_to_plot)
    print('Successfully saved model')


def load_training_checkpoint(
    checkpoint_path: str,
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
) -> Tuple[int, Dict[str, Any]]:
    """
    Loads a training checkpoint from a specific file path.

    Args:
        checkpoint_path: Full path to the checkpoint file
        model: Model instance to load weights into
        optimizer: Optimizer instance to load state into

    Returns:
        Tuple containing:
        - int: The batch index of the loaded checkpoint
        - dict: The statistics dictionary from the checkpoint

    Raises:
        FileNotFoundError: If specified checkpoint doesn't exist
        TypeError: If saved model type doesn't match current model type
    """
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")

    # Determine device to load on
    try:
        device = next(model.parameters()).device
    except StopIteration:  # Model has no parameters
        device = torch.device("cpu")

    # Load checkpoint
    with torch.serialization.safe_globals([collections.defaultdict, list]):
        checkpoint = torch.load(checkpoint_path, map_location=device)

    # Verify model compatibility
    saved_model_type = checkpoint.get("model_type")
    current_model_type = type(model).__name__
    if saved_model_type != current_model_type:
        raise TypeError(
            f"Model type mismatch: Saved model '{saved_model_type}', "
            f"Current model '{current_model_type}'"
        )

    # Load states
    model.load_state_dict(checkpoint["model_state_dict"])
    if optimizer is not None:
      optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

    return checkpoint["batch_idx"], checkpoint["stats"]

In [None]:
from dataclasses import dataclass, field

@dataclass
class PlotConfig:
    """Configuration for plotting training statistics"""
    figsize: tuple = (10, 6)
    linewidth: float = 2.0
    fontsize: int = 12
    dpi: int = 100
    colors: List[str] = field(default_factory=lambda: ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
                                                       '#9467bd', '#8c564b', '#e377c2', '#7f7f7f']) # Use default_factory to create a new list for each instance

def smooth_data(data: np.ndarray, window_size: int = 20) -> np.ndarray:
    """
    Smooth data using a simple moving average with proper edge handling.

    Args:
        data: Input data array
        window_size: Size of the moving average window

    Returns:
        Smoothed data array
    """

    # Convert to float64 only if data is numeric (int or float)
    # and not already float64, use float32 to avoid overflow with very large integers
    if not isinstance(data, np.ndarray) and (isinstance(data[0], int) or isinstance(data[0], float)) and not isinstance(data[0], np.float64):
        try:
            data = np.array(data, dtype=np.float32) # Use float32
        except OverflowError:
            # If float32 still overflows, use float64 with scaling if necessary
            data = np.array(data, dtype=np.float64) / 1e10  # Scale down

    if window_size > len(data):
        raise ValueError("Window size cannot be larger than data length")

    # Create a masked array where NaNs are masked
    masked_data = np.ma.masked_array(data, np.isnan(data))

    # Perform convolution on the masked array
    result = np.convolve(masked_data.filled(0), np.ones(window_size)/window_size, mode='valid')

    # Adjust the result to account for the masked values
    mask = np.convolve(~masked_data.mask, np.ones(window_size), mode='valid')
    result = np.ma.masked_array(result, mask == 0)

    return result

def plot_metric(ax: plt.Axes, x: np.ndarray, y: np.ndarray,
                label: str, color: str, config: PlotConfig) -> None:
    """
    Helper function to plot a single metric.
    """
    ax.plot(x, y, label=label, color=color,
            linewidth=config.linewidth)
    ax.set_xlabel('Training Step', fontsize=config.fontsize)
    ax.set_ylabel(label, fontsize=config.fontsize)
    ax.set_title(f'{label} Over Training Steps', fontsize=config.fontsize+2)
    ax.legend(fontsize=config.fontsize-2)
    ax.grid(True, alpha=0.3)

def plot_training_stats(training_stats: Dict[str, List[float]],
                       window_size: int = 50,
                       config: Optional[PlotConfig] = None,
                       plots_dir: str = None,
                       metrics: Dict[str, List] = None) -> None:
    """
    Plot training statistics with proper smoothing and visualization.

    Args:
        training_stats: Dictionary containing training metrics
        window_size: Smoothing window size
        config: Plot configuration object
    """

    if config is None:
        config = PlotConfig()

    # Smooth all metrics
    smoothed_stats = {
        key: smooth_data(values, window_size)
        for key, values in training_stats.items()
    }

    # Create training steps array
    training_steps = np.arange(len(smoothed_stats['loss']))

    # Create subplots
    fig, axes = plt.subplots(math.ceil(len(metrics)/3), 3, figsize=(16, 22), dpi=config.dpi)

    axes = axes.flatten()

    id = 0
    for idx, (metric, label) in enumerate(metrics):
        if metric in smoothed_stats:
            plot_metric(axes[id], np.arange(len(smoothed_stats[metric])),
                       smoothed_stats[metric], label,
                       config.colors[idx % len(config.colors)], config)
            id+=1
    # Adjust layout and show
    plt.tight_layout()
    if plots_dir:
        plt.savefig(plots_dir)
        plt.close()
    else:
        plt.show()


### Dataset

In [None]:
class HorseRacingDataset(Dataset):
    """Dataset class for horse racing prediction tasks"""

    def __init__(self, data_config):
        """
        Initialize dataset with configuration

        Args:
            data_config: Data configuration parameters
        """
        self.config = data_config
        self.embedding_dict: Dict = {}

        # Load and preprocess data
        self._load_data()
        self._preprocess_data()
        self._validate_data()

        logger.info(f"Dataset initialized with {len(self)} races")

    def __len__(self) -> int:
        """Total number of races in dataset"""
        return len(self.df_races)

    def __getitem__(self, idx: int) -> Tuple:
        """Get item with robust error handling and validation"""
        for attempt in range(100):  # Max 100 attempts to find valid sample
            try:
                i = self._get_item(idx)
                return i
            except (KeyError, IndexError, ValueError) as e:
                logger.warning(f"Error processing index {idx}: {str(e)}")
                idx = (idx + 1) % len(self)

        raise RuntimeError(f"Failed to find valid sample after 100 attempts starting from index {idx}")

    def _load_data(self) -> None:
        """Load and filter raw data files"""
        try:
            self.df_races = self._load_csv("df_races_input_w_datetime.csv", sort_by = ['crid']).drop(columns = self.config.cols_races_to_drop)
            self.df_results = self._load_csv("df_results.csv", sort_by = ['crid', 'hid']).drop(columns = self.config.cols_results_to_drop)
            self.df_race_horse = self._load_csv("df_race_horse_input.csv", sort_by = ['crid', 'hid']).drop(columns = self.config.cols_horses_to_drop)
            # self.df_races = self._load_csv("df_races_input_w_datetime_10000.csv", sort_by = ['crid']).drop(columns = self.config.cols_races_to_drop)
            # self.df_results = self._load_csv("df_results_10000.csv", sort_by = ['crid', 'hid']).drop(columns = self.config.cols_results_to_drop)
            # self.df_race_horse = self._load_csv("df_race_horse_input_10000.csv", sort_by = ['crid', 'hid']).drop(columns = self.config.cols_horses_to_drop)

        except FileNotFoundError as e:
            logger.error(f"Data loading failed: {str(e)}")
            raise

    def _load_csv(self, filename: str, sort_by: list) -> pd.DataFrame:
        """Load and filter CSV file with memory optimization"""
        filepath = os.path.join(self.config.data_folder, filename)
        df = pd.read_csv(filepath, sep=";",
                        dtype={'crid': 'int32', 'hid': 'int32'})
        # display(df)
        return df.sort_values(sort_by).reset_index(drop=True)

    def _preprocess_data(self) -> None:
        """Preprocess and merge datasets"""
        if self.config.input_solution:
          # Merge results to test the functionning of the model
          cols_to_merge = ['hrn', 'hid', 'res_win'] + [f"position_{i}" for i in range(1, 41)]
          self.df_results['hrn'] = self.df_results['hrn'] - 1
          self.df_race_horse = self.df_race_horse.merge(
              self.df_results[cols_to_merge],
              on=['hrn', 'hid'],
              how='left',
              validate='one_to_one'
          ).fillna(0)
          self.df_results['hrn'] = self.df_results['hrn'] + 1

        # Drop the crid that are lower than start_crid
        self.df_races = self.df_races[self.df_races['crid'] >= self.config.start_crid]
        self.df_race_horse = self.df_race_horse[self.df_race_horse['crid'] >= self.config.start_crid]
        self.df_results = self.df_results[self.df_results['crid'] >= self.config.start_crid]

        # Get dataframe, that is not standardized
        self.df_target = self.df_results.copy()

        # Prepare feature columns
        self._setup_feature_columns()
        self._standardize_features()

        # Create lookup indices
        self.crid_to_race_horse = self._create_crid_groups(self.df_race_horse)
        self.crid_to_results = self._create_crid_groups(self.df_results)

    def _setup_feature_columns(self) -> None:
        """Identify feature columns for each dataframe"""
        id_columns = ["rid", "hid", "crid", 'date']

        self.features = {
            'races': [c for c in self.df_races if c not in id_columns],
            'race_horse': [c for c in self.df_race_horse if c not in id_columns],
            'results': [c for c in self.df_results if c not in id_columns]
        }

        logger.info(f"Feature counts - Races: {len(self.features['races'])}, "
                   f"Race Horses: {len(self.features['race_horse'])}, "
                   f"Results: {len(self.features['results'])}")

    def _standardize_features(self) -> None:
        """Apply z-score standardization to feature columns"""
        self.df_races = self._zscore_standardize(self.df_races, self.features['races'])
        self.df_race_horse = self._zscore_standardize(self.df_race_horse, self.features['race_horse'])
        self.df_results = self._zscore_standardize(self.df_results, self.features['results'])

    @staticmethod
    def _zscore_standardize(df: pd.DataFrame, columns: List[str]) -> pd.DataFrame:
        """Safe z-score standardization with error handling"""
        df = df.copy()
        for col in columns:
            try:
                mean = df[col].mean()
                std = df[col].std(ddof=0)
                df[col] = (df[col] - mean) / (std + 1e-8)
            except TypeError:
                logger.error(f"Non-numeric data in column {col}")
                raise
        return df.fillna(0)

    def _create_crid_groups(self, df: pd.DataFrame) -> Dict:
        """Create efficient crid to indices mapping"""
        return df.groupby('crid', sort=False).indices

    def _validate_data(self) -> None:
        """Validate dataset consistency"""
        if len(self.df_races) == 0:
            raise ValueError("No races loaded in dataset")

        if not all(c in self.df_race_horse for c in ['hrn', 'hid']):
            raise ValueError("Missing required columns in race horse data")

    def _get_item(self, idx: int) -> Tuple:
        """Core item retrieval logic"""
        crid = self.df_races.iloc[idx]['crid']

        # Get race features
        race_features = self.df_races.iloc[idx][self.features['races']].values.astype(np.float32)
        # Get horse data
        horse_data, results_data = self._get_horse_data(crid)

        # Get target information
        target_data = self.df_target.loc[self.crid_to_results.get(crid, [])]

        targets = {
            'position': target_data['position'].values.astype(np.float32),
            'decimalPrice': target_data['decimalPrice'].values.astype(np.float32),
            'hids': target_data['hid'].values.astype(np.int32),
            'hrn': target_data['hrn'].values.astype(np.int32),
            'crid': target_data['crid'].values.astype(np.int32),
            'date': self.df_races.iloc[idx]['date'].astype(np.int32)
        }

        return (race_features, horse_data, results_data) + tuple(targets.values())

    def _get_horse_data(self, crid: int) -> Tuple:
        """Retrieve horse data for a given CRID"""
        horse_indices = self.crid_to_race_horse.get(crid, np.array([], dtype=int))
        results_indices = self.crid_to_results.get(crid, np.array([], dtype=int))

        if len(horse_indices) != len(results_indices):
            logger.warning(f"Mismatched data lengths for CRID {crid}: "
                          f"{len(horse_indices)} horses vs {len(results_indices)} results")
            return np.empty((0, len(self.features['race_horse']))), np.empty((0, len(self.features['results'])))

        return (
            self.df_race_horse.iloc[horse_indices][self.features['race_horse']].values.astype(np.float32),
            self.df_results.iloc[results_indices][self.features['results']].values.astype(np.float32)
        )

def collate_fn(batch: List, dataset: HorseRacingDataset) -> Dict:
    """Efficient batch collation with padding and masking"""

    def pad_array(arr: np.ndarray, target_length: int, pad_value: float = 0) -> np.ndarray:
      """Pad array to target length"""
      pad_width = (0, target_length - len(arr))
      return np.pad(arr, (pad_width, (0, 0)) if arr.ndim == 2 else pad_width,
                  constant_values=pad_value)

    batch_elements = len(batch)
    max_horses = max(len(item[1]) for item in batch)

    # Initialize storage
    batch_dict = {
        'race_data': [],
        'horse_data': [],
        'results_data': [],
        'positions': [],
        'prices': [],
        'hids': [],
        'hrn': [],
        'crid': [],
        'date': []
    }

    # Process each sample
    for sample in batch:
        race, horses, results, pos, price, hids, hrn, crid, date = sample
        num_horses = len(horses)

        # Pad features
        batch_dict['race_data'].append(pad_array(np.tile(race, (num_horses, 1)), max_horses))
        batch_dict['horse_data'].append(pad_array(horses, max_horses))
        batch_dict['results_data'].append(pad_array(results, max_horses))

        # Pad targets
        batch_dict['positions'].append(pad_array(pos, max_horses, -1))
        batch_dict['prices'].append(pad_array(price, max_horses, -1))
        batch_dict['hids'].append(pad_array(hids, max_horses, -1))
        batch_dict['hrn'].append(pad_array(hrn, max_horses, -1))
        batch_dict['crid'].append(pad_array(crid, max_horses, -1))
        batch_dict['date'].append(date)

    # Convert to tensors
    tensor_batch = {
        k: torch.tensor(np.stack(v), dtype=torch.float32)
        for k, v in batch_dict.items()
    }
    tensors_with_nan = []
    for name, tensor in tensor_batch.items():
        if torch.isnan(tensor).any():
            tensors_with_nan.append(name)
    if tensors_with_nan:
        print(f"Tensors with NaNs in {tensors_with_nan}")
    return tensor_batch


### Loss functions

In [None]:
def spearman_rank_correlation(logits: torch.Tensor,
                             positions_arrival: torch.Tensor) -> torch.Tensor:
    """
    Computes Spearman's rank correlation between predicted logits and actual positions,
    ignoring invalid entries (-1 or 40). Handles variable participant counts per race.

    Args:
        logits: Tensor of shape [batch_size, num_horses] with prediction scores
        positions_arrival: Tensor of shape [batch_size, num_horses] with actual positions

    Returns:
        Mean Spearman's rho across batch (valid races only) as torch.Tensor
    """
    device = logits.device
    batch_size = logits.size(0)
    correlations = []

    for i in range(batch_size):
        # Filter valid entries for this race
        mask = (positions_arrival[i] != -1.0) & (positions_arrival[i] != 40.0)
        race_logits = logits[i][mask].detach().cpu().numpy()
        race_positions = positions_arrival[i][mask].cpu().numpy()

        # Skip races with <2 valid participants
        if len(race_logits) < 2:
            continue

        try:
            # Generate predicted ranks from logits (higher logit = better rank)
            pred_ranks = (-race_logits).argsort().argsort()  # Double argsort for rank

            # Calculate Spearman correlation
            rho, _ = spearmanr(pred_ranks, race_positions)

            # Handle NaN/edge cases
            if np.isnan(rho):
                rho = 0.0
        except:
            rho = 0.0

        correlations.append(rho)

    # Return average across valid races
    if not correlations:
        return torch.tensor(0.0, device=device)
    return torch.tensor(np.mean(correlations), device=device)

def loss_function_classificationV2(
    logits: torch.Tensor,
    batch: dict
    ) -> tuple:
    """
    Loss function with vectorized operations and reduced memory footprint.

    Args:
        logits: Model outputs (batch_size, num_tokens, 40)
        decimal_prices: Decimal odds (batch_size, num_tokens)
        positions_arrival: Target positions (batch_size, num_tokens)
        crids: Race identifiers (batch_size)

    Returns:
        Tuple containing loss tensor and various metrics
    """
    decimal_prices = batch['prices']
    positions_arrival = batch['positions']
    device = logits.device
    batch_size, num_tokens, n_rankings = logits.shape

    # Mask invalid positions and calculate outputs
    valid_mask = (positions_arrival != -1.0) & (positions_arrival != 40.0)
    logits = logits.masked_fill(~valid_mask.unsqueeze(-1), -1e6)
    output = torch.sigmoid(logits)

    # Create indices tensor using vectorized operations
    positions = positions_arrival.clamp(min=1, max=n_rankings).long() - 1

    positions_expanded = positions.unsqueeze(-1)  # (batch_size, num_tokens, 1)
    rankings_range = torch.arange(n_rankings, device=device).view(1, 1, -1)  # (1, 1, n_rankings)
    indices = (rankings_range >= positions_expanded).float()  # (batch_size, num_tokens, n_rankings)
    indices[valid_mask == False] = 1.0  # (batch_size, num_tokens, n_rankings

    # Calculate position weights, to put as much attention on each race, independently of the amount of racers.
    weights_by_position = torch.arange(n_rankings, 0, -1, device=device).float()
    weights_position = valid_mask.unsqueeze(-1) * weights_by_position.view(1, 1, -1)
    weights_position = (n_rankings**2 / 2) * weights_position / (weights_position.sum(dim=(1,2), keepdim=True) + 1e-5)

    # Calculate loss
    loss = torch.nn.functional.binary_cross_entropy(
        output, indices, weight=weights_position
    )
    with torch.no_grad():

        # Calculate accuracy
        valid_positions = valid_mask & (positions_arrival == 1.0)
        race_output_max = output[..., 0].masked_fill(~valid_mask, -float('inf'))
        predicted_winners = race_output_max.argmax(dim=1)

        # Create mask of correct targets
        correct_target_mask = (positions_arrival == 1.0) & valid_mask

        # Check if predictions match any correct target
        batch_indices = torch.arange(batch_size, device=device)
        correct_predictions = correct_target_mask[batch_indices, predicted_winners]

        # Filter for batches with valid positions
        valid_batches = valid_positions.any(dim=1)
        accuracy = correct_predictions[valid_batches].float().mean()

        # Total winnings
        winnings = (output[valid_positions][:,0] / decimal_prices[valid_positions]).sum()

        # Total betted amount
        total_betted = output[valid_mask][:, 0].sum()

        # Predicted earnings
        predicted_earnings = winnings - total_betted

        # Greedy earnings
        greedy_winnings = ((positions_arrival.gather(1, predicted_winners.unsqueeze(1))==1) / decimal_prices.gather(1, predicted_winners.unsqueeze(1))).sum()
        greedy_earnings = greedy_winnings - batch_size  # Subtract total bets

        # Calculate Spearman correlation
        spearman_results = spearmanr_kpi(output.detach().clone(), positions_arrival)

    return (
        loss,
        accuracy.detach(),
        predicted_earnings / batch_size,
        greedy_earnings / batch_size,
        total_betted / batch_size,
        winnings / batch_size,
        torch.tensor(spearman_results, device=device)
    )

def loss_function_first_horse_classification(
    logits: torch.Tensor,
    batch: dict
) -> tuple:
    """
    Focused on predicting first-place finishes using softmax cross-entropy.
    Horses with position -1 are excluded from the softmax, while horses with
    position 40 are included in the softmax but do not contribute to the loss.

    Args:
        logits: Winner prediction scores (batch_size, num_horses)
        decimal_prices: Decimal odds (batch_size, num_horses)
        positions_arrival: Target positions (batch_size, num_horses)
        crids: Race identifiers (batch_size)

    Returns:
        Tuple containing loss tensor and metrics
    """

    decimal_prices = batch['prices']
    positions_arrival = batch['positions']

    device = logits.device
    batch_size, num_horses, n_ranks = logits.shape
    logits = logits.squeeze(-1)

    # Mask out logits for horses with position -1 (excluded from softmax)
    mask_out = (positions_arrival == -1.0)
    logits = logits.masked_fill(mask_out, -1e9)

    # Mask for valid horses (excluding -1 and 40) to determine valid winners
    valid_mask = (positions_arrival != -1.0)
    winner_mask = (positions_arrival == 1.0)
    valid_races = winner_mask.sum(dim=1) == 1  # Races with exactly one valid winner

    # Convert to class indices for valid winners
    winner_indices = winner_mask.float().argmax(dim=1)  # (batch_size,)

    # Cross-entropy loss only for valid races
    if valid_races.any():
        logits_valid = logits[valid_races]
        targets_valid = winner_indices[valid_races].long()
        loss = F.cross_entropy(logits_valid, targets_valid, reduction = 'none')

        n_horses_participated = valid_mask[valid_races].sum(dim=1)

        # Scale loss for different amount of participants per race
        loss = (loss / torch.log(n_horses_participated)).mean()
    else:
        loss = torch.tensor(0.0, device=device)

    return loss


def loss_function_plackett_luce(
    logits: torch.Tensor,
    decimal_prices: torch.Tensor,
    positions_arrival: torch.Tensor,
    crids: torch.Tensor,
    penalty_weight: float = 1,
    avg_bet_per_race: float = 0.5,
    printit: bool = False
) -> tuple:
    """
    Plackett-Luce loss implementation for horse racing predictions.
    Handles up to 40 positions with dynamic computation graph optimization.
    """
    device = logits.device
    batch_size, num_horses, n_ranks = logits.shape
    logits = logits.squeeze(-1)
    # Mask and prepare valid rankings
    valid_mask = (positions_arrival != -1.0) & (positions_arrival != 40.0)
    positions = positions_arrival.clamp(min=1, max=40).long()

    # ====================== Plackett-Luce Loss Core ========================
    # Sort logits by actual positions for each race
    adjusted_positions = torch.where(valid_mask, positions, torch.full_like(positions, 41))
    sorted_indices = adjusted_positions.argsort(dim=1)

    # Prepare sorted tensors with valid masking
    sorted_logits = logits.gather(1, sorted_indices)
    sorted_valid = valid_mask.gather(1, sorted_indices)
    sorted_logits_masked = sorted_logits.masked_fill(~sorted_valid, -float('inf'))

    # Compute reverse cumulative logsumexp for stability
    reversed_logits = torch.flip(sorted_logits_masked, dims=[1])
    reverse_cumsum = torch.logcumsumexp(reversed_logits, dim=1)
    cum_logsumexp = torch.flip(reverse_cumsum, dims=[1])

    # Calculate per-position log probabilities
    log_probs = sorted_logits_masked - cum_logsumexp
    valid_log_probs = log_probs * sorted_valid.float()

    # Normalize by number of participants per race
    participants_per_race = valid_mask.sum(dim=1, dtype=torch.float)  # (batch_size,)
    race_log_likelihood = torch.nansum(valid_log_probs, dim=1) / participants_per_race

    # Filter valid races (handle potential 0/0 from empty races)
    valid_races = (valid_mask.any(dim=1)) & (participants_per_race > 0)
    loss = -race_log_likelihood[valid_races].mean() if valid_races.any() else torch.tensor(0.0, device=device)
    # ========================================================================

    with torch.no_grad():
        # Prediction metrics (similar to original)
        probs = torch.softmax(logits, dim=1)
        winner_mask = (positions_arrival == 1.0) & valid_mask

        # Accuracy calculation
        valid_winner_races = winner_mask.any(dim=1)
        if valid_winner_races.any():
            pred_winners = probs.argmax(dim=1)
            correct = winner_mask[torch.arange(batch_size), pred_winners]
            accuracy = correct[valid_winner_races].float().mean()
        else:
            accuracy = torch.tensor(0.0, device=device)

        # Betting metrics
        if valid_winner_races.any():
            # Probabilistic betting
            selected_probs = probs[winner_mask]
            selected_decimalprices = decimal_prices[winner_mask]
            winnings_prob = (selected_probs / selected_decimalprices).sum()
            total_betted_prob = batch_size

            # Greedy betting
            pred_winners = probs.argmax(dim=1)
            winnings_greedy = (positions_arrival[torch.arange(batch_size), pred_winners] == 1.0)
            decimalPrice_greedy = decimal_prices[torch.arange(batch_size), pred_winners]
            winnings_greedy = (winnings_greedy.float() / decimalPrice_greedy).sum()
            total_betted_greedy = batch_size
        else:
            winnings_prob = torch.tensor(0.0, device=device)
            total_betted_prob = torch.tensor(0.0, device=device)
            winnings_greedy = torch.tensor(0.0, device=device)
            total_betted_greedy = torch.tensor(0.0, device=device)

        predicted_earnings_prob = winnings_prob - total_betted_prob
        predicted_earnings_greedy = winnings_greedy - total_betted_greedy

        # Ranking correlation
        spearman = spearman_rank_correlation(logits, positions_arrival)

    return (
        loss,
        accuracy.detach(),
        predicted_earnings_prob / batch_size,
        predicted_earnings_greedy / batch_size,
        torch.tensor(1.0, device = device),
        winnings_prob / batch_size,
        spearman.to(device)
    )

### Embedding Manager

In [None]:
import torch
import torch.nn as nn
from typing import Dict, Tuple, Optional, List
from dataclasses import dataclass
from collections import defaultdict, deque

@dataclass
class EmbeddingState:
    embeddings: deque
    node_counts: deque

class CustomEmbeddingManager(nn.Module):
    """Manages dynamic embeddings with computation graph optimization"""

    def __init__(self,
                 embedding_dim: int = 128,
                 max_nodes: int = 100,
                 max_depth_nodes: int = 10,
                 max_sequence_length: int = 5):
        super().__init__()

        # State management
        self.embedding_states: Dict[int, EmbeddingState] = defaultdict(
            lambda: EmbeddingState(embeddings=deque(maxlen=max_sequence_length),
                                   node_counts=deque(maxlen=max_sequence_length)))

        # Configuration
        self.embedding_dim = embedding_dim
        self.max_nodes = max_nodes
        self.max_depth_nodes = max_depth_nodes
        self.max_sequence_length = max_sequence_length

    def update_embeddings(self,
                         horse_ids: torch.Tensor,
                         new_embeddings: torch.Tensor,
                         additional_nodes: int) -> None:
        """
        Update embeddings in a vectorized manner
        Args:
            horse_ids: (batch_size, num_tokens) tensor of horse IDs
            new_embeddings: (batch_size, num_tokens, embedding_dim) tensor
            additional_nodes: (batch_size, num_tokens) tensor of node counts
        """
        self.device = new_embeddings.device
        batch_size, num_tokens = horse_ids.shape

        # Flatten and filter valid IDs
        mask = horse_ids != -1
        valid_ids = horse_ids[mask].long()
        valid_embeddings = new_embeddings[mask]

        # Vectorized update
        for hid, emb in zip(valid_ids, valid_embeddings):
            state = self.embedding_states[hid.item()]
            if len(state.embeddings) == self.max_sequence_length:
              # Detach and remove reference to popped embedding
              popped_emb = state.embeddings.popleft()
              popped_emb = popped_emb.detach()
              state.node_counts.popleft()
            state.embeddings.append(emb)
            state.node_counts.append(additional_nodes + 1)

    def get_detach_flags(self, horse_ids: torch.Tensor) -> Tuple[List[bool], int, List[int]]:
        """
        Calculate detachment flags considering:
        1. Each individual node_count <= max_depth_nodes
        2. Total of selected node_counts <= max_nodes
        3. Prioritizes smallest node_counts first
        """
        batch_size, num_tokens = horse_ids.shape
        device = horse_ids.device

        # Initialize
        detach_flags = []
        node_counts = []
        hid_location = []

        # Collect all valid candidates with their positions
        valid_mask = horse_ids != -1
        valid_indices = torch.nonzero(valid_mask, as_tuple=False)

        for idx in valid_indices:
            i, j = idx.tolist()
            hid = horse_ids[i, j].item()
            state = self.embedding_states.get(hid)

            if not state:
                continue

            # Collect individual node counts with their positions
            for count in state.node_counts:
              node_counts.append(count)
              detach_flags.append(True)
              hid_location.append(hid)

        # Sort candidates by node count (smallest first)
        sorted_indices = sorted(range(len(node_counts)), key=lambda i: node_counts[i])

        # Select candidates until we reach max_nodes
        total = 0
        selected = set()
        for i, index in enumerate(sorted_indices):
            count = node_counts[index]
            if total + count > self.max_nodes or count > self.max_depth_nodes:
                break
            total += count
            detach_flags[index] = False

        return detach_flags, total, hid_location

    def get_embeddings(self,
                      horse_ids: torch.Tensor) -> Tuple[Optional[torch.Tensor],
                                                       Optional[torch.Tensor],
                                                       List[int]]:
        """
        Retrieve embeddings with optimized detachment
        Args:
            horse_ids: (batch_size, num_tokens) tensor of horse IDs
        Returns:
            embeddings: (total_sequences, seq_len, embedding_dim) padded embeddings
            lengths: (total_sequences,) tensor of sequence lengths
            valid_ids: List of valid horse IDs
        """
        detach_flags, total_nodes, hid_location = self.get_detach_flags(horse_ids)
        valid_mask = horse_ids != -1
        valid_ids = horse_ids[valid_mask].tolist()

        # Batch retrieval of embeddings
        sequences = []
        integrated_ids = []
        i = 0
        for hid in valid_ids:
            state = self.embedding_states.get(hid)
            if state and len(state.embeddings) > 0:
                intermediate_sequence = []
                for inter_state in state.embeddings:
                    # if detach_flags[i]:
                    #     inter_state = inter_state.detach()
                    if hid_location[i] != hid:
                        logger.warning(f"In get_embeddings crid sequence {hid} doesn't coincide with detach_flags sequence {hid_location[i]} ")
                    i += 1
                    intermediate_sequence.append(inter_state)
                sequences.append(torch.stack(intermediate_sequence))
                integrated_ids.append(hid)
        if not sequences:
            return None, None, [], 0

        # Pad sequences efficiently
        lengths = torch.tensor([len(seq) for seq in sequences], dtype=torch.long, device=self.device)
        padded = torch.nn.utils.rnn.pad_sequence(sequences,
                                                batch_first=True,
                                                padding_value=0.0)

        return padded, lengths, integrated_ids, total_nodes

    def reset_embeddings(self) -> None:
        """Safely reset all embeddings and associated computation graphs"""
        # First detach and clear gradients
        for hid, state in self.embedding_states.items():
            with torch.no_grad():
                # Detach all embeddings from computation graph
                state.embeddings = [t.detach() for t in state.embeddings]

                # Remove gradient information
                for t in state.embeddings:
                    t.grad = None
                    t.requires_grad_(False)

                # Clear lists
                state.embeddings.clear()
                state.node_counts.clear()

        # Then clear the dictionary
        self.embedding_states.clear()

        # Finally force CUDA cleanup
        if torch.cuda.is_initialized():
            torch.cuda.synchronize()
            torch.cuda.empty_cache()

        gc.collect()



### Models

In [None]:

def token_ids_to_adjacency(token_ids: torch.Tensor, self_loop: bool = False) -> torch.Tensor:
    """
    Converts token IDs to a batch of adjacency matrices where ALL valid nodes are neighbors.

    Args:
        token_ids: (B, N) tensor where:
            - B = batch size
            - N = number of nodes
            - -1 indicates padding/no neighbor

    Returns:
        adj: (B, N, N) adjacency matrix where valid nodes are connected
    """
    B, N = token_ids.shape
    device = token_ids.device

    # Create mask of valid nodes
    valid_nodes = (token_ids != -1)

    # Create adjacency matrix
    adj = (valid_nodes.unsqueeze(-1).type(torch.float32) @ valid_nodes.unsqueeze(-2).type(torch.float32))  # (B, N, N)

    # Remove self-loops if needed
    if not self_loop:
        identity = torch.eye(N, device=device).unsqueeze(0)
        adj = adj * (1 - identity)

    return adj

def batch_normalize_adjacency(adj: torch.Tensor) -> torch.Tensor:
    """Batch-aware adjacency normalization"""
    # Compute degree
    degree = adj.sum(dim=-1).clamp(min=1)
    deg_inv_sqrt = torch.pow(degree, -0.5)

    # Normalized adjacency (D^(-0.5)AD^(-0.5))
    norm_adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2)
    return norm_adj

class GraphConvolution_GCNII2(nn.Module):

    def __init__(self, in_features, out_features, residual=False, variant=False):
        super(GraphConvolution_GCNII, self).__init__()
        self.variant = variant
        if self.variant:
            self.in_features = 2*in_features
        else:
            self.in_features = in_features

        self.out_features = out_features
        self.residual = residual
        self.weight = nn.Parameter(torch.FloatTensor(self.in_features,self.out_features))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.out_features)
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, input, adj , h0 , lamda, alpha, l):
        theta = math.log(lamda/l+1)
        hi = torch.spmm(adj, input)
        # hi = torch.mm(adj, input)

        if self.variant:
            support = torch.cat([hi,h0],1)
            r = (1-alpha)*hi+alpha*h0
        else:
            support = (1-alpha)*hi+alpha*h0
            r = support
        output = theta*torch.mm(support, self.weight)+(1-theta)*r
        if self.residual:
            output = output+input
        return output
class GraphConvolution_GCNII(nn.Module):
    def __init__(self, in_features, out_features, residual=False, variant=False):
        super(GraphConvolution_GCNII, self).__init__()
        self.variant = variant
        if self.variant:
            self.in_features = 2 * in_features  # Concatenate along feature dim
        else:
            self.in_features = in_features

        self.out_features = out_features
        self.residual = residual
        self.weight = nn.Parameter(torch.FloatTensor(self.in_features, self.out_features))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.out_features)
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, x, adj, h0, lamda, alpha, l):
        # Shapes:
        # x:      [B, N, in_features]
        # adj:    [B, N, N]
        # h0:     [B, N, in_features]
        # Returns: [B, N, out_features]

        theta = math.log(lamda / l + 1)

        hi = torch.bmm(adj, x)  # Shape: [B, N, in_features]

        if self.variant:
            support = torch.cat([hi, h0], dim=2)  # [B, N, 2*in_features]
            r = (1 - alpha) * hi + alpha * h0  # [B, N, in_features]
        else:
            support = (1 - alpha) * hi + alpha * h0  # [B, N, in_features]
            r = support

        # Batched matrix multiplication
        output = theta * torch.matmul(support, self.weight) + (1 - theta) * r

        if self.residual:
            # Ensure input & output dimensions match
            output = output + x  # [B, N, out_features] += [B, N, in_features]

        return output

class GCNII(nn.Module):
    def __init__(self, nlayers,emb_dim, dropout, lamda, alpha, variant, residual, self_loop):
        super(GCNII, self).__init__()
        self.convs = nn.ModuleList()
        for _ in range(nlayers):
            self.convs.append(GraphConvolution_GCNII(emb_dim, emb_dim,variant=variant, residual = residual))
        self.act_fn = nn.GELU()
        self.dropout = dropout
        self.alpha = alpha
        self.lamda = lamda
        self.residual = residual
        self.self_loop = self_loop
    def forward(self, input) : #x, hids):
        x, hids = input
        adj = token_ids_to_adjacency(hids, self.self_loop)
        adj = batch_normalize_adjacency(adj) #.to_sparse()

        first_x = x
        for i,con in enumerate(self.convs):
            x = F.dropout(x, self.dropout, training=self.training)
            x = self.act_fn(con(x,adj,first_x,self.lamda,self.alpha,i+1))
        x = F.dropout(x, self.dropout, training=self.training)
        return x, hids


class FeedForward(nn.Module):
    """Position-wise feed-forward network with GELU activation, dropout, and configurable hidden dimension."""
    def __init__(self, emb_dim: int, hidden_layer_dim: int, dropout: float):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(emb_dim, hidden_layer_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_layer_dim, emb_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class GraphConvolution(nn.Module):
    """Graph convolution layer with input/output transformations and neighbor aggregation.

    Args:
        emb_dim: Dimension of node embeddings (int)
    """
    def __init__(self, emb_dim: int, dropout: float = 0.0):
        super().__init__()
        self.linear_self = nn.Linear(emb_dim, emb_dim)
        self.linear_neigh = nn.Linear(emb_dim, emb_dim)
        self.eps = 1e-6  # For numerical stability
        self.dropout = dropout

    def forward(self, x: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Node features (batch_size, num_nodes, emb_dim)
            token_ids: Token ids (batch_size, max_number_tokens)
                          -1 indicates padding/no token
        """
        batch_size, num_nodes, emb_dim = x.shape

        x = F.dropout(x, self.dropout, training=self.training)

        # Self transformation
        self_emb = self.linear_self(x)  # (B, N, D)

        # Neighbor transformation and aggregation
        hids_mask = token_ids != -1
        valid_counts = hids_mask.sum(dim=-1, keepdim=True)  # (B, N, 1)
        transformed_neigh = self.linear_neigh(x)  # (B, N, D)

        # Masked sum and normalize
        summed_neigh = (transformed_neigh * hids_mask[..., None]).sum(dim=-2)
        normalized_neigh = summed_neigh / (valid_counts + self.eps)

        # Combine and activate
        return F.gelu(self_emb + normalized_neigh.unsqueeze(1))


class GCNBlock(nn.Module):
    """Multi-head graph convolution block with residual connection and layer normalization.

    Args:
        config: Dictionary containing:
            - emb_dim: Embedding dimension (int)
            - n_heads: Number of attention heads (int)
    """
    def __init__(self, emb_dim, n_heads, dropout: float = 0.0, ff_layer: bool = False, multiple_ff: float = 4):
        super().__init__()

        self.heads = nn.ModuleList([
            GraphConvolution(emb_dim, dropout = dropout)
            for _ in range(n_heads)
        ])
        self.norm = nn.LayerNorm(emb_dim)
        self.output_proj = nn.Linear(emb_dim * n_heads, emb_dim)
        self.dropout = nn.Dropout(dropout)
        self.ff_layer = ff_layer
        self.multiple_ff = multiple_ff
        if ff_layer:
            self.feed_forward = FeedForward(
                emb_dim = emb_dim,
                hidden_layer_dim = emb_dim * multiple_ff,
                dropout = dropout
            )
    def forward(self, data: tuple) -> torch.Tensor:
        """
        Args:
            data: Tuple containing: (x, token_ids)
            x: Node features (B, N, D)
            token_ids: Token ids (B, N, K)
        """

        x, token_ids = data

        residual = x
        x = self.norm(x)
        # x[token_ids!=-1] = self.norm(x[token_ids!=-1])

        # Process all heads in parallel
        head_outputs = [head(x, token_ids) for head in self.heads]

        if len(self.heads)>1:
          combined = torch.cat(head_outputs, dim=-1)  # (B, N, D*H)

          # Project back to original dimension
          # return self.dropout(self.output_proj(combined)) + residual, token_ids
          if self.ff_layer:
            return self.feed_forward(self.output_proj(combined)) + residual, token_ids
          else:
            return self.output_proj(combined) + residual, token_ids

        else:
          # return self.dropout(head_outputs[0]) + residual, token_ids
          if self.ff_layer:
            return self.feed_forward(head_outputs[0]) + residual, token_ids
          else:
            return head_outputs[0] + residual, token_ids

class GNN_Transformer(nn.Module):
    """Transformer Model with padding mask handling and configurable parameters"""

    def __init__(self, emb_dim, nhead, num_layers, multiple_ff, dropout):
        """
        Args:
            config: Dictionary containing model parameters:
                - emb_dim: Input dimension size (int)
                - nhead: Number of attention heads (int)
                - num_layers: Number of transformer layers (int, optional)
                - dim_feedforward: Feedforward dimension (int, optional)
                - dropout: Dropout probability (float, optional)
        """
        super().__init__()
        self.emb_dim = emb_dim

        # Transformer components
        encoder_layer = TransformerEncoderLayer(
            d_model = emb_dim,
            nhead = nhead,
            dim_feedforward = self.emb_dim * multiple_ff,
            dropout = dropout,
            batch_first=True
        )
        self.transformer_encoder = TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )

    def forward(self, input: Tuple) -> torch.Tensor:
        """
        Process variable-length sequences through Transformer

        Args:
            x:(B, N, emb_dim) padded sequences
            token_ids: Token ids (B, N) (when -1 then pad this element)

        Returns:
            output_sequences: (B, N, emb_dim) processed sequences
        """
        x, token_ids = input

        # Validate inputs
        # self._validate_inputs(x, token_ids)

        # Create padding mask (True indicates padding positions)
        padding_mask = self._create_padding_mask(token_ids)

        # Process through transformer
        output = self.transformer_encoder(
            x,
            src_key_padding_mask=padding_mask
        )

        return output, token_ids

    def _create_padding_mask(self,
                            token_ids: torch.Tensor,
                            ) -> torch.Tensor:
        """Creates padding mask for transformer input"""
        mask = token_ids == -1
        return mask

    def _validate_inputs(self,
                        padded_sequences: torch.Tensor,
                        sequence_lengths: torch.Tensor):
        """Validates input dimensions and types"""
        if padded_sequences.dim() != 3:
            raise ValueError(f"Input sequences must be 3D tensor (batch, seq, features), "
                           f"got {padded_sequences.dim()}D")

        if sequence_lengths.dim() != 1:
            raise ValueError(f"Lengths must be 1D tensor, got {sequence_lengths.dim()}D")

        if padded_sequences.size(0) != sequence_lengths.size(0):
            raise ValueError("Batch size mismatch between sequences and lengths")


class LSTM_Model(nn.Module):
    """LSTM Model with padded sequence handling and configurable parameters"""

    def __init__(self, input_size, hidden_size, num_layers, dropout):
        """
        Args:
            config: Dictionary containing model parameters:
                - emb_dim: Input and hidden dimension size (int)
                - num_layers: Number of LSTM layers (int, optional)
                - dropout: Dropout probability (float, optional)
        """
        super().__init__()

        # Initialize LSTM layer
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True,
            bidirectional=False
        )

    def forward(self,
                input: Tuple) -> torch.Tensor:
        """
        Process variable-length sequences through LSTM

        Args:
            padded_sequences: (batch_size, max_seq_len, emb_dim) padded sequences
            sequence_lengths: (batch_size,) lengths of valid sequences

        Returns:
            output_sequences: (batch_size, max_seq_len, emb_dim) processed sequences
        """
        padded_sequences, sequence_lengths = input

        # Validate inputs
        self._validate_inputs(padded_sequences, sequence_lengths)

        # Convert lengths to CPU tensor for packing
        lengths_cpu = sequence_lengths.cpu()

        # Pack padded sequences
        packed_input = nn.utils.rnn.pack_padded_sequence(
            input=padded_sequences,
            lengths=lengths_cpu,
            batch_first=True,
            enforce_sorted=False
        )

        # Process through LSTM
        packed_output, _ = self.lstm(packed_input)

        # Unpack sequences
        output_sequences, _ = nn.utils.rnn.pad_packed_sequence(
            packed_output,
            batch_first=True,
            total_length=padded_sequences.size(1))

        return output_sequences

    def _validate_inputs(self,
                        sequences: torch.Tensor,
                        lengths: torch.Tensor) -> None:
        """Validate input dimensions and lengths"""
        if sequences.dim() != 3:
            raise ValueError(f"Input sequences must be 3D tensor (batch, seq, features), got {sequences.shape}")

        if lengths.dim() != 1:
            raise ValueError(f"Sequence lengths must be 1D tensor, got {lengths.shape}")

        if sequences.size(0) != lengths.size(0):
            raise ValueError(f"Batch size mismatch between sequences ({sequences.size(0)}) and lengths ({lengths.size(0)})")

        if (lengths < 0).any() or (lengths > sequences.size(1)).any():
            raise ValueError("Invalid sequence lengths detected")

class TransformerModel(nn.Module):
    """Transformer Model with padding mask handling and configurable parameters"""

    def __init__(self, emb_dim, nhead, num_layers, multiple_ff, dropout, max_seq_len):
        """
        Args:
            config: Dictionary containing model parameters:
                - emb_dim: Input dimension size (int)
                - nhead: Number of attention heads (int)
                - num_layers: Number of transformer layers (int, optional)
                - dim_feedforward: Feedforward dimension (int, optional)
                - dropout: Dropout probability (float, optional)
        """
        super().__init__()
        self.emb_dim = emb_dim
        self.max_seq_len = max_seq_len  # Add max length parameter

        # Learnable positional embeddings
        self.position_embedding = nn.Embedding(
            num_embeddings=self.max_seq_len,
            embedding_dim = self.emb_dim
        )

        # Transformer components
        encoder_layer = TransformerEncoderLayer(
            d_model = emb_dim,
            nhead = nhead,
            dim_feedforward = self.emb_dim * multiple_ff,
            dropout = dropout,
            batch_first=True
        )
        self.transformer_encoder = TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )

    def forward(self, input: Tuple) -> torch.Tensor:
        """
        Process variable-length sequences through Transformer

        Args:
            padded_sequences: (batch_size, max_seq_len, emb_dim) padded sequences
            sequence_lengths: (batch_size,) lengths of valid sequences

        Returns:
            output_sequences: (batch_size, max_seq_len, emb_dim) processed sequences
        """
        padded_sequences, sequence_lengths = input

        # Validate inputs
        self._validate_inputs(padded_sequences, sequence_lengths)

        # Add positional encodings
        x = self._add_positional_encoding(padded_sequences)

        # Create padding mask (True indicates padding positions)
        padding_mask = self._create_padding_mask(sequence_lengths,
                                                padded_sequences.size(1))
        # Process through transformer
        output_sequences = self.transformer_encoder(
            x,
            src_key_padding_mask=padding_mask
        )

        return output_sequences

    def _add_positional_encoding(self, x: torch.Tensor) -> torch.Tensor:
        """Adds learned positional embeddings"""
        batch_size, seq_len, emb_dim = x.size()

        # Create position indices
        positions = torch.arange(seq_len, device=x.device)\
                      .expand(batch_size, seq_len)

        # Get embeddings (batch_size, seq_len, emb_dim)
        pos_embeddings = self.position_embedding(positions)

        return x + pos_embeddings

    def _create_padding_mask(self,
                            lengths: torch.Tensor,
                            max_len: int) -> torch.Tensor:
        """Creates padding mask for transformer input"""
        batch_size = lengths.size(0)
        mask = torch.arange(max_len, device=lengths.device)\
               .expand(batch_size, max_len) >= lengths.unsqueeze(1)
        return mask

    def _validate_inputs(self,
                        padded_sequences: torch.Tensor,
                        sequence_lengths: torch.Tensor):
        """Validates input dimensions and types"""
        if padded_sequences.dim() != 3:
            raise ValueError(f"Input sequences must be 3D tensor (batch, seq, features), "
                           f"got {padded_sequences.dim()}D")

        if sequence_lengths.dim() != 1:
            raise ValueError(f"Lengths must be 1D tensor, got {sequence_lengths.dim()}D")

        if padded_sequences.size(0) != sequence_lengths.size(0):
            raise ValueError("Batch size mismatch between sequences and lengths")

In [None]:

class HorseRacingModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        # self.__dict__.update(cfg)
        self.sequence_model_cfg = cfg.sequence_model
        self.graph_conv_model_cfg = cfg.graph_conv_model
        self.input_features = cfg.input_features
        self.input_embedding = self.sequence_model_cfg['input_embedding']
        self.logits_size = cfg.logits_size
        self.emb_dim = cfg.emb_dim

        # Save the embeddings for horse states
        self.horse_embeddings = CustomEmbeddingManager(
            embedding_dim = self.sequence_model_cfg["emb_dim"],
            max_nodes = self.sequence_model_cfg["max_nodes"],
            max_depth_nodes = self.sequence_model_cfg["max_depth_nodes"],
            max_sequence_length = self.sequence_model_cfg["max_seq_depth"]
        )

        # Embedding for horses that have no previous races
        self.first_horse_state = nn.Parameter(torch.zeros(self.emb_dim))


        # Input projection logits
        self.logit_projection = nn.Linear(
            len(self.input_features["races"]) + len(self.input_features["race_horse"]) + self.emb_dim,
            self.emb_dim
            )
        # Input projection embeddings
        if self.input_embedding:
            self.embedding_projection = nn.Linear(
                self.emb_dim + len(self.input_features["races"]) + len(self.input_features["race_horse"]) + len(self.input_features["results"]),
                self.graph_conv_model_cfg['emb_dim']
            )
        else:
            self.embedding_projection = nn.Linear(
                len(self.input_features["races"]) + len(self.input_features["race_horse"]) + len(self.input_features["results"]),
                self.graph_conv_model_cfg['emb_dim']
            )

        # Graph Neural Networks
        if self.graph_conv_model_cfg['layers_type'] == 'GCNII':
              self.logit_graph_nn = GCNII(
                    nlayers = self.graph_conv_model_cfg['n_layers'],
                    emb_dim = self.graph_conv_model_cfg['emb_dim'],
                    dropout = self.graph_conv_model_cfg['dropout'],
                    lamda = self.graph_conv_model_cfg['GCNII_config']['lamda'],
                    alpha = self.graph_conv_model_cfg['GCNII_config']['alpha'],
                    variant = self.graph_conv_model_cfg['GCNII_config']['variant'],
                    residual = self.graph_conv_model_cfg['GCNII_config']['residual'],
                    self_loop = self.graph_conv_model_cfg['GCNII_config']['self_loop']
                )
              self.embedding_graph_nn = GCNII(
                    nlayers = self.graph_conv_model_cfg['n_layers'],
                    emb_dim = self.graph_conv_model_cfg['emb_dim'],
                    dropout = self.graph_conv_model_cfg['dropout'],
                    lamda = self.graph_conv_model_cfg['GCNII_config']['lamda'],
                    alpha = self.graph_conv_model_cfg['GCNII_config']['alpha'],
                    variant = self.graph_conv_model_cfg['GCNII_config']['variant'],
                    residual = self.graph_conv_model_cfg['GCNII_config']['residual'],
                    self_loop = self.graph_conv_model_cfg['GCNII_config']['self_loop']
                )

        elif self.graph_conv_model_cfg['layers_type'] == 'GCN':
              self.logit_graph_nn = nn.Sequential(*[
                  GCNBlock(self.graph_conv_model_cfg['emb_dim'],
                           1,
                           self.graph_conv_model_cfg['dropout'],
                           self.graph_conv_model_cfg['GCN_config']['ff_layer'],
                           self.graph_conv_model_cfg['GCN_config']['multiple_ff'])
                  for _ in range(self.graph_conv_model_cfg['n_layers'])
              ])
              self.embedding_graph_nn = nn.Sequential(*[
                  GCNBlock(self.graph_conv_model_cfg['emb_dim'],
                           1,
                           self.graph_conv_model_cfg['dropout'])
                  for _ in range(self.graph_conv_model_cfg['n_layers'])
              ])

        elif self.graph_conv_model_cfg['layers_type'] == 'Transformer':
              self.logit_graph_nn = GNN_Transformer(
                emb_dim = self.emb_dim,
                nhead = self.graph_conv_model_cfg['GNN_Transformer_config']['n_heads'],
                num_layers = self.graph_conv_model_cfg['GNN_Transformer_config']['n_layers'],
                multiple_ff = self.graph_conv_model_cfg['GNN_Transformer_config']['multiple_ff'],
                dropout = self.graph_conv_model_cfg['GNN_Transformer_config']['dropout'],
            )

              self.embedding_graph_nn = GNN_Transformer(
                emb_dim = self.emb_dim,
                nhead = self.graph_conv_model_cfg['GNN_Transformer_config']['n_heads'],
                num_layers = self.graph_conv_model_cfg['GNN_Transformer_config']['n_layers'],
                multiple_ff = self.graph_conv_model_cfg['GNN_Transformer_config']['multiple_ff'],
                dropout = self.graph_conv_model_cfg['GNN_Transformer_config']['dropout'],
            )

        else:
            raise ValueError(f"Unsupported Graph Neural Network layers type: {self.layers_type}")

        # Output layers
        self.final_norm = nn.LayerNorm(self.emb_dim)
        self.final_projection = nn.Linear(self.emb_dim, self.logits_size)

        # Sequence modeling
        self.sequence_norm = nn.LayerNorm(self.emb_dim)
        if self.sequence_model_cfg['sequence_model_type'] == 'Transformer':

            self.sequence_model = TransformerModel(
                emb_dim = self.emb_dim,
                nhead = self.sequence_model_cfg['Transformer_config']['n_heads'],
                num_layers = self.sequence_model_cfg['Transformer_config']['n_layers'],
                multiple_ff = self.sequence_model_cfg['Transformer_config']['multiple_ff'],
                dropout = self.sequence_model_cfg['Transformer_config']['dropout'],
                max_seq_len = self.sequence_model_cfg['max_seq_depth']
            )

        elif self.sequence_model_cfg['sequence_model_type'] == 'LSTM':
            self.sequence_model = LSTM_Model(
                input_size = self.emb_dim,
                hidden_size = self.sequence_model_cfg['LSTM_config']['hidden_dim'],
                num_layers = self.sequence_model_cfg['LSTM_config']['n_layers'],
                dropout = self.sequence_model_cfg['LSTM_config']['dropout'])
        else:
            raise ValueError(f"Unsupported Sequence Neural Network layers type: {self.layers_type}")

    def forward(self, inputs):

        race_features = inputs['race_data']
        horse_features = inputs['horse_data']
        results = inputs['results_data']
        hids = inputs['hids']
        crids = inputs['crid']

        if self.training:
          race_features.requires_grad = True
          horse_features.requires_grad = True
          results.requires_grad = True

        # Retrieve horse embeddings
        embeddings, seq_lengths, hid_order, node_counts = self.horse_embeddings.get_embeddings(hids)

        # Process through sequence neural network
        if embeddings is not None:
            lstm_output = self._process_sequence(self.sequence_model, embeddings, seq_lengths)
        else:
            lstm_output = None

        # Create embedding matrix using vectorized operations
        batch_size, num_horses = hids.shape
        horse_embeddings = self._create_horse_embeddings_matrix(
            hids, lstm_output, hid_order, seq_lengths, batch_size, num_horses
        )

        # Prepare Graph Neural Network inputs
        logit_input = torch.cat([race_features, horse_features, horse_embeddings], dim=-1)
        if self.input_embedding:
            embed_input = torch.cat([race_features, horse_features, horse_embeddings, results], dim=-1)
        else:
            embed_input = torch.cat([race_features, horse_features, results], dim=-1)

        # Process through Graph Neural Networks
        logit_input = self._process_layer(self.logit_projection, logit_input)
        embed_input = self._process_layer(self.embedding_projection, embed_input)

        logit_output = self._process_GNN(self.logit_graph_nn, logit_input, hids)
        embed_output = self._process_GNN(self.embedding_graph_nn, embed_input, hids)

        # Generate final predictions
        logit_output = self.final_norm(logit_output)
        logits = self.final_projection(logit_output)

        self.horse_embeddings.update_embeddings(hids, embed_output, node_counts)

        return logits, node_counts




    def _create_horse_embeddings_matrix(self, hids, lstm_output, hid_order, seq_lengths,
                                      batch_size, num_horses):
        device = hids.device
        embeddings = torch.zeros((batch_size, num_horses, self.emb_dim), device=device)

        # Create lookup for valid hids
        valid_hids = (hids != -1)
        hid_lookup = {hid: idx for idx, hid in enumerate(hid_order)}

        # Vectorized embedding assignment
        with torch.no_grad():
            index_hid_order = 0
            hid_indices = torch.full_like(hids, -1, dtype=torch.long, device=device)
            for i in range(batch_size):
                for j in range(num_horses):
                    if valid_hids[i,j] and hids[i,j].item() in hid_lookup:
                        hid_indices[i,j] = hid_lookup[hids[i,j].item()]

        valid_mask = hid_indices != -1
        if valid_mask.any():
            seq_lengths.to(device)

            seq_indices = seq_lengths[hid_indices[valid_mask]].to(device) - 1
            embeddings[valid_mask] = lstm_output[hid_indices[valid_mask], seq_indices]

        # Handle first-time hids
        first_time_mask = valid_hids & ~valid_mask
        embeddings[first_time_mask] = self.first_horse_state

        return embeddings

    def _process_layer(self, layer, x):
        return checkpoint(layer, x, use_reentrant=True) if self.training else layer(x)

    def _process_sequence(self, sequence_model, embeddings, lengths):
        return checkpoint(sequence_model, ((embeddings, lengths)), use_reentrant=False) if self.training else sequence_model((embeddings, lengths))

    def _process_GNN(self, transformer, x, hids):
        output, hids = checkpoint(transformer, ((x, hids)), use_reentrant=False) if self.training else transformer((x, hids))
        return output


### Training helper functions

In [None]:
class GradientMonitor:
    """Monitors and reports gradient statistics"""
    def __init__(self):
        self.max_gradients = {}
        self.avg_gradients = {}

    def update(self, model: nn.Module):
        """Update gradient statistics"""
        for name, param in model.named_parameters():
            if param.grad is not None:
                grad = param.grad.detach()
                self.max_gradients[name] = torch.max(torch.abs(grad)).item()
                self.avg_gradients[name] = torch.mean(torch.abs(grad)).item()

    def report(self, frequency: int = 100):
        """Print gradient summary"""
        if not self.max_gradients:
            return

        max_grad = max(self.max_gradients.values())
        avg_grad = sum(self.avg_gradients.values()) / len(self.avg_gradients)
        print(f"\nGradient Summary:")
        print(f"Max Gradient: {max_grad:.4e}, Max Gradients per layer: {self.max_gradients}")
        print(f"Avg Gradient: {avg_grad:.4e}, Avg Gradient per layer: {self.avg_gradients}")


class TrainingConfig:
    """Validated training configuration container"""
    def __init__(self, config: Dict[str, Any]):
        self.__dict__.update(config)

        self.optimizer = OptimConfig(config['optimizer'])
        self.data = DataConfig(config['data'])
        self.model = ModelConfig(config['model'])
        self.train_follow = TrainingIllustration(config['training_follow'])

class DataConfig:
    """Validated data configuration container"""
    def __init__(self, config: Dict[str, Any]):
        self.__dict__.update(config)

    def get_dataloader(self):

        dataset = HorseRacingDataset(self)
        self.features = dataset.features
        dataloader = DataLoader(
            dataset,
            batch_size = self.batch_size,
            shuffle = self.shuffle,
            collate_fn = lambda b: collate_fn(b, dataset),
            num_workers = self.num_workers,
            drop_last = self.drop_last
        )
        return dataloader
    def get_datafeatures(self):
        return self.features

class TrainingStats:
    """Handles training statistics collection and reporting"""
    def __init__(self):
        self.data = defaultdict(list)
        self.metrics = [
            'loss', 'total_winnings', 'total_betted',
            'predicted_earnings', 'predicted_earnings_greedy',
            'hrn', 'accuracy', 'spearman', 'batch_time', 'n_nodes'
        ]

    def update(self, epoch: int, batch_idx: int, model: nn.Module, outputs: torch.Tensor,
                batch_time: float, nodes_depth: int, lr: float, batch: Dict, loss: torch.Tensor):
        """ Compute metrics """
        with torch.no_grad():
            device = outputs.device
            positions_arrival = batch['positions']
            decimal_prices = batch['prices']
            hrn = batch['hrn']
            hids = batch['hids']
            crids = batch['crid']

            batch_size, num_horses, n_ranks = outputs.shape
            if n_ranks > 1:
              outputs = outputs[:,:,0]
            outputs = outputs.squeeze(-1)
            # Mask out logits for horses with position -1 (excluded from softmax)
            mask_out = (positions_arrival == -1.0)
            outputs = outputs.masked_fill(mask_out, -1e9)

            # Races with exactly one valid winner
            winner_mask = (positions_arrival == 1.0)
            valid_races = winner_mask.sum(dim=1) == 1

            probs = torch.softmax(outputs, dim=1)

            # Accuracy calculation
            valid_winner_races = winner_mask.any(dim=1)

            if valid_winner_races.any():
                # Accuracy
                pred_winners = probs.argmax(dim=1)
                correct = winner_mask[torch.arange(batch_size), pred_winners]
                accuracy = correct[valid_winner_races].float().mean()

                # Probabilistic betting. Bet the probability of winning on the horse.
                selected_probs = probs[winner_mask]
                selected_decimalprices = decimal_prices[winner_mask]
                winnings_prob = (selected_probs / selected_decimalprices).sum()
                total_betted_prob = batch_size

                # Value betting. Bet on the horse with a probability higher than the decimal price
                valid_horses_mask = (positions_arrival != -1.0) & (positions_arrival != 40.0)
                value_bets_mask = (probs > decimal_prices) & valid_horses_mask
                won_value_bets = value_bets_mask & (positions_arrival == 1.0)
                winnings_value = (won_value_bets.float() / decimal_prices).sum()
                total_betted_value = value_bets_mask.sum().float()

                # Greedy betting. Bet full on the horse with the highest probability
                pred_winners = probs.argmax(dim=1)
                winnings_greedy = (positions_arrival[torch.arange(batch_size), pred_winners] == 1.0)
                decimalPrice_greedy = decimal_prices[torch.arange(batch_size), pred_winners]
                winnings_greedy = (winnings_greedy.float() / decimalPrice_greedy).sum()
                total_betted_greedy = batch_size

                # Kelly Criterion betting
                # Create valid horses mask (adjust based on your data conventions)
                valid_horses_mask = (positions_arrival != -1.0) & (positions_arrival != 40.0)

                # Value betting mask (where our probability estimate is better than market)
                value_bets_mask = (probs > decimal_prices) & valid_horses_mask

                # Calculate Kelly criterion fractions
                b = 1 / decimal_prices  # Net odds received on the wager (profit per unit bet if you win)
                edge = probs * (b + 1) - 1  # Simplified Kelly numerator
                kelly_fractions = edge / b  # Kelly formula: (bp - q)/b

                # Apply masks and clamp values
                kelly_fractions = torch.where(value_bets_mask, kelly_fractions, 0.0)
                kelly_fractions = torch.clamp(kelly_fractions, min=0.0, max=1.0)

                # Calculate actual winnings and losses
                won_bets = (positions_arrival == 1.0) & value_bets_mask
                returns = kelly_fractions * won_bets / decimal_prices
                losses = kelly_fractions * (~won_bets & value_bets_mask)

                # Aggregate results
                total_betted_kelly = kelly_fractions.sum()
                total_return_kelly = returns.sum()
                net_profit_kelly = total_return_kelly - total_betted_kelly
                roi_kelly = net_profit_kelly / total_betted_kelly if total_betted_kelly > 0 else 0.0

            else:
                accuracy = torch.tensor(0.0, device=device)

                winnings_prob = torch.tensor(0.0, device=device)
                total_betted_prob = torch.tensor(0.0, device=device)

                winnings_greedy = torch.tensor(0.0, device=device)
                total_betted_greedy = torch.tensor(0.0, device=device)

                total_betted_value = torch.tensor(0.0, device=device)
                winnings_value = torch.tensor(0.0, device=device)

                roi_kelly = torch.tensor(0.0, device=device)

            predicted_earnings_prob = (winnings_prob - total_betted_prob) / (total_betted_prob + 1e-8)
            predicted_earnings_greedy = (winnings_greedy - total_betted_greedy)/ (total_betted_greedy + 1e-8)
            predicted_earnings_value = (winnings_value - total_betted_value)/(total_betted_value + 1e-8)

            # Ranking correlation
            spearman = self.spearman_rank_correlation(outputs.detach().clone(), positions_arrival)
            n_nodes_computation_graph =  self.number_nodes_computation_graph(model, loss)

            """Update statistics"""
            self.data['epoch'].append(epoch)
            self.data['batch'].append(batch_idx)
            self.data['batch_time'].append(batch_time)
            self.data['computation_depth'].append(nodes_depth)
            self.data['lr'].append(lr)
            self.data['hrn'].append(hrn[hrn != -1].float().mean().item())
            self.data['n_nodes'].append(n_nodes_computation_graph)

            self.data['accuracy'].append(accuracy.item())
            self.data['predicted_earnings'].append(predicted_earnings_prob.item())
            self.data['predicted_earnings_greedy'].append(predicted_earnings_greedy.item())
            self.data['predicted_earnings_value'].append(predicted_earnings_value.item())
            self.data['roi_kelly'].append(roi_kelly.item())
            self.data['spearman'].append(spearman.item())
            self.data['avg_race_participants'].append(mask_out.sum().item()/batch_size)

            self.data['loss'].append(loss.clone().item())

    def number_nodes_computation_graph(self, model: nn.Module, loss: float) -> int:
        try:
            # dot = make_dot(loss, params=dict(model.named_parameters()),
            #              show_attrs=False, show_saved=False)
            # return len(dot.source)
            return 0
        except Exception as e:
            print(f"Failed to count number of nodes in computation graph: {e}")

    def spearman_rank_correlation(self, logits: torch.Tensor,
                             positions_arrival: torch.Tensor) -> torch.Tensor:
        """
        Computes Spearman's rank correlation between predicted logits and actual positions,
        ignoring invalid entries (-1 or 40). Handles variable participant counts per race.

        Args:
            logits: Tensor of shape [batch_size, num_horses] with prediction scores
            positions_arrival: Tensor of shape [batch_size, num_horses] with actual positions

        Returns:
            Mean Spearman's rho across batch (valid races only) as torch.Tensor
        """
        device = logits.device
        batch_size = logits.size(0)
        correlations = []

        for i in range(batch_size):
            # Filter valid entries for this race
            mask = (positions_arrival[i] != -1.0) & (positions_arrival[i] != 40.0)
            race_logits = logits[i][mask].detach().cpu().numpy()
            race_positions = positions_arrival[i][mask].cpu().numpy()

            # Skip races with <2 valid participants
            if len(race_logits) < 2:
                continue

            try:
                # Generate predicted ranks from logits (higher logit = better rank)
                pred_ranks = (-race_logits).argsort().argsort()  # Double argsort for rank

                # Calculate Spearman correlation
                rho, _ = spearmanr(pred_ranks, race_positions)

                # Handle NaN/edge cases
                if np.isnan(rho):
                    rho = 0.0
            except:
                rho = 0.0

            correlations.append(rho)

        # Return average across valid races
        if not correlations:
            return torch.tensor(0.0, device=device)
        return torch.tensor(np.mean(correlations), device=device)

    def report(self, window_size: int = 100):
        """Print formatted training statistics"""
        if len(self.data['loss']) < window_size:
            return
        print(f"Epoch: {self.data['epoch'][-1]}, Batch: {self.data['batch'][-1]}, Loss:{round(np.mean(self.data['loss'][-window_size:]),2)}, Betting:{round(np.mean(self.data['predicted_earnings'][-window_size:]),2)}, Greedy Betting:{round(np.mean(self.data['predicted_earnings_greedy'][-window_size:]),2)}, HRN:{round(np.mean(self.data['hrn'][-window_size:]),2)} ")


class ModelConfig:
    """Validated model configuration container"""
    def __init__(self, config: Dict[str, Any]):
        self.__dict__.update(config)

class OptimConfig:
    """Validated optimizer configuration container"""
    def __init__(self, config: Dict[str, Any]):
        self.__dict__.update(config)
        self.lr = 0
        self.iterations_since_last_lr_update = 0

    def init_optimizer(self, model: nn.Module) -> torch.optim.Optimizer:
        if self.optimizer_name == 'Adam':
            optim = torch.optim.Adam(params = model.parameters(),
                            lr=self.lr,
                            betas=self.Adam_cfg['betas'],
                            eps=self.Adam_cfg['eps'],
                            weight_decay=self.Adam_cfg['weight_decay'])
            return optim
        elif self.optimizer_name == 'SGD':
            optim = torch.optim.SGD(params = model.parameters(),
                            lr=self.lr,
                            momentum=self.SGD_cfg['momentum'],
                            weight_decay=self.SGD_cfg['weight_decay'])
            return optim
        else:
            raise ValueError(f"Unsupported optimizer: {cfg.optimizer}")


    def update_learning_rate(self, optimizer: torch.optim.Optimizer, stats: TrainingStats, batch_idx: int):
      """
      Updates the learning rate based on the training loss.

      Args:
          optimizer: The optimizer to update.
          stats: The TrainingStats object containing the training loss history.
          patience: The number of epochs to wait before reducing the learning rate.
          factor: The factor by which to reduce the learning rate.
          min_lr: The minimum learning rate.
          window_size: The size of the window to consider for the average loss.
      """
      self.iterations_since_last_lr_update += 1

      # Ensure cold start
      if self.n_steps_cold_start >= batch_idx:
          self.lr = (self.initial_lr / self.n_steps_cold_start) * batch_idx
          for param_group in optimizer.param_groups:
              param_group['lr'] = self.lr
          self.iterations_since_last_lr_update = 0
          return
      # Check if enough data is available
      if len(stats.data['loss']) < self.window_size + self.patience:
          return

      # Calculate average loss over the last 'window_size' epochs
      current_loss = np.mean(stats.data['loss'][-self.window_size:])

      # Calculate average loss 'patience' epochs ago
      previous_loss = np.mean(stats.data['loss'][-(self.window_size + self.patience):-self.patience])

      # If loss has not improved, reduce learning rate
      if current_loss >= previous_loss and self.iterations_since_last_lr_update >= self.patience:
          self.iterations_since_last_lr_update = 0
          for param_group in optimizer.param_groups:
              param_group['lr'] = max(param_group['lr'] * self.factor, self.min_lr)
              self.lr = max(param_group['lr'] * self.factor, self.min_lr)
              print(f"Learning rate reduced to: {param_group['lr']:.6f}")

    def handle_gradients(self, model: nn.Module, optimizer: torch.optim.Optimizer, batch_idx: int):
        """Handle gradient updates and clipping"""


        if batch_idx % self.n_steps_temporal_gradient_accumulation == 0:
          # Gradient clipping
          torch.nn.utils.clip_grad_norm_(
              model.parameters(),
              max_norm=self.gradient_clipping_max_norm,
              error_if_nonfinite=True
          )

          optimizer.step()
          optimizer.zero_grad()

    def update(self, model: nn.Module):
        """Update gradient statistics"""
        for name, param in model.named_parameters():
            if param.grad is not None:
                grad = param.grad.detach()
                self.max_gradients[name] = torch.max(torch.abs(grad)).item()
                self.avg_gradients[name] = torch.mean(torch.abs(grad)).item()

    def report(self, model):
        """Print gradient summary"""
        max_gradients = {}
        avg_gradients = {}
        std_gradients = {}
        for name, param in model.named_parameters():
            if param.grad is not None:
                grad = param.grad.detach()
                max_gradients[name] = torch.max(torch.abs(grad)).item()
                avg_gradients[name] = torch.mean(torch.abs(grad)).item()
                std_gradients[name] = torch.std(torch.abs(grad)).item()


        max_grad = max(max_gradients.values())
        avg_grad = sum(avg_gradients.values()) / len(avg_gradients)
        std_grad = sum(std_gradients.values()) / len(std_gradients)

        print(f"\nGradient Summary:")
        print(f"Max Gradient: {max_grad:.4e}, Max Gradients per layer: {max_gradients}")
        print(f"Avg Gradient: {avg_grad:.4e}, Avg Gradient per layer: {avg_gradients}")
        print(f"Std Gradient: {std_grad:.4e}, Avg Gradient per layer: {std_gradients}")


class TrainingIllustration:
    """Handles training visualization and reporting"""
    def __init__(self, config: Dict[str, Any]):
        self.__dict__.update(config)
        # Plot metrics
        self.metrics_to_plot = [
                                ('loss', 'Training Loss'),
                                ('predicted_earnings', 'Predicted Earnings (Probabilistic)'),
                                ('predicted_earnings_greedy', 'Predicted Earnings (Greedy)'),
                                ('predicted_earnings_value', 'Predicted Earnings (Value)'),
                                ('roi_kelly', 'ROI Kelly Betting Method'),
                                ('accuracy', 'Accuracy'),
                                ('hrn', 'HRN'),
                                ('spearman', 'Spearman Correlation'),
                                # ('computation_depth', 'Computation Depth'),
                                # ('n_nodes', 'Number of Nodes'),
                                ('batch_time', 'Batch Time'),
                                ('lr', 'Learning Rate'),
                                ('avg_race_participants', 'Number of participants per race')
                                ]

    def handle_batch_operations(self, batch_idx: int, model: nn.Module,
                          optimizer: torch.optim.Optimizer, stats: TrainingStats,
                          cfg_optimizer: OptimConfig, model_folder: str, loss: torch.Tensor):
        """Handle periodic batch operations"""
        # Visualization Computation Graph
        if batch_idx % self.n_steps_graph == 0:
            visualize_computation_graph(model, loss, model_folder, batch_idx)

        # Reporting
        if batch_idx % self.n_steps_gradient == 0:
            cfg_optimizer.report(model)

        # Display Plots
        if batch_idx % self.n_steps_plot == 0 and batch_idx > self.window_size:
            plot_training_stats(training_stats = stats.data,
                                metrics = self.metrics_to_plot,
                                window_size = self.window_size,
                                plots_dir = None,
                                config = None)

        # Checkpointing
        if batch_idx % self.n_steps_checkpoint == 0 and batch_idx > 0:
            model_path = os.path.join(model_folder, f"model_{batch_idx}.pt")
            save_training_checkpoint(model_folder, batch_idx, model, optimizer, stats.data)
            gc.collect()





In [None]:
def model_forward(model: nn.Module, batch: Dict):
    """Forward pass through the model"""
    race_data, horse_data, results_data, position, decimalPrice, crids, hids, hrn, date = batch
    return model(batch)


def handle_batch_operations(batch_idx: int, model: nn.Module,
                          optimizer: torch.optim.Optimizer, stats: TrainingStats,
                          grad_monitor: GradientMonitor, model_folder: str, loss: torch.Tensor):
    """Handle periodic batch operations"""
    # Visualization
    if batch_idx % 100 == 0 and batch_idx < -1 and False:
        visualize_computation_graph(model, loss, model_folder, batch_idx)

    # Reporting
    if batch_idx % 10 == 0 and batch_idx > 0 and False:
        stats.report()

    # Display Graphs
    if batch_idx % 100 == 0 and batch_idx > 0:
        plot_training_stats(stats.data)

    # Checkpointing
    if batch_idx % 1000 == 0 and batch_idx > 0:
        model_path = os.path.join(model_folder, f"model_{batch_idx}.pt")
        save_training_checkpoint(model_folder, batch_idx, model, optimizer, stats.data)
        gc.collect()


def visualize_computation_graph(model: nn.Module, loss: float,
                              save_path: str, batch_idx: int):
    """Save computation graph visualization"""
    try:

        dot = make_dot(loss, params=dict(model.named_parameters()),
                     show_attrs=False, show_saved=False)
        dot.render(os.path.join(save_path, f"graph_{batch_idx}"), format="png")
    except Exception as e:
        print(f"Failed to save computation graph: {e}")
        return None


def create_directory_training_session(cfg: Any = None) -> str:
    """
    Creates a new directory for storing training checkpoints and logs.
    Also saves the training configuration as a JSON file.

    Args:
        path: Base path where the directory should be created
        cfg: Configuration object (class) containing training settings

    Returns:
        The full path of the created directory
    """
    try:
      # path = cfg.train_follow['model_folder']
      # suffix = cfg.train_follow['suffix_model_folder']
      path = cfg.train_follow.model_folder
      suffix = cfg.train_follow.suffix_model_folder

      # Create a timestamped directory name
      timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
      session_dir = os.path.join(path, f"training_session_{timestamp}_{suffix}")

      # Create the directory
      os.makedirs(session_dir, exist_ok=True)

      # Create subdirectories
      os.makedirs(os.path.join(session_dir, "checkpoints"), exist_ok=True)
      os.makedirs(os.path.join(session_dir, "logs"), exist_ok=True)

      # Save configuration if provided
      if cfg is not None:
          # Convert cfg to dictionary if it's not already
          if not isinstance(cfg, dict):
              try:
                  cfg_dict = vars(cfg)  # Try to convert class to dict

                  # Replace the function with its name
                  cfg_dict['loss_function'] = cfg_dict['loss_function'].__name__

                  # Convert DataConfig and ModelConfig to dictionaries
                  cfg_dict['data'] = vars(cfg_dict['data'])
                  cfg_dict['model'] = vars(cfg_dict['model'])
                  cfg_dict['optimizer'] = vars(cfg_dict['optimizer'])
                  cfg_dict['train_follow'] = vars(cfg_dict['train_follow'])

              except TypeError:
                  cfg_dict = {k: getattr(cfg, k) for k in dir(cfg) if not k.startswith('_')}
          else:
              cfg_dict = cfg

          # Save as JSON
          config_path = os.path.join(session_dir, "training_config.json")
          with open(config_path, 'w') as f:
              json.dump(cfg_dict, f, indent=4)

      return session_dir

    except OSError as e:
        print(f"Error creating directory: {e}")
        raise

def move_to_device(batch: Dict, device: torch.device) -> Tuple:
    """Move batch tensors to specified device"""
    return{k: batch[k].to(device) if isinstance(batch[k], torch.Tensor) else batch[k] for k in batch}


### Training

In [None]:
def training(training_cfg: Dict[str, Any]):
    """Enhanced training procedure with robust error handling"""
    set_seed(seed = 42)

    cfg = TrainingConfig(training_cfg)
    device = get_device()

    # Data loading
    dataloader = cfg.data.get_dataloader()


    # Model initialization
    cfg.model.input_features = cfg.data.get_datafeatures()
    model = HorseRacingModel(cfg.model).to(device)

    # Optimizer initialization
    optimizer = cfg.optimizer.init_optimizer(model)

    # Load model and optimizer and initialise stats follower
    if cfg.continue_training_from_checkpoint:
        batch_idx_start, prev_stats = load_training_checkpoint( checkpoint_path = cfg.checkpoint_path, model = model, optimizer = optimizer)
    else:
        batch_idx_start = 0
    stats = TrainingStats()

    # Loss function initialization
    criterion = cfg.loss_function

    # Create directory of training session
    cfg.model_folder = create_directory_training_session(copy.deepcopy(cfg))


    for epoch in range(cfg.data.n_epochs):
        model.train()
        model.horse_embeddings.reset_embeddings()

        for batch_idx, batch in enumerate(tqdm(dataloader, desc="Training Epoch"), start = batch_idx_start):
            start_time = time.time()

            batch = move_to_device(batch, device)

            # Forward pass
            outputs, nodes_depth = model(batch)
            loss = criterion(outputs, batch)

            # Backward pass
            loss.backward(retain_graph = True)

            # Update statistics
            stats.update(epoch, batch_idx, model, outputs, time.time() - start_time, nodes_depth, cfg.optimizer.lr, batch, loss)

            # Update learning rate
            cfg.optimizer.update_learning_rate(optimizer, stats, batch_idx)

            # Batch operations
            cfg.train_follow.handle_batch_operations(batch_idx, model, optimizer, stats, cfg.optimizer, cfg.model_folder, loss)

            # Update weights
            cfg.optimizer.handle_gradients(model, optimizer, batch_idx)


data_config = {
    'batch_size': 8,
    'num_workers': 2,
    'shuffle': False,
    'drop_last': True,
    'data_folder': "/content/drive/My Drive/HorseRacing/Horse riding/data/clean data/",
    'start_crid': 0,
    'max_crids': 100000000,
    'input_solution': False,
    'n_epochs': 1,
    'cols_races_to_drop': [],
    'cols_results_to_drop': [],
    'cols_horses_to_drop': ['decimalPrice', 'isFav']
}
sequence_model_cfg = {
    'emb_dim': 512,
    'max_seq_depth': 12,
    'max_nodes': np.inf,
    'max_depth_nodes': np.inf,
    'input_embedding': False,
    'sequence_model_type': 'Transformer',
    'Transformer_config':{
            'n_heads': 8,
            'n_layers': 4,
            'dropout': 0.2,
            'multiple_ff': 4
            },
    'LSTM_config':{
            'n_layers': 4,
            'dropout': 0.2,
            'hidden_dim': 128
            }
    }

graph_conv_model_cfg = {
    'emb_dim': 512,
    'dropout': 0.2,
    'n_layers': 4,
    'layers_type': 'GCN', # GCN, Transformer, GCNII
    'GCNII_config':{
            'lamda':0.5,
            'alpha': 0.1,
            'variant': True,
            'residual': True,
            'self_loop': True
            },
    'GNN_Transformer_config':{
            'n_heads': 8,
            'n_layers': 4,
            'dropout': 0.2,
            'multiple_ff': 4
            },
    'GCN_config':{
            'ff_layer' : True,
            'multiple_ff': 4
            }
          }

model_cfg = {
    'emb_dim': 512,
    'logits_size': 1,
    'node_embedding_input': False,
    'sequence_model': sequence_model_cfg,
    'graph_conv_model': graph_conv_model_cfg,
}


optimizer_cfg = {
    'optimizer_name': 'Adam',
    'Adam_cfg':
    {
        'betas': (0.9, 0.999),
        'eps': 1e-8,
        'weight_decay': 0.0001,
    },
    'SGD_cfg':
    {
        'momentum': 0.9,
        'weight_decay': 0.0001
    },
    'initial_lr': 0.0001,
    'amsgrad': False,
    'patience': 1000,
    'factor': 0.5,
    'min_lr': 1e-6,
    'window_size': 300,
    'n_steps_temporal_gradient_accumulation': 4,
    'gradient_clipping_max_norm': 2,
    'n_steps_cold_start': 300
    }

training_follow_config = {
    'n_steps_gradient': 100,
    'n_steps_graph': 100000000,
    'n_steps_report': 100000000,
    'n_steps_plot': 100,
    'window_size': 50,
    'n_steps_checkpoint': 100,
    'model_folder': "/content/drive/My Drive/HorseRacing/Horse riding/new_models/",
    'suffix_model_folder': 'test_gcn_big_1_continue',
}

training_config = {
    "optimizer": optimizer_cfg,
    "loss_function": loss_function_first_horse_classification,
    "data":data_config,
    "model":model_cfg,
    "training_follow": training_follow_config,
    "continue_training_from_checkpoint": True,
    "checkpoint_path": "/content/drive/My Drive/HorseRacing/Horse riding/new_models/training_session_20250407_130333_test_gcn_big_1/checkpoints/checkpoint_4400.pt",
}

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# with torch.autograd.set_detect_anomaly(False):
  # training(training_config)

### TESTING

In [None]:


class TestingConfig:
    """Validated training configuration container"""
    def __init__(self, config: Dict[str, Any]):
        self.__dict__.update(config)

        self.data = DataConfig(config['data'])
        self.model = ModelConfig(config['model'])
        self.test_follow = TestingIllustration(config['testing_follow_config'])

class DataConfig:
    """Validated data configuration container"""
    def __init__(self, config: Dict[str, Any]):
        self.__dict__.update(config)

    def get_dataloader(self):

        dataset = HorseRacingDataset(self)
        self.features = dataset.features
        dataloader = DataLoader(
            dataset,
            batch_size = self.batch_size,
            shuffle = self.shuffle,
            collate_fn = lambda b: collate_fn(b, dataset),
            num_workers = self.num_workers,
            drop_last = self.drop_last
        )
        return dataloader
    def get_datafeatures(self):
        return self.features

class TestingStats:
    """Handles training statistics collection and reporting"""
    def __init__(self):
        self.data = defaultdict(list)
        self.metrics = [
            'loss', 'total_winnings', 'total_betted',
            'predicted_earnings', 'predicted_earnings_greedy',
            'hrn', 'accuracy', 'spearman', 'batch_time', 'n_nodes'
        ]
        self.df_outputs = pd.DataFrame(columns = ['crid', 'rid', 'hrn', 'position', 'prices', 'prediction'])
    def update(self, epoch: int, batch_idx: int, model: nn.Module, outputs: torch.Tensor,
                batch_time: float, nodes_depth: int, batch: Dict, loss: torch.Tensor):
        """ Compute metrics """
        with torch.no_grad():
            device = outputs.device
            positions_arrival = batch['positions']
            decimal_prices = batch['prices']
            hrn = batch['hrn']
            hids = batch['hids']
            crids = batch['crid']

            batch_size, num_horses, n_ranks = outputs.shape
            if n_ranks > 1:
              outputs = outputs[:,:,0]
            outputs = outputs.squeeze(-1)
            # Mask out logits for horses with position -1 (excluded from softmax)
            mask_out = (positions_arrival == -1.0)
            outputs = outputs.masked_fill(mask_out, -1e9)

            # Races with exactly one valid winner
            winner_mask = (positions_arrival == 1.0)
            valid_races = winner_mask.sum(dim=1) == 1

            probs = torch.softmax(outputs, dim=1)

            # Accuracy calculation
            valid_winner_races = winner_mask.any(dim=1)

            if valid_winner_races.any():
                # Accuracy
                pred_winners = probs.argmax(dim=1)
                correct = winner_mask[torch.arange(batch_size), pred_winners]
                accuracy = correct[valid_winner_races].float().mean()

                # Probabilistic betting. Bet the probability of winning on the horse.
                selected_probs = probs[winner_mask]
                selected_decimalprices = decimal_prices[winner_mask]
                winnings_prob = (selected_probs / selected_decimalprices).sum()
                total_betted_prob = batch_size

                # Value betting. Bet on the horse with a probability higher than the decimal price
                valid_horses_mask = (positions_arrival != -1.0) & (positions_arrival != 40.0)
                value_bets_mask = (probs > decimal_prices) & valid_horses_mask
                won_value_bets = value_bets_mask & (positions_arrival == 1.0)
                winnings_value = (won_value_bets.float() / decimal_prices).sum()
                total_betted_value = value_bets_mask.sum().float()

                # Greedy betting. Bet full on the horse with the highest probability
                pred_winners = probs.argmax(dim=1)
                winnings_greedy = (positions_arrival[torch.arange(batch_size), pred_winners] == 1.0)
                decimalPrice_greedy = decimal_prices[torch.arange(batch_size), pred_winners]
                winnings_greedy = (winnings_greedy.float() / decimalPrice_greedy).sum()
                total_betted_greedy = batch_size

                # Kelly Criterion betting
                # Create valid horses mask (adjust based on your data conventions)
                valid_horses_mask = (positions_arrival != -1.0) & (positions_arrival != 40.0)

                # Value betting mask (where our probability estimate is better than market)
                value_bets_mask = (probs > decimal_prices) & valid_horses_mask

                # Calculate Kelly criterion fractions
                b = 1 / decimal_prices  # Net odds received on the wager (profit per unit bet if you win)
                edge = probs * (b + 1) - 1  # Simplified Kelly numerator
                kelly_fractions = edge / b  # Kelly formula: (bp - q)/b

                # Apply masks and clamp values
                kelly_fractions = torch.where(value_bets_mask, kelly_fractions, 0.0)
                kelly_fractions = torch.clamp(kelly_fractions, min=0.0, max=1.0)

                # Calculate actual winnings and losses
                won_bets = (positions_arrival == 1.0) & value_bets_mask
                returns = kelly_fractions * won_bets / decimal_prices
                losses = kelly_fractions * (~won_bets & value_bets_mask)

                # Aggregate results
                total_betted_kelly = kelly_fractions.sum()
                total_return_kelly = returns.sum()
                net_profit_kelly = total_return_kelly - total_betted_kelly
                roi_kelly = net_profit_kelly / total_betted_kelly if total_betted_kelly > 0 else 0.0

            else:
                accuracy = torch.tensor(0.0, device=device)

                winnings_prob = torch.tensor(0.0, device=device)
                total_betted_prob = torch.tensor(0.0, device=device)

                winnings_greedy = torch.tensor(0.0, device=device)
                total_betted_greedy = torch.tensor(0.0, device=device)

                total_betted_value = torch.tensor(0.0, device=device)
                winnings_value = torch.tensor(0.0, device=device)

                roi_kelly = torch.tensor(0.0, device=device)

            predicted_earnings_prob = (winnings_prob - total_betted_prob) / (total_betted_prob + 1e-8)
            predicted_earnings_greedy = (winnings_greedy - total_betted_greedy)/ (total_betted_greedy + 1e-8)
            predicted_earnings_value = (winnings_value - total_betted_value)/(total_betted_value + 1e-8)

            # Ranking correlation
            spearman = self.spearman_rank_correlation(outputs.detach().clone(), positions_arrival)
            n_nodes_computation_graph =  self.number_nodes_computation_graph(model, loss)

            """Update statistics"""
            self.data['epoch'].append(epoch)
            self.data['batch'].append(batch_idx)
            self.data['batch_time'].append(batch_time)
            self.data['computation_depth'].append(nodes_depth)
            self.data['hrn'].append(hrn[hrn != -1].float().mean().item())
            self.data['n_nodes'].append(n_nodes_computation_graph)

            self.data['accuracy'].append(accuracy.item())
            self.data['predicted_earnings'].append(predicted_earnings_prob.item())
            self.data['predicted_earnings_greedy'].append(predicted_earnings_greedy.item())
            self.data['predicted_earnings_value'].append(predicted_earnings_value.item())
            self.data['roi_kelly'].append(roi_kelly.item())
            self.data['spearman'].append(spearman.item())
            self.data['avg_race_participants'].append(mask_out.sum().item()/batch_size)

            self.data['loss'].append(loss.clone().item())

            # Extract batch data with device awareness
            hids = batch['hids'].cpu().numpy()
            crids = batch['crid'].cpu().numpy()
            positions = batch['positions'].cpu().numpy()
            prices = batch['prices'].cpu().numpy()
            predictions = outputs.detach().cpu().numpy()

            # Reshape arrays
            predictions_flat = predictions.reshape(-1)
            hids_flat = hids.reshape(-1)
            crids_flat = crids.reshape(-1)
            positions_flat = positions.reshape(-1)
            prices_flat = prices.reshape(-1)

            # Create mask for valid hids
            valid_mask = hids_flat != -1

            # Create temporary dataframe
            temp_df = pd.DataFrame({
                'hid': hids_flat[valid_mask],
                'crid': crids_flat[valid_mask],
                'position': positions_flat[valid_mask],
                'decimalPrice': prices_flat[valid_mask],
                'prediction': predictions_flat[valid_mask]
            })

            # Update main dataframe
            self.df_outputs = pd.concat(
                [self.df_outputs, temp_df],
                ignore_index=True
            )

    def save_results_df(self, folder: str, filename):
        """Save predictions dataframe to CSV"""
        path = os.path.join(folder, f"df_predictions_{filename}.csv")
        self.df_outputs.to_csv(path, index=False)
        print(f"Predictions saved to {path}")


    def number_nodes_computation_graph(self, model: nn.Module, loss: float) -> int:
        try:
            # dot = make_dot(loss, params=dict(model.named_parameters()),
            #              show_attrs=False, show_saved=False)
            # return len(dot.source)
            return 0
        except Exception as e:
            print(f"Failed to count number of nodes in computation graph: {e}")

    def spearman_rank_correlation(self, logits: torch.Tensor,
                             positions_arrival: torch.Tensor) -> torch.Tensor:
        """
        Computes Spearman's rank correlation between predicted logits and actual positions,
        ignoring invalid entries (-1 or 40). Handles variable participant counts per race.

        Args:
            logits: Tensor of shape [batch_size, num_horses] with prediction scores
            positions_arrival: Tensor of shape [batch_size, num_horses] with actual positions

        Returns:
            Mean Spearman's rho across batch (valid races only) as torch.Tensor
        """
        device = logits.device
        batch_size = logits.size(0)
        correlations = []

        for i in range(batch_size):
            # Filter valid entries for this race
            mask = (positions_arrival[i] != -1.0) & (positions_arrival[i] != 40.0)
            race_logits = logits[i][mask].detach().cpu().numpy()
            race_positions = positions_arrival[i][mask].cpu().numpy()

            # Skip races with <2 valid participants
            if len(race_logits) < 2:
                continue

            try:
                # Generate predicted ranks from logits (higher logit = better rank)
                pred_ranks = (-race_logits).argsort().argsort()  # Double argsort for rank

                # Calculate Spearman correlation
                rho, _ = spearmanr(pred_ranks, race_positions)

                # Handle NaN/edge cases
                if np.isnan(rho):
                    rho = 0.0
            except:
                rho = 0.0

            correlations.append(rho)

        # Return average across valid races
        if not correlations:
            return torch.tensor(0.0, device=device)
        return torch.tensor(np.mean(correlations), device=device)

    def report(self, window_size: int = 100):
        """Print formatted training statistics"""
        if len(self.data['loss']) < window_size:
            return
        print(f"Epoch: {self.data['epoch'][-1]}, Batch: {self.data['batch'][-1]}, Loss:{round(np.mean(self.data['loss'][-window_size:]),2)}, Betting:{round(np.mean(self.data['predicted_earnings'][-window_size:]),2)}, Greedy Betting:{round(np.mean(self.data['predicted_earnings_greedy'][-window_size:]),2)}, HRN:{round(np.mean(self.data['hrn'][-window_size:]),2)} ")


class ModelConfig:
    """Validated model configuration container"""
    def __init__(self, config: Dict[str, Any]):
        self.__dict__.update(config)


class TestingIllustration:
    """Handles training visualization and reporting"""
    def __init__(self, config: Dict[str, Any]):
        self.__dict__.update(config)
        # Plot metrics
        self.metrics_to_plot = [
                                ('loss', 'Training Loss'),
                                ('predicted_earnings', 'Predicted Earnings (Probabilistic)'),
                                ('predicted_earnings_greedy', 'Predicted Earnings (Greedy)'),
                                ('predicted_earnings_value', 'Predicted Earnings (Value)'),
                                ('roi_kelly', 'ROI Kelly Betting Method'),
                                ('accuracy', 'Accuracy'),
                                ('hrn', 'HRN'),
                                ('spearman', 'Spearman Correlation'),
                                # ('computation_depth', 'Computation Depth'),
                                # ('n_nodes', 'Number of Nodes'),
                                ('batch_time', 'Batch Time'),
                                ('avg_race_participants', 'Number of participants per race')
                                ]

    def handle_batch_operations(self, batch_idx: int, model: nn.Module,
                          stats: TrainingStats):
        """Handle periodic batch operations"""

        # Display Plots
        if batch_idx % self.n_steps_plot == 0 and batch_idx > self.window_size:
            plot_training_stats(training_stats = stats.data,
                                metrics = self.metrics_to_plot,
                                window_size = self.window_size,
                                plots_dir = None,
                                config = None)


In [None]:
def list_files_in_folder(folder_path):
    try:
        # List all files in the given folder
        files = os.listdir(folder_path)
        pt_files = [os.path.join(folder_path, f) for f in files if os.path.isfile(os.path.join(folder_path, f)) and f.endswith('.pt')]
        return pt_files
    except Exception as e:
        print(f"An error occurred: {e}")
        return []



def testing(training_cfg: Dict[str, Any]):
    """Enhanced training procedure with robust error handling"""
    set_seed(seed = 42)

    cfg = TestingConfig(training_cfg)
    device = get_device()

    list_checkpoints_to_test = list_files_in_folder(cfg.checkpoints_path)
    print('list ch', list_checkpoints_to_test)
    # Data loading
    dataloader = cfg.data.get_dataloader()

    # Model initialization
    cfg.model.input_features = cfg.data.get_datafeatures()
    model = HorseRacingModel(cfg.model).to(device)

    # Loss function initialization
    criterion = cfg.loss_function



    for file_ckp in list_checkpoints_to_test:
        print(f'Starting to test: {file_ckp}')
        # Load model
        batch_idx_model, stats_model = load_training_checkpoint( checkpoint_path = file_ckp, model = model, optimizer = None)

        model.eval()
        model.horse_embeddings.reset_embeddings()
        with torch.no_grad():
          batch_idx_start = 0
          stats = TestingStats()
          for batch_idx, batch in enumerate(tqdm(dataloader, desc="Training Epoch"), start = batch_idx_start):
              start_time = time.time()

              batch = move_to_device(batch, device)

              # Forward pass
              outputs, nodes_depth = model(batch)
              loss = criterion(outputs, batch)

              # Update statistics
              stats.update(0, batch_idx, model, outputs, time.time() - start_time, nodes_depth, batch, loss)

              # Batch operations
              cfg.test_follow.handle_batch_operations(batch_idx, model, stats)
        filename = os.path.basename(file_ckp)
        stats.save_results_df(folder = cfg.checkpoints_path, filename = filename)


data_config = {
    'batch_size': 8,
    'num_workers': 2,
    'shuffle': False,
    'drop_last': True,
    'data_folder': "/content/drive/My Drive/HorseRacing/Horse riding/data/clean data/",
    'start_crid': 0,
    'max_crids': 50000,
    'input_solution': False,
    'n_epochs': 1,
    'cols_races_to_drop': [],
    'cols_results_to_drop': [],
    'cols_horses_to_drop': ['decimalPrice', 'isFav']
}
sequence_model_cfg = {
    'emb_dim': 512,
    'max_seq_depth': 12,
    'max_nodes': np.inf,
    'max_depth_nodes': np.inf,
    'input_embedding': False,
    'sequence_model_type': 'Transformer',
    'Transformer_config':{
            'n_heads': 8,
            'n_layers': 4,
            'dropout': 0.2,
            'multiple_ff': 4
            },
    'LSTM_config':{
            'n_layers': 4,
            'dropout': 0.2,
            'hidden_dim': 128
            }
    }

graph_conv_model_cfg = {
    'emb_dim': 512,
    'dropout': 0.2,
    'n_layers': 4,
    'layers_type': 'GCN', # GCN, Transformer, GCNII
    'GCNII_config':{
            'lamda':0.5,
            'alpha': 0.1,
            'variant': True,
            'residual': True,
            'self_loop': True
            },
    'GNN_Transformer_config':{
            'n_heads': 8,
            'n_layers': 4,
            'dropout': 0.2,
            'multiple_ff': 4
            },
    'GCN_config':{
            'ff_layer' : True,
            'multiple_ff': 4
            }
          }

model_cfg = {
    'emb_dim': 512,
    'logits_size': 1,
    'node_embedding_input': False,
    'sequence_model': sequence_model_cfg,
    'graph_conv_model': graph_conv_model_cfg,
}


testing_follow_config = {
    'n_steps_gradient': 100,
    'n_steps_graph': 100000000,
    'n_steps_report': 100000000,
    'n_steps_plot': 1000,
    'window_size': 50,
    'n_steps_checkpoint': 100,
    'model_folder': "/content/drive/My Drive/HorseRacing/Horse riding/new_models/",
    'suffix_model_folder': 'test_gcn_big_1_continue',
}

testing_config = {
    "loss_function": loss_function_first_horse_classification,
    "data":data_config,
    "model":model_cfg,
    "testing_follow_config": testing_follow_config,
    "continue_training_from_checkpoint": True,
    "checkpoints_path": "/content/drive/My Drive/HorseRacing/Horse riding/new_models/training_session_20250407_130333_test_gcn_big_1/checkpoints",
    "start_crid_id": 0,
    "end_crid_id": 50000
}


# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
with torch.autograd.set_detect_anomaly(False):
  testing(testing_config)