In [1]:
import pickle
import pandas as pd

$\textbf{Epidemiology: Clinical Model Training}$

Injury risk estimation is possible from two perspectives: 
- __Season__ (i.e., post-outing injury probability; postgame)
- __Pitch-Level__ (i.e., next-pitch injury probability; within game)

This notebook uses the data structures and model architecture established in `clinical_model_setup.ipynb`. The architecture was shown to work with a manually crafted batched-training loop.

In [130]:
# download most up-to-date tensor dictionaries
    # TODO: wait until the tensors are uploaded to S3
# pitch_level_tensors = aws.s3.download_file(
#     aws.bucket_name, 
#     'epidemiology/ml/datasets/pytorch/pitch_level_tensors.pkl', 
#     'storage/pitch_level_tensors.pkl'
# )

# load tensors into memory
    # outing_ --> one outcome per pitcher
with open('storage/outing_level_tensors.pkl', 'rb') as f:
    pitch_level_tensors = pickle.load(f)

$\textbf{Model Development}$

In [9]:
import torch
from tqdm import tqdm
import torch.nn as nn
from nnet import CNNbiLSTM
import torch.nn.functional as F
from nnet.loss_functions import pitch_level_loss
from sklearn.metrics import roc_auc_score, f1_score
from services.scaling import compute_masked_scalers, apply_scalers

In [151]:
# compile model for training --> model, optimizer, loss function, and data setup
def compile_model(
        train_data: dict,
        val_data: dict,
        model_config: dict = {
            'stem': 64,
            'c': 96,
            'kernel': 7,
            'lstm_hidden': 128,
            'dropout': 0.1,
            'bidir': True
        },
        device: str = 'cpu',
        use_pos_weight: bool = False
) -> dict:
    """ 
    Compile a nnet architecture for training. Includes scaling, model instantiation, and loss function setup.
    
    Args:
        train_data (dict): Training data containing tensors and masks.
        val_data (dict): Validation data containing tensors and masks.
        model_config (dict): Configuration for the CNNbiLSTM model. Defaults to basic setup.
        device (str): Device to run the model on ('cpu' or 'cuda').
        use_pos_weight (bool): Whether to use positive class weights in loss function.
    
    Returns:
        model_setup (dict): Dictionary with all relevant information..
    """

    # standardize example sequence --> example training tensor sequence (x)
    mean, std = compute_masked_scalers(train_data['seq'], train_data['mask'])
    x = apply_scalers(train_data['seq'], mean, std)
    B, T, K = x.shape       # NOTE: B = batch size, T = time steps, K = features; K is used for model setup

    # setup (device, shapes)
        # --> move full tensors to device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    x_trn, y_binary_trn, mask_trn, lengths_trn = move_tensors_to_device(x, train_data, device)

    # setup model, optimizer
    model = CNNbiLSTM(
        k_in=K, 
        stem=model_config.get('stem', 64),
        c=model_config.get('c', 96),
        kernel=model_config.get('kernel', 7),
        lstm_hidden=model_config.get('lstm_hidden', 128),
        dropout=model_config.get('dropout', 0.1),
        bidir=model_config.get('bidir', True),
    ).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

    # optional: compute pos_weight over valid steps once
        # NOTE: for probs this should be 1.0
    if use_pos_weight:
        with torch.no_grad():
            pos = y_binary_trn.sum()
            neg = y_binary_trn.shape[0] - pos
            pos_weight = (neg / pos.clamp(min=1)).float()
    else:
        pos_weight = torch.tensor(1.0, device=device)

    # package all into dictionary
    model_setup = {
        'model': model.train(),
        'optimizer': optimizer,
        'loss_fn': pitch_level_loss,
        'pos_weight': pos_weight,
        'mean': mean,
        'std': std,
        'device': device,
        'train_setup': {
            'x': x_trn,
            'y_binary': y_binary_trn,
            'mask': mask_trn,
            'lengths': lengths_trn
        },
        'val_setup': {
            'x': apply_scalers(val_data['seq'], mean, std).to(device),
            'y_binary': val_data['binary'].float().to(device),
            'mask': val_data['mask'].bool().to(device),
            'lengths': val_data['lengths'].float().to(device)
        }
    }

    return model_setup

