In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt

In [None]:
venue_country_dict = {
    'Harare Sports Club': 'Zimbabwe',
    'Green Park': 'India',
    'Vidarbha Cricket Association Stadium, Jamtha': 'India',
    'M Chinnaswamy Stadium': 'India',
    'Central Broward Regional Park Stadium Turf Ground': 'USA',
    'Sabina Park, Kingston': 'Jamaica',
    'R.Premadasa Stadium, Khettarama': 'Sri Lanka',
    'JSCA International Stadium Complex': 'India',
    'Barsapara Cricket Stadium': 'India',
    'Old Trafford': 'England',
    'Sophia Gardens': 'England',
    'County Ground': 'England',
    'Arun Jaitley Stadium': 'India',
    'Saurashtra Cricket Association Stadium': 'India',
    'Greenfield International Stadium': 'India',
    'The Wanderers Stadium': 'South Africa',
    'SuperSport Park': 'South Africa',
    'Newlands': 'South Africa',
    'Barabati Stadium': 'India',
    'Holkar Cricket Stadium': 'India',
    'Wankhede Stadium': 'India',
    'The Village, Malahide': 'Ireland',
    'Brisbane Cricket Ground, Woolloongabba': 'Australia',
    'Melbourne Cricket Ground': 'Australia',
    'Sydney Cricket Ground': 'Australia',
    'Westpac Stadium': 'New Zealand',
    'Eden Park': 'New Zealand',
    'Seddon Park': 'New Zealand',
    'Eden Gardens': 'India',
    'Bharat Ratna Shri Atal Bihari Vajpayee Ekana Cricket Stadium': 'India',
    'MA Chidambaram Stadium, Chepauk': 'India',
    'Dr. Y.S. Rajasekhara Reddy ACA-VDCA Cricket Stadium': 'India',
    'M.Chinnaswamy Stadium': 'India',
    'Punjab Cricket Association IS Bindra Stadium, Mohali': 'India',
    'Rajiv Gandhi International Stadium, Uppal': 'India',
    'Bay Oval': 'New Zealand',
    'Providence Stadium, Guyana': 'Guyana',
    'Maharashtra Cricket Association Stadium': 'India',
    'Manuka Oval': 'Australia',
    'Narendra Modi Stadium': 'India',
    'R Premadasa Stadium, Colombo': 'Sri Lanka',
    'Dubai International Cricket Stadium': 'UAE',
    'The Rose Bowl, Southampton': 'England',
    'Edgbaston, Birmingham': 'England',
    'Trent Bridge, Nottingham': 'England',
    'Sawai Mansingh Stadium, Jaipur': 'India',
    'JSCA International Stadium Complex, Ranchi': 'India',
    'Eden Gardens, Kolkata': 'India',
    'Bharat Ratna Shri Atal Bihari Vajpayee Ekana Cricket Stadium, Lucknow': 'India',
    'Himachal Pradesh Cricket Association Stadium, Dharamsala': 'India',
    'Arun Jaitley Stadium, Delhi': 'India',
    'Barabati Stadium, Cuttack': 'India',
    'Dr. Y.S. Rajasekhara Reddy ACA-VDCA Cricket Stadium, Visakhapatnam': 'India',
    'Saurashtra Cricket Association Stadium, Rajkot': 'India',
    'M Chinnaswamy Stadium, Bangalore': 'India',
    'Perth Stadium': 'Australia',
    'Adelaide Oval': 'Australia',
    'The Village, Malahide, Dublin': 'Ireland',
    'Brian Lara Stadium, Tarouba, Trinidad': 'Trinidad and Tobago',
    'Warner Park, Basseterre, St Kitts': 'Saint Kitts and Nevis',
    'Central Broward Regional Park Stadium Turf Ground, Lauderhill': 'USA',
    'Bay Oval, Mount Maunganui': 'New Zealand',
    'McLean Park, Napier': 'New Zealand',
    'Punjab Cricket Association IS Bindra Stadium, Mohali, Chandigarh': 'India',
    'Vidarbha Cricket Association Stadium, Jamtha, Nagpur': 'India',
    'Rajiv Gandhi International Stadium, Uppal, Hyderabad': 'India',
    'Greenfield International Stadium, Thiruvananthapuram': 'India',
    'Barsapara Cricket Stadium, Guwahati': 'India',
    'Holkar Cricket Stadium, Indore': 'India',
    'Wankhede Stadium, Mumbai': 'India',
    'Maharashtra Cricket Association Stadium, Pune': 'India',
    'Narendra Modi Stadium, Ahmedabad': 'India',
    'Malahide, Dublin': 'Ireland',
    "St George's Park, Gqeberha": 'South Africa',
    'The Wanderers Stadium, Johannesburg': 'South Africa',
    'Shaheed Veer Narayan Singh International Stadium, Raipur': 'India',
    'M Chinnaswamy Stadium, Bengaluru': 'India',
    'Zhejiang University of Technology Cricket Field': 'China',
    'Nassau County International Cricket Stadium, New York': 'USA',
    'Sir Vivian Richards Stadium, North Sound, Antigua': 'Antigua and Barbuda',
    'Daren Sammy National Cricket Stadium, Gros Islet, St Lucia': 'Saint Lucia',
    'Kensington Oval, Bridgetown, Barbados': 'Barbados',
    'Shrimant Madhavrao Scindia Cricket Stadium, Gwalior': 'India',
    'Pallekele International Cricket Stadium': 'Sri Lanka',
    'Kingsmead, Durban': 'South Africa',
    'SuperSport Park, Centurion': 'South Africa',
    'New Wanderers Stadium': 'South Africa',
    'Kingsmead': 'South Africa',
    'Brabourne Stadium': 'India',
    'Trent Bridge': 'England',
    "Lord's": 'England',
    'AMI Stadium': 'New Zealand',
    'R Premadasa Stadium': 'Sri Lanka',
    'Beausejour Stadium, Gros Islet': 'Saint Lucia',
    'Kensington Oval, Bridgetown': 'Barbados',
    'Punjab Cricket Association Stadium, Mohali': 'India',
    'Moses Mabhida Stadium': 'South Africa',
    "Queen's Park Oval, Port of Spain": 'Trinidad and Tobago',
    'Stadium Australia': 'Australia',
    'Subrata Roy Sahara Stadium': 'India',
    'Sardar Patel Stadium, Motera': 'India',
    'Edgbaston': 'England',
    'Shere Bangla National Stadium': 'Bangladesh',
    'Himachal Pradesh Cricket Association Stadium': 'India'
}

