In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib qt

import random
from pathlib import Path

from tqdm import tqdm
import numpy as np
import scipy as sp
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
# import pandas as pd
# pd.options.display.width = 1000

import os, sys
sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))

from vrAnalysis import analysis
from vrAnalysis import helpers
from vrAnalysis import database
from vrAnalysis import tracking
from vrAnalysis import session
from vrAnalysis import registration
from vrAnalysis import fileManagement as fm
from vrAnalysis import faststats as fs

sessiondb = database.vrDatabase('vrSessions')
mousedb = database.vrDatabase('vrMice')

# pd.set_option('display.max_rows', 100)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [17]:
mousedb = database.vrDatabase("vrMice")
df = mousedb.getTable(trackerExists=True)
mouse_names = df["mouseName"].unique()
keep_planes = [1]

# Get data for a single mouse
mouse_name = "CR_Hippocannula6"
track = tracking.tracker(mouse_name)  # get tracker object for mouse
pcm = analysis.placeCellMultiSession(track, autoload=False, keep_planes=keep_planes)

In [24]:
envnum, idx_ses = pcm.env_idx_ses_selector(envmethod="second", sesmethod=6)
envnum_first = pcm.env_selector(envmethod="first")
idx_ses_first = pcm.env_stats()[envnum_first]

idx_ses = sorted(list(set(idx_ses_first) & set(idx_ses)))
spkmaps, extras = pcm.get_spkmaps(envnum=envnum, idx_ses=idx_ses, trials="full", average=False, tracked=True)
_, extras_first = pcm.get_spkmaps(envnum=envnum_first, idx_ses=idx_ses, trials="full", average=True, tracked=True)

avgmaps = [np.nanmean(s, axis=1) for s in spkmaps]
avgcenters = [np.nanargmax(s, axis=1) for s in avgmaps]
avgmax = [np.nanmax(s, axis=1) for s in avgmaps]
centers = np.stack(avgcenters, axis=0)
maxes = np.stack(avgmax, axis=0)

In [25]:
idx_red = np.logical_or(np.any(np.stack(extras["idx_red"]), axis=0), np.any(np.stack(extras_first["idx_red"]), axis=0))
ctl_reliability = [np.nanmean(relcor[~idx_red]) for relcor in extras["relcor"]]
ctl_reliability_first = [np.nanmean(relcor[~idx_red]) for relcor in extras_first["relcor"]]
red_reliability = [np.nanmean(relcor[idx_red]) for relcor in extras["relcor"]]
red_reliability_first = [np.nanmean(relcor[idx_red]) for relcor in extras_first["relcor"]]

fig, ax = plt.subplots(2, 1, figsize=(8, 7), layout="constrained")
ax[0].plot(range(len(ctl_reliability)), ctl_reliability, color="k", linewidth=1)
ax[0].plot(range(len(red_reliability)), red_reliability, color="r", linewidth=1)
ax[0].set_xlabel("Environment")
ax[0].set_ylabel("Reliability")
ax[0].set_title("Novel Environment")
ax[1].plot(range(len(ctl_reliability_first)), ctl_reliability_first, color="k", linewidth=1)
ax[1].plot(range(len(red_reliability_first)), red_reliability_first, color="r", linewidth=1)
ax[1].set_xlabel("Environment")
ax[1].set_ylabel("Reliability")
ax[1].set_title("Familiar Environment")
plt.show()

In [107]:
xvals = np.tile(np.arange(centers.shape[0]).reshape(-1, 1), (1, centers.shape[1]))
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.scatter(xvals, (centers - centers[0]), c=centers, alpha=maxes / np.nanmax(maxes), cmap="Spectral")
plt.show()

In [104]:
xvals.shape

(1, 3576)

In [88]:
idx_target_ses = len(idx_ses) - 1
reliability = [rel for rel in extras["relcor"]]

def make_roi_trajectory(spkmaps, roi_idx, dead_trials=5):
    roi_activity = [s[roi_idx] for s in spkmaps]
    dead_space = [np.full((dead_trials, roi_activity[0].shape[1]), np.nan) for _ in range(len(roi_activity) - 1)]
    dead_space.append(None)
    interleaved = [item for pair in zip(roi_activity, dead_space) for item in pair if item is not None]

    trial_env = [ises*np.ones(r.shape[0]) for ises, r in enumerate(roi_activity)]
    dead_trial_env = [np.nan*np.ones(dead_trials) for _ in range(len(roi_activity) - 1)]
    dead_trial_env.append(None)
    env_trialnum = [item for pair in zip(trial_env, dead_trial_env) for item in pair if item is not None]
    return np.concatenate(interleaved, axis=0), np.concatenate(env_trialnum)
    
# Choose an ROI with the top Xth percentile reliability
def choose_roi(reliability, percentile=90):
    threshold = np.percentile(reliability, percentile)
    roi_idx = np.random.choice(np.where(reliability[idx_target_ses] > threshold)[0])
    return roi_idx

