In [None]:
import os
import sys

file_dir = os.getcwd()
sys.path.append(file_dir + "/../")
import torch
import numpy as np
from vi_rnn.saving import load_model
from scipy.signal import welch
from evaluation.calc_stats import calc_isi_stats, calculate_correlation
from vi_rnn.saving import load_model
from vi_rnn.utils import get_orth_proj_latents
from vi_rnn.generate import generate
from vi_rnn.inference import filtering_posterior
from scipy.signal import welch
from scipy.signal.windows import hann
from scipy.stats import zscore
from scipy import signal, ndimage
from sklearn.linear_model import Ridge, LinearRegression
import matplotlib.pyplot as plt
import matplotlib as mpl

%matplotlib inline

# we are using data from: https://crcns.org/data-sets/hc/hc-11/about-hc-11
# Grosmark, A.D., and Buzsáki, G. (2016). Diversity in neural firing dynamics supports both rigid and learned hippocampal sequences. Science 351, 1440–1443.
# Chen, Z., Grosmark, A.D., Penagos, H., and Wilson, M.A. (2016). Uncovering representations of sleep-associated hippocampal ensemble spike activity. Sci. Rep. 6, 32193.

In [None]:
# set seed
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
fs = 40

In [None]:
filename = "../models/hpc11/_CNN_causal_PLRNN_Z_Date_122024_06_24_T_17_54_22"
vae, training_params, task_params = load_model(str(filename), backward_compat=True)
rank = vae.dim_z
print(rank)
vae = vae.eval()

In [None]:
# load spiking data
train_data = np.load("../data_untracked/train_hpc_11_i25.npy")
test_data = np.load("../data_untracked/test_hpc_11_i25.npy")
dim_x, T = test_data.shape
_, T_train = train_data.shape
print("Spiking train data shape: ", train_data.shape)
print("Spiking test data shape: ", test_data.shape)
# load locations
test_locs = np.load("../data_untracked/testloc.npy")
train_locs = np.load("../data_untracked/trainloc.npy")
print("Locations test shape: ", test_locs.shape)
print("Locations train shape: ", train_locs.shape)

# load lfp
lfp_run = np.load("../data_untracked/run_maze.npy").T
print("LFP run shape: ", lfp_run.shape)
lfp_norun = np.load("../data_untracked/norun_maze.npy").T
print("LFP norun shape: ", lfp_norun.shape)

In [None]:
# get latent trajectory and observation (cut first 1000 time steps)
prior_Z, _, spikes_pred, lam = generate(
    vae,
    x=torch.from_numpy(test_data)[:, :1000].float().unsqueeze(0),
    dur=T,
    cut_off=1000,
    initial_state="posterior_sample",
)
spikes_pred = spikes_pred[0, :, :, 0].T.detach().numpy()

# project latents on orthogonalised connnectivity and normalise
projection_matrix = get_orth_proj_latents(vae)
prior_Z = projection_matrix @ prior_Z[0, :, :, 0]
prior_Z = zscore(prior_Z.detach().numpy(), axis=1)

In [None]:
# get posterior latents given test
k = 128  # number of particles
_, post_Z, _ = filtering_posterior(
    vae,
    x=torch.from_numpy(test_data).float().unsqueeze(0),
    u=torch.zeros(1, 0, test_data.shape[1]),
    k=k,
    resample="systematic",
)

post_Z = post_Z[0].mean(axis=-1)  # average over particles
post_Z = projection_matrix @ post_Z  # project on orthogonalised connectivity
post_Z = zscore(post_Z.detach().numpy(), axis=1)  # normalise

In [None]:
# get posterior latents given train
k = 128  # number of particles
_, post_Z_train, _ = filtering_posterior(
    vae,
    x=torch.from_numpy(train_data).float().unsqueeze(0),
    u=torch.zeros(1, 0, train_data.shape[1]),
    k=k,
    resample="systematic",
)
post_Z_train = post_Z_train[0].mean(axis=-1)  # average over particles
post_Z_train = (
    projection_matrix @ post_Z_train
)  # project on orthogonalised connectivity
post_Z_train = zscore(post_Z_train.detach().numpy(), axis=1)  # normalise

