# Analyzing results for Head Direction

## Imports and configuration

In [None]:
import json
import matplotlib as mpl
import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
import numpy as np
import os
import pathlib
import scipy.stats as stats
from sklearn.decomposition import PCA
import sys
import torch

sys.path.append("..")

from models import rnn
from tasks import spatial_navigation
import utils

In [None]:
if "Arial" not in fm.get_font_names():
    font_path = pathlib.Path.home() / "fonts" / "arial.ttf"  # Set to correct path
    fm.fontManager.addfont(font_path)

plt.rcParams["font.family"] = "Arial"
plt.rcParams["font.size"] = 10
plt.rcParams["axes.linewidth"] = 1.2
plt.rcParams["xtick.major.width"] = 1.2
plt.rcParams["ytick.major.width"] = 1.2

In [None]:
active = "#2ca7c5"
quiescent = "#ee3233"

active_colors = [
    "#0d323b",
    "#1e7489",
    "#2ca7c5",
    "#80cadc",
    "#d4edf3",
    "#ffffff",
    "#d4edf3",
    "#80cadc",
    "#2ca7c5",
    "#1e7489",
    "#0d323b",
]
lcmap_active = mpl.colors.LinearSegmentedColormap.from_list("lcmap_active", active_colors)
norm = mpl.colors.Normalize(vmin=0, vmax=1.5)
lcmap_active = mpl.cm.ScalarMappable(norm=norm, cmap=lcmap_active)
lcmap_active.set_array([])

quiescent_colors = [
    "#470f0f",
    "#a62323",
    "#ee3233",
    "#f48484",
    "#fbd6d6",
    "#ffffff",
    "#fbd6d6",
    "#f48484",
    "#ee3233",
    "#a62323",
    "#470f0f",
]
lcmap_quiescent = mpl.colors.LinearSegmentedColormap.from_list("lcmap_quiescent", quiescent_colors)
norm = mpl.colors.Normalize(vmin=0, vmax=1.5)
lcmap_quiescent = mpl.cm.ScalarMappable(norm=norm, cmap=lcmap_quiescent)
lcmap_quiescent.set_array([])

## Load results

In [None]:
results = pathlib.Path("../results")
noise = 0.0001
config_name = f"noisy_unbiased_{noise}"
seed_dirs = [results / "head_direction" / f"{config_name}_{seed}" for seed in range(5)]
figures = results / "head_direction" / "figures"
pathlib.Path.mkdir(figures, exist_ok=True)

## Visualizations

### Loss

In [None]:
test_metrics = [json.load(open(d / "test_metrics.json", "r")) for d in seed_dirs]
start, end, skip = 100, 20001, 100
test_losses = np.array([[f[str(i)]["loss"] for i in range(start, end, skip)] for f in test_metrics])
test_posmse = np.array([[f[str(i)]["pos_mse"] for i in range(start, end, skip)] for f in test_metrics])

In [None]:
losses_mean, losses_std = test_losses.mean(axis=0), test_losses.std(axis=0)
posmse_mean, posmse_std = test_posmse.mean(axis=0), test_posmse.std(axis=0)

In [None]:
fig, ax = plt.subplots(figsize=(1.5, 1.5))

ax.spines[["right", "top"]].set_visible(False)
ax.set_box_aspect(1)
ax.errorbar(np.arange(start, end, skip), losses_mean, yerr=losses_std, fmt="-", c="gray")
ax.set_xlabel("batches")
ax.set_ylabel("test loss", color="grey")
ax.tick_params(axis="both", which="major", labelsize=8)
ax.tick_params(axis="both", which="minor", labelsize=8)
ax.yaxis.get_offset_text().set_fontsize(8)

ax2 = ax.twinx()
ax2.spines[["left", "top"]].set_visible(False)
ax2.set_box_aspect(1)
ax2.errorbar(np.arange(start, end, skip), posmse_mean, yerr=posmse_std, fmt="-", c="black")
ax2.set_ylabel("bearing decoding error", color="black")
ax2.tick_params(axis="both", which="major", labelsize=8)
ax2.tick_params(axis="both", which="minor", labelsize=8)

