In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib qt

import random
from pathlib import Path

from tqdm import tqdm
import numpy as np
import scipy as sp
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
# import pandas as pd
# pd.options.display.width = 1000

import os, sys
sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))

from vrAnalysis import analysis
from vrAnalysis import helpers
from vrAnalysis import database
from vrAnalysis import tracking
from vrAnalysis import session
from vrAnalysis import registration
from vrAnalysis import fileManagement as fm
from vrAnalysis import faststats as fs

sessiondb = database.vrDatabase('vrSessions')
mousedb = database.vrDatabase('vrMice')

# pd.set_option('display.max_rows', 100)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [2]:
from vrAnalysis.syd_viewers.placecell_across_session_rel import PlaceFieldLoader, get_viewer
loader = PlaceFieldLoader("CR_Hippocannula6", keep_planes=[1, 2, 3, 4], sesmethod="all")

  gini_coefficient = (n + 1 - 2 * cum_vals / np.sum(x, axis=axis)) / n
  activity = x / np.sum(x, axis=axis, keepdims=True)
                                                                                     

In [3]:
viewer = get_viewer(loader).deploy(suppress_warnings=True)

HBox(children=(VBox(children=(VBox(children=(HTML(value='<b>Parameters</b>'), Dropdown(description='envoption'â€¦

In [4]:

mousedb = database.vrDatabase("vrMice")
df = mousedb.getTable(trackerExists=True)
mouse_names = df["mouseName"].unique()
keep_planes = [1, 2]

# Get data for a single mouse
mouse_name = "ATL027"
track = tracking.tracker(mouse_name)  # get tracker object for mouse
pcm = analysis.placeCellMultiSession(track, autoload=False, keep_planes=keep_planes)

In [5]:
envnum, idx_ses = pcm.env_idx_ses_selector(envmethod="second", sesmethod=4)
spkmaps, extras = pcm.get_spkmaps(envnum=envnum, idx_ses=idx_ses, trials="full", average=False, tracked=True)
print([s.shape for s in spkmaps])

                                                                                     

[(2160, 66, 195), (2160, 69, 195), (2160, 69, 195), (2160, 45, 195)]


In [14]:
all_trial_reliability = helpers.reliability_loo(np.concatenate(spkmaps, axis=1), weighted=True)
session_average_reliability = helpers.reliability_loo(np.stack([np.nanmean(spkmap, axis=1) for spkmap in spkmaps], axis=1), weighted=True)

In [52]:
idx_target_ses = len(idx_ses) - 1
idx_sort = extras["pfidx"][idx_target_ses]
sorted_spkmaps = [smap[idx_sort] for smap in spkmaps]
sorted_reliability = [rel[idx_sort] for rel in extras["relcor"]]

def make_roi_trajectory(spkmaps, roi_idx, dead_trials=5):
    roi_activity = [s[roi_idx] for s in spkmaps]
    dead_space = [np.full((dead_trials, roi_activity[0].shape[1]), np.nan) for _ in range(len(roi_activity) - 1)]
    dead_space.append(None)
    interleaved = [item for pair in zip(roi_activity, dead_space) for item in pair if item is not None]

    trial_env = [ises*np.ones(r.shape[0]) for ises, r in enumerate(roi_activity)]
    dead_trial_env = [np.nan*np.ones(dead_trials) for _ in range(len(roi_activity) - 1)]
    dead_trial_env.append(None)
    env_trialnum = [item for pair in zip(trial_env, dead_trial_env) for item in pair if item is not None]
    return np.concatenate(interleaved, axis=0), np.concatenate(env_trialnum)
    
# Choose an ROI with the top Xth percentile reliability
def choose_roi(sorted_reliability, percentile=90):
    threshold = np.percentile(sorted_reliability, percentile)
    roi_idx = np.random.choice(np.where(sorted_reliability[idx_target_ses] > threshold)[0])
    return roi_idx

roi_idx = choose_roi(sorted_reliability, percentile=95)
dead_trials = 5
roi_trajectory, env_trialnum = make_roi_trajectory(sorted_spkmaps, roi_idx, dead_trials=dead_trials)

def com(data, axis=-1):
    x = np.arange(data.shape[axis])
    com = np.sum(data * x, axis=axis) / (np.sum(data, axis=axis) + 1e-10)
    com[np.any(data < 0, axis=axis)] = np.nan
    return com

idx_not_nan = ~np.any(np.isnan(roi_trajectory), axis=1)
pfmax = np.where(idx_not_nan, np.max(roi_trajectory, axis=1), np.nan)
pfcom = np.where(idx_not_nan, com(roi_trajectory, axis=1), np.nan)
pfloc = np.where(idx_not_nan, np.argmax(roi_trajectory, axis=1), np.nan)


cmap = mpl.colormaps["gray_r"]
cmap.set_bad((1, 0.8, 0.8))  # Light red color
ses_col = plt.cm.Set1(np.linspace(0, 1, len(idx_ses)))

fig = plt.figure(1, figsize=(6, 12))
fig.clf()
ax = fig.add_subplot(131)
ax.cla()
ax.imshow(roi_trajectory, aspect="auto", interpolation="none", cmap=cmap, vmin=0, vmax=10)
ax.set_title(f"ROI Activity")

idx_trials_target = np.where(env_trialnum == idx_target_ses)[0]
min_y = np.nanmin(idx_trials_target)
max_y = np.nanmax(idx_trials_target)
plt.plot([1, 1], [min_y, max_y], color=ses_col[idx_target_ses], linestyle="-", linewidth=5)
plt.plot([roi_trajectory.shape[1]-1, roi_trajectory.shape[1]-1], [min_y, max_y], color=ses_col[idx_target_ses], linestyle="-", linewidth=5)
ax.set_xlim(0, roi_trajectory.shape[1])
ax.text(0, (min_y + max_y) / 2, f"Target environment", color=ses_col[idx_target_ses], ha="right", va="center", rotation=90)

alpha = 0.7
ax = fig.add_subplot(132)
ax.cla()
ax.plot(pfcom, range(len(pfcom)), color="k", linewidth=2, label="Center of Mass")
ax.plot(pfloc, range(len(pfloc)), color="r", linewidth=2, label="Peak Location")
ax.set_ylim(roi_trajectory.shape[0], 0)
ax.legend()
ax.set_title("PF Location")

ax = fig.add_subplot(133)
ax.cla()
ax.plot(pfmax, range(len(pfmax)), color="b", linewidth=2, label="Peak Amplitude")
ax.legend()
ax.set_ylim(roi_trajectory.shape[0], 0)
plt.show()


In [33]:
len(env_trialnum)
len(roi_trajectory)

264

In [31]:
min_y, max_y

(204, 248)

In [7]:
np.stack((env, env == idx_target_ses), axis=0).T

array([[ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,