In [None]:
# zscore the lfp and get the power spectral density
lfp_run_pr = zscore(lfp_run, axis=1).mean(axis=0)
lfp_norun_pr = zscore(lfp_norun, axis=1).mean(axis=0)

nperseg = 1024

f_lfp_run, psd_lfp_run = welch(lfp_run_pr, fs=fs, nperseg=nperseg)
f_lfp_norun, psd_lfp_norun = welch(lfp_norun_pr, fs=fs, nperseg=nperseg)

In [None]:
# Obtain power spectral density of the prior latents

nperseg = 400

f_pZ1, psd_pZ1 = welch(prior_Z[0], fs=fs, nperseg=nperseg)
f_pZ2, psd_pZ2 = welch(prior_Z[1], fs=fs, nperseg=nperseg)
f_pZ3, psd_pZ3 = welch(prior_Z[2], fs=fs, nperseg=nperseg)
f_pZ4, psd_pZ4 = welch(prior_Z[3], fs=fs, nperseg=nperseg)

In [None]:
# linear regression from latents to LFP

# Define the model and fit it
model = Ridge()

X = post_Z_train.T
X_test = post_Z.T  #
X_prior = prior_Z.T

# Smooth latents
window_l = 301
window = hann(window_l)
X = ndimage.convolve1d(X, window, axis=0)
X_test = ndimage.convolve1d(X_test, window, axis=0)
X_prior = ndimage.convolve1d(X_prior, window, axis=0)


# fit model
y = train_locs
model.fit(X, y)

# make predictions
y_pred = model.predict(X_test)
y_pred_train = model.predict(X)
y_prior = model.predict(X_prior)
r2 = model.score(X_test, test_locs)

print("R2 score:", r2)

In [None]:
t_pr = np.arange(0, len(y_prior)) / 40

In [None]:
# Predicted vs rat

fig, ax = plt.subplots(1, 1, figsize=(1.5, 1))
pr = "#9BB5DE"
rat = "#2B3073"
dur = 20 * 40
ax.plot(t_pr, y_pred, color=pr, alpha=0.7, label="predicted")
ax.plot(t_pr, test_locs, color=rat, alpha=0.7, label="rat")
legend_labels = ["predicted", "rat"]
legend_colors = [pr, rat]
legend = plt.legend(
    legend_labels,
    handletextpad=0,
    handlelength=0,
    fancybox=True,
    loc="upper right",
    bbox_to_anchor=(1.12, 0.6),
)
for text, color in zip(legend.get_texts(), legend_colors):
    text.set_color(color)
ax.set_xlabel("time (s)")
ax.set_yticks([])

In [None]:
inds = abs(ndimage.convolve1d(np.gradient(y_prior), window)) > 1  # 0.75
print("percentage running", np.sum(inds) / len(inds))

In [None]:
# LFPs

nperseg = 400
color1 = "teal"
color2 = "#A860AF"
mean_psd_run = []
mean_psd_stat = []
for i in range(len(prior_Z)):
    f_pZ, psd_pZ = welch(prior_Z[i, inds], fs=fs, nperseg=nperseg)
    mean_psd_run.append(psd_pZ)
    f_pZ, psd_pZ = welch(prior_Z[i, ~inds], fs=fs, nperseg=nperseg)
    mean_psd_stat.append(psd_pZ)
mean_psd_run = np.array(mean_psd_run).mean(axis=0)
mean_psd_stat = np.array(mean_psd_stat).mean(axis=0)

fig, ax = plt.subplots(1, 1, figsize=(3, 3))
(line1,) = ax.semilogy(
    f_pZ, mean_psd_run, alpha=0.9, zorder=0, label="running", color=color1
)