plt.show()

In [None]:
fig.savefig(figures / f"{config_name}-loss.pdf", bbox_inches="tight", pad_inches=0)

### Population activity dimensionality (PCs vs cumulative explained variance)

In [None]:
ev_active = []
ev_quiescent = []

In [None]:
for seed in range(5):
    utils.set_random_seeds(seed)

    epoch = 20000
    model = torch.load(seed_dirs[seed] / f"model_{epoch}.pt", map_location="cpu")
    model.set_device("cpu")
    model.eval()

    task = model.task
    task.batch_size = 1000

    test_data = task.get_test_batch()
    t_quiescent = 200
    quiescent_inputs = torch.zeros(task.batch_size, t_quiescent, 1)

    h_active, _ = model(test_data["data"], test_data["init_state"])
    h_active = h_active.cpu().detach().numpy().reshape(-1, model.n_rec)
    model.sigma_rec *= np.sqrt(2)
    h_quiescent, _ = model(quiescent_inputs, test_data["init_state"])
    h_quiescent = h_quiescent.cpu().detach().numpy().reshape(-1, model.n_rec)

    pca = PCA()
    pca.fit(h_active)
    ev_active.append(pca.explained_variance_ratio_)
    pca.fit(h_quiescent)
    ev_quiescent.append(pca.explained_variance_ratio_)

In [None]:
ev_active = np.array(ev_active)
ev_quiescent = np.array(ev_quiescent)

In [None]:
fig, ax = plt.subplots(figsize=(1.5, 1.5))
ax.spines[["right", "top"]].set_visible(False)
ax.set_box_aspect(1)

ax.errorbar(
    np.arange(1, 11),
    ev_active.cumsum(axis=-1).mean(axis=0)[:10],
    yerr=ev_active.cumsum(axis=-1).std(axis=0)[:10],
    fmt="-",
    c=active,
    label="active",
)
ax.errorbar(
    np.arange(1, 11),
    ev_quiescent.cumsum(axis=-1).mean(axis=0)[:10],
    yerr=ev_quiescent.cumsum(axis=-1).std(axis=0)[:10],
    fmt="-",
    c=quiescent,
    label="quiescent",
)

ax.set_xlabel("# of PCs")
ax.set_ylabel("explained variance")
ax.set_xticks(np.arange(1, 11))
ax.set_yticks(np.arange(0.4, 1.0, 0.1))
ax.tick_params(axis="both", which="major", labelsize=8)
ax.tick_params(axis="both", which="minor", labelsize=8)

leg = ax.legend(loc="best", frameon=False)

for handle, text in zip(leg.legend_handles, leg.get_texts()):
    text.set_color(handle.get_color())
    handle.set_visible(False)

plt.show()

In [None]:
fig.savefig(figures / f"{config_name}-explained_variance.pdf", bbox_inches="tight", pad_inches=0)

### Neural activity PCs

In [None]:
quiescence = "scaled"
# One of:
#   - "scaled" (quiescent_noise = np.sqrt(2) * active_noise)
#   - "same" (quiescent_noise = active_noise)
#   - "absolute" (quiescent_noise = np.sqrt(2 * noise) where noise is defined earlier e.g. 0.0001

In [None]:
seed = 0
utils.set_random_seeds(seed)

epoch = 20000
model = torch.load(seed_dirs[seed] / f"model_{epoch}.pt", map_location="cpu")
model.set_device("cpu")
model.eval()

task = model.task
task.batch_size = 200

test_data = task.get_test_batch()
t_quiescent = 1000
quiescent_inputs = torch.zeros(task.batch_size, t_quiescent, 1)

h_active = model(test_data["data"], test_data["init_state"])

if quiescence == "scaled":
    model.sigma_rec *= np.sqrt(2)
elif quiescence == "absolute":
    model.sigma_rec = np.sqrt(2 * noise)
elif quiescence == "same":
    pass

