In [1]:
import importlib
import data_preprocessing
import data_set_and_loader
import models.multihead_pitch_prediction
from layers.lstm import LSTM
importlib.reload(data_preprocessing)
importlib.reload(data_set_and_loader)
importlib.reload(models.multihead_pitch_prediction)
import torch
from torch import nn
import torch.optim as optim
import os
import time
from models.multihead_pitch_prediction import MultiHeadLSTM
from data_preprocessing import load_data, sort_n_group, build_seqs, encode_and_scale, compute_feature_medians, split_data, sort_n_group_pitcher
from data_set_and_loader import create_dataloaders, AtBatDataset
from config import DATA_PATH

In [2]:
df = load_data(DATA_PATH)

In [3]:
df_pitcher = sort_n_group_pitcher(df, 668881)

In [5]:
median_values = compute_feature_medians(df_pitcher)

In [6]:
X_sequences , Y_sequences = build_seqs(df_pitcher, median_values)

Finished sequence building. Dropped 11 at-bats with NaN values.
Dropped 279 at-bats due to outliers.
Dropped 14 at-bats due to invalid results: {'strikeout_double_play', 'walk', 'truncated_pa', 'strikeout'}


In [13]:
num_indices = [8, 9, 16, 17, 18, 19, 20]
cat_indices = [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 21, 22]

processed_X, processed_Y, label_encoders_X, y_type_encoder, y_desc_encoder, y_event_encoder = encode_and_scale(X_sequences, Y_sequences)

In [15]:
X_train, X_test, Y_train, Y_test = split_data(processed_X, processed_Y, test_size=0.2, random_state=42)

In [16]:
train_loader, test_loader = create_dataloaders(X_train, Y_train, X_test, Y_test, 64) #remember to adjust batch-size here too

In [17]:
x_num_indices = [8, 9, 16, 17, 18, 19, 20]
x_low_card_cats_indices = [0, 1, 2, 3, 4, 5, 7, 10, 11, 14]
x_inning_index = 6
x_pitcher_index = 12
x_batter_index = 13
x_prev_pitch_index = 15
x_prev_desc_index = 22
x_prev_event_index = 21

In [19]:
EPOCHS = 30
BATCH_SIZE = 64
LEARNING_RATE = 0.002
MODEL_SAVE_PATH = r"C:\Users\Richard\Documents\SEG4300\Project\SEG4300-Project\Code\saved_models\best_model_pitcher.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize model (update parameters based on your dataset)
model = MultiHeadLSTM(
    num_numeric_features=7,
    num_pitchers=2, num_batters=351,
    num_prev_descriptions=11, num_prev_events=3, num_prev_pitch_types=7,
    num_low_card_cats=10, 
    num_innings=9, inning_emb_dim=4,
    pitcher_emb_dim=2, batter_emb_dim=14,
    prev_description_emb_dim=4, prev_event_emb_dim=2, prev_pitch_emb_dim=3,
    hidden_dim=256,
    num_pitch_type_classes=6, num_description_classes=12, num_event_classes=20,
    cont_dim=5, lstm_layers=2, dropout=0.2
).to(device)

# Loss functions
criterion_classification = nn.CrossEntropyLoss(ignore_index=0)
criterion_continuous = nn.MSELoss(reduction="none")

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)

# Learning rate scheduler: Reduce LR by 50% every 5 epochs
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)

# Track best validation loss
best_val_loss = float("inf")


