In [None]:
import sys, os

sys.path.append("../")
from vi_rnn.saving import load_model
from vi_rnn.datasets import NLBDataset, load_nlb_dataset
from vi_rnn.utils import *
from vi_rnn.inference import predict_NLB
from evaluation.calc_stats import *
import copy
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import h5py
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.linear_model import Ridge

from pathlib import Path

mpl.rc_file("matplotlibrc")

%matplotlib inline

In [None]:
DATA_ROOT = Path("../").absolute() / "data_untracked" / "processed"
CHKPT_DIR = Path("../").absolute() / "models"

In [None]:
# Load model

chkpt = "maze_nocue_nc_5d/"
chkpt_path = CHKPT_DIR / chkpt

model_save_name = sorted(chkpt_path.glob("*.pkl"))[0].stem
suffixes = [
    "_state_dict_enc",
    "_state_dict_rnn",
    "_task_params",
    "_training_params",
    "_vae_params",
]
for suffix in suffixes:
    model_save_name = model_save_name.replace(suffix, "")

model_save_name = str(chkpt_path / model_save_name)
vae, training_params, task_params = load_model(model_save_name, backward_compat=True)

In [None]:
# Load dataset

train_data, eval_data, train_inputs, eval_inputs = load_nlb_dataset(
    data_root=DATA_ROOT,
    name="mc_maze_input",
    phase="pos",
    bin_size=20,
    t_forward=0,
    input_field="input",
    u=2,
    cosmooth=False,
)
train_dataset = NLBDataset(train_data, dict(name="mc_maze_input"), inputs=train_inputs)
eval_dataset = NLBDataset(eval_data, dict(name="mc_maze_input"), inputs=eval_inputs)

In [None]:
dim_z = vae.vae_params["dim_z"]
dim_x = vae.vae_params["dim_x"]
n_inp = vae.vae_params["dim_u"]
print(f"dim_z: {dim_z}, n_inp: {n_inp}")

In [None]:
projection_matrix = get_orth_proj_latents(vae)
projection_matrix = projection_matrix.cpu().numpy()

In [None]:
# NOTE added for reproducability
# ------------------------------------------
# If you want to plot latents Z in an orthogonalised basis
# you have to use P@Z or Z@P.T
# the plots here are not in an orthogonalised this basis (as Z@P is used)
# as detailed in the paper, any full rank transformation P is valid
# given that this is what we used I left it like this here

sign_constraints = [-1, -1, -1, -1, 1]
for i, s in enumerate(sign_constraints):
    if int(np.sign(projection_matrix[i, 0])) is not s:
        projection_matrix[i] *= -1

In [None]:
vae.to_device("cuda")

In [None]:
# Get posterior latents

n_repeats = 10
batch_size = 64
k = 100
device = "cuda:0"

x_t = train_dataset.data.to(device)
u_t = train_dataset.stim.to(device)
t_held_in = x_t.shape[2]

