In [15]:
import pickle
import pandas as pd
from connections import AWS

$\textbf{Epidemiology: Clinical Model Training}$

Injury risk estimation is possible from two perspectives: 
- __Season__ (i.e., post-outing injury probability; postgame)
- __Pitch-Level__ (i.e., next-pitch injury probability; within game)

This notebook uses the data structures and model architecture established in `clinical_model_setup.ipynb`. The architecture was shown to work with a manually crafted batched-training loop.

In [2]:
# setup AWS connection
aws = AWS()
aws.connect()

[AWS]: Port 5433 is free.
[AWS]: Connected to RDS endpoint.


In [230]:
# download most up-to-date tensor dictionaries
    # TODO: wait until the tensors are uploaded to S3
# pitch_level_tensors = aws.s3.download_file(
#     aws.bucket_name, 
#     'epidemiology/ml/datasets/pytorch/pitch_level_tensors.pkl', 
#     'storage/pitch_level_tensors.pkl'
# )

# load tensors into memory
with open('storage/pitch_level_tensors.pkl', 'rb') as f:
    pitch_level_tensors = pickle.load(f)

$\textbf{Model Development}$

In [236]:
import torch
from tqdm import tqdm
import torch.nn as nn
from nnet import CNNbiLSTM
import torch.nn.functional as F
from nnet.loss_functions import pitch_level_loss
from sklearn.metrics import roc_auc_score, f1_score
from services.scalers import compute_masked_scalers, apply_scalers

In [262]:
# compile model for training --> model, optimizer, loss function, and data setup
def compile_model(
        train_data: dict,
        val_data: dict,
        model_config: dict = {
            'stem': 64,
            'c': 96,
            'kernel': 7,
            'lstm_hidden': 128,
            'dropout': 0.1,
            'bidir': True
        },
        device: str = 'cpu',
        use_pos_weight: bool = False
) -> dict:
    """ 
    Compile a nnet architecture for training. Includes scaling, model instantiation, and loss function setup.
    
    Args:
        train_data (dict): Training data containing tensors and masks.
        val_data (dict): Validation data containing tensors and masks.
        model_config (dict): Configuration for the CNNbiLSTM model. Defaults to basic setup.
        device (str): Device to run the model on ('cpu' or 'cuda').
        use_pos_weight (bool): Whether to use positive class weights in loss function.
    
    Returns:
        model_setup (dict): Dictionary with all relevant information..
    """

    # standardize example sequence --> example training tensor sequence (x)
    mean, std = compute_masked_scalers(train_data['seq'], train_data['mask'])
    x = apply_scalers(train_data['seq'], mean, std)
    B, T, K = x.shape       # NOTE: B = batch size, T = time steps, K = features; K is used for model setup

    # setup (device, shapes)
        # --> move full tensors to device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    x_trn, y_step_trn, y_binary_trn, mask_trn, lengths_trn = move_tensors_to_device(x, train_data, device)

    # setup model, optimizer
    model = CNNbiLSTM(
        k_in=K, 
        stem=model_config.get('stem', 64),
        c=model_config.get('c', 96),
        kernel=model_config.get('kernel', 7),
        lstm_hidden=model_config.get('lstm_hidden', 128),
        dropout=model_config.get('dropout', 0.1),
        bidir=model_config.get('bidir', True),
    ).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

    # optional: compute pos_weight over valid steps once
        # NOTE: for probs this should be 1.0
    if use_pos_weight:
        with torch.no_grad():
            pos = y_binary_trn[mask_trn].sum()
            tot = mask_trn.sum()
            neg = tot - pos
            pos_weight = (neg / pos.clamp(min=1)).float()
    else:
        pos_weight = torch.tensor(1.0, device=device)

    # package all into dictionary
    model_setup = {
        'model': model.train(),
        'optimizer': optimizer,
        'loss_fn': pitch_level_loss,
        'pos_weight': pos_weight,
        'mean': mean,
        'std': std,
        'device': device,
        'train_setup': {
            'x': x_trn,
            'y_step': y_step_trn,
            'y_binary': y_binary_trn,
            'mask': mask_trn,
            'lengths': lengths_trn
        },
        'val_setup': {
            'x': apply_scalers(val_data['seq'], mean, std).to(device),
            'y_step': val_data['probs'].float().to(device),
            'y_binary': val_data['binary'].float().to(device),
            'mask': val_data['mask'].bool().to(device),
            'lengths': val_data['lengths'].float().to(device)
        }
    }

    return model_setup

# move tensors to device
    # mostly a helper function for training data 