In [None]:
class EnhancedFeatureExtractor:
    def __init__(self, df):
        self.df = df

    def extract_advanced_features(self):
        # 1. Momentum: average runs over the last 5 balls
        self.df['runs_last_5_balls'] = (
            self.df.groupby(['match_id', 'innings'])['total_run'].rolling(window=5, min_periods=1).mean().reset_index(0, drop=True).values
        )

        # 2. Momentum: total wickets taken over the last 20 balls
        self.df['wickets_last_20_balls'] = (
            self.df.groupby(['match_id', 'innings'])['is_wicket'].rolling(window=20, min_periods=1).sum().reset_index(0, drop=True).values
        )

        # 3. Venue scoring tendency: how the current venue’s average run rate compares to overall
        venue_avg_score = self.df.groupby('venue')['total_run'].transform('mean')
        self.df['venue_score_impact'] = venue_avg_score / venue_avg_score.mean()

        # 4. Partnership runs: cumulative runs by the current striker in this innings
        self.df['partnership_runs'] = (
            self.df.groupby(['match_id', 'innings', 'striker_id'])['total_run'].cumsum()
        )

        # 5. Running run rate: total runs so far divided by overs bowled
        self.df['current_run_rate'] = (
            self.df.groupby(['match_id', 'innings'])['total_run']
            .cumsum()/ (self.df['total_balls'] / 6)
            )

        # 6. Wickets remaining: 10 total wickets minus those fallen so far
        self.df['wickets_remaining'] = (10 - self.df.groupby(['match_id', 'innings'])['is_wicket'].cumsum())

        return self.df

In [None]:
def parse_season(season_str):
    if not isinstance(season_str, str):
        return None
    parts = season_str.split('/')
    if len(parts) == 2:
        return 2000 + int(parts[1])
    try:
        return int(parts[0])
    except:
        return None

In [None]:
def create_wicket_flag(df: pd.DataFrame):
    df['is_wicket'] = 0
    if 'player_dismissed' in df.columns:
        df.loc[df['player_dismissed'].notna(), 'is_wicket'] = 1
    if 'other_player_dismissed' in df.columns:
        df.loc[df['other_player_dismissed'].notna(), 'is_wicket'] = 1
    return df

In [None]:
def load_and_add_recency(csv_path: str):
    df = pd.read_csv(csv_path)
    df['season'] = df['season'].apply(parse_season)
    df['start_date'] = pd.to_datetime(df['start_date'], format='%d-%m-%Y', errors='coerce')
    df.sort_values(by=['start_date', 'match_id', 'innings', 'total_balls'], inplace=True)
    df.reset_index(drop=True, inplace=True)

    earliest_date = df['start_date'].min()
    df['days_since'] = (df['start_date'] - earliest_date).dt.days
    match_min_days = df.groupby('match_id')['days_since'].transform('min')
    max_days = match_min_days.max() if not match_min_days.empty else 1
    df['recency_norm'] = match_min_days / max_days
    alpha = 1.0
    df['recency_weight'] = 1.0 + alpha * df['recency_norm']

    df = create_wicket_flag(df)
    feature_extractor = EnhancedFeatureExtractor(df)
    df = feature_extractor.extract_advanced_features()
    return df


