In [1]:
import argparse
import warnings
from pathlib import Path
from typing import Dict

import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import seaborn as sns
import torch
from einops import rearrange
from tqdm import tqdm

from viv1t import data
from viv1t.utils import plot

plt.style.use("seaborn-v0_8-deep")
plot.set_font()

matplotlib.rcParams["animation.embed_limit"] = 2**64

FONTSIZE = 11
FPS = 30
DPI = 180
COLORMAP = "turbo"

DATA_DIR = Path("../data/sensorium")
PLOT_DIR = Path("figures/raw_data")

In [2]:
args = argparse.Namespace()
args.data_dir = Path("../data/sensorium")
args.mouse_ids = list(data.SENSORIUM)
args.ds_mode = 2
args.transform_input = 0
args.transform_output = 0
args.batch_size = 1
args.crop_frame = -1
args.limit_data = None
args.device = torch.device("cpu")
args.num_workers = 2
args.verbose = 2

_, val_ds, _ = data.get_training_ds(
    args,
    data_dir=args.data_dir,
    mouse_ids=args.mouse_ids,
    batch_size=args.batch_size,
    device=args.device,
)

_, test_ds = data.get_submission_ds(
    args,
    data_dir=args.data_dir,
    mouse_ids=args.mouse_ids,
    batch_size=args.batch_size,
    device=args.device,
)

plot_dir = Path("figures/raw_data")

In [3]:
def normalize(x: np.ndarray):
    return (x - x.min()) / (x.max() - x.min())


def animate_trial(
    sample: Dict[str, np.ndarray], frames: np.ndarray = None, filename: Path = None
):
    if frames is None:
        frames = np.arange(sample["video"].shape[1])
    video = sample["video"] / 255.0
    behavior = sample["behavior"]
    pupil_center = sample["pupil_center"]
    _, t, h, w = video.shape

    figure_width, figure_height = 3.5, 2.28
    figure = plt.figure(
        figsize=(figure_width, figure_height), dpi=DPI, facecolor="white"
    )
    get_height = lambda x: x * (h / w) * (figure_width / figure_height)

    # spatial attention
    width = 0.997
    ax = figure.add_axes(rect=(0.001, 0.135, width, get_height(width)))

    text_kwargs = {
        "y": -0.08,
        "va": "center",
        "fontsize": FONTSIZE,
        "transform": ax.transAxes,
        "linespacing": 0.85,
    }

    def update(frame: int):
        ax.cla()
        ax.imshow(video[0, frame], cmap="gray", aspect="equal", vmin=0, vmax=1)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.text(
            x=0.98,
            y=0.97,
            s=f"Frame {frame:03d}",
            ha="right",
            va="top",
            color="orangered",
            alpha=0.8,
            fontsize=FONTSIZE,
            transform=ax.transAxes,
        )
        ax.text(
            x=0, s=f"{behavior[0, frame]:.1e}\npupil size", ha="left", **text_kwargs
        )
        ax.text(x=0.4, s=f"{behavior[1, frame]:.1e}\nspeed", ha="center", **text_kwargs)
        ax.text(
            x=1.0,
            s=f"({pupil_center[0, frame]:.0f}, {pupil_center[1, frame]:.0f})\npupil center",
            ha="right",
            **text_kwargs,
        )
        for spine in ax.spines.values():
            spine.set_linewidth(2)

    ani = animation.FuncAnimation(figure, update, frames=frames, interval=1000 / FPS)
    if filename is not None:
        filename.parent.mkdir(parents=True, exist_ok=True)
        ani.save(filename, fps=FPS, dpi=DPI, savefig_kwargs={"pad_inches": 0.1})
    plt.close(figure)

In [4]:
FRAMES = np.arange(50, 300)
trials = {"F": [0, 1, 2], "G": [1, 4], "J": [40]}
for mouse_id, indexes in trials.items():
    for index in tqdm(indexes, desc=f"mouse {mouse_id}"):
        sample = val_ds[mouse_id].dataset.__getitem__(index, to_tensor=False)
        animate_trial(
            sample,
            frames=FRAMES,
            filename=plot_dir
            / f"mouse{mouse_id}"
            / "validation"
            / f"mouse{mouse_id}_input{index:03d}.mp4",
        )

mouse F:   0%|          | 0/3 [00:00<?, ?it/s]


TypeError: clamp() received an invalid combination of arguments - got (numpy.ndarray, min=int), but expected one of:
 * (Tensor input, Tensor min = None, Tensor max = None, *, Tensor out = None)
 * (Tensor input, Number min = None, Number max = None, *, Tensor out = None)


