In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib qt

import random
from pathlib import Path

from tqdm import tqdm
import numpy as np
import scipy as sp
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
# import pandas as pd
# pd.options.display.width = 1000

import os, sys
sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))

from vrAnalysis import analysis
from vrAnalysis import helpers
from vrAnalysis import database
from vrAnalysis import tracking
from vrAnalysis import session
from vrAnalysis import registration
from vrAnalysis import fileManagement as fm
from vrAnalysis import faststats as fs

sessiondb = database.vrDatabase('vrSessions')
mousedb = database.vrDatabase('vrMice')

# pd.set_option('display.max_rows', 100)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [2]:
# I'd like a plot of the number of red cells (including number reliable)
# for each grouping of sessions. 
mouse_name = "CR_Hippocannula6"
track = tracking.tracker(mouse_name)
pcm = analysis.placeCellMultiSession(track, autoload=False, keep_planes=[1, 2, 3, 4])

In [5]:
envnum, idx_ses = pcm.env_idx_ses_selector(envmethod="second", sesmethod=4)
spkmaps, extras = pcm.get_spkmaps(envnum=envnum, idx_ses=idx_ses, trials="full", average=False, tracked=True)

In [9]:

print(spkmaps[0].shape)
print(extras.keys())

relmethod = "relloo"
top_rel = [np.where(rel > np.nanpercentile(rel, 90))[0] for rel in extras[relmethod]]

idx_ses_plot = 2

fig, ax = plt.subplots()

(1954, 75, 243)
dict_keys(['relmse', 'relcor', 'relloo', 'idx_red', 'pfloc', 'pfidx', 'roi_idx'])


In [46]:
from syd.interactive_viewer import InteractiveViewer
from syd.notebook_deploy import NotebookDeployment

