In [1]:
# Core libraries
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# Data processing and analysis
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import joblib
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, asdict, field



# Progress tracking and logging
from tqdm import tqdm
import logging

# Random seed
import random

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_dtype(torch.float32)

In [2]:

@dataclass
class DataConfig:
    seq_length: int = 3
    start_season: int = 2016
    train_ratio: float = 0.7
    valid_ratio: float = 0.2
    batch_size: int = 32
    min_innings: float = 200.0
    positions: List[str] = field(default_factory=lambda: ['C', '1B', '2B', '3B', 'SS', 'LF', 'CF', 'RF', 'DH'])

@dataclass
class PositionalMetrics:
    inn: float
    drs: float
    uzr: float
    oaa: float
    is_primary: bool = False

@dataclass
class SeasonSnapshot:
    season: int
    age: int
    metrics: Dict[str, PositionalMetrics]

@dataclass
class PlayerSequence:
    player_id: str
    history: List[SeasonSnapshot]
    target: Optional[SeasonSnapshot] = None

class FieldingDataset(Dataset):
    def __init__(self, sequences: List[PlayerSequence], positions: List[str]):
        self.sequences = sequences
        self.positions = positions
        
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        
        # Create tensors with explicit dtype
        history_tensor = torch.zeros(len(sequence.history), len(self.positions), 4, dtype=torch.float32)
        position_mask = torch.zeros(len(sequence.history), len(self.positions), dtype=torch.float32)
        
        # Fill history tensor
        for t, season in enumerate(sequence.history):
            for pos_idx, pos in enumerate(self.positions):
                if pos in season.metrics:
                    metrics = season.metrics[pos]
                    history_tensor[t, pos_idx] = torch.tensor([
                        metrics.inn, metrics.drs, metrics.uzr, metrics.oaa
                    ], dtype=torch.float32)
                    position_mask[t, pos_idx] = 1
        
        # Create target tensor if exists
        if sequence.target:
            target_tensor = torch.zeros(len(self.positions), 4, dtype=torch.float32)
            target_mask = torch.zeros(len(self.positions), dtype=torch.float32)
            
            for pos_idx, pos in enumerate(self.positions):
                if pos in sequence.target.metrics:
                    metrics = sequence.target.metrics[pos]
                    target_tensor[pos_idx] = torch.tensor([
                        metrics.inn, metrics.drs, metrics.uzr, metrics.oaa
                    ], dtype=torch.float32)
                    target_mask[pos_idx] = 1
        else:
            target_tensor = torch.zeros(len(self.positions), 4, dtype=torch.float32)
            target_mask = torch.zeros(len(self.positions), dtype=torch.float32)
            
        return {
            'history': history_tensor,
            'history_mask': position_mask,
            'target': target_tensor,
            'target_mask': target_mask,
            'age': torch.tensor([s.age for s in sequence.history], dtype=torch.float32),
            'player_id': sequence.player_id
        }

def prepare_sequences(df: pd.DataFrame, config: DataConfig) -> List[PlayerSequence]:
    sequences = []
    
    # Filter required columns first
    required_columns = ['Inn', 'DRS', 'UZR/150', 'OAA', 'Pos', 'IDfg', 'Season', 'Age']
    df = df[
        df[required_columns].notna().all(axis=1) &  # Remove rows with any NaN
        (df['Inn'] > 0) &  # Remove invalid innings
        df['Inn'].notna()  # Ensure innings are valid
    ]
    
    # Replace infinities and verify no NaN values remain
    df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=['Inn', 'DRS', 'UZR/150', 'OAA'])
    
    # Log data quality
    logger.info(f"Total valid rows after cleaning: {len(df)}")
    logger.info("\nValue ranges:")
    for col in ['Inn', 'DRS', 'UZR/150', 'OAA', 'Age']:
        logger.info(f"{col}: {df[col].min():.1f} to {df[col].max():.1f}")
    
    for player_id, player_data in df.groupby('IDfg'):
        player_data = player_data.sort_values('Season')
        seasons = []
        
        # Group by season
        for season, season_data in player_data.groupby('Season'):
            # Get primary position (most innings)
            primary_pos = season_data.loc[season_data['Inn'].idxmax(), 'Pos']
            
            # Create position metrics dictionary
            position_metrics = {}
            for _, row in season_data.iterrows():
                if row['Inn'] >= config.min_innings:
                    position_metrics[row['Pos']] = PositionalMetrics(
                        inn=float(row['Inn']),  # Ensure float type
                        drs=float(row['DRS']),
                        uzr=float(row['UZR/150']),
                        oaa=float(row['OAA']),
                        is_primary=(row['Pos'] == primary_pos)
                    )
            
            if position_metrics:  # Only add season if we have valid positions
                seasons.append(SeasonSnapshot(
                    season=int(season),  # Ensure int type
                    age=int(season_data['Age'].iloc[0]),
                    metrics=position_metrics
                ))
        
        # Create sequences with sliding window
        for i in range(len(seasons) - config.seq_length):
            history = seasons[i:i+config.seq_length]
            target = seasons[i+config.seq_length]
            sequences.append(PlayerSequence(
                player_id=str(player_id),  # Ensure string type
                history=history,
                target=target
            ))
    
    return sequences