h_quiescent = model(quiescent_inputs, test_data["init_state"])

h_active_ = h_active[0].cpu().detach().numpy().reshape(-1, model.n_rec)
h_quiescent_ = h_quiescent[0].cpu().detach().numpy().reshape(-1, model.n_rec)

pca = PCA()
h_active_pca = pca.fit_transform(h_active_)
h_quiescent_pca = pca.transform(h_quiescent_)

In [None]:
active_xy = task.hd_cells.decode_hd(h_active[1])
active_d = active_xy.squeeze().cpu().detach().numpy()
quiescent_xy = task.hd_cells.decode_hd(h_quiescent[1])
quiescent_d = quiescent_xy.squeeze().cpu().detach().numpy()

In [None]:
fig, ax = plt.subplots(figsize=(7.5, 1.5))
ax.set_box_aspect(1)

skip_traj = 2
shift_time = 7
skip_time = 20
t_task = task.sequence_length
for i in range(0, task.batch_size, skip_traj):
    plt.scatter(
        h_active_pca[i * t_task + shift_time : (i + 1) * t_task : skip_time, 0],
        h_active_pca[i * t_task + shift_time : (i + 1) * t_task : skip_time, 1],
        c=lcmap_active.to_rgba(active_d[i, shift_time::skip_time]),
    )

ax.grid(visible=False)
ax.set_xlabel("PC-1")
ax.set_ylabel("PC-2")
ax.spines[["right", "top"]].set_visible(False)
ax.tick_params(axis="both", which="major", labelsize=8)
ax.tick_params(axis="both", which="minor", labelsize=8)
ax.set_xticks([-10, 0, 10])
ax.set_yticks([-10, 0, 10])
ax.set_xlim(-14, 14)
ax.set_ylim(-14, 14)

cb = plt.colorbar(mappable=lcmap_active, ax=ax)
cb.ax.tick_params(labelsize=8)
cb.set_ticks([-np.pi, 0, np.pi], labels=[r"$-\pi$", "0", r"$\pi$"])
cb.set_label(label="bearing", labelpad=10, fontsize=8)

In [None]:
fig.savefig(figures / f"{config_name}_{seed}-pca_active.pdf", bbox_inches="tight", pad_inches=0)

In [None]:
fig, ax = plt.subplots(figsize=(7.5, 1.5))
ax.set_box_aspect(1)

skip_traj = 2
shift_time = 7
skip_time = 20
for i in range(0, task.batch_size, skip_traj):
    plt.scatter(
        h_quiescent_pca[i * t_quiescent + shift_time : (i + 1) * t_quiescent : skip_time, 0],
        h_quiescent_pca[i * t_quiescent + shift_time : (i + 1) * t_quiescent : skip_time, 1],
        c=lcmap_quiescent.to_rgba(quiescent_d[i, shift_time::skip_time]),
    )

ax.grid(visible=False)
ax.set_xlabel("PC-1")
ax.set_ylabel("PC-2")
ax.spines[["right", "top"]].set_visible(False)
ax.tick_params(axis="both", which="major", labelsize=8)
ax.tick_params(axis="both", which="minor", labelsize=8)
ax.set_xticks([-10, 0, 10])
ax.set_yticks([-10, 0, 10])
ax.set_xlim(-14, 14)
ax.set_ylim(-14, 14)

cb = plt.colorbar(mappable=lcmap_quiescent, ax=ax)
cb.ax.tick_params(labelsize=8)
cb.ax.yaxis.set_ticks_position("left")
cb.set_ticks([-np.pi, 0, np.pi], labels=[r"$-\pi$", "0", r"$\pi$"])
cb.set_label(label="bearing", labelpad=10, fontsize=8)

In [None]:
fig.savefig(
    figures / f"{config_name}_{seed}-pca_quiescent_{quiescence}.pdf", bbox_inches="tight", pad_inches=0
)

### Histogram of decoded output bearings

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "polar"}, figsize=(1, 1))

bins = 20
density = True
gaps = True