In [5]:
def standardize(response: np.ndarray):
    return (response - np.mean(response, axis=1, keepdims=True)) / (
        np.std(response, axis=1, keepdims=True) + np.finfo(np.float32).eps
    )


def plot_trial(
    sample: Dict[str, np.ndarray],
    mouse_id: str,
    random_neuron: bool = False,
    filename: Path = None,
):
    figure = plt.figure(figsize=(8, 6), dpi=DPI, facecolor="white")

    width = 0.25
    ax1 = figure.add_axes(rect=[0.01, 0.6, width, 0.3])  # video
    pos1 = ax1.get_position()
    top, height, gap = pos1.y0 - 0.08, 0.08, 0.06
    ax2 = figure.add_axes(rect=[0.01, top, width, height])  # pupil dilation
    ax3 = figure.add_axes(rect=[0.01, top - (height + gap), width, height])  # speed
    ax4 = figure.add_axes(
        rect=[0.01, top - 2 * (height + gap), width, height]
    )  # pupil center x
    ax5 = figure.add_axes(
        rect=[0.01, top - 3 * (height + gap), width, height]
    )  # pupil center y

    pos5 = ax5.get_position()
    width, height = 0.3, 3 * (height + gap) + 0.3
    ax6 = figure.add_axes(rect=[pos1.x1 + 0.1, pos5.y0, width, height])  # response

    h, w = sample["video"].shape[2], sample["video"].shape[3]
    t = sample["video"].shape[1]
    n = sample["response"].shape[0]
    movie_xticks = np.linspace(0, w - 1, 2, dtype=int)
    movie_yticks = np.linspace(0, h - 1, 2, dtype=int)
    response_xticks = np.linspace(0, t - 1, 4)
    response_yticks = np.linspace(0, n - 1, 10)

    pupil_size_yticks = np.linspace(
        np.nanmin(sample["behavior"][0, :]),
        np.nanmax(sample["behavior"][0, :]),
        2,
    )
    speed_yticks = np.linspace(
        np.nanmin(sample["behavior"][1, :]),
        np.nanmax(sample["behavior"][1, :]),
        2,
    )
    pupil_center_x_yticks = np.linspace(
        np.nanmin(sample["pupil_center"][0, :]),
        np.nanmax(sample["pupil_center"][0, :]),
        2,
    )
    pupil_center_y_yticks = np.linspace(
        np.nanmin(sample["pupil_center"][1, :]),
        np.nanmax(sample["pupil_center"][1, :]),
        2,
    )

    # plot movie frame
    ax1.imshow(sample["video"][0, t - 1, :, :], cmap="gray", aspect="equal")
    ax1.grid(linewidth=0)
    plot.set_xticks(
        ax1,
        ticks=movie_xticks,
        tick_labels=movie_xticks.astype(int),
        tick_fontsize=FONTSIZE,
    )
    plot.set_yticks(
        ax1,
        ticks=movie_yticks,
        tick_labels=movie_yticks.astype(int),
        tick_fontsize=FONTSIZE,
    )

    # behavior - pupil size
    ax2.plot(sample["behavior"][0], color="black", linewidth=1.5)
    plot.set_yticks(
        ax2,
        ticks=pupil_size_yticks,
        tick_labels=[f"{value:.01e}" for value in pupil_size_yticks],
        tick_fontsize=FONTSIZE,
    )
    ax2.set_title("Pupil size", pad=1, fontsize=FONTSIZE)

    # behavior - speed
    ax3.plot(sample["behavior"][1], color="black", linewidth=1.5)
    plot.set_yticks(
        ax3,
        ticks=speed_yticks,
        tick_labels=[f"{value:.0e}" for value in speed_yticks],
        tick_fontsize=FONTSIZE,
    )
    ax3.set_title("Locomotion speed", pad=1, fontsize=FONTSIZE)

    # pupil center - horizontal
    ax4.plot(sample["pupil_center"][0], color="black", linewidth=1.5)
    plot.set_yticks(
        ax4,
        ticks=pupil_center_x_yticks,
        tick_labels=[f"{value:.01e}" for value in pupil_center_x_yticks],
        tick_fontsize=FONTSIZE,
    )
    ax4.set_title("Pupil center (x)", pad=1, fontsize=FONTSIZE)

    # pupil center - horizontal
    ax5.plot(sample["pupil_center"][1], color="black", linewidth=1.5)
    plot.set_yticks(
        ax5,
        ticks=pupil_center_y_yticks,
        tick_labels=[f"{value:.01e}" for value in pupil_center_y_yticks],
        tick_fontsize=FONTSIZE,
    )
    ax5.set_title("Pupil center (y)", pad=1, fontsize=FONTSIZE)

    for ax in [ax2, ax3, ax4, ax5]:
        plot.set_xticks(
            ax,
            ticks=response_xticks,
            tick_labels=response_xticks.astype(int),
            label="Movie frame" if ax == ax5 else None,
            tick_fontsize=FONTSIZE,
            label_fontsize=FONTSIZE,
        )
        ax.grid(visible=False, which="major")
        sns.despine(ax=ax, trim=True)

    # plot response
    response = standardize(sample["response"])

    neuron_coordinates = np.load(
        DATA_DIR
        / data.MOUSE_IDS[mouse_id]
        / "meta"
        / "neurons"
        / "cell_motor_coordinates.npy"
    )
    z_values = neuron_coordinates[:, 2]
    depths, counts = np.unique(z_values, return_counts=True)
    uniques = dict(zip(depths, counts))

    if random_neuron:
        neurons = np.arange(response.shape[0])
        neurons = np.random.permutation(neurons)
        response = response[neurons, :]
    else:
        print(f"Mouse {mouse_id} unique z values: {uniques}")
        z_orders = np.argsort(z_values)
        response = response[z_orders, :]

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=UserWarning)
        ax6.imshow(response, cmap=COLORMAP, aspect="auto", interpolation=None)
    ax6.grid(linewidth=0)
    plot.set_xticks(
        ax6,
        ticks=response_xticks,
        tick_labels=response_xticks.astype(int),
        label="Movie frame",
        tick_fontsize=FONTSIZE,
        label_fontsize=FONTSIZE,
    )
    plot.set_yticks(
        ax6,
        ticks=response_yticks,
        tick_labels=response_yticks.astype(int),
        tick_fontsize=FONTSIZE,
    )
    ax6.set_ylabel(
        "Neuron (random)" if random_neuron else "Neuron", fontsize=FONTSIZE, labelpad=0
    )
    ax6.set_title("Response", pad=2, fontsize=FONTSIZE)

    if not random_neuron:
        for depth in depths:
            start = np.where(z_values == depth)[0][0]
            stop = np.where(z_values == depth)[0][-1]
            print(f"depth: {depth:.01f} start {start} stop {stop}")
            ax6.axhline(
                y=start, color="orangered", linewidth=1, alpha=0.6, linestyle="--"
            )

    # plot colorbar
    pos6 = ax6.get_position()
    width, height = 0.008, (pos6.y1 - pos6.y0) * 0.15
    cbar_ax6 = figure.add_axes(
        rect=[
            pos6.x1 + 0.01,
            ((pos6.y1 - pos6.y0) / 2 + pos6.y0) - (height / 2),
            width,
            height,
        ]
    )
    cbar6 = plt.colorbar(cm.ScalarMappable(cmap=COLORMAP), cax=cbar_ax6, shrink=0.5)
    cbar6.mappable.set_clim(np.min(response), np.max(response))
    cbar_ticks = np.linspace(np.min(response), np.max(response), 3)
    plot.set_yticks(
        axis=cbar_ax6,
        ticks=cbar_ticks.round(1),
        tick_labels=cbar_ticks,
        tick_fontsize=FONTSIZE,
    )

    for ax in [ax1, ax2, ax3, ax4, ax5, ax6, cbar_ax6]:
        plot.set_ticks_params(ax, length=2)

    if filename is not None:
        plot.save_figure(figure, filename=filename, dpi=DPI, close=False)
    return figure

