In [10]:
import torch
import pickle
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [33]:
# load tensors from disk
with open('storage/pitch_level_tensors.pkl', 'rb') as f:
    pitch_level_tensors = pickle.load(f)

$\textbf{Model Architecture}$

Both models use a __CNN-(bi)LSTM__ architecture:
- The CNN layers progressively extract salient features
- The bi-LSTM looks for temporal patterns that may distinguish injured and non-injured pitchers

In [34]:
""" LOSS FUNCTIONS """
def pitch_level_loss(
        logits_step: torch.Tensor, 
        y_step: torch.Tensor, 
        mask: torch.Tensor, 
        pos_weight: bool = False
) -> torch.Tensor:
    """ 
    Compute pitch-level loss given ground truth. Valid for binary or smoothed (e.g., sigmoid) outcome labels.

    Args:
        logits_step (torch.Tensor): Logits from the model of shape [B, T].
        y_step (torch.Tensor): Ground truth labels of shape [B, T]. Should be in [0, 1].
        mask (torch.Tensor): Mask indicating valid time steps of shape [B, T].
        pos_weight (bool, optional): Whether or not to use weights for positive class in BCE loss. Default is False.
    """
    if pos_weight is None:
        # setup weights
        pos = y_step[mask].sum()
        neg = mask.sum() - pos
        pos_weight = neg / pos.clamp(min=1.0)
        
        return F.binary_cross_entropy_with_logits(logits_step[mask], y_step[mask])
    
    return F.binary_cross_entropy_with_logits(logits_step[mask], y_step[mask], pos_weight=pos_weight)

@torch.no_grad()
def compute_pos_weight(
    train_loader: DataLoader, 
    device: torch.device = torch.device('cpu')
) -> torch.Tensor:
    pos = 0.0
    tot = 0.0
    
    # iterate through training data to compute positive and total counts
    for x, y, m, L in train_loader:
        y, m = y.to(device), m.to(device)
        pos += y[m].sum().item()
        tot += m.sum().item()
    
    # compute positive weight
    neg = max(tot - pos, 1.0)
    pos = max(pos, 1.0)
    
    return torch.tensor(neg / pos, device=device, dtype=torch.float32)

# masked softmax function
    # computes softmax over scores while ignoring masked positions
def masked_softmax(
        scores: torch.Tensor, 
        mask: torch.Tensor, 
        dim: int = -1
) -> torch.Tensor:
    """ Compute masked softmax over scores."""
    # fill masked positions with -inf to ignore them in softmax
    scores = scores.masked_fill(~mask, float('-inf'))
    
    # apply softmax to scores along the specified dimension
    attn = torch.softmax(scores, dim=dim)
    attn = attn * mask.float()  # make sure padded positions are 0 after softmax
    
    # compute denominator for normalization
    denom = attn.sum(dim=dim, keepdim=True).clamp_min(1e-8)
    
    return attn / denom



In [39]:
""" MODEL ARCHITECTURES """
# CNN block for local patterns
class CNNBlock(nn.Module):
    """ 
    Depthwise-separable 1D convolutional block over time with residual. 

    Args:
        num_channels (int): Number of input channels.
        kernel (int): Size of the convolutional kernel. Defaults to 7.
        dropout (float): Dropout rate. Defaults to 0.1.
    
    Returns:
        None
    
    **Note**: CNN expects tensor with shape [B, C, T]. 
    """
    def __init__(
            self, 
            num_channels: int, 
            kernel: int = 7, 
            dropout: float = 0.1
    ) -> None:
        super().__init__()
        pad = kernel // 2
        
        # depthwise convolution
        self.dw = nn.Conv1d(num_channels, num_channels, kernel_size=kernel, padding=pad, groups=num_channels)
        self.pw = nn.Conv1d(num_channels, num_channels, kernel_size=1)
        
        # batch normalization, activation, and dropout
        self.bn = nn.BatchNorm1d(num_channels)
        self.act = nn.GELU()
        self.drop = nn.Dropout(dropout)

    def forward(
            self, 
            x: torch.Tensor
    ):  # x: [B,C,T]
        residual = x
        
        # apply layers
        x = self.dw(x)
        x = self.pw(x)
        
        # batch normalization, activation, and dropout
        x = self.bn(x)
        x = self.act(x)
        x = self.drop(x)
        
        return x + residual

