In [None]:
import argparse
import os
import pickle
from typing import Dict
from typing import List
from typing import Tuple

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from einops import rearrange
from einops import repeat
from matplotlib import cm
from torch.utils.data import DataLoader
from tqdm import tqdm

from viv1t.utils import plot
from viv1t.utils import utils

In [None]:
matplotlib.rcParams["animation.embed_limit"] = 2**64

plot.set_font()
utils.set_random_seed(1234)

COLORMAP = "turbo"
TICK_FONTSIZE = 8
LABEL_FONTSIZE = 9
TITLE_FONTSIZE = 10
FPS = 30
ALPHA = 0.4  # heatmap overlay alpha value
SKIP = 50  # number of frames to skip for metric calculation
DPI = 180
MAX_FRAMES = 300

TURBO = matplotlib.colormaps.get_cmap("turbo")
TURBO_COLOR = TURBO(np.arange(256))[:, :3]

HOT = matplotlib.colormaps.get_cmap("hot")
HOT_COLOR = HOT(np.arange(256))[:, :3]

In [None]:
with open("../runs/best_vivit/attention_rollout.pkl", "rb") as file:
    results = pickle.load(file)

In [None]:
def animate_spatial_attention(sample: Dict[str, np.ndarray], filename: str = None):
    figure_width, figure_height = 4, 2.3
    figure = plt.figure(
        figsize=(figure_width, figure_height), dpi=DPI, facecolor="white"
    )
    _, t, h, w = sample["video"].shape
    get_height = lambda x: x * (h / w) * (figure_width / figure_height)

    # spatial attention
    width = 0.95
    ax = figure.add_axes(rect=(0.002, 0.006, width, get_height(width)))
    pos = ax.get_position()

    # add colorbar
    cbar_width, cbar_height = 0.01, 0.1
    cbar_ax = figure.add_axes(
        rect=(
            pos.x1 + 0.01,
            pos.y0 + ((pos.y1 - pos.y0) / 2) - (cbar_height / 2),
            cbar_width,
            cbar_height,
        )
    )
    camera = Camera(figure)

    for i in range(t):
        # plot spatial attention map overlay on frame
        image = sample["video"][0, i]
        heatmap = sample["spatial_attention"][i]
        heatmap = TURBO_COLOR[np.uint8(255.0 * heatmap)] * 255.0
        heatmap = ALPHA * heatmap + (1 - ALPHA) * image[..., None]
        ax.imshow(heatmap.astype(np.uint8), cmap=TURBO, interpolation=None)
        ax.set_title('spatial "attention"', pad=3, fontsize=LABEL_FONTSIZE)
        ax.grid(linewidth=0)
        ax.set_xticks([])
        ax.set_yticks([])

        cbar = plt.colorbar(cm.ScalarMappable(cmap=COLORMAP), cax=cbar_ax, shrink=0.5)
        cbar.mappable.set_clim(0, 1)
        cbar_yticks = np.linspace(0, 1, 2, dtype=int)
        plot.set_yticks(
            axis=cbar_ax,
            ticks=cbar_yticks,
            tick_labels=cbar_yticks,
            tick_fontsize=TICK_FONTSIZE,
        )
        plot.set_ticks_params(cbar_ax, length=1.5, pad=1)

        camera.snap()

    animation = camera.animate()
    if filename is not None:
        if not os.path.isdir(os.path.dirname(filename)):
            os.makedirs(os.path.dirname(filename))
        animation.save(filename, fps=FPS, dpi=DPI, savefig_kwargs={"pad_inches": 0})
    plt.close(figure)