In [6]:
mouse_id, trial_id = "F", 0
figure = plot_trial(
    val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
    mouse_id=mouse_id,
    random_neuron=False,
    filename=PLOT_DIR
    / "response_patterns"
    / f"mouse{mouse_id}"
    / f"trial{trial_id}.png",
)
plt.show()
plt.close(figure)
# figure = plot_trial(
#     val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
#     mouse_id=mouse_id,
#     random_neuron=True,
#     filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}_random.png",
# )
# plt.show()
# plt.close(figure)

TypeError: clamp() received an invalid combination of arguments - got (numpy.ndarray, min=int), but expected one of:
 * (Tensor input, Tensor min = None, Tensor max = None, *, Tensor out = None)
 * (Tensor input, Number min = None, Number max = None, *, Tensor out = None)


In [None]:
mouse_id, trial_id = "F", 10
figure = plot_trial(
    val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
    mouse_id=mouse_id,
    random_neuron=False,
    filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}.png",
)
plt.show()
plt.close(figure)
# figure = plot_trial(
#     val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
#     random_neuron=True,
#     filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}_random.png",
# )
# plt.show()
# plt.close(figure)

In [None]:
mouse_id, trial_id = "G", 10
figure = plot_trial(
    val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
    mouse_id=mouse_id,
    random_neuron=False,
    filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}.png",
)
plt.show()
plt.close(figure)
# figure = plot_trial(
#     val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
#     random_neuron=True,
#     filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}_random.png",
# )
# plt.show()
# plt.close(figure)