### 🚀 Training Function
def train_one_epoch(epoch):
    model.train()
    total_loss, total_pitch_type_loss, total_desc_loss, total_event_loss, total_cont_loss = 0, 0, 0, 0, 0
    correct_pitch_types, correct_desc, correct_event = 0, 0, 0
    total_pitch_types, total_desc, total_event = 0, 0, 0

    start_time = time.time()

    for batch in train_loader:
        (
            padded_X, padded_Y_type, padded_Y_cont, padded_Y_desc, padded_Y_event, lengths
        ) = [tensor.to(device) for tensor in batch]

        # Extract features based on new indices
        x_num = padded_X[:, :, x_num_indices]  # Numeric features
        x_low_card_cats = padded_X[:, :, x_low_card_cats_indices]  # Low-cardinality categorical
        x_inning = padded_X[:, :, x_inning_index].long()  # Inning (integer categories)
        x_pitcher = padded_X[:, :, x_pitcher_index].long()
        x_batter = padded_X[:, :, x_batter_index].long()
        x_prev_pitch = padded_X[:, :, x_prev_pitch_index].long()
        x_prev_desc = padded_X[:, :, x_prev_desc_index].long()
        x_prev_event = padded_X[:, :, x_prev_event_index].long()

        optimizer.zero_grad()

        # Forward pass
        pitch_type_logits, pitch_cont_values, pitch_result_desc, pitch_result_event = model(
            x_num, x_low_card_cats, x_inning, x_pitcher, x_batter, x_prev_desc, x_prev_event, x_prev_pitch, lengths
        )

        # Compute losses
        loss_pitch_type = criterion_classification(
            pitch_type_logits.view(-1, pitch_type_logits.size(-1)), padded_Y_type.view(-1)
        )
        loss_desc = criterion_classification(
            pitch_result_desc.view(-1, pitch_result_desc.size(-1)), padded_Y_desc.view(-1)
        )
        loss_event = criterion_classification(
            pitch_result_event.view(-1, pitch_result_event.size(-1)), padded_Y_event.view(-1)
        )
        #Handle padded numerical values (-999) in MSE loss
        mask = (padded_Y_cont != -999).float()  # Create a mask (1 for real values, 0 for padded)
        loss_cont = (mask * (pitch_cont_values - padded_Y_cont) ** 2).sum() / mask.sum()

        # Total loss
        loss = loss_pitch_type + loss_desc + loss_event + loss_cont
        loss.backward()
        optimizer.step()

        # Track losses
        total_loss += loss.item()
        total_pitch_type_loss += loss_pitch_type.item()
        total_desc_loss += loss_desc.item()
        total_event_loss += loss_event.item()
        total_cont_loss += loss_cont.item()

        # Compute accuracy for all categorical outputs (ignore padding)
        mask_type = (padded_Y_type != 0)  
        mask_desc = (padded_Y_desc != 0)
        mask_event = (padded_Y_event != 0)

        _, predicted_pitch_type = torch.max(pitch_type_logits, dim=-1)
        _, predicted_desc = torch.max(pitch_result_desc, dim=-1)
        _, predicted_event = torch.max(pitch_result_event, dim=-1)

        correct_pitch_types += (predicted_pitch_type[mask_type] == padded_Y_type[mask_type]).sum().item()
        correct_desc += (predicted_desc[mask_desc] == padded_Y_desc[mask_desc]).sum().item()
        correct_event += (predicted_event[mask_event] == padded_Y_event[mask_event]).sum().item()

        total_pitch_types += mask_type.sum().item()
        total_desc += mask_desc.sum().item()
        total_event += mask_event.sum().item()

    avg_loss = total_loss / len(train_loader)
    pitch_type_acc = 100 * correct_pitch_types / total_pitch_types
    desc_acc = 100 * correct_desc / total_desc
    event_acc = 100 * correct_event / total_event

    end_time = time.time()
    print(f"Epoch [{epoch+1}/{EPOCHS}] - Loss: {avg_loss:.4f} - Pitch Type Acc: {pitch_type_acc:.2f}% - Desc Acc: {desc_acc:.2f}% - Event Acc: {event_acc:.2f}% - Time: {end_time - start_time:.2f}s")

    return avg_loss


