# Notebook to train LN models
#### as described in [A systematic comparison of LN mdels on the retina](https://www.biorxiv.org/content/10.1101/2024.03.06.583740v2.full)

In [None]:
import os

import hydra
import lightning
from torch.utils import data

from openretina.data_io.base import compute_data_info
from openretina.data_io.cyclers import LongCycler, ShortCycler

## Loading configs

To train on natural movies use `config_name = "vystrcilova_2024_wn_ln.yaml"`

To train on white noise use `config_name = "vystrcilova_2024_nm_ln.yaml"`

#### Hyper-parameters relevant for the different LN model configurations:

The parameters for the following LN model options can be set in `configs/model/single_cell_lnp.yaml`

* Spatial crop (**SC**): `spat_kernel_size`
  * `[15,15]` - spatial crop applied
  * `[40,40]` - spatial crop not applied
* Gaussian fit (**GF**): `fit_gaussian`
    * `True` - gaussian fit to receptive field
    * `False` - gaussian not fit to receptive field
* Parametrized nonlinearity (**PNL**): `nonlinearity`
    * `"softplus"` - non-learnable nonlinearity parameters
    * `"parametrized_softplus"` - learnable nonlinearity parameters


In [2]:
config_name = "vystrcilova_2024_wn_ln.yaml"  # use vystrcilova_2024_nm_ln.yaml for the natural movie dataset

with hydra.initialize(config_path="../configs", version_base="1.3"):
    cfg = hydra.compose(config_name=config_name)


## Dataloader

If this is your first time working with the `sridhar_2025` dataset, when calling the dataloader for the first time the dataset will be downloaded at the `cfg.paths.cache_dir` location.

The cell and retina for which the model should be trained is specified in:

`configs/dataloader/sridhar_wn_2025_ln.yaml` when training on **white noise**
`configs/dataloader/sridhar_nm_2025_ln.yaml` when training on **natural movies**

under `cell_index` and `retina`

The reliable cells indices used in the paper were:

`[0, 2, 3, 4, 8, 11, 12, 13, 15, 17, 18, 21, 23, 24, 27, 28, 36, 37, 38, 39, 40, 41, 42, 44, 45, 47, 48, 49, 50, 51, 52, 55, 56, 58, 59, 60, 61, 65, 66, 67, 68, 69, 73, 74, 75, 76, 80, 84, 85, 87, 88, 92, 93, 95, 98, 99, 100, 102, 104, 105, 107, 108, 109, 110, 111, 113, 115, 117, 118, 119, 120, 123, 124, 125, 129, 131, 132, 134, 135, 137, 138, 139, 141, 142, 147, 150, 151, 152, 153, 154, 156, 157, 158, 160, 161, 163, 164, 165, 166, 167, 168, 169, 170, 172, 173, 174, 175, 179, 180, 181, 183, 184, 186, 187, 188, 190, 192, 194, 195, 196, 198, 199, 201, 202, 203, 204, 205, 206, 207, 208, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 223, 224, 225, 226, 228, 229, 232, 233, 237, 238, 239, 240, 241, 244, 246, 248, 249, 251, 252, 253, 255, 256, 259, 260, 263, 264, 265, 266, 267, 268, 270, 272, 275, 276, 278, 279, 280, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 293, 294, 295, 297, 298, 299, 300, 302, 305, 306, 308, 309, 314, 315, 316, 319, 321, 322, 323, 324, 326, 327, 328, 329, 334, 335, 336, 337, 338, 339, 343, 345, 346, 347, 348, 349, 351, 355, 356, 359, 360, 361, 364, 366, 367, 368, 369]` for **`retina = "01"`**

`[10, 13, 14, 19, 20, 22, 23, 27, 30, 31, 33, 34, 35, 36, 37, 38, 39, 40, 41, 47, 49, 52, 54, 55, 56, 62, 63, 65, 67, 68, 71, 73, 75, 76, 79, 88, 93, 94, 97, 99, 102, 104, 105, 106, 109, 110, 114, 116, 122, 126, 132, 133, 135, 137, 139, 142, 148, 155, 157, 158, 161, 165, 167, 168, 169, 171, 177, 178, 180]` for **`retina = "02"`**

