# Train the models! 👟

This notebook imports training methods and models from the `src/` directory in order to train the models.

In [2]:
# configure the output directory
import os

output_dir = "output/tcn_test"

if not os.path.exists(output_dir):
    os.mkdir(output_dir)

In [2]:
# check the device
from src.utils import device

print(device)

cuda


In [4]:
from src.train import run_train
from src.models import TCN

model = TCN(channels=[32, 8])

run_train(
    model=model,
    sequence_len=10,
    epochs=30,
    output_path=output_dir,
    save_return=True,
    # weight_decay=1e-5,
)

[2024-07-24 20:06:24] INFO     : src.train - Loading data...
[2024-07-24 20:06:27] INFO     : src.train - Beginning to train the network...
[2024-07-24 20:06:30] INFO     : src.train - EPOCH: 1/30
[2024-07-24 20:06:30] INFO     : src.train - Train loss: 4.7455, Train accuracy: 0.5175
[2024-07-24 20:06:30] INFO     : src.train - Val loss: 0.9294, Val accuracy: 0.5520

[2024-07-24 20:06:32] INFO     : src.train - EPOCH: 2/30
[2024-07-24 20:06:32] INFO     : src.train - Train loss: 0.8089, Train accuracy: 0.5431
[2024-07-24 20:06:32] INFO     : src.train - Val loss: 0.7181, Val accuracy: 0.5637

[2024-07-24 20:06:35] INFO     : src.train - EPOCH: 3/30
[2024-07-24 20:06:35] INFO     : src.train - Train loss: 0.7025, Train accuracy: 0.5690
[2024-07-24 20:06:35] INFO     : src.train - Val loss: 0.6937, Val accuracy: 0.5723

[2024-07-24 20:06:37] INFO     : src.train - EPOCH: 4/30
[2024-07-24 20:06:37] INFO     : src.train - Train loss: 0.6809, Train accuracy: 0.5869
[2024-07-24 20:06:37] INF

              precision    recall  f1-score   support

    AWAY_WIN       0.65      0.44      0.52       163
    HOME_WIN       0.61      0.79      0.69       184

    accuracy                           0.63       347
   macro avg       0.63      0.61      0.61       347
weighted avg       0.63      0.63      0.61       347



[2024-07-24 20:07:39] INFO     : src.train - Accuracy from next season: 0.6061


              precision    recall  f1-score   support

    AWAY_WIN       0.55      0.40      0.46       445
    HOME_WIN       0.63      0.76      0.69       606

    accuracy                           0.61      1051
   macro avg       0.59      0.58      0.58      1051
weighted avg       0.60      0.61      0.59      1051



[2024-07-24 20:07:40] INFO     : src.train - Accuracy on short streaks (training): 0.5138
[2024-07-24 20:07:40] INFO     : src.train - Accuracy on long streaks (training): 0.2761
[2024-07-24 20:07:40] INFO     : src.train - Accuracy on short streaks (evaluation): 0.4770
[2024-07-24 20:07:40] INFO     : src.train - Accuracy on long streaks (evaluation): 0.1515


RuntimeError: Serialization of parametrized modules is only supported through state_dict(). See:
https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training

In [11]:
# set the hyperparameters

hidden_size = 128

In [3]:
from src.models.rnn import RNN, LSTM
from src.train import run_train

Test:
- learning rate
- normalized
- hidden size
- regularization with `optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)`
- dropout

In [4]:
for model_class in [RNN, LSTM]:  # [GRU, RNN, LSTM]:
    for n in [10]:
        model = model_class(hidden_size=256, dropout=0.1)
        run_train(
            model=model,
            sequence_len=n,
            epochs=30,
            output_path=output_dir,
            save_return=True,
            weight_decay=1e-5,
        )

[2024-07-24 12:28:22] INFO     : src.train - Loading data...
[2024-07-24 12:28:28] INFO     : src.train - Beginning to train the network...
[2024-07-24 12:28:30] INFO     : src.train - EPOCH: 1/30
[2024-07-24 12:28:30] INFO     : src.train - Train loss: 0.6860, Train accuracy: 0.5654
[2024-07-24 12:28:30] INFO     : src.train - Val loss: 0.6700, Val accuracy: 0.5947

[2024-07-24 12:28:32] INFO     : src.train - EPOCH: 2/30
[2024-07-24 12:28:32] INFO     : src.train - Train loss: 0.6809, Train accuracy: 0.5720
[2024-07-24 12:28:32] INFO     : src.train - Val loss: 0.6695, Val accuracy: 0.5965

[2024-07-24 12:28:35] INFO     : src.train - EPOCH: 3/30
[2024-07-24 12:28:35] INFO     : src.train - Train loss: 0.6776, Train accuracy: 0.5739
[2024-07-24 12:28:35] INFO     : src.train - Val loss: 0.6708, Val accuracy: 0.5973

