In [4]:
%matplotlib widget

from pathlib import Path
from typing import Tuple
import pandas as pd
import ipywidgets as widgets
import numpy as np
import plotly.graph_objects as go
from IPython.display import clear_output
from ipywidgets import Layout
from ipywidgets import interactive
import pandas as pd

from viv1t import data
from viv1t.utils import utils

TICK_FONTSIZE = 10
LABEL_FONTSIZE = 11
TITLE_FONTSIZE = 12

FPS = 30

DATA_DIR = Path("../data/sensorium")
MOUSE_IDS = data.SENSORIUM_OLD

NEURON_COORDINATES = {
    mouse_id: data.get_neuron_coordinates(
        data_dir=DATA_DIR, mouse_id=mouse_id, to_tensor=False
    )
    for mouse_id in MOUSE_IDS
}

OUTPUT_DIR = Path("../runs") / "vivit" / "172_viv1t_causal"

DF = pd.read_parquet(OUTPUT_DIR / "feedbackRF" / "response_amplitudes.parquet")

COLORMAPS = {
    0: "black",  # classical
    1: "red",  # inverse
    2: "green",  # both
}

get_ticks = lambda coors: np.linspace(
    np.floor(np.min(coors) * 0.1) * 10 - 20,
    np.ceil(np.max(coors) * 0.1) * 10 + 20,
    5,
    dtype=int,
)


X_TICKS = {
    mouse_id: get_ticks(NEURON_COORDINATES[mouse_id][:, 0]) for mouse_id in MOUSE_IDS
}
Y_TICKS = {
    mouse_id: get_ticks(NEURON_COORDINATES[mouse_id][:, 1]) for mouse_id in MOUSE_IDS
}
Z_TICKS = {
    mouse_id: get_ticks(NEURON_COORDINATES[mouse_id][:, 2]) for mouse_id in MOUSE_IDS
}

In [5]:
def get_neurons_within_range(
    neuron_coordinates: np.ndarray,
    x_range: Tuple[int, int],
    y_range: Tuple[int, int],
    z_range: Tuple[int, int],
):
    """Return neurons that are within the x, y and z ranges"""
    x_neurons = np.where(
        np.logical_and(
            neuron_coordinates[:, 0] >= x_range[0],
            neuron_coordinates[:, 0] <= x_range[1],
        )
    )[0]
    y_neurons = np.where(
        np.logical_and(
            neuron_coordinates[:, 1] >= y_range[0],
            neuron_coordinates[:, 1] <= y_range[1],
        )
    )[0]
    z_neurons = np.where(
        np.logical_and(
            neuron_coordinates[:, 2] >= z_range[0],
            neuron_coordinates[:, 2] <= z_range[1],
        )
    )[0]
    neurons = np.intersect1d(x_neurons, y_neurons, assume_unique=True)
    neurons = np.intersect1d(neurons, z_neurons, assume_unique=True)
    return neurons

In [6]:
clear_output(wait=True)

margin = "0.1em 1em 0.1em"
style = {"description_width": "initial", "indent": False}
layout = Layout(width="20em", margin=margin)

mouse_id_dropdown = widgets.Dropdown(
    options=MOUSE_IDS,
    value="A",
    description="Mouse ID",
    style=style,
    layout=layout,
)

opacity_slider = widgets.FloatSlider(
    min=0.0,
    max=1.0,
    step=0.01,
    value=0.8,
    description="Marker opacity",
    style=style,
    layout=layout,
)

marker_size_slider = widgets.IntSlider(
    min=1,
    max=10,
    step=1,
    value=5,
    description="Marker size",
    style=style,
    layout=layout,
)

reset_button = widgets.Button(description="Reset", style=style, layout=layout)


coordinate_layout = Layout(width="25em", margin=margin)

x_coordinates_slider = widgets.IntRangeSlider(
    value=[0, 10],
    step=10,
    description="x range (μm)",
    disabled=False,
    style=style,
    layout=coordinate_layout,
)

y_coordinates_slider = widgets.IntRangeSlider(
    value=[0, 10],
    step=10,
    description="y range (μm)",
    disabled=False,
    style=style,
    layout=coordinate_layout,
)

z_coordinates_slider = widgets.IntRangeSlider(
    value=[0, 10],
    step=10,
    description="z range (μm)",
    disabled=False,
    style=style,
    layout=coordinate_layout,
)


def set_slider(slider, min_value, max_value):
    if slider.max < min_value:
        slider.max = max_value
        slider.min = min_value
    else:
        slider.min = min_value
        slider.max = max_value


def set_mouse_id(mouse_id: str):
    x_min = NEURON_COORDINATES[mouse_id][:, 0].min()
    x_max = NEURON_COORDINATES[mouse_id][:, 0].max()
    set_slider(x_coordinates_slider, x_min, x_max)
    x_coordinates_slider.value = [x_coordinates_slider.min, x_coordinates_slider.max]

    y_min = NEURON_COORDINATES[mouse_id][:, 1].min()
    y_max = NEURON_COORDINATES[mouse_id][:, 1].max()
    set_slider(y_coordinates_slider, y_min, y_max)
    y_coordinates_slider.value = [y_coordinates_slider.min, y_coordinates_slider.max]

    # z_max = NEURON_COORDINATES[mouse_id][:, 2].max()
    z_max = 300
    z_min = NEURON_COORDINATES[mouse_id][:, 2].min()
    set_slider(z_coordinates_slider, z_min, z_max)
    z_coordinates_slider.value = [z_coordinates_slider.min, z_coordinates_slider.max]


def reset(b):
    opacity_slider.value = 0.8
    marker_size_slider.value = 5
    x_coordinates_slider.value = [x_coordinates_slider.min, x_coordinates_slider.max]
    y_coordinates_slider.value = [y_coordinates_slider.min, y_coordinates_slider.max]
    z_coordinates_slider.value = [z_coordinates_slider.min, z_coordinates_slider.max]


