# Train the models! 👟

This notebook imports training methods and models from the `src/` directory in order to train the models.

In [1]:
# setup imports
from src.utils import device, check_dir
from src.train import run_train
from src import models
from torchinfo import summary
import torch


print(device)  # check the device

torch.manual_seed(0)  # make training reproducable

cuda


<torch._C.Generator at 0x21bc3505150>

## RNN - Recurrent Neural Network

In [3]:
rnn_model = models.RNN(
    hidden_size=2048,
)

summary(rnn_model)

Layer (type:depth-idx)                   Param #
RNN                                      --
├─RNN: 1-1                               4,435,968
├─Linear: 1-2                            2,049
├─Sigmoid: 1-3                           --
Total params: 4,438,017
Trainable params: 4,438,017
Non-trainable params: 0

In [None]:
_m, _h = run_train(
    # fixed
    model=rnn_model,
    sequence_len=8,
    data_as_sequence=True,
    output_path=check_dir("rnn"),
    # adjust
    epochs=30,
    init_learning_rate=1e-3,
    weight_decay=1e-5,
    # data_use_all=True,
)

# LSTM - Long Short-Term Memory

In [2]:
lstm_model = models.LSTM(
    hidden_size=1024,
)

summary(lstm_model)

Layer (type:depth-idx)                   Param #
LSTM                                     --
├─LSTM: 1-1                              4,677,632
├─Linear: 1-2                            1,025
├─Sigmoid: 1-3                           --
Total params: 4,678,657
Trainable params: 4,678,657
Non-trainable params: 0

In [3]:
_m, _h = run_train(
    # fixed
    model=lstm_model,
    sequence_len=8,
    data_as_sequence=True,
    output_path=check_dir("lstm"),
    # adjust
    epochs=30,
    init_learning_rate=1e-3,
    weight_decay=1e-5,
    # data_use_all=True,
)

[2024-08-11 17:03:21] INFO     : src.train - Loading data...
[2024-08-11 17:03:26] INFO     : src.train - Beginning to train the network...
[2024-08-11 17:03:29] INFO     : src.train - EPOCH: 1/30
[2024-08-11 17:03:29] INFO     : src.train - Train loss: 0.6831, Train accuracy: 0.5747
[2024-08-11 17:03:29] INFO     : src.train - Val loss: 0.6745, Val accuracy: 0.5801

[2024-08-11 17:03:33] INFO     : src.train - EPOCH: 2/30
[2024-08-11 17:03:33] INFO     : src.train - Train loss: 0.6768, Train accuracy: 0.5771
[2024-08-11 17:03:33] INFO     : src.train - Val loss: 0.6672, Val accuracy: 0.6003

[2024-08-11 17:03:36] INFO     : src.train - EPOCH: 3/30
[2024-08-11 17:03:36] INFO     : src.train - Train loss: 0.6723, Train accuracy: 0.5866
[2024-08-11 17:03:36] INFO     : src.train - Val loss: 0.6661, Val accuracy: 0.5952

[2024-08-11 17:03:39] INFO     : src.train - EPOCH: 4/30
[2024-08-11 17:03:39] INFO     : src.train - Train loss: 0.6770, Train accuracy: 0.5783
[2024-08-11 17:03:39] INF

              precision    recall  f1-score   support

    AWAY_WIN       0.68      0.17      0.27       163
    HOME_WIN       0.56      0.93      0.70       184

    accuracy                           0.57       347
   macro avg       0.62      0.55      0.48       347
weighted avg       0.61      0.57      0.49       347



[2024-08-11 17:05:01] INFO     : src.train - Accuracy from next season: 0.5747


              precision    recall  f1-score   support

    AWAY_WIN       0.49      0.10      0.16       445
    HOME_WIN       0.58      0.92      0.71       606

    accuracy                           0.57      1051
   macro avg       0.54      0.51      0.44      1051
weighted avg       0.54      0.57      0.48      1051