class MyViewer(InteractiveViewer):
    def __init__(self, spkmaps, extras):
        self.spkmaps = spkmaps
        self.extras = extras

        self.add_integer("idx_ses_plot", min_value=0, max_value=len(self.spkmaps), default=0)
        self.add_selection("idx_roi_plot", options=list(range(self.spkmaps[0].shape[0])), default=0)
        self.add_float("vmax", min_value=0.1, max_value=30, default=10)
        self.add_integer("trial", min_value=0, max_value=self.spkmaps[0].shape[1], default=0)
        self.add_boolean("weight_as_alpha", default=True)
        self.add_integer_pair("reliability_range", min_value=0, max_value=100, default=(90, 100))
        self.add_selection("relmethod", options=["relloo", "relcor"], default="relloo")

        self.on_change("idx_ses_plot", self.update_ranges)
        self.on_change("reliability_range", self.update_ranges)
        self.on_change("relmethod", self.update_ranges)

        self.update_ranges(self.get_state())

    def update_ranges(self, state):
        idx_ses_plot = state["idx_ses_plot"]
        min_rel, max_rel = state["reliability_range"]
        relmethod = state["relmethod"]
        min_percentile = np.nanpercentile(self.extras[relmethod][idx_ses_plot], min_rel)
        max_percentile = np.nanpercentile(self.extras[relmethod][idx_ses_plot], max_rel)
        idx_roi_options = np.where(np.logical_and(self.extras[relmethod][idx_ses_plot] >= min_percentile, self.extras[relmethod][idx_ses_plot] <= max_percentile))[0]

        self.update_selection("idx_roi_plot", options=list(idx_roi_options), default=idx_roi_options[0])
        self.update_integer("trial", min_value=0, max_value=self.spkmaps[idx_ses_plot].shape[1])

    def get_trial_consistency(self, roi_activity):
        num_trials = roi_activity.shape[0]
        trial_consistency = np.full(num_trials, np.nan)
        trial_weight = np.sqrt(np.mean(roi_activity**2, axis=1))
        for trial in range(num_trials):
            all_but_trial_activity = np.delete(roi_activity, trial, axis=0)
            all_but_trial_average = np.nanmean(all_but_trial_activity, axis=0)
            trial_consistency[trial] = helpers.vectorCorrelation(all_but_trial_average, roi_activity[trial], axis=0)
        return trial_consistency, trial_weight

    def plot(self, state):
        idx_ses_plot = state["idx_ses_plot"]
        idx_roi_plot = state["idx_roi_plot"]
        vmax = state["vmax"]
        trial = state["trial"]
        weight_as_alpha = state["weight_as_alpha"]

        trial_consistency, trial_weight = self.get_trial_consistency(self.spkmaps[idx_ses_plot][idx_roi_plot])
        trial_weight = trial_weight / np.nanmax(trial_weight)

        loo_average = np.delete(self.spkmaps[idx_ses_plot][idx_roi_plot], trial, axis=0)
        loo_weights = np.sqrt(np.mean(loo_average**2, axis=1, keepdims=True))
        loo_average = np.nanmean(loo_average*loo_weights, axis=0)

        # Colormap based on trial consistency
        cmap = plt.get_cmap("viridis_r")
        colors = cmap(trial_consistency)

        fig, ax = plt.subplots(1, 3, figsize=(12, 5), layout="constrained")
        cb = ax[0].imshow(self.spkmaps[idx_ses_plot][idx_roi_plot], cmap="gray_r", vmin=0, vmax=vmax, aspect="auto", interpolation="none")
        ax[0].axhline(y=trial, color="g", linewidth=0.5)
        ax[0].set_title(f"ROI {idx_roi_plot} in Session {idx_ses_plot}")
        ax[0].set_xlabel("Virtual Position")
        ax[0].set_ylabel("Trials")
        cbar = plt.colorbar(cb, ax=ax[0])
        cbar.set_label(r"Activity Level ($\sigma$)")

        for ii in range(trial_consistency.shape[0]):
            alpha = trial_weight[ii] if weight_as_alpha else 1
            ax[1].plot(self.spkmaps[idx_ses_plot][idx_roi_plot, ii], color=colors[ii], linewidth=1, alpha=alpha)

        norm = plt.Normalize(np.min(trial_consistency), np.max(trial_consistency))
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        cbar = plt.colorbar(sm, ax=ax[1])
        cbar.set_label('Trial Consistency')

        ax[1].set_xlabel("Virtual Position")
        ax[1].set_ylabel("Activity")
        ax[1].set_title(f"All Trials - ROI LOO: {self.extras[relmethod][idx_ses_plot][idx_roi_plot]:.2f}")

        ax[2].plot(self.spkmaps[idx_ses_plot][idx_roi_plot, trial], color="k", linewidth=1, label="Selected Trial")
        ax[2].plot(loo_average, color="r", linewidth=1, label="Leave-One-Out Average")
        ax[2].legend()
        ax[2].set_xlabel("Virtual Position")
        ax[2].set_ylabel("Activity")
        ax[2].set_title(f"Trial {trial}: Consistency: {trial_consistency[trial]:.2f}, Weight: {trial_weight[trial]:.2f}")
        return fig
    

viewer = MyViewer(spkmaps, extras)
deployment = NotebookDeployment(viewer, controls_position="left")
deployment.deploy()