figure = go.FigureWidget()


def animate_sample(
    mouse_id: str,
    opacity: float,
    marker_size: int,
    x_range: Tuple[int, int],
    y_range: Tuple[int, int],
    z_range: Tuple[int, int],
):
    neuron_coordinates = NEURON_COORDINATES[mouse_id]

    df = DF.loc[(DF["mouse"] == mouse_id) & (DF["depth"] >= 200) & (DF["depth"] <= 300)]

    if not figure.data:
        for i, cell_type in enumerate(("Classical", "Inverse", "Both")):
            figure.add_trace(
                go.Scatter3d(
                    name=cell_type,
                    mode="markers",
                    marker={"color": COLORMAPS[i]},
                )
            )

        # add a dummy trace to keep depth (z-axis) consistent
        depths = np.unique(neuron_coordinates[:, 2])
        depths = depths[np.where((depths >= 200) & (depths <= 300))]
        figure.add_trace(
            go.Scatter3d(
                mode="markers",
                x=[neuron_coordinates[0, 0]] * len(depths),
                y=[neuron_coordinates[0, 1]] * len(depths),
                z=depths,
                marker={"size": 1, "opacity": 0},
                showlegend=False,
            )
        )
        # change default camera angle
        figure.update_layout(scene_camera=dict(center=dict(x=0, y=0, z=-0.2)))

    range_neuron = get_neurons_within_range(
        neuron_coordinates, x_range=x_range, y_range=y_range, z_range=z_range
    )
    # get neuron types
    classic_neurons = df.loc[
        (df["classic_tuned"] == True) & (df["inverse_tuned"] == False)
    ].neuron.values
    inverse_neurons = df.loc[
        (df["classic_tuned"] == False) & (df["inverse_tuned"] == True)
    ].neuron.values
    mixed_neurons = df.loc[
        (df["classic_tuned"] == True) & (df["inverse_tuned"] == True)
    ].neuron.values

    with figure.batch_update():
        for i, neurons in enumerate([classic_neurons, inverse_neurons, mixed_neurons]):
            neurons = np.intersect1d(neurons, range_neuron, assume_unique=True)
            figure.data[i].x = neuron_coordinates[neurons, 0]
            figure.data[i].y = neuron_coordinates[neurons, 1]
            figure.data[i].z = neuron_coordinates[neurons, 2]
            figure.data[i].marker.size = marker_size
            figure.data[i].marker.opacity = opacity

    layout = go.Layout(
        margin={"l": 0, "r": 0, "b": 0, "t": 0},
        height=850,
        width=1400,
        scene={
            "aspectmode": "manual",
            "aspectratio": {"x": 1, "y": 1, "z": 0.6},
            "xaxis": {
                "title": "x coordinate (μm)",
                "nticks": 5,
                "range": [X_TICKS[mouse_id][0], X_TICKS[mouse_id][-1]],
            },
            "yaxis": {
                "title": "y coordinate (μm)",
                "nticks": 5,
                "range": [Y_TICKS[mouse_id][0], Y_TICKS[mouse_id][-1]],
            },
            "zaxis": {
                "title": "z coordinate (μm)",
                "nticks": 5,
                "range": [Z_TICKS[mouse_id][0], Z_TICKS[mouse_id][-1]],
            },
        },
        legend={
            "itemsizing": "constant",
            "xanchor": "auto",
            "x": 0,
            "yanchor": "auto",
            "y": 1,
            "font": {"size": 16},
        },
    )

    figure.update_scenes(zaxis_autorange="reversed")
    figure.update_layout(layout)


interactive(set_mouse_id, mouse_id=mouse_id_dropdown)
interactive(
    animate_sample,
    mouse_id=mouse_id_dropdown,
    opacity=opacity_slider,
    marker_size=marker_size_slider,
    x_range=x_coordinates_slider,
    y_range=y_coordinates_slider,
    z_range=z_coordinates_slider,
)
reset_button.on_click(reset)

display(widgets.HBox([mouse_id_dropdown, reset_button]))
display(widgets.HBox([opacity_slider, marker_size_slider]))
display(
    widgets.HBox([x_coordinates_slider, y_coordinates_slider, z_coordinates_slider])
)
display(figure)

HBox(children=(FloatSlider(value=0.8, description='Marker opacity', layout=Layout(margin='0.1em 1em 0.1em', wi…

HBox(children=(IntRangeSlider(value=(-1058, -443), description='x range (μm)', layout=Layout(margin='0.1em 1em…

FigureWidget({
    'data': [{'marker': {'color': 'black', 'opacity': 0.8, 'size': 5},
              'mode': 'markers',
              'name': 'Classical',
              'type': 'scatter3d',
              'uid': '9ce08170-26a0-4c1f-bc33-0f3bb3c60237',
              'x': array([-589., -541., -535., ..., -546., -572., -667.], dtype=float32),
              'y': array([-677., -635., -649., ..., -337., -558., -598.], dtype=float32),
              'z': array([275., 275., 275., ..., 225., 225., 225.], dtype=float32)},
             {'marker': {'color': 'red', 'opacity': 0.8, 'size': 5},
              'mode': 'markers',
              'name': 'Inverse',
              'type': 'scatter3d',
              'uid': 'e5dba29e-fb81-4ebe-b9d6-aa88ee571f67',
              'x': array([-795., -803., -772., ..., -745., -488., -899.], dtype=float32),
              'y': array([-694., -666., -688., ..., -289., -603., -706.], dtype=float32),
              'z': array([275., 275., 275., ..., 225., 225., 225.], dtype=