In [15]:
import pandas as pd

In [16]:
mini = pd.read_csv("statcast_mini_cleaned.csv")
mini

Unnamed: 0,launch_speed,game_date,release_speed,release_pos_x,release_pos_z,batter,pitcher,balls,strikes,pfx_x,...,hit_location_2.0,hit_location_3.0,hit_location_4.0,hit_location_5.0,hit_location_6.0,hit_location_7.0,hit_location_8.0,hit_location_9.0,outs_when_up_1,outs_when_up_2
0,-0.627173,2024-06-30,-1.056521,-1.498172,-2.802440,680869,623149,0.121914,1.339073,1.591409,...,True,False,False,False,False,False,False,False,False,True
1,1.645207,2024-06-30,-1.089609,-1.535090,-2.936805,680869,623149,0.121914,1.339073,1.568781,...,False,False,False,False,False,False,False,False,False,True
2,1.081636,2024-06-30,0.333173,-1.455979,-2.572100,680869,623149,0.121914,1.339073,-0.773192,...,False,False,False,False,False,False,False,False,False,True
3,-0.627173,2024-06-30,-0.891082,-1.503446,-2.840830,680869,623149,-0.912177,1.339073,1.783745,...,False,False,False,False,False,False,False,False,False,True
4,1.668473,2024-06-30,0.184277,-1.445430,-2.744855,680869,623149,-0.912177,0.129561,-0.626111,...,False,False,False,False,False,False,False,False,False,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99995,1.409955,2024-06-05,1.110740,-0.169095,0.691058,643396,544150,-0.912177,1.339073,-0.569542,...,False,False,False,False,False,False,False,False,False,True
99996,2.035570,2024-06-05,-0.527114,-0.438075,0.115207,643396,544150,-0.912177,0.129561,0.290313,...,False,False,False,False,False,False,False,False,False,True
99997,1.099732,2024-06-05,0.994932,-0.332592,0.287962,643396,544150,-0.912177,-1.079951,-0.535600,...,False,False,False,False,False,False,False,False,False,True
99998,1.877874,2024-06-05,0.961844,-0.422252,0.230377,672386,544150,2.190096,1.339073,-0.614797,...,False,False,False,False,False,False,True,False,True,False


In [45]:
#Note - currently ignoring instances where a batter does not have at least sequence_length appearances in the data.
# This will potentially result in a loss of training examples, and could be addressed by padding so sequences would always be the same length.

import torch
from torch.utils.data import Dataset
from torch.masked import masked_tensor
import json
import pandas
import numpy as np