class CNNbiLSTM(nn.Module):
    """
    CNN + BiLSTM model for time series data with per-pitch head.

    Args:
        k_in (int): Number of input features (K).
        stem (int): Number of channels in the stem layer. Defaults to 64.
        c (int): Number of channels after projection. Defaults to 96.
        kernel (int): Size of the convolutional kernel. Defaults to 7.
        lstm_hidden (int): Hidden size for the LSTM layer. Defaults to 128.
        dropout (float): Dropout rate. Defaults to 0.1.
        bidir (bool): Whether to use a bidirectional LSTM. Defaults to True.

    Returns:
        None

    **Note**: CNN expects tensor with shape [B, K, T]. 
    """
    def __init__(
            self, 
            k_in: int, 
            stem: int = 64, 
            c: int = 96, 
            kernel: int = 7, 
            lstm_hidden: int = 128, 
            dropout: float = 0.1, 
            bidir: bool = True
    ) -> None:
        super().__init__()
        
        # 1x1 stem to mix K features into 'stem' channels
        self.stem = nn.Sequential(
            nn.Conv1d(k_in, stem, kernel_size=1),
            nn.GELU(),
            nn.Dropout(dropout),
        )
        
        # progressive temporal convs
        self.conv1 = CNNBlock(stem, kernel=kernel, dropout=dropout)
        self.proj1 = nn.Conv1d(stem, c, kernel_size=1)                    # project to C channels
        self.conv2 = CNNBlock(c, kernel=kernel, dropout=dropout)

        # BiLSTM over time
        self.lstm = nn.LSTM(input_size=c, hidden_size=lstm_hidden, batch_first=True, bidirectional=bidir)
        hdim = lstm_hidden * (2 if bidir else 1)

        # per-pitch head
        self.head_step = nn.Sequential(
            nn.LayerNorm(hdim),
            nn.Linear(hdim, 1)
        )

        # attention scorer for pooling + sequence head
        self.attn_scorer = nn.Linear(hdim, 1)
        self.head_seq    = nn.Linear(hdim, 1)

    def forward(
            self, 
            x: torch.Tensor, 
            lengths: torch.Tensor,
            mask: torch.Tensor
    ) -> None:
        # CNN expects [B,K,T]
        x = x.transpose(1, 2)                  # [B,K,T]
        x = self.stem(x)                       # [B,stem,T]
        x = self.conv1(x)                      # [B,stem,T]
        x = self.proj1(x)                      # [B,C,T]
        x = self.conv2(x)                      # [B,C,T]
        x = x.transpose(1, 2)                  # [B,T,C] for LSTM

        # pack to ignore padding inside LSTM
        packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, _ = self.lstm(packed)
        out, _ = pad_packed_sequence(packed_out, batch_first=True, total_length=x.size(1))    # [B,T,H]

        # apply head to get logits for each pitch
        logits_step = self.head_step(out).squeeze(-1)  # [B,T]

        # attention pooling (mask‑aware)
        scores = self.attn_scorer(out).squeeze(-1)     # [B,T]
        attn   = masked_softmax(scores, mask, dim=1)   # [B,T]
        ctx    = (attn.unsqueeze(-1) * out).sum(dim=1) # [B,H]
        logit_seq = self.head_seq(ctx).squeeze(-1)     # [B]
        
        return logits_step, logit_seq, attn 


$\textbf{Model Development (Beta)}$

In [40]:
from tqdm import tqdm

