In [None]:
%matplotlib widget

from pathlib import Path

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from IPython.display import clear_output
from ipywidgets import Layout
from ipywidgets import interactive
from matplotlib import cm

from matplotlib import colormaps
from matplotlib.animation import FuncAnimation
from matplotlib.colors import LinearSegmentedColormap
from viv1t import data
from viv1t.data import get_neuron_coordinates
from viv1t.utils import plot
from viv1t.utils import utils

plot.set_font()
plt.style.use("seaborn-v0_8-deep")
plt.rcParams["animation.html"] = "jshtml"
plt.rcParams["animation.embed_limit"] = 512

TICK_FONTSIZE = 10
LABEL_FONTSIZE = 11
TITLE_FONTSIZE = 12

DATA_DIR = Path("../data")
PLOT_DIR = Path("figures/response_3d")

white = "#ffffff"
colormaps.register(LinearSegmentedColormap.from_list("0", [white, "#1d4e31"]))
colormaps.register(LinearSegmentedColormap.from_list("45", [white, "#5dc685"]))
colormaps.register(LinearSegmentedColormap.from_list("90", [white, "#e21d23"]))
colormaps.register(LinearSegmentedColormap.from_list("135", [white, "#ef682e"]))
colormaps.register(LinearSegmentedColormap.from_list("180", [white, "#ffc533"]))
colormaps.register(LinearSegmentedColormap.from_list("225", [white, "#9d926c"]))
colormaps.register(LinearSegmentedColormap.from_list("270", [white, "#3b5fa5"]))
colormaps.register(LinearSegmentedColormap.from_list("315", [white, "#c2d3ea"]))

In [None]:
def get_tier_ids(mouse_id: str):
    return data.get_tier_ids(data_dir=DATA_DIR, mouse_id=mouse_id)


def get_data(mouse_id: str, trial_id: int, mean_response_over_repeat: bool = False):
    mouse_dir = DATA_DIR / data.MOUSE_IDS[mouse_id]
    if mean_response_over_repeat:
        # get all trials with the same video ID
        video_ids = data.get_video_ids(mouse_id)
        trial_ids = np.where(video_ids == video_ids[trial_id])[0]
        samples = {i: data.load_trial(mouse_dir, trial_id=i) for i in trial_ids}
        sample = samples[trial_id]
        # compute average response over repeats
        sample["response"] = np.mean(
            np.stack([samples[i]["response"] for i in trial_ids]), axis=0
        )
        del samples
    else:
        sample = data.load_trial(mouse_dir, trial_id=trial_id)
    sample["neuron_coordinates"] = get_neuron_coordinates(mouse_id=mouse_id)
    return sample


def normalize(response: np.ndarray):
    r_min = np.min(response, axis=0)
    r_max = np.max(response, axis=0)
    return (response - r_min) / (r_max - r_min)


def get_colormap(si: str, direction: int):
    match (si, direction):
        case ("OSI", None):
            cmap = "Oranges"
        case ("DSI", None):
            cmap = "Greens"
        case ("SSI", None):
            cmap = "GnBu"
        case ("OSI", 0):
            cmap = "0"
        case ("OSI", 45):
            cmap = "90"
        case ("OSI", 90):
            cmap = "180"
        case ("OSI", 135):
            cmap = "270"
        case ("DSI", 0):
            cmap = "0"
        case ("DSI", 45):
            cmap = "45"
        case ("DSI", 90):
            cmap = "90"
        case ("DSI", 135):
            cmap = "135"
        case ("DSI", 180):
            cmap = "180"
        case ("DSI", 225):
            cmap = "225"
        case ("DSI", 270):
            cmap = "270"
        case ("DSI", 315):
            cmap = "315"
        case _:
            raise ValueError(f"Unknown SI: {si}, direction: {direction}.")
    return cm.ScalarMappable(cmap=cmap)


