In [1]:
import sys
sys.path.append('../')

from Datasets.BaseballDataset import BaseballDataset

import torch
import torch.nn as nn
import torch.optim as optim
import math
import torch.nn.functional as F
from torch.utils.data import DataLoader
import json
import pandas as pd
import os
import matplotlib.pyplot as plt
import numpy as np
import pickle
from sklearn.preprocessing import StandardScaler

In [43]:
data_config_path = "../data/config.json"
train_data_path = "../data/mini_train.csv"
test_data_path = "../data/mini_test.csv"
sequence_length = 200
train_data = pd.read_csv(train_data_path)
test_data = pd.read_csv(test_data_path)

In [45]:
train_dataset = BaseballDataset(train_data,data_config_path,sequence_length)
test_dataset = BaseballDataset(test_data,data_config_path,sequence_length)

In [46]:
from torch.utils.data import DataLoader
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import mean_squared_error, log_loss
import numpy as np


class BaselineModel:
    def __init__(self, dataset, scalers_path, max_iters=3000):
        self.dataset = dataset
        # Initialize models for each continuous and categorical target
        self.continuous_models = [LinearRegression() for _ in range(dataset[0][1].shape[0])]
        self.categorical_models = [LogisticRegression(max_iter=max_iters) for _ in range(len(dataset[0][2]))]
        self.continuous_label_names = dataset.continuous_label_names
        self.categorical_label_names = dataset.categorical_label_names

        with open(scalers_path, "rb") as file:
            self.scalers = pickle.load(file)
    
    def train(self, batch_size=32):
        # Prepare data loader
        dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)

        X_cont = []
        Y_cont = []
        X_cat = []
        Y_cat = [[] for _ in range(len(self.categorical_models))]

        for sequences, cont_target_tensor, cat_target_tensors in dataloader:
            # Use only the last pitch in the sequence as input
            last_pitches = sequences[:, -1, :].numpy()

            X_cont.append(last_pitches)
            Y_cont.append(cont_target_tensor.numpy())

            for i, cat_target_tensor in enumerate(cat_target_tensors):
                Y_cat[i].append(cat_target_tensor.numpy())
        
        X_cont = np.concatenate(X_cont)
        Y_cont = np.concatenate(Y_cont)

        X_cat = X_cont.copy()  # Same input data for categorical targets

        # Train each model on the corresponding target
        for i, model in enumerate(self.continuous_models):
            model.fit(X_cont, Y_cont[:, i])
        
        for i, model in enumerate(self.categorical_models):
            Y_cat_combined = np.concatenate(Y_cat[i])
            model.fit(X_cat, Y_cat_combined.argmax(axis=1))  # Train on the class index, not the one-hot encoding
    
    def predict(self, sequences, scale=False):
        # Extract the last pitch in each sequence
        last_pitches = sequences[:, -1, :].numpy()

        # Predict continuous targets
        cont_preds = [model.predict(last_pitches) for model in self.continuous_models]
        cont_preds = np.stack(cont_preds, axis=-1)

        # Predict categorical targets
        cat_preds = [model.predict_proba(last_pitches) for model in self.categorical_models]
        cat_preds = np.concatenate(cat_preds, axis=-1)

        # Combine continuous and categorical predictions
        all_preds = np.concatenate([cont_preds, cat_preds], axis=-1)

        # Convert predictions to a pandas DataFrame
        flat_cat_names = [name for sublist in self.categorical_label_names for name in sublist]
        col_names = self.continuous_label_names + flat_cat_names
        preds_df = pd.DataFrame(all_preds, columns=col_names)

        if scale:
            # Re-scale continuous predictions
            for i, column in enumerate(self.continuous_label_names):
                if column in self.scalers:
                    scaler = self.scalers[column]
                    preds_df[column] = (preds_df[column] * scaler.scale_) + scaler.mean_

        return preds_df
    

In [38]:
scalers_path = "../data/statcast_2023-2024_cleaned_scalers.pkl"
baseline_model = BaselineModel(dataset,scalers_path)
baseline_model.train()

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [41]:
dataloader = DataLoader(dataset, batch_size=3000, shuffle=True)

for sequences, cont_target_tensor, cat_target_tensors in dataloader:
    preds = baseline_model.predict(sequences, scale=True)

preds

Unnamed: 0,launch_speed,hc_x,hc_y,launch_angle,events_B,events_S,events_double,events_field_out,events_hit_by_pitch,events_home_run,...,hit_location_0.0,hit_location_1.0,hit_location_2.0,hit_location_3.0,hit_location_4.0,hit_location_5.0,hit_location_6.0,hit_location_7.0,hit_location_8.0,hit_location_9.0
0,53.960797,53.967503,53.932946,53.966821,0.300625,0.629927,0.004669,0.054791,0.000479,0.004367,...,0.943762,0.005257,0.001012,0.003759,0.009192,0.002434,0.008543,0.004016,0.010498,0.011526
1,54.052873,54.012064,54.034799,54.061399,0.246985,0.610503,0.004220,0.097898,0.000632,0.010694,...,0.850951,0.005090,0.001154,0.001495,0.007642,0.035638,0.038119,0.006930,0.045047,0.007934
2,54.229637,54.102094,54.094364,54.213041,0.339901,0.285499,0.060193,0.129775,0.002319,0.035479,...,0.742680,0.010675,0.064940,0.006178,0.010999,0.025167,0.025471,0.031341,0.032034,0.050515
3,53.922375,53.836550,53.847008,53.985464,0.373159,0.294867,0.004108,0.072393,0.000336,0.007604,...,0.701049,0.001547,0.210996,0.002173,0.017980,0.010622,0.007591,0.024763,0.005548,0.017731
4,54.176324,54.062990,54.096183,54.206675,0.242802,0.551021,0.011926,0.162359,0.004900,0.009149,...,0.799334,0.000728,0.001260,0.016893,0.035866,0.032341,0.023735,0.065808,0.004516,0.019520
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2781,54.056738,54.005698,53.987055,54.019339,0.329116,0.554708,0.003787,0.069176,0.005085,0.006908,...,0.906225,0.000389,0.001199,0.003592,0.008141,0.012960,0.018531,0.011614,0.023295,0.014053
2782,54.153362,54.015701,53.997514,54.139607,0.282777,0.609307,0.008813,0.077601,0.002338,0.005874,...,0.886347,0.005208,0.001720,0.009530,0.037942,0.006629,0.010066,0.007919,0.003059,0.031579
2783,54.020021,53.882020,53.856102,54.076176,0.314654,0.296601,0.004491,0.063189,0.000139,0.015779,...,0.660975,0.000703,0.266958,0.002201,0.007998,0.006750,0.007735,0.016190,0.009632,0.020860
2784,54.336946,54.130740,54.120737,54.309892,0.011256,0.389451,0.010984,0.131739,0.000335,0.027921,...,0.604083,0.009161,0.239026,0.004958,0.022316,0.036116,0.025415,0.021926,0.017126,0.019873