In [None]:
mouse_id, trial_id = "H", 10
figure = plot_trial(
    val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
    mouse_id=mouse_id,
    random_neuron=False,
    filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}.png",
)
plt.show()
plt.close(figure)
# figure = plot_trial(
#     val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
#     random_neuron=True,
#     filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}_random.png",
# )
# plt.show()
# plt.close(figure)

In [None]:
mouse_id, trial_id = "I", 10
figure = plot_trial(
    val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
    mouse_id=mouse_id,
    random_neuron=False,
    filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}.png",
)
plt.show()
plt.close(figure)
# figure = plot_trial(
#     val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
#     random_neuron=True,
#     filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}_random.png",
# )
# plt.show()
# plt.close(figure)

In [None]:
mouse_id, trial_id = "J", 10
figure = plot_trial(
    val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
    mouse_id=mouse_id,
    random_neuron=False,
    filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}.png",
)
plt.show()
plt.close(figure)
# figure = plot_trial(
#     val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
#     random_neuron=True,
#     filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}_random.png",
# )
# plt.show()
# plt.close(figure)

In [None]:
mouse_id, trial_id = "A", 10
figure = plot_trial(
    val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
    mouse_id=mouse_id,
    random_neuron=False,
    filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}.png",
)
plt.show()
plt.close(figure)
# figure = plot_trial(
#     val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
#     random_neuron=True,
#     filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}_random.png",
# )
# plt.show()
# plt.close(figure)

In [None]:
mouse_id, trial_id = "B", 10
figure = plot_trial(
    val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
    mouse_id=mouse_id,
    random_neuron=False,
    filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}.png",
)
plt.show()
plt.close(figure)
# figure = plot_trial(
#     val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
#     random_neuron=True,
#     filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}_random.png",
# )
# plt.show()
# plt.close(figure)

In [None]:
mouse_id, trial_id = "C", 10
figure = plot_trial(
    val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
    mouse_id=mouse_id,
    random_neuron=False,
    filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}.png",
)
plt.show()
plt.close(figure)
# figure = plot_trial(
#     val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
#     random_neuron=True,
#     filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}_random.png",
# )
# plt.show()
# plt.close(figure)

In [None]:
mouse_id, trial_id = "C", 11

figure = plot_trial(
    val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
    mouse_id=mouse_id,
    random_neuron=False,
    filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}.png",
)
plt.show()
plt.close(figure)

In [None]:
mouse_id, trial_id = "C", 12
figure = plot_trial(
    val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
    mouse_id=mouse_id,
    random_neuron=False,
    filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}.png",
)
plt.show()
plt.close(figure)

In [None]:
mouse_id, trial_id = "C", 14
figure = plot_trial(
    val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
    mouse_id=mouse_id,
    random_neuron=False,
    filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}.png",
)
plt.show()
plt.close(figure)

In [None]:
mouse_id, trial_id = "C", 15

figure = plot_trial(
    val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
    mouse_id=mouse_id,
    random_neuron=False,
    filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}.png",
)
plt.show()
plt.close(figure)

In [None]:
mouse_id, trial_id = "C", 16
figure = plot_trial(
    val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
    mouse_id=mouse_id,
    random_neuron=False,
    filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}.png",
)
plt.show()
plt.close(figure)

In [None]:
mouse_id, trial_id = "D", 10
figure = plot_trial(
    val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
    mouse_id=mouse_id,
    random_neuron=False,
    filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}.png",
)
plt.show()
plt.close(figure)
# figure = plot_trial(
#     val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
#     random_neuron=True,
#     filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}_random.png",
# )
# plt.show()
# plt.close(figure)

In [None]:
mouse_id, trial_id = "E", 10
figure = plot_trial(
    val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
    mouse_id=mouse_id,
    random_neuron=False,
    filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}.png",
)
plt.show()
plt.close(figure)
# figure = plot_trial(
#     val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
#     random_neuron=True,
#     filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}_random.png",
# )
# plt.show()
# plt.close(figure)

In [None]:
mouse_id, trial_id = "E", 50
figure = plot_trial(
    val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
    mouse_id=mouse_id,
    random_neuron=False,
    filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}.png",
)
plt.show()
plt.close(figure)
# figure = plot_trial(
#     val_ds[mouse_id].dataset.__getitem__(trial_id, to_tensor=False),
#     random_neuron=True,
#     filename=f"figures/response_patterns/mouse{mouse_id}/trial{trial_id}_random.png",
# )
# plt.show()
# plt.close(figure)