class BaseballDataset(Dataset):
    def __init__(self, data, config_path, sequence_length, encode_pos=False, masked_tensor=False, seed=42):
        self.seed = seed
        self.set_seed()
        self.encode_pos = encode_pos
        self.masked_tensor = masked_tensor
        self.config = self.load_config(config_path)

        self.data = data
        if not self.masked_tensor:
            self.data = self.add_mask_dimensions(data)
        
        
        self.sequence_length = sequence_length
        self.label_columns = self.get_label_columns()
        self.metadata_columns = self.get_metadata_columns()
        self.categorical_columns = self.get_categorical_columns()
        self.mean_values = self.get_mean_values()
        self.processed_pitches = []
        self.sequences = []
        self.process_all_pitches()
        self.continuous_label_indices, self.categorical_label_indices = self.get_label_indices()

        if self.masked_tensor:
            self.mask = self.create_mask()

        self.prepare_sequences()
    
    def set_seed(self):
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)

    
    def load_config(self, config_path):
        with open(config_path, 'r') as file:
            config = json.load(file)
        return config
    
    def get_label_columns(self):
        return {column for column, settings in self.config.items() if settings.get('label', False)}
    
    def get_metadata_columns(self):
        return {column for column, settings in self.config.items() if settings.get('metadata', False)}
    
    def get_categorical_columns(self):
        return {column for column, settings in self.config.items() if settings.get('categorical', False)}
    
    def get_mean_values(self):
        continuous_label_columns = list(self.label_columns - self.categorical_columns)
        return self.data[continuous_label_columns].mean().to_dict()
    
    def add_mask_dimensions(self, data):
        config = self.config
        for column in config:
            if config[column].get('label', False) and config[column].get('categorical', False):
                mask_column = f"{column}_mask"
                data[mask_column] = 0
        return data
    
    def get_label_indices(self):
        sample_pitch = self.processed_pitches[0][0]
        categorical_label_indices = []
        continuous_label_indices = []

        for key in self.label_columns:
            if key in self.categorical_columns:
                for idx, col in enumerate(sample_pitch):
                    if col.startswith(key):
                        categorical_label_indices.append(idx)
            else:
                for idx, col in enumerate(sample_pitch):
                    if col == key:
                        continuous_label_indices.append(idx)
                        
        return continuous_label_indices, categorical_label_indices
    
    def create_mask(self):

        pitch_dim = len(self.processed_pitches[0][0])
        sequence_mask = torch.ones((self.sequence_length,pitch_dim))

        for i in range(pitch_dim):
            if i in self.continuous_label_indices or i in self.categorical_label_indices:
                sequence_mask[-1][i] = 0
        
        return sequence_mask == 1
            
    
    def process_all_pitches(self):
        for index, row in self.data.iterrows():
            pitch_data, pitch_metadata = self.process_pitch(row)
            self.processed_pitches.append((pitch_data, pitch_metadata))
    
    def prepare_sequences(self):
        grouped = self.data.groupby('batter')
        
        for batter, group in grouped:
            group = group.sort_values(by=['game_date', 'at_bat_number'])
            indices = group.index.tolist()
            
            for i in range(len(indices) - self.sequence_length):
                sequence_indices = indices[i:i + self.sequence_length]
                self.sequences.append(sequence_indices)

 
    
    def process_pitch(self, pitch):
        pitch_data = {}
        pitch_metadata = {}
        
        for key, value in pitch.items():
            if key in self.metadata_columns:
                pitch_metadata[key] = value
            else:
                pitch_data[key] = value
        
        return pitch_data, pitch_metadata
    
    def mask_values(self, pitch):

        if not self.masked_tensor: #fill w mean/mode

            for key in self.label_columns:
                if key in self.categorical_columns:
                    for col in pitch:
                        if col.startswith(key):
                            pitch[col] = 0
                    pitch[f'{key}_mask'] = 1
                elif key not in self.metadata_columns:
                    pitch[key] = self.mean_values[key]
            return pitch



    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence_indices = self.sequences[idx]


        sequence = []
        metadata = []
        for i in sequence_indices:
            sequence.append(self.processed_pitches[i][0].copy())
            metadata.append(self.processed_pitches[i][1])

        if self.encode_pos:
            # Add positional encoding
            for i, pitch in enumerate(sequence):
                pitch['pos'] = i / self.sequence_length

        #target is unmasked last pitch in sequence
        target = sequence[-1].copy()
        
        if not self.masked_tensor:
            # Mask the last pitch in the sequence
            sequence[-1] = self.mask_values(sequence[-1])

        # Convert to tensor
        sequence_tensor = self.sequence_to_tensor(sequence)

        if self.masked_tensor:
            sequence_tensor = masked_tensor(sequence_tensor,self.mask)

        target = self.pitch_to_tensor(target)
        cont_target_tensor = torch.index_select(target,0,torch.LongTensor(self.continuous_label_indices))
        cat_target_tensor = torch.index_select(target,0,torch.LongTensor(self.categorical_label_indices))
        
        return sequence_tensor, cont_target_tensor, cat_target_tensor, metadata
    
    def sequence_to_tensor(self, sequence):
        # Convert the list of pitch dictionaries to a tensor
        sequence_tensor = torch.stack([self.pitch_to_tensor(pitch) for pitch in sequence])
        return sequence_tensor
    
    def pitch_to_tensor(self, pitch):
        # Convert a single pitch dictionary to a tensor, excluding metadata columns
        return torch.tensor(list(pitch.values()), dtype=torch.float)



In [48]:
config_path = 'config.json'
dataset = BaseballDataset(mini,config_path,10)

In [32]:
dataset.processed_pitches[0][0]