In [None]:
def animate_temporal_attention(sample: Dict[str, np.ndarray], filename: str = None):
    figure_width, figure_height = 4, 0.6
    figure = plt.figure(
        figsize=(figure_width, figure_height), dpi=DPI, facecolor="white"
    )
    _, t, h, w = sample["video"].shape
    frames = np.arange(start=MAX_FRAMES - t, stop=MAX_FRAMES, step=1)
    frame_xticks = np.linspace(frames[0], frames[-1], 4)
    temporal_attention_color = TURBO(sample["temporal_attention"])

    # temporal attention
    width = 0.94
    ax = figure.add_axes(rect=(0.035, 0.39, width, 0.4))

    camera = Camera(figure)

    for i in range(t):
        # plot temporal attention
        ax.scatter(
            frames[:i],
            sample["temporal_attention"][:i],
            s=8,
            label="temporal attention",
            clip_on=False,
            edgecolor="none",
            c=temporal_attention_color[:i],
        )
        ax.set_xlim(frames[0], frames[-1])
        ax.set_ylim(0, 1)
        plot.set_yticks(
            ax, ticks=[0, 1], tick_labels=[0, 1], tick_fontsize=TICK_FONTSIZE
        )
        ax.set_title('temporal "attention"', pad=2, fontsize=LABEL_FONTSIZE)
        ax.set_xlim(frames[0], frames[-1])
        plot.set_xticks(
            ax,
            ticks=frame_xticks,
            tick_labels=frame_xticks.astype(int),
            label="movie frame",
            tick_fontsize=TICK_FONTSIZE,
            label_fontsize=LABEL_FONTSIZE,
            label_pad=-2,
        )
        ax.grid(visible=False, which="major")
        sns.despine(ax=ax)
        plot.set_ticks_params(ax, length=2)

        camera.snap()

    animation = camera.animate()
    if filename is not None:
        if not os.path.isdir(os.path.dirname(filename)):
            os.makedirs(os.path.dirname(filename))
        animation.save(filename, fps=FPS, dpi=DPI, savefig_kwargs={"pad_inches": 0})
    plt.close(figure)

In [None]:
mouse_id, trial = "G", 1
sample = {
    "video": results[mouse_id]["videos"][trial],
    "behavior": results[mouse_id]["behaviors"][trial],
    "pupil_center": results[mouse_id]["pupil_centers"][trial],
    "spatial_attention": results[mouse_id]["spatial_attentions"][trial],
    "temporal_attention": results[mouse_id]["temporal_attentions"][trial],
    "correlation": results[mouse_id]["correlation"][trial],
}
animate_spatial_attention(
    sample=sample,
    filename=f"figures/NeurIPS2023/attention/mouse{mouse_id}_spatial{trial:03d}.gif",
)
animate_temporal_attention(
    sample=sample,
    filename=f"figures/NeurIPS2023/attention/mouse{mouse_id}_temporal{trial:03d}.gif",
)

In [None]:
mouse_id, trial = "G", 4
sample = {
    "video": results[mouse_id]["videos"][trial],
    "behavior": results[mouse_id]["behaviors"][trial],
    "pupil_center": results[mouse_id]["pupil_centers"][trial],
    "spatial_attention": results[mouse_id]["spatial_attentions"][trial],
    "temporal_attention": results[mouse_id]["temporal_attentions"][trial],
    "correlation": results[mouse_id]["correlation"][trial],
}
animate_spatial_attention(
    sample=sample,
    filename=f"figures/NeurIPS2023/attention/mouse{mouse_id}_spatial{trial:03d}.gif",
)
animate_temporal_attention(
    sample=sample,
    filename=f"figures/NeurIPS2023/attention/mouse{mouse_id}_temporal{trial:03d}.gif",
)

In [None]:
mouse_id, trial = "J", 33
sample = {
    "video": results[mouse_id]["videos"][trial],
    "behavior": results[mouse_id]["behaviors"][trial],
    "pupil_center": results[mouse_id]["pupil_centers"][trial],
    "spatial_attention": results[mouse_id]["spatial_attentions"][trial],
    "temporal_attention": results[mouse_id]["temporal_attentions"][trial],
    "correlation": results[mouse_id]["correlation"][trial],
}
animate_spatial_attention(
    sample=sample,
    filename=f"figures/NeurIPS2023/attention/mouse{mouse_id}_spatial{trial:03d}.gif",
)
animate_temporal_attention(
    sample=sample,
    filename=f"figures/NeurIPS2023/attention/mouse{mouse_id}_temporal{trial:03d}.gif",
)

Attention plot for poster

In [None]:
with open("../runs/best_vivit/attention_rollout.pkl", "rb") as file:
    vivit_attention = pickle.load(file)

In [None]:
TICK_FONTSIZE = 10
LABEL_FONTSIZE = 11