# Load and process data
config = DataConfig()

# Read data
fielding_df = pd.read_csv('../data/mlb_fielding_data_2000_2024.csv')
batting_df = pd.read_csv('../data/mlb_batting_data_2010_2024.csv')
pitching_df = pd.read_csv('../data/mlb_pitching_data_2000_2024.csv')

# Merge age information
age_lookup = pd.concat([
    batting_df[['IDfg', 'Season', 'Age']],
    pitching_df[['IDfg', 'Season', 'Age']]
]).drop_duplicates()

df = fielding_df.merge(
    age_lookup,
    on=['IDfg', 'Season'],
    how='left'
)

# Create sequences
sequences = prepare_sequences(df, config)

# Split into train/valid/test
total_size = len(sequences)
train_size = int(total_size * config.train_ratio)
valid_size = int(total_size * config.valid_ratio)

indices = np.random.permutation(total_size)
train_indices = indices[:train_size]
valid_indices = indices[train_size:train_size + valid_size]
test_indices = indices[train_size + valid_size:]

train_sequences = [sequences[i] for i in train_indices]
valid_sequences = [sequences[i] for i in valid_indices]
test_sequences = [sequences[i] for i in test_indices]

# Create datasets
train_dataset = FieldingDataset(train_sequences, config.positions)
valid_dataset = FieldingDataset(valid_sequences, config.positions)
test_dataset = FieldingDataset(test_sequences, config.positions)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size)

# Print dataset sizes and inspect a batch
print(f"Train size: {len(train_dataset)}")
print(f"Valid size: {len(valid_dataset)}")
print(f"Test size: {len(test_dataset)}")

sample_batch = next(iter(train_loader))
print("\nBatch shapes:")
for k, v in sample_batch.items():
    if isinstance(v, torch.Tensor):
        print(f"{k}: {v.shape}")

INFO:__main__:Total valid rows after cleaning: 10453
INFO:__main__:
Value ranges:
INFO:__main__:Inn: 0.1 to 1443.1
INFO:__main__:DRS: -27.0 to 41.0
INFO:__main__:UZR/150: -259.2 to 222.6
INFO:__main__:OAA: -23.0 to 35.0
INFO:__main__:Age: 19.0 to 44.0


Train size: 677
Valid size: 193
Test size: 98

Batch shapes:
history: torch.Size([32, 3, 9, 4])
history_mask: torch.Size([32, 3, 9])
target: torch.Size([32, 9, 4])
target_mask: torch.Size([32, 9])
age: torch.Size([32, 3])


Define metrics model