{'launch_speed': -0.627173396565644,
 'release_speed': -1.0565214549009387,
 'release_pos_x': -1.498171506813079,
 'release_pos_z': -2.802440118206849,
 'balls': 0.1219138180002686,
 'strikes': 1.3390725358479865,
 'pfx_x': 1.5914091999059314,
 'pfx_z': -0.0072048435585131,
 'plate_x': -0.3428798837008528,
 'plate_z': -0.3272326885736035,
 'hc_x': -0.4302601880736678,
 'hc_y': -0.4262360093225854,
 'vy0': 1.0505266932171395,
 'vz0': 1.1445950008515329,
 'ax': 1.2917317789396796,
 'ay': -1.101446080193421,
 'az': -0.2937865271321216,
 'sz_top': 0.3025393953569042,
 'sz_bot': -0.6798670794382925,
 'launch_angle': -0.6065380417109826,
 'release_spin_rate': 0.6924009058330866,
 'release_extension': 0.4403710672229072,
 'release_pos_y': -0.5358001046128373,
 'events_S': False,
 'events_double': False,
 'events_field_out': False,
 'events_hit_by_pitch': False,
 'events_home_run': False,
 'events_single': False,
 'events_strikeout': True,
 'events_triple': False,
 'events_walk': False,
 'pitc

In [50]:
dataset[0]

(tensor([[ 1.2781,  1.4747, -0.4328,  0.4031, -0.9122,  0.1296, -1.4407, -0.2483,
          -0.7815, -0.2232,  1.8923,  2.5658, -1.4625, -0.7981, -1.6891,  1.9647,
          -0.1214,  0.6992,  1.1819,  1.2917, -1.1412,  0.8265, -0.9923,  0.0000,
           0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.0000,  0.0000,
           0.0000],
         [-0.6272,  1.6898, -0.4275,  0.3647, -0.9122, -1.0800, -0.7958,  0.4616,
          -0.8393,  0.3072, -0.4303, -0.4262, -1.6779, -0.8118, -0.9601,  1.8140,
           0.7238,  0.6000,  0.8581, -0.6065, -0.3970,  0.8265, -0.9705,  1.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.00

In [39]:
dataset.sequences[0]

[98180, 98181, 98080, 98081, 98082, 98083, 98018, 98019, 98020, 95537]

In [44]:
seq,tar,m = dataset[82950]

In [45]:
seq[0].dtype

torch.float32

In [72]:
input

MaskedTensor(
  [
    [  1.2781,   1.4747,  -0.4328,   0.4031,  -0.9122,   0.1296,  -1.4407,  -0.2483,  -0.7815,  -0.2232,   1.8923,   2.5658,  -1.4625,  -0.7981,  -1.6891,   1.9647,  -0.1214,   0.6992,   1.1819,   1.2917,  -1.1412,   0.8265,  -0.9923,   0.0000,   0.0000,   1.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   1.0000,   0.0000,   0.0000,   0.0000,   0.0000,   1.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   1.0000,   0.0000,   0.0000,   0.0000,   0.0000,   1.0000],
    [ -0.6272,   1.6898,  -0.4275,   0.3647,  -0.9122,  -1.0800,  -0.7958,   0.4616,  -0.8393,   0.3072,  -0.4303,  -0.4262,  -1.6779,  -0.8118,  -0.9601,   1.8140,   0.7238,   0.6000,   0.8581,  -0.6065,  -0.3970,   0.8265,  -0.9705,   1.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,  

In [73]:
target

tensor([ 1.7460,  1.1107, -0.5488, -0.2303,  0.1219,  1.3391, -0.8750,  1.0912,
        -0.2852,  0.2968,  0.8603,  1.7207, -1.0755, -0.6795, -1.0123,  1.8586,
         1.2850,  0.6992,  1.0200,  1.6440,  0.4553,  0.6334, -0.8401,  0.0000,
         0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  1.0000])

In [74]:
metadata

[{'game_date': '2024-06-05',
  'batter': 444482,
  'pitcher': 667755,
  'game_pk': 746221,
  'at_bat_number': 10,
  'batter_name': 'david peralta',
  'pitcher_name': 'josé soriano'},
 {'game_date': '2024-06-05',
  'batter': 444482,
  'pitcher': 667755,
  'game_pk': 746221,
  'at_bat_number': 10,
  'batter_name': 'david peralta',
  'pitcher_name': 'josé soriano'},
 {'game_date': '2024-06-05',
  'batter': 444482,
  'pitcher': 667755,
  'game_pk': 746221,
  'at_bat_number': 30,
  'batter_name': 'david peralta',
  'pitcher_name': 'josé soriano'},
 {'game_date': '2024-06-05',
  'batter': 444482,
  'pitcher': 667755,
  'game_pk': 746221,
  'at_bat_number': 30,
  'batter_name': 'david peralta',
  'pitcher_name': 'josé soriano'},
 {'game_date': '2024-06-05',
  'batter': 444482,
  'pitcher': 667755,
  'game_pk': 746221,
  'at_bat_number': 30,
  'batter_name': 'david peralta',
  'pitcher_name': 'josé soriano'},
 {'game_date': '2024-06-05',
  'batter': 444482,
  'pitcher': 667755,
  'game_pk': 74

In [43]:

from torch.utils.data import DataLoader
batch_size = 32  # Define your batch size
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

In [44]:
for b in dataloader:
    s,t,m,cont,cat = b
    break

If you would like this operator to be supported, please file an issue for a feature request at https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.
In the case that the semantics for the operator are not trivial, it would be appreciated to also include a proposal for the semantics.


TypeError: Multiple dispatch failed for 'torch._ops.aten.stack.default'; all __torch_dispatch__ handlers returned NotImplemented:

  - tensor subclass <class 'torch.masked.maskedtensor.core.MaskedTensor'>

For more information, try re-running with TORCH_LOGS=not_implemented

In [17]:
print(s.shape,t.shape,len(m))

torch.Size([32, 10, 66]) torch.Size([32, 66]) 10
