In [1]:
import os
os.chdir('..')

import json
from lightning_module import PreTrainLightning
from models import TSMVAE

from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from dataset import create_dataloaders

In [2]:
seq_len = 8 # TIME STAMPS
in_chans = 2 # Dimensions of time series

embed_dim = 64 # Embedding dimension
latent_dim = 32 # Latent dimension
num_heads = 4 # Number of attention heads
num_layers = 2 # Number of transformer layers
dropout = 0.1 # Dropout rate
depth = 2 # Depth of the model

decoder_embed_dim = latent_dim # Decoder embedding dimension
decoder_num_heads = 4 # Number of attention heads in the decoder
decoder_depth = 2 # Depth of the decoder
z_type = 'vanilla' # Type of latent variable
beta = 0.0001
noise_factor = 0.0

mask_ratio = 0.15
LR = 1e-4

In [3]:
model = TSMVAE(
            seq_len=seq_len, 
            in_chans=in_chans, 
            embed_dim=embed_dim, 
            num_heads=num_heads, 
            depth=depth,
            decoder_embed_dim=latent_dim, 
            decoder_num_heads=decoder_num_heads,
            decoder_depth=decoder_depth, 
            z_type=z_type, 
            lambda_=beta, 
            mask_ratio=mask_ratio,
            dropout=dropout,
            noise_factor=noise_factor,
        )

In [4]:
dirpath = '/home/jp4474/latent-abc-smc/notebooks_bcell'

In [5]:
lightning_module = PreTrainLightning(model=model, lr=LR, prog_bar=True)
checkpoint_callback = ModelCheckpoint(
            dirpath=dirpath,
            filename='TSMVAE-{epoch:02d}-{val_loss:.4f}',
            save_top_k=10,
            monitor='val_loss',
            mode='min'
        )

lr_monitor = LearningRateMonitor(logging_interval='epoch')
early_stop_callback = EarlyStopping(monitor="val_loss", patience=10, mode="min")

In [6]:
trainer = Trainer(
    max_epochs=100,
    accelerator='auto',
    devices=1,
    callbacks=[checkpoint_callback, lr_monitor, early_stop_callback],
    log_every_n_steps=10,
    enable_progress_bar=False,
    precision="32-true",
    fast_dev_run=False,
)

GPU available: True (cuda), used: True


TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [7]:
train_dataloader, val_dataloader = create_dataloaders(
    data_dir='/home/jp4474/latent-abc-smc/data/Lotka',
    batch_size=32
)

In [8]:
trainer.fit(lightning_module, train_dataloader, val_dataloader)

You are using a CUDA device ('NVIDIA RTX 5000 Ada Generation') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2025-08-05 21:41:18.968359: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-05 21:41:18.977029: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754430078.985812 4182059 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for 