In [65]:
import sys
sys.path.append('../')

from Datasets.BaseballDataset import BaseballDataset

import torch
import torch.nn as nn
import torch.optim as optim
import math
import torch.nn.functional as F
from torch.utils.data import DataLoader
import json
import pandas as pd
import os
import matplotlib.pyplot as plt
import numpy as np
import pickle
from sklearn.preprocessing import StandardScaler

In [66]:
data_config_path = "../data/config.json"
data_path = "../data/mini_train.csv"
sequence_length = 200
data = pd.read_csv(data_path)


In [67]:
mini_dataset = BaseballDataset(data,data_config_path,sequence_length)

In [68]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x + self.pe[:x.size(1)].transpose(0, 1)
        return self.dropout(x)

class TransformerModel(nn.Module):
    def __init__(self, input_dim, num_heads, num_encoder_layers, hidden_dim, output_dim, sequence_length, dropout=0.1):
        super(TransformerModel, self).__init__()
        
        self.input_dim = input_dim
        self.sequence_length = sequence_length
        
        self.embedding = nn.Linear(input_dim, hidden_dim)
        self.positional_encoding = PositionalEncoding(hidden_dim, dropout)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers)
        
        self.fc_layers = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        x = self.transformer_encoder(x)
        x = x[:, -1, :]  # Use the output of the last pitch in the sequence
        x = self.fc_layers(x)
        return x


def load_model(model_path, config_path):
    with open(config_path, 'r') as file:
        config = json.load(file)

    model = TransformerModel(
        input_dim=config['input_dim'],
        num_heads=config['num_heads'],
        num_encoder_layers=config['num_encoder_layers'],
        hidden_dim=config['hidden_dim'],
        output_dim=config['output_dim'],
        sequence_length=config['sequence_length'],
        dropout=config.get('dropout', 0.1)  # Optional: provide a default value for dropout if not in config
    )

    model.load_state_dict(torch.load(model_path))
    model.eval()  # Set the model to evaluation mode
    return model



def make_preds(model, dataset, scaler_path, device, batch_size):

    #get column names in correct order
    flat_cat_names = []
    for names in dataset.categorical_label_names:
        flat_cat_names = flat_cat_names + names 
    col_names = dataset.continuous_label_names + flat_cat_names

    #create dataloader for dataset
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    model.eval()

    preds_array = [] #keep trask of preds for each batch
    true_array = [] #keep track of true values
    with torch.no_grad():
        idx = 0
        for sequence_tensor, cont_target_tensor, cat_target_tensor in loader:
            idx += 1
            if idx % 10 == 0:
                print(f"Starting Batch: {idx}")
                
            sequence_tensor, cont_target_tensor = sequence_tensor.to(device), cont_target_tensor.to(device)
            cat_targets = [t.to(device) for t in cat_target_tensor]
            output = model(sequence_tensor)

            #first k logits correspond to continuous outputs, k = cont_target.size(1)
            cont_output = output[:, :cont_target_tensor.size(1)].cpu().squeeze(0).detach().numpy()
            cont_targets = cont_target_tensor.cpu().squeeze(0).detach().numpy()

            #can have multiple kinds of categorical outputs. If cat_targets is (batch_size, 2, 10), there are 2 kinds of cateogorical outputs, each with 10 values.
            #The first 10 logits after the continuous logits will correspond to first categorical output, second 10 to the second, so this requires multiple softmaxes
            cat_probs = []
            cat_target_probs = []
            start_idx = cont_target_tensor.size(1)
            for cat_target in cat_targets:
                end_idx = start_idx + cat_target.size(1)
                cat_probs.append(nn.functional.softmax(output[:, start_idx:end_idx],dim=1).cpu().squeeze(0).detach().numpy())
                cat_target_probs.append(cat_target.cpu().squeeze(0).detach().numpy())
                start_idx = end_idx
    
            #cat continuous and categorical outputs together
            preds = cont_output
            for probs in cat_probs:
                preds = np.concatenate((preds, probs),axis=1)
            
            preds_array.append(preds)

            true = cont_targets
            for probs in cat_target_probs:
                true = np.concatenate((true, probs),axis=1)
            
            true_array.append(true)

    #make single preds pd     
    preds_array = np.vstack(preds_array)
    preds_pd = pd.DataFrame(preds_array, columns=col_names)

    true_array = np.vstack(true_array)
    true_pd = pd.DataFrame(true_array, columns=col_names)

    #scale continuous outputs back to real values
    with open(scaler_path, "rb") as file:
        scalers = pickle.load(file)

    for column, scaler in scalers.items():
        if column in preds_pd:
            preds_pd[column] = (preds_pd[column] * scaler.scale_) + scaler.mean_
            true_pd[column] = (true_pd[column] * scaler.scale_) + scaler.mean_
    

    return preds_pd, true_pd