In [3]:
class DefensiveMetricsPredictor(nn.Module):
    def __init__(self,
                 n_positions: int = 9,
                 n_metrics: int = 4,
                 d_model: int = 128,
                 n_heads: int = 8,
                 n_layers: int = 4,
                 d_ff: int = 512,
                 dropout: float = 0.1):
        super().__init__()
        
        # Position embeddings
        self.position_embeddings = nn.Parameter(torch.randn(n_positions, d_model))
        
        # Metric embeddings
        self.metric_embeddings = nn.Parameter(torch.randn(n_metrics, d_model))
        
        # Input projection for each metric type
        self.metric_projections = nn.ModuleList([
            nn.Linear(1, d_model) for _ in range(n_metrics)
        ])
        
        # Age embedding
        self.age_embedding = nn.Sequential(
            nn.Linear(1, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        
        # Position-wise feed-forward
        self.position_ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        
        # Cross-position attention layers
        self.cross_position_layers = nn.ModuleList([
            nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
            for _ in range(n_layers)
        ])
        
        # Temporal attention layers
        self.temporal_layers = nn.ModuleList([
            nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
            for _ in range(n_layers)
        ])
        
        # Layer norms
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Output projections for each metric
        self.metric_predictions = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_model),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(d_model, 1)
            ) for _ in range(n_metrics)
        ])
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, mask: torch.Tensor, age: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [batch_size, seq_len, n_positions, n_metrics]
            mask: [batch_size, seq_len, n_positions]
            age: [batch_size, seq_len]
        Returns:
            output: [batch_size, n_positions, n_metrics]
        """
        # Add shape assertions
        batch_size, seq_len, n_positions, n_metrics = x.shape
        assert seq_len == 3, f"Expected seq_length 3, got {seq_len}"
        assert n_positions == 9, f"Expected 9 positions, got {n_positions}"
        assert n_metrics == 4, f"Expected 4 metrics, got {n_metrics}"
    
        
        # Process each metric separately and combine
        metric_encodings = []
        for i in range(n_metrics):
            metric_data = x[..., i:i+1]  # [batch, seq, pos, 1]
            metric_proj = self.metric_projections[i](metric_data)  # [batch, seq, pos, d_model]
            metric_encodings.append(metric_proj)
        
        # Combine metric encodings
        position_states = torch.stack(metric_encodings, dim=-2)  # [batch, seq, pos, n_metrics, d_model]
        position_states = position_states.mean(dim=-2)  # [batch, seq, pos, d_model]
        
        # Add position embeddings
        position_states = position_states + self.position_embeddings.unsqueeze(0).unsqueeze(0)
        
        # Process age information
        age_embed = self.age_embedding(age.unsqueeze(-1))  # [batch, seq, d_model]
        
        # Cross-position attention
        for layer in self.cross_position_layers:
            # Reshape for attention
            states = position_states.view(batch_size * seq_len, n_positions, -1)
            states = layer(states, states, states, 
                         key_padding_mask=mask.view(-1, n_positions))[0]
            position_states = states.view(batch_size, seq_len, n_positions, -1)
        
        # Temporal attention with age awareness
        position_states = position_states.transpose(1, 2)  # [batch, pos, seq, d_model]
        for layer in self.temporal_layers:
            temporal_states = []
            for pos in range(n_positions):
                pos_states = position_states[:, pos]  # [batch, seq, d_model]
                pos_states = pos_states + age_embed  # Add age information
                pos_states = layer(pos_states, pos_states, pos_states)[0]
                temporal_states.append(pos_states)
            position_states = torch.stack(temporal_states, dim=1)
        
        predictions = []
        for i in range(n_metrics):
            metric_pred = self.metric_predictions[i](position_states[:, :, -1])  # [batch, n_positions]
            predictions.append(metric_pred.unsqueeze(-1))  # [batch, n_positions, 1]
        
        # Stack along last dimension to get [batch, n_positions, n_metrics]
        output = torch.cat(predictions, dim=-1)
        
        return output

Define positional predictor model

In [4]:
class PositionTransitionModel(nn.Module):
    def __init__(self,
                 n_positions: int = 9,
                 n_metrics: int = 4,
                 d_model: int = 128,
                 n_heads: int = 8,
                 n_layers: int = 4,
                 d_ff: int = 512,
                 dropout: float = 0.1):
        super().__init__()
        
        # Position hierarchy embeddings (e.g., SS->2B more likely than SS->1B)
        self.position_hierarchy = nn.Parameter(torch.randn(n_positions, n_positions))
        
        # Position embeddings
        self.position_embeddings = nn.Parameter(torch.randn(n_positions, d_model))
        
        # Age-based transition embeddings
        self.age_transition = nn.Sequential(
            nn.Linear(1, d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, n_positions * n_positions)
        )
        
        # Metric encoders
        self.metric_encoders = nn.ModuleList([
            nn.Sequential(
                nn.Linear(1, d_model),
                nn.ReLU(),
                nn.Dropout(dropout)
            ) for _ in range(n_metrics)
        ])
        
        # Position-specific metric attention
        self.metric_attention = nn.MultiheadAttention(
            d_model, n_heads, dropout=dropout, batch_first=True
        )
        
        # Historical position encoding
        self.history_encoder = nn.GRU(
            input_size=n_positions,
            hidden_size=d_model,
            num_layers=2,
            batch_first=True,
            dropout=dropout
        )
        
        # Position transition layers
        self.transition_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=n_heads,
                dim_feedforward=d_ff,
                dropout=dropout,
                batch_first=True
            ) for _ in range(n_layers)
        ])
        
        # Output projection
        self.position_predictor = nn.Sequential(
            nn.Linear(d_model * 3, d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, n_positions)
        )
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, 
                defensive_metrics: torch.Tensor,  # [batch, n_positions, n_metrics]
                position_history: torch.Tensor,   # [batch, seq_len, n_positions]
                age: torch.Tensor,               # [batch, 1]
                current_position: torch.Tensor    # [batch, n_positions] (one-hot)
                ) -> torch.Tensor:
        
        batch_size = defensive_metrics.shape[0]
        
        # 1. Process defensive metrics
        metric_embeddings = []
        for i in range(defensive_metrics.shape[-1]):
            metric = defensive_metrics[..., i:i+1]
            metric_embed = self.metric_encoders[i](metric)
            metric_embeddings.append(metric_embed)
        
        metric_tensor = torch.stack(metric_embeddings, dim=-2)
        
        # 2. Process position history
        history_encoding, _ = self.history_encoder(position_history)
        last_state = history_encoding[:, -1]
        
        # 3. Age-based transition probabilities
        age_transitions = self.age_transition(age)
        age_transitions = age_transitions.view(batch_size, self.position_hierarchy.shape[0], -1)
        
        # 4. Current position context
        current_embed = torch.matmul(current_position, self.position_embeddings)
        
        # 5. Apply metric attention for each position
        position_metrics = []
        for pos in range(defensive_metrics.shape[1]):
            pos_embed = self.position_embeddings[pos:pos+1].expand(batch_size, -1)
            attended_metrics, _ = self.metric_attention(
                pos_embed.unsqueeze(1),
                metric_tensor,
                metric_tensor
            )
            position_metrics.append(attended_metrics.squeeze(1))
        position_metrics = torch.stack(position_metrics, dim=1)
        
        # 6. Combine all information
        combined_state = torch.cat([
            current_embed,
            last_state,
            position_metrics.mean(dim=1)
        ], dim=-1)
        
        # 7. Generate transition probabilities
        logits = self.position_predictor(combined_state)
        
        # 8. Apply position hierarchy constraints
        hierarchy_weights = F.softmax(self.position_hierarchy, dim=-1)
        current_pos_idx = current_position.argmax(dim=-1)
        transition_constraints = hierarchy_weights[current_pos_idx]
        
        # 9. Final position probabilities
        position_probs = F.softmax(logits * transition_constraints, dim=-1)
        
        return position_probs

Define our custom loss function for defense

In [5]:
class DefensiveMetricsLoss(nn.Module):
    def __init__(self, metric_weights: Optional[Dict[str, float]] = None):
        super().__init__()
        self.metric_weights = metric_weights or {
            'inn': 0.5,
            'drs': 1.0,
            'uzr': 1.0,
            'oaa': 1.0
        }

    def forward(self, pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """
        Args:
            pred: [batch_size, n_positions, n_metrics]
            target: [batch_size, n_positions, n_metrics]
            mask: [batch_size, n_positions]
        """
        # Add assertions to catch shape mismatches
        assert pred.dim() == 3, f"Expected 3D tensor, got shape {pred.shape}"
        assert pred.shape == target.shape, f"Shape mismatch: pred {pred.shape} vs target {target.shape}"
        assert mask.shape == pred.shape[:2], f"Mask shape {mask.shape} doesn't match pred {pred.shape[:2]}"
        
        # Expand mask for broadcasting
        mask = mask.unsqueeze(-1)  # [batch_size, n_positions, 1]
        
        total_loss = 0.0
        total_weight = sum(self.metric_weights.values())
        
        for i, weight in enumerate(self.metric_weights.values()):
            # Calculate squared error for this metric
            diff = (pred[..., i] - target[..., i]) ** 2  # [batch_size, n_positions]
            
            # Apply mask and normalize by number of valid positions
            masked_diff = diff * mask.squeeze(-1)
            valid_positions = mask.squeeze(-1).sum(dim=1) + 1e-8  # [batch_size]
            
            # Calculate mean loss per batch item
            metric_loss = (masked_diff.sum(dim=1) / valid_positions).mean()
            
            # Add weighted contribution
            total_loss += (weight / total_weight) * metric_loss
        
        return total_loss

In [6]:
class PositionTransitionLoss(nn.Module):
    def __init__(self, position_weights: Optional[torch.Tensor] = None):
        super().__init__()
        self.position_weights = position_weights
        
    def forward(self, pred_probs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            pred_probs: [batch, n_positions] - Predicted position probabilities
            target: [batch, n_positions] - One-hot encoded target positions
        """
        # Cross entropy with position weights if provided
        if self.position_weights is not None:
            loss = F.cross_entropy(pred_probs, target.argmax(dim=-1), 
                                 weight=self.position_weights)
        else:
            loss = F.cross_entropy(pred_probs, target.argmax(dim=-1))
        
        # Add regularization to prevent overconfident predictions
        entropy_reg = -torch.mean(torch.sum(pred_probs * torch.log(pred_probs + 1e-8), dim=-1))
        
        return loss - 0.1 * entropy_reg  # Small entropy regularization weight