(line2,) = ax.semilogy(
    f_pZ, mean_psd_stat, alpha=0.9, zorder=0, label="stationary", color=color2
)
(line3,) = ax.semilogy(
    f_lfp_run, psd_lfp_run, color="black", alpha=0.6, zorder=0, label="LFP"
)
(line4,) = ax.semilogy(
    f_lfp_norun, psd_lfp_norun, color="grey", alpha=0.6, zorder=0, label="LFP"
)
ax.set_xlim([0, 20])
ax.set_ylim([10**-4, 1])
ax.set_title("psd")
ax.set_xlabel("frequency (hz)")
ax.tick_params(axis="y", which="both", width=1)

ax.set_yticks([0.001, 0.1])
ax.set_yticklabels(["0.001", "0.1"])
ax.set_xticks([0.2, 10, 20])
# custom legend handles. note: legends were manually adjusted using illustrator afterwards to include the LFP signal
legend_labels = ["run", "stat", "$LFP run$", "$LFP stat$"]

legend_colors = [color1, color2, "black", "grey"]

legend = ax.legend(
    legend_labels,
    handletextpad=0,
    handlelength=0,
    fancybox=True,
    loc="upper right",
    bbox_to_anchor=(1, 1),
)
for text, color in zip(legend.get_texts(), legend_colors):
    text.set_color(color)

In [None]:
# get firing rates
fr_test = np.mean(test_data, axis=1) * fs
fr_train = np.mean(train_data, axis=1) * fs
fr_gen = np.mean(spikes_pred, axis=0) * fs

In [None]:
# Calculate correlation matrices
test_correlation = calculate_correlation(test_data.T)
gen_correlation = calculate_correlation(spikes_pred)
train_correlation = calculate_correlation(train_data.T)

# Extracting upper triangle values without the diagonal
i_upper = np.triu_indices(dim_x, k=1)
test_corr_values = test_correlation[i_upper]
gen_corr_values = gen_correlation[i_upper]
train_corr_values = train_correlation[i_upper]

In [None]:
# get ISI stats

CVs_isi_test, Means_isi_test, Std_isi_test = calc_isi_stats(test_data.T, dt=1 / fs)
CVs_isi_gen, Means_isi_gen, Std_isi_gen = calc_isi_stats(spikes_pred, dt=1 / fs)
CVs_isi_train, Means_isi_train, Std_isi_train = calc_isi_stats(train_data.T, dt=1 / fs)

In [None]:
# Make plots

color1 = "#7B46C1"
color2 = "#A860AF"
color3 = "#7C277D"
color4 = "#8A44A4"
colors = [
    color1,
    color2,
    color3,
    color4,
    color1,
    color2,
    color3,
    color4,
    color1,
    color2,
    color3,
    color4,
    color1,
    color2,
    color3,
    color4,
]

tg = "teal"
tr = "firebrick"
cmap = plt.get_cmap("tab20b")
cmap2 = plt.get_cmap("Dark2")

