In [21]:
import os
import pickle
import numpy as np
import typing as t
from scipy.stats import pearsonr
from scipy.ndimage import center_of_mass

from v1t import data
from v1t.utils import utils, tensorboard

utils.set_random_seed(1234)

tensorboard.set_font()

### Computer correlation between attention maps and behavior variables

In [22]:
def computer_centers(heatmaps: np.ndarray):
    centers = np.zeros((len(heatmaps), 2), dtype=np.float32)
    for i, heatmap in enumerate(heatmaps):
        y, x = center_of_mass(heatmap)
        centers[i, 0], centers[i, 1] = x, y
    mid_point = np.array([64 / 2, 36 / 2])
    centers = centers - mid_point
    return centers


def abs_correlation(x: np.ndarray, y: np.ndarray) -> t.Tuple[float, str]:
    """Return the absolute Pearson correlation and its p-value in asterisk"""
    corr, p_value = pearsonr(x, y)
    asterisks = "n.s."
    if p_value <= 0.0001:
        asterisks = "****"
    elif p_value <= 0.001:
        asterisks = "***"
    elif p_value <= 0.01:
        asterisks = "**"
    elif p_value <= 0.05:
        asterisks = "*"
    return np.abs(corr), asterisks

In [23]:
with open("../runs/best_v1t_sensorium/attention_rollout_maps.pkl", "rb") as file:
    results = pickle.load(file)

center_corrs, dilation_corrs = {"x": [], "y": []}, {"x": [], "y": []}
for mouse_id, mouse_dict in results["test"].items():
    print(f"Mouse {mouse_id}")
    # compute correlation center of mass and pupil center
    mass_centers = computer_centers(mouse_dict["heatmaps"])
    pupil_centers = mouse_dict["pupil_centers"]
    corr_x, p_x = abs_correlation(mass_centers[:, 0], pupil_centers[:, 0])
    corr_y, p_y = abs_correlation(mass_centers[:, 1], pupil_centers[:, 1])
    center_corrs["x"].append(corr_x)
    center_corrs["y"].append(corr_y)
    print(
        f"\tAbs. Corr(center of mass, pupil center)\n"
        f"\tx-axis: {corr_x:.03f} ({p_x})\n\ty-axis: {corr_y:.03f} ({p_y})"
    )

    # standard deviation in x and y axes
    spread_x = np.std(np.sum(mouse_dict["heatmaps"], axis=1), axis=1)
    spread_y = np.std(np.sum(mouse_dict["heatmaps"], axis=2), axis=1)
    dilation = mouse_dict["behaviors"][:, 0]
    # absolute correlation between pupil dilation and attention map
    # standard deviation
    corr_x, p_x = abs_correlation(spread_x, dilation)
    corr_y, p_y = abs_correlation(spread_y, dilation)
    dilation_corrs["x"].append(corr_x)
    dilation_corrs["y"].append(corr_y)
    print(
        f"\tAbs. Corr(attention map std, pupil dilation)\n"
        f"\tx-axis: {corr_x:.03f} ({p_x})\n\ty-axis: {corr_y:.03f} ({p_y})"
    )

print(
    f"\nAvg. Corr(center of mass, pupil center)\n"
    f'\tx-axis: {np.mean(center_corrs["x"]):.03f} \pm {np.std(center_corrs["x"]):.03f}\n'
    f'\ty-axis: {np.mean(center_corrs["y"]):.03f} \pm {np.std(center_corrs["y"]):.03f}\n'
)

Mouse A
	Abs. Corr(center of mass, pupil center)
	x-axis: 0.669 (****)
	y-axis: 0.602 (****)
	Abs. Corr(attention map std, pupil dilation)
	x-axis: 0.241 (****)
	y-axis: 0.254 (****)
Mouse B
	Abs. Corr(center of mass, pupil center)
	x-axis: 0.463 (****)
	y-axis: 0.487 (****)
	Abs. Corr(attention map std, pupil dilation)
	x-axis: 0.265 (****)
	y-axis: 0.222 (****)
Mouse C
	Abs. Corr(center of mass, pupil center)
	x-axis: 0.382 (****)
	y-axis: 0.458 (****)
	Abs. Corr(attention map std, pupil dilation)
	x-axis: 0.304 (****)
	y-axis: 0.180 (****)
Mouse D
	Abs. Corr(center of mass, pupil center)
	x-axis: 0.377 (****)
	y-axis: 0.330 (****)
	Abs. Corr(attention map std, pupil dilation)
	x-axis: 0.025 (n.s.)
	y-axis: 0.151 (***)
Mouse E
	Abs. Corr(center of mass, pupil center)
	x-axis: 0.425 (****)
	y-axis: 0.292 (****)
	Abs. Corr(attention map std, pupil dilation)
	x-axis: 0.259 (****)
	y-axis: 0.242 (****)

Avg. Corr(center of mass, pupil center)
	x-axis: 0.463 \pm 0.108
	y-axis: 0.434 \pm 0

### Plot attention rollout map overlay on top of visual stimulus