[2024-08-11 17:05:02] INFO     : src.train - Accuracy on short streaks (training): 0.5247
[2024-08-11 17:05:02] INFO     : src.train - Accuracy on long streaks (training): 0.5233
[2024-08-11 17:05:02] INFO     : src.train - Accuracy on short streaks (evaluation): 0.4728
[2024-08-11 17:05:02] INFO     : src.train - Accuracy on long streaks (evaluation): 0.4394


# GRU - Gated Recurrent Unit

In [3]:
gru_model = models.GRU(
    hidden_size=1024,
)

summary(gru_model)

Layer (type:depth-idx)                   Param #
GRU                                      --
├─GRU: 1-1                               3,508,224
├─Linear: 1-2                            1,025
├─Sigmoid: 1-3                           --
Total params: 3,509,249
Trainable params: 3,509,249
Non-trainable params: 0

In [4]:
_m, _h = run_train(
    # fixed
    model=rnn_model,
    sequence_len=8,
    data_as_sequence=True,
    output_path=check_dir("gru"),
    # adjust
    epochs=30,
    init_learning_rate=1e-3,
    weight_decay=1e-5,
    # data_use_all=True,
)

[2024-08-11 16:41:32] INFO     : src.train - Loading data...
[2024-08-11 16:41:36] INFO     : src.train - Beginning to train the network...
[2024-08-11 16:41:41] INFO     : src.train - EPOCH: 1/30
[2024-08-11 16:41:41] INFO     : src.train - Train loss: 0.6771, Train accuracy: 0.5806
[2024-08-11 16:41:41] INFO     : src.train - Val loss: 0.6321, Val accuracy: 0.6383

[2024-08-11 16:41:45] INFO     : src.train - EPOCH: 2/30
[2024-08-11 16:41:45] INFO     : src.train - Train loss: 0.6487, Train accuracy: 0.6237
[2024-08-11 16:41:45] INFO     : src.train - Val loss: 0.6272, Val accuracy: 0.6457

[2024-08-11 16:41:49] INFO     : src.train - EPOCH: 3/30
[2024-08-11 16:41:49] INFO     : src.train - Train loss: 0.6396, Train accuracy: 0.6332
[2024-08-11 16:41:49] INFO     : src.train - Val loss: 0.6168, Val accuracy: 0.6539

[2024-08-11 16:41:52] INFO     : src.train - EPOCH: 4/30
[2024-08-11 16:41:52] INFO     : src.train - Train loss: 0.6376, Train accuracy: 0.6372
[2024-08-11 16:41:52] INF

              precision    recall  f1-score   support

    AWAY_WIN       0.71      0.43      0.53       163
    HOME_WIN       0.62      0.84      0.72       184

    accuracy                           0.65       347
   macro avg       0.67      0.64      0.63       347
weighted avg       0.66      0.65      0.63       347



[2024-08-11 16:43:33] INFO     : src.train - Accuracy from next season: 0.6051


              precision    recall  f1-score   support

    AWAY_WIN       0.55      0.37      0.44       445
    HOME_WIN       0.63      0.78      0.70       606

    accuracy                           0.61      1051
   macro avg       0.59      0.57      0.57      1051
weighted avg       0.59      0.61      0.59      1051



[2024-08-11 16:43:34] INFO     : src.train - Accuracy on short streaks (training): 0.4959
[2024-08-11 16:43:34] INFO     : src.train - Accuracy on long streaks (training): 0.3098
[2024-08-11 16:43:34] INFO     : src.train - Accuracy on short streaks (evaluation): 0.4686
[2024-08-11 16:43:35] INFO     : src.train - Accuracy on long streaks (evaluation): 0.1667


# TCN - Temporal Convolutional Network

In [None]:
tcn_model = models.TCN(
    channels=[32, 16, 4],
)

summary(tcn_model)