def move_tensors_to_device(
        x: torch.Tensor,
        train_data: dict,
        device: str
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Move tensors to the specified device.
    
    Args:
        x (torch.Tensor): Input tensor.
        train_data (dict): Dictionary containing training data tensors.
        device (str): Device to move tensors to ('cpu' or 'cuda').
    
    Returns:
        tuple: Tensors moved to the specified device.
    """
    return x.float().to(device), train_data['probs'].float().to(device), train_data['binary'].float().to(device), train_data['mask'].bool().to(device), train_data['lengths'].long().to(device)
    
# early stopping check
def check_early_stopping(
        val_loss: float,
        best_val_loss: float,
        epochs_no_improve: int,
        patience: int
) -> dict:
    """
    Check if early stopping criteria are met.
    
    Args:
        val_loss (float): Current validation loss.
        best_val_loss (float): Best validation loss observed so far.
        epochs_no_improve (int): Number of epochs since last improvement.
        patience (int): Patience for early stopping.
    
    Returns:
        dict: Dictionary indicating whether to stop, if there's a new best, updated best loss, and epochs without improvement.
    """
    if val_loss < best_val_loss:
        return {
            'should_stop': False,
            'new_best': True,
            'best_val_loss': val_loss,
            'epochs_no_improve': 0
        }
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            return {
                'should_stop': True,
                'new_best': False,
                'best_val_loss': best_val_loss,
                'epochs_no_improve': epochs_no_improve
            }
        else:
            return {
                'should_stop': False,
                'new_best': False,
                'best_val_loss': best_val_loss,
                'epochs_no_improve': epochs_no_improve
            }

# ROC/AUC calcs
def compute_roc_auc( 
        probs: torch.Tensor,
        y_true: torch.Tensor,
        mask: torch.Tensor
) -> tuple[float, float, float]:
    """
    Compute ROC AUC score.
    
    Args:
        probs (torch.Tensor): Predicted probabilities.
        y_true (torch.Tensor): True labels.
        mask (torch.Tensor): Mask indicating valid entries.
    
    Returns:
        tuple: (batch_scores (torch.Tensor), batch_labels (torch.Tensor), roc_auc (float))
    """
    # per-batch valid scores/labels for AUC (use binary labels)
    batch_scores = probs.detach().float()
    batch_labels = y_true[mask].detach().float()

    # compute AUC
    if batch_labels.sum() == 0 or batch_labels.sum() == len(batch_labels):
        roc_auc = 0.5 
    else:
        roc_auc = roc_auc_score(batch_labels.cpu().numpy(), batch_scores.cpu().numpy())

    return batch_scores, batch_labels, roc_auc

# training loop for a model
    # NOTE: sets up a dictionary for storing model loss history
def train_model(
        model_name: str,
        model_setup: dict,
        n_epochs: int = 100,
        batch_size: int = 32,
        patience: int = 10,
        verbose: bool = True
) -> dict:
    """
    Train the model using the provided training and validation data. Manually handles batching and early stopping.
    
    Args:
        model_setup (dict): Dictionary containing model, optimizer, loss function, and data setup.
        n_epochs (int): Number of epochs to train. Defaults to 100.
        batch_size (int): Batch size for training. Defaults to 32.
        patience (int): Patience for early stopping. Defaults to 10.
        verbose (bool): Whether to print training progress. Defaults to True.
    
    Returns:
        dict: Dictionary containing training and validation loss history.
    """
    # unpack model setup
    model = model_setup['model']
    optimizer = model_setup['optimizer']
    loss_fn = model_setup['loss_fn']
    pos_weight = model_setup['pos_weight']

    # extract training data
    x = model_setup['train_setup']['x']
    y_step_binary = model_setup['train_setup']['y_binary']
    mask = model_setup['train_setup']['mask']
    lengths = model_setup['train_setup']['lengths']

    # setup loss storage
    history = {
        'train_loss': [],
        'val_loss': [],
        'best_val_loss': None,
        'stopped_epoch': None
    }

    # early stopping setup
    best_val_loss = float('inf')
    epochs_no_improve = 0


    # MAIN LOOP: iterate through epochs
    B = x.shape[0]
    for epoch in range(1, n_epochs+1):
        
        """ TRAINING """
        # set model to train mode, shuffle indices, create batches
        model.train()
        idx = torch.randperm(B, device=model_setup['device'])  # shuffle indices each epoch

        # loss counters
        running_loss = 0.0

        # wrap range() with tqdm
        pbar = tqdm(range(0, B, batch_size), desc=f"Epoch {epoch}/{n_epochs}", leave=False)
        for start in pbar:
            end = min(start + batch_size, B)
            bidx = idx[start:end]

            # forward pass
            xb = x[bidx]
            yb = y_step_binary[bidx]
            mb = mask[bidx]
            Lb = lengths[bidx]

            # update loss
            logits = model(xb, Lb)
            loss = F.binary_cross_entropy_with_logits(
                logits[mb], yb[mb], pos_weight=pos_weight
            )

            # back propagation
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            # running loss
            bs = xb.size(0)
            running_loss += loss.item() * bs
            
            # running accuracy
            probs  = torch.sigmoid(logits[mb])
            preds  = (probs > 0.5).float()
            correct = (preds == y_step_binary[bidx][mb]).sum().item()
            total   = mb.sum().item()

            # shape checks
            assert preds.shape == y_step_binary[bidx][mb].shape
            assert probs.shape == y_step_binary[bidx][mb].shape

            # update loss
            run_avg_loss = running_loss / ((start // batch_size + 1) * bs)

            # show both current and running losses in the bar
            pbar.set_postfix(loss=f"{run_avg_loss:.4f}")

        # update epoch loss
        epoch_loss = running_loss / B
        
        # print epoch update
        if verbose:
            print(f"Epoch {epoch:02d} | Training Loss: {epoch_loss:.3f}")

        """ VALIDATION """
        # set model to eval mode
        model.eval()
        with torch.no_grad():
            logits = model(
                model_setup['val_setup']['x'].float(), 
                model_setup['val_setup']['lengths'].float()
            )
            val_loss = loss_fn(
                logits,
                model_setup['val_setup']['y_step'],
                model_setup['val_setup']['mask'],
                pos_weight
            ).item()
            history['val_loss'].append(val_loss)
            
            # print update
            if verbose:
                print(f"Epoch {epoch:02d} | Validation Loss: {val_loss:.3f}")

            """ EARLY STOPPING CHECK """
            early_stop_info = check_early_stopping(
                val_loss, best_val_loss, epochs_no_improve, patience
            )
            
            # update values with results
            best_val_loss = early_stop_info['best_val_loss']
            epochs_no_improve = early_stop_info['epochs_no_improve']

            # option 1: new best --> save intermediate model
            if early_stop_info['new_best']:

                # update history
                history['best_val_loss'] = best_val_loss
                
                # save model (+ state dict)
                    # ensure model is in eval mode before saving
                model.eval()
                torch.save(model, f"models/inj/{model_name}.pt")
                torch.save(model.state_dict(), f"models/inj/{model_name}_state_dict.pt")

                # print update
                if verbose:
                    print(f"New best model saved with validation loss: {best_val_loss:.3f}")

            # option 2: not best, but not yet time to stop --> continue
            if not early_stop_info['should_stop']:
                continue

            # option 3: stop training early
            if early_stop_info['should_stop']:
                if verbose:
                    print(f"Early stopping triggered after {epoch} epochs. Best validation loss: {best_val_loss:.3f}")
                
                # add to history
                history['stopped_epoch'] = epoch

                # NOTE: don't need to save model here, since we already saved the best one above
                
                return history

    # NOTE: if we get here, training completed without early stopping
    if verbose:
        print(f"Training completed after {epoch} epochs with best validation loss: {best_val_loss:.3f}")

    # save final model
        # ensure model is in eval mode before saving
    model.eval()
    torch.save(model, f"models/inj/{model_name}.pt")
    torch.save(model.state_dict(), f"models/inj/{model_name}_state_dict.pt")

    return history


In [263]:
# create model setups for outing- and pitch-level data by day windows
model_setups = {}
model_results = {}

# compile models for each day window
for day in [7, 15, 30, 45, 90]:
    model_setups[day] = compile_model(
        train_data=pitch_level_tensors['trn'][day],
        val_data=pitch_level_tensors['val'][day],
        use_pos_weight=True
    )
    

In [265]:
""" PITCH-LEVEL MODELS """
# for day in [7, 15, 30, 45, 90]:
for day in [7]:
    
    print(f'Training pitch-level model for preceding day window: {day} days')
    
    # pitch-level model loop
    pitch_model_results = train_model(
        model_name=f'pitch_model_{day}d',
        model_setup=model_setups[day],
        n_epochs=100,
        batch_size=32,
        patience=5,
        verbose=True
    )

    # save to disk, upload to S3
    with open(f'storage/pitch_model_{day}d_results.pkl', 'wb') as f:
        pickle.dump(pitch_model_results, f)
    with open(f'storage/pitch_model_{day}d_results.pkl', 'rb') as f:
        content = f.read()
        aws.upload_to_s3(content, f'epidemiology/ml/models/inj/pitch_model_{day}d_results.pkl')

Training pitch-level model for preceding day window: 7 days


                                                                          

Epoch 01 | Training Loss: 1.266
Epoch 01 | Validation Loss: 1.043
New best model saved with validation loss: 1.043


                                                                          

Epoch 02 | Training Loss: 1.037
Epoch 02 | Validation Loss: 1.013
New best model saved with validation loss: 1.013


                                                                          

Epoch 03 | Training Loss: 1.010
Epoch 03 | Validation Loss: 0.939
New best model saved with validation loss: 0.939


                                                                          

Epoch 04 | Training Loss: 0.941
Epoch 04 | Validation Loss: 1.050


                                                                          

Epoch 05 | Training Loss: 0.981
Epoch 05 | Validation Loss: 0.874
New best model saved with validation loss: 0.874


                                                                          

Epoch 06 | Training Loss: 0.817
Epoch 06 | Validation Loss: 0.890


                                                                          

Epoch 07 | Training Loss: 0.835
Epoch 07 | Validation Loss: 0.796
New best model saved with validation loss: 0.796


                                                                          

Epoch 08 | Training Loss: 0.825
Epoch 08 | Validation Loss: 0.790
New best model saved with validation loss: 0.790


                                                                          

Epoch 09 | Training Loss: 0.764
Epoch 09 | Validation Loss: 0.822


                                                                           

Epoch 10 | Training Loss: 0.799
Epoch 10 | Validation Loss: 0.848


                                                                           

Epoch 11 | Training Loss: 0.874
Epoch 11 | Validation Loss: 0.801


                                                                           

Epoch 12 | Training Loss: 0.792
Epoch 12 | Validation Loss: 0.786
New best model saved with validation loss: 0.786


                                                                           

Epoch 13 | Training Loss: 0.740
Epoch 13 | Validation Loss: 0.810


                                                                           

Epoch 14 | Training Loss: 0.737
Epoch 14 | Validation Loss: 0.826


                                                                           

Epoch 15 | Training Loss: 0.749
Epoch 15 | Validation Loss: 0.840


                                                                           

Epoch 16 | Training Loss: 0.706
Epoch 16 | Validation Loss: 0.804


                                                                           

Epoch 17 | Training Loss: 0.709
Epoch 17 | Validation Loss: 0.763
New best model saved with validation loss: 0.763


                                                                           

Epoch 18 | Training Loss: 0.699
Epoch 18 | Validation Loss: 0.818


                                                                           

Epoch 19 | Training Loss: 0.812
Epoch 19 | Validation Loss: 0.783


                                                                           

Epoch 20 | Training Loss: 0.714
Epoch 20 | Validation Loss: 0.790


                                                                           

Epoch 21 | Training Loss: 0.691
Epoch 21 | Validation Loss: 0.769


                                                                           

Epoch 22 | Training Loss: 0.679
Epoch 22 | Validation Loss: 0.750
New best model saved with validation loss: 0.750


                                                                           

Epoch 23 | Training Loss: 0.677
Epoch 23 | Validation Loss: 0.804


                                                                           

Epoch 24 | Training Loss: 0.677
Epoch 24 | Validation Loss: 0.739
New best model saved with validation loss: 0.739


                                                                           

Epoch 25 | Training Loss: 0.649
Epoch 25 | Validation Loss: 0.752


                                                                           

Epoch 26 | Training Loss: 0.632
Epoch 26 | Validation Loss: 0.746


                                                                           

Epoch 27 | Training Loss: 0.653
Epoch 27 | Validation Loss: 0.757


                                                                           

Epoch 28 | Training Loss: 0.642
Epoch 28 | Validation Loss: 0.769


                                                                           

Epoch 29 | Training Loss: 0.629
Epoch 29 | Validation Loss: 0.782
Early stopping triggered after 29 epochs. Best validation loss: 0.739
[AWS]: Uploaded object to s3://pitch-ml/epidemiology/ml/models/inj/pitch_model_7d_results.pkl


$\textbf{Save Preprocessed Training/Validation Data}$

Mostly to save the scaled data, since saving scalers isn't finalized yet.

In [256]:
for day in [7, 15, 30, 45, 90]:
    # get model setup
    day_setup_train = model_setups[day]['train_setup']
    day_setup_val = model_setups[day]['val_setup']

    # save to disk
    with open(f'models/data/pitch_model_{day}d_trn_data.pkl', 'wb') as f:
        pickle.dump(day_setup_train, f)
    with open(f'models/data/pitch_model_{day}d_val_data.pkl', 'wb') as f:
        pickle.dump(day_setup_val, f)

    # TODO: upload to S3
    # with open(f'models/data/pitch_model_{day}d_trn_data.pkl', 'rb') as f:
    #     content = f.read()
    #     aws.upload_to_s3(content, f'epidemiology/ml/models/inj/pitch_model_{day}d_trn_data.pkl')
    # with open(f'models/data/pitch_model_{day}d_val_data.pkl', 'rb') as f:
    #     content = f.read()
    #     aws.upload_to_s3(content, f'epidemiology/ml/models/inj/pitch_model_{day}d_val_data.pkl')

$\textbf{Close AWS Connection}$

In [266]:
# close connection
aws.close()

[AWS]: No active connection to close.