In [None]:
def encode_categorical_columns(df: pd.DataFrame, cat_cols: list):
    encoders = {}
    for col in cat_cols:
        if col in df.columns:
            le = LabelEncoder()
            df[col] = df[col].fillna("NaNPlaceholder")
            df[col] = le.fit_transform(df[col].astype(str))
            encoders[col] = le
    return df, encoders


In [None]:
def build_feature_columns(df: pd.DataFrame):
    numeric_cols = [
        'season', 'innings', 'total_balls', 'runs_off_bat', 'extras',
        'wickets_fallen', 'current_score', 'runs_conceded', 'strike_rate',
        'run_rate', 'economy', 'home/away', 'current_run_rate', 'wickets_remaining',
        'bowler_id', 'striker_id', 'non_striker_id'
    ]
    cat_cols = [
        'batting_team', 'bowling_team', 'venue', 'toss_winner',
        'toss_decision', 'day/night'
    ]
    numeric_cols = [c for c in numeric_cols if c in df.columns]
    cat_cols = [c for c in cat_cols if c in df.columns]
    return numeric_cols, cat_cols

In [None]:
def preprocess_data(df: pd.DataFrame):
    # 1. Flag wickets
    df = create_wicket_flag(df)

    # 2. Derive city from venue, then collapse smaller West Indies countries
    df['city'] = df['venue'].replace(venue_country_dict)
    west_indies = [
        'Guyana', 'Trinidad and Tobago', 'Barbados',
        'Saint Lucia', 'Saint Kitts and Nevis',
        'Antigua and Barbuda', 'Jamaica'
    ]
    df.loc[df['city'].isin(west_indies), 'city'] = 'West Indies'

    # 3. Home/Away indicator: India home = 1, else away = 0
    df['home/away'] = np.where(df['city'] == 'India', 1, 0)

    # 4. Preserve original innings label for later filtering
    df['original_innings'] = df['innings']

    # 5. Determine which columns to treat as numeric vs. categorical
    numeric_cols, cat_cols = build_feature_columns(df)

    # 6. Encode all categorical columns into integer codes
    df, encoders = encode_categorical_columns(df, cat_cols)

    # 7. Clean numeric columns:
    #    - Replace NaNs with 0
    #    - Replace infinities with 0
    #    - Clip extreme values to [-500, 500]
    df[numeric_cols] = df[numeric_cols].fillna(0)
    df[numeric_cols] = df[numeric_cols].replace([np.inf, -np.inf], 0)
    df[numeric_cols] = df[numeric_cols].clip(lower=-500, upper=500)

    # 8. Standardize numeric columns (zero mean, unit variance)
    scaler = StandardScaler()
    df[numeric_cols] = scaler.fit_transform(df[numeric_cols])

    # 9. Final list of features to use in modeling
    feature_cols = numeric_cols + cat_cols
    return df, feature_cols, encoders

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        # Create a matrix of shape (max_len, d_model) with positional encodings
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # Compute the div_term for even indices
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        # Apply sine to even positions
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd positions
        pe[:, 1::2] = torch.cos(position * div_term)
        # Add a batch dimension (1, max_len, d_model) and register as buffer
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        # x shape: (batch_size, seq_len, d_model)
        seq_len = x.size(1)
        # Slice the positional encodings to the current sequence length
        pe = self.pe[:, :seq_len, :].to(x.device)
        # Add the positional encodings to the input embeddings
        return x + pe

In [None]:
class CustomCricketLoss(nn.Module):
    def __init__(self, alpha=0.7, beta=0.1, gamma=0.5, delta=0.6):
        super().__init__()
        # Huber loss (Smooth L1) without reduction, so we can weight per sample
        self.huber_loss = nn.SmoothL1Loss(reduction='none')
        self.alpha = alpha    # weight for MAPE term
        self.beta = beta      # weight for Huber term
        self.gamma = gamma    # penalty multiplier for under-prediction
        self.delta = delta    # penalty multiplier for over-prediction

    def forward(self, pred, target, weight=None, run_rate=None, wickets_remaining=None):
        # 1. Huber loss per sample
        huber = self.huber_loss(pred, target)  
        
        # 2. MAPE: absolute percentage error ×100
        mape = torch.abs((pred - target) / (target + 1e-8)) * 100
        
        # 3. Directional penalties:
        #    under-prediction penalty: (target - pred) when pred <= target
        underprediction_penalty = (target - pred).clamp(min=0) * self.gamma
        #    over-prediction penalty: (pred - target) when pred >= target
        overprediction_penalty = (pred - target).clamp(min=0) * self.delta

        # 4. Contextual scaling (if provided):
        if run_rate is not None and wickets_remaining is not None:
            # High run rate → under-prediction is costlier
            underprediction_penalty *= (1 + run_rate / 6.0)
            # Few wickets remaining → over-prediction is costlier
            overprediction_penalty *= (1 + (10 - wickets_remaining) / 10.0)

        # 5. Combine terms, applying recency weight if given
        if weight is not None:
            combined = (
                self.alpha * (mape * weight) +
                self.beta * (huber * weight) +
                underprediction_penalty * weight +
                overprediction_penalty * weight
            )
        else:
            combined = (
                self.alpha * mape +
                self.beta * huber +
                underprediction_penalty +
                overprediction_penalty
            )

        # 6. Return mean loss over the batch
        return combined.mean()

