In [1]:
import os
import torch
from training import train_model, hyperparameter_tuning
from model import AdvancedLSTM, LSTMEncoderOnly

# REMEMBER TO SAVE BEST HYPERPARAMS FOR AFTER TRAINING USE

In [2]:
file_path = os.getcwd() + "/final_data" # run from root
X_train, y_train = torch.load(file_path + "/train_sequences.pt", weights_only=True)
X_val, y_val = torch.load(file_path + "/val_sequences.pt", weights_only=True) 

print(f"Train sequences: {X_train.shape}, Targets: {y_train.shape}")

Train sequences: torch.Size([32472, 5, 26]), Targets: torch.Size([32472, 1])


In [None]:
# Hyperparameter tuning
best_params = hyperparameter_tuning(X_train, y_train, X_val, y_val, epochs=30, n_trials=25)

# Full training with best hyperparameters
print("\nTraining final model with best hyperparameters...")

adv_model = AdvancedLSTM(input_dim=26, hidden_dim=best_params['hidden_dim'], 
                          output_dim=1, num_layers=best_params['num_layers'], 
                          dropout=best_params['dropout'], num_fc_layers=best_params['num_fc_layers'])


print("train max:", y_train.max().item(), "train min:", y_train.min().item())

Train sequences: torch.Size([32472, 5, 26]), Targets: torch.Size([32472, 1])
Running random search with 25 trials...

Trial 1/25
Params: {'learning_rate': 0.005, 'hidden_dim': 64, 'weight_decay': 0.001, 'num_layers': 4, 'dropout': 0.1, 'num_fc_layers': 2, 'batch_size': 64}


KeyboardInterrupt: 

In [3]:
model = AdvancedLSTM(
    input_dim=26,
    hidden_dim=64,
    output_dim=1,
    num_layers=2,
    dropout=0.4,
    num_fc_layers=2,
)

In [5]:
model.hidden_dim

64

In [6]:
train_model(
    model=model,
    X_train=X_train,
    y_train=y_train,
    X_val=X_val,
    y_val=y_val,
    epochs=50,
    batch_size=32,
    learning_rate=0.005,
    weight_decay=0.01,
    verbose=2)

Epoch  1/50: 100%|██████████| 1015/1015 [00:15<00:00, 65.34it/s, loss=0.7165, lr=0.000252]


Epoch 1 Training MSE (log): 0.9242
Epoch 1 validation RMSE: 2.7272
Epoch 1 validation MAE: 1.9687
Best model saved at epoch 1 with RMSE: 2.7272


Epoch  2/50: 100%|██████████| 1015/1015 [00:15<00:00, 65.70it/s, loss=0.5670, lr=0.000408]


Epoch 2 Training MSE (log): 0.6381
Epoch 2 validation RMSE: 2.7199
Epoch 2 validation MAE: 2.0160
Best model saved at epoch 2 with RMSE: 2.7199


Epoch  3/50: 100%|██████████| 1015/1015 [00:15<00:00, 66.03it/s, loss=0.6382, lr=0.000658]


Epoch 3 Training MSE (log): 0.5952
Epoch 3 validation RMSE: 2.7374
Epoch 3 validation MAE: 2.0681


Epoch  4/50: 100%|██████████| 1015/1015 [00:15<00:00, 64.75it/s, loss=0.4272, lr=0.000994]


Epoch 4 Training MSE (log): 0.5545
Epoch 4 validation RMSE: 2.7391
Epoch 4 validation MAE: 2.1543


Epoch  5/50: 100%|██████████| 1015/1015 [00:15<00:00, 65.97it/s, loss=0.5730, lr=0.0014]


Epoch 5 Training MSE (log): 0.5230
Epoch 5 validation RMSE: 2.7441
Epoch 5 validation MAE: 2.1575


Epoch  6/50: 100%|██████████| 1015/1015 [00:15<00:00, 64.42it/s, loss=0.3615, lr=0.00186]


Epoch 6 Training MSE (log): 0.5027
Epoch 6 validation RMSE: 2.7694
Epoch 6 validation MAE: 2.1872


Epoch  7/50: 100%|██████████| 1015/1015 [00:15<00:00, 64.63it/s, loss=0.4849, lr=0.00235]


Epoch 7 Training MSE (log): 0.4951
Epoch 7 validation RMSE: 2.7319
Epoch 7 validation MAE: 2.1300


Epoch  8/50: 100%|██████████| 1015/1015 [00:16<00:00, 61.21it/s, loss=0.4868, lr=0.00285]


Epoch 8 Training MSE (log): 0.4920
Epoch 8 validation RMSE: 2.7409
Epoch 8 validation MAE: 2.1727


Epoch  9/50: 100%|██████████| 1015/1015 [00:15<00:00, 64.28it/s, loss=0.6825, lr=0.00334]


Epoch 9 Training MSE (log): 0.4923
Epoch 9 validation RMSE: 2.7393
Epoch 9 validation MAE: 2.1770