`[2, 4, 7, 9, 10, 31, 32, 34, 43, 46, 48, 53, 57, 58, 61, 80, 83, 85, 87, 97, 101, 104, 105, 114, 117, 125, 131, 137, 140, 145, 147, 148, 149, 152, 156, 160, 162, 182, 184, 185, 189, 191, 192, 193, 195, 211, 214, 217, 218, 220, 223, 224, 226, 228, 229, 231, 239, 247, 257, 261, 262, 265, 268, 272, 273, 274, 275, 278, 279, 283, 286, 291, 300, 307, 308, 309, 318, 321, 325, 343, 346, 347, 348, 349, 352, 368, 396, 398, 400, 418, 491]` for **`retina = "03"`** (only reliable on natural movies)

`[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 17, 18, 20, 23, 25, 26, 27, 28, 30, 31, 32, 35, 36, 37, 38, 39, 42, 43, 44, 45, 48, 50, 52, 53, 54, 56, 58, 60, 61, 63, 64, 66, 67, 68, 69, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 85, 86, 87, 88, 91, 92, 93, 95, 96, 97, 98, 100, 101, 102, 104, 105, 106, 107, 109]` for **`retina = "04"`**



In [4]:
if "matmul_precision" in cfg:
    hydra.utils.call(cfg.matmul_precision)

movies_dict = hydra.utils.call(cfg.data_io.stimuli)
neuron_data_dict = hydra.utils.call(cfg.data_io.responses)

  _C._set_float32_matmul_precision(precision)


In [5]:
if cfg.check_stimuli_responses_match:
    for session, neuron_data in neuron_data_dict.items():
        neuron_data.check_matching_stimulus(movies_dict[session])


In [10]:
dataloaders = hydra.utils.instantiate(
        cfg.dataloader,
        neuron_data_dictionary=neuron_data_dict,
        movies_dictionary=movies_dict,
    )

data_info = compute_data_info(neuron_data_dict, movies_dict, partial_data_info=cfg.data_io.get("data_info"))

train_loader = data.DataLoader(
    LongCycler(dataloaders["train"], shuffle=True),
    batch_size=None,
    num_workers=0,
    pin_memory=True,
)
valid_loader = ShortCycler(dataloaders["validation"])



Random seed 1000 has been set.
train idx: [2 6 5 1 4 9 0 8]
val idx: [7 3]


In [11]:
if cfg.seed is not None:
    lightning.pytorch.seed_everything(cfg.seed)

Seed set to 42


## Model

In [12]:

input_shape = next(iter(train_loader))[1][0].shape[1:]
cfg.model['in_shape'] = (input_shape[0], cfg.dataloader['num_of_frames'], *input_shape[2:])

# getting the receptive field location from the dataset
retina_index = cfg.dataloader.retina_index
location = train_loader.dataset.loaders[retina_index].dataset.locations

model = hydra.utils.instantiate(cfg.model)
model.location = location[0]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


## Logging and callbacks

In [13]:
log_save_path = os.path.join("logs", "lns_wn/")
os.makedirs(
        log_save_path,
        exist_ok=True,
    )

In [14]:
early_stopping = lightning.pytorch.callbacks.EarlyStopping(
    monitor="val_correlation",
    patience=10,
    mode="max",
    verbose=False,
    min_delta=0.001,
)

lr_monitor = lightning.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch")

model_checkpoint = lightning.pytorch.callbacks.ModelCheckpoint(
    monitor="val_correlation", mode="max", save_weights_only=False
)

## Trainer

In [15]:
trainer = lightning.Trainer(
    max_epochs=10000,
    logger=None,
    callbacks=[early_stopping, lr_monitor, model_checkpoint],
    accelerator="gpu",
    log_every_n_steps=10,
)


MisconfigurationException: No supported gpu backend found!

In [None]:
trainer.fit(model, train_loader, valid_loader)
