In [None]:
import torch
import numpy as np
import torch.nn as nn
from pathlib import Path
from tqdm import tqdm
from sklearn.metrics import r2_score, mean_absolute_error, root_mean_squared_error

    
def val_step(
    logits: torch.Tensor,
    labels: torch.Tensor,
    criterion: torch.nn.MSELoss
) -> torch.Tensor:
    loss = criterion(logits, labels)
    return loss


def train_step(
    model,
    logits: torch.Tensor,
    labels: torch.Tensor,
    criterion: torch.nn.MSELoss,
    optimizer: torch.optim.Optimizer
) -> torch.Tensor:
    optimizer.zero_grad()
    loss = criterion(logits, labels)
    loss.backward()
    max_norm = 1.0 # Define the maximum allowed gradient norm
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
    optimizer.step()
    return loss


def train(
    model: torch.nn.Module,
    train_dataloader: torch.utils.data.DataLoader,
    test_dataloader: torch.utils.data.DataLoader,
    timestamps: int = 36,
    epochs: int = 20,
    device: str = "cuda"
) -> None:
    
    criterion = nn.HuberLoss()
    optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-2)

    for epoch in range(epochs):
        
        losses = []
        r2s = []
        mses = []
        rmses = []
        
        model.train()
        
        train_tqdm = tqdm(train_dataloader, total=len(train_dataloader))
        
        for inputs, labels in train_tqdm:
            
            inputs = inputs.to("cuda")
            labels = labels.to("cuda")
            
            logits: torch.Tensor = model(inputs).squeeze(1)
            
            loss = train_step(model, logits, labels, criterion, optimizer)
            
            r2 = r2_score(logits.detach().cpu().numpy(), labels.detach().cpu().numpy())
            mse = mean_absolute_error(logits.detach().cpu().numpy(), labels.detach().cpu().numpy())
            rmse = root_mean_squared_error(logits.detach().cpu().numpy(), labels.detach().cpu().numpy())
            
            r2s.append(r2)
            mses.append(mse)
            rmses.append(rmse)
            losses.append(loss.item())
            
            train_tqdm.set_description(
                f"Train Epoch {epoch}, Loss - {np.mean(losses):0.4f}, R^2 - {np.mean(r2s):0.4f}, MAE - {np.mean(mses):0.4f}, RMSE - {np.mean(rmses):0.4f}"
            )
            
        r2s.clear()
        mses.clear()
        rmses.clear()
        losses.clear()
        
        with torch.no_grad():
            
            model.eval()
            
            test_tqdm = tqdm(test_dataloader, total=len(test_dataloader))
            
            for full_batch in test_tqdm:
                
                full_batch = full_batch.to(device)
                
                inputs = full_batch[:, :-timestamps, :]
                labels = full_batch[:, -timestamps:, :][::240]
                
                logits = model(inputs)
            
                eval_labels = labels.reshape(test_dataloader.batch_size, -1)
                eval_logits = logits.reshape(test_dataloader.batch_size, -1)

                loss = val_step(eval_logits, eval_labels, criterion)
            
                r2 = r2_score(eval_logits.detach().cpu().numpy(), eval_labels.detach().cpu().numpy())
                mse = mean_absolute_error(eval_logits.detach().cpu().numpy(), eval_labels.detach().cpu().numpy())
                rmse = root_mean_squared_error(eval_logits.detach().cpu().numpy(), eval_labels.detach().cpu().numpy())
                
                r2s.append(r2)
                mses.append(mse)
                rmses.append(rmse)
                losses.append(loss.item())
                
                test_tqdm.set_description(
                    f"Test, Loss - {np.mean(losses):0.4f}, R^2 - {np.mean(r2s):0.4f}, MAE - {np.mean(mses):0.4f}, RMSE - {np.mean(rmses):0.4f}"
                )
                
        print()

    torch.save(model.state_dict(), "lstm_model.pt")

In [6]:
from src.dataset import TimestampsDataset
from src.lstm import LSTMForecaster