def plot_spatial_attention_frame(sample, frame: int, filename: str, dpi=1200):
    figure_width, figure_height = 3, 2
    figure = plt.figure(
        figsize=(figure_width, figure_height), dpi=dpi, facecolor="white"
    )
    _, t, h, w = sample["video"].shape
    get_height = lambda x: x * (h / w) * (figure_width / figure_height)

    # spatial attention
    width = 0.95
    ax = figure.add_axes(rect=(0.002, 0.006, width, get_height(width)))

    # plot spatial attention map overlay on frame
    image = sample["video"][0, frame]
    heatmap = sample["spatial_attention"][frame]
    heatmap = TURBO_COLOR[np.uint8(255.0 * heatmap)] * 255.0
    heatmap = ALPHA * heatmap + (1 - ALPHA) * image[..., None]
    ax.imshow(heatmap.astype(np.uint8), cmap=TURBO, interpolation=None)
    ax.set_title('spatial "attention"', pad=4, fontsize=LABEL_FONTSIZE)
    ax.grid(linewidth=0)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.text(
        x=0.98,
        y=0.97,
        s=f"Frame {frame+50:03d}",
        ha="right",
        va="top",
        color="orangered",
        alpha=0.8,
        fontsize=LABEL_FONTSIZE,
        transform=ax.transAxes,
    )

    if not os.path.isdir(os.path.dirname(filename)):
        os.makedirs(os.path.dirname(filename))
    plot.save_figure(figure, filename=filename, dpi=dpi)
    plt.close(figure)


def plot_temporal_attention(
    sample: Dict[str, np.ndarray],
    selected: np.ndarray,
    filename: str = None,
    dpi: int = 240,
):
    figure_width, figure_height = 7, 0.8
    figure = plt.figure(
        figsize=(figure_width, figure_height), dpi=DPI, facecolor="white"
    )
    _, t, h, w = sample["video"].shape

    temporal_attention_color = TURBO(sample["temporal_attention"])

    # temporal attention
    width = 0.94
    ax = figure.add_axes(rect=(0.035, 0.39, width, 0.4))

    frames = np.arange(50, 300, dtype=int)

    # plot temporal attention
    ax.scatter(
        frames,
        sample["temporal_attention"],
        s=10,
        label="temporal attention",
        clip_on=False,
        edgecolor="none",
        c=temporal_attention_color,
    )
    ax.text(
        x=1.0,
        y=0.98,
        s='temporal "attention"',
        ha="right",
        va="top",
        color="black",
        alpha=0.8,
        fontsize=LABEL_FONTSIZE,
        transform=ax.transAxes,
    )
    ax.set_ylim(0, 1)
    plot.set_yticks(ax, ticks=[0, 1], tick_labels=[0, 1], tick_fontsize=TICK_FONTSIZE)
    ax.set_xlim(frames[0], frames[-1])
    frame_xticks = np.linspace(frames[0], frames[-1], 3, dtype=int)
    frame_xticks = np.append(frame_xticks, selected + 50)
    plot.set_xticks(
        ax,
        ticks=frame_xticks,
        tick_labels=frame_xticks,
        label="movie frame",
        tick_fontsize=TICK_FONTSIZE,
        label_fontsize=LABEL_FONTSIZE,
        label_pad=0,
    )
    ax.grid(visible=False, which="major")
    sns.despine(ax=ax)
    plot.set_ticks_params(ax, length=2)

    if not os.path.isdir(os.path.dirname(filename)):
        os.makedirs(os.path.dirname(filename))
    plot.save_figure(figure, filename=filename, dpi=dpi)
    plt.close(figure)

In [None]:
mouse_id, trial = "G", 1
sample = {
    "video": vivit_attention[mouse_id]["videos"][trial],
    "behavior": vivit_attention[mouse_id]["behaviors"][trial],
    "pupil_center": vivit_attention[mouse_id]["pupil_centers"][trial],
    "spatial_attention": vivit_attention[mouse_id]["spatial_attentions"][trial],
    "temporal_attention": vivit_attention[mouse_id]["temporal_attentions"][trial],
    "correlation": vivit_attention[mouse_id]["correlation"][trial],
}
frames = np.array([38, 48, 58, 165, 175, 185])
for frame in frames:
    plot_spatial_attention_frame(
        sample=sample,
        frame=frame,
        filename=f"figures/NeurIPS2023/poster/mouse{mouse_id}_spatial{trial:03d}_frame{frame:03d}.png",
    )
plot_temporal_attention(
    sample=sample,
    selected=frames,
    filename=f"figures/NeurIPS2023/poster/mouse{mouse_id}_temporal{trial:03d}.png",
)