x = active_xy.flatten()

if not gaps:
    bins = np.linspace(-np.pi, np.pi, num=bins + 1)

n, bins = np.histogram(x, bins=bins)
widths = np.diff(bins)

if density:
    area = n / x.size()
    radius = (area / np.pi) ** 0.5
else:
    radius = n

patches = ax.bar(
    bins[:-1],
    radius,
    zorder=1,
    align="edge",
    width=widths,
    edgecolor=active,
    fill=True,
    color=f"{active}88",
    linewidth=1.2,
)

ax.set_rmax(2)
ax.set_rticks([0.5, 1, 1.5, 2])
ax.set_rlabel_position(-22.5)
ax.set_rorigin(-0.05)
ax.set_theta_offset(0)
ax.grid(True)
ax.set_xticklabels(
    [
        "0",
        r"$\frac{\pi}{4}$",
        r"$\frac{\pi}{2}$",
        r"$\frac{3\pi}{4}$",
        "$\pi$\n$-\pi$",
        r"$-\frac{3\pi}{4}$",
        r"-$\frac{\pi}{2}$",
        r"-$\frac{\pi}{4}$",
    ]
)
ax.set_yticks([])
ax.tick_params(axis="both", which="major", labelsize=8, pad=0)
ax.tick_params(axis="both", which="minor", labelsize=8, pad=0)
ax.tick_params(axis="x", which="major", colors="k")

In [None]:
fig.savefig(figures / f"{config}_{seed}-outputs_active.pdf", bbox_inches="tight", pad_inches=0)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "polar"}, figsize=(1, 1))

bins = 20
density = True
gaps = True

x = quiescent_xy.flatten()

if not gaps:
    bins = np.linspace(-np.pi, np.pi, num=bins + 1)

n, bins = np.histogram(x, bins=bins)
widths = np.diff(bins)

if density:
    area = n / x.size()
    radius = (area / np.pi) ** 0.5
else:
    radius = n

patches = ax.bar(
    bins[:-1],
    radius,
    zorder=1,
    align="edge",
    width=widths,
    edgecolor=quiescent,
    fill=True,
    color=f"{quiescent}88",
    linewidth=1.2,
)

ax.set_rmax(2)
ax.set_rticks([0.5, 1, 1.5, 2])
ax.set_rlabel_position(-22.5)
ax.set_rorigin(-0.05)
ax.set_theta_offset(0)
ax.grid(True)
ax.set_xticklabels(
    [
        "0",
        r"$\frac{\pi}{4}$",
        r"$\frac{\pi}{2}$",
        r"$\frac{3\pi}{4}$",
        "$\pi$\n$-\pi$",
        r"$-\frac{3\pi}{4}$",
        r"-$\frac{\pi}{2}$",
        r"-$\frac{\pi}{4}$",
    ]
)
ax.set_yticks([])
ax.tick_params(axis="both", which="major", labelsize=8, pad=0)
ax.tick_params(axis="both", which="minor", labelsize=8, pad=0)
ax.tick_params(axis="x", which="major", colors="k")

In [None]:
fig.savefig(
    figures / f"{config}_{seed}-outputs_quiescent-q_{quiescence}.pdf", bbox_inches="tight", pad_inches=0
)

### Decoded output trajectories

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "polar"}, figsize=(1, 1))

x = active_d[0][:50]

ax.scatter(x[0], [0], c="k", label="start", s=10, zorder=3)
ax.scatter(x[-1], [49], c="k", label="end", marker="^", s=10, zorder=3)
ax.plot(x, np.arange(50), color=f"{active}dd", linewidth=2)