Define training pipeline

In [7]:
class Trainer:
    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        valid_loader: DataLoader,
        criterion: nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler._LRScheduler,
        device: torch.device,
        patience: int = 10,
        max_grad_norm: float = 1.0,
        checkpoint_dir: str = 'checkpoints'
    ):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.patience = patience
        self.max_grad_norm = max_grad_norm
        self.checkpoint_dir = checkpoint_dir
        
        os.makedirs(checkpoint_dir, exist_ok=True)
        
    def train_epoch(self) -> float:
        self.model.train()
        total_loss = 0.0
        
        with tqdm(self.train_loader, desc='Training') as pbar:
            for x, mask, weights, y in pbar:
                x = x.to(self.device)
                mask = mask.to(self.device)
                weights = weights.to(self.device)
                y = y.to(self.device)
                
                self.optimizer.zero_grad()
                output = self.model(x, mask, weights)
                loss = self.criterion(output, y, weights)
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
                self.optimizer.step()
                
                total_loss += loss.item()
                pbar.set_postfix({'loss': loss.item()})
                
        return total_loss / len(self.train_loader)
    
    def validate(self) -> float:
        self.model.eval()
        total_loss = 0.0
        
        with torch.no_grad():
            for batch in self.valid_loader:
                # Unpack position-specific data correctly
                history = batch['history'].to(self.device)
                history_mask = batch['history_mask'].to(self.device)
                target = batch['target'].to(self.device)
                target_mask = batch['target_mask'].to(self.device)
                age = batch['age'].to(self.device)
                
                output = self.model(history, history_mask, age)
                loss = self.criterion(output, target, target_mask)
                total_loss += loss.item()
                
        return total_loss / len(self.valid_loader)
    
    def train(self, num_epochs: int):
        best_valid_loss = float('inf')
        patience_counter = 0
        
        for epoch in range(num_epochs):
            train_loss = self.train_epoch()
            valid_loss = self.validate()
            self.scheduler.step()
            
            logger.info(f'Epoch {epoch+1}/{num_epochs}:')
            logger.info(f'Train Loss: {train_loss:.4f}')
            logger.info(f'Valid Loss: {valid_loss:.4f}')
            
            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                patience_counter = 0
                self.save_checkpoint(epoch, valid_loss)
            else:
                patience_counter += 1
                
            if patience_counter >= self.patience:
                logger.info(f'Early stopping triggered after {epoch+1} epochs')
                break
                
    def save_checkpoint(self, epoch: int, valid_loss: float):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'valid_loss': valid_loss
        }
        path = os.path.join(self.checkpoint_dir, f'model_epoch_{epoch}_loss_{valid_loss:.4f}.pt')
        torch.save(checkpoint, path)
        logger.info(f'Checkpoint saved: {path}')

