In [1]:
import rootutils
from pathlib import Path

rootutils.setup_root(Path.cwd(), indicator=".project-root", pythonpath=True)

from src.models.lightning_module import PreTrainLightning
from src.models.components.models import TSMVAE

from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from src.data.dataset import create_dataloaders

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seq_len = 8 # TIME STAMPS
in_chans = 2 # Dimensions of time series

embed_dim = 64 # Embedding dimension
latent_dim = 32 # Latent dimension
num_heads = 4 # Number of attention heads
num_layers = 2 # Number of transformer layers
dropout = 0.1 # Dropout rate
depth = 2 # Depth of the model

decoder_embed_dim = latent_dim # Decoder embedding dimension
decoder_num_heads = 4 # Number of attention heads in the decoder
decoder_depth = 2 # Depth of the decoder
z_type = 'vae'
beta = 0.0001
noise_factor = 0.0

mask_ratio = 0.15
LR = 1e-4

In [3]:
model = TSMVAE(
            seq_len=seq_len, 
            in_chans=in_chans, 
            embed_dim=embed_dim, 
            num_heads=num_heads, 
            depth=depth,
            decoder_embed_dim=latent_dim, 
            decoder_num_heads=decoder_num_heads,
            decoder_depth=decoder_depth, 
            z_type=z_type, 
            lambda_=beta, 
            mask_ratio=mask_ratio,
            dropout=dropout,
            noise_factor=noise_factor,
        )

In [4]:
dirpath = '/home/jp4474/viaABC/tutorial'

In [5]:
lightning_module = PreTrainLightning(model=model, lr=LR, prog_bar=True)
checkpoint_callback = ModelCheckpoint(
            dirpath=dirpath,
            filename='TSMVAE-{epoch:02d}-{train_loss:.4f}',
            save_top_k=1,
            monitor='train_loss',
            mode='min'
        )

lr_monitor = LearningRateMonitor(logging_interval='epoch')
early_stop_callback = EarlyStopping(monitor="train_loss", patience=10, mode="min")

In [6]:
trainer = Trainer(
    max_epochs=100,
    accelerator='auto',
    devices=1,
    callbacks=[checkpoint_callback, lr_monitor, early_stop_callback],
    log_every_n_steps=10,
    enable_progress_bar=False,
    precision="32-true",
    fast_dev_run=False,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/jp4474/viaABC/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [7]:
train_dataloader, val_dataloader = create_dataloaders(
    data_dir='/home/jp4474/viaABC/tutorial',
    batch_size=32
)

In [8]:
trainer.fit(lightning_module, train_dataloader, val_dataloader)

You are using a CUDA device ('NVIDIA RTX 5000 Ada Generation') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/jp4474/viaABC/.venv/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:658: Checkpoint directory /home/jp4474/viaABC/tutorial exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | TSMVAE | 131 K  | train
-----------------------------------------
131 K     Trainable params
0         Non-trainable params
131 K     Total params
0.524     Total estimated model params size (MB)
101       Modules in train mode
0         Modules in eval mode
/home/jp4474/viaABC/.venv/lib/python3.10/site-packages