ax.set_rmax(2)
ax.set_rticks([0.5, 1, 1.5, 2])
ax.set_rlabel_position(-22.5)
ax.set_rorigin(-0.05)
ax.set_theta_offset(0)
ax.grid(True)
ax.set_xticklabels(
    [
        "0",
        r"$\frac{\pi}{4}$",
        r"$\frac{\pi}{2}$",
        r"$\frac{3\pi}{4}$",
        "$\pi$\n$-\pi$",
        r"$-\frac{3\pi}{4}$",
        r"-$\frac{\pi}{2}$",
        r"-$\frac{\pi}{4}$",
    ]
)
ax.set_yticks([])
ax.tick_params(axis="both", which="major", labelsize=8, pad=0)
ax.tick_params(axis="both", which="minor", labelsize=8, pad=0)
ax.tick_params(axis="x", which="major", colors="k")

In [None]:
fig.savefig(figures / f"{config}_{seed}-trajectory_active.pdf", bbox_inches="tight", pad_inches=0)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "polar"}, figsize=(1, 1))

x = quiescent_d[0][:500]

ax.scatter(x[0], [0], c="k", label="start", s=10, zorder=3)
ax.scatter(x[-1], [499], c="k", label="end", marker="^", s=10, zorder=3)
ax.plot(x, np.arange(500), color=f"{quiescent}dd", linewidth=2)

ax.set_rmax(2)
ax.set_rticks([0.5, 1, 1.5, 2])
ax.set_rlabel_position(-22.5)
ax.set_rorigin(-0.05)
ax.set_theta_offset(0)
ax.grid(True)
ax.set_xticklabels(
    [
        "0",
        r"$\frac{\pi}{4}$",
        r"$\frac{\pi}{2}$",
        r"$\frac{3\pi}{4}$",
        "$\pi$\n$-\pi$",
        r"$-\frac{3\pi}{4}$",
        r"-$\frac{\pi}{2}$",
        r"-$\frac{\pi}{4}$",
    ]
)
ax.set_yticks([])
ax.tick_params(axis="both", which="major", labelsize=8, pad=0)
ax.tick_params(axis="both", which="minor", labelsize=8, pad=0)
ax.tick_params(axis="x", which="major", colors="k")

In [None]:
fig.savefig(
    figures / f"{config}_{seed}-trajectory_quiescent-q_{quiescent}.pdf", bbox_inches="tight", pad_inches=0
)

### Histogram of angular velocity

In [None]:
active_d_rescaled = active_d + np.pi
quiescent_d_rescaled = quiescent_d + np.pi
av_active = (active_d_rescaled[:, 1:] - active_d_rescaled[:, :-1]).flatten()
av_quiescent = (quiescent_d_rescaled[:, 1:] - quiescent_d_rescaled[:, :-1]).flatten()

In [None]:
# Approximate thresholds to deal with sign change
av_active[av_active > 1.5 * np.pi] -= 2 * np.pi
av_active[av_active < -1.5 * np.pi] += 2 * np.pi
av_quiescent[av_quiescent > 1.5 * np.pi] -= 2 * np.pi
av_quiescent[av_quiescent < -1.5 * np.pi] += 2 * np.pi

In [None]:
fig = plt.figure(figsize=(1.5, 1.5))
plt.hist(
    av_active, bins=np.linspace(-0.5, 0.5, 20), color=active, weights=np.ones(len(av_active)) / len(av_active)
)
plt.xlabel("Angular velocity")
plt.ylabel("Proportion")

In [None]:
fig.savefig(figures / f"{config}_{seed}-velocity_active.pdf", bbox_inches="tight", pad_inches=0)

In [None]:
fig, ax = plt.subplots(figsize=(1.5, 1.5))
plt.hist(
    av_quiescent,
    bins=np.linspace(-0.2, 0.2, 20),
    color=quiescent,
    weights=np.ones(len(av_quiescent)) / len(av_quiescent),
)
ax2 = ax.secondary_xaxis("top")
ax2.tick_params(axis="x", length=0)
ax2.set_xticks([av_quiescent.mean()], minor=False)
plt.axvline(av_quiescent.mean(), c="k", ls="--")
plt.xlabel("Angular velocity")
plt.ylabel("Proportion")

In [None]:
fig.savefig(figures / f"{config}_{seed}-velocity_quiescent-q_{quiescence}", bbox_inches="tight", pad_inches=0)