In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchtune
from model.model import SeqFTTransformer
from model.loader import get_loaders
from utils import get_pck_key_size, get_pck_value_size
import math
import multiprocessing as mp
import matplotlib.pyplot as plt
import pickle
mp.set_start_method('spawn')

In [None]:
num_cont = 25 # Number of continuous features
#TODO: should be set automatically

# Detect device
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

batch_size = 16
kv_caching = False
num_workers = 1

shift_cols = [
    'events',
    'description',
    'hit_location',
    'bb_type',
    'pfx_x',
    'pfx_z',
    'hc_x',
    'hc_y',
    'vx0',
    'vy0',
    'vz0',
    'ax',
    'ay',
    'az',
    'hit_distance_sc',
    'launch_speed',
    'launch_angle',
    'release_speed',
    'release_spin_rate',
    'release_extension',
    'release_pos_x',
    'release_pos_y',
    'release_pos_z',
    'spin_axis',
]

train_loader, val_loader, test_loader = get_loaders(
    'full_multi/data/1623_full_cleandata_bin.parquet',
    'full_multi/data/train_full_1623.pkl',
    'full_multi/data/val_full_1623.pkl',
    'full_multi/data/test_full_1623.pkl',
    'full_multi/mappings/pn_to_pn.pkl',
    num_cont=num_cont,
    batch_size=batch_size,
    num_workers=num_workers,
    columns_to_shift=shift_cols
    )

comb_map_files = [
    'full_multi/mappings/stand.pkl',
    'full_multi/mappings/p_throws.pkl',
    'full_multi/mappings/game_year.pkl',
    'full_multi/mappings/balls.pkl',
    'full_multi/mappings/strikes.pkl',
    'full_multi/mappings/on_3b.pkl',
    'full_multi/mappings/on_2b.pkl',
    'full_multi/mappings/on_1b.pkl',
    'full_multi/mappings/outs_when_up.pkl',
    'full_multi/mappings/inning_topbot.pkl',
    'full_multi/mappings/events.pkl',
    'full_multi/mappings/description.pkl',
    'full_multi/mappings/home_team.pkl',
    'full_multi/mappings/away_team.pkl', 
    'full_multi/mappings/hit_location.pkl',
    'full_multi/mappings/bb_type.pkl',
]

sep_map_files = [
    'full_multi/mappings/batter.pkl',
    'full_multi/mappings/pitcher.pkl',
    'full_multi/mappings/fielder_2.pkl',
]

comb_cat_sizes = tuple([get_pck_key_size(map_file) for map_file in comb_map_files])

sep_cat_sizes = tuple([get_pck_key_size(map_file) for map_file in sep_map_files])

sep_cat_emb_dims = tuple([4] * len(sep_cat_sizes))

tgt_cat_size = get_pck_key_size('full_multi/mappings/pn_to_pn.pkl')

out_cat_size = get_pck_value_size('full_multi/mappings/pn_to_pn.pkl')

model_config = {
    'comb_category_sizes' : comb_cat_sizes,
    'comb_category_emb_dim' : 8,
    'sep_category_sizes' : sep_cat_sizes,
    'sep_category_emb_dims' : sep_cat_emb_dims,
    'pad_idx' : 0,
    'num_continuous' : num_cont,
    'dim' : 32,
    'depth' : 1,
    'num_heads' : 8,
    'num_kv_heads' : 4,
    'tgt_categories' : tgt_cat_size,
    'max_seq_len' : 128,
    'max_batch_size' : batch_size,
    'out_categories' : out_cat_size,
    'kv_caching' : kv_caching,
    'attn_dropout' : 0.1,
    'hidden_mult' : 2
}
with open('trained_models/model_config.pkl', 'wb') as f:
    pickle.dump(model_config, f)

model = SeqFTTransformer(
    comb_category_sizes=model_config['comb_category_sizes'],
    comb_category_emb_dim=model_config['comb_category_emb_dim'],
    sep_category_sizes=model_config['sep_category_sizes'],
    sep_category_emb_dims=model_config['sep_category_emb_dims'],
    pad_idx=model_config['pad_idx'],
    num_continuous=model_config['num_continuous'],
    dim=model_config['dim'],
    depth=model_config['depth'],
    num_heads=model_config['num_heads'],
    num_kv_heads=model_config['num_kv_heads'],
    tgt_categories=model_config['tgt_categories'],
    max_seq_len=model_config['max_seq_len'],
    max_batch_size=model_config['max_batch_size'],
    out_categories=model_config['out_categories'],
    kv_caching=model_config['kv_caching'],
    attn_dropout=model_config['attn_dropout'],
    hidden_mult=model_config['hidden_mult']
).to(device)