In [None]:
class TransformerHybridModel(nn.Module):
    def __init__(self, input_size, hidden_size=64, num_lstm_layers=2,
                 num_transformer_layers=1, nhead=8, dropout=0.5):
        super().__init__()

        # 1) Input projection: map raw features to a unified hidden dimension
        self.input_projection = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        # 2) Bidirectional LSTM stack for local sequence modeling
        self.lstm = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=num_lstm_layers,
            batch_first=True,
            dropout=dropout if num_lstm_layers > 1 else 0,
            bidirectional=True
        )

        # 3) Learnable [CLS] token to summarize the sequence
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size * 2))

        # 4) Sinusoidal positional encoding for Transformer
        self.positional_encoding = PositionalEncoding(hidden_size * 2)

        # 5) Transformer encoder for long‑range dependencies
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size * 2,
            nhead=nhead,
            dim_feedforward=hidden_size * 4,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_transformer_layers
        )

        # 6) Regression head: convert the [CLS] output into a scalar prediction
        self.output_net = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, 1)
        )

    def forward(self, x):
        # x: (batch_size, seq_len=30, input_size)
        x = self.input_projection(x)
        # After projection: (batch_size, 30, hidden_size)

        lstm_out, _ = self.lstm(x)
        # After BiLSTM: (batch_size, 30, hidden_size*2)

        # Prepend a batch‑wise expanded [CLS] token
        batch_size = x.size(0)
        cls_token = self.cls_token.expand(batch_size, -1, -1)
        # Shape: (batch_size, 1, hidden_size*2)
        x = torch.cat([cls_token, lstm_out], dim=1)
        # Now x.shape = (batch_size, 31, hidden_size*2)

        # Inject positional information
        x = self.positional_encoding(x)
        # Shape remains (batch_size, 31, hidden_size*2)

        # Transformer encoder applies self‑attention over all 31 positions
        encoder_out = self.transformer_encoder(x)
        # Shape: (batch_size, 31, hidden_size*2)

        # Take the [CLS] output (position 0) as the sequence summary
        cls_output = encoder_out[:, 0, :]  
        # Shape: (batch_size, hidden_size*2)

        # Feed through regression head to predict the final score
        output = self.output_net(cls_output).squeeze(-1)
        # Shape: (batch_size,)
        return output

In [None]:
class CricketBallDatasetWithWeight(Dataset):
    def __init__(self, X, y_final_score, weights, run_rates, wickets_remaining):
        self.X = X
        self.y_final_score = y_final_score
        self.weights = weights
        self.run_rates = run_rates
        self.wickets_remaining = wickets_remaining

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return (torch.tensor(self.X[idx], dtype=torch.float32),
                (torch.tensor(self.y_final_score[idx], dtype=torch.float32),
                 torch.tensor(self.weights[idx], dtype=torch.float32),
                 torch.tensor(self.run_rates[idx], dtype=torch.float32),
                 torch.tensor(self.wickets_remaining[idx], dtype=torch.float32)))