In [8]:
class DefensiveMetricsTrainer(Trainer):
    def train_epoch(self) -> float:
        self.model.train()
        total_loss = 0.0
        
        with tqdm(self.train_loader, desc='Training Metrics') as pbar:
            for batch in pbar:
                # Unpack position-specific data
                history = batch['history'].to(self.device)
                history_mask = batch['history_mask'].to(self.device)
                target = batch['target'].to(self.device)
                target_mask = batch['target_mask'].to(self.device)
                age = batch['age'].to(self.device)
                
                self.optimizer.zero_grad()
                output = self.model(history, history_mask, age)
                loss = self.criterion(output, target, target_mask)
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
                self.optimizer.step()
                
                total_loss += loss.item()
                pbar.set_postfix({'loss': loss.item()})
                
        return total_loss / len(self.train_loader)

class PositionTransitionTrainer(Trainer):
    def __init__(self, metrics_model: nn.Module, **kwargs):
        super().__init__(**kwargs)
        self.metrics_model = metrics_model  # Store metrics model
    def train_epoch(self) -> float:
        self.model.train()
        total_loss = 0.0
        
        with tqdm(self.train_loader, desc='Training Position') as pbar:
            for batch in pbar:
                # Get defensive metrics predictions first
                with torch.no_grad():
                    defensive_metrics = self.metrics_model(
                        batch['history'].to(self.device),
                        batch['history_mask'].to(self.device),
                        batch['age'].to(self.device)
                    )
                
                # Prepare position transition inputs
                position_history = batch['history_mask'].to(self.device)  # Use mask as one-hot position encoding
                current_position = position_history[:, -1]  # Last position in sequence
                age = batch['age'][:, -1].unsqueeze(-1).to(self.device)  # Current age
                
                self.optimizer.zero_grad()
                position_probs = self.model(
                    defensive_metrics=defensive_metrics,
                    position_history=position_history,
                    age=age,
                    current_position=current_position
                )
                
                # Target is next season's position
                target_position = batch['target_mask'].to(self.device)
                loss = self.criterion(position_probs, target_position)
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
                self.optimizer.step()
                
                total_loss += loss.item()
                pbar.set_postfix({'loss': loss.item()})
                
        return total_loss / len(self.train_loader)

