In [1]:
import argparse
import json
import pytorch_lightning as pl
from argparse import Namespace
from model import DTSModel
from datamodules.csvdatamodule import CsvDataModule
from datamodules.hivedatamodule import HiveDataModule
from datamodules.sqldatamodule import SqlDataModule
from datamodules.s3datamodule import S3DataModule
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from dataclasses import dataclass



In [17]:
@dataclass
class Config:
    data: dict
    model: dict
    train:dict
        
config = Config(
    data= {
        "data": "csv",
        "data_params": {
            "num_feature": 5,
            "seq_len": 30,
            "tgt_len": 1,
            "batch_size": 32,
            "train_path": "data/AAPL_train.csv",
            "val_path": "data/AAPL_val.csv",
            "test_path": "data/AAPL_test.csv"
        }
    },
    model= {
        "model_name": "custom_rnn",
        "model_params": {
            "input_size": 5,
            "hidden_size": 32,
            "output_size": 1,
            "num_layers": 2,
            "lr": 2e-5,
            "batch_size": 32
        },
        "loss_fn_type": "mse",
        "loss_params": {}
    },
    train=  {
        "accelerator": "auto",
        "devices": 4,
        "strategy": 'ddp',
        "max_epochs": 1
    }
)

In [3]:
pl.seed_everything(42, workers=True)
logger = TensorBoardLogger('logs/', name=config.model['model_name'])

Global seed set to 42


In [4]:
# Create LightningDataModule
data_module = CsvDataModule(config.data['data_params'])

In [None]:
data_module.train_data

In [None]:
data_loader = data_module.train_dataloader()
batch = next(iter(data_loader))


In [None]:
print(batch[0].shape)

In [5]:
# Create LightningModule
model = DTSModel(config.model)

In [6]:
model

DTSModel(
  (model): CustomRNN(
    (rnn): RNN(5, 32, num_layers=2, batch_first=True)
    (fc): Linear(in_features=32, out_features=1, bias=True)
  )
  (loss_fn): MSE()
)

In [19]:
# Callback to save the model checkpoint
checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath='./checkpoints/',
        filename='model-{epoch:02d}-{val_loss:.2f}',
        save_top_k=3,
        mode='min',
        save_last=True,
        every_n_epochs=1  # Save checkpoint every 100 steps
    )

# Create Trainer
trainer = pl.Trainer(
        accelerator=config.train['accelerator'],
        devices=config.train['devices'],
        strategy=config.train['strategy'],
        max_epochs=config.train['max_epochs'],
        callbacks=[
            EarlyStopping(monitor='val_loss'),
            LearningRateMonitor(logging_interval='step'),
            checkpoint_callback
        ],
        logger=logger,
    )

# Train the model
trainer.fit(model, datamodule=data_module)

MisconfigurationException: `Trainer(strategy='ddp')` is not compatible with an interactive environment. Run your code as a script, or choose one of the compatible strategies: `Fabric(strategy=None|'dp'|'ddp_notebook')`. In case you are spawning processes yourself, make sure to include the Trainer creation inside the worker function.