This notebook is a gentle introduction to `openretina` through some visualization examples. No pre-requisites are needed to run this notebook, apart from having installed the package using one of the following options.

Recommended:
```
git clone git@github.com:open-retina/open-retina.git
cd open-retina
pip install -e .
```

Alternative:

```
pip install openretina
```


# Imports and data setup

In [None]:
import logging
import os
from pathlib import Path

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from IPython.display import clear_output, display
from moviepy import VideoFileClip

from openretina.data_io.hoefling_2024.constants import BADEN_TYPE_BOUNDARIES, RGC_GROUP_GROUP_ID_TO_CLASS_NAME
from openretina.data_io.hoefling_2024.stimuli import movies_from_pickle
from openretina.models.core_readout import load_core_readout_from_remote
from openretina.utils.file_utils import get_cache_directory, get_local_file_path, optionally_download_from_url
from openretina.utils.misc import CustomPrettyPrinter
from openretina.utils.plotting import (
    create_roi_animation,
    display_video,
    numpy_to_mp4_video,
    prepare_video_for_display,
    stitch_videos,
)

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)  # to display logs in jupyter notebooks

%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

pp = CustomPrettyPrinter(indent=4, max_lines=40)

In this example, we are going to visualize the stimuli and model predictions for a simple "Core + Readout" model trained on data from Hoefling et al., 2024: ["A chromatic feature detector in the retina signals visual context changes"](https://elifesciences.org/articles/86860).


First we will set the directory to which the stimuli data are downloaded to. The default download target path for many functions within `openretina` is `OPENRETINA_CACHE_DIRECTORY`, which can be changed via its environment variable.

In [None]:
# The default directory for downloads will be ~/openretina_cache
# To change this, uncomment the following line and change its path
# os.environ["OPENRETINA_CACHE_DIRECTORY"] = "/Data/"

# You can then check if that directory has been correctly set by running:
get_cache_directory()

The following download can take a while:

In [None]:
movie_stimulus_path = get_local_file_path(
    "https://huggingface.co/datasets/open-retina/open-retina/blob/main/euler_lab/hoefling_2024/stimuli/rgc_natstim_72x64_joint_normalized_2024-10-11.zip"
)

movie_stimuli = movies_from_pickle(movie_stimulus_path)

Let's visualize the test movie, which we are going to use in our visualizations in the rest of the notebook.

In [None]:
numpy_to_mp4_video(movie_stimuli.test, fps=30)

The video stimuli used are crops of natural movies recorded in the green and UV channel from a "mouse cam" (see [Qiu et al., 2021](https://www.sciencedirect.com/science/article/pii/S096098222100676X)).

# Loading a model

Now we are going to load a retina model that was trained on neural responses to this data. This is as easy as running:

In [None]:
model = load_core_readout_from_remote(
    "hoefling_2024_base_high_res", device="cuda" if torch.cuda.is_available() else "cpu"
)

# Visualizing model predictions

We will now visualize the responses that the model gives to the video data. To do so we need to specify which recording session we want to predict. Here we pick by default the first one in the readout.

In [None]:
# First, put the stimuli in a torch tensor, which is what the model expects.
stim = torch.Tensor(movie_stimuli.test).to(model.device)

# Second, we need to select one of the many experimental sessions the model was trained on to visualize a response.
example_session = model.readout.sessions[0]  # Can pick any number as long as it is in range

with torch.no_grad():
    predicted_response = model.forward(stim.unsqueeze(0), data_key=example_session)
predicted_response_numpy = predicted_response.squeeze().cpu().numpy()

Let's visualize the predicted response of example neurons with an interactive plot:

In [None]:
# Create a dropdown for neuron selection
neuron_selector = widgets.Dropdown(
    options=list(range(predicted_response_numpy.shape[1])),
    value=0,
    description="Neuron:",
)


# Define the plotting function
def plot_response(neuron_idx):
    plt.figure(figsize=(12, 6))
    plt.plot(predicted_response_numpy[:, neuron_idx])
    plt.xlabel("Time [frames]")
    plt.ylabel("Response [a.u.]")
    sns.despine()
    plt.show()


# Create an interactive widget
widgets.interactive(plot_response, neuron_idx=neuron_selector)

To make things more interesting, we can also plot this from an "ROI view".
ROI stands for Region Of Interest, and in our case, an ROI represents a retinal neuron that was imaged and segmented during data collection. Each ROI corresponds to a spatially localized neuron whose activity was recorded over time using 2P Calcium Imaging.

To extract the ROI mask, we access the `data_info` field within the model, which is a dictionary containing various kinds of information about the data that was used to train the model.

In [None]:
model.data_info.keys()

In [None]:
# Let's see what extra information we have about the sessions.
pp.pprint(model.data_info["sessions_kwargs"])

In [None]:
# What we need is the roi_mask, and the roi_ids. Optionally we can pass also the cell type identity.

roi_mask = model.data_info["sessions_kwargs"][example_session]["roi_mask"]
roi_ids = model.data_info["sessions_kwargs"][example_session]["roi_ids"]
cell_types = model.data_info["sessions_kwargs"][example_session]["group_assignment"]

roi_animation = create_roi_animation(
    roi_mask=roi_mask, activity=predicted_response_numpy.T, roi_ids=roi_ids, max_activity=5, visualize_ids=True
)
numpy_to_mp4_video(roi_animation, fps=30)

We can also pass cell type information to visualize the cells colour-coded by their type:

In [None]:
video = create_roi_animation(
    roi_mask=roi_mask,
    activity=predicted_response_numpy.T,
    roi_ids=roi_ids,
    cell_types=cell_types,  # array of cell type IDs
    type_boundaries=BADEN_TYPE_BOUNDARIES,
    max_activity=5,
    visualize_ids=False,
)

numpy_to_mp4_video(video, fps=30)

# Play with model predictions

Now that we have gone through some basics, let's play with a more engaging example. 

We will take the ROI response view from above a step further, by:
1. Showing the stimulus and the response side by side.
2. Add the ability to visualize different "broad" cell types as defined in [Baden et al., 2016](https://www.nature.com/articles/nature16468) (Slow On, Fast On, Off, On-Off, uncertain RGCs, ACs).

Do not worry about the visualisation code too much. To change which session's activity is visualized, you can change the example session in the following cell.

In [None]:
# Choose the folder where the videos will be saved. This can be deleted later.
videos_cache_folder = Path(get_cache_directory()).joinpath("./videos_cache").resolve()
videos_cache_folder.mkdir(exist_ok=True)
print(f"Videos will be saved in {videos_cache_folder}")

example_session = model.readout.sessions[2]  # Can pick any number as long as it is in range

# Get predictions
with torch.no_grad():
    predicted_response = model.forward(stim.unsqueeze(0), data_key=example_session)
predicted_response_numpy = predicted_response.squeeze().cpu().numpy()

# Extract metadata again
roi_mask = model.data_info["sessions_kwargs"][example_session]["roi_mask"]
roi_ids = model.data_info["sessions_kwargs"][example_session]["roi_ids"]
cell_types = model.data_info["sessions_kwargs"][example_session]["group_assignment"]

# Get cell-type groups.
baden_groups = np.array([RGC_GROUP_GROUP_ID_TO_CLASS_NAME[cell_type] for cell_type in cell_types])
baden_unique_groups = np.unique(baden_groups)

The first time a new type group is selected the activity video will be rendered, which might take around 40-50s. After the first display, it will be saved in `videos_cache_folder` and it will be shown again much faster.

In [None]:
box_layout = widgets.Layout(
    display="flex",
    flex_flow="column",
    border="solid",
    width="100%",
    align_items="center",
    justify_content="center",
)
style = {"description_width": "initial"}

video_dict = {
    group: mask for group, mask in zip(baden_unique_groups, [baden_groups == group for group in baden_unique_groups])
}

video_dict["All cell types"] = np.ones_like(baden_groups).astype(bool)

video_dropdown = widgets.Dropdown(
    options=list(video_dict.keys()),
    value="All cell types",
    description="Select Video: ",
    layout=widgets.Layout(width="100%", max_width="600px", min_width="300px"),
    style=style,
)

video_output = widgets.Output()

loading = widgets.Label(value="🔄 Loading...", layout=widgets.Layout(visibility="hidden"))
empty = widgets.Label(value="")


def on_video_change(change):
    """Callback for dropdown selection change."""

    loading.layout.visibility = "visible"
    with video_output:
        clear_output(wait=True)
        video_save_path = os.path.join(videos_cache_folder, f"{example_session} {change['new']}.mp4")
        if os.path.exists(video_save_path):
            display_video(video_array=None, video_save_path=video_save_path)
        else:
            group_mask = video_dict[change["new"]]
            stim_video = prepare_video_for_display(
                movie_stimuli.test[:, 30:, ...]
            )  # Skip the first 30 frames, to match response length
            response_video = create_roi_animation(
                roi_mask=roi_mask,
                activity=predicted_response_numpy.T[group_mask],
                roi_ids=roi_ids[group_mask],
                cell_types=cell_types[group_mask],  # array of cell type IDs
                type_boundaries=BADEN_TYPE_BOUNDARIES,  # boundaries between broad types
                max_activity=5,
                visualize_ids=False,
            )

            type_video = stitch_videos(stim_video, response_video)

            # Before displaying, clear video area of all previous content
            clear_output(wait=True)

            display_video(type_video, video_save_path=video_save_path, fps=30)

    loading.layout.visibility = "hidden"


# Attach the callback to the dropdown
video_dropdown.observe(on_video_change, names="value")

# Display the widgets
display(widgets.VBox([video_dropdown, loading, video_output, empty], layout=box_layout))

# Initial video display
on_video_change({"new": video_dropdown.value, "old": None, "owner": video_dropdown, "type": "change"})

# Un-scientific bonus: Showing any video to a retina model

This final section is intended as a more exploratory and fun visualization feature than a rigorous analysis. 

While the model can generate predicted retinal responses to any video, these should be interpreted with caution. The model in question was trained on UV/green videos captured with a specialized camera, whereas arbitrary videos are typically in RGB and recorded under more diverse and less controlled conditions. This creates two potential sources of distribution shift: differences in spectral content and overall image statistics. While the model will still produce responses (if we manipulate the input videos to roughly match the ones it was trained on), they may not accurately reflect real retinal activity. 

In [None]:
cute_dog_video_path = optionally_download_from_url(
    "https://videos.pexels.com", "video-files/4411457/4411457-hd_1920_1080_25fps.mp4", cache_folder=videos_cache_folder
)

In [None]:
display_video(video_array=None, video_save_path=cute_dog_video_path)

To "show" an arbitrary video to our retina model, we need to make sure the input size and statistics match the ones used to train the model. Let's fetch them first from "data_info", and then use them to rescale and normalize this video.

In [None]:
model.data_info["input_shape"]

In [None]:
num_channels, target_height, target_width = model.data_info["input_shape"]

In [None]:
# Use moviepy to load the video into a numpy array
clip_object = VideoFileClip(cute_dog_video_path)
dog_clip_array = np.array(list(clip_object.iter_frames()))
dog_clip_array.shape

In [None]:
import cv2

resized_dog_clip = np.stack(
    [
        cv2.resize(
            frame,
            (target_width, target_height),
            interpolation=cv2.INTER_CUBIC,
        )
        for frame in dog_clip_array
    ],
    axis=0,
)

In [None]:
resized_dog_clip.shape

In [None]:
numpy_to_mp4_video(resized_dog_clip, fps=25)

Now that the video is the appropriate size, we still need to do two things: have it in two channels (as the mouse retina model we exported was trained on videos on the UV and green channels), and normalize the input range for the model.

In [None]:
model.data_info["movie_norm_dict"]

In [None]:
# We create a dummy UV channel by averaging the red and blue channels.
dog_clip_two_channels = np.stack(
    [
        resized_dog_clip[:, :, :, 1],
        (0.5 * resized_dog_clip[:, :, :, 0] + 0.5 * resized_dog_clip[:, :, :, 2]),
    ],
    axis=-1,
)

dog_clip_normalised = (
    dog_clip_two_channels - model.data_info["movie_norm_dict"]["default"]["norm_mean"]
) / model.data_info["movie_norm_dict"]["default"]["norm_std"]

In [None]:
numpy_to_mp4_video(dog_clip_normalised, fps=25)

Finally, we are ready to show the video to our retina model.

In [None]:
from einops import rearrange

example_session = model.readout.sessions[2]

# Put channel dimension first, as the model expects that.
dog_clip_normalised = rearrange(dog_clip_normalised, "t h w c -> c t h w")
dog_clip_tensor = torch.Tensor(dog_clip_normalised).to(model.device)

with torch.no_grad():
    predicted_dog_response = model.forward(dog_clip_tensor.unsqueeze(0), data_key=example_session)
predicted_dog_response_numpy = predicted_dog_response.squeeze().cpu().numpy()

# Extract metadata again
roi_mask = model.data_info["sessions_kwargs"][example_session]["roi_mask"]
roi_ids = model.data_info["sessions_kwargs"][example_session]["roi_ids"]
cell_types = model.data_info["sessions_kwargs"][example_session]["group_assignment"]

stim_video = prepare_video_for_display(dog_clip_normalised[:, 30:, ...])

response_video = create_roi_animation(
    roi_mask=roi_mask,
    activity=predicted_dog_response_numpy.T,
    roi_ids=roi_ids,
    cell_types=cell_types,
    type_boundaries=BADEN_TYPE_BOUNDARIES,
    max_activity=5,
    visualize_ids=False,
)

type_video = stitch_videos(stim_video, response_video)

numpy_to_mp4_video(type_video, fps=25)

And that is a wrap! We hope this notebook gave you some ideas on how to use a pre-trained retina model. To get a more in-depth view at training and other analyses possible within `openretina`, have a look at the other notebooks.

In [None]:
# Optionally, delete the video cache folder once you are done, to free up space.
# import shutil

# shutil.rmtree(videos_cache_folder)