In [None]:
def create_sequences_with_weight(df: pd.DataFrame,feature_cols: list,seq_length=30,runs_col='total_run',weight_col='recency_weight'):
    X_list, y_final_score_list, weight_list, run_rate_list, wickets_remaining_list, meta_list = [], [], [], [], [], []
    grouped = df.groupby(['match_id', 'innings'], sort=False)

    for (m_id, inn), group_data in grouped:
        # 1. Sort deliveries within this innings by the ball index
        group_data = group_data.sort_values('total_balls')

        # 2. Extract raw arrays for features, runs, and context
        feats = group_data[feature_cols].values                # shape: (num_balls, num_features)
        runs = group_data[runs_col].values                     # run scored on each ball
        group_weight = group_data[weight_col].iloc[0]          # same recency weight for all balls in this match
        final_score = np.sum(runs)                             # target: total runs in the innings
        orig_inn = group_data['original_innings'].iloc[0]      # store 1st/2nd innings label
        run_rates = group_data['current_run_rate'].values      # dynamic run rate at each ball
        wickets_remaining = group_data['wickets_remaining'].values  # wickets still in hand at each ball

        # 3. Slide a window across this innings to create sequences
        #    Step size = 6 balls (one over), window length = seq_length (30 balls)
        for start_idx in range(0, len(group_data) - seq_length, 6):
            end_idx = start_idx + seq_length

            # 3a. Input sequence: features for balls [start_idx, end_idx)
            X_seq = feats[start_idx:end_idx]
            X_list.append(X_seq)

            # 3b. Target: the final innings score for every sequence
            y_final_score_list.append(final_score)

            # 3c. Recency weight for this sequence
            weight_list.append(group_weight)

            # 3d. Context at window end: run rate and wickets remaining
            run_rate_list.append(run_rates[end_idx - 1])
            wickets_remaining_list.append(wickets_remaining[end_idx - 1])

            # 3e. Metadata: match ID, original innings, starting ball index
            meta_list.append((m_id, orig_inn, start_idx))

    # 4. Convert lists into NumPy arrays for model input
    X_array = np.array(X_list, dtype=np.float32)                      # shape: (num_sequences, 30, num_features)
    y_final_score_array = np.array(y_final_score_list, dtype=np.float32)
    weight_array = np.array(weight_list, dtype=np.float32)
    run_rate_array = np.array(run_rate_list, dtype=np.float32)
    wickets_remaining_array = np.array(wickets_remaining_list, dtype=np.float32)

    return (X_array,y_final_score_array,weight_array,run_rate_array,wickets_remaining_array,meta_list)

In [None]:
def train_advanced_model(model,train_loader,val_loader=None,epochs=50,lr=1e-3,device='cpu',max_grad_norm=1.0,patience=10):
    # 1. Move the model to the specified device (CPU or GPU)
    model.to(device)

    # 2. Set up the optimizer: AdamW with weight decay for regularization
    optimizer = optim.AdamW(model.parameters(),
                            lr=lr,
                            weight_decay=0.1)

    # 3. Learning‑rate scheduler: Cosine Annealing with warm restarts
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=5,       # after 5 epochs, restart the LR schedule
        T_mult=2,    # each subsequent restart cycle is twice as long
        eta_min=1e-6 # minimum LR after annealing
    )

    # 4. Use your custom cricket loss
    criterion = CustomCricketLoss(alpha=0.7,
                                 beta=0.1,
                                 gamma=0.5,
                                 delta=0.6)

    # 5. Initialize tracking variables
    best_val_loss = float('inf')
    patience_counter = 0
    train_losses, val_losses = [], []

    # 6. Epoch loop
    for epoch in range(epochs):
        model.train()               # set model to training mode
        epoch_train_losses = []

        # 7. Batch loop over training data
        for X_batch, (y_batch, w_batch, rr_batch, wr_batch) in train_loader:
            # 7a. Move batch tensors to device
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)
            w_batch = w_batch.to(device)
            rr_batch = rr_batch.to(device)
            wr_batch = wr_batch.to(device)

            # 7b. Zero out any gradients from previous step
            optimizer.zero_grad()

            # 7c. Forward pass: get predictions
            pred = model(X_batch)

            # 7d. Compute loss (includes recency weight, run rate, wickets)
            loss = criterion(pred, y_batch, w_batch, rr_batch, wr_batch)

            # 7e. Safety check: abort if loss is NaN or infinite
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"NaN/Inf loss detected at epoch {epoch+1}")
                return None, None, None

            # 7f. Backward pass: compute gradients
            loss.backward()

            # 7g. Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            # 7h. Optimizer step: update model weights
            optimizer.step()

            # 7i. Track this batch’s loss
            epoch_train_losses.append(loss.item())

        # 8. Scheduler step: adjust learning rate
        scheduler.step()

        # 9. Compute and log average training loss for this epoch
        avg_train_loss = np.mean(epoch_train_losses)
        train_losses.append(avg_train_loss)
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}")

        # 10. If a validation loader is provided, evaluate on validation set
        if val_loader:
            model.eval()           # switch to evaluation mode
            epoch_val_losses = []
            with torch.no_grad():
                for X_val, (y_val, w_val, rr_val, wr_val) in val_loader:
                    # Move to device
                    X_val = X_val.to(device)
                    y_val = y_val.to(device)
                    w_val = w_val.to(device)
                    rr_val = rr_val.to(device)
                    wr_val = wr_val.to(device)

                    # Forward pass
                    pred = model(X_val)
                    # Compute validation loss
                    val_loss = criterion(pred, y_val, w_val, rr_val, wr_val)
                    epoch_val_losses.append(val_loss.item())

            # Compute and log average validation loss
            avg_val_loss = np.mean(epoch_val_losses)
            val_losses.append(avg_val_loss)
            print(f"Val Loss: {avg_val_loss:.4f}")

            # 11. Early‑stopping logic
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                patience_counter = 0     # reset patience if improved
            else:
                patience_counter += 1    # otherwise, increment

        # 12. If no improvement for `patience` epochs, stop training
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

    # 13. Return the trained model and loss histories
    return model, train_losses, val_losses