def animate_sample(
    figure: plt.Figure,
    mouse_id: str,
    trial_id: int,
    top_neuron: float = 0.1,
    mean_response_over_repeat: bool = False,
    fps: int = 30,
    si: str = "None",
    si_threshold: float = 0.3,
    si_by_direction: bool = False,
    save: bool = False,
    dpi: int = 120,
):
    assert si in ("None", "OSI", "DSI", "SSI")

    sample = get_data(
        mouse_id=mouse_id,
        trial_id=trial_id,
        mean_response_over_repeat=mean_response_over_repeat,
    )
    video = sample["video"]
    pupil_center = sample["pupil_center"]
    behavior = sample["behavior"]
    response = sample["response"]

    _, t, h, w = video.shape

    # normalize response per frame
    # response = np.maximum(response / np.max(response, axis=0), 0)
    response = normalize(response)

    total_neurons = response.shape[0]  # total number of neurons in the population
    neuron_coordinates = sample["neuron_coordinates"]
    num_neurons = int(top_neuron * total_neurons)  # number of neurons to plot

    si_neurons, preferred_directions, directions = None, None, None
    if si != "None":
        tuning = utils.load_tuning(data.METADATA_DIR, mouse_id=mouse_id)
        match si:
            case "OSI" | "DSI" | "SSI":
                si_values = tuning[si]
            case _:
                raise ValueError(f"Unknown selectivity index: {si}")
        si_neurons = np.where(si_values >= si_threshold)[0]
        if si_by_direction:
            directions = np.array(
                list(tuning["tuning_curves"][0].keys()), dtype=np.int32
            )
            tuning_curves = np.array(
                [
                    np.array(list(tuning_curve.values()))
                    for tuning_curve in tuning["tuning_curves"].values()
                ],
                dtype=np.float32,
            )
            if si == "OSI":
                # combine opposite directions for a single orientation value
                directions = directions[:4]
                tuning_curves = tuning_curves[:, :4] + tuning_curves[:, 4:]
            preferred_directions = np.argmax(tuning_curves, axis=-1)
            del tuning_curves
        del tuning

    if figure is None:
        f_h, f_w = 6, 10
        figure = plt.figure(figsize=(f_w, f_h), dpi=dpi, facecolor="white")
        axes, cbar_axes = [], []
        get_width = lambda height: height * (w / h) * (f_h / f_w)
        height = 0.3
        width = get_width(height)
        # axes[0] video
        axes.append(figure.add_axes(rect=[0.04, 0.61, width, height]))
        text_top = 0.61 + height
        top, height = 0.5, 0.05
        width = 0.8 * width
        gap = height + 0.06
        left = 0.1
        # axes[1] pupil size
        axes.append(figure.add_axes(rect=[left, top, width, height]))
        # axes[2] motor speed
        axes.append(figure.add_axes(rect=[left, top - gap, width, height]))
        # axes[3] pupil center x
        axes.append(figure.add_axes(rect=[left, top - 2 * gap, width, height]))
        # axes[4] pupil center y
        axes.append(figure.add_axes(rect=[left, top - 3 * gap, width, height]))
        bottom = 0.03
        # axes[5] response
        axes.append(figure.add_axes(rect=[0.24, bottom, 0.85, 0.85], projection="3d"))
        # axes[6] area on the top-right of the figure to show text
        axes.append(figure.add_axes(rect=[0.4, text_top, 0.05, 0.05]))

        bottom = 0.1
        # color bars
        cbar_left = 0.98
        for i in range(9):
            cbar_left -= 0.045
            cbar_axes.append(
                figure.add_axes(
                    rect=[
                        cbar_left,
                        # bottom,
                        # text_top - 0.04,
                        text_top - 0.138,
                        0.008,
                        0.06,
                    ]
                )
            )
        del f_h, f_w, get_width, height, width, top, gap, left, cbar_left
    else:
        axes = figure.get_axes()
        axes, cbar_axes = axes[:7], axes[7:]

    pupil_size_yticks = np.linspace(
        np.nanmin(behavior[0, :]), np.nanmax(behavior[0, :]), 2
    )
    speed_yticks = np.linspace(np.nanmin(behavior[1, :]), np.nanmax(behavior[1, :]), 2)
    pupil_center_x_yticks = np.linspace(
        np.nanmin(pupil_center[0, :]), np.nanmax(pupil_center[0, :]), 2
    )
    pupil_center_y_yticks = np.linspace(
        np.nanmin(pupil_center[1, :]), np.nanmax(pupil_center[1, :]), 2
    )

    frame_ticks = np.linspace(0, t, 4, dtype=int)

    get_ticks = lambda coors: np.linspace(
        np.floor(np.min(coors) * 0.1) * 10 - 20,
        np.ceil(np.max(coors) * 0.1) * 10 + 20,
        5,
        dtype=int,
    )
    x_ticks = get_ticks(neuron_coordinates[:, 0])
    y_ticks = get_ticks(neuron_coordinates[:, 1])
    z_ticks = get_ticks(neuron_coordinates[:, 2])
    response_title = "Normalized response"
    if mean_response_over_repeat:
        response_title += " (average over repeats)"

    cbar_mappable = cm.ScalarMappable(cmap="Greys")
    cbar_mappable.set_clim(0, 1)
    plt.colorbar(cbar_mappable, cax=cbar_axes[0], shrink=0.5)
    cbar_yticks = np.linspace(0, 1, 2, dtype=int)
    plot.set_yticks(
        axis=cbar_axes[0],
        ticks=cbar_yticks,
        tick_labels=cbar_yticks,
        tick_fontsize=TICK_FONTSIZE,
    )

    for cbar_ax in cbar_axes[1:]:
        cbar_ax.cla()
        cbar_ax.set_axis_off()

    if si_neurons is None:
        cbar_axes[1].set_axis_off()
        si_mappable = None
    elif not si_by_direction:
        cbar_axes[1].set_axis_on()
        si_mappable = get_colormap(si=si, direction=None)
        si_mappable.set_clim(0, 1)
        plt.colorbar(si_mappable, cax=cbar_axes[1], shrink=0.5)
        cbar_axes[1].set_yticks([])
        cbar_axes[1].text(
            x=1.4,
            y=0,
            s=si,
            ha="left",
            va="center",
            fontsize=TICK_FONTSIZE,
            transform=cbar_axes[1].transAxes,
        )
    else:
        si_mappable = {}
        for i, direction in enumerate(directions[::-1]):
            si_mappable[direction] = get_colormap(si=si, direction=direction)
            si_mappable[direction].set_clim(0, 1)

            cbar_ax = cbar_axes[i + 1]
            cbar_ax.set_axis_on()
            plt.colorbar(si_mappable[direction], cax=cbar_ax, shrink=0.5)
            cbar_ax.set_yticks([])
            cbar_ax.text(
                x=1.4,
                y=0,
                s=f"{direction}°",
                ha="left",
                va="center",
                fontsize=TICK_FONTSIZE,
                transform=cbar_ax.transAxes,
            )

    color_value = 1.0 if si_by_direction else 0.8
    scatter_kwargs = {
        # "vmin": 0,
        # "vmax": 1,
        "s": 10,
        # "alpha": 0.9,
        "depthshade": False,
    }
    line_kwargs = {
        "color": "black",
        "linewidth": 1.5,
        "clip_on": False,
    }

    pos1 = axes[0].get_position()
    pos5 = axes[5].get_position()

    def animate(frame: int):
        for i in range(7):
            axes[i].cla()
        axes[6].set_axis_off()
        # plot movie frame
        axes[0].imshow(video[0, frame, :, :], cmap="gray", aspect="equal")
        axes[0].text(
            x=0,
            y=pos1.y1 + 0.13,
            s=f"Frame: {frame:03d}",
            ha="left",
            va="center",
            fontsize=LABEL_FONTSIZE,
            transform=axes[0].transAxes,
        )
        axes[0].set_xticks([])
        axes[0].set_yticks([])

        # behavior - pupil size
        axes[1].plot(behavior[0, :frame], **line_kwargs)
        axes[1].set_ylim(pupil_size_yticks[0], pupil_size_yticks[1])
        plot.set_yticks(
            axes[1],
            ticks=pupil_size_yticks,
            tick_labels=pupil_size_yticks.astype(int),
            tick_fontsize=TICK_FONTSIZE,
        )
        axes[1].set_title("Pupil size", pad=0, fontsize=LABEL_FONTSIZE)

        # behavior - speed
        axes[2].plot(behavior[1, :frame], **line_kwargs)
        axes[2].set_ylim(speed_yticks[0], speed_yticks[1])
        plot.set_yticks(
            axes[2],
            ticks=speed_yticks,
            tick_labels=[f"{value:.0e}" for value in speed_yticks],
            tick_fontsize=TICK_FONTSIZE,
        )
        axes[2].set_title("Locomotion speed", pad=0, fontsize=LABEL_FONTSIZE)

        # pupil center - horizontal
        axes[3].plot(pupil_center[0, :frame], **line_kwargs)
        axes[3].set_ylim(pupil_center_x_yticks[0], pupil_center_x_yticks[1])
        plot.set_yticks(
            axes[3],
            ticks=pupil_center_x_yticks,
            tick_labels=[f"{value:.01e}" for value in pupil_center_x_yticks],
            tick_fontsize=TICK_FONTSIZE,
        )
        axes[3].set_title("Pupil center (x)", pad=0, fontsize=LABEL_FONTSIZE)

        # pupil center - horizontal
        axes[4].plot(pupil_center[1, :frame], **line_kwargs)
        axes[4].set_ylim(pupil_center_y_yticks[0], pupil_center_y_yticks[1])
        plot.set_yticks(
            axes[4],
            ticks=pupil_center_y_yticks,
            tick_labels=[f"{value:.01e}" for value in pupil_center_y_yticks],
            tick_fontsize=TICK_FONTSIZE,
        )
        axes[4].set_title("Pupil center (y)", pad=0, fontsize=LABEL_FONTSIZE)

        # color and alpha (RBGA) values for all neurons
        colors = np.zeros((total_neurons, 4), dtype=np.float32)

        # compute top neurons to plot per-frame
        neurons = np.argsort(response[:, frame])[-num_neurons:]

        # set color values for neurons to plot
        colors[neurons] = cbar_mappable.to_rgba(color_value)

        if si_neurons is not None:
            # within the top neurons to plot, plot the top SI neurons
            _si_neurons = np.intersect1d(neurons, si_neurons, assume_unique=True)
            # show the percentage of top neurons that are SI neurons
            axes[6].text(
                x=0,
                y=0,
                s=f"SI neurons: {100 * len(_si_neurons)  / len(neurons):.0f}%\n"
                f"SI threshold: {si_threshold:.2f}",
                ha="left",
                va="center",
                fontsize=TICK_FONTSIZE,
                transform=axes[6].transAxes,
            )
            if si_by_direction:
                for i, direction in enumerate(directions):
                    direction_neurons = np.where(preferred_directions == i)[0]
                    direction_neurons = np.intersect1d(
                        _si_neurons, direction_neurons, assume_unique=True
                    )
                    colors[direction_neurons] = si_mappable[direction].to_rgba(
                        color_value
                    )
            else:
                colors[_si_neurons] = si_mappable.to_rgba(color_value)

        # set alpha values based on response values
        colors[neurons, -1] = np.maximum(response[neurons, frame], 0.025)
        # colors[neurons, -1] = response[neurons, frame]

        axes[5].scatter(
            neuron_coordinates[neurons, 0],
            neuron_coordinates[neurons, 1],
            neuron_coordinates[neurons, 2],
            c=colors[neurons],
            **scatter_kwargs,
        )
        axes[5].set_xlabel("x coordinate (μm)", fontsize=TICK_FONTSIZE)
        axes[5].set_ylabel("y coordinate (μm)", fontsize=TICK_FONTSIZE)
        axes[5].set_zlabel("z coordinate (μm)", fontsize=TICK_FONTSIZE)
        axes[5].set_title(response_title, fontsize=LABEL_FONTSIZE, y=pos5.y1 + 0.05)
        axes[5].set_xlim(x_ticks[0], x_ticks[-1])
        axes[5].set_xticks(x_ticks)
        axes[5].set_ylim(y_ticks[0], y_ticks[-1])
        axes[5].set_yticks(y_ticks)
        axes[5].set_zlim(z_ticks[0], z_ticks[-1])
        axes[5].set_zticks(z_ticks)
        axes[5].invert_zaxis()
        for axis in (axes[5].xaxis, axes[5].yaxis, axes[5].zaxis):
            axis.set_pane_color((0.0, 0.0, 0.0, 0.05))
            axis._axinfo["grid"]["color"] = (0.0, 0.0, 0.0, 0.3)

        for i in [1, 2, 3, 4]:
            axes[i].set_xlim(frame_ticks[0], frame_ticks[-1])
            sns.despine(ax=axes[i])
            if i == 4:
                plot.set_xticks(
                    axes[i],
                    ticks=frame_ticks,
                    tick_labels=frame_ticks,
                    label="Movie frame",
                    tick_fontsize=TICK_FONTSIZE,
                    label_fontsize=LABEL_FONTSIZE,
                )
            else:
                axes[i].set_xticks([])
        for ax in axes + cbar_axes:
            plot.set_ticks_params(ax, length=2.5)

    anim = FuncAnimation(figure, animate, frames=t, interval=int(1000 / fps))
    if save:
        PLOT_DIR.mkdir(exist_ok=True, parents=True)
        anim.save(
            PLOT_DIR / f"mouse{mouse_id}_trial{trial_id}.gif",
            fps=fps,
            dpi=dpi,
            savefig_kwargs={"pad_inches": 0},
        )
    return figure, anim