# move tensors to device
    # mostly a helper function for training data 
def move_tensors_to_device(
        x: torch.Tensor,
        train_data: dict,
        device: str
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Move tensors to the specified device.
    
    Args:
        x (torch.Tensor): Input tensor.
        train_data (dict): Dictionary containing training data tensors.
        device (str): Device to move tensors to ('cpu' or 'cuda').
    
    Returns:
        tuple: Tensors moved to the specified device.
    """
    return x.float().to(device), train_data['binary'].float().to(device), train_data['mask'].bool().to(device), train_data['lengths'].long().to(device)
    
# early stopping check
def check_early_stopping(
        val_loss: float,
        best_val_loss: float,
        epochs_no_improve: int,
        patience: int
) -> dict:
    """
    Check if early stopping criteria are met.
    
    Args:
        val_loss (float): Current validation loss.
        best_val_loss (float): Best validation loss observed so far.
        epochs_no_improve (int): Number of epochs since last improvement.
        patience (int): Patience for early stopping.
    
    Returns:
        dict: Dictionary indicating whether to stop, if there's a new best, updated best loss, and epochs without improvement.
    """
    if val_loss < best_val_loss:
        return {
            'should_stop': False,
            'new_best': True,
            'best_val_loss': val_loss,
            'epochs_no_improve': 0
        }
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            return {
                'should_stop': True,
                'new_best': False,
                'best_val_loss': best_val_loss,
                'epochs_no_improve': epochs_no_improve
            }
        else:
            return {
                'should_stop': False,
                'new_best': False,
                'best_val_loss': best_val_loss,
                'epochs_no_improve': epochs_no_improve
            }

# ROC/AUC calcs
def compute_roc_auc( 
        probs: torch.Tensor,
        y_true: torch.Tensor,
        mask: torch.Tensor
) -> tuple[float, float, float]:
    """
    Compute ROC AUC score.
    
    Args:
        probs (torch.Tensor): Predicted probabilities.
        y_true (torch.Tensor): True labels.
        mask (torch.Tensor): Mask indicating valid entries.
    
    Returns:
        tuple: (batch_scores (torch.Tensor), batch_labels (torch.Tensor), roc_auc (float))
    """
    # per-batch valid scores/labels for AUC (use binary labels)
    batch_scores = probs.detach().float()
    batch_labels = y_true[mask].detach().float()

    # compute AUC
    if batch_labels.sum() == 0 or batch_labels.sum() == len(batch_labels):
        roc_auc = 0.5 
    else:
        roc_auc = roc_auc_score(batch_labels.cpu().numpy(), batch_scores.cpu().numpy())

    return batch_scores, batch_labels, roc_auc

# training loop for a model
    # NOTE: sets up a dictionary for storing model loss history
def train_model(
        model_name: str,
        model_setup: dict,
        n_epochs: int = 100,
        batch_size: int = 32,
        patience: int = 10,
        verbose: bool = True,
        epoch_start: int = 1
) -> dict:
    """
    Train the model using the provided training and validation data. Manually handles batching and early stopping.
    
    Args:
        model_setup (dict): Dictionary containing model, optimizer, loss function, and data setup.
        n_epochs (int): Number of epochs to train. Defaults to 100.
        batch_size (int): Batch size for training. Defaults to 32.
        patience (int): Patience for early stopping. Defaults to 10.
        verbose (bool): Whether to print training progress. Defaults to True.
        epoch_start (int): Starting epoch number (useful for resuming training). Defaults to 1.
    
    Returns:
        dict: Dictionary containing training and validation loss history.
    """
    # unpack model setup
    model = model_setup['model']
    optimizer = model_setup['optimizer']
    pos_weight = model_setup['pos_weight']

    # extract training data
    x = model_setup['train_setup']['x']
    y_step_binary = model_setup['train_setup']['y_binary']
    mask = model_setup['train_setup']['mask']
    lengths = model_setup['train_setup']['lengths']

    # setup loss storage
    history = {
        'train_loss': [],
        'val_loss': [],
        'best_val_preds': None,
        'best_val_loss': None,
        'stopped_epoch': None
    }

    # loss counters
    running_loss = 0.0
    running_correct = 0
    running_total = 0

    # early stopping setup
    best_val_loss = float('inf')
    epochs_no_improve = 0


    # MAIN LOOP: iterate through epochs
    B = x.shape[0]
    for epoch in range(0 + epoch_start, epoch_start + n_epochs + 1):
        
        """ TRAINING """
        # set model to train mode, shuffle indices, create batches
        model.train()
        idx = torch.randperm(B, device=model_setup['device'])  # shuffle indices each epoch

        # loss counters
        running_loss = 0.0

        # wrap range() with tqdm
        pbar = tqdm(range(0, B, batch_size), desc=f"Epoch {epoch}/{epoch_start + n_epochs}", leave=False)
        for start in pbar:
            end = min(start + batch_size, B)
            bidx = idx[start:end]

            # forward pass setup
            xb = x[bidx]                        # [B,T,K]
            yb = y_step_binary[bidx]            # [B] sequence-level labels (0/1)
            mb = mask[bidx]                     # [B,T]
            Lb = lengths[bidx]                  # [B]

            # forward pass: get per-pitch + sequence logits
            _, logit_seq, _ = model(xb, Lb, mb)

            # main sequence-level loss
            loss = F.binary_cross_entropy_with_logits(
                logit_seq, yb.float(), pos_weight=pos_weight
            )

            # back propagation
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            # running loss
            bs = xb.size(0)
            running_loss += loss.item() * bs
            
        # accuracy (sequence-level)
            probs  = torch.sigmoid(logit_seq)
            preds  = (probs > 0.5).float()
            correct = (preds == yb).sum().item()
            total   = yb.size(0)

            # update total
            running_correct += correct
            running_total += total

            # update accuracy
            run_avg_loss = running_loss / ((start // batch_size + 1) * bs)
            run_acc      = running_correct / running_total

            # show both current and running losses in the bar
            pbar.set_postfix(loss=f"{run_avg_loss:.4f}", acc=f"{run_acc:.4f}")

        # update epoch loss
        epoch_loss = running_loss / B
        history['train_loss'].append(epoch_loss)
        
        # print epoch update
        if verbose:
            print(f"Epoch {epoch:02d} | Training Loss: {epoch_loss:.3f}")

        """ VALIDATION """
        # set model to eval mode
        model.eval()
        with torch.no_grad():
            _, logits, _ = model(
                model_setup['val_setup']['x'].float(), 
                model_setup['val_setup']['lengths'].float(),
                model_setup['val_setup']['mask'].bool()
            )
            val_loss = F.binary_cross_entropy_with_logits(
                logits, 
                model_setup['val_setup']['y_binary'], 
                pos_weight=model_setup['pos_weight']
            ).item()
            history['val_loss'].append(val_loss)
            
            # print update
            if verbose:
                print(f"Epoch {epoch:02d} | Validation Loss: {val_loss:.3f}")

            """ EARLY STOPPING CHECK """
            early_stop_info = check_early_stopping(
                val_loss, best_val_loss, epochs_no_improve, patience
            )
            
            # update values with results
            best_val_loss = early_stop_info['best_val_loss']
            epochs_no_improve = early_stop_info['epochs_no_improve']

            # option 1: new best --> save intermediate model
            if early_stop_info['new_best']:

                # update history
                history['best_val_loss'] = best_val_loss
                history['best_val_preds'] = torch.sigmoid(logits).detach().cpu().numpy()
                
                # save model (+ state dict)
                    # ensure model is in eval mode before saving
                model.eval()
                torch.save(model, f"models/inj/{model_name}.pt")
                torch.save(model.state_dict(), f"models/inj/{model_name}_state_dict.pt")

                # print update
                if verbose:
                    print(f"New best model saved with validation loss: {best_val_loss:.3f}")

            # option 2: not best, but not yet time to stop --> continue
            if not early_stop_info['should_stop']:
                continue

            # option 3: stop training early
            if early_stop_info['should_stop']:
                if verbose:
                    print(f"Early stopping triggered after {epoch} epochs. Best validation loss: {best_val_loss:.3f}")
                
                # add to history
                history['stopped_epoch'] = epoch

                # NOTE: don't need to save model here, since we already saved the best one above
                
                return history

    # NOTE: if we get here, training completed without early stopping
    if verbose:
        print(f"Training completed after {epoch} epochs with best validation loss: {best_val_loss:.3f}")

    # save final model
        # ensure model is in eval mode before saving
    model.eval()
    torch.save(model, f"models/inj/{model_name}.pt")
    torch.save(model.state_dict(), f"models/inj/{model_name}_state_dict.pt")

    return history


In [None]:
# create model setups for outing- and pitch-level data by day windows
model_setups = {}
model_results = {}

# compile models for each day window
model_setup = compile_model(
    train_data=pitch_level_tensors['trn'],
    val_data=pitch_level_tensors['val'],
    use_pos_weight=True,
    model_config={
        'stem': 64,
        'c': 96,
        'kernel': 7,
        'lstm_hidden': 64,
        'dropout': 0.1,
        'bidir': True
    },
)
    

In [None]:
""" TRAIN MODEL """
print(f'Training pitch-level model...')

# pitch-level model loop
    # 60 epochs run so far (as of 8/25 am)
    # save in pitch_model_results
ext_results = train_model(
    model_name=f'pitch_model_test',     # NOTE: test
    model_setup=model_setup,
    n_epochs=1,
    batch_size=32,
    patience=10,
    verbose=True, 
    epoch_start=0
)

Training pitch-level model...


                                                                                   

Epoch 00 | Training Loss: 1.070
Epoch 00 | Validation Loss: 1.019
New best model saved with validation loss: 1.019


                                                                                   

Epoch 01 | Training Loss: 1.070
Epoch 01 | Validation Loss: 1.017
New best model saved with validation loss: 1.017
Training completed after 1 epochs with best validation loss: 1.017


In [None]:
# save to disk
# with open(f'storage/pitch_model_results.pkl', 'wb') as f:
#     pickle.dump(ext_results, f)

# TODO: upload to S3
# with open(f'storage/pitch_model_results.pkl', 'rb') as f:
#     content = f.read()
#     aws.upload_to_s3(content, f'epidemiology/ml/models/inj/pitch_model_results.pkl')

$\textbf{Save Preprocessed Train/Val Data}$

Mostly to save the scaled data, since saving scalers isn't finalized yet.

In [162]:
# get model setup
day_setup_train = model_setup['train_setup']
day_setup_val = model_setup['val_setup']

# save to disk
with open(f'models/data/pitch_model_trn_data.pkl', 'wb') as f:
    pickle.dump(day_setup_train, f)
with open(f'models/data/pitch_model_val_data.pkl', 'wb') as f:
    pickle.dump(day_setup_val, f)

# TODO: upload to S3
# with open(f'models/data/pitch_model_{day}d_trn_data.pkl', 'rb') as f:
#     content = f.read()
#     aws.upload_to_s3(content, f'epidemiology/ml/models/inj/pitch_model_{day}d_trn_data.pkl')
# with open(f'models/data/pitch_model_{day}d_val_data.pkl', 'rb') as f:
#     content = f.read()
#     aws.upload_to_s3(content, f'epidemiology/ml/models/inj/pitch_model_{day}d_val_data.pkl')