In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib qt

import random
from pathlib import Path

from tqdm import tqdm
import numpy as np
import scipy as sp
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
# import pandas as pd
# pd.options.display.width = 1000

import os, sys
sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))

from vrAnalysis import analysis
from vrAnalysis import helpers
from vrAnalysis import database
from vrAnalysis import tracking
from vrAnalysis import session
from vrAnalysis import registration
from vrAnalysis import fileManagement as fm
from vrAnalysis import faststats as fs

sessiondb = database.vrDatabase('vrSessions')
mousedb = database.vrDatabase('vrMice')

# pd.set_option('display.max_rows', 100)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [2]:
from hosting.placefield_reliability.reliability_viewer import ReliabilityViewer

In [3]:
viewer = ReliabilityViewer(fast_mode=True)

['CR_Hippocannula6' 'CR_Hippocannula7']


Preparing mouse data: 100%|██████████| 2/2 [00:24<00:00, 12.32s/it]


In [5]:
fig = viewer.get_plot("CR_Hippocannula6", use_relcor=False, tracked=False, average=False, min_session=None, max_session=None)
fig.show()

In [13]:
print(viewer.idx_ses_first["CR_Hippocannula6"])
print(viewer.idx_ses_second["CR_Hippocannula6"])
print(viewer.rel_idx_ses_first["CR_Hippocannula6"])
print(viewer.rel_idx_ses_second["CR_Hippocannula6"])


[0, 1, 2, 4, 5, 6]
[1, 2, 3, 4, 5, 6]
[-1, 0, 1, 3, 4, 5]
[0, 1, 2, 3, 4, 5]


In [17]:
mousedb = database.vrDatabase("vrMice")
df = mousedb.getTable(trackerExists=True)
mouse_names = df["mouseName"].unique()
keep_planes = [1]

# Get data for a single mouse
mouse_name = "CR_Hippocannula6"
track = tracking.tracker(mouse_name)  # get tracker object for mouse
pcm = analysis.placeCellMultiSession(track, autoload=False, keep_planes=keep_planes)

In [24]:
envnum, idx_ses = pcm.env_idx_ses_selector(envmethod="second", sesmethod=6)
envnum_first = pcm.env_selector(envmethod="first")
idx_ses_first = pcm.env_stats()[envnum_first]

idx_ses = sorted(list(set(idx_ses_first) & set(idx_ses)))
spkmaps, extras = pcm.get_spkmaps(envnum=envnum, idx_ses=idx_ses, trials="full", average=False, tracked=True)
_, extras_first = pcm.get_spkmaps(envnum=envnum_first, idx_ses=idx_ses, trials="full", average=True, tracked=True)

avgmaps = [np.nanmean(s, axis=1) for s in spkmaps]
avgcenters = [np.nanargmax(s, axis=1) for s in avgmaps]
avgmax = [np.nanmax(s, axis=1) for s in avgmaps]
centers = np.stack(avgcenters, axis=0)
maxes = np.stack(avgmax, axis=0)

In [25]:
idx_red = np.logical_or(np.any(np.stack(extras["idx_red"]), axis=0), np.any(np.stack(extras_first["idx_red"]), axis=0))
ctl_reliability = [np.nanmean(relcor[~idx_red]) for relcor in extras["relcor"]]
ctl_reliability_first = [np.nanmean(relcor[~idx_red]) for relcor in extras_first["relcor"]]
red_reliability = [np.nanmean(relcor[idx_red]) for relcor in extras["relcor"]]
red_reliability_first = [np.nanmean(relcor[idx_red]) for relcor in extras_first["relcor"]]

fig, ax = plt.subplots(2, 1, figsize=(8, 7), layout="constrained")
ax[0].plot(range(len(ctl_reliability)), ctl_reliability, color="k", linewidth=1)
ax[0].plot(range(len(red_reliability)), red_reliability, color="r", linewidth=1)
ax[0].set_xlabel("Environment")
ax[0].set_ylabel("Reliability")
ax[0].set_title("Novel Environment")
ax[1].plot(range(len(ctl_reliability_first)), ctl_reliability_first, color="k", linewidth=1)
ax[1].plot(range(len(red_reliability_first)), red_reliability_first, color="r", linewidth=1)
ax[1].set_xlabel("Environment")
ax[1].set_ylabel("Reliability")
ax[1].set_title("Familiar Environment")
plt.show()