In this notebook we will train a Core/Readout model on data from Hoefling et al., 2024: ["A chromatic feature detector in the retina signals visual context changes"](https://elifesciences.org/articles/86860).

We will closely follow the structure of our unified training script, `openretina.cli.train.py`, including using Hydra to import and examine model config files. 

Note that using `openretina.cli.train.py`, and the corresponding command `openretina train` is the recommended way to run model training, as for some configurations it can take some time. 


# Imports

In [1]:
import logging
import os

import hydra
import lightning
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from einops import rearrange

from openretina.data_io.base import compute_data_info
from openretina.data_io.cyclers import LongCycler, ShortCycler
from openretina.data_io.hoefling_2024.dataloaders import natmov_dataloaders_v2
from openretina.data_io.hoefling_2024.responses import filter_responses, make_final_responses
from openretina.data_io.hoefling_2024.stimuli import movies_from_pickle
from openretina.eval.metrics import correlation_numpy, feve
from openretina.models.core_readout import CoreReadout
from openretina.utils.file_utils import get_local_file_path
from openretina.utils.h5_handling import load_h5_into_dict
from openretina.utils.misc import CustomPrettyPrinter
from openretina.utils.plotting import (
    numpy_to_mp4_video,
)

matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)  # to display logs in jupyter

%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

pp = CustomPrettyPrinter(indent=4, max_lines=30)

Let's import also the global config file for this model using hydra.

In [2]:
with hydra.initialize(config_path=os.path.join("..", "configs"), version_base="1.3"):
    cfg = hydra.compose(config_name="goldin_2022_core_readout.yaml")

# Loading data

The first step in loading data is determining from where it will be fetched / stored.

Let's see how this is handled in the configs:

In [3]:
pp.pprint(cfg.paths)

{   'cache_dir': '${oc.env:OPENRETINA_CACHE_DIRECTORY}',
    'data_dir': '/home/baptiste/Documents/LabPipelines/open-retina/notebooks/data/omarre_lab/goldin_2023/goldin_2023_data.zip',
    'log_dir': '.',
    'output_dir': '${hydra:runtime.output_dir}'}


The config contains the path from where files will be downloaded, and also requires the `cache_dir` to be set by the user: this is the directory where the data will be stored on download.

When using the training script, if cache_dir is not set by the user in the config files or somewhere in the script, this will fall back to the `OPENRETINA_CACHE_DIRECTORY` environment variable, which by default points to `~/openretina_cache`.

If set, the `cache_dir` is also what the package will use in place of the default openretina cache folder. Let's set both here:

In [4]:
your_chosen_root_folder = "."  # Change this with your desired path.

cfg.paths.cache_dir = your_chosen_root_folder

# We will also overwrite the output directory for the logs/model to the local folder.
cfg.paths.log_dir = your_chosen_root_folder
cfg.paths.output_dir = your_chosen_root_folder

os.environ["OPENRETINA_CACHE_DIRECTORY"] = your_chosen_root_folder

## Stimuli

Loading of the stimuli is achieved, in the training script, via:
```
movies_dict = hydra.utils.call(cfg.data_io.stimuli)
```

Let's unpack it here.

In [5]:
pp.pprint(cfg.data_io.stimuli)

{   '_convert_': 'object',
    '_target_': 'openretina.data_io.goldin_2022.stimuli.load_all_stimuli',
    'base_data_path': '${paths.data_dir}',
    'normalize_stimuli': True,
    'specie': 'axolotl and mouse',
    'stim_type': 'naturalscene'}


In [8]:
cfg.paths

{'cache_dir': '.', 'data_dir': '/home/baptiste/Documents/LabPipelines/open-retina/notebooks/data/omarre_lab/goldin_2023/goldin_2023_data.zip', 'log_dir': '.', 'output_dir': '.'}

Essentially, using the `get_local_file_path` function, if `file_path` is not a local fiile, it will be downloaded to the cache folder and read from there.

In [6]:
movies_path = get_local_file_path(file_path=cfg.paths.movies_path, cache_folder=cfg.paths.data_dir)

movies_dict = movies_from_pickle(movies_path)

ConfigAttributeError: Key 'movies_path' is not in struct
    full_key: paths.movies_path
    object_type=dict

In [None]:
pp.pprint(movies_dict)

Let us also visualize a few seconds of the training video:

In [None]:
numpy_to_mp4_video(movies_dict.train[:, :300, ...])

## Responses

In the training script, responses are loaded through:

```
neuron_data_dict = hydra.utils.call(cfg.data_io.responses)
```

Let's unpack it here.

In [None]:
pp.pprint(cfg.data_io.responses)

While this may look complex, it effectively amounts to resolving a few intermediate steps in loading the data, and should be read from the inside out.

When written more simply, it is equivalent to the following:

In [None]:
responses_path = get_local_file_path(file_path=cfg.paths.responses_path, cache_folder=cfg.paths.data_dir)

responses_dict = load_h5_into_dict(file_path=responses_path)

filtered_responses_dict = filter_responses(responses_dict, **cfg.quality_checks)

final_responses = make_final_responses(filtered_responses_dict, response_type="natural")

And here is how the final responses will be organised:

In [None]:
pp.pprint(final_responses)

# Creating dataloaders

The corresponding code in `train.py` is:
```
dataloaders = hydra.utils.instantiate(
        cfg.dataloader,
        neuron_data_dictionary=neuron_data_dict,
        movies_dictionary=movies_dict,
    )
```

In [None]:
pp.pprint(cfg.dataloader)

In [None]:
dataloaders = natmov_dataloaders_v2(
    neuron_data_dictionary=final_responses,
    movies_dictionary=movies_dict,
    allow_over_boundaries=True,
    batch_size=128,
    train_chunk_size=50,
    validation_clip_indices=cfg.dataloader.validation_clip_indices,
)

In [None]:
pp.pprint(dataloaders)

Let's also compute `data_info`, which is used to initialise certain model components and to save important metadata about stimuli and responses within the model.

In [None]:
data_info = compute_data_info(neuron_data_dictionary=final_responses, movies_dictionary=movies_dict)

pp.pprint(data_info)

# Model initialisation

Relevant `train.py` section:
```
cfg.model.n_neurons_dict = data_info["n_neurons_dict"]

model = hydra.utils.instantiate(cfg.model, data_info=data_info)
```

The config for the model will contain all the relevant hyperparameters for it:

In [None]:
pp.pprint(cfg.model)

As you can see, the value for `n_neurons_dict` is missing, and needs to be set from data_info.

In [None]:
n_neurons_dict = data_info["n_neurons_dict"]

model = CoreReadout(
    in_shape=(1, 1, 108, 108),
    hidden_channels=(16, 16),
    temporal_kernel_sizes=(1, 1),
    spatial_kernel_sizes=(11, 5),
    n_neurons_dict=n_neurons_dict,
    core_gamma_hidden=0.0,
    core_gamma_in_sparse=0.0,
    core_gamma_input=0.0,
    core_gamma_temporal=40.0,
    core_hidden_padding=True,
    core_input_padding=False,
    cut_first_n_frames_in_core=0,
    downsample_input_kernel_size=None,
    dropout_rate=0.0,
    learning_rate=0.01,
    maxpool_every_n_layers=None,
    readout_bias=True,
    readout_gamma=0.4,
    readout_gaussian_masks=True,
    readout_gaussian_mean_scale=6.0,
    readout_gaussian_var_scale=4.0,
    readout_positive=True,
    readout_scale=True,
    data_info=data_info,
)

# Training

With data imported, models initialised and dataloaders set up, we can turn to training. 

```
log_folder = os.path.join(cfg.paths.output_dir, cfg.exp_name)
os.makedirs(log_folder, exist_ok=True)
logger_array = []
for _, logger_params in cfg.logger.items():
    logger = hydra.utils.instantiate(logger_params, save_dir=log_folder)
    logger_array.append(logger)

callbacks = [
    hydra.utils.instantiate(callback_params) for callback_params in cfg.get("training_callbacks", {}).values()
]

trainer = hydra.utils.instantiate(cfg.trainer, logger=logger_array, callbacks=callbacks)
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=valid_loader)
```

This section is a bit more involved in `train.py`, to leave flexibility for different loggers and callbacks configurations. We are going to keep it simple here.

Let's first initialise a simple tensorboard logger:

In [None]:
log_save_path = os.path.join(cfg.paths.output_dir, "notebook_example")
os.makedirs(log_save_path, exist_ok=True)

