In [None]:
import os
import lightning.pytorch as L
from pytorch_lightning.loggers import CometLogger
from lightning.pytorch.callbacks import ModelCheckpoint, RichProgressBar
from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme

from utils.configs import ExperimentConfigs
from utils.callbacks import MetricLoggerCallback 
from utils.dataloader import DataloaderModule
from data.particle_clouds.jets import JetDataModule
from multimodal_bridge_matching import MultiModalBridgeMatching


config = ExperimentConfigs('/home/df630/Multimodal-Bridges/experiments/configs/aoj_config.yaml')
jets = JetDataModule(config=config, preprocess=True)
dataloader = DataloaderModule(config=config, datamodule=jets)
model = MultiModalBridgeMatching(config)

In [None]:
#...Loggers

if hasattr(config.experiment, 'comet_logger'):
    logger = CometLogger(**config.experiment.comet_logger.to_dict())
    config.experiment.comet_logger.experiment_key = logger.experiment.get_key()
    logger.experiment.log_parameters(parameters=config.to_dict())
else:
    logger = None
    
#...Callbacks (always after loading model and data)

progress_bar = RichProgressBar(theme=RichProgressBarTheme(**config.experiment.progress_bar.to_dict()))
checkpoints = ModelCheckpoint(**config.experiment.checkpoints.to_dict())
metrics = MetricLoggerCallback(sync_dist=True)

In [None]:
#...Train

trainer = L.Trainer(
    max_epochs=config.train.max_epochs,
    accelerator="gpu",
    strategy='ddp_notebook',
    num_nodes=1,
    devices=[0,3],
    sync_batchnorm=True,
    gradient_clip_val=1.0,
    callbacks=[progress_bar, checkpoints, metrics],  # Pass all callbacks
    logger=logger,  # Logger
)

trainer.fit(
        model, train_dataloaders=dataloader.train, val_dataloaders=dataloader.valid
    )