In [3]:
import rootutils
from pathlib import Path

rootutils.setup_root(Path.cwd(), indicator=".project-root", pythonpath=True)

import warnings

# # Suppress TensorFlow warnings
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # Suppress TensorFlow logging
# os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'  # Turn off oneDNN custom operations
warnings.filterwarnings('ignore', category=UserWarning)

import json
import torch
from src.models.lightning_module import PreTrainLightningSpatial
from src.models.components.models import MaskedAutoencoderViT3D
from src.data.dataset import create_dataloaders

from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import numpy as np

# Set float32 matmul precision for better performance on CUDA devices with Tensor Cores
torch.set_float32_matmul_precision('medium')

In [4]:
observational_data = np.load("/home/jp4474/viaABC/data_temp/SPATIAL/data.npy")

In [5]:
config = {
    "model": {
        "name": "SpatialSIR3D",
        "params": {
            "decoder_depth": 8,
            "decoder_embed_dim": 104,
            "decoder_num_heads": 8,
            "depth": 10,
            "embed_dim": 144,
            "img_size": 80,
            "lambda_": 0.0,
            "num_frames": 15,
            "num_heads": 8,
            "patch_size": 8,
            "pred_t_dim": 15,
            "t_patch_size": 3,
            "z_type": "vanilla",
            "mask_ratio": 0.25,
        }
    }
}

In [6]:
# save config as yaml
with open('/home/jp4474/viaABC/tutorial_sirs/config.yaml', 'w') as f:
    json.dump(config, f, indent=4)

In [7]:
model = MaskedAutoencoderViT3D(**config['model']['params'])

img_size (80, 80) patch_size (8, 8) frames 15 t_patch_size 3
model initialized


In [8]:
dirpath = '/home/jp4474/viaABC/tutorial_sirs'

In [9]:
lightning_module = PreTrainLightningSpatial(model=model, lr=1e-3, mask_ratio=config['model']['params']['mask_ratio'])
checkpoint_callback = ModelCheckpoint(
            dirpath=dirpath,
            filename='SpatialSIR3D-{epoch:02d}-{train_loss:.4f}',
            save_top_k=1,
            monitor='train_loss', # During pretraining, masks are generated dynamically so model cannot essentially overfit
            mode='min'
        )

lr_monitor = LearningRateMonitor(logging_interval='epoch')
early_stop_callback = EarlyStopping(monitor="train_loss", patience=10, mode="min")

In [10]:
trainer = Trainer(
    max_epochs=100,
    accelerator='auto',
    devices=1,
    callbacks=[checkpoint_callback, lr_monitor, early_stop_callback],
    log_every_n_steps=10,
    enable_progress_bar=False,
    precision="32-true",
    fast_dev_run=False,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [11]:
train_dataloader, val_dataloader = create_dataloaders(
    data_dir='/home/jp4474/viaABC/tutorial_sirs/',
    batch_size=64,
    system_name="sirs"
)

In [12]:
next(iter(train_dataloader)).shape

torch.Size([1, 3, 15, 80, 80])

In [13]:
trainer.fit(lightning_module, train_dataloader, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                   | Params | Mode 
---------------------------------------------------------
0 | model | MaskedAutoencoderViT3D | 3.8 M  | train
---------------------------------------------------------
3.8 M     Trainable params
0         Non-trainable params
3.8 M     Total params
15.358    Total estimated model params size (MB)
315       Modules in train mode
0         Modules in eval mode
`Trainer.fit` stopped: `max_epochs=100` reached.