### 🚀 Validation Function
def validate():
    model.eval()
    val_loss, val_pitch_type_loss, val_desc_loss, val_event_loss, val_cont_loss = 0, 0, 0, 0, 0
    correct_pitch_types, correct_desc, correct_event = 0, 0, 0
    total_pitch_types, total_desc, total_event = 0, 0, 0

    with torch.no_grad():
        for batch in test_loader:
            (
                padded_X, padded_Y_type, padded_Y_cont, padded_Y_desc, padded_Y_event, lengths
            ) = [tensor.to(device) for tensor in batch]

            # Extract features based on new indices
            x_num = padded_X[:, :, x_num_indices]  # Numeric features
            x_low_card_cats = padded_X[:, :, x_low_card_cats_indices]  # Low-cardinality categorical
            x_inning = padded_X[:, :, x_inning_index].long()  # Inning (integer categories)
            x_pitcher = padded_X[:, :, x_pitcher_index].long()
            x_batter = padded_X[:, :, x_batter_index].long()
            x_prev_pitch = padded_X[:, :, x_prev_pitch_index].long()
            x_prev_desc = padded_X[:, :, x_prev_desc_index].long()
            x_prev_event = padded_X[:, :, x_prev_event_index].long()

            pitch_type_logits, pitch_cont_values, pitch_result_desc, pitch_result_event = model(
                x_num, x_low_card_cats, x_inning, x_pitcher, x_batter, x_prev_desc, x_prev_event, x_prev_pitch, lengths
            )

            # Compute validation losses
            loss_pitch_type = criterion_classification(
                pitch_type_logits.view(-1, pitch_type_logits.size(-1)), padded_Y_type.view(-1)
            )
            loss_desc = criterion_classification(
                pitch_result_desc.view(-1, pitch_result_desc.size(-1)), padded_Y_desc.view(-1)
            )
            loss_event = criterion_classification(
                pitch_result_event.view(-1, pitch_result_event.size(-1)), padded_Y_event.view(-1)
            )
            mask = (padded_Y_cont != -999).float()
            loss_cont = (mask * (pitch_cont_values - padded_Y_cont) ** 2).sum() / mask.sum()

            loss = loss_pitch_type + loss_desc + loss_event + loss_cont
            val_loss += loss.item()

            # Compute accuracy for all categorical outputs (ignore padding)
            mask_type = (padded_Y_type != 0)  
            mask_desc = (padded_Y_desc != 0)
            mask_event = (padded_Y_event != 0)

            _, predicted_pitch_type = torch.max(pitch_type_logits, dim=-1)
            _, predicted_desc = torch.max(pitch_result_desc, dim=-1)
            _, predicted_event = torch.max(pitch_result_event, dim=-1)

            correct_pitch_types += (predicted_pitch_type[mask_type] == padded_Y_type[mask_type]).sum().item()
            correct_desc += (predicted_desc[mask_desc] == padded_Y_desc[mask_desc]).sum().item()
            correct_event += (predicted_event[mask_event] == padded_Y_event[mask_event]).sum().item()

            total_pitch_types += mask_type.sum().item()
            total_desc += mask_desc.sum().item()
            total_event += mask_event.sum().item()

    avg_val_loss = val_loss / len(test_loader)
    pitch_type_acc = 100 * correct_pitch_types / total_pitch_types
    desc_acc = 100 * correct_desc / total_desc
    event_acc = 100 * correct_event / total_event

    print(f"Validation Loss: {avg_val_loss:.4f} - Pitch Type Acc: {pitch_type_acc:.2f}% - Desc Acc: {desc_acc:.2f}% - Event Acc: {event_acc:.2f}%")

    return avg_val_loss


early_stopping_patience = 5
no_improvement_epochs = 0

### 🚀 Training Loop
for epoch in range(EPOCHS):
    train_loss = train_one_epoch(epoch)
    val_loss = validate()

    if val_loss < best_val_loss:
        print("🔥 New best model found! Saving...")
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    scheduler.step()

    if no_improvement_epochs >= early_stopping_patience:
        print("Early stopping triggered")
        break

print("Training complete! Best model saved as", MODEL_SAVE_PATH)

Using device: cpu
Epoch [1/30] - Loss: 6.6206 - Pitch Type Acc: 44.04% - Desc Acc: 25.61% - Event Acc: 29.47% - Time: 0.85s
Validation Loss: 5.3987 - Pitch Type Acc: 41.95% - Desc Acc: 28.99% - Event Acc: 38.01%
🔥 New best model found! Saving...
Epoch [2/30] - Loss: 5.3098 - Pitch Type Acc: 53.41% - Desc Acc: 30.36% - Event Acc: 39.76% - Time: 0.83s
Validation Loss: 5.1884 - Pitch Type Acc: 53.95% - Desc Acc: 28.99% - Event Acc: 38.01%
🔥 New best model found! Saving...
Epoch [3/30] - Loss: 5.2189 - Pitch Type Acc: 55.34% - Desc Acc: 33.84% - Event Acc: 39.30% - Time: 0.80s
Validation Loss: 5.1503 - Pitch Type Acc: 53.95% - Desc Acc: 33.29% - Event Acc: 38.01%
🔥 New best model found! Saving...
Epoch [4/30] - Loss: 5.1783 - Pitch Type Acc: 55.34% - Desc Acc: 32.25% - Event Acc: 39.11% - Time: 0.78s
Validation Loss: 5.1316 - Pitch Type Acc: 53.95% - Desc Acc: 33.29% - Event Acc: 42.80%
🔥 New best model found! Saving...
Epoch [5/30] - Loss: 5.1468 - Pitch Type Acc: 55.34% - Desc Acc: 34.86