In [None]:
def load_espn_data(espn_file_path, match_id, innings):
    espn_df = pd.read_excel(espn_file_path)
    filtered_data = espn_df[(espn_df['match_id'] == match_id) & (espn_df['innings'] == innings)]
    return filtered_data[['over', 'espn_prediction']]

In [None]:
def calculate_over_wise_predictions(ball_indices, predictions, seq_length=30):
    over_predictions = {}
    for ball_idx, prediction in zip(ball_indices, predictions):
        over = (ball_idx // 6) + 1
        ball_in_over = ball_idx % 6
        if ball_idx >= seq_length - 1:
            if ball_in_over == 5 or ball_idx == ball_indices[-1]:
                over_predictions[over] = prediction
    return over_predictions

In [None]:
def calculate_rmse(actual_values, predicted_values):
    return np.sqrt(mean_squared_error(actual_values, predicted_values))

In [None]:
def visualize_comparison(overs, model_preds, espn_preds, actual_final_score,
                         model_rmse, espn_rmse, comparison_rmse, match_id, innings):
    plt.figure(figsize=(14, 8))
    plt.plot(overs, model_preds, 'ro-', label='Model Predictions', linewidth=2, markersize=8)
    plt.plot(overs, espn_preds, 'bs-', label='ESPN Predictions', linewidth=2, markersize=8)
    plt.axhline(y=actual_final_score, color='green', linestyle='-', linewidth=3,
                label=f'Actual Final Score: {actual_final_score:.0f}')
    model_errors = np.abs(np.array(model_preds) - actual_final_score)
    espn_errors = np.abs(np.array(espn_preds) - actual_final_score)
    plt.fill_between(overs, np.array(model_preds) - model_errors, np.array(model_preds) + model_errors,
                     color='red', alpha=0.1)
    plt.fill_between(overs, np.array(espn_preds) - espn_errors, np.array(espn_preds) + espn_errors,
                     color='blue', alpha=0.1)
    plt.text(0.02, 0.97, f"Model RMSE: {model_rmse:.2f}", transform=plt.gca().transAxes,
             fontsize=12, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    plt.text(0.02, 0.92, f"ESPN RMSE: {espn_rmse:.2f}", transform=plt.gca().transAxes,
             fontsize=12, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    plt.text(0.02, 0.87, f"Model vs ESPN RMSE: {comparison_rmse:.2f}", transform=plt.gca().transAxes,
             fontsize=12, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    plt.xlabel('Over', fontsize=14)
    plt.ylabel('Predicted Final Score', fontsize=14)
    plt.title(f'Model vs ESPN Predictions - Match {match_id}, Innings {innings}', fontsize=16)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=12)
    plt.savefig('comparison_plot.png')

In [None]:
def compare_with_espn(ball_indices_sorted, pred_sorted, actual_final_score, selected_match, selected_inning, seq_length=30):
    model_over_predictions = calculate_over_wise_predictions(ball_indices_sorted, pred_sorted, seq_length)
    try:
        espn_file_path = r"/content/Score.xlsx"
        espn_data = load_espn_data(espn_file_path, selected_match, selected_inning)
        if espn_data.empty:
            print(f"No ESPN data found for match {selected_match}, inning {selected_inning}")
            return None, None, None
        starting_over = (seq_length // 6)
        matched_overs = []
        model_preds = []
        espn_preds = []
        for over in espn_data['over']:
            if over >= starting_over and over in model_over_predictions:
                matched_overs.append(over)
                model_preds.append(model_over_predictions[over])
                espn_preds.append(espn_data.loc[espn_data['over'] == over, 'espn_prediction'].values[0])
        if not matched_overs:
            print(f"No overlapping overs found for comparison in match {selected_match}, inning {selected_inning}")
            return None, None, None
        model_rmse = calculate_rmse([actual_final_score] * len(model_preds), model_preds)
        espn_rmse = calculate_rmse([actual_final_score] * len(espn_preds), espn_preds)
        comparison_rmse = calculate_rmse(espn_preds, model_preds)
        visualize_comparison(matched_overs, model_preds, espn_preds, actual_final_score,
                            model_rmse, espn_rmse, comparison_rmse, selected_match, selected_inning)
        return model_rmse, espn_rmse, comparison_rmse
    except FileNotFoundError:
        print(f"ESPN data file not found at: {espn_file_path}")
        return None, None, None
    except Exception as e:
        print(f"Error comparing with ESPN data: {e}")
        return None, None, None

In [None]:
def create_detailed_rmse_analysis(ball_indices_sorted, pred_sorted, actual_final_score, selected_match,
                                 selected_inning, espn_data=None, seq_length=30):
    phases = {
        'Powerplay (1-6)': (1, 6),
        'Middle (7-15)': (7, 15),
        'Death (16-20)': (16, 20)
    }
    over_indices = [(idx // 6) + 1 for idx in ball_indices_sorted]
    phase_results = {phase: {'model_rmse': None, 'espn_rmse': None, 'comparison_rmse': None}
                    for phase in phases.keys()}
    for phase_name, (start_over, end_over) in phases.items():
        phase_indices = [i for i, over in enumerate(over_indices)
                         if start_over <= over <= end_over and ball_indices_sorted[i] >= seq_length - 1]
        if not phase_indices:
            continue
        phase_pred = [pred_sorted[i] for i in phase_indices]
        phase_overs = [over_indices[i] for i in phase_indices]
        model_rmse = calculate_rmse([actual_final_score] * len(phase_pred), phase_pred)
        phase_results[phase_name]['model_rmse'] = model_rmse
        if espn_data is not None:
            model_over_pred = calculate_over_wise_predictions(
                [ball_indices_sorted[i] for i in phase_indices],
                [pred_sorted[i] for i in phase_indices],
                seq_length
            )
            espn_phase = espn_data[(espn_data['over'] >= start_over) & (espn_data['over'] <= end_over)]
            matched_overs = []
            model_preds = []
            espn_preds = []
            for over in espn_phase['over']:
                if over in model_over_pred:
                    matched_overs.append(over)
                    model_preds.append(model_over_pred[over])
                    espn_preds.append(espn_phase.loc[espn_phase['over'] == over, 'espn_prediction'].values[0])
            if matched_overs:
                espn_rmse = calculate_rmse([actual_final_score] * len(espn_preds), espn_preds)
                comparison_rmse = calculate_rmse(espn_preds, model_preds)
                phase_results[phase_name]['espn_rmse'] = espn_rmse
                phase_results[phase_name]['comparison_rmse'] = comparison_rmse
    print("\n=== Detailed RMSE Analysis by Innings Phase ===")
    for phase, results in phase_results.items():
        print(f"\n{phase}:")
        print(f"  Model RMSE: {results['model_rmse']:.2f}" if results['model_rmse'] else "  Model RMSE: N/A")
        print(f"  ESPN RMSE: {results['espn_rmse']:.2f}" if results['espn_rmse'] else "  ESPN RMSE: N/A")
        print(f"  Model vs ESPN RMSE: {results['comparison_rmse']:.2f}" if results['comparison_rmse'] else "  Model vs ESPN RMSE: N/A")
    return phase_results

In [None]:
def main():
    # 1. Load and preprocess full dataset
    csv_path = r"/content/ball_by_ball.csv"
    df = load_and_add_recency(csv_path)
    df, feature_cols, encoders = preprocess_data(df)

    # 2. Ask user for the starting match_id
    initial_match_id = int(input("Enter the starting match_id: "))

    # 3. Build sorted list of all match_ids ≥ the one entered
    all_ids = sorted(df['match_id'].unique())
    if initial_match_id not in all_ids:
        raise ValueError(f"match_id {initial_match_id} not found in data")
    start_idx = all_ids.index(initial_match_id)
    to_process = all_ids[start_idx:]

    # 4. Loop over each match_id, training on all earlier matches and predicting on the current one
    seq_length = 30
    for match_id in to_process:
        print(f"\n=== Processing match_id {match_id} ===")

        # 4a. Prepare test sequences for this match
        last_match_df = df[df['match_id'] == match_id]
        X_last, y_last, w_last, rr_last, wr_last, meta_last = create_sequences_with_weight(
            last_match_df, feature_cols=feature_cols, seq_length=seq_length, runs_col='total_run'
        )
        if len(meta_last) == 0:
            print(f"No sequences for match {match_id}, skipping.")
            continue

        # 4b. Prepare train+val on all matches with id < current
        df_past = df[df['match_id'] < match_id]
        X_pv, y_pv, w_pv, rr_pv, wr_pv, meta_pv = create_sequences_with_weight(
            df_past, feature_cols=feature_cols, seq_length=seq_length, runs_col='total_run'
        )
        if len(X_pv) < 1:
            print(f"Not enough past data before match {match_id}, skipping.")
            continue

        # 4c. Split into train & val
        X_train, X_val, y_train, y_val, w_train, w_val, rr_train, rr_val, wr_train, wr_val, _, _ = \
            train_test_split(
                X_pv, y_pv, w_pv, rr_pv, wr_pv, meta_pv,
                test_size=0.2, random_state=42
            )

        # 4d. Create DataLoaders
        train_ds = CricketBallDatasetWithWeight(X_train, y_train, w_train, rr_train, wr_train)
        val_ds   = CricketBallDatasetWithWeight(X_val,   y_val,   w_val,   rr_val,   wr_val)
        test_ds  = CricketBallDatasetWithWeight(X_last,  y_last,  w_last,  rr_last,  wr_last)

        train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
        val_loader   = DataLoader(val_ds,   batch_size=32, shuffle=False)
        test_loader  = DataLoader(test_ds,  batch_size=32, shuffle=False)

        # 4e. Instantiate and train model
        model = TransformerHybridModel(input_size=len(feature_cols))
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model, train_losses, val_losses = train_advanced_model(
            model, train_loader, val_loader, epochs=50, device=device
        )

        # 4f. Plot & save training/validation loss
        plt.figure(figsize=(10, 5))
        plt.plot(train_losses, label='Training Loss')
        plt.plot(val_losses,   label='Validation Loss')
        plt.title(f'Model Loss for match {match_id}')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.savefig(f'loss_plot_{match_id}.png')
        plt.close()

        # 4g. Predict on the current match
        model.eval()
        with torch.no_grad():
            X_tensor = torch.tensor(X_last, dtype=torch.float32).to(device)
            preds = model(X_tensor).cpu().numpy()

        # 4h. Sort predictions by ball index
        ball_idxs = [m[2] + seq_length - 1 for m in meta_last]
        order = sorted(range(len(ball_idxs)), key=lambda i: ball_idxs[i])
        ball_idxs_sorted = [ball_idxs[i] for i in order]
        preds_sorted = preds[order]

        # 4i. Print per-ball predictions
        print(f"\nPredictions for match {match_id}:")
        for b, p in zip(ball_idxs_sorted, preds_sorted):
            print(f"  Ball {b+1}: {p:.2f}")

        # 4j. Plot predictions vs actual + wickets
        group = df[(df['match_id'] == match_id) & (df['original_innings'] == meta_last[0][1])].sort_values('total_balls')
        actual_score = group['total_run'].sum()
        plt.figure(figsize=(12, 6))
        plt.axhline(y=actual_score, linestyle='-', label='Actual Final Score')
        plt.plot(ball_idxs_sorted, preds_sorted, marker='x', linestyle='--', label='Predicted Final Score')

        wk_indices = group.reset_index(drop=True).index[group['is_wicket'] == 1].tolist()
        wk_indices = [i for i in wk_indices if i >= seq_length - 1]
        wk_x, wk_p = [], []
        for idx in wk_indices:
            if idx in ball_idxs_sorted:
                pos = ball_idxs_sorted.index(idx)
                wk_x.append(idx)
                wk_p.append(preds_sorted[pos])
        if wk_x:
            plt.scatter(wk_x, wk_p, marker='o', s=80, label='Wicket')

        plt.xlabel('Ball Index')
        plt.ylabel('Final Score')
        plt.title(f'Pred vs Actual for match {match_id}')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'prediction_plot_{match_id}.png')
        plt.close()

        # 4k. Summary RMSE
        rmse = calculate_rmse([actual_score]*len(preds_sorted), preds_sorted)
        print(f"\nActual Final Score: {actual_score}")
        print(f"Predicted Final Score (last ball): {preds_sorted[-1]:.2f}")
        print(f"Model RMSE: {rmse:.2f}")

        # 4l. (Optional) Compare with ESPN if data exists
        try:
            espn_df = load_espn_data(r"/content/Score.xlsx", match_id, meta_last[0][1])
            if not espn_df.empty:
                print("\nComparing vs ESPN predictions:")
                m_rmse, e_rmse, c_rmse = compare_with_espn(
                    ball_idxs_sorted, preds_sorted, actual_score, match_id, meta_last[0][1], seq_length
                )
                print(f"  Model RMSE: {m_rmse:.2f}, ESPN RMSE: {e_rmse:.2f}, Comparison RMSE: {c_rmse:.2f}")
                create_detailed_rmse_analysis(
                    ball_idxs_sorted, preds_sorted, actual_score,
                    match_id, meta_last[0][1], espn_df, seq_length
                )
            else:
                print(f"No ESPN data for match {match_id}")
        except FileNotFoundError:
            print("ESPN data file not found.")
        except Exception as e:
            print(f"ESPN comparison error: {e}")

In [None]:
if __name__ == "__main__":
    main()