with mpl.rc_context(fname="matplotlibrc"):

    # Create a figure with specified size
    fig, axes = plt.subplots(2, 4, figsize=(6, 3))
    fig.subplots_adjust(hspace=0.2, wspace=0.6)

    ax1 = axes[0, 0]
    ax2 = axes[0, 2]
    ax3 = axes[0, 1]
    ax4 = axes[1, 0]
    ax5 = axes[1, 1]
    ax6 = axes[1, 2]
    ax7 = axes[1, 3]
    ax8 = axes[0, 3]

    # extract the width gap
    gap = ax5.get_position().x0 - ax4.get_position().x1

    # manually adjust the locations of wider plots
    width_scaling_factor = 1.6
    new_positions = []

    for i, ax in enumerate(axes[0, :3]):
        pos = ax.get_position()
        if i == 0:
            new_positions.append(
                [pos.x0, pos.y0, pos.width * width_scaling_factor, pos.height]
            )
        else:
            new_positions.append(
                [
                    new_positions[i - 1][0] + new_positions[i - 1][2] + gap * 4 / 2.1,
                    pos.y0,
                    pos.width * width_scaling_factor,
                    pos.height,
                ]
            )

    for i, ax in enumerate(axes[0, :3]):
        ax.set_position(new_positions[i])
    sec = 20
    init = 750
    duration = sec * 40
    # data prediction
    pr = "#9BB5DE"
    rat = "#2B3073"
    ax1.plot(t_pr, y_pred, color=pr, alpha=0.7, label="predicted")
    ax1.plot(t_pr, test_locs, color=rat, alpha=0.7, label="rat")
    legend_labels = ["predicted", "rat"]
    legend_colors = [pr, rat]
    legend = plt.legend(
        legend_labels,
        handletextpad=0,
        handlelength=0,
        fancybox=True,
        loc="upper right",
        bbox_to_anchor=(1.12, 0.6),
    )
    for text, color in zip(legend.get_texts(), legend_colors):
        text.set_color(color)
    ax1.set_xlabel("time (s)")
    ax1.set_xlim(0, t_pr[-1])
    ax1.set_yticks([])
    ax1.set_title("predicted location")
    # generated spikes

    ax3.set_title("generated location")
    ax3.plot(t_pr, y_prior, color=pr, alpha=1)
    ax3.set_xlabel("time (s)")
    ax3.set_yticks([])
    ax3.set_xlim(0, t_pr[-1])

    # latents
    t = np.linspace(0, sec, duration)
    for i in range(len(prior_Z)):
        ax2.plot(
            t,
            prior_Z[i][init : init + duration] + (i - 1) * 3,
            alpha=0.9,
            # label=f"Z{i}",
            color=colors[i],
        )

    ax2.set_xlim(0, sec)
    ax2.set_yticks([])
    ax2.set_xticks([0, 10, 20])
    ax2.set_yticks([])
    ax2.set_xlabel("time (s)")
    ax2.set_title("latents")

    # coefficient of variation
    n_dots = len(CVs_isi_test)
    zorders = np.arange(n_dots * 2)
    np.random.shuffle(zorders)
    for i in range(n_dots):
        ax5.scatter(
            CVs_isi_test[i],
            CVs_isi_gen[i],
            s=10,
            alpha=0.7,
            color=tg,
            zorder=zorders[i],
        )
        ax5.scatter(
            CVs_isi_test[i],
            CVs_isi_train[i],
            s=10,
            alpha=0.7,
            color=tr,
            label="train",
            zorder=zorders[i + n_dots],
        )

    ax5.plot([0, 4], [0, 4], color="gray", linestyle="--", zorder=0)
    ax5.set_title("cv ISIs")
    ax5.set_xticks([0, 4])
    ax5.set_yticks([0, 4])
    ax5.set_xlabel("test")
    ax5.set_xlim([0, 4])
    ax5.set_ylim([0, 4])

    # mean rates
    ax4.plot(
        np.linspace(0, 5, 2),
        np.linspace(0, 5, 2),
        color="gray",
        linestyle="--",
        zorder=0,
    )
    n_dots = len(fr_test)
    zorders = np.arange(n_dots * 2)
    np.random.shuffle(zorders)
    for i in range(n_dots):
        ax4.scatter(fr_test[i], fr_gen[i], s=10, alpha=0.7, color=tg, zorder=zorders[i])
        ax4.scatter(
            fr_test[i],
            fr_train[i],
            s=10,
            alpha=0.7,
            color=tr,
            zorder=zorders[i + n_dots],
        )

    ax4.set_title("mean rates (hz)")
    ax4.set_xticks([1, 5])
    ax4.set_yticks([1, 5])
    ax4.set_xlabel("test")
    ax4.set_ylabel("gen / train")

    # custom legend labels and colors
    legend_labels = ["test/gen", "test/train"]
    legend_colors = [tg, tr]

    # add custom legend
    legend_elements = [
        plt.Line2D([0], [0], color=color, lw=0, label=label)
        for color, label in zip(legend_colors, legend_labels)
    ]
    legend = ax4.legend(
        handles=legend_elements,
        handletextpad=0,
        handlelength=0,
        fancybox=True,
        loc="upper right",
        bbox_to_anchor=(1.16, 0.45),
        fontsize=6,
    )

    for text, color in zip(legend.get_texts(), legend_colors):
        text.set_color(color)

    # pairwise correlation
    ax6.plot([-0.1, 0.2], [-0.1, 0.2], color="gray", linestyle="--", zorder=0)
    dots = len(test_corr_values)
    zorders = np.arange(n_dots * 2)
    np.random.shuffle(zorders)
    for i in range(n_dots):
        ax6.scatter(
            test_corr_values[i],
            gen_corr_values[i],
            s=10,
            alpha=0.7,
            color=tg,
            zorder=zorders[i],
        )
        ax6.scatter(
            test_corr_values[i],
            train_corr_values[i],
            s=10,
            alpha=0.7,
            color=tr,
            label="train",
            zorder=zorders[i + n_dots],
        )

    ax6.tick_params(axis="x", which="both", width=1)
    ax6.tick_params(axis="y", which="both", width=1)
    ax6.set_title("pairwise corr.")
    ax6.set_xlabel("test")
    ax6.set_yticks([-0.1, 0.2])
    ax6.set_yticklabels(["-0.1", "0.2"])
    ax6.set_xticks([-0.1, 0.2])
    ax6.set_xticklabels(["-0.1", "0.2"])
    ax6.set_xlim([-0.1, 0.2])
    ax6.set_ylim([-0.1, 0.2])

    # PSD
    color1 = "#7B46C1"
    color2 = "hotpink"
    color3 = "teal"
    color4 = "lightblue"
    (line1,) = ax7.semilogy(
        f_pZ, mean_psd_run, alpha=0.9, zorder=10, label="running", color=color1
    )

    # Plot stationary PS
    (line2,) = ax7.semilogy(
        f_pZ, mean_psd_stat, alpha=0.9, zorder=0, label="stationary", color=color2
    )

    (line3,) = ax7.semilogy(
        f_lfp_run, psd_lfp_run, color=color3, alpha=1, zorder=3, label="LFP"
    )
    (line4,) = ax7.semilogy(
        f_lfp_norun, psd_lfp_norun, color=color4, alpha=1, zorder=4, label="LFP norun"
    )
    ax7.set_xlim([0, 20])
    ax7.set_ylim([10**-4, 1])
    ax7.set_title("psd")
    ax7.set_xlabel("frequency (hz)")
    ax7.tick_params(axis="y", which="both", width=1)

    ax7.set_yticks([0.001, 0.1])
    ax7.set_yticklabels(["0.001", "0.1"])
    ax7.set_xticks([0.2, 10, 20])

    legend_labels = [
        "model running",
        "model static",
        "LFP rat running",
        "LFP rat static",
    ]

    legend_colors = [color1, color2, color3, color4]

    legend = ax7.legend(
        legend_labels,
        handletextpad=0,
        handlelength=0,
        fancybox=True,
        loc="upper right",
        bbox_to_anchor=(2, 1),
    )
    for text, color in zip(legend.get_texts(), legend_colors):
        text.set_color(color)

    ax8.axis("off")
    plt.gcf().set_size_inches(5.2, 3)

    ax1.set_box_aspect(0.625)
    ax2.set_box_aspect(0.625)
    ax3.set_box_aspect(0.625)
    ax7.set_box_aspect(1)
    ax4.set_box_aspect(1)
    ax5.set_box_aspect(1)
    ax6.set_box_aspect(1)
    ax8.set_box_aspect(1)

    plt.savefig("../figures/hpc11_full.png", dpi=300)
    plt.savefig("../figures/hpc11_full.pdf", dpi=300)

    plt.show()