In [41]:
def compute_masked_scalers(x, mask):
    # x: [B,T,K], mask: [B,T]
    m = mask.unsqueeze(-1)                  # [B,T,1]
    num = m.sum(dim=(0,1)).clamp(min=1)     # [K]
    
    # compute mean, var, std
    mean = (x * m).sum(dim=(0,1)) / num
    var  = ((x - mean) * m).pow(2).sum(dim=(0,1)) / num
    std  = var.sqrt().clamp(min=1e-6)
    
    return mean, std

# normalize all splits in-place (or create new tensors)
def apply_scalers(x, mean, std): 
    return (x - mean) / std


In [42]:
# set number of epochs, batch size
NUM_EPOCHS = 5      # if none, run intil early stopping
BATCH_SIZE = 32

## TODO: create compile_model() function --->

# standardize example sequence --> example training tensor sequence (x)
example_mean, example_std = compute_masked_scalers(pitch_level_tensors['trn']['seq'], pitch_level_tensors['trn']['mask'])
x = apply_scalers(pitch_level_tensors['trn']['seq'], example_mean, example_std)

# setup (device, shapes)
device = "cuda" if torch.cuda.is_available() else "cpu"
B, T, K = x.shape
print(f'Input shape: {x.shape}, Device: {device}')

# move full tensors to device once (since we’re not using a DataLoader yet)
x               = x.float().to(device)
y_step_binary   = pitch_level_tensors['trn']['binary'].float().to(device)
mask            = pitch_level_tensors['trn']['mask'].bool().to(device)
lengths         = pitch_level_tensors['trn']['lengths'].long().to(device)

# setup model
model = CNNbiLSTM(k_in=K, stem=64, c=96, kernel=7, lstm_hidden=128, dropout=0.1, bidir=True).to(device)

# optimizer + class weights
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

# optional: compute pos_weight over valid steps once
    # NOTE: for probs this should be 1.0
with torch.no_grad():
    pos = y_step_binary.sum()
    neg = y_step_binary.shape[0] - pos
    pos_weight = (neg / pos.clamp(min=1)).float()

print("pos_weight =", float(pos_weight))


Input shape: torch.Size([1207, 12104, 7]), Device: cpu
pos_weight = 3.326164960861206


In [None]:
for epoch in range(1, NUM_EPOCHS+1):
    # set training info
    model.train()
    idx = torch.randperm(B, device=device)

    # loss counters
    running_loss = 0.0
    running_correct = 0
    running_total = 0

    # wrap range() with tqdm
    pbar = tqdm(range(0, B, BATCH_SIZE), desc=f"Epoch {epoch}/{NUM_EPOCHS}", leave=False)
    for start in pbar:
        end = min(start + BATCH_SIZE, B)
        bidx = idx[start:end]

        # forward pass setup
        xb = x[bidx]                        # [B,T,K]
        yb = y_step_binary[bidx]       # [B] sequence-level labels (0/1)
        mb = mask[bidx]                     # [B,T]
        Lb = lengths[bidx]                  # [B]

        # forward pass: get per-pitch + sequence logits
        logits_step, logit_seq, attn = model(xb, Lb, mb)

        # main sequence-level loss
        loss = F.binary_cross_entropy_with_logits(
            logit_seq, yb.float(), pos_weight=pos_weight
        )

        # back propagation
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # running loss
        bs = xb.size(0)
        running_loss += loss.item() * bs
        
       # accuracy (sequence-level)
        probs  = torch.sigmoid(logit_seq)
        preds  = (probs > 0.5).float()
        correct = (preds == yb).sum().item()
        total   = yb.size(0)

        # update total
        running_correct += correct
        running_total += total

        # update accuracy
        run_avg_loss = running_loss / ((start // BATCH_SIZE + 1) * bs)
        run_acc      = running_correct / running_total

        # show both current and running losses in the bar
        pbar.set_postfix(loss=f"{loss.item():.4f}", acc=f"{run_acc:.4f}")

    break

    # update epoch loss
    epoch_loss = running_loss / B
    epoch_acc  = running_correct / running_total
    
    print(f"Epoch {epoch:02d} | Training Loss: {epoch_loss:.3f} | Training Accuracy: {epoch_acc:.3f}")