In [None]:
# clear existing figure and anim variables if they exist
if "figure" in (globals(), locals()) and figure is not None:
    plt.close(figure)
    del anim
    clear_output(wait=True)


figure, anim = None, None

style = {"description_width": "initial", "indent": False}
layout = Layout(width="20em", margin="0.5em 1em 0.5em")

mouse_id_dropdown = widgets.Dropdown(
    options=data.SENSORIUM,
    value="A",
    description="Mouse ID",
    style=style,
    layout=layout,
)

tier_dropdown = widgets.Dropdown(
    options=[
        ("Train", "train"),
        ("Validation", "validation"),
        ("Live main", "live_main"),
        ("Live bonus", "live_bonus"),
        ("Final main", "final_main"),
        ("Final bonus", "final_bonus"),
    ],
    value="validation",
    description="Dataset",
    style=style,
    layout=layout,
)

trial_id_dropdown = widgets.Dropdown(
    options=["None"],
    value="None",
    description="Trial ID",
    style=style,
    layout=layout,
)

repeat_checkbox = widgets.Checkbox(
    value=False,
    description="Average response over repeats",
    disabled=False,
    style=style,
    layout=layout,
)

neuron_quantile_slider = widgets.FloatSlider(
    min=0.01,
    max=1.0,
    step=0.01,
    value=1.0,
    description="Activity quantile",
    style=style,
    layout=layout,
)