In [None]:
#Define training
def train(model, train_loader, criterion, optimizer, scheduler, device):
    model.train()
    train_loss = 0

    for bs, ps, cs, cas, cos, ts, res in train_loader:
        optimizer.zero_grad()

        bs = bs.to(device)
        ps = ps.to(device)
        cs = cs.to(device)
        cas = cas.to(device)
        cos = cos.to(device)
        ts = ts.to(device)
        tgt = F.pad(ts[:, :-1], (1, 0), value=0)
        res = res.to(device)
        
        out = model(
            cas,
            torch.cat((bs.unsqueeze(2),ps.unsqueeze(2), cs.unsqueeze(2)),dim=2),
            cos,
            tgt
        )
        
        loss = criterion(out.transpose(2,1), res.to(dtype=torch.long))
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()
    return train_loss / len(train_loader)


def val(model, val_loader, criterion, mask_const, device):
    with torch.no_grad():
        model.eval()
        total_loss = 0
        total_acc = 0

        for bs, ps, cs, cas, cos, ts, res in val_loader:
            bs = bs.to(device)
            ps = ps.to(device)
            cs = cs.to(device)
            cas = cas.to(device)
            cos = cos.to(device)
            ts = ts.to(device)
            tgt = F.pad(ts[:, :-1], (1, 0), value=0)
            res = res.to(device)
            
            out = model(
                cas,
                torch.cat((bs.unsqueeze(2),ps.unsqueeze(2), cs.unsqueeze(2)),dim=2),
                cos,
                tgt
            )
            
            loss = criterion(out.transpose(2,1), res.to(dtype=torch.long))
            total_loss += loss.item()

            # Calculate accuracy
            preds = out.argmax(dim=2)

            accuracy = (preds == res).where(res != mask_const, torch.zeros(preds.size(), device=device)).sum().item() / (res != mask_const).sum().item()
            total_acc += accuracy
        return total_loss / len(val_loader), total_acc /len(val_loader)

In [None]:
"""from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/model_visualization')
ts = ts.to(torch.int64).to(device)
writer.add_graph(model, (torch.cat((bs.unsqueeze(2),ps.unsqueeze(2),cas),dim=2).squeeze(0),
    cos.squeeze(0)))
writer.close()  """

In [None]:
# Model hyperparameters
loss_func = nn.CrossEntropyLoss
learning_rate = 2e-5
opt_func = torch.optim.AdamW
epochs = 15
mask_const = 0


optimizer = opt_func(model.parameters(), lr=learning_rate)
criterion = loss_func(ignore_index=mask_const)

# Scheduling
num_train_samples = len(train_loader)
warmup_ratio = 0.1
total_steps = (num_train_samples // batch_size) * epochs
warmup_steps = int(total_steps * warmup_ratio)
scheduler = torchtune.training.get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)

In [None]:
torch.set_float32_matmul_precision('high')

opt_train = torch.compile(train)
train_losses = []
val_losses = []

# TRAIN LOOP
for epoch in range(epochs):
    train_loss = opt_train(model, train_loader, criterion, optimizer, scheduler, device)
    #train_loss = train(model, train_loader, criterion, optimizer, scheduler, device)

    if kv_caching:
        model.reset_caches()
    torch.cuda.empty_cache()

    val_loss, _ = val(model, val_loader, criterion, mask_const, device)

    if kv_caching:
        model.reset_caches()
    torch.cuda.empty_cache()

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.4f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val Loss: {val_loss:.4f} |  Val. PPL: {math.exp(val_loss):7.3f}')

    best_val_loss = None

    if not best_val_loss or val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'trained_models/new_sft_full_1623.pth')

    plt.figure()
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.show()


# test the best model
test_loss, test_acc = val(model, test_loader, criterion, mask_const, device)
print(f'\t Test Loss: {test_loss:.4f} |  Test Acc.: {test_acc:7.4f}')

In [None]:
test_loss, test_acc = val(model, test_loader, criterion, mask_const, device)
print(f'\t Test Loss: {test_loss:.4f} |  Test Acc.: {test_acc:7.4f}')