[2024-07-24 12:28:37] INFO     : src.train - EPOCH: 4/30
[2024-07-24 12:28:37] INFO     : src.train - Train loss: 0.6785, Train accuracy: 0.5775
[2024-07-24 12:28:37] INF

              precision    recall  f1-score   support

    AWAY_WIN       0.00      0.00      0.00       163
    HOME_WIN       0.53      1.00      0.69       184

    accuracy                           0.53       347
   macro avg       0.27      0.50      0.35       347
weighted avg       0.28      0.53      0.37       347



[2024-07-24 12:29:33] INFO     : src.train - Accuracy from next season: 0.5766


              precision    recall  f1-score   support

    AWAY_WIN       0.00      0.00      0.00       445
    HOME_WIN       0.58      1.00      0.73       606

    accuracy                           0.58      1051
   macro avg       0.29      0.50      0.37      1051
weighted avg       0.33      0.58      0.42      1051



[2024-07-24 12:29:35] INFO     : src.train - Accuracy on short streaks (training): 0.6103


              precision    recall  f1-score   support

    AWAY_WIN       0.00      0.00      0.00       719
    HOME_WIN       0.61      1.00      0.76      1126

    accuracy                           0.61      1845
   macro avg       0.31      0.50      0.38      1845
weighted avg       0.37      0.61      0.46      1845



[2024-07-24 12:29:35] INFO     : src.train - Accuracy on long streaks (training): 0.6212
[2024-07-24 12:29:36] INFO     : src.train - Accuracy on short streaks (evaluation): 0.5649


              precision    recall  f1-score   support

    AWAY_WIN       0.00      0.00      0.00       236
    HOME_WIN       0.62      1.00      0.77       387

    accuracy                           0.62       623
   macro avg       0.31      0.50      0.38       623
weighted avg       0.39      0.62      0.48       623

              precision    recall  f1-score   support

    AWAY_WIN       0.00      0.00      0.00       104
    HOME_WIN       0.56      1.00      0.72       135

    accuracy                           0.56       239
   macro avg       0.28      0.50      0.36       239
weighted avg       0.32      0.56      0.41       239



[2024-07-24 12:29:36] INFO     : src.train - Accuracy on long streaks (evaluation): 0.5758


              precision    recall  f1-score   support

    AWAY_WIN       0.00      0.00      0.00        28
    HOME_WIN       0.58      1.00      0.73        38

    accuracy                           0.58        66
   macro avg       0.29      0.50      0.37        66
weighted avg       0.33      0.58      0.42        66



[2024-07-24 12:29:36] INFO     : src.train - Loading data...
[2024-07-24 12:29:40] INFO     : src.train - Beginning to train the network...
[2024-07-24 12:29:43] INFO     : src.train - EPOCH: 1/30
[2024-07-24 12:29:43] INFO     : src.train - Train loss: 0.6732, Train accuracy: 0.5852
[2024-07-24 12:29:43] INFO     : src.train - Val loss: 0.6608, Val accuracy: 0.6003

[2024-07-24 12:29:46] INFO     : src.train - EPOCH: 2/30
[2024-07-24 12:29:46] INFO     : src.train - Train loss: 0.6665, Train accuracy: 0.5978
[2024-07-24 12:29:46] INFO     : src.train - Val loss: 0.6893, Val accuracy: 0.5270

[2024-07-24 12:29:49] INFO     : src.train - EPOCH: 3/30
[2024-07-24 12:29:49] INFO     : src.train - Train loss: 0.6675, Train accuracy: 0.5946
[2024-07-24 12:29:49] INFO     : src.train - Val loss: 0.6533, Val accuracy: 0.6180

[2024-07-24 12:29:51] INFO     : src.train - EPOCH: 4/30
[2024-07-24 12:29:51] INFO     : src.train - Train loss: 0.6626, Train accuracy: 0.6044
[2024-07-24 12:29:51] INF

              precision    recall  f1-score   support

    AWAY_WIN       0.69      0.12      0.21       163
    HOME_WIN       0.55      0.95      0.70       184

    accuracy                           0.56       347
   macro avg       0.62      0.54      0.45       347
weighted avg       0.62      0.56      0.47       347



[2024-07-24 12:30:56] INFO     : src.train - Accuracy from next season: 0.5595


              precision    recall  f1-score   support

    AWAY_WIN       0.45      0.20      0.27       445
    HOME_WIN       0.58      0.83      0.68       606

    accuracy                           0.56      1051
   macro avg       0.52      0.51      0.48      1051
weighted avg       0.53      0.56      0.51      1051



[2024-07-24 12:30:57] INFO     : src.train - Accuracy on short streaks (training): 0.4591


              precision    recall  f1-score   support

    AWAY_WIN       0.10      0.05      0.06       719
    HOME_WIN       0.54      0.72      0.62      1126

    accuracy                           0.46      1845
   macro avg       0.32      0.38      0.34      1845
weighted avg       0.37      0.46      0.40      1845



[2024-07-24 12:30:57] INFO     : src.train - Accuracy on long streaks (training): 0.4944
[2024-07-24 12:30:58] INFO     : src.train - Accuracy on short streaks (evaluation): 0.4519


              precision    recall  f1-score   support

    AWAY_WIN       0.11      0.05      0.07       236
    HOME_WIN       0.57      0.77      0.65       387

    accuracy                           0.49       623
   macro avg       0.34      0.41      0.36       623