si_dropdown = widgets.Dropdown(
    options=["None", "DSI", "OSI", "SSI"],
    value="None",
    disabled=True,
    description="SI",
    stype=style,
    layout=layout,
)

si_threshold_slider = widgets.FloatSlider(
    min=0.0,
    max=1.0,
    step=0.01,
    description="SI threshold",
    disabled=True,
    style=style,
    layout=layout,
    continuous_update=True,
)

si_by_direction = widgets.Checkbox(
    value=False,
    description="color by preference",
    disabled=True,
    style=style,
    layout=layout,
)

fps_slider = widgets.IntSlider(
    min=1, max=30, step=1, value=30, description="FPS", style=style, layout=layout
)

animate_button = widgets.Button(description="Animate", style=style, layout=layout)

save_checkbox = widgets.Checkbox(
    value=False,
    description="Save animation",
    disabled=False,
    style=style,
    layout=layout,
)


def _set_trial_ids(tiers: np.ndarray, tier: str):
    trial_id_dropdown.options = np.where(tiers == tier)[0]
    if tier == "train":
        repeat_checkbox.value = False
        repeat_checkbox.disabled = True
    else:
        repeat_checkbox.disabled = False


def set_mouse_id(mouse_id):
    print(mouse_id)
    tiers = get_tier_ids(mouse_id)
    tier = tier_dropdown.value
    _set_trial_ids(tiers, tier)
    if mouse_id in data.SENSORIUM_OLD:
        si_dropdown.disabled = False
    else:
        si_dropdown.disabled = True
    si_options = ["None"]
    if mouse_id in ("B", "C", "E"):
        si_options += ["OSI", "DSI"]
    if mouse_id in ("A", "D", "E"):
        si_options += ["SSI"]
    si_dropdown.options = si_options
    si_dropdown.value = "None"