In [69]:
m_path = "tiny_data_grid_experiment/h8_e8_h72_d0.1_lp0.5_lr0.001_ep50/transformer_model.pth"
c_path = "tiny_data_grid_experiment/h8_e8_h72_d0.1_lp0.5_lr0.001_ep50/model_config.json"
model = load_model(m_path, c_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

cuda


TransformerModel(
  (embedding): Linear(in_features=75, out_features=72, bias=True)
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=72, out_features=72, bias=True)
        )
        (linear1): Linear(in_features=72, out_features=72, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=72, out_features=72, bias=True)
        (norm1): LayerNorm((72,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((72,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc_layers): Sequential(
    (0): Linear(in_features=72, out_features=72, bias=True)
    (1): ReLU()
    (2): L

In [70]:
preds, true = make_preds(model,mini_dataset,"../data/statcast_2023-2024_cleaned_scalers.pkl",device, batch_size=100)

Starting Batch: 10
Starting Batch: 20


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [71]:
preds

Unnamed: 0,launch_speed,hc_x,hc_y,launch_angle,events_B,events_S,events_double,events_field_out,events_hit_by_pitch,events_home_run,...,hit_location_0.0,hit_location_1.0,hit_location_2.0,hit_location_3.0,hit_location_4.0,hit_location_5.0,hit_location_6.0,hit_location_7.0,hit_location_8.0,hit_location_9.0
0,53.764048,53.867767,53.860563,53.784510,0.997373,0.002431,0.000002,0.000019,1.046915e-04,0.000004,...,0.999923,2.567186e-06,0.000012,0.000008,3.390405e-06,0.000007,0.000008,3.247183e-06,0.000018,0.000014
1,53.746073,53.860694,53.869212,53.812892,0.995757,0.002842,0.000007,0.000046,1.089696e-03,0.000011,...,0.999832,8.929605e-06,0.000021,0.000017,8.427814e-06,0.000014,0.000018,6.182944e-06,0.000043,0.000031
2,54.919176,55.056432,54.986871,54.766815,0.000718,0.003420,0.073950,0.630928,9.602335e-05,0.087825,...,0.071173,2.277478e-02,0.007634,0.077221,9.436519e-02,0.106601,0.161335,1.561804e-01,0.164030,0.138685
3,53.772328,53.844697,53.852248,53.811501,0.990320,0.009557,0.000001,0.000012,5.559679e-05,0.000002,...,0.999977,4.604223e-07,0.000007,0.000002,8.275414e-07,0.000002,0.000002,9.726683e-07,0.000004,0.000003
4,53.779098,53.846921,53.850949,53.810150,0.994781,0.004904,0.000002,0.000015,2.028626e-04,0.000003,...,0.999937,1.101750e-06,0.000026,0.000005,2.199938e-06,0.000004,0.000005,2.124563e-06,0.000009,0.000008
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2781,54.026020,53.818864,53.823651,54.080974,0.001054,0.998211,0.000024,0.000494,8.938545e-07,0.000068,...,0.998270,2.695666e-05,0.000004,0.000126,1.078622e-04,0.000124,0.000126,2.175103e-04,0.000738,0.000259
2782,54.604809,54.753780,54.852090,54.510408,0.000287,0.001259,0.056537,0.594656,2.565821e-04,0.051700,...,0.049152,3.870830e-02,0.157678,0.053675,6.071455e-02,0.105687,0.144597,2.046981e-01,0.076841,0.108248
2783,54.120436,53.819953,53.826137,54.186605,0.000641,0.999017,0.000014,0.000218,5.159933e-07,0.000034,...,0.998962,1.768196e-05,0.000004,0.000076,6.533314e-05,0.000072,0.000064,1.300703e-04,0.000481,0.000128
2784,54.096018,53.826697,53.814396,54.163913,0.000812,0.997946,0.000041,0.000850,1.310113e-06,0.000110,...,0.996928,4.807348e-05,0.000007,0.000216,1.882192e-04,0.000196,0.000227,4.292411e-04,0.001301,0.000460


In [85]:
preds['hc_x'].describe()

count    2786.000000
mean       54.009474
std         0.397340
min        53.741617
25%        53.821997
50%        53.850706
75%        53.879205
max        55.171426
Name: hc_x, dtype: float64

In [87]:
true['launch_angle'].describe()

count    2786.000000
mean       54.100048
std         0.501494
min        53.776220
25%        53.776220
50%        53.776220
75%        54.539128
max        55.417279
Name: launch_angle, dtype: float64