weighted avg       0.39      0.49      0.43       623

              precision    recall  f1-score   support

    AWAY_WIN       0.14      0.05      0.07       104
    HOME_WIN       0.51      0.76      0.61       135

    accuracy                           0.45       239
   macro avg       0.32      0.41      0.34       239
weighted avg       0.35      0.45      0.38       239



[2024-07-24 12:30:58] INFO     : src.train - Accuracy on long streaks (evaluation): 0.4545


              precision    recall  f1-score   support

    AWAY_WIN       0.10      0.04      0.05        28
    HOME_WIN       0.52      0.76      0.62        38

    accuracy                           0.45        66
   macro avg       0.31      0.40      0.33        66
weighted avg       0.34      0.45      0.38        66



In [None]:
# utility function to work out current streak from list of results
def current_streak(game_results: list[int]) -> int:
    streak = 1
    for g in game_results[1:]:
        if g != game_results[0]:
            break
        streak += 1
    return streak

In [None]:
current_streak([0, 0, 0, 0, 0, 0, 1])

In [None]:
import pandas as pd

df = pd.read_parquet("data/parquet/evaluation_streaks_short_df.parquet")

In [None]:
for i, x in df.iterrows():
    print(x["info"])
    for y in x["data"]:
        print(y)

    break

In [43]:
from src.predict import load_record_from_csv
import torch

data_one = load_record_from_csv(
    file_path="data/predict_csv/lac_home_win_vs_lal_2023_04_05.csv",
)
data_two = load_record_from_csv(
    file_path="data/predict_csv/ind_home_loss_vs_nyk_2023_04_05.csv",
)

print(data_one.shape)
print(data_two.shape)

data = torch.cat((data_one, data_two))

print(data.shape)

torch.Size([1, 10, 116])
torch.Size([1, 10, 116])
torch.Size([2, 10, 116])


In [8]:
from pytorch_tcn import TCN as _TCN
import torch.nn as nn
import torch


class TCN(nn.Module):
    """Base class for RNN models."""

    def __init__(self, hidden_size: int, input_size: int = 116, dropout: float = 0.0):
        super(TCN, self).__init__()

        self._hidden_size = hidden_size
        self._input_size = input_size
        self._dropout = dropout
        self._output_size = 1
        self._num_layers = 1

        self.tcn = _TCN(
            num_inputs=input_size,
            num_channels=[64, 16, 4],
            output_projection=1,
            output_activation="sigmoid",
            input_shape="NLC",
        )
        # self.linear = nn.Linear(self._hidden_size, self._output_size)
        # self.sigmoid = nn.Sigmoid()

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = self.tcn(input)
        output = output[:, -1, :]
        return output

    def init_zeroes(self, batch_size: int):
        zeroes = torch.zeros(1, batch_size, self._hidden_size)
        zeroes = zeroes.to(device)
        return zeroes

    def __repr__(self) -> str:
        return type(self).__name__

In [51]:
print(data.shape)

pred = model(data)

print(pred.shape)

print(pred)

torch.Size([2, 10, 116])
torch.Size([2, 10, 4])
tensor([[[5.3738e-01, 5.6585e-01, 5.0000e-01, 5.0000e-01],
         [8.0683e-01, 9.2744e-01, 5.0000e-01, 5.0000e-01],
         [1.3775e-02, 1.8514e-04, 5.0000e-01, 5.0000e-01],
         [9.9998e-01, 1.3542e-04, 5.0000e-01, 5.0000e-01],
         [4.8602e-09, 7.8826e-01, 5.0000e-01, 5.0000e-01],
         [1.1819e-07, 6.4642e-06, 5.0000e-01, 5.0000e-01],
         [7.2147e-01, 9.9929e-01, 5.0000e-01, 5.0000e-01],
         [9.9229e-01, 7.7033e-03, 5.0000e-01, 5.0000e-01],
         [3.7591e-02, 6.5476e-04, 5.0000e-01, 5.0000e-01],
         [1.0000e+00, 1.0000e+00, 5.0000e-01, 5.0000e-01]],

        [[4.6262e-01, 4.3415e-01, 5.0000e-01, 5.0000e-01],
         [1.9317e-01, 7.2559e-02, 5.0000e-01, 5.0000e-01],
         [9.8623e-01, 9.9981e-01, 5.0000e-01, 5.0000e-01],
         [1.9689e-05, 9.9986e-01, 5.0000e-01, 5.0000e-01],
         [1.0000e+00, 2.1174e-01, 5.0000e-01, 5.0000e-01],
         [1.0000e+00, 9.9999e-01, 5.0000e-01, 5.0000e-01],
      

In [10]:
from src.train import run_train

model = TCN(hidden_size=0)

run_train(
    model=model,
    sequence_len=10,
    epochs=30,
    output_path=output_dir,
    save_return=True,
    weight_decay=1e-5,
)

[2024-07-24 18:59:22] INFO     : src.train - Loading data...


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
