# Notebook to train CNN models
#### as described in [A systematic comparison of predictive models on the retina](https://www.biorxiv.org/content/10.1101/2024.03.06.583740v2.full)

In [1]:
import os

import hydra
import lightning
import wandb
from torch.utils import data

from openretina.data_io.base import compute_data_info
from openretina.data_io.cyclers import LongCycler, ShortCycler
from openretina.models.core_readout import UnifiedCoreReadout


## Loading configs

To train on natural movies use `config_name = "vystrcilova_2024_wn_cnn.yaml"`

To train on white noise use `config_name = "vystrcilova_2024_nm_cnn.yaml"`

#### Hyper-parameters relevant for the different CNN model configurations:

The parameters for the CNN model configurations can be set in `/mnt/vast-nhr/projects/nim00012/michaela/open-retina/configs/model/core_gaussian_readout.yaml`

Relevant parameters with example settings for
**CNN 3** on natural movies:
* `hidden_channels: [16, 32, 64]`
* `core`:
  * `spatial_kernel_size: [21, 5, 5]`
  *  `temporal_kernel_size: [35, 3, 3]`
    * `gamma_input: 48.0`
    *  `gamma_temporal: 0.08`
* `readout`
    * `gamma: 0.023`
    * `init_mu_range: 0.3`
    * `init_simga: 0.15`




In [2]:
config_name = "vystrcilova_2024_wn_cnn.yaml"  # use vystrcilova_2024_nm_cnn.yaml for the natural movie dataset

with hydra.initialize(config_path="../configs", version_base="1.3"):
    cfg = hydra.compose(config_name=config_name)

## Downloading datasets
Dataset download will happen only once and automatically, the first time you attempt to instantiate the dataloader, at the location specified by `cfg.paths.cache_dir`.


In [3]:
movies_dict = hydra.utils.call(cfg.data_io.stimuli)
neuron_data_dict = hydra.utils.call(cfg.data_io.responses)

if cfg.check_stimuli_responses_match:
    for session, neuron_data in neuron_data_dict.items():
        neuron_data.check_matching_stimulus(movies_dict[session])

In [10]:
dataloaders = hydra.utils.instantiate(
        cfg.dataloader,
        neuron_data_dictionary=neuron_data_dict,
        movies_dictionary=movies_dict,
    )

data_info = compute_data_info(neuron_data_dict, movies_dict, partial_data_info=cfg.data_io.get("data_info"))

train_loader = data.DataLoader(
    LongCycler(dataloaders["train"], shuffle=True),
    batch_size=None,
    num_workers=0,
    pin_memory=True,
)
valid_loader = ShortCycler(dataloaders["validation"])

Random seed 1000 has been set.
train idx: [2 6 5 1 4 9 0 8]
val idx: [7 3]


## Model

In [11]:
cfg.model.n_neurons_dict = data_info["n_neurons_dict"]


In [12]:
model = UnifiedCoreReadout(**cfg.model)
model = model.float()

## Logging nad callbacks

In [13]:
log_save_path = os.path.join(cfg.paths.log_dir, "cnns_wn/")
os.makedirs(
    log_save_path,
    exist_ok=True,
)

logger = lightning.pytorch.loggers.WandbLogger(
    name="",
    save_dir=log_save_path,
)
early_stopping = lightning.pytorch.callbacks.EarlyStopping(
    monitor="val_correlation",
    patience=10,
    mode="max",
    verbose=False,
    min_delta=0.001,
)

lr_monitor = lightning.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch")

model_checkpoint = lightning.pytorch.callbacks.ModelCheckpoint(
    monitor="val_correlation", mode="max", save_weights_only=False
)


## Trainer

In [14]:
trainer = lightning.Trainer(
    max_epochs=100,
    logger=None,
    callbacks=[early_stopping, lr_monitor, model_checkpoint],
    accelerator="gpu",
    log_every_n_steps=10,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, train_loader, valid_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type                        | Params | Mode 
------------------------------------------------------------------------
0 | core            | SimpleCoreWrapper           | 7.2 K  | train
1 | readout         | MultiSampledGaussianReadout | 8.8 K  | train
2 | loss            | PoissonLoss3d               | 0      | train
3 | validation_loss | CorrelationLoss3d           | 0      | train
------------------------------------------------------------------------
16.0 K    Trainable params
0         Non-trainable params
16.0 K    Total params
0.064     Total estimated model params size (MB)
21        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/mnt/vast-nhr/projects/nim00012/michaela/open-retina/.venv/lib64/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

In [None]:
wandb.finish()