In [1]:
import os
import hydra
import lightning

from openretina.data_io.cyclers import LongCycler
from openretina.data_io.sridhar_2025.dataloader_utils import download_wn_dataset, download_nm_dataset




In [2]:
config_name = 'vystrcilova_2024_wn_ln.yaml' # use vystrcilova_2024_nm_ln.yaml for the natural movie dataset

with hydra.initialize(config_path="../configs", version_base="1.3"):
    cfg = hydra.compose(config_name=config_name)

## If this is your first time working with the `sridhar_2025` dataset, use the cell below to download it
`download_nm_dataset` downloads the natural movie part of the dataset
`download_wn_dataset` downloads the white noise part of the dataset

In [None]:
if config_name == 'vystrcilova_2024_nm_ln.yaml':
    download_wn_dataset(cfg.paths.cache_dir, cfg.paths.cache_dir, os.getenv('Hf_TOKEN'))
elif config_name == 'vystrcilova_2024_wn_ln.yaml':
    download_nm_dataset(cfg.paths.cache_dir, cfg.paths.cache_dir, os.getenv('Hf_TOKEN'))


In [4]:
with hydra.initialize(config_path="../configs/dataloader/", version_base="1.3"):
    dataloader_cfg = hydra.compose(config_name="sridhar_wn_2025_ln.yaml")

dataloader = hydra.utils.instantiate(dataloader_cfg)

Random seed 1000 has been set.
train idx: [2 6 5 1 4 9 0 8]
val idx: [7 3]
chunk size: 254, frames in trial: 25500
chunk size: 254, frames in trial: 25500


In [6]:
with hydra.initialize(config_path="../configs/model/", version_base="1.3"):
    model_cfg = hydra.compose(config_name="single_cell_lnp.yaml")


In [7]:
input_shape = next(iter(dataloader["train"][dataloader_cfg["retina_index"]]))[0].shape[1:]

model_cfg['in_shape'] = (input_shape[0], dataloader_cfg['num_of_frames'], *input_shape[2:])
retina_index = dataloader_cfg["retina_index"]
location = dataloader["train"][retina_index].dataset.locations
model = hydra.utils.instantiate(model_cfg)
model.location = location[0]

In [8]:
os.makedirs(
        cfg.paths.log_dir,
        exist_ok=True,
    )

early_stopping = lightning.pytorch.callbacks.EarlyStopping(
        monitor="val_correlation",
        patience=10,
        mode="max",
        verbose=False,
        min_delta=0.001,
    )

lr_monitor = lightning.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch")

model_checkpoint = lightning.pytorch.callbacks.ModelCheckpoint(
    monitor="val_correlation", mode="max", save_weights_only=False
)

In [9]:
trainer = lightning.Trainer(
        max_epochs=10000,
        logger=None,
        callbacks=[early_stopping, lr_monitor, model_checkpoint],
        accelerator="gpu",
        log_every_n_steps=10,
    )

train_loader = LongCycler(dataloader["train"])
val_loader = LongCycler(dataloader["validation"])

/user/vystrcilova/u14647/.conda/envs/venv_or/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /user/vystrcilova/u14647/.conda/envs/venv_or/lib/pyt ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, train_loader, val_loader)


You are using a CUDA device ('NVIDIA A100-SXM4-80GB MIG 1g.20gb') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/user/vystrcilova/u14647/.conda/envs/venv_or/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:658: Checkpoint directory /scratch-grete/projects/nim00010/open-retina/notebooks/lightning_logs/version_11777876/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                | Type              | Params | Mode 
------------------------------------------------------------------
0 | loss                | PoissonLoss3d     | 0      | train
1 | correlation_loss    | CorrelationLoss3d | 0      | train
2 | _smooth_reg_fn_spat | LaplaceL2norm     | 0      | train


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]