def gather_idxs(reliability_values, min_percentile=90, max_percentile=100):
    min_threshold = np.percentile(reliability_values, min_percentile)
    max_threshold = np.percentile(reliability_values, max_percentile)
    return np.where((reliability_values > min_threshold) & (reliability_values < max_threshold))[0]

idxs = gather_idxs(reliability[idx_target_ses], min_percentile=90, max_percentile=100)
roi_idx = 0
true_roi_idx = idxs[roi_idx]

print(len(idxs))
print(roi_idx)
print(true_roi_idx)

# roi_idx = choose_roi(reliability, percentile=95)
dead_trials = 5
roi_trajectory, env_trialnum = make_roi_trajectory(spkmaps, true_roi_idx, dead_trials=dead_trials)

def com(data, axis=-1):
    x = np.arange(data.shape[axis])
    com = np.sum(data * x, axis=axis) / (np.sum(data, axis=axis) + 1e-10)
    com[np.any(data < 0, axis=axis)] = np.nan
    return com

idx_not_nan = ~np.any(np.isnan(roi_trajectory), axis=1)
pfmax = np.where(idx_not_nan, np.max(roi_trajectory, axis=1), np.nan)
pfcom = np.where(idx_not_nan, com(roi_trajectory, axis=1), np.nan)
pfloc = np.where(idx_not_nan, np.argmax(roi_trajectory, axis=1), np.nan)


cmap = mpl.colormaps["gray_r"]
cmap.set_bad((1, 0.8, 0.8))  # Light red color
ses_col = plt.cm.Set1(np.linspace(0, 1, len(idx_ses)))

plt.rcParams["font.size"] = 14

fig = plt.figure(1, figsize=(8, 9))
fig.clf()
ax = fig.add_subplot(131)
ax.cla()
ax.imshow(roi_trajectory, aspect="auto", interpolation="none", cmap=cmap, vmin=0, vmax=10)
ax.set_title(f"ROI Activity")

idx_trials_target = np.where(env_trialnum == idx_target_ses)[0]
min_y = np.nanmin(idx_trials_target)
max_y = np.nanmax(idx_trials_target)
plt.plot([1, 1], [min_y, max_y], color=ses_col[idx_target_ses], linestyle="-", linewidth=5)
plt.plot([roi_trajectory.shape[1]-1, roi_trajectory.shape[1]-1], [min_y, max_y], color=ses_col[idx_target_ses], linestyle="-", linewidth=5)
ax.set_xlim(0, roi_trajectory.shape[1])
ax.text(0, (min_y + max_y) / 2, f"Target environment", color=ses_col[idx_target_ses], ha="right", va="center", rotation=90)
ax.set_ylabel("Trial")
ax.set_yticks([])
ax.set_xlabel("Virtual Position")

alpha_values = np.where(~np.isnan(pfmax), pfmax / np.nanmax(pfmax), 0)
ax = fig.add_subplot(132)
ax.cla()
ax.scatter(pfcom, range(len(pfcom)), color="k", s=10, alpha=alpha_values, linewidth=2)
ax.scatter(pfloc, range(len(pfloc)), color="r", s=10, alpha=alpha_values, linewidth=2)
ax.scatter([-10], [-10], color="k", s=10, alpha=0.5, linewidth=2, label="CoM")
ax.scatter([-10], [-10], color="r", s=10, alpha=0.5, linewidth=2, label="MaxLoc")
ax.set_xlim(0, roi_trajectory.shape[1])
ax.set_ylim(roi_trajectory.shape[0], 0)
ax.legend(loc="upper center")
ax.set_yticks([])
ax.set_title("PF Location")
ax.set_xlabel("Virtual Position")

ax = fig.add_subplot(133)
ax.cla()
# ax.plot(pfmax, range(len(pfmax)), color="k", linewidth=1, label="Peak Amplitude")
ax.scatter(pfmax, range(len(pfmax)), color="k", s=10)
ax.set_ylim(roi_trajectory.shape[0], 0)
ax.set_title("PF Amplitude")
ax.set_yticks([])
ax.set_xlabel("Activity (sigma)")
plt.show()


89
0
8


In [86]:
idxs

array([  3,   9,  16,  22,  40,  51,  52,  59,  60,  61,  72,  80,  86,
        91,  94,  95, 106, 108, 109, 110, 127, 140, 156, 161, 167, 184,
       197, 199, 201, 203, 231, 243, 250, 280, 288, 289, 298, 300, 301,
       307, 315, 327, 330, 336, 337, 343, 345, 360, 378, 412, 422, 448,
       452, 465, 468, 476, 484, 486, 496, 527, 547, 549, 557, 558, 569,
       577, 620, 623, 633, 654, 662, 664, 665, 668, 696, 712, 713, 714,
       716, 719, 729, 737, 747, 761, 777, 803, 805, 861, 862], dtype=int64)

In [33]:
len(env_trialnum)
len(roi_trajectory)

264

In [31]:
min_y, max_y

(204, 248)

In [7]:
np.stack((env, env == idx_target_ses), axis=0).T

array([[ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,