In [1]:
import os

import hydra
import lightning
import wandb

from openretina.data_io.cyclers import LongCycler
from openretina.utils.model_utils import get_core_output_based_on_dimensions
from openretina.models.core_readout import BaseCoreReadout




##  Initialize config

In [2]:
config_name = 'vystrcilova_2024_wn_cnn.yaml' # use vystrcilova_2024_nm_cnn.yaml for the natural movie dataset

with hydra.initialize(config_path="../configs", version_base="1.3"):
    cfg = hydra.compose(config_name=config_name)

## Downloading datasets
Only necessary to download the first time. Afterwards, data will be stored localy in `cfg.cache_dir`


In [3]:
from openretina.data_io.sridhar_2025.dataloader_utils import download_wn_dataset, download_nm_dataset
if config_name == 'vystrcilova_2024_wn_cnn.yaml':
    download_wn_dataset(cfg.paths.cache_dir, cfg.paths.cache_dir, os.getenv('Hf_TOKEN'))
elif config_name == 'vystrcilova_2024_nm_cnn.yaml':
    download_nm_dataset(cfg.paths.cache_dir, cfg.paths.cache_dir, os.getenv('Hf_TOKEN'))

Using the latest cached version of the module from /user/vystrcilova/u14647/.cache/huggingface/modules/datasets_modules/datasets/open-retina--wn_marmoset_data/3c691108d8582898c99dde1bb246b40885b2f16299445fcf588a1bc6315fefa5 (last modified on Mon Nov 17 22:53:09 2025) since it couldn't be found locally at open-retina/wn_marmoset_data, or remotely on the Hugging Face Hub.


ConnectionError: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/datasets/m-vys/wn_marmoset_data/revision/main (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x148cff676150>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 73b63cae-6535-4c52-8e8e-74b5a606065d)')

In [6]:
print(cfg)
cfg.dataloader.num_of_frames = cfg.model.core.temporal_kernel_sizes[0]
cfg.dataloader.num_of_layers = len(cfg.model.core.temporal_kernel_sizes)
dataloaders = hydra.utils.instantiate(cfg.dataloader)
n_neurons_dict = {}
retina_indices = list(dataloaders['train'].keys())
for index in retina_indices:
    n_neurons_dict[index] = dataloaders["train"][index].dataset.n_neurons


{'data_io': {'basepath_wn': '/projects/extern/nhr/nhr_ni/nim00010/dir.project/wn_marmoset_data/', 'cache_dir': '/projects/extern/nhr/nhr_ni/nim00010/dir.project/', 'response_path': 'responses', 'image_path': 'stimuli_padded_4/', 'stas_path': 'stas', 'sta_file': 'cell_data_01_WN_stas_cell'}, 'dataloader': {'_target_': 'openretina.data_io.sridhar_2025.dataloaders.white_noise_loader', '_convert_': 'object', 'basepath': '/projects/extern/nhr/nhr_ni/nim00010/dir.project/wn_marmoset_data/', 'files': {'01': 'responses/cell_responses_01_wn.pkl', '02': 'responses/cell_responses_02_wn.pkl', '04': 'responses/cell_responses_04_wn.pkl'}, 'big_crops': {'01': [35, 35, 55, 55], '02': [35, 35, 55, 55], '04': [35, 35, 55, 55]}, 'train_image_path': {'01': 'non_repeating_stimuli_1', '02': 'non_repeating_stimuli_1', '04': 'non_repeating_stimuli_2'}, 'test_image_path': {'01': 'repeating_stimuli_1', '02': 'repeating_stimuli_1', '04': 'repeating_stimuli_2'}, 'excluded_cells': {'01': [], '02': [], '03': [], '0

In [7]:
print(retina_indices)

['04']


In [8]:
input_shape = next(iter(dataloaders["train"][retina_indices[0]]))
cfg.model["in_shape"] = list(input_shape[0].shape[1:])
cfg.model["n_neurons_dict"] = n_neurons_dict
cfg.model["core"]["channels"] = [input_shape[0].shape[1]] + cfg.model.hidden_channels
cfg.model["readout"]["n_neurons_dict"] = n_neurons_dict
cfg.model["readout"]["in_shape"] = get_core_output_based_on_dimensions(cfg.model)
print(cfg.model.in_shape)

factorized core output shape: (16, 100, 60, 70)
[1, 134, 80, 90]


In [9]:
model = hydra.utils.instantiate(cfg.model)
batch = next(iter(dataloaders["train"][retina_indices[0]]))
print("input img shape", batch[0].shape)
print("response shape", batch[1].shape)



input img shape torch.Size([16, 1, 134, 80, 90])
response shape torch.Size([16, 100, 110])


In [10]:
model = BaseCoreReadout(core=model['core'], readout=model['readout'],
                            learning_rate=cfg.model.learning_rate)
model = model.float()

In [11]:
log_save_path = os.path.join(cfg.paths.log_dir, "cnns_wn/")
os.makedirs(
    log_save_path,
    exist_ok=True,
)

logger = lightning.pytorch.loggers.WandbLogger(
    name="",
    save_dir=log_save_path,
)
early_stopping = lightning.pytorch.callbacks.EarlyStopping(
    monitor="val_correlation",
    patience=10,
    mode="max",
    verbose=False,
    min_delta=0.001,
)

lr_monitor = lightning.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch")

model_checkpoint = lightning.pytorch.callbacks.ModelCheckpoint(
    monitor="val_correlation", mode="max", save_weights_only=False
)

trainer = lightning.Trainer(
    max_epochs=100,
    logger=None,
    callbacks=[early_stopping, lr_monitor, model_checkpoint],
    accelerator="gpu",
    log_every_n_steps=10,
)

/user/vystrcilova/u14647/.conda/envs/venv_or/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /user/vystrcilova/u14647/.conda/envs/venv_or/lib/pyt ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [12]:
train_loader = LongCycler(dataloaders["train"])
val_loader = LongCycler(dataloaders["validation"])

In [None]:
trainer.fit(model, train_loader, val_loader)

You are using a CUDA device ('NVIDIA A100-SXM4-80GB MIG 1g.20gb') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                               | Params | Mode 
--------------------------------------------------------------------------------
0 | core             | SimpleCoreWrapper                  | 16.1 K | train
1 | readout          | MultiSampledGaussianReadoutWrapper | 2.5 K  | train
2 | loss             | PoissonLoss3d                      | 0      | train
3 | correlation_loss | CorrelationLoss3d                  | 0      | train
--------------------------------------------------------------------------------
18.6 K    Trainable params
0         Non-train

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
wandb.finish()