Epoch 10/50: 100%|██████████| 1015/1015 [00:15<00:00, 63.59it/s, loss=0.4297, lr=0.0038]


Epoch 10 Training MSE (log): 0.4901
Epoch 10 validation RMSE: 2.7442
Epoch 10 validation MAE: 2.1625


Epoch 11/50: 100%|██████████| 1015/1015 [00:16<00:00, 61.82it/s, loss=0.3348, lr=0.00421]


Epoch 11 Training MSE (log): 0.4902
Epoch 11 validation RMSE: 2.7399
Epoch 11 validation MAE: 2.1639


Epoch 12/50: 100%|██████████| 1015/1015 [00:16<00:00, 60.77it/s, loss=0.3809, lr=0.00454]


Epoch 12 Training MSE (log): 0.4886
Epoch 12 validation RMSE: 2.7300
Epoch 12 validation MAE: 2.1412


Epoch 13/50: 100%|██████████| 1015/1015 [00:16<00:00, 63.16it/s, loss=0.4356, lr=0.00479]


Epoch 13 Training MSE (log): 0.4845
Epoch 13 validation RMSE: 2.7707
Epoch 13 validation MAE: 2.2401


Epoch 14/50: 100%|██████████| 1015/1015 [00:15<00:00, 63.65it/s, loss=0.6322, lr=0.00495]


Epoch 14 Training MSE (log): 0.4820
Epoch 14 validation RMSE: 2.8543
Epoch 14 validation MAE: 2.2906


Epoch 15/50: 100%|██████████| 1015/1015 [00:16<00:00, 63.04it/s, loss=0.3203, lr=0.005] 


Epoch 15 Training MSE (log): 0.4797
Epoch 15 validation RMSE: 2.7481
Epoch 15 validation MAE: 2.1452


Epoch 16/50: 100%|██████████| 1015/1015 [00:15<00:00, 65.77it/s, loss=0.4307, lr=0.00499]


Epoch 16 Training MSE (log): 0.4772
Epoch 16 validation RMSE: 2.7446
Epoch 16 validation MAE: 2.1424


Epoch 17/50: 100%|██████████| 1015/1015 [00:16<00:00, 63.36it/s, loss=0.6951, lr=0.00496]


Epoch 17 Training MSE (log): 0.4706
Epoch 17 validation RMSE: 2.8452
Epoch 17 validation MAE: 2.2855


Epoch 18/50: 100%|██████████| 1015/1015 [00:15<00:00, 63.69it/s, loss=0.3570, lr=0.00491]


Epoch 18 Training MSE (log): 0.4652
Epoch 18 validation RMSE: 2.8365
Epoch 18 validation MAE: 2.2369


Epoch 19/50: 100%|██████████| 1015/1015 [00:16<00:00, 62.52it/s, loss=0.4359, lr=0.00484]


Epoch 19 Training MSE (log): 0.4534
Epoch 19 validation RMSE: 2.8488
Epoch 19 validation MAE: 2.2022


Epoch 20/50: 100%|██████████| 1015/1015 [00:15<00:00, 64.36it/s, loss=0.5708, lr=0.00475]


Epoch 20 Training MSE (log): 0.4434
Epoch 20 validation RMSE: 2.8933
Epoch 20 validation MAE: 2.2596


Epoch 21/50: 100%|██████████| 1015/1015 [00:16<00:00, 59.95it/s, loss=0.4143, lr=0.00465]


Epoch 21 Training MSE (log): 0.4274
Epoch 21 validation RMSE: 2.8357
Epoch 21 validation MAE: 2.1457


Epoch 22/50: 100%|██████████| 1015/1015 [00:16<00:00, 62.32it/s, loss=0.4041, lr=0.00452]


Epoch 22 Training MSE (log): 0.4130
Epoch 22 validation RMSE: 2.8513
Epoch 22 validation MAE: 2.1208


Epoch 23/50: 100%|██████████| 1015/1015 [00:16<00:00, 60.70it/s, loss=0.3806, lr=0.00438]


Epoch 23 Training MSE (log): 0.3982
Epoch 23 validation RMSE: 2.8559
Epoch 23 validation MAE: 2.0872


Epoch 24/50: 100%|██████████| 1015/1015 [00:16<00:00, 60.21it/s, loss=0.3016, lr=0.00423]


Epoch 24 Training MSE (log): 0.3853
Epoch 24 validation RMSE: 2.8537
Epoch 24 validation MAE: 2.0981


Epoch 25/50: 100%|██████████| 1015/1015 [00:16<00:00, 60.54it/s, loss=0.4588, lr=0.00406]


Epoch 25 Training MSE (log): 0.3738
Epoch 25 validation RMSE: 3.0047
Epoch 25 validation MAE: 2.2432


Epoch 26/50:  74%|███████▎  | 747/1015 [00:12<00:04, 60.85it/s, loss=0.2871, lr=0.00393]


KeyboardInterrupt: 