INPUT_SEQUENCE = 239
OUTPUT_SEQUENCE = 1
DEVICE = "cuda"

model = LSTMForecaster(input_features=6, output_timestamps=OUTPUT_SEQUENCE, hidden_size=64, num_layers=2)
model = model.to(DEVICE)
train_dataset = TimestampsDataset(data=Path("train.csv"), lags=INPUT_SEQUENCE + OUTPUT_SEQUENCE)
test_dataset = TimestampsDataset(data=Path("test.csv"), lags=INPUT_SEQUENCE + OUTPUT_SEQUENCE)

train(
    timestamps=OUTPUT_SEQUENCE,
    model=model, 
    train_dataloader=torch.utils.data.DataLoader(
        train_dataset,
        batch_size=16,
        shuffle=True,
        drop_last=True
    ),
    test_dataloader=torch.utils.data.DataLoader(
        test_dataset,
        batch_size=4,
        shuffle=True,
        drop_last=True
    ),
    device=DEVICE
)

Train Epoch 0, Loss - 0.1466, R^2 - -82031.6428, MAE - 0.2700, RMSE - 0.3014:   0%|          | 60/58031 [00:00<03:11, 303.43it/s] 

logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch

Train Epoch 0, Loss - 0.0748, R^2 - -70247.4817, MAE - 0.1758, RMSE - 0.2068:   0%|          | 132/58031 [00:00<02:51, 336.73it/s]

logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch

Train Epoch 0, Loss - 0.0521, R^2 - -55766.1368, MAE - 0.1457, RMSE - 0.1767:   0%|          | 204/58031 [00:00<02:46, 347.45it/s]

logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch

Train Epoch 0, Loss - 0.0415, R^2 - -42956.8364, MAE - 0.1308, RMSE - 0.1635:   0%|          | 277/58031 [00:00<02:44, 352.06it/s]

logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch

Train Epoch 0, Loss - 0.0348, R^2 - -34698.3804, MAE - 0.1194, RMSE - 0.1527:   1%|          | 348/58031 [00:01<02:46, 346.48it/s]

logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch

Train Epoch 0, Loss - 0.0302, R^2 - -29246.0025, MAE - 0.1109, RMSE - 0.1449:   1%|          | 418/58031 [00:01<02:46, 345.08it/s]

logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch

Train Epoch 0, Loss - 0.0269, R^2 - -25401.1937, MAE - 0.1043, RMSE - 0.1393:   1%|          | 454/58031 [00:01<02:46, 346.65it/s]

logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch

Train Epoch 0, Loss - 0.0244, R^2 - -22500.0992, MAE - 0.0989, RMSE - 0.1348:   1%|          | 527/58031 [00:01<02:43, 352.70it/s]

logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch

Train Epoch 0, Loss - 0.0225, R^2 - -20202.4272, MAE - 0.0945, RMSE - 0.1307:   1%|          | 600/58031 [00:01<02:41, 354.90it/s]

logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch

Train Epoch 0, Loss - 0.0209, R^2 - -18455.6988, MAE - 0.0913, RMSE - 0.1279:   1%|          | 672/58031 [00:02<02:42, 351.90it/s]

logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch

Train Epoch 0, Loss - 0.0196, R^2 - -17111.4509, MAE - 0.0884, RMSE - 0.1251:   1%|▏         | 744/58031 [00:02<02:43, 350.29it/s]

logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch

Train Epoch 0, Loss - 0.0186, R^2 - -16201.2221, MAE - 0.0861, RMSE - 0.1230:   1%|▏         | 816/58031 [00:02<02:43, 349.55it/s]

logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch

Train Epoch 0, Loss - 0.0176, R^2 - -15695.9280, MAE - 0.0841, RMSE - 0.1211:   2%|▏         | 886/58031 [00:02<02:44, 348.14it/s]

logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch

Train Epoch 0, Loss - 0.0173, R^2 - -15525.7363, MAE - 0.0833, RMSE - 0.1202:   2%|▏         | 942/58031 [00:02<02:44, 346.02it/s]


logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch.Size([16, 6])
labels:  torch.Size([16, 6])
logits:  torch

KeyboardInterrupt: 

In [4]:
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    """
    Injects positional information into the input embeddings.
    Uses sine and cosine functions of different frequencies.
    """
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Create a long enough positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # Shape: [max_len, 1]
        # Term for calculating frequencies
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        # Apply sine to even indices in pe
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices in pe
        pe[:, 1::2] = torch.cos(position * div_term)
        # Add a batch dimension: [1, max_len, d_model] -> becomes [max_len, 1, d_model] after unsqueeze(0) below
        pe = pe.unsqueeze(0).transpose(0, 1) # Shape: [max_len, 1, d_model]

        # Register 'pe' as a buffer that should not be considered a model parameter.
        # 'pe' will be moved to the correct device along with the module.
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [sequence_length, batch_size, d_model]
               OR [batch_size, sequence_length, d_model] if batch_first=True elsewhere
        """
        # If batch_first=True was used for the input x to this module:
        # x shape: [batch_size, sequence_length, d_model]
        # self.pe shape: [max_len, 1, d_model]
        # We need to add positional encodings to x.
        # Select positional encodings up to the sequence length of x: self.pe[:x.size(1), :]
        # Transpose pe slice to match x's batch-first format if needed, or adjust x.
        # Let's assume x comes in as [batch_size, sequence_length, d_model]
        # self.pe[:x.size(1), :].transpose(0,1) gives shape [1, sequence_length, d_model]
        # This will broadcast correctly during addition.

        # If x is [sequence_length, batch_size, d_model] (PyTorch default)
        # x = x + self.pe[:x.size(0), :]
        # return self.dropout(x)

        # If x is [batch_size, sequence_length, d_model] (using batch_first=True)
        # Need pe slice shape [1, sequence_length, d_model]
        pe_slice = self.pe[:x.size(1), :].transpose(0, 1) # Shape [1, sequence_length, d_model]
        x = x + pe_slice # Broadcasting adds positional encoding to each batch element
        return self.dropout(x)


class TransformerTimeSeriesPredictor(nn.Module):
    """
    Transformer-based model for time series prediction.

    Args:
        input_dim (int): The number of features in the input time series (e.g., 6).
        d_model (int): The dimension of the transformer embeddings and hidden layers.
                       Must be divisible by nhead.
        nhead (int): The number of attention heads in the multi-head attention mechanism.
        num_encoder_layers (int): The number of stacked transformer encoder layers.
        dim_feedforward (int): The dimension of the feedforward network model in encoder layers.
        output_dim (int): The number of time steps to predict into the future.
        dropout (float): Dropout probability.
        max_len (int): Maximum sequence length for positional encoding.
    """
    def __init__(self, input_dim, d_model, nhead, num_encoder_layers, dim_feedforward, output_dim, dropout=0.1, max_len=5000):
        super(TransformerTimeSeriesPredictor, self).__init__()
        
        self.output_timestamps = output_dim
        
        self.input_dim = input_dim

        self.d_model = d_model

        # --- Input Embedding ---
        # Linear layer to project input features to d_model dimension
        self.input_embedding = nn.Linear(input_dim, d_model)

        # --- Positional Encoding ---
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len)

        # --- Transformer Encoder ---
        # Define a single encoder layer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='relu', # or 'gelu'
            batch_first=True # IMPORTANT: Input/Output shape (batch, seq, feature)
        )
        # Stack multiple encoder layers
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers=num_encoder_layers,
            norm=nn.LayerNorm(d_model) # Optional final normalization
        )

        # --- Output Layer ---
        # Linear layer to map the final transformer output to the desired prediction dimension
        self.output_layer = nn.Linear(d_model, output_dim)

        # Initialize weights (optional but often recommended)
        self._init_weights()

    def _init_weights(self):
        # Initialize weights for linear layers
        initrange = 0.1
        self.input_embedding.weight.data.uniform_(-initrange, initrange)
        self.input_embedding.bias.data.zero_()
        self.output_layer.weight.data.uniform_(-initrange, initrange)
        self.output_layer.bias.data.zero_()

    def forward(self, src):
        """
        Defines the forward pass of the model.

        Args:
            src (torch.Tensor): The input time series data with shape
                               (batch_size, sequence_length, input_dim).

        Returns:
            torch.Tensor: The predicted values with shape (batch_size, output_dim).
        """
        # 1. Input Embedding
        # src shape: [batch_size, seq_len, input_dim]
        # embedded shape: [batch_size, seq_len, d_model]
        embedded = self.input_embedding(src) * math.sqrt(self.d_model) # Scale embedding

        # 2. Positional Encoding
        # pos_encoded shape: [batch_size, seq_len, d_model]
        pos_encoded = self.pos_encoder(embedded)

        # 3. Transformer Encoder
        # encoder_output shape: [batch_size, seq_len, d_model]
        # Note: No mask is applied here, assuming we only use the final output
        # for forecasting. If using for other tasks or intermediate outputs,
        # a mask might be needed (e.g., nn.Transformer.generate_square_subsequent_mask).
        encoder_output = self.transformer_encoder(pos_encoded)

        # 4. Output Layer
        # We typically use the output corresponding to the *last* time step
        # of the input sequence for forecasting.
        # last_step_output shape: [batch_size, d_model]
        last_step_output = encoder_output[:, -1, :]

        # prediction shape: [batch_size, output_dim]
        prediction = self.output_layer(last_step_output)

        return prediction.reshape(src.shape[0], self.output_timestamps // 6, self.input_dim)

In [6]:
from src.dataset import TimestampsDataset

INPUT_SEQUENCE = 120
OUTPUT_SEQUENCE = 240
DEVICE = "cuda"

# --- Hyperparameters ---
INPUT_DIM = 6        # Number of features per time step

# --- Model Specific Hyperparameters ---
D_MODEL = 256        # Embedding dimension (must be divisible by NHEAD)
NHEAD = 4            # Number of attention heads
NUM_ENCODER_LAYERS = 3 # Number of transformer encoder layers
DIM_FEEDFORWARD = 512 # Hidden dimension in feedforward network
OUTPUT_DIM = OUTPUT_SEQUENCE * 6      # Predict next 5 time steps
DROPOUT = 0.2

# --- Model Initialization ---
model = TransformerTimeSeriesPredictor(
    input_dim=INPUT_DIM,
    d_model=D_MODEL,
    nhead=NHEAD,
    num_encoder_layers=NUM_ENCODER_LAYERS,
    dim_feedforward=DIM_FEEDFORWARD,
    output_dim=OUTPUT_DIM,
    dropout=DROPOUT
)

model = model.to(DEVICE)
train_dataset = TimestampsDataset(data=Path("train.csv"), lags=INPUT_SEQUENCE + OUTPUT_SEQUENCE)
test_dataset = TimestampsDataset(data=Path("test.csv"), lags=INPUT_SEQUENCE + OUTPUT_SEQUENCE)

train(
    model,
    timestamps=OUTPUT_SEQUENCE,
    train_dataloader=torch.utils.data.DataLoader(
        train_dataset,
        batch_size=16,
        shuffle=True,
        drop_last=True
    ),
    test_dataloader=torch.utils.data.DataLoader(
        test_dataset,
        batch_size=16,
        shuffle=True,
        drop_last=True
    ),
    device=DEVICE,
    epochs=10
)

Train Epoch 0, Loss - 0.0226, R^2 - -29.8695, MAE - 0.0822, RMSE - 0.1218: 100%|██████████| 58039/58039 [06:22<00:00, 151.61it/s]
  return F.mse_loss(input, target, reduction=self.reduction)
  0%|          | 0/563 [00:00<?, ?it/s]


RuntimeError: The size of tensor a (1440) must match the size of tensor b (90) at non-singleton dimension 1

In [25]:
16 * 48

768