In [None]:
import rootutils
from pathlib import Path

rootutils.setup_root(Path.cwd(), indicator=".project-root", pythonpath=True)

from src.models.lightning_module import PreTrainLightning
from src.models.components.models import TSMVAE

from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from src.data.dataset import create_dataloaders

In [None]:
seq_len = 8 # TIME STAMPS
in_chans = 2 # Dimensions of time series

embed_dim = 64 # Embedding dimension
latent_dim = 32 # Latent dimension
num_heads = 4 # Number of attention heads
num_layers = 2 # Number of transformer layers
dropout = 0.1 # Dropout rate
depth = 2 # Depth of the model

decoder_embed_dim = latent_dim # Decoder embedding dimension
decoder_num_heads = 4 # Number of attention heads in the decoder
decoder_depth = 2 # Depth of the decoder
z_type = 'vae'
beta = 0.0001
noise_factor = 0.0

mask_ratio = 0.15
LR = 1e-4

In [None]:
model = TSMVAE(
            seq_len=seq_len, 
            in_chans=in_chans, 
            embed_dim=embed_dim, 
            num_heads=num_heads, 
            depth=depth,
            decoder_embed_dim=latent_dim, 
            decoder_num_heads=decoder_num_heads,
            decoder_depth=decoder_depth, 
            z_type=z_type, 
            lambda_=beta, 
            mask_ratio=mask_ratio,
            dropout=dropout,
            noise_factor=noise_factor,
        )

In [None]:
dirpath = '/home/jp4474/viaABC/tutorial'

In [None]:
lightning_module = PreTrainLightning(model=model, lr=LR, prog_bar=True)
checkpoint_callback = ModelCheckpoint(
            dirpath=dirpath,
            filename='TSMVAE-{epoch:02d}-{train_loss:.4f}',
            save_top_k=1,
            monitor='train_loss',
            mode='min'
        )

lr_monitor = LearningRateMonitor(logging_interval='epoch')
early_stop_callback = EarlyStopping(monitor="train_loss", patience=10, mode="min")

In [None]:
trainer = Trainer(
    max_epochs=100,
    accelerator='auto',
    devices=1,
    callbacks=[checkpoint_callback, lr_monitor, early_stop_callback],
    log_every_n_steps=10,
    enable_progress_bar=False,
    precision="32-true",
    fast_dev_run=False,
)

In [None]:
train_dataloader, val_dataloader = create_dataloaders(
    data_dir='/home/jp4474/viaABC/tutorial',
    batch_size=32,
    system_name='lotka'
)

In [None]:
trainer.fit(lightning_module, train_dataloader, val_dataloader)