In [None]:
from Data_Prep import Data_Prep, Player_IO
from sklearn.model_selection import train_test_split # type: ignore
import torch
import Prep_Map
import Output_Map
from Player_Dataset import Create_Test_Train_Datasets
import Model_Train
from Player_Model import RNN_Model, LayerArch
from Player_Model import DEFAULT_PT_ARCH_P, DEFAULT_STATS_ARCH_P, DEFAULT_WARCLASS_ARCH_P, DEFAULT_POS_ARCH_P, DEFAULT_LVL_ARCH_P, DEFAULT_PA_ARCH_P, DEFAULT_VALUE_ARCH_P
from tqdm.notebook import tqdm
from Constants import device
import seaborn as sns
from pylab import savefig
from matplotlib import pyplot as plt
from Constants import DEFAULT_HIDDEN_SIZE_PITCHER, DEFAULT_NUM_LAYERS_PITCHER
from pathlib import Path

output_idx = 6
element_name = Model_Train.ELEMENT_LIST[output_idx]
filename = element_name + "Arch"

outfile_idx = 0
while True:
    outfile = f'Experiments/{filename}_P_{outfile_idx}.png'
    outfile_path = Path(outfile)
    if not outfile_path.exists():
        break
    outfile_idx += 1

batch_size = 800
ys = range(2, 5)
xs = range(20, 201, 10)
data = []
title = element_name + " Loss vs Ouput Layer Architecture"

x_label = "Num Layers"
y_label = "Layer Size"

data_prep = Data_Prep(Prep_Map.base_prep_map, Output_Map.base_output_map)
pitcher_io_list = data_prep.Generate_IO_Pitchers("WHERE lastMLBSeason<? AND signingYear<? AND isHitter=?", (2025,2015,0), use_cutoff=True)
train_dataset, test_dataset = Create_Test_Train_Datasets(pitcher_io_list, 0.25, 0)

num_layers = DEFAULT_NUM_LAYERS_PITCHER
hidden_size = DEFAULT_HIDDEN_SIZE_PITCHER
pitching_mutators = data_prep.Generate_Pitching_Mutators(batch_size, Player_IO.GetMaxLength(pitcher_io_list))

with tqdm(total=len(xs) * len(ys), desc="Total Training Runs") as pbar:
    for num_layers in ys:
        z = []
        for layer_size in xs:
            test_arch = LayerArch(layer_size=layer_size, num_layers=num_layers)
            
            warclass_arch = test_arch if output_idx == 0 else DEFAULT_WARCLASS_ARCH_P
            lvl_arch = test_arch if output_idx == 1 else DEFAULT_LVL_ARCH_P
            pa_arch = test_arch if output_idx == 2 else DEFAULT_PA_ARCH_P
            stats_arch = test_arch if output_idx == 3 else DEFAULT_STATS_ARCH_P
            pos_arch = test_arch if output_idx == 4 else DEFAULT_POS_ARCH_P
            val_arch = test_arch if output_idx == 5 else DEFAULT_VALUE_ARCH_P
            pt_arch = test_arch if output_idx == 6 else DEFAULT_PT_ARCH_P
            
            network = RNN_Model(train_dataset.get_input_size(), 
                                num_layers, 
                                hidden_size, 
                                pitching_mutators, 
                                data_prep=data_prep, 
                                is_hitter=False, 
                                pt_arch_p = pt_arch,
                                stats_arch_p=stats_arch,
                                warclass_arch_p=warclass_arch,
                                pos_arch_p=pos_arch,
                                lvl_arch_p=lvl_arch,
                                pa_arch_p=pa_arch,
                                val_arch_p=val_arch)
            network = network.to(device)

            training_generator = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            testing_generator = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
            
            
            best_loss = Model_Train.trainAndGraph(network, 
                                                    train_dataset,
                                                    test_dataset,
                                                    batch_size=batch_size,
                                                    num_epochs=41, 
                                                    logging_interval=10000, 
                                                    early_stopping_cutoff=40, 
                                                    should_output=False, 
                                                    model_name=f"Models/exp",
                                                    save_last=False,
                                                    elements_to_save=[output_idx])
            z.append(best_loss[0])
            pbar.update(1)
        data.append(z)

In [None]:
plt.figure(figsize=(1 * len(xs), .75 * len(ys) + 2))
heatmap = sns.heatmap(data, xticklabels=xs, yticklabels=ys, annot=True, fmt=".3f")
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.title(title)
plt.tight_layout(pad=0.25)
plt.savefig(outfile, dpi=400)