def set_tier(tier: str):
    tiers = get_tier_ids(mouse_id_dropdown.value)
    _set_trial_ids(tiers, tier)


output = widgets.Output()


@output.capture()
def set_si(si_type: str):
    mouse_id = mouse_id_dropdown.value
    tuning = utils.load_tuning(data.METADATA_DIR, mouse_id=mouse_id)
    si_min, si_max = None, None
    current_max = si_threshold_slider.max
    match si_type:
        case "None":
            si_threshold_slider.disabled = True
            si_threshold_slider.value = 0.0
            si_by_direction.value = False
            si_by_direction.disabled = True
        case "OSI" | "DSI":
            si_values = tuning[si_type]
            si_min, si_max = np.nanmin(si_values), np.nanmax(si_values)
            si_by_direction.disabled = False
        case "SSI":
            si_values = tuning[si_type]
            si_min, si_max = np.min(si_values), np.max(si_values)
            si_by_direction.value = False
            si_by_direction.disabled = True
        case _:
            raise ValueError(f"Unknown SI: {si_type}")

    if si_min is not None and si_max is not None:
        if current_max < si_min:
            si_threshold_slider.max, si_threshold_slider.min = si_max, si_min
        else:
            si_threshold_slider.min, si_threshold_slider.max = si_min, si_max
        si_threshold_slider.disabled = False
        si_threshold_slider.value = np.round(si_min, 1)


interactive(set_mouse_id, mouse_id=mouse_id_dropdown)
interactive(set_tier, tier=tier_dropdown)
interactive(set_si, si_type=si_dropdown)

display(widgets.HBox([mouse_id_dropdown, tier_dropdown, trial_id_dropdown]))
display(widgets.HBox([repeat_checkbox, neuron_quantile_slider]))
display(widgets.HBox([si_dropdown, si_threshold_slider, si_by_direction]))
display(widgets.HBox([fps_slider, animate_button, save_checkbox]))


def on_button_clicked(b):
    global figure, anim
    if anim is not None:
        anim.event_source.stop()
    figure, anim = animate_sample(
        figure=figure,
        mouse_id=mouse_id_dropdown.value,
        trial_id=trial_id_dropdown.value,
        top_neuron=neuron_quantile_slider.value,
        mean_response_over_repeat=repeat_checkbox.value,
        fps=fps_slider.value,
        si=si_dropdown.value,
        si_threshold=si_threshold_slider.value,
        si_by_direction=si_by_direction.value,
        save=save_checkbox.value,
    )
    anim.event_source.start()


animate_button.on_click(on_button_clicked)