In [1]:
import argparse
import json
import pytorch_lightning as pl
from argparse import Namespace
from model import DTSModel
from datamodules.csvdatamodule import CsvDataModule
from datamodules.hivedatamodule import HiveDataModule
from datamodules.sqldatamodule import SqlDataModule
from datamodules.s3datamodule import S3DataModule
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from dataclasses import dataclass



In [3]:
@dataclass
class Config:
    data: dict
    model: dict
    train:dict
        
config = Config(
    data= {
        "data": "csv",
        "data_params": {
            "num_feature": 5,
            "seq_len": 30,
            "tgt_len": 1,
            "batch_size": 32,
            "train_path": "data/AAPL_train.csv",
            "val_path": "data/AAPL_val.csv",
            "test_path": "data/AAPL_test.csv"
        }
    },
    model= {
        "model_name": "custom_rnn",
        "model_params": {
            "input_size": 5,
            "hidden_size": 32,
            "output_size": 1,
            "num_layers": 2,
            "lr": 2e-5,
            "batch_size": 32
        },
        "loss_fn_type": "mse",
        "loss_params": {}
    },
    train=  {
        "accelerator": "auto",
        "devices": 4,
        "strategy": "ddp",
        "max_epochs": 1
    }
)

In [4]:
pl.seed_everything(42, workers=True)
logger = TensorBoardLogger('logs/', name=config.model['model_name'])

Global seed set to 42


In [5]:
# Create LightningDataModule
data_module = CsvDataModule(config.data['data_params'])

In [8]:
data_module.train_data.head()

Unnamed: 0,2009-04-01,3.7174999713897705,3.892857074737549,3.7103569507598877,3.3038582801818848,589372000,AAPL
0,2009-04-02,3.933571,4.098214,3.920714,3.426054,812366800,AAPL
1,2009-04-03,4.078214,4.1475,4.054286,3.525758,636241200,AAPL
2,2009-04-06,4.105,4.241071,4.045714,3.600535,658064400,AAPL
3,2009-04-07,4.161786,4.166786,4.078214,3.495665,536580800,AAPL
4,2009-04-08,4.1225,4.171071,4.092143,3.53579,455630000,AAPL


In [7]:
batch = next(iter(data_module.train_dataloader()))
print(batch)

TypeError: cannot do slice indexing on Index with these indexers [1] of type int

In [None]:
# Create LightningModule
model = DTSModel(config.model)



In [None]:
# Callback to save the model checkpoint
checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath='./checkpoints/',
        filename='model-{epoch:02d}-{val_loss:.2f}',
        save_top_k=3,
        mode='min',
        save_last=True,
        every_n_epochs=1  # Save checkpoint every 100 steps
    )

# Create Trainer
trainer = pl.Trainer(
        accelerator=config.train['accelerator'],
        devices=config.train['devices'],
        strategy=config.train['strategy'],
        max_epochs=config.train['max_epochs'],
        callbacks=[
            EarlyStopping(monitor='val_loss'),
            LearningRateMonitor(logging_interval='step'),
            checkpoint_callback
        ],
        logger=logger,
    )

# Train the model
trainer.fit(model, datamodule=data_module)