logger = lightning.pytorch.loggers.TensorBoardLogger(
    name="tensorboard/",
    save_dir=log_save_path,
)

Then some training callbacks (i.e. utility functions that will be called during training):

In [None]:
early_stopping = lightning.pytorch.callbacks.EarlyStopping(
    monitor="val_correlation",
    patience=10,
    mode="max",
    verbose=False,
    min_delta=0.001,
)

lr_monitor = lightning.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch")

model_checkpoint = lightning.pytorch.callbacks.ModelCheckpoint(
    monitor="val_correlation", mode="max", save_weights_only=False
)

We can then instantiate the trainer:

In [None]:
trainer = lightning.Trainer(max_epochs=100, logger=logger, callbacks=[early_stopping, lr_monitor, model_checkpoint])

Finally, we can start training. Before doing sp though, we can initialise the tensorboard jupyter integration, to visualize how training progresses.

Run the following cell once or twice until the tensorboard extension UI shows up. Once is shows, note that at the beginning it will show no data (unless you have run this notebook before), because we have not started the trainer yet.

When you run the cell containing `trainer.fit` you can then come back to the tensorboard extension, reload the window *within the extension* by clicking the refresh icon in the top right, and follow the training.

In [None]:
%reload_ext tensorboard

%tensorboard --logdir {log_save_path}

The only last important step before calling the trainer is to convert the dictionary of dataloaders we have into a unified iterator that will cycle through all sessions during training and evaluation:

In [None]:
train_loader = LongCycler(dataloaders["train"])
val_loader = ShortCycler(dataloaders["validation"])

And we are finally ready to train:

In [None]:
trainer.fit(model, train_loader, val_loader)

# Evaluation

Once the model is done training, we can turn to evaluation.

First, let's still use the trainer to see the poisson and correlation performance on each of the dataloaders.

In [None]:
test_loader = ShortCycler(dataloaders["test"])
trainer.test(model, dataloaders=[train_loader, val_loader, test_loader], ckpt_path="best")

We can also look at further evals, like the fraction of explainable variance explained for an example session.

In [None]:
# Let's pick an example session
example_session = list(final_responses.keys())[0]

# Extract responses by trial:
responses_by_trial = final_responses[example_session].test_by_trial

responses_by_trial.shape

In [None]:
# Get the test movie for that session:
test_movie = dataloaders["test"][example_session].dataset.movies

# Pass it through the model: move to gpu and add batch dimension
with torch.no_grad():
    model_predictions = model.forward(test_movie.to(model.device).unsqueeze(0), data_key=example_session)

model_predictions.shape

In [None]:
help(feve)

In [None]:
# We need to reshape the predictions and responses by trial to match what the function expects

feve_score = feve(
    rearrange(responses_by_trial, "neurons time trials -> time trials neurons")[20:],
    model_predictions.squeeze(0).cpu().numpy(),
)

print(f"Average FEVe score for session {example_session}: {feve_score.mean():.2f}")

Finally, we can plot an example neuron's predictions and its ground truth response.

In [None]:
neuron_idx = 4
session_idx = 0


example_session = list(final_responses.keys())[session_idx]

test_sample = next(iter(dataloaders["test"][example_session]))
responses_by_trial = final_responses[example_session].test_by_trial
mean_test_responses = final_responses[example_session].test_response

input_samples = test_sample.inputs
targets = test_sample.targets

model.eval()
model.cpu()

with torch.no_grad():
    reconstructions = model(input_samples.cpu(), example_session)
reconstructions = reconstructions.cpu().numpy().squeeze()

feve_score = feve(
    rearrange(responses_by_trial, "neurons time trials -> time trials neurons")[20:],
    model_predictions.squeeze(0).cpu().numpy(),
)

correlations = correlation_numpy(mean_test_responses.T[20:], model_predictions.squeeze(0).cpu().numpy(), axis=0)


targets = targets.cpu().numpy().squeeze()
window = 750
plt.figure(figsize=(10, 5))
plt.plot(np.arange(0, window), targets[:window, neuron_idx], label="target")
plt.plot(np.arange(20, window), reconstructions[:window, neuron_idx], label="prediction")
plt.suptitle(f"Neuron {neuron_idx} - FEVE: {feve_score[neuron_idx]:.2f} - Correlation: {correlations[neuron_idx]:.2f}")

plt.legend()
sns.despine()

---