In [None]:
_m, _h = run_train(
    # fixed
    model=rnn_model,
    sequence_len=8,
    data_as_sequence=True,
    output_path=check_dir("tcn"),
    # adjust
    epochs=30,
    init_learning_rate=1e-3,
    weight_decay=1e-5,
    # data_use_all=True,
)

## TE - Tranformer Encoder

In [4]:
te_model = models.TE(
    hidden_size=2048,
)

summary(te_model)

Layer (type:depth-idx)                                  Param #
TE                                                      --
├─TransformerEncoderLayer: 1-1                          --
│    └─MultiheadAttention: 2-1                          2,586,336
│    │    └─NonDynamicallyQuantizableLinear: 3-1        862,112
│    └─Linear: 2-2                                      1,902,592
│    └─Dropout: 2-3                                     --
│    └─Linear: 2-4                                      1,901,472
│    └─LayerNorm: 2-5                                   1,856
│    └─LayerNorm: 2-6                                   1,856
│    └─Dropout: 2-7                                     --
│    └─Dropout: 2-8                                     --
├─Linear: 1-2                                           929
├─Sigmoid: 1-3                                          --
Total params: 7,257,153
Trainable params: 7,257,153
Non-trainable params: 0

In [5]:
_m, _h = run_train(
    # fixed
    model=rnn_model,
    sequence_len=8,
    data_as_sequence=False,
    output_path=check_dir("te"),
    # adjust
    epochs=30,
    init_learning_rate=1e-4,
    weight_decay=1e-5,
    # data_use_all=True,
)

[2024-08-11 17:21:42] INFO     : src.train - Loading data...
[2024-08-11 17:21:46] INFO     : src.train - Beginning to train the network...
[2024-08-11 17:21:48] INFO     : src.train - EPOCH: 1/30
[2024-08-11 17:21:48] INFO     : src.train - Train loss: 0.7967, Train accuracy: 0.5419
[2024-08-11 17:21:48] INFO     : src.train - Val loss: 0.6776, Val accuracy: 0.5939

[2024-08-11 17:21:51] INFO     : src.train - EPOCH: 2/30
[2024-08-11 17:21:51] INFO     : src.train - Train loss: 0.6900, Train accuracy: 0.5581
[2024-08-11 17:21:51] INFO     : src.train - Val loss: 0.6799, Val accuracy: 0.5939

[2024-08-11 17:21:53] INFO     : src.train - EPOCH: 3/30
[2024-08-11 17:21:53] INFO     : src.train - Train loss: 0.6882, Train accuracy: 0.5594
[2024-08-11 17:21:53] INFO     : src.train - Val loss: 0.6756, Val accuracy: 0.5939

[2024-08-11 17:21:56] INFO     : src.train - EPOCH: 4/30
[2024-08-11 17:21:56] INFO     : src.train - Train loss: 0.6860, Train accuracy: 0.5636
[2024-08-11 17:21:56] INF

              precision    recall  f1-score   support

    AWAY_WIN       0.00      0.00      0.00       163
    HOME_WIN       0.53      1.00      0.69       184

    accuracy                           0.53       347
   macro avg       0.27      0.50      0.35       347
weighted avg       0.28      0.53      0.37       347



[2024-08-11 17:23:01] INFO     : src.train - Accuracy from next season: 0.5766


              precision    recall  f1-score   support

    AWAY_WIN       0.00      0.00      0.00       445
    HOME_WIN       0.58      1.00      0.73       606

    accuracy                           0.58      1051
   macro avg       0.29      0.50      0.37      1051
weighted avg       0.33      0.58      0.42      1051



[2024-08-11 17:23:02] INFO     : src.train - Accuracy on short streaks (training): 0.6103
[2024-08-11 17:23:02] INFO     : src.train - Accuracy on long streaks (training): 0.6212
[2024-08-11 17:23:03] INFO     : src.train - Accuracy on short streaks (evaluation): 0.5649
[2024-08-11 17:23:03] INFO     : src.train - Accuracy on long streaks (evaluation): 0.5758
