In [1]:
import argparse
import json
import pytorch_lightning as pl
from argparse import Namespace
from model import DTSModel
from datamodules.csvdatamodule import CsvDataModule
from datamodules.hivedatamodule import HiveDataModule
from datamodules.sqldatamodule import SqlDataModule
from datamodules.s3datamodule import S3DataModule
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from dataclasses import dataclass
import torch

In [2]:
@dataclass
class Config:
    data: dict
    model: dict
    train:dict
        
config = Config(
    data= {
        "data": "csv",
        "data_params": {
            "num_feature": 5,
            "seq_len": 20,
            "tgt_len": 1,
            "batch_size": 64,
            "train_path": "data/AAPL_train.csv",
            "val_path": "data/AAPL_val.csv",
            "test_path": "data/AAPL_test.csv"
        }
    },
    model= {
        "model_name": "custom_rnn",
        "model_params": {
            "input_size": 5,
            "hidden_size": 32,
            "output_size": 1,
            "num_layers": 2,
            "lr": 2e-5,
        },
        "loss_fn_type": "mse",
        "loss_params": {}
    },
    train=  {
        "accelerator": "gpu",
        "devices": 1,
        "strategy": 'ddp',
        "max_epochs": 100
    }
)

In [3]:
pl.seed_everything(42, workers=True)
logger = TensorBoardLogger('logs/', name=config.model['model_name'])

Global seed set to 42


In [4]:
# Create LightningDataModule
data_module = CsvDataModule(config.data['data_params'])


In [5]:
#next(iter(data_module.train_dataloader()))

In [6]:
# Create LightningModule
device = torch.device('cuda') if config.train['accelerator']=='gpu' else torch.device('cpu')
model = DTSModel(config.model, device)
model.to(device)

DTSModel(
  (model): CustomRNN(
    (rnn): LSTM(5, 32, num_layers=2, batch_first=True)
    (fc): Linear(in_features=32, out_features=1, bias=True)
  )
  (loss_fn): MSE()
)

In [7]:
# Callback to save the model checkpoint
checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath='./checkpoints/',
        filename='model-{epoch:02d}-{val_loss:.2f}',
        save_top_k=3,
        mode='min',
        save_last=True,
        every_n_epochs=1,  # Save checkpoint every epoch
    )

# Create Trainer
trainer = pl.Trainer(
        accelerator=config.train['accelerator'],
        devices=config.train['devices'],
        #strategy=config.train['strategy'],
        max_epochs=config.train['max_epochs'],
        callbacks=[
            #EarlyStopping(monitor='val_loss'),
            LearningRateMonitor(logging_interval='step'),
            checkpoint_callback
        ],
        logger=logger,
    )

# Train the model
trainer.fit(model, datamodule=data_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type      | Params
--------------------------------------
0 | model   | CustomRNN | 13.5 K
1 | loss_fn | MSE       | 0     
--------------------------------------
13.5 K    Trainable params
0         Non-trainable params
13.5 K    Total params
0.054     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


In [8]:
preds = trainer.test(datamodule=data_module)

  rank_zero_warn(
Restoring states from the checkpoint path at /hd1/dl/deep_time_series_framework/checkpoints/model-epoch=99-val_loss=12.15.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at /hd1/dl/deep_time_series_framework/checkpoints/model-epoch=99-val_loss=12.15.ckpt


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           16.093013763427734
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [9]:
preds

[{'test_loss': 16.093013763427734}]

In [10]:
import numpy as np
test_loader = data_module.test_dataloader()
predictions = []
model.to(device)
for batch_idx, (inputs, labels) in enumerate(test_loader):
    if batch_idx !=0: continue
    
    outputs = model(inputs.to(device))
    outputs = data_module.target_denormalize(outputs.detach().cpu().numpy().squeeze())
    labels = data_module.target_denormalize(labels.detach().cpu().numpy().squeeze())

    predictions.append(list(zip(labels, outputs)))
#predictions = np.concatenate(predictions, axis=0)

In [11]:
predictions

[[(173.80438232421875, 55.2401),
  (174.2510986328125, 55.237064),
  (170.93540954589844, 55.26767),
  (171.8090057373047, 55.274727),
  (168.56285095214844, 55.263584),
  (165.01882934570312, 55.26278),
  (163.31137084960938, 55.246277),
  (161.22666931152344, 55.22379),
  (160.44244384765625, 55.22152),
  (158.61582946777344, 55.167694),
  (158.52651977539062, 55.14164),
  (158.05992126464844, 55.13809),
  (169.08897399902344, 55.126873),
  (173.50656127929688, 55.137775),
  (173.3378143310547, 55.179012),
  (174.55882263183594, 55.196182),
  (171.64024353027344, 55.2235),
  (171.3520050048828, 55.255203),
  (170.62640380859375, 55.2667),
  (173.77732849121094, 55.26066),
  (175.21856689453125, 55.254673),
  (171.08363342285156, 55.264618),
  (167.6245880126953, 55.275803),
  (167.86314392089844, 55.268425),
  (171.74957275390625, 55.24743),
  (171.5110321044922, 55.230583),
  (167.86314392089844, 55.21786),
  (166.2926483154297, 55.221546),
  (163.3306121826172, 55.23271),
  (159.10