Qzs_filt_repeats = []
Qzs_sm_repeats = []
Xs_filt_repeats = []
Xs_sm_repeats = []
for i in range(n_repeats):
    with torch.no_grad():
        Qzs_filt, Qzs_sm, Xs_filt, Xs_sm = zip(
            *[
                predict_NLB(
                    vae=vae,
                    x=x_t_chunk,
                    u=u_t_chunk,
                    k=k,
                    t_held_in=t_held_in,
                    t_forward=0,
                )
                for x_t_chunk, u_t_chunk in zip(
                    torch.chunk(x_t, chunks=len(x_t) // batch_size, dim=0),
                    torch.chunk(u_t, chunks=len(u_t) // batch_size, dim=0),
                )
            ]
        )
        Qzs_filt_repeats.append(
            torch.cat(Qzs_filt, dim=0).cpu().numpy().mean(axis=-1).transpose(0, 2, 1)
        )
        Qzs_sm_repeats.append(
            torch.cat(Qzs_sm, dim=0).cpu().numpy().mean(axis=-1).transpose(0, 2, 1)
        )
        Xs_filt_repeats.append(
            torch.cat(Xs_filt, dim=0).cpu().numpy().mean(axis=-1).transpose(0, 2, 1)
        )
        Xs_sm_repeats.append(
            torch.cat(Xs_sm, dim=0).cpu().numpy().mean(axis=-1).transpose(0, 2, 1)
        )

Qzs_filt = np.mean(Qzs_filt_repeats, axis=0)
Qzs_sm = np.mean(Qzs_sm_repeats, axis=0)
Xs_filt = np.mean(Xs_filt_repeats, axis=0)
Xs_sm = np.mean(Xs_sm_repeats, axis=0)

In [None]:
# Get reach target positions
target_pos = train_dataset.stim[:, :, 0].detach().cpu().numpy()
angles = np.arctan2(target_pos[:, 1], target_pos[:, 0])
angles = angles / (2 * np.pi) + 0.5
angles = (np.round(angles * 8) % 8) / 8

In [None]:
plt.scatter(target_pos[:, 0], target_pos[:, 1], c=angles, cmap=plt.cm.hsv)

In [None]:
# Take mean of posterior

Qz_mean = np.empty((len(np.unique(angles)), *Qzs_filt.shape[1:]))
for i, angle in enumerate(np.sort(np.unique(angles))):
    mask = angles == angle
    Qz_mean[i] = Qzs_filt[mask].mean(axis=0)

In [None]:
Qz_mean_proj = Qz_mean @ projection_matrix
Qzs_proj_all = Qzs_filt @ projection_matrix
plt_start = 6
plt_end = 10
plt_dim1 = 0
plt_dim2 = 1
sign_dim1 = 1
sign_dim2 = -1
fig, axs = plt.subplots(1, 1, figsize=(1.0, 1.0))
for angle_i in range(len(np.sort(np.unique(angles)))):
    angle = np.sort(np.unique(angles))[angle_i]

    mask = angles == angle
    Qzs_proj_all_cond = Qzs_proj_all[mask]
    for trial in range(min(Qzs_proj_all_cond.shape[0], 10)):
        axs.plot(
            Qzs_proj_all_cond[trial, plt_start:plt_end, plt_dim1] * sign_dim1,
            Qzs_proj_all_cond[trial, plt_start:plt_end, plt_dim2] * sign_dim2,
            color=plt.cm.hsv(angle),
            alpha=0.3,
            linewidth=0.6,
        )

    axs.plot(
        Qz_mean_proj[angle_i, plt_start:plt_end, plt_dim1] * sign_dim1,
        Qz_mean_proj[angle_i, plt_start:plt_end, plt_dim2] * sign_dim2,
        color=plt.cm.hsv(angle),
        alpha=0.6,
    )
axs.scatter(
    Qz_mean_proj[:, plt_end - 1, plt_dim1] * sign_dim1,
    Qz_mean_proj[:, plt_end - 1, plt_dim2] * sign_dim2,
    s=20,
    color=[plt.cm.hsv(a) for a in np.sort(np.unique(angles))],
    edgecolor="black",
    linewidth=0.8,
    zorder=2,
    alpha=0.6,
)
axs.set_xticks([])
axs.set_yticks([])
axs.set_xlabel(f"$z_{plt_dim1+1}$")
axs.set_ylabel(f"$z_{plt_dim2+1}$")

ylim = axs.get_ylim()
yrange = ylim[1] - ylim[0]
axs.set_ylim(ylim[0] - 0.05 * yrange, ylim[1])

prep_xlim, prep_ylim = axs.get_xlim(), axs.get_ylim()

In [None]:
def plot_with_increasing_alpha(
    ax, x, y, z=None, alpha_min=0.0, alpha_max=1.0, **kwargs
):
    alphas = np.linspace(alpha_min, alpha_max, len(x) - 1)
    for i in range(len(x) - 1):
        data = (x[i : i + 2], y[i : i + 2])
        if z is not None:
            data += (z[i : i + 2],)
        ax.plot(*data, alpha=alphas[i], **kwargs)
    return ax

In [None]:
# Make panel a (left)

Qz_mean_proj = Qz_mean @ projection_matrix
Qzs_proj_all = Qzs_filt @ projection_matrix
plt_start = 3
plt_end = 11
plt_dim1 = 0
plt_dim2 = 1
sign_dim1 = 1
sign_dim2 = -1
fig, axs = plt.subplots(1, 1, figsize=(1.0, 1.0))
for angle_i in range(len(np.sort(np.unique(angles)))):
    angle = np.sort(np.unique(angles))[angle_i]

    mask = angles == angle
    Qzs_proj_all_cond = Qzs_proj_all[mask]
    for trial in range(min(Qzs_proj_all_cond.shape[0], 5)):
        plot_with_increasing_alpha(
            axs,
            x=Qzs_proj_all_cond[trial, plt_start:plt_end, plt_dim1] * sign_dim1,
            y=Qzs_proj_all_cond[trial, plt_start:plt_end, plt_dim2] * sign_dim2,
            alpha_max=0.3,
            color=plt.cm.hsv(angle),
            linewidth=0.6,
        )

    plot_with_increasing_alpha(
        axs,
        x=Qz_mean_proj[angle_i, plt_start:plt_end, plt_dim1] * sign_dim1,
        y=Qz_mean_proj[angle_i, plt_start:plt_end, plt_dim2] * sign_dim2,
        alpha_max=0.6,
        color=plt.cm.hsv(angle),
    )

axs.scatter(
    Qz_mean_proj[:, plt_end - 1, plt_dim1] * sign_dim1,
    Qz_mean_proj[:, plt_end - 1, plt_dim2] * sign_dim2,
    s=20,
    color=[plt.cm.hsv(a) for a in np.sort(np.unique(angles))],
    edgecolor="black",
    linewidth=0.8,
    zorder=2,
    alpha=0.6,
)
axs.set_xticks([])
axs.set_yticks([])
axs.set_xlabel(f"$z_{plt_dim1+1}$")
axs.set_ylabel(f"$z_{plt_dim2+1}$")


# prep_xlim, prep_ylim = axs.get_xlim(), axs.get_ylim()
axs.set_xlim(*prep_xlim)
axs.set_ylim(*prep_ylim)

axs.set_title("pre-movement")
plt.savefig("../figures/prep_inf_latents.svg")
plt.savefig("../figures//prep_inf_latents.pdf")

In [None]:
# Make panel a (right)

fig, axs = plt.subplots(1, 1, figsize=(1.2, 1.2), subplot_kw={"projection": "3d"})

Qz_mean_proj = Qz_mean @ projection_matrix
Qzs_proj_all = Qzs_filt @ projection_matrix
plt_start = 10
plt_end = 35
plt_dim1 = 4
plt_dim2 = 3
plt_dim3 = 2
sign_dim1 = 1
sign_dim2 = 1
sign_dim3 = 1

prop_cycle = [plt.cm.hsv(ang) for ang in np.sort(np.unique(angles))]

for angle_i in range(len(np.sort(np.unique(angles)))):
    angle = np.sort(np.unique(angles))[angle_i]

    mask = angles == angle
    Qzs_proj_all_cond = Qzs_proj_all[mask]
    for trial in range(min(Qzs_proj_all_cond.shape[0], 5)):
        axs.plot(
            Qzs_proj_all_cond[trial, plt_start:plt_end, plt_dim1] * sign_dim1,
            Qzs_proj_all_cond[trial, plt_start:plt_end, plt_dim2] * sign_dim2,
            Qzs_proj_all_cond[trial, plt_start:plt_end, plt_dim3] * sign_dim3,
            color=prop_cycle[angle_i],
            alpha=0.3,
            linewidth=0.6,
        )

    axs.plot(
        Qz_mean_proj[angle_i, plt_start:plt_end, plt_dim1] * sign_dim1,
        Qz_mean_proj[angle_i, plt_start:plt_end, plt_dim2] * sign_dim2,
        Qz_mean_proj[angle_i, plt_start:plt_end, plt_dim3] * sign_dim3,
        color=prop_cycle[angle_i],
        alpha=0.6,
    )

axs.scatter(
    Qz_mean_proj[:, plt_start, plt_dim1] * sign_dim1,
    Qz_mean_proj[:, plt_start, plt_dim2] * sign_dim2,
    Qz_mean_proj[:, plt_start, plt_dim3] * sign_dim3,
    s=15,
    color=prop_cycle,
    edgecolor="black",
    linewidth=0.8,
    zorder=2,
    alpha=0.6,
)
axs.scatter(
    Qz_mean_proj[:, plt_end - 1, plt_dim1] * sign_dim1,
    Qz_mean_proj[:, plt_end - 1, plt_dim2] * sign_dim2,
    Qz_mean_proj[:, plt_end - 1, plt_dim3] * sign_dim3,
    s=15,
    marker="^",
    color=prop_cycle,
    edgecolor="black",
    linewidth=0.8,
    zorder=2,
    alpha=0.6,
)
axs.set_xticks([])
axs.set_yticks([])
axs.set_zticks([])
axs.set_xlabel(f"$z_{plt_dim1+1}$")
axs.set_ylabel(f"$z_{plt_dim2+1}$")
axs.set_zlabel(f"$z_{plt_dim3+1}$")
axs.xaxis.labelpad = -15
axs.yaxis.labelpad = -15
axs.zaxis.labelpad = -16
mvt_xlim, mvt_ylim, mvt_zlim = axs.get_xlim(), axs.get_ylim(), axs.get_zlim()

axs.set_title("movement period", pad=-10)
plt.savefig("../figures/movement_inf_latents.svg")
plt.savefig("../figures/movement_inf_latents.pdf")

In [None]:
# For creating input to the network
def project_to_square(theta):
    if theta > 180:
        theta -= 360
    elif theta < -180:
        theta += 360
    if theta >= 45 and theta < 135:
        y = 1.0
        x = y * np.cos(theta / 180 * np.pi) / np.sin(theta / 180 * np.pi)
    elif theta >= 135 or theta < -135:
        x = -1.0
        y = x * np.sin(theta / 180 * np.pi) / np.cos(theta / 180 * np.pi)
    elif theta >= -135 and theta < -45:
        y = -1.0
        x = y * np.cos(theta / 180 * np.pi) / np.sin(theta / 180 * np.pi)
    else:
        x = 1.0
        y = x * np.sin(theta / 180 * np.pi) / np.cos(theta / 180 * np.pi)
    return x, y

In [None]:
# Obtain prior latents (conditioned on input)
noise_scale = 1
gen_angles = np.arange(-180, 180, 22.5)
targets = [project_to_square(ang) for ang in gen_angles]
dur = 35
t_on = 0
t_of = 35
R_z = 0.05
prop_cycle = [plt.cm.hsv(i) for i in np.arange(0, 1, 1 / len(targets))]
z0 = Qzs_filt[:, 0, :].mean(axis=0)[None, :, None]
z0 = np.tile(z0, (len(targets), 1, 1))
z0 = torch.tensor(z0, dtype=torch.float)
n_repeats = 50
z_all = np.zeros((n_repeats, len(targets), dur, dim_z))
for ri in range(n_repeats):
    input = np.zeros((len(targets), dur, n_inp))
    z = np.zeros((len(targets), dur, dim_z))
    for i, target in enumerate(targets):
        input[i, t_on:t_of, :] = np.array(target)[None, :]
    input = torch.tensor(input, dtype=torch.float).permute(0, 2, 1).cuda()
    z = (
        vae.rnn.get_latent_time_series(dur, u=input, z0=z0, noise_scale=noise_scale)[0]
        .cpu()
        .numpy()
    )
    z_all[ri] = z.transpose(0, 2, 1, 3).squeeze()

In [None]:
z_mean = np.mean(z_all, axis=0)

In [None]:
# Project prior latents to (non..) orthogonal space
z_proj_mean = z_mean @ projection_matrix

In [None]:
# Make panel d (left)

z_proj_all = z_all @ projection_matrix
plt_start = 2
plt_end = 9
plt_dim1 = 0
plt_dim2 = 1
sign_dim1 = 1
sign_dim2 = -1
fig, axs = plt.subplots(1, 1, figsize=(1.0, 1.0))
prop_cycle = [plt.cm.hsv(i) for i in np.arange(0, 1, 1 / len(targets))]
for targ_i in range(len(targets)):

    for trial in range(min(z_proj_all.shape[0], 5)):
        plot_with_increasing_alpha(
            axs,
            x=z_proj_all[trial, targ_i, plt_start:plt_end, plt_dim1] * sign_dim1,
            y=z_proj_all[trial, targ_i, plt_start:plt_end, plt_dim2] * sign_dim2,
            alpha_min=0.0,
            alpha_max=0.3,
            linewidth=0.6,
            color=prop_cycle[targ_i],
        )
    plot_with_increasing_alpha(
        axs,
        x=z_proj_mean[targ_i, plt_start:plt_end, plt_dim1] * sign_dim1,
        y=z_proj_mean[targ_i, plt_start:plt_end, plt_dim2] * sign_dim2,
        alpha_min=0.0,
        alpha_max=0.6,
        color=prop_cycle[targ_i],
    )

axs.scatter(
    z_proj_mean[:, plt_end - 1, plt_dim1] * sign_dim1,
    z_proj_mean[:, plt_end - 1, plt_dim2] * sign_dim2,
    s=20,
    color=prop_cycle,
    edgecolor="black",
    linewidth=0.8,
    zorder=2,
    alpha=0.6,
)
axs.set_xticks([])
axs.set_yticks([])
axs.set_xlabel(f"$z_{plt_dim1+1}$")
axs.set_ylabel(f"$z_{plt_dim2+1}$")

axs.set_xlim(*prep_xlim)
axs.set_ylim(*prep_ylim)

axs.set_title("pre-movement")
plt.savefig("../figures/prep_gen_latents.svg")
plt.savefig("../figures/prep_gen_latents.pdf")

In [None]:
# Make panel d (right)

fig, axs = plt.subplots(1, 1, figsize=(1.2, 1.2), subplot_kw={"projection": "3d"})

z_proj_all = z_all @ projection_matrix
plt_start = 10
plt_end = 35
plt_dim1 = 4
plt_dim2 = 3
plt_dim3 = 2
sign_dim1 = 1
sign_dim2 = 1
sign_dim3 = 1

strd = 2
prop_cycle = [plt.cm.hsv(i) for i in np.arange(0, 1, 1 / len(targets))]
for targ_i in range(len(targets)):

    for trial in range(min(z_proj_all.shape[0], 5)):
        axs.plot(
            z_proj_all[trial, targ_i, plt_start:plt_end, plt_dim1] * sign_dim1,
            z_proj_all[trial, targ_i, plt_start:plt_end, plt_dim2] * sign_dim2,
            z_proj_all[trial, targ_i, plt_start:plt_end, plt_dim3] * sign_dim3,
            color=prop_cycle[targ_i],
            alpha=0.3,
            linewidth=0.6,
        )

    if targ_i % strd == 0:
        axs.plot(
            z_proj_mean[targ_i, plt_start:plt_end, plt_dim1] * sign_dim1,
            z_proj_mean[targ_i, plt_start:plt_end, plt_dim2] * sign_dim2,
            z_proj_mean[targ_i, plt_start:plt_end, plt_dim3] * sign_dim3,
            color=prop_cycle[targ_i],
            alpha=0.6,
        )

strd = 2
axs.scatter(
    z_proj_mean[::strd, plt_start, plt_dim1] * sign_dim1,
    z_proj_mean[::strd, plt_start, plt_dim2] * sign_dim2,
    z_proj_mean[::strd, plt_start, plt_dim3] * sign_dim3,
    s=15,
    color=prop_cycle[::strd],
    edgecolor="black",
    linewidth=0.8,
    zorder=2,
    alpha=0.8,
)
axs.scatter(
    z_proj_mean[::strd, plt_end - 1, plt_dim1] * sign_dim1,
    z_proj_mean[::strd, plt_end - 1, plt_dim2] * sign_dim2,
    z_proj_mean[::strd, plt_end - 1, plt_dim3] * sign_dim3,
    s=15,
    marker="^",
    color=prop_cycle[::strd],
    edgecolor="black",
    linewidth=0.8,
    zorder=2,
    alpha=0.8,
)
axs.set_xticks([])
axs.set_yticks([])
axs.set_zticks([])
axs.set_xlabel(f"$z_{plt_dim1+1}$")
axs.set_ylabel(f"$z_{plt_dim2+1}$")
axs.set_zlabel(f"$z_{plt_dim3+1}$")
axs.xaxis.labelpad = -15
axs.yaxis.labelpad = -15
axs.zaxis.labelpad = -16

axs.set_xlim(*mvt_xlim)
axs.set_ylim(*mvt_ylim)
axs.set_zlim(*mvt_zlim)

axs.set_title("movement period", pad=-10)
plt.savefig("../figures/movement_gen_latents.svg")
plt.savefig("../figures/movement_gen_latents.pdf")

In [None]:
# Get reach data

data_path = DATA_ROOT / "mc_maze_input" / "pos" / "eval_target_20ms.h5"
with h5py.File(data_path, "r") as h5f:
    train_behavior = h5f["mc_maze_20"]["train_behavior"][()]
    eval_behavior = h5f["mc_maze_20"]["eval_behavior"][()]

In [None]:
# Decode from observations

flatten2d = lambda arr: arr.reshape(-1, arr.shape[-1])
rate_decoder = Ridge(alpha=1e-6)

rate_decoder.fit(
    flatten2d(Xs_filt[:400]),
    flatten2d(train_behavior[:400]),
)

print(rate_decoder.score(flatten2d(Xs_filt[400:]), flatten2d(train_behavior[400:])))

In [None]:
# Decode from latents

latent_decoder = Ridge(alpha=1e-6)

latent_decoder.fit(
    flatten2d(Qzs_filt),
    flatten2d(train_behavior),
)

print(latent_decoder.score(flatten2d(Qzs_filt), flatten2d(train_behavior)))

In [None]:
nlb_behavior_all = rate_decoder.predict(flatten2d(Xs_filt)).reshape(
    *Xs_filt.shape[:-1], -1
)
nlb_position_all = np.cumsum(nlb_behavior_all, axis=1) / 50

In [None]:
# Make panel b

fig, axs = plt.subplots(1, 1, figsize=(1.0, 1.0))

nlb_position_mean = np.empty((len(angles), *nlb_position_all.shape[1:]))
for angle_i in range(len(np.sort(np.unique(angles)))):
    angle = np.sort(np.unique(angles))[angle_i]
    mask = angles == angle
    nlb_position_mean[angle_i] = nlb_position_all[mask].mean(axis=0)

for angle_i in range(len(np.sort(np.unique(angles)))):
    angle = np.sort(np.unique(angles))[angle_i]
    axs.plot(
        nlb_position_mean[angle_i, :, 0],
        nlb_position_mean[angle_i, :, 1],
        color=plt.cm.hsv(angle),
        alpha=0.6,
    )

    mask = angles == angle
    nlb_position_cond = nlb_position_all[mask]
    for trial in range(min(nlb_position_cond.shape[0], 10)):
        axs.plot(
            nlb_position_cond[trial, :, 0],
            nlb_position_cond[trial, :, 1],
            color=plt.cm.hsv(angle),
            alpha=0.3,
            linewidth=0.6,
        )
axs.set_xticks([])
axs.set_yticks([])
axs.set_xlabel("x position")
axs.set_ylabel("y position")
axs.set_title("decoded reaches")

xlim = axs.get_xlim()
xrange = xlim[1] - xlim[0]
axs.set_xlim(xlim[0] + 0.1 * xrange, xlim[1] - 0.1 * xrange)


dec_xlim = axs.get_xlim()
dec_ylim = axs.get_ylim()

plt.savefig("../figures/decoded_inf_reaches.svg")
plt.savefig("../figures/decoded_inf_reaches.pdf")

In [None]:
device = "cuda:0"

x_all = np.empty((*z_all.shape[:-1], 182))
for i in range(z_all.shape[0]):
    with torch.no_grad():
        x = (
            vae.rnn.get_observation(
                z=torch.tensor(z_all[i], dtype=torch.float, device=device)
                .unsqueeze(-1)
                .permute(0, 2, 1, 3)
            )[0]
            .cpu()
            .numpy()
        )

        x_all[i] = x.transpose(0, 2, 1, 3).squeeze()

In [None]:
behavior_all = rate_decoder.predict(flatten2d(x_all)).reshape(*x_all.shape[:-1], -1)

In [None]:
position_all = np.cumsum(behavior_all, axis=2) / 50

In [None]:
# Make panel e

position_mean = position_all.mean(axis=0)
fig, axs = plt.subplots(1, 1, figsize=(1.0, 1.0))
strd = 2
for targ_i in range(0, len(targets), 2):

    for trial in range(min(position_all.shape[0], 10)):
        axs.plot(
            position_all[trial, targ_i, :, 0],
            position_all[trial, targ_i, :, 1],
            color=prop_cycle[targ_i],
            alpha=0.3,
            linewidth=0.6,
        )

    if targ_i % strd == 0:
        axs.plot(
            position_mean[targ_i, :, 0],
            position_mean[targ_i, :, 1],
            color=prop_cycle[targ_i],
            alpha=0.6,
        )
axs.set_xticks([])
axs.set_yticks([])
axs.set_xlabel("x position")
axs.set_ylabel("y position")
axs.set_title("decoded reaches")


axs.set_xlim(*dec_xlim)
axs.set_ylim(*dec_ylim)

plt.savefig("../figures/decoded_gen_reaches.svg")
plt.savefig("../figures/decoded_gen_reaches.pdf")

In [None]:
# Make panel C spike stats

In [None]:
device = "cuda:0"
u = eval_dataset.stim.to(device)
initial_states = Qzs_filt[:, 0, :]
initial_states_mean = initial_states.mean(axis=0)
initial_states_std = initial_states.std(axis=0)
noise_scale = 1
np.random.seed(1)
torch.manual_seed(1)

z_sim = np.zeros((u.shape[0], dim_z, dur))
for ri in range(u.shape[0]):
    input = u[ri][None, :, :]
    z0 = torch.tensor(initial_states[ri], dtype=torch.float, device=device)[
        None,
        :,
        None,
    ]
    z = (
        vae.rnn.get_latent_time_series(dur, u=input, z0=z0, noise_scale=noise_scale)[0]
        .cpu()
        .numpy()
    )
    z_sim[ri] = z.squeeze()

In [None]:
# Obtain sampled spikes

x_sim = np.empty((z_sim.shape[0], 182, dur))
for i in range(z_sim.shape[0]):
    with torch.no_grad():
        x = (
            vae.rnn.get_observation(
                z=torch.tensor(z_sim[i], dtype=torch.float, device=device)
                .unsqueeze(-1)
                .unsqueeze(0)
            )[1]
            .cpu()
            .numpy()
        )
        x_sim[i] = x.squeeze()
sampled_spikes = x_sim

In [None]:
# Train stats per condition

spikes = train_dataset.data.detach().cpu().numpy()
target_pos = train_dataset.stim[:, :, 0].detach().cpu().numpy()
angles = np.arctan2(target_pos[:, 1], target_pos[:, 0])
angles = angles / (2 * np.pi) + 0.5
angles = (np.round(angles * 8) % 8) / 8

sr_mean_cond = np.empty((len(np.unique(angles)), spikes.shape[1]))
sr_std_cond = np.empty((len(np.unique(angles)), spikes.shape[1]))
isi_mean_cond = np.empty((len(np.unique(angles)), spikes.shape[1]))
isi_std_cond = np.empty((len(np.unique(angles)), spikes.shape[1]))
isi_cv_cond = np.empty((len(np.unique(angles)), spikes.shape[1]))
for i, angle in enumerate(np.sort(np.unique(angles))):
    mask = angles == angle
    sr_mean = spikes[mask].mean(axis=(0, 2)) * 50
    sr_std = spikes[mask].std(axis=(0, 2)) * 50
    isi_cv, isi_mean, isi_std = calc_isi_stats_per_trial(
        spikes[mask].transpose(0, 2, 1), dt=0.02
    )
    sr_mean_cond[i] = sr_mean
    sr_std_cond[i] = sr_std
    isi_mean_cond[i] = isi_mean.T
    isi_std_cond[i] = isi_std.T
    isi_cv_cond[i] = isi_cv.T

In [None]:
# Test stats per condition

test_spikes = eval_dataset.data.detach().cpu().numpy()
test_target_pos = eval_dataset.stim[:, :, 0].detach().cpu().numpy()
test_angles = np.arctan2(test_target_pos[:, 1], test_target_pos[:, 0])
test_angles = test_angles / (2 * np.pi) + 0.5
test_angles = (np.round(test_angles * 8) % 8) / 8

test_sr_mean_cond = np.empty((len(np.unique(test_angles)), test_spikes.shape[1]))
test_sr_std_cond = np.empty((len(np.unique(test_angles)), test_spikes.shape[1]))
test_isi_mean_cond = np.empty((len(np.unique(test_angles)), test_spikes.shape[1]))
test_isi_std_cond = np.empty((len(np.unique(test_angles)), test_spikes.shape[1]))
test_isi_cv_cond = np.empty((len(np.unique(test_angles)), test_spikes.shape[1]))
for i, angle in enumerate(np.sort(np.unique(test_angles))):
    mask = test_angles == angle
    test_sr_mean = test_spikes[mask].mean(axis=(0, 2)) * 50
    test_sr_std = test_spikes[mask].std(axis=(0, 2)) * 50
    test_isi_cv, test_isi_mean, test_isi_std = calc_isi_stats_per_trial(
        test_spikes[mask].transpose(0, 2, 1), dt=0.02
    )

    test_sr_mean_cond[i] = test_sr_mean
    test_sr_std_cond[i] = test_sr_std
    test_isi_mean_cond[i] = test_isi_mean.T
    test_isi_std_cond[i] = test_isi_std.T
    test_isi_cv_cond[i] = test_isi_cv.T

In [None]:
# Sampled stats per condition

inf_sr_mean_cond = np.empty((len(np.unique(angles)), spikes.shape[1]))
inf_sr_std_cond = np.empty((len(np.unique(angles)), spikes.shape[1]))
inf_isi_mean_cond = np.empty((len(np.unique(angles)), spikes.shape[1]))
inf_isi_std_cond = np.empty((len(np.unique(angles)), spikes.shape[1]))
inf_isi_cv_cond = np.empty((len(np.unique(angles)), spikes.shape[1]))
for i, angle in enumerate(np.sort(np.unique(angles))):
    mask = test_angles == angle
    inf_sr_mean = sampled_spikes[mask].mean(axis=(0, 2)) * 50
    inf_sr_std = sampled_spikes[mask].std(axis=(0, 2)) * 50
    inf_isi_cv, inf_isi_mean, inf_isi_std = calc_isi_stats_per_trial(
        sampled_spikes[mask].transpose(0, 2, 1), dt=0.02
    )

    inf_sr_mean_cond[i] = inf_sr_mean
    inf_sr_std_cond[i] = inf_sr_std
    inf_isi_mean_cond[i] = inf_isi_mean.T
    inf_isi_std_cond[i] = inf_isi_std.T
    inf_isi_cv_cond[i] = inf_isi_cv.T

In [None]:
top_fr = np.argsort(spikes.mean(axis=(0, 2)))[::-1][:30]

In [None]:
def correlation_per_condition(spikes, angles, neuron_inds):
    pwc_cond = []
    for i, angle in enumerate(np.sort(np.unique(angles))):
        mask = angles == angle
        data = spikes[mask][:, neuron_inds]
        data = data.transpose(0, 2, 1).reshape(-1, data.shape[1])
        pwc_cond_mat = calculate_correlation(data)
        pwc_cond_values = pwc_cond_mat[np.triu_indices(len(neuron_inds), k=1)]
        pwc_cond.append(pwc_cond_values)
    pwc_cond = np.stack(pwc_cond, axis=0)
    return pwc_cond

In [None]:
# Pairwise correlations
inf_pwc_cond = correlation_per_condition(sampled_spikes, test_angles, top_fr)
pwc_cond = correlation_per_condition(spikes, angles, top_fr)
test_pwc_cond = correlation_per_condition(test_spikes, test_angles, top_fr)

In [None]:
# Make supplementary figure with stats

true_vals = [
    test_sr_mean_cond,
    test_sr_std_cond,
    test_isi_mean_cond,
    test_isi_std_cond,
    test_pwc_cond,
]
train_vals = [sr_mean_cond, sr_std_cond, isi_mean_cond, isi_std_cond, pwc_cond]
inf_vals = [
    inf_sr_mean_cond,
    inf_sr_std_cond,
    inf_isi_mean_cond,
    inf_isi_std_cond,
    inf_pwc_cond,
]
labels = ["mean rates", "SD rates", "mean ISI", "SD ISI", "pairwise corr."]
angle_labels = ["-180", "-135", "-45", "0", "45", "90", "135"]

np.random.seed(0)
fig, axs = plt.subplots(7, 5, figsize=(5, 7))
for cond in range(test_sr_mean_cond.shape[0]):
    pidx = test_sr_mean_cond.shape[0] - cond - 1
    for i, true, train, inf, label in zip(
        range(len(labels)), true_vals, train_vals, inf_vals, labels
    ):
        all_x = np.concatenate([true[cond], true[cond]])
        all_y = np.concatenate([inf[cond], train[cond]])
        colors = ["teal"] * len(inf[cond]) + ["firebrick"] * len(train[cond])
        perm = np.random.permutation(len(all_x))
        colors_perm = [colors[i] for i in perm]
        if i != 4:
            axs[pidx][i].scatter(
                all_x[perm], all_y[perm], color=colors_perm, alpha=0.7, linewidths=0
            )
        else:
            axs[pidx][i].scatter(
                all_x[perm], all_y[perm], color=colors_perm, alpha=0.3, linewidths=0
            )
        if "mean rates" in label:
            axs[pidx][i].set_xscale("log")
            axs[pidx][i].set_yscale("log")
        xlim = axs[pidx][i].get_xlim()
        ylim = axs[pidx][i].get_ylim()
        axs[pidx][i].plot(
            np.arange(-0.5, 50.0), np.arange(-0.5, 50.0), linestyle="--", color="gray"
        )
        if "mean rates" in label:
            axs[pidx][i].set_xlim(min(xlim[0], ylim[0]), max(xlim[1], ylim[1]))
            axs[pidx][i].set_ylim(min(xlim[0], ylim[0]), max(xlim[1], ylim[1]))
            axs[pidx][i].set_xticks([0.1, 1, 10])
            axs[pidx][i].set_yticks([0.1, 1, 10])
            axs[pidx][i].set_xticklabels(["0.1", "1", "10"])
            axs[pidx][i].set_yticklabels(["0.1", "1", "10"])
            axs[pidx][i].get_xaxis().set_tick_params(which="minor", size=0)
            axs[pidx][i].get_xaxis().set_tick_params(which="minor", width=0)
            axs[pidx][i].get_yaxis().set_tick_params(which="minor", size=0)
            axs[pidx][i].get_yaxis().set_tick_params(which="minor", width=0)

        elif "pairwise corr." in label:
            axs[pidx][i].set_xlim(-0.15, 0.28)
            axs[pidx][i].set_ylim(-0.15, 0.28)
        elif "SD rates" in label:
            axs[pidx][i].set_xlim(0, 50)
            axs[pidx][i].set_ylim(0, 50)
        elif "mean ISI" in label:
            axs[pidx][i].set_xlim(0, 0.5)
            axs[pidx][i].set_ylim(0, 0.5)
        elif "SD ISI" in label:
            axs[pidx][i].set_xlim(0, 0.3)
            axs[pidx][i].set_ylim(0, 0.3)
        axs[pidx][i].set_yticklabels([])
    axs[pidx][0].set_ylabel(f"angle$={angle_labels[cond]}^{{\circ}}$\ngen/train")
for i in range(len(labels)):
    axs[0][i].set_title(labels[i])
    axs[-1][i].set_xlabel("test")
axs[0][0].text(0.022, 25.0, "test/gen", color="teal", alpha=0.7, fontsize=6)
axs[0][0].text(0.022, 10.0, "test/train", color="firebrick", alpha=0.7, fontsize=6)
plt.tight_layout()

plt.savefig("../figures/spike_stat_cond_scatter.png", dpi=300)
plt.savefig("../figures/spike_stat_cond_scatter.svg", dpi=300)
plt.savefig("../figures/spike_stat_cond_scatter.pdf", dpi=300)

In [None]:
# Fig c top

fig, axs = plt.subplots(1, 2, figsize=(2.0, 1.0))
n_cond = sr_mean_cond.shape[0]

vmin = 0
vmax = 0.3

axs[0].imshow(1 - np.corrcoef(test_sr_mean_cond), vmin=vmin, vmax=vmax, cmap="Blues")
axs[0].set_title("true")
axs[0].set_xticks([])
axs[0].set_yticks([])
axs[0].set_xlabel("condition")
axs[0].set_ylabel("condition")

axs[1].imshow(1 - np.corrcoef(inf_sr_mean_cond), vmin=vmin, vmax=vmax, cmap="Blues")
axs[1].set_title("model")
axs[1].set_xticks([])
axs[1].set_yticks([])
axs[1].set_xlabel("condition")
# plt.suptitle("     mean firing rate dissimilarity", y=0.95)
# plt.tight_layout()

plt.savefig("../figures/cond_spike_corr.svg")
plt.savefig("../figures/cond_spike_corr.pdf")

In [None]:
fig, axs = plt.subplots(figsize=(0.25 / 3, 3 / 3))
cmap = mpl.cm.Blues
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

cb1 = mpl.colorbar.ColorbarBase(axs, cmap=cmap, norm=norm, orientation="vertical")
cb1.set_ticks([0.0, 0.15, 0.3])
cb1.set_label("corr. distance")

plt.savefig("../figures/fr_colorbar.svg")
plt.savefig("../figures/fr_colorbar.pdf")

In [None]:
# Fig c bottom


fig, axs = plt.subplots(1, 2, figsize=(2.0, 1.0))
n_cond = isi_mean_cond.shape[0]

vmin = 0.0
vmax = 1.0

nan_mask = np.logical_or(
    np.any(np.isnan(isi_mean_cond), axis=0),
    np.any(np.isnan(inf_isi_mean_cond), axis=0),
)
nan_mask = np.logical_or(
    nan_mask,
    np.any(np.isnan(test_isi_mean_cond), axis=0),
)

axs[0].imshow(
    1 - np.corrcoef(test_isi_mean_cond[:, ~nan_mask]),
    vmin=vmin,
    vmax=vmax,
    cmap="Blues",
)
axs[0].set_title("true")
axs[0].set_xticks([])
axs[0].set_yticks([])
axs[0].set_xlabel("condition")
axs[0].set_ylabel("condition")

axs[1].imshow(
    1 - np.corrcoef(inf_isi_mean_cond[:, ~nan_mask]),
    vmin=vmin,
    vmax=vmax,
    cmap="Blues",
)
axs[1].set_title("model")
axs[1].set_xticks([])
axs[1].set_yticks([])
axs[1].set_xlabel("condition")
# plt.suptitle("     mean ISI dissimilarity", y=0.95)
# plt.tight_layout()

plt.savefig("../figures/cond_isi_corr.svg")
plt.savefig("../figures/cond_isi_corr.pdf")

In [None]:
upper_diag_elem_test = (
    1 - np.corrcoef(test_isi_mean_cond[:, ~nan_mask])[np.triu_indices(n_cond, k=1)]
)
upper_diag_elem_inf = (
    1 - np.corrcoef(inf_isi_mean_cond[:, ~nan_mask])[np.triu_indices(n_cond, k=1)]
)
np.median(abs((upper_diag_elem_test - upper_diag_elem_inf)))

In [None]:
fig, axs = plt.subplots(figsize=(0.25 / 3, 3 / 3))
cmap = mpl.cm.Blues
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

cb1 = mpl.colorbar.ColorbarBase(axs, cmap=cmap, norm=norm, orientation="vertical")
cb1.set_ticks([0.0, 0.5, 1.0])
cb1.set_label("corr. distance")

plt.savefig("../figures/isi_colorbar.svg")
plt.savefig("../figures/isi_colorbar.pdf")