HBox(children=(VBox(children=(IntSlider(value=0, continuous_update=False, description='idx_ses_plot', layout=L…

reliability_range (50, 55) (50, 55)


IndexError: index 0 is out of bounds for axis 0 with size 0

reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


reliability_range (90, 100) (90, 100)


In [5]:
from hosting.placefield_singlecells.place_field_viewer import PlaceFieldViewer
viewer = PlaceFieldViewer(fast_mode=True)

['ATL022', 'ATL027']


100%|██████████| 8/8 [00:13<00:00,  1.67s/it]0:00<?, ?it/s]
100%|██████████| 8/8 [00:10<00:00,  1.30s/it]0:18<00:18, 18.61s/it]
Preparing mouse data: 100%|██████████| 2/2 [00:33<00:00, 16.76s/it]


In [15]:
fig = viewer.get_plot("ATL022", 1, min_percentile=90, max_percentile=100, idx_target_ses=0, dead_trials=5, red_cells=True)
fig.show()

In [14]:
idxs = viewer._gather_idxs("ATL022", 0, 0, 100, red_cells=False)
idxs_redcells = viewer._gather_idxs("ATL022", 0, 00, 100, red_cells=True)

print(len(idxs), len(idxs_redcells))

863 3


In [12]:
red_only = [i for i in idxs_redcells if i not in idxs]
not_red = [i for i in idxs if i not in idxs_redcells]

print(len(red_only))
print(len(not_red))


0
18


In [16]:
fig = viewer.get_plot("ATL022", 15, min_percentile=0, max_percentile=100, idx_target_ses=0, dead_trials=5, rel_loo=True, fig_number=1)
fig = viewer.get_plot("ATL022", 89, min_percentile=0, max_percentile=100, idx_target_ses=0, dead_trials=5, rel_loo=True, fig_number=2)

In [3]:
from hosting.placefield_reliability.reliability_viewer import ReliabilityViewer

In [4]:
viewer = ReliabilityViewer(fast_mode=True)

['CR_Hippocannula6' 'CR_Hippocannula7']


Preparing mouse data: 100%|██████████| 2/2 [00:28<00:00, 14.05s/it]


In [5]:
fig = viewer.get_plot("CR_Hippocannula6", use_relcor=False, tracked=False, average=False, min_session=None, max_session=None)
fig.show()

In [13]:
print(viewer.idx_ses_first["CR_Hippocannula6"])
print(viewer.idx_ses_second["CR_Hippocannula6"])
print(viewer.rel_idx_ses_first["CR_Hippocannula6"])
print(viewer.rel_idx_ses_second["CR_Hippocannula6"])


[0, 1, 2, 4, 5, 6]
[1, 2, 3, 4, 5, 6]
[-1, 0, 1, 3, 4, 5]
[0, 1, 2, 3, 4, 5]


In [2]:
mousedb = database.vrDatabase("vrMice")
df = mousedb.getTable(trackerExists=True)
mouse_names = df["mouseName"].unique()
keep_planes = [1]

# Get data for a single mouse
mouse_name = "CR_Hippocannula6"
track = tracking.tracker(mouse_name)  # get tracker object for mouse
pcm = analysis.placeCellMultiSession(track, autoload=False, keep_planes=keep_planes)

In [3]:
envnum, idx_ses = pcm.env_idx_ses_selector(envmethod="second", sesmethod=6)
envnum_first = pcm.env_selector(envmethod="first")
idx_ses_first = pcm.env_stats()[envnum_first]

idx_ses = sorted(list(set(idx_ses_first) & set(idx_ses)))
spkmaps, extras = pcm.get_spkmaps(envnum=envnum, idx_ses=idx_ses, trials="full", average=False, tracked=True)
_, extras_first = pcm.get_spkmaps(envnum=envnum_first, idx_ses=idx_ses, trials="full", average=True, tracked=True)

avgmaps = [np.nanmean(s, axis=1) for s in spkmaps]
avgcenters = [np.nanargmax(s, axis=1) for s in avgmaps]
avgmax = [np.nanmax(s, axis=1) for s in avgmaps]
centers = np.stack(avgcenters, axis=0)
maxes = np.stack(avgmax, axis=0)

100%|██████████| 5/5 [00:13<00:00,  2.80s/it]


In [4]:
idx_red = np.logical_or(np.any(np.stack(extras["idx_red"]), axis=0), np.any(np.stack(extras_first["idx_red"]), axis=0))
ctl_reliability = [np.nanmean(relcor[~idx_red]) for relcor in extras["relcor"]]
ctl_reliability_first = [np.nanmean(relcor[~idx_red]) for relcor in extras_first["relcor"]]
red_reliability = [np.nanmean(relcor[idx_red]) for relcor in extras["relcor"]]
red_reliability_first = [np.nanmean(relcor[idx_red]) for relcor in extras_first["relcor"]]

fig, ax = plt.subplots(2, 1, figsize=(8, 7), layout="constrained")
ax[0].plot(range(len(ctl_reliability)), ctl_reliability, color="k", linewidth=1)
ax[0].plot(range(len(red_reliability)), red_reliability, color="r", linewidth=1)
ax[0].set_xlabel("Environment")
ax[0].set_ylabel("Reliability")
ax[0].set_title("Novel Environment")
ax[1].plot(range(len(ctl_reliability_first)), ctl_reliability_first, color="k", linewidth=1)
ax[1].plot(range(len(red_reliability_first)), red_reliability_first, color="r", linewidth=1)
ax[1].set_xlabel("Environment")
ax[1].set_ylabel("Reliability")
ax[1].set_title("Familiar Environment")
plt.show()

In [86]:
import numba as nb

def _avg_except_one(grand_average: np.ndarray, single_sample: np.ndarray, num_samples: int) -> np.ndarray:
    if grand_average.shape != single_sample.shape:
        raise ValueError("grand_average and single_sample must have the same shape")
    return (grand_average - (single_sample / num_samples)) * (num_samples / (num_samples - 1))

def reliability(spkmap: np.ndarray, weighted: bool = True) -> np.ndarray:
    """Measure reliability using leave-one-out method
    
    Spkmap is a 3d Array of shape (num_rois, num_trials, num_positions)
    """
    num_rois, num_trials = spkmap.shape[:2]
    trial_average = np.nanmean(spkmap, axis=1)
    trial_consistency = np.full((num_rois, num_trials), np.nan)
    for trial in range(num_trials):
        average_except_trial = _avg_except_one(trial_average, spkmap[:, trial], num_trials)
        trial_consistency[:, trial] = helpers.vectorCorrelation(average_except_trial, spkmap[:, trial], axis=1)
    # Use RMS activity on each trial as weights if requested
    weights = np.sqrt(np.mean(spkmap**2, axis=2)) if weighted else None
    score = np.average(trial_consistency, axis=1, weights=weights)
    return score, trial_consistency

rel_score, xx_trial_consistency = reliability(spkmaps[0])

@nb.njit(parallel=True, fastmath=True)
def _jit_reliability(spkmap: np.ndarray) -> np.ndarray:
    """Measure reliability using leave-one-out method"""
    num_rois, num_trials, num_positions = spkmap.shape
    trial_average = np.zeros((num_rois, num_positions))
    for roi in nb.prange(num_rois):
        for position in range(num_positions):
            trial_average[roi, position] = np.mean(spkmap[roi, :, position])
    
    trial_consistency = np.full((num_rois, num_trials), np.nan)
    for trial in nb.prange(num_trials):
        # First get the average excluding the current trial
        average_except_trial = (trial_average - (spkmap[:, trial] / num_trials)) * (num_trials / (num_trials - 1))

        for roi in range(num_rois):
            average_dev = average_except_trial[roi] - np.mean(average_except_trial[roi])
            average_std = np.sqrt(np.sum(average_dev**2) / (num_trials - 1))
            trial_dev = spkmap[roi, trial] - np.mean(spkmap[roi, trial])
            trial_std = np.sqrt(np.sum(trial_dev**2) / (num_trials - 1))
            if average_std == 0 or trial_std == 0:
                trial_consistency[roi, trial] = 0
            else:
                z_average_dev_norm = average_dev/average_std
                z_trial_dev_norm = trial_dev/trial_std
                z_product = np.sum(z_average_dev_norm * z_trial_dev_norm) / (num_trials - 1)
                trial_consistency[roi, trial] = z_product
    return trial_consistency

def jit_reliability(spkmap: np.ndarray, weighted: bool = True) -> np.ndarray:
    """Measure reliability using leave-one-out method"""
    trial_consistency = _jit_reliability(spkmap)
    # Use RMS activity on each trial as weights if requested
    weights = np.sqrt(np.mean(spkmap**2, axis=2)) if weighted else None
    score = np.average(trial_consistency, axis=1, weights=weights)
    return score, trial_consistency
 
rel_score, trial_consistency = reliability(spkmaps[0])
rel_score_jit, trial_consistency_jit = jit_reliability(spkmaps[0])

print(np.allclose(rel_score, rel_score_jit))
print(np.allclose(trial_consistency, trial_consistency_jit))
%timeit rel_score, trial_consistency = reliability(spkmaps[0])
%timeit rel_score_jit, trial_consistency_jit = jit_reliability(spkmaps[0])


True
True
419 ms ± 7.39 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
66.8 ms ± 1.02 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [70]:
plt.scatter(trial_consistency[0], trial_consistency_jit[0] - trial_consistency[0])
plt.show()


In [46]:
%timeit x = np.allclose(trial_average, np.nanmean(spkmaps[0], axis=1))
%timeit y = jit_reliability(spkmaps[0])
%timeit z = fs.nanmean(spkmaps[0], axis=1)

52.7 ms ± 1.39 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
6.19 ms ± 705 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
31.5 ms ± 1.27 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