Train model

In [9]:
# Initialize models and optimizers
metrics_model = DefensiveMetricsPredictor().to(device)
transition_model = PositionTransitionModel().to(device)

# Initialize losses
metrics_loss = DefensiveMetricsLoss()
transition_loss = PositionTransitionLoss()

# Initialize optimizers with different learning rates
metrics_optimizer = optim.AdamW(metrics_model.parameters(), lr=1e-4, weight_decay=0.01)
transition_optimizer = optim.AdamW(transition_model.parameters(), lr=5e-5, weight_decay=0.01)

# Initialize schedulers
metrics_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    metrics_optimizer, 
    T_0=10, 
    T_mult=2
)
transition_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    transition_optimizer, 
    T_0=10, 
    T_mult=2
)

# Initialize trainers
metrics_trainer = DefensiveMetricsTrainer(
    model=metrics_model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    criterion=metrics_loss,
    optimizer=metrics_optimizer,
    scheduler=metrics_scheduler,
    device=device,
    patience=10,
    max_grad_norm=1.0,
    checkpoint_dir='checkpoints/metrics'
)

transition_trainer = PositionTransitionTrainer(
    model=transition_model,
    metrics_model=metrics_model,  # Pass metrics model for predictions
    train_loader=train_loader,
    valid_loader=valid_loader,
    criterion=transition_loss,
    optimizer=transition_optimizer,
    scheduler=transition_scheduler,
    device=device,
    patience=15,  # Longer patience for position model
    max_grad_norm=0.5,
    checkpoint_dir='checkpoints/transition'
)

# Training loop
def train_models(num_epochs=100):
    logger.info("Training Defensive Metrics Model...")
    metrics_trainer.train(num_epochs)
    
    logger.info("\nTraining Position Transition Model...")
    metrics_model.eval()  # Freeze metrics model
    transition_trainer.train(num_epochs)
    
    return metrics_model, transition_model

# Train both models
metrics_model, transition_model = train_models(num_epochs=100)

INFO:__main__:Training Defensive Metrics Model...
Training Metrics:   0%|          | 0/22 [00:00<?, ?it/s]


AssertionError: Expected 3D tensor, got shape torch.Size([32, 9, 1, 4])

Evaluation