In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib qt

import random
from pathlib import Path

from tqdm import tqdm
import numpy as np
import scipy as sp
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
# import pandas as pd
# pd.options.display.width = 1000

from vrAnalysis import analysis
from vrAnalysis import helpers
from vrAnalysis import database
from vrAnalysis import tracking
from vrAnalysis import session
from vrAnalysis import registration
from vrAnalysis import fileManagement as fm

import faststats as fs

# from vrAnalysis.uiDatabase import addEntryGUI
from vrAnalysis.redgui import redCellGUI as rgui

from dimilibi import CrossCompare
from dimilibi import SVCANet, HurdleNet, BetaVAE
from dimilibi import Population
from dimilibi import SVCA
from dimilibi import PCA
from dimilibi import RidgeRegression, ReducedRankRegression
from dimilibi import LocalSimilarity, FlexibleFilter, EmptyRegularizer, BetaVAE_KLDiv

sessiondb = database.vrDatabase('vrSessions')
mousedb = database.vrDatabase('vrMice')

# pd.set_option('display.max_rows', 100)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [21]:
mouse_name = "ATL058"
track = tracking.tracker(mouse_name)
pcm = analysis.placeCellMultiSession(track, autoload=False, keep_planes=[1, 2, 3, 4], speedThreshold=1)
env_stats = pcm.env_stats()
print(env_stats)

envs = list(env_stats.keys())
first_session = [env_stats[env][0] for env in envs]
idx_first_session = np.argsort(first_session)

# use environment that was introduced second
use_environment = envs[idx_first_session[1]]
idx_ses = env_stats[use_environment][: min(12, len(env_stats[use_environment]))]

if len(idx_ses) < 2:
    # Attempt to use first environment if not enough sessions in second
    use_environment = envs[idx_first_session[0]]
    idx_ses = env_stats[use_environment][: min(12, len(env_stats[use_environment]))]

if len(idx_ses) < 2:
    print(f"Skipping {mouse_name} due to not enough sessions!")

print(use_environment, idx_ses)

{1: [0, 1, 2, 3, 4, 5, 6, 7, 8], 3: [1, 2, 3, 4, 5, 6, 7, 8], 4: [6, 7, 8]}
3 [1, 2, 3, 4, 5, 6, 7, 8]


In [22]:
envnum = use_environment
max_diff = 4
relcor_cutoff = 0.5
smooth = 10

bins_cor = np.linspace(-1, 1, 21)
bins_mse = np.linspace(-4, 1, 21)
centers_cor = helpers.edge2center(bins_cor)
centers_mse = helpers.edge2center(bins_mse)

def make_histograms(pcm, idx_ses):
    pcm.load_pcss_data(idx_ses=idx_ses)
    ctl_relcor = []
    red_relcor = []
    ctl_relmse = []
    red_relmse = []
    for idx in tqdm(idx_ses, desc="Measuring reliability...", unit="session"):
        relmse, relcor = map(lambda x: x[0], pcm.pcss[idx].get_reliability_values(envnum=envnum))
        idx_red = pcm.pcss[idx].vrexp.getRedIdx(keep_planes=pcm.keep_planes)
        ctl_relcor.append(helpers.fractional_histogram(relcor[~idx_red], bins_cor)[0])
        red_relcor.append(helpers.fractional_histogram(relcor[idx_red], bins_cor)[0])
        ctl_relmse.append(helpers.fractional_histogram(relmse[~idx_red], bins_mse)[0])
        red_relmse.append(helpers.fractional_histogram(relmse[idx_red], bins_mse)[0])

    return ctl_relcor, red_relcor, ctl_relmse, red_relmse

ctl_relcor, red_relcor, ctl_relmse, red_relmse = make_histograms(pcm, idx_ses)

100%|██████████| 8/8 [00:53<00:00,  6.64s/it]
Measuring reliability...: 100%|██████████| 8/8 [00:00<00:00, 319.99session/s]


In [23]:
fig, ax = plt.subplots(2, len(idx_ses), figsize=(15, 7), layout="constrained")
for idx, (cc, rc, cm, rm) in enumerate(zip(ctl_relcor, red_relcor, ctl_relmse, red_relmse)):
    ax[0, idx].plot(centers_cor, cc, linewidth=1.2, color="k", label="Ctl")
    ax[0, idx].plot(centers_cor, rc, linewidth=1.2, color="r", label="Red")
    ax[1, idx].plot(centers_mse, cm, linewidth=1.2, color="k", label="Ctl")
    ax[1, idx].plot(centers_mse, rm, linewidth=1.2, color="r", label="Red")
    ax[0, idx].set_title(f"Session {idx_ses[idx]}")
    ax[0, idx].set_xlabel("Correlation")
    ax[1, idx].set_xlabel("MSE")
    ax[0, idx].set_ylabel("Fraction")
    ax[1, idx].set_ylabel("Fraction")
plt.show()

In [2]:
mouse_name = "ATL060"
track = tracking.tracker(mouse_name)
pcm = analysis.placeCellMultiSession(track, autoload=False, keep_planes=[1, 2, 3, 4], speedThreshold=1)
env_stats = pcm.env_stats()
print(env_stats)

{1: [4, 5, 6], 3: [1, 2, 3, 4, 5, 6], 4: [0, 1, 2, 3, 4, 5, 6]}


In [7]:
envs = list(env_stats.keys())
first_session = [env_stats[env][0] for env in envs]
idx_first_session = np.argsort(first_session)

# use environment that was introduced second
use_environment = envs[idx_first_session[1]]
idx_ses = env_stats[use_environment][: min(12, len(env_stats[use_environment]))]

if len(idx_ses) < 2:
    # Attempt to use first environment if not enough sessions in second
    use_environment = envs[idx_first_session[0]]
    idx_ses = env_stats[use_environment][: min(12, len(env_stats[use_environment]))]

if len(idx_ses) < 2:
    print(f"Skipping {mouse_name} due to not enough sessions!")

print(envs, first_session, use_environment)

[1, 3, 4] [4, 1, 0] 3


In [9]:
envnum = use_environment
max_diff = 4
relcor_cutoff = 0.5
smooth = 10

def make_corrs(pcm, idx_ses, max_diff=None):
    pcm.load_pcss_data(idx_ses=idx_ses)
    idx_pairs = helpers.all_pairs(idx_ses)
    if max_diff is not None:
        idx_pairs = idx_pairs[np.abs(idx_pairs[:, 1] - idx_pairs[:, 0]) <= max_diff]
    ses_diffs = idx_pairs[:, 1] - idx_pairs[:, 0]
    ctlcorrs = []
    redcorrs = []
    for idx in tqdm(idx_pairs, desc="computing correlations", leave=False):
        spkmaps, _, relcor, _, _, idx_red, _ = pcm.get_spkmaps(envnum, trials="full", pop_nan=True, smooth=smooth, average=True, idx_ses=idx)
        any_rel = np.any(np.stack(relcor) > relcor_cutoff, axis=0)
        any_red = np.any(np.stack(idx_red), axis=0)
        print(idx, np.sum(any_rel), np.sum(any_red), np.sum(any_rel & any_red))
        ctlmaps = [s[~any_red & any_rel] for s in spkmaps]
        redmaps = [s[any_red & any_rel] for s in spkmaps]
        ctlcorrs.append(helpers.vectorCorrelation(ctlmaps[0], ctlmaps[1], axis=1))
        redcorrs.append(helpers.vectorCorrelation(redmaps[0], redmaps[1], axis=1))
    return ctlcorrs, redcorrs, ses_diffs

ctlcorrs, redcorrs, ses_diffs = make_corrs(pcm, idx_ses, max_diff=max_diff)

computing correlations:   7%|▋         | 1/14 [00:01<00:23,  1.77s/it]

[1 2] 378 53 10


computing correlations:  14%|█▍        | 2/14 [00:03<00:21,  1.76s/it]

[1 3] 347 48 11


computing correlations:  21%|██▏       | 3/14 [00:04<00:17,  1.58s/it]

[1 4] 178 30 5


computing correlations:  29%|██▊       | 4/14 [00:06<00:15,  1.56s/it]

[1 5] 161 30 5


computing correlations:  36%|███▌      | 5/14 [00:08<00:14,  1.64s/it]

[2 3] 332 46 6


computing correlations:  43%|████▎     | 6/14 [00:09<00:12,  1.59s/it]

[2 4] 179 34 1


computing correlations:  50%|█████     | 7/14 [00:11<00:10,  1.57s/it]

[2 5] 172 32 2


computing correlations:  57%|█████▋    | 8/14 [00:12<00:09,  1.53s/it]

[2 6] 127 31 1


computing correlations:  64%|██████▍   | 9/14 [00:14<00:07,  1.51s/it]

[3 4] 169 36 2


computing correlations:  71%|███████▏  | 10/14 [00:15<00:06,  1.51s/it]

[3 5] 152 34 3


computing correlations:  79%|███████▊  | 11/14 [00:16<00:04,  1.46s/it]

[3 6] 130 31 4


computing correlations:  86%|████████▌ | 12/14 [00:18<00:02,  1.43s/it]

[4 5] 243 50 2


computing correlations:  93%|█████████▎| 13/14 [00:19<00:01,  1.37s/it]

[4 6] 208 48 3


                                                                       

[5 6] 148 41 3




In [11]:
bins = np.linspace(-1, 1, 11)
centers = helpers.edge2center(bins)

num_diffs = len(np.unique(ses_diffs))
ctl_counts = np.zeros((num_diffs, len(centers)))
red_counts = np.zeros((num_diffs, len(centers)))
for idiff in range(num_diffs):
    idx = ses_diffs == (idiff + 1)
    c_ctlcorrs = np.concatenate([c for i, c in enumerate(ctlcorrs) if idx[i]])
    c_redcorrs = np.concatenate([c for i, c in enumerate(redcorrs) if idx[i]])
    ctl_counts[idiff] = helpers.fractional_histogram(c_ctlcorrs, bins=bins)[0]
    red_counts[idiff] = helpers.fractional_histogram(c_redcorrs, bins=bins)[0]

fig, ax = plt.subplots(2, num_diffs, figsize=(12, 4), layout="constrained")
for i in range(num_diffs):
    ax[0, i].plot(centers, ctl_counts[i], color="k", lw=1)
    ax[0, i].plot(centers, red_counts[i], color="r", lw=1)
    ax[0, i].set_title(f"$\Delta$ Session: {i + 1}")
    ax[1, i].plot(centers, red_counts[i] - ctl_counts[i], color="k", lw=1)
plt.show()

In [271]:
envnum = 3
idx_ses = [1, 2, 3, 4]
num_ses = len(idx_ses)
spkmaps, relmse, relcor, pfloc, pfidx, idx_red, roi_idx = pcm.get_spkmaps(envnum, trials="full", average=True, idx_ses=idx_ses)
any_red = np.any(np.stack(idx_red), axis=0)
print([s.shape for s in spkmaps])
print([i.shape for i in idx_red])
print([np.sum(i) for i in idx_red], np.sum(any_red))

ctlmaps = [s[~any_red] for s in spkmaps]
ctlidx = [np.argsort(p[~any_red]) for p in pfloc]
redmaps = [s[any_red] for s in spkmaps]
redidx = [np.argsort(p[any_red]) for p in pfloc]

print([s.shape for s in ctlmaps])
print([s.shape for s in redmaps])
print([s.shape for s in ctlidx])
print([s.shape for s in redidx])

100%|██████████| 4/4 [00:15<00:00,  3.82s/it]


[(1093, 194), (1093, 194), (1093, 194), (1093, 194)]
[(1093,), (1093,), (1093,), (1093,)]
[15, 18, 17, 15] 22
[(1071, 194), (1071, 194), (1071, 194), (1071, 194)]
[(22, 194), (22, 194), (22, 194), (22, 194)]
[(1071,), (1071,), (1071,), (1071,)]
[(22,), (22,), (22,), (22,)]


In [272]:
iref = 3

ctl_corr_to_ref = [helpers.vectorCorrelation(cm, ctlmaps[iref], axis=1) for cm in ctlmaps]
red_corr_to_ref = [helpers.vectorCorrelation(rm, redmaps[iref], axis=1) for rm in redmaps]

max_pos = ctlmaps[0].shape[1]
num_ctl = ctlmaps[0].shape[0]
num_red = redmaps[0].shape[0]
extents = [[0, max_pos, 0, num] for num in [num_ctl, num_red]]

# Change default font size
plt.rcParams.update({'font.size': 18})
fig, ax = plt.subplots(2, num_ses, figsize=(12, 6), layout="constrained")
for ises in range(num_ses):
    ax[0, ises].imshow(ctlmaps[ises][ctlidx[iref]], extent=extents[0], aspect="auto", vmin=0, vmax=2, cmap="hot", interpolation="none")
    ax[1, ises].imshow(redmaps[ises][redidx[iref]], extent=extents[1], aspect="auto", vmin=0, vmax=2, cmap="hot", interpolation="none")
    ax[0, ises].set_yticks([])
    ax[1, ises].set_yticks([])
    ax[1, ises].set_xlabel(f"Virtual Position (cm)")
    if ises==0:
        ax[1, ises].set_ylabel("Red Cells")
        ax[0, ises].set_ylabel("Control Cells")
    ax[0, ises].set_title(("REFERENCE\n" if ises==iref else "") + f"Session {idx_ses[ises]}")
plt.show()

In [42]:
bins = np.linspace(-1, 1, 9)
centers = helpers.edge2center(bins)

ctl_corr_hist = [helpers.fractional_histogram(c, bins=bins)[0] for c in ctl_corr_to_ref]
red_corr_hist = [helpers.fractional_histogram(c, bins=bins)[0] for c in red_corr_to_ref]

fig, ax = plt.subplots(1, num_ses-1, figsize=(6, 4), layout="constrained")
iplot = 0
for ises in range(num_ses):
    if ises == iref:
        continue
    ax[iplot].plot(centers, ctl_corr_hist[ises], color="k", lw=1)
    ax[iplot].plot(centers, red_corr_hist[ises], color="r", lw=1)
    iplot += 1
plt.show()

In [27]:
import umap

mouse_name = "ATL027"
ses = random.choice(sessiondb.iterSessions(mouseName=mouse_name, experimentID=3, imaging=True))
print(ses.sessionPrint())

# Load the session
pcss = analysis.placeCellSingleSession(ses, keep_planes=[1, 2])

ATL027/2023-08-08/701


In [273]:
average = True
smooth = None
spkmaps = pcss.get_spkmap(average=average, smooth=smooth, trials="full")
num_pos = spkmaps[0].shape[-1]
if not average:
    env_trials = [s.shape[1] for s in spkmaps]
    min_trials = min(env_trials)
    idx_use_trials = [np.random.permutation(etr)[:min_trials] for etr in env_trials]
    spkmaps = [s[:, idx_use_trials[i]] for i, s in enumerate(spkmaps)]
    # each row is a neuron with each trial concatenated along columns
    spkmaps = [s.reshape(s.shape[0], -1) for s in spkmaps]

reliable_only = True
if reliable_only:
    idx_reliable = pcss.get_reliable(cutoffs=(0.3, 0.6))
    any_reliable = np.any(np.stack(idx_reliable, axis=0), axis=0)
    spkmaps = [s[any_reliable] for s in spkmaps]

pos_colormaps = ["coolwarm", "coolwarm", "coolwarm"] #, "spring", "cool"]
env_colors = ["k", "r", "b"]

pos_colors = [mpl.colormaps[cm](np.linspace(0, 1, num_pos)) for cm in pos_colormaps]
env_colors = [np.tile(np.array(mpl.colors.to_rgba(c)).reshape(1, -1), (num_pos, 1)) for c in env_colors]
if not average:
    pos_colors = [np.tile(pc, (min_trials, 1)) for pc in pos_colors]
    env_colors = [np.tile(ec, (min_trials, 1)) for ec in env_colors]
print([s.shape for s in spkmaps])
print([s.shape for s in env_colors], [s.shape for s in pos_colors])

NameError: name 'pcss' is not defined

In [146]:
test_env = 2

train_data = np.concatenate([spkmaps[envidx] for envidx in range(len(spkmaps)) if envidx != test_env], axis=1)
test_data = spkmaps[test_env]

train_colors_pos = np.concatenate([pos_colors[envidx] for envidx in range(len(spkmaps)) if envidx != test_env], axis=0)
test_colors_pos = pos_colors[test_env]
train_colors_env = np.concatenate([env_colors[envidx] for envidx in range(len(spkmaps)) if envidx != test_env], axis=0)
test_colors_env = env_colors[test_env]

reducer = umap.UMAP(n_neighbors=5, n_components=2).fit(train_data.T)
train_embedding = reducer.transform(train_data.T)
test_embedding = reducer.transform(test_data.T)

reducer_full = umap.UMAP(n_neighbors=5, n_components=2).fit(np.concatenate(spkmaps, axis=1).T)
full_embedding = reducer_full.transform(np.concatenate(spkmaps, axis=1).T)

print(train_embedding.shape, test_embedding.shape, full_embedding.shape)

(390, 2) (195, 2) (585, 2)


In [147]:
fig, ax = plt.subplots(2, 2, figsize=(8, 8))
ax[0, 0].scatter(train_embedding[:, 0], train_embedding[:, 1], c=train_colors_pos, s=10)
ax[0, 0].scatter(test_embedding[:, 0], test_embedding[:, 1], c=test_colors_pos, s=10)
ax[0, 1].scatter(train_embedding[:, 0], train_embedding[:, 1], c=train_colors_env, s=10)
ax[0, 1].scatter(test_embedding[:, 0], test_embedding[:, 1], c=test_colors_env, s=10)
ax[1, 0].scatter(full_embedding[:, 0], full_embedding[:, 1], c=np.concatenate(pos_colors, axis=0), s=10)
ax[1, 1].scatter(full_embedding[:, 0], full_embedding[:, 1], c=np.concatenate(env_colors, axis=0), s=10)

for a in ax.flatten():
    a.set_xticks([])
    a.set_yticks([])
for a in ax[1, :]:
    a.set_xlabel("UMAP 1")
for a in ax[:, 0]:
    a.set_ylabel("UMAP 2")
ax[0, 0].set_title("Color by Position")
ax[0, 1].set_title("Color by Environment")
ax[0, 0].set_ylabel("Train(BlackRed) vs Test(BLUE)\n\n\nUMAP 2")
ax[1, 0].set_ylabel("Full Embedding\n\n\nUMAP 2")
plt.show()

In [124]:
import matplotlib.gridspec as gridspec

spkmaps = pcss.get_spkmap(average=True, trials="full")
idx_reliable = pcss.get_reliable(cutoffs=(0.3, 0.6))
any_reliable = np.any(np.stack(idx_reliable, axis=0), axis=0)

def get_kernels(data):
    dot_product = np.dot(data.T, data)
    norms = np.linalg.norm(data, axis=0)
    cosine_angle = dot_product / np.outer(norms, norms)
    kernel = np.corrcoef(data.T)
    return cosine_angle, kernel

angle, kernel = get_kernels(np.concatenate(spkmaps, axis=1))
rel_angle, rel_kernel = get_kernels(np.concatenate(spkmaps, axis=1)[any_reliable])

edges = [num_pos * i for i in range(1, len(spkmaps))]
extent = [0, angle.shape[0], 0, angle.shape[1]]
ticks = [num_pos/2 + num_pos * i for i in range(len(spkmaps))]
labels = [f"Env {i}" for i in range(len(spkmaps))]

cmap = "bwr"

# Create figure and gridspec
fig = plt.figure(figsize=(8, 8))
gs = gridspec.GridSpec(2, 3, width_ratios=[1, 1, 0.2])

# Create the main plots
ax00 = fig.add_subplot(gs[0, 0])
ax01 = fig.add_subplot(gs[0, 1])
ax10 = fig.add_subplot(gs[1, 0])
ax11 = fig.add_subplot(gs[1, 1])

# Create colorbar axis that spans both rows
cax = fig.add_subplot(gs[:, 2])  # This spans both rows

# Create the plots
im = ax00.imshow(angle, extent=extent, cmap=cmap, vmin=-1, vmax=1)
ax01.imshow(kernel, extent=extent, cmap=cmap, vmin=-1, vmax=1)
ax10.imshow(rel_angle, extent=extent, cmap=cmap, vmin=-1, vmax=1)
ax11.imshow(rel_kernel, extent=extent, cmap=cmap, vmin=-1, vmax=1)

# Add grid lines
for edge in edges:
    for ax in [ax00, ax01, ax10, ax11]:
        ax.axhline(edge, color="k", lw=0.5)
        ax.axvline(edge, color="k", lw=0.5)
for ax in [ax00, ax01, ax10, ax11]:
    ax.set_xticks(ticks)
    ax.set_xticklabels(labels)
    ax.set_yticks([])
for ax in [ax00, ax10]:
    ax.set_yticks(ticks)
    ax.set_yticklabels(reversed(labels))
ax00.set_title("Cosine Angle")
ax01.set_title("Correlation")
ax00.set_ylabel("All cells")
ax10.set_ylabel("Reliable cells")

# Create colorbar
plt.colorbar(im, cax=cax)

plt.show()

In [3]:
%reload_ext autoreload
%autoreload 2

from vrAnalysis2.external.pettit2022 import find_pettit_harvey_sessions, data_path
sessions = find_pettit_harvey_sessions(data_path / "dataFolder")

behavior = sessions[0].behavior
spks = sessions[0].spks

In [None]:
from pathlib import Path
from vrAnalysis2 import files
from vrAnalysis import database
sessiondb = database.vrDatabase('vrSessions')
from typing import Union

def find_experiment_options(root_dir: Union[str, Path]) -> list[Path]:
    """
    Find all vrExperimentOptions.json files in the given directory and its subdirectories.
    
    Parameters
    ----------
    root_dir : str or Path
        The root directory to start the search from
    
    Returns
    -------
    list[Path]
        List of paths to all matching files
    """
    def make_identifier(pth: Path) -> list[str]:
        return "_".join(list(reversed([p.stem for p in list(pth.parents)[:3]])))
    
    root_path = Path(root_dir)
    all_paths = list(root_path.rglob("vrExperimentOptions.json"))
    session_identifier = [make_identifier(pth) for pth in all_paths]
    return all_paths, session_identifier

pths, sids = find_experiment_options(files.local_data_path())
csesids = [sessiondb.sessionPrint(joinby="_") for sessiondb in sessiondb.iterSessions(useDefault=True)]

for sid in sids:
    if sid not in csesids:
        print("oops", sid)

for ses in sessiondb.iterSessions():
    csesid = ses.sessionPrint(joinby="_")
    print(csesid, csesid in sids)

In [None]:
sessiondb.printSessions(mouseName="ATL057")

In [183]:
def r(d, length_scale=1.0):
    """Covariance function r(x - x') for scalar distance d"""
    return np.exp(-0.5 * (d**2) / length_scale**2)

def generate_cov(x, length_scale=1.0):
    """Generate covariance matrix K(x, x') for a given x"""
    distances = np.abs(x - x[0])
    first_row = r(distances, length_scale)
    return sp.linalg.toeplitz(first_row)

L = 0.2
NP = 801
N = 1e6
sigma = 0.015
theta = 1.3
x = np.linspace(0, L, NP)  # 1D space from 0 to 10
K = generate_cov(x, length_scale=sigma)  # Generate covariance matrix
h = np.random.multivariate_normal(mean=np.zeros(len(x)), cov=K, size=int(N))  # Generate h(x) as a sample from GP[0, r(x - x')]
f = np.maximum(0, h - theta)
Ch = np.cov(h.T)
Cf = np.cov(f.T)

In [234]:
from argparse import ArgumentParser
from vrAnalysis.helpers import AttributeDict, cutoff_type, positive_float
from vrAnalysis.analysis.variance_structure import load_spectra_data

MOUSE_NAMES = [
    "CR_Hippocannula6",
    "CR_Hippocannula7",
    "ATL022",
    "ATL027",
    "ATL028",
    "ATL020",
    "ATL012",
    "ATL045",
]
CUTOFFS = (0.4, 0.7)
MAXCUTOFFS = None

def get_spectra(mouse_name, args):
    """method for analyzing and plotting spectra with cvPCA and cvFOURIER analyses"""
    # load spectra data (use temp if it matches)
    track = tracking.tracker(mouse_name)  # get tracker object for mouse
    pcm = analysis.placeCellMultiSession(track, autoload=False)  # open up place cell multi session analysis object (don't autoload!!!)

    single_args = AttributeDict(vars(args))
    single_args["mouse_name"] = mouse_name

    spectra_dictionary = load_spectra_data(pcm, single_args, save_as_temp=False, reload=False)

    # return the dictionary
    return spectra_dictionary


def handle_inputs(inputs=["--do-spectra"]):
    """method for creating and parsing input arguments"""
    parser = ArgumentParser(description="do summary plots for a mouse")
    parser.add_argument(
        "--mouse-names",
        type=str,
        nargs="*",
        default="processed",
        help="which mice to compare (list of mouse names, or like default), (default='all')",
    )
    parser.add_argument("--cutoffs", nargs="*", type=cutoff_type, default=CUTOFFS, help=f"cutoffs for reliability (default={CUTOFFS})")
    parser.add_argument("--maxcutoffs", nargs="*", type=cutoff_type, default=MAXCUTOFFS, help="maxcutoffs for reliability cells (default=None)")
    parser.add_argument("--do-spectra", default=False, action="store_true", help="create spectrum plots for mouse (default=False)")
    parser.add_argument("--dist-step", default=1, type=float, help="dist-step for creating spkmaps (default=1cm)")
    parser.add_argument("--smooth", default=0.1, type=positive_float, help="smoothing width for spkmaps (default=0.1cm)")
    parser.add_argument("--reload-spectra-data", default=False, action="store_true", help="reload spectra data (default=False)")
    args = parser.parse_args(inputs)

    # if mouse_names is "all", get all mouse names from the database
    if args.mouse_names == "all":
        # mousedb = database.vrDatabase("vrSessions")
        mousedb = database.vrDatabase("vrMice")
        df = mousedb.getTable(trackerExists=True)
        mouse_names = df["mouseName"].unique()
        args.mouse_names = mouse_names
    elif args.mouse_names == "processed":
        args.mouse_names = MOUSE_NAMES

    # return the parsed arguments
    return args

# analyze spectra and make plots
args = handle_inputs()
pcms = []
spectra_data = []
for mouse in MOUSE_NAMES:
    print(f"Getting spectra data for {mouse}")
    spectra_data.append(get_spectra(mouse, args))  # Each is a dictionary of all the spectral output data
    c_track = tracking.tracker(mouse)
    c_pcm = analysis.placeCellMultiSession(c_track, autoload=False)
    pcms.append(c_pcm)

Getting spectra data for CR_Hippocannula6
Successfully loaded temporary data for variance structure analysis.
Getting spectra data for CR_Hippocannula7
Successfully loaded temporary data for variance structure analysis.
Getting spectra data for ATL022
Successfully loaded temporary data for variance structure analysis.
Getting spectra data for ATL027
Successfully loaded temporary data for variance structure analysis.
Getting spectra data for ATL028
Successfully loaded temporary data for variance structure analysis.
Getting spectra data for ATL020
Successfully loaded temporary data for variance structure analysis.
Getting spectra data for ATL012
Successfully loaded temporary data for variance structure analysis.
Getting spectra data for ATL045
Successfully loaded temporary data for variance structure analysis.


In [244]:
for fieldName in sessiondb.tableData()[0]:
    print(fieldName)

uSessionID
mouseName
sessionDate
sessionID
experimentType
experimentID
variableGain
behavior
imaging
faceCamera
vrEnvironments
headPlateRotation
numPlanes
planeSeparation
pockelsPercentage
objectiveRotation
vrRegistration
suite2p
suite2pQC
redCellQC
sessionQC
scratchJustification
logtime
sessionNotes
suite2pDate
vrRegistrationDate
vrRegistrationError
vrRegistrationException
redCellQCDate
vrBehaviorVersion
dontTrack


In [247]:
ises = np.random.choice(sessiondb.iterSessions(imaging=True, vrRegistration=True, experimentID=1), 10)
for ses in ises: print(ses.sessionPrint())

ATL020/2023-04-05/701
ATL022/2023-04-06/701
ATL022/2023-03-27/701
ATL023/2023-04-28/702
ATL045/2024-01-26/701
ATL020/2023-03-31/701
ATL022/2023-03-27/701
ATL020/2023-04-04/701
ATL045/2024-01-24/701
ATL027/2023-07-21/701


In [248]:
pcss = [analysis.placeCellSingleSession(ses) for ses in ises]
spkmaps = [p.get_spkmap(average=True, trials="full")[0] for p in pcss]
print([s.shape for s in spkmaps])

[(10996, 200), (14021, 200), (13940, 200), (12310, 200), (11086, 200), (11902, 200), (13940, 200), (10254, 200), (14780, 200), (12839, 200)]


In [273]:
idx_nan = np.any(np.stack([np.any(np.isnan(s), axis=0) for s in spkmaps], axis=0), axis=0)
spkmaps = [s[:, ~idx_nan] for s in spkmaps]
kernels = [np.cov(s.T) for s in spkmaps]
def get_kfunc(kernel, rows):
    kfunc = []
    for r in range(rows):
        kfunc.append(kernel[r][r:])
    max_length = max([len(k) for k in kfunc])
    for r in range(rows):
        kfunc[r] = np.concatenate([kfunc[r], np.zeros(max_length - len(kfunc[r]))])
    return np.stack(kfunc)
kfuns = [get_kfunc(k, 100) for k in kernels]
avg_kfuns = np.stack([np.mean(k, axis=0) for k in kfuns])
avg_kfuns = avg_kfuns / np.max(avg_kfuns, axis=1, keepdims=True)

In [279]:
from pathlib import Path
from scipy.optimize import curve_fit

fpath = Path(r"C:\Users\Andrew\Documents\GitHub\vrAnalysis\figures\plots_for_dataclub_241111")
orange = np.array([241, 80, 15]) / 255

def lorentz(x, alpha, magnitude):
    return magnitude * alpha / (x**2 + alpha**2)

# Fit the data
popt, pcov = curve_fit(lorentz, x, Cf[0] / max(Cf[0]), p0=[1.0, 1.0], bounds=(0, np.inf))
print(popt)

rlorentz = lorentz(x, popt[0], popt[1])

xcm = x * 1000
Lcm = L * 1000
xcm_kf = np.linspace(0, max(xcm), avg_kfuns.shape[1])

plt.rcParams.update({'font.size': 14})

vmin = -1
vmax = 1

plt.close('all')
fig, ax = plt.subplots(1, 4, figsize=(18, 4), layout="constrained")
ax[0].imshow(Cf / np.max(Cf), extent=[0, Lcm, 0, Lcm], cmap="bwr", vmin=vmin, vmax=vmax)
ax[1].imshow(sp.linalg.toeplitz(rlorentz), extent=[0, Lcm, 0, Lcm], cmap="bwr", vmin=vmin, vmax=vmax)
ax[2].plot(xcm, Cf[0] / max(Cf[0]), c=orange, label="f-Covariance")
ax[2].plot(xcm, rlorentz, c='k', label="Lorentz")
ax[2].plot(xcm_kf, np.mean(avg_kfuns, axis=0), c='b', label="data")
ax[3].imshow(kernels[0] / np.max(kernels[0]), extent=[0, 100, 0, 100], cmap="bwr", vmin=vmin, vmax=vmax)

ax[2].set_xlim(0, 120)

ax[0].set_xlabel("Position")
ax[0].set_ylabel("Position")
ax[0].set_title("f(x) Kernel")
ax[1].set_xlabel("Position")
ax[1].set_title("Lorentz Kernel")
ax[2].set_xlabel("Displacement")
ax[2].set_ylabel("Correlation")
ax[2].legend(loc="upper right", fontsize=14)

ax[3].set_xlabel("Position")
ax[3].set_ylabel("Position")
ax[3].set_title("Data Kernel")
plt.show()
helpers.save_figure(fig, fpath / "lorentz_comparison_withdata.png")

# # w, v = helpers.smart_pca(Ch)
# wf, vf = helpers.smart_pca(Cf)
# w = w / np.sum(w)
# wf = wf / np.sum(wf)

# fig, ax = plt.subplots(1, 1, figsize=(7, 5), layout="constrained")
# ax.plot(range(1, len(w)+1), w, label="Gaussian Kernel", linewidth=2, color="k")
# ax.plot(range(1, len(wf)+1), wf, label="Thresholded Kernel", linewidth=2, color=orange)
# ax.set_yscale("log")
# ax.set_xlim(0, 81)
# ax.set_ylim(1e-10, 1)
# ax.legend(loc="upper right")
# ax.text(17, 3e-4, "<---leading values are linear", color=orange)
# plt.show()
# helpers.save_figure(fig, fpath / "GP_Model_1.png")

# fig, ax = plt.subplots(1, 1, figsize=(7, 5), layout="constrained")
# ax.plot(range(1, len(w)+1), w, label="Gaussian Kernel", linewidth=2, color="k")
# # ax.plot(range(1, len(wf)+1), wf, label="Thresholded Kernel", linewidth=2, color="b")
# ax.set_yscale("log")
# ax.set_xlim(0, 81)
# ax.set_ylim(1e-10, 1)
# ax.legend(loc="upper right")
# # ax.text(17, 3e-4, "<---leading values are linear", color="b")
# helpers.save_figure(fig, fpath / "GP_Model_0.png")
# plt.show()

[0.01009633 0.01075494]


In [185]:
w.shape, wf.shape, Ch.shape

((801,), (801,), (801, 801))

In [23]:
w, v = helpers.smart_pca(Cf) #sp.linalg.toeplitz(rlorentz))

plt.plot(w)
plt.yscale('log')
plt.show()

In [None]:
# Also get real place field data
mouse_name = "ATL027"
track = tracking.tracker(mouse_name)
pcm = analysis.placeCellMultiSession(track, autoload=False)
ises = 8
pcss = analysis.placeCellSingleSession(pcm.pcss[ises].vrexp, keep_planes=[1, 2, 3, 4], autoload=False)
split_params = dict(total_folds=2, train_folds=1)
pcss.define_train_test_split(**split_params)
pcss.load_data(new_split=False)

In [None]:
N, P, T = 3000, 200, 100

xpos = np.linspace(0, P, P)

method = "relugp"

if method == "rbf":
    pf_loc = np.linspace(0, P, N) # place field location
    pf_width = 1.0 * np.random.rand(N) + 2.5 # place field width
    pf_basis = np.exp(-(pf_loc[:, None] - xpos[None, :]) ** 2 / 2 / pf_width[:, None] ** 2) # shape of place field
    
elif method == "relugp":
    def r(d, length_scale=1.0):
        """Covariance function r(x - x') for scalar distance d"""
        return np.exp(-0.5 * (d**2) / length_scale**2)

    def generate_cov(x, length_scale=1.0):
        """Generate covariance matrix K(x, x') for a given x"""
        distances = np.abs(x - x[0])
        first_row = r(distances, length_scale)
        return sp.linalg.toeplitz(first_row)
    
    L = 200 / 1000
    fs = 0.001
    NP = int(L / fs)
    sigma = 0.015
    theta = 1.3
    x = np.linspace(0, L, NP)  # 1D space from 0 to 10
    K = generate_cov(x, length_scale=sigma)  # Generate covariance matrix
    h = np.random.multivariate_normal(mean=np.zeros(len(x)), cov=K, size=N*10)  # Generate h(x) as a sample from GP[0, r(x - x')]
    pf_basis = np.maximum(0, h - theta)
    idx_with_pf = np.where(np.any(pf_basis > 0, axis=1))[0]
    pf_basis = pf_basis[idx_with_pf]
    pf_basis = pf_basis[np.random.permutation(pf_basis.shape[0])[:N]]
    if pf_basis.shape[0] < N:
        raise ValueError("Not enough place fields")
    
    idx_sort = np.argsort(np.argmax(pf_basis, axis=1))
    pf_basis = pf_basis[idx_sort]
    pf_loc = np.argmax(pf_basis, axis=1)
    pf_basis = pf_basis / np.max(pf_basis, axis=1)[:, None]
    
    print(np.sum(np.any(pf_basis > 0, axis=1)) / N)

else:
    raise ValueError("Method not recognized")

# Generate some place field properties
beta_val = 0.1
prob_pf = np.random.beta(beta_val, beta_val, N)
# prob_pf = np.random.rand(N) ** 5.0 # probability of expressing place field
noise_value = 0.5

# Generate place field data
pf_trial = np.random.rand(N, T) < prob_pf[:, None] # place field expression per trial
pf_activity = pf_basis[:, :, None] * pf_trial[:, None, :] # place field activity
noise_activity = np.random.randn(N, P, T) * noise_value # noise activity
data = pf_activity + noise_activity

train_data = np.mean(data[:, :, :T//2], axis=2)
test_data = np.mean(data[:, :, T//2:], axis=2)

# Get real place field data
envidx = 1
train_spkmaps = pcss.get_spkmap(trials="train", average=True)
test_spkmaps = pcss.get_spkmap(trials="test", average=True)
idx_nan = np.any(np.stack([np.any(np.isnan(spkmap), axis=0) for spkmap in train_spkmaps+test_spkmaps]), axis=0)
train_spkmaps = [spkmap[:, ~idx_nan] for spkmap in train_spkmaps]
test_spkmaps = [spkmap[:, ~idx_nan] for spkmap in test_spkmaps]
train_spkmap = train_spkmaps[envidx]
test_spkmap = test_spkmaps[envidx]

# Run cvPCA Analyses
nc = 80
cvpca = helpers.cvPCA(train_data.T, test_data.T, nc=nc)
truev = helpers.cvPCA(pf_basis.T, pf_basis.T, nc=nc)

# Run on real mouse data
cvpca_mouse = helpers.cvPCA(train_spkmap.T, test_spkmap.T, nc=nc)

cvpca_v = helpers.smart_pca(train_data, centered=True)[1][:, :nc]
train_proj = cvpca_v.T @ (train_data - train_data.mean(axis=1, keepdims=True))
test_proj = cvpca_v.T @ (test_data - test_data.mean(axis=1, keepdims=True))

ineg = np.where(cvpca < 0)[0]
if len(ineg) == 0:
    ineg = [30]

norm = lambda x: x / np.sum(x)

xv = range(1, nc + 1)
fig, ax = plt.subplots(2, 2, figsize=(6, 6), layout="constrained")
ax[0, 0].imshow(train_data, aspect="auto", cmap="inferno", interpolation="none")
ax[0, 0].set_title("Train Data")
ax[0, 1].imshow(test_data, aspect="auto", cmap="inferno", interpolation="none")
ax[0, 1].set_title("Test Data")
ax[1, 0].plot(xv, norm(cvpca), c="k")
ax[1, 0].plot(xv, norm(truev), c="r")
ax[1, 0].plot(xv, norm(cvpca_mouse), c="b")
ax[1, 0].set_xlabel("Component")
ax[1, 0].set_ylabel("C-V Variance")
# ax[1, 0].set_xscale("log")
ax[1, 0].set_yscale("log")
ax[1, 1].plot(train_proj[ineg[0]], "k", label="Train")
ax[1, 1].plot(test_proj[ineg[0]], "b", label="Test")
ax[1, 1].set_xlabel("Train Projection onto Component")
plt.show()

In [36]:
mouse_name = "ATL022"
track = tracking.tracker(mouse_name)
pcm = analysis.placeCellMultiSession(track, autoload=False)
ises = 7
pcss = analysis.placeCellSingleSession(pcm.pcss[ises].vrexp, keep_planes=[1, 2, 3, 4], autoload=False)
split_params = dict(total_folds=2, train_folds=1)
pcss.define_train_test_split(**split_params)
pcss.load_data(new_split=False)

In [None]:
train_spkmaps = pcss.get_spkmap(average=True, smooth=0.1, trials="train")
idx_trials = [np.argsort(ti) for ti in pcss.idxFullTrialEachEnv]
spkmaps = [spkmap[:, itt] for spkmap, itt in zip(pcss.get_spkmap(average=False, smooth=0.1, trials="full"), idx_trials)]
spks = pcss.prepare_spks()

idx_nan = np.any(
    np.stack([np.any(np.isnan(t), axis=0) for t in train_spkmaps] + [np.any(np.isnan(t), axis=(0, 1)) for t in spkmaps]), axis=0
)
train_spkmaps = [t[:, ~idx_nan] for t in train_spkmaps]
spkmaps = [t[:, :, ~idx_nan] for t in spkmaps]

# Measure noise on test trials
noise = [te - tr[:, None, :] for tr, te in zip(train_spkmaps, spkmaps)]
print([t.shape for t in noise], [t.shape for t in spkmaps])

# Flattened (bin by bin across the session)
noise = [t.transpose((0, 1, 2)).reshape(t.shape[0], -1) for t in noise]
print([t.shape for t in noise])

noisecorr = [np.corrcoef(t.T) for t in noise]

In [11]:
num_envs = len(train_spkmaps)
fig, ax = plt.subplots(1, num_envs, figsize=(5 * num_envs, 5), layout="constrained")
for i, tnc in enumerate(noisecorr):
    ax[i].imshow(tnc, aspect="auto", cmap="bwr", vmin=-1, vmax=1)
plt.show()

In [52]:
mouse_name = "ATL027"
track = tracking.tracker(mouse_name)
pcm = analysis.placeCellMultiSession(track, autoload=False)
ises = 12
pcss = analysis.placeCellSingleSession(pcm.pcss[ises].vrexp, keep_planes=[1, 2, 3, 4], autoload=False)
split_params = dict(total_folds=2, train_folds=1)
pcss.define_train_test_split(**split_params)
pcss.load_data(new_split=False)

In [56]:
train_spkmaps = pcss.get_spkmap(average=True, smooth=0.1, trials="train")
test_spkmaps = pcss.get_spkmap(average=True, smooth=0.1, trials="test")

idx_nan = np.any(np.stack([np.any(np.isnan(t), axis=0) for t in train_spkmaps] + [np.any(np.isnan(t), axis=0) for t in test_spkmaps]), axis=0)
train_spkmaps = [t[:, ~idx_nan] for t in train_spkmaps]
test_spkmaps = [t[:, ~idx_nan] for t in test_spkmaps]

train_cov = [np.cov(t.T) for t in train_spkmaps]
test_cov = [np.cov(t.T) for t in test_spkmaps]
cv_cov = [helpers.abcov(tr.T, te.T) for tr, te in zip(train_spkmaps, test_spkmaps)]

In [54]:
freqs_cvf, basis = helpers.get_fourier_basis(train_spkmaps[0].shape[1], Fs=pcss.distStep)
num_components = basis.shape[0]

s = [helpers.cvPCA(tr.T, te.T, nc=num_components) for tr, te in zip(train_spkmaps, test_spkmaps)]
corr, cos_train, sin_train, cos_test, sin_test = helpers.named_transpose([helpers.cvFOURIER(tr, te, basis, covariance=True) for tr, te in zip(train_spkmaps, test_spkmaps)])
corrsum = [np.mean(c, axis=0) for c in corr]

In [101]:
envidx = 1

num_neurons, num_bins = train_spkmaps[0].shape

train_show = train_cov[envidx]
test_show = test_cov[envidx]
cv_show = cv_cov[envidx]

# pad with numbins//2 zeros on each side
shift_center = np.arange(num_bins)
roll_center = num_bins // 2

pad_matrix = np.full((num_bins, roll_center), np.nan)
train_show_pad = np.hstack([pad_matrix, train_show, pad_matrix])
test_show_pad = np.hstack([pad_matrix, test_show, pad_matrix])
cv_show_pad = np.hstack([pad_matrix, cv_show, pad_matrix])

# Roll each row to be centered on the peak
train_show_roll = np.array([np.roll(row, roll_center - p) for row, p in zip(train_show_pad, shift_center)])
test_show_roll = np.array([np.roll(row, roll_center-p) for row, p in zip(test_show_pad, shift_center)])
cv_show_roll = np.array([np.roll(row, roll_center-p) for row, p in zip(cv_show_pad, shift_center)])

freq, train_show_power = sp.signal.welch(train_show, axis=1, fs=1, scaling="density")
_, test_show_power = sp.signal.welch(test_show, axis=1, fs=1, scaling="density")
_, cv_show_power = sp.signal.welch(cv_show, axis=1, fs=1, scaling="density")

freq_ac, train_ac_power = sp.signal.welch(np.nanmean(train_show_roll, axis=0), fs=1, scaling="density")
_, test_ac_power = sp.signal.welch(np.nanmean(test_show_roll, axis=0), fs=1, scaling="density")
_, cv_ac_power = sp.signal.welch(np.nanmean(cv_show_roll, axis=0), fs=1, scaling="density")

middle_band = slice(100, num_bins*3-100)
freq_ac_mid, train_ac_power_mid = sp.signal.welch(np.nanmean(train_show_roll, axis=0)[middle_band], fs=1, scaling="density")
_, test_ac_power_mid = sp.signal.welch(np.nanmean(test_show_roll, axis=0)[middle_band], fs=1, scaling="density")
_, cv_ac_power_mid = sp.signal.welch(np.nanmean(cv_show_roll, axis=0)[middle_band], fs=1, scaling="density")


f_xvals = np.arange(len(freq)) + 1

cmap = mpl.colormaps["inferno"]
cmap.set_bad(color=[0.2, 0.2, 0.2])

vmin = 0
vmax = 0.1 #np.nanmax(train_show_roll)

ymax = np.nanmax(np.nanmean(train_show_roll, axis=0)) * 1.1
pmax = np.nanmax(np.nanmean(train_show_power, axis=0)) * 1.1

norm = lambda x: x / np.sum(x)

plt.rcParams.update({'font.size': 18})

# Plot stuff
xvals = np.arange(3*num_bins) - (3*num_bins)//2
fig, ax = plt.subplots(4, 3, figsize=(14, 12), sharex="row", sharey="row", layout="constrained")
# autocorrelation maps
ax[0, 0].imshow(train_show_pad, aspect='auto', cmap=cmap, vmin=0, vmax=vmax)
ax[0, 1].imshow(test_show_pad, aspect='auto', cmap=cmap, vmin=0, vmax=vmax)
ax[0, 2].imshow(cv_show_pad, aspect='auto', cmap=cmap, vmin=0, vmax=vmax)

# average autocorrelation
ax[1, 0].plot(np.nanmean(train_show_roll, axis=0), label="Train")
ax[1, 1].plot(np.nanmean(test_show_roll, axis=0), label="Test")
ax[1, 2].plot(np.nanmean(cv_show_roll, axis=0), label="CV")
ax[1, 0].set_ylim([0, ymax])

# power - train/test/cv - 
# black: full, red: over average autocorrelation, blue: over middle band
ax[2, 0].plot(freq, np.nanmean(train_show_power, axis=0), c='k', label="Train")
ax[2, 1].plot(freq, np.nanmean(test_show_power, axis=0), c='k', label="Test")
ax[2, 2].plot(freq, np.nanmean(cv_show_power, axis=0), c='k', label="CV")
ax[2, 0].plot(freq_ac, train_ac_power.T, c='r', label="Train AC")
ax[2, 1].plot(freq_ac, test_ac_power.T, c='r', label="Test AC")
ax[2, 2].plot(freq_ac, cv_ac_power.T, c='r', label="CV AC")
ax[2, 0].plot(freq_ac_mid, train_ac_power_mid.T, c='b', label="Train AC Mid")
ax[2, 1].plot(freq_ac_mid, test_ac_power_mid.T, c='b', label="Test AC Mid")
ax[2, 2].plot(freq_ac_mid, cv_ac_power_mid.T, c='b', label="CV AC Mid")
ax[2, 0].set_yscale('log')
ax[2, 1].set_yscale('log')
ax[2, 2].set_yscale('log')

# cross-validated variance
# left, green: cvpca
# left, magenta: cv-fourier average
# left, black: fourier power autocorr
# middle, green: cv-fourier cosine
# right, green: cv-fourier sine

ax[3, 0].plot(freqs_cvf, norm(s[envidx]), c='g', label="Train")
ax[3, 1].plot(freqs_cvf, norm(corr[envidx][0]), c='g', label="Correlation - Cosine")
ax[3, 2].plot(freqs_cvf, norm(corr[envidx][1]), c='g', label="Correlation - Sine")

ax[3, 0].plot(freqs_cvf, norm(corrsum[envidx]), c='m', label="Correlation - SumFourier")

ax[3, 0].plot(freq, norm(np.nanmean(cv_show_power, axis=0)), c='k', label="CV")
ax[3, 1].plot(freq, norm(np.nanmean(cv_show_power, axis=0)), c='k', label="CV")
ax[3, 2].plot(freq, norm(np.nanmean(cv_show_power, axis=0)), c='k', label="CV")
ax[3, 0].set_xlabel("Frequency")
ax[3, 1].set_xlabel("Frequency")
ax[3, 2].set_xlabel("Frequency")
ax[3, 0].set_yscale('log')
ax[3, 1].set_yscale('log')
ax[3, 2].set_yscale('log')

pmin = min([np.nanmin(np.nanmean(train_show_power, axis=0)), np.nanmin(np.nanmean(test_show_power, axis=0)), np.nanmin(np.nanmean(cv_show_power, axis=0))])

ax[0, 0].set_title("Train")
ax[0, 1].set_title("Test")
ax[0, 2].set_title("CV")
ax[0, 0].set_ylabel("Position")
ax[1, 0].set_ylabel("AutoCorr\n(average row of cov)")
ax[2, 0].set_ylabel("Power Spectrum")

ax[0, 0].set_xlabel("Position")
ax[0, 1].set_xlabel("Position")
ax[0, 2].set_xlabel("Position")

ax[1, 0].set_xlabel("Position")
ax[1, 1].set_xlabel("Position")
ax[1, 2].set_xlabel("Position")

ax[2, 0].set_xlabel("Frequency")
ax[2, 1].set_xlabel("Frequency")
ax[2, 2].set_xlabel("Frequency")

plt.show()

In [100]:
w, v = helpers.smart_pca(cv_show)
_, wf = sp.signal.welch(cv_show[cv_show.shape[0]//2], fs=1, scaling="density")

fig, ax = plt.subplots(1, 2, figsize=(10, 5), layout="constrained")
ax[0].plot(range(len(freqs_cvf)), norm(s[envidx]), c='k', label="cvPCA")
ax[0].plot(range(len(freqs_cvf)), norm(corrsum[envidx]), c='r', label="cv-Fourier")
ax[0].set_yscale('log')
ax[0].set_ylim([1e-5, 1])
ax[0].set_xlabel("Dimensions")
ax[0].set_ylabel("Relative Variance")
ax[0].legend(loc="upper right")

ax[1].plot(range(len(freqs_cvf)), norm(w[:len(freqs_cvf)]), c='k', label="cvPCA")
ax[1].plot(range(len(freqs_cvf)), norm(wf[:len(freqs_cvf)]), c='b', label="True")
ax[1].set_yscale('log')
# ax[1].set_ylim([1e-5, 1])
ax[1].set_xlabel("Dimensions")
ax[1].set_ylabel("Relative Variance")
ax[1].legend(loc="upper right")
plt.show()

In [2]:
# Analyses and work to do:

# ROICaT Figure:
# - add a "print pair data" button to the interactive viewer (and maybe even a "save figure" button?)
# - build an example figure with the ROICaT data (can be simple, just make it soon)

# LBM-s3d:
# - get started

In [3]:
# Database Management:
# I need a way to report how many sessions the mouse has experienced each environment, independent of 
# which environments are represented in imaging sessions (which is how I'm doing it now...)

# Required Updates: 
# need to update the placeCellMultiSession object to reflect changes to spkmap code
# anything that uses pcss.get_place_field (pcmm make_snake_data and make_paired_snake)

# Compare cvPCA analyses with eigenspectrum of spontaneous data unrelated to SVCA
# And I want to start with the rastermap on projected place field data

# Compare cvPCA to SVCA (do a hybrid: use cvPCA to get the spatial PCs, then apply those to the SVCA split)

# Buzsaki Data:
# https://crcns.org/data-sets/hc/hc-3 -- https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4097350/
# https://app.globus.org/file-manager?origin_id=188a6110-96db-11eb-b7a9-f57b2d55370d&origin_path=%2FVargaV%2F&two_pane=false - https://buzsakilab.com/wp/animals/?frm_search&project=67125&frm-page-14333=2

In [4]:
# Post Dataclub 240429: 
# -- need to consolidate all my figures (especially for the last few slides in a script)

In [5]:
# Post Meeting with Kenneth Plan:
# - Relate to Kernel Matrices:
#   - https://people.eecs.berkeley.edu/~jordan/kernels/0521813972c09_p291-326.pdf
#   - First, look at the kernel matrices (the position x position covariance matrices for each environment/session)
#   - Study the structure, and how it changes over time. 
#   - Compare the cross-validated kernel and the non-cv kernel matrix and compare their changes over time
#            - notes about ^^, this will tell us how much changes in eigenspectrum relate to reliability across trials vs the shape of the kernel matrix etc...
# - Studies of non-place cells:
#   - Look at the eigenspectrum from non-place cells, suppose as a function of the reliability...
#   - Do cross-validated decoding from non-place cells
# - Discussion of SVCA results
#   - SVCA dimensionality could have issues with noise estimation... the same way the trial expanded cvPCA plots did...
#   - Predict cell2 group from cell1 group, and predict cell2 group from their cross-validated place field, compare variance explained and overlap in variance explained
# - Rastermap: 
#   - need to find a way to remove expected spatial activity from full spike trace data (then maybe do rastermap again?)
# - Signal to Noise
#   - For each ROI, measure activity in center of place field, outside of place field on a linear track, and outside the track (or in other environments)
# - Measure spontaneous periods of activity.

In [120]:
# DIMILIBI GOALS:
# Compare best networks to ridge regression for a bunch of sessions.
# Ridge Regression:
#   - need to optimize ridge parameter: I can use a simple grid search on a log-space for this in two stages
#   - Note: I tested (for one session) if the best ridge parameter is the same for full-rank and low-rank, and it was. 
#   - setup a train/val/test split program and then fit the model to each session, and record the results for several ranks
# Networks: 
#   - for a subset of ranks, train a standard network and a beta-VAE network, record results for each session
#   - to validate, just train a network on lots of epochs, store the evaluation test score throughout training, and 
#     save the full trajectory across training along with the best test score and the associated epoch number.
#   - Note: for BetaVAELoss, will need to separate the loss into reconstruction and KL divergence, for proper saving.
# Analysis / summary:
#   - plot summary curves across mice for each rank, color-coded by RRR, BetaVAE, and SVCANet
#   - probably also compare for each mouse somehow? Maybe categorized dot plots separated by rank? 
#   - I also want to compare how the validation scores improve over time for the network models.

In [None]:
# ('CR_Hippocannula6', '2022-08-26', '702') # test this because performance improved for 2000 epochs

In [None]:
# choose a session randomly that has registered imaging data and a single environment
vrexp = random.choice(sessiondb.iterSessions(imaging=True, vrRegistration=True, experimentID=1))
print(vrexp.sessionPrint()) # show which session you chose

keep_planes = [1, 2]
onefile = "mpci.roiActivityDeconvolvedOasis"
ospks = vrexp.loadone(onefile)
keep_idx = vrexp.idxToPlanes(keep_planes=keep_planes)
ospks = ospks[:, keep_idx]
time_split_prms = dict(
    num_groups=3,
    relative_size=[5, 5, 1], #[5, 5, 1],
    chunks_per_group=-3, # 25
    num_buffer=3, # usually use default (which is 10)
)
npop = Population(ospks.T, generate_splits=True, time_split_prms=time_split_prms)
print(npop.size())

pcss = analysis.placeCellSingleSession(vrexp, keep_planes=keep_planes, onefile=onefile, autoload=True)
assert len(pcss.environments) == 1, "Only one environment is supported for this analysis"

train_source, train_target = npop.get_split_data(0, center=False, scale=True, pre_split=False, scale_type="preserve")
test_source, test_target = npop.get_split_data(1, center=False, scale=True, pre_split=False, scale_type="preserve")

In [20]:
# Do a comparison of cvPCA and SVCA
svca = SVCA(centered=True).fit(train_source, train_target)

# Get place fields
envnum = pcss.environments[0]
train_spkmap = pcss.get_spkmap(envnum=envnum, average=True, trials="train")[0]
source_spkmap = train_spkmap[npop.cell_split_indices[0]]
target_spkmap = train_spkmap[npop.cell_split_indices[1]]

test_spkmap = pcss.get_spkmap(envnum=envnum, average=True, trials="test")[0]
source_spkmap_test = test_spkmap[npop.cell_split_indices[0]]
target_spkmap_test = test_spkmap[npop.cell_split_indices[1]]

idx_nan = np.any(np.isnan(source_spkmap), axis=0) | np.any(np.isnan(target_spkmap), axis=0)
source_pca = PCA().fit(source_spkmap[:, ~idx_nan])
target_pca = PCA().fit(target_spkmap[:, ~idx_nan])
source_components = source_pca.get_components()
target_components = target_pca.get_components()

# Compare the PCA map of train to test trials on the source data
idx_nan = np.any(np.isnan(source_spkmap_test), axis=0)
source_pca_test = PCA().fit(source_spkmap_test[:, ~idx_nan])
traintest_map = np.dot(source_components.T, source_pca_test.get_components())

# For U, V, and components, each column is a component (so each row is a neuron)
source_map = np.dot(source_components.T, svca.U)
target_map = np.dot(target_components.T, svca.V)

vmin = min(source_map.min(), target_map.min(), traintest_map.min())
vmax = max(source_map.max(), target_map.max(), traintest_map.max())

# Take weighted average across axis 0 (for each SV, which PF PCs are it composed of?)
idx = np.arange(source_map.shape[0]).reshape(-1, 1)
source_map_avg = np.sum(np.abs(source_map) * idx, axis=0) / np.sum(np.abs(source_map), axis=0)
target_map_avg = np.sum(np.abs(target_map) * idx, axis=0) / np.sum(np.abs(target_map), axis=0)
pc_map_avg = np.sum(np.abs(traintest_map) * idx, axis=0) / np.sum(np.abs(traintest_map), axis=0)

In [5]:
# To get SV activity projections, 
u_activity = fs.zscore(np.array(svca.U.T @ npop.data[npop.cell_split_indices[0]]), axis=1)
v_activity = fs.zscore(np.array(svca.V.T @ npop.data[npop.cell_split_indices[1]]), axis=1)
u_rawspkmap = helpers.getBehaviorAndSpikeMaps(vrexp, onefile=u_activity.T)[3]
v_rawspkmap = helpers.getBehaviorAndSpikeMaps(vrexp, onefile=v_activity.T)[3]
uspkmap_train = pcss.get_spkmap(envnum=envnum, average=True, trials="train", rawspkmap=u_rawspkmap)[0]
vspkmap_train = pcss.get_spkmap(envnum=envnum, average=True, trials="train", rawspkmap=v_rawspkmap)[0]
uspkmap_test = pcss.get_spkmap(envnum=envnum, average=True, trials="test", rawspkmap=u_rawspkmap)[0]
vspkmap_test = pcss.get_spkmap(envnum=envnum, average=True, trials="test", rawspkmap=v_rawspkmap)[0]

def select_env(tup, idx):
    return list(map(lambda x: x[idx], tup))

urelmse, urelcor = select_env(pcss.get_reliability_values(envnum=envnum, rawspkmap=u_rawspkmap), 0)
vrelmse, vrelcor = select_env(pcss.get_reliability_values(envnum=envnum, rawspkmap=v_rawspkmap), 0)
relmse, relcor = select_env(pcss.get_reliability_values(envnum=envnum), 0)

u_rel_idx = urelcor > 0.6
v_rel_idx = vrelcor > 0.6
rel_idx = relcor > 0.6

uidx = pcss.get_place_field(uspkmap_train[u_rel_idx], method="max")[1]
vidx = pcss.get_place_field(vspkmap_train[v_rel_idx], method="max")[1]

In [None]:
from scripts.dimilibi.helpers import make_position_basis, filter_timepoints
frame_position, frame_environment, environments = vrexp.get_frame_behavior(speedThreshold=1)
valid_u_activity, valid_position, valid_environment = filter_timepoints(u_activity.T, frame_position, frame_environment)
position_basis = make_position_basis(valid_position, valid_environment, num_basis=10)

upospop = Population(valid_u_activity.T, time_split_prms={"num_groups": 2, "relative_size": [5, 1], "chunks_per_group": -3, "num_buffer": 3}, dtype=torch.float32)
train_valid_u = upospop.apply_split(valid_u_activity.T, 0)
test_valid_u = upospop.apply_split(valid_u_activity.T, 1)
train_pos_basis = upospop.apply_split(position_basis.T, 0)
test_pos_basis = upospop.apply_split(position_basis.T, 1)

print(train_valid_u.shape, train_pos_basis.shape, test_valid_u.shape, test_pos_basis.shape)

In [None]:
rmodel = RidgeRegression(alpha=1e3, fit_intercept=True).fit(train_valid_u.T, train_pos_basis.T)
print(rmodel.score(test_valid_u.T, test_pos_basis.T))

In [None]:
vmin = -1
vmax = 1

tspkmapidx = pcss.get_place_field(train_spkmap[rel_idx], method="max")[1]
print(vmin, vmax)

fig, ax = plt.subplots(2, 4, figsize=(12, 6), layout="constrained")
ax[0, 0].imshow(uspkmap_train[u_rel_idx][uidx], aspect="auto", cmap="bwr", vmin=vmin, vmax=vmax, interpolation="none")
ax[0, 1].imshow(vspkmap_train[v_rel_idx][vidx], aspect="auto", cmap="bwr", vmin=vmin, vmax=vmax, interpolation="none")
ax[0, 2].imshow(train_spkmap[rel_idx][tspkmapidx], aspect="auto", cmap="bwr", vmin=vmin, vmax=vmax, interpolation="none")
ax[1, 0].imshow(uspkmap_test[u_rel_idx][uidx], aspect="auto", cmap="bwr", vmin=vmin, vmax=vmax, interpolation="none")
ax[1, 1].imshow(vspkmap_test[v_rel_idx][vidx], aspect="auto", cmap="bwr", vmin=vmin, vmax=vmax, interpolation="none")
ax[1, 2].imshow(test_spkmap[rel_idx][tspkmapidx], aspect="auto", cmap="bwr", vmin=vmin, vmax=vmax, interpolation="none")
ax[0, 3].ecdf(urelmse[~np.isnan(urelmse)], label="U")
ax[0, 3].ecdf(vrelmse[~np.isnan(vrelmse)], label="V")
ax[0, 3].ecdf(relmse[~np.isnan(relmse)], label="PCA")
ax[1, 3].ecdf(urelcor[~np.isnan(urelcor)], label="U")
ax[1, 3].ecdf(vrelcor[~np.isnan(vrelcor)], label="V")
ax[1, 3].ecdf(relcor[~np.isnan(relcor)], label="PCA")
ax[0, 3].set_xlabel("Reliability (method 1)")
ax[1, 3].set_xlabel("Reliability (method 2)")
ax[0, 3].set_ylabel("Cumulative probability")
ax[1, 3].set_ylabel("Cumulative probability")
ax[0, 3].legend()
ax[1, 3].legend()
ax[0, 3].set_xlim(-2, 1)

plt.show()

In [59]:
max_components = source_components.shape[1]
fig, ax = plt.subplots(1, 4, figsize=(12, 3), layout="constrained")
ax[0].imshow(np.abs(source_map[:, :max_components]), aspect="auto", interpolation="none", cmap="bwr", vmin=vmin, vmax=vmax)
ax[1].imshow(np.abs(target_map[:, :max_components]), aspect="auto", interpolation="none", cmap="bwr", vmin=vmin, vmax=vmax)
ax[2].imshow(np.abs(traintest_map[:, :max_components]), aspect="auto", interpolation="none", cmap="bwr", vmin=vmin, vmax=vmax)
ax[3].scatter(np.abs(source_map[:, :max_components].flatten()), np.abs(target_map[:, :max_components].flatten()), s=1)

# Link axes 0 and 1
ax[1].sharex(ax[0])
ax[1].sharey(ax[0])
ax[2].sharex(ax[0])
ax[2].sharey(ax[0])
plt.show()

In [None]:
center = False
scale = True
pre_split = False
scale_type = "preserve"

train_source, train_target = npop.get_split_data(0, center=center, scale=scale, pre_split=pre_split, scale_type=scale_type)
val_source, val_target = npop.get_split_data(1, center=center, scale=scale, pre_split=pre_split, scale_type=scale_type)
test_source, test_target = npop.get_split_data(2, center=center, scale=scale, pre_split=pre_split, scale_type=scale_type)

get_whitening = False

if get_whitening:
    zca_source = PCA().fit(train_source).get_zca().to(device)
    zca_val = PCA().fit(val_source).get_zca().to(device)
    zca_target = PCA().fit(train_target).get_zca().to(device)

print(train_source.shape, train_target.shape, val_source.shape, val_target.shape, test_source.shape, test_target.shape)

In [27]:
# get eigenvalues of the full population to compare with simulated data appropriately
# npop_evals = PCA().fit(ospks.T).get_eigenvalues()

N = npop.size(0) // 2
T = 8000
Ttest = 1000
Q = torch.linalg.qr(torch.normal(0, 1, (2*N, 2*N)))[0]
D = npop_evals[:2*N]

train_scores = torch.diag(D) @ torch.normal(0, 1, (2*N, T))
val_scores = torch.diag(D) @ torch.normal(0, 1, (2*N, Ttest))
train_data = Q @ train_scores
val_data = Q @ val_scores

train_source = train_data[:N]
train_target = train_data[N:]
val_source = val_data[:N]
val_target = val_data[N:]

# zscore the data
train_source = (train_source - train_source.mean(1, keepdim=True)) / train_source.std(1, keepdim=True)
train_target = (train_target - train_target.mean(1, keepdim=True)) / train_target.std(1, keepdim=True)
val_source = (val_source - val_source.mean(1, keepdim=True)) / val_source.std(1, keepdim=True)
val_target = (val_target - val_target.mean(1, keepdim=True)) / val_target.std(1, keepdim=True)

# zca_source = PCA().fit(train_source).get_zca()

In [4]:
rrr = ReducedRankRegression(alpha=1e5, fit_intercept=True).fit(train_source.T.to('cpu'), train_target.T.to('cpu'))

In [None]:
rank = 5
print(rrr.score(train_source.T.to('cpu'), train_target.T.to('cpu'), rank=rank))
print(rrr.score(train_source.T.to('cpu'), train_target.T.to('cpu')))
print(rrr.score(val_source.T, val_target.T, rank=rank, nonnegative=False))
print(rrr.score(val_source.T, val_target.T, rank=rank, nonnegative=True))
print(rrr.score(val_source.T, val_target.T, nonnegative=False))
print(rrr.score(val_source.T, val_target.T, nonnegative=True))

In [10]:
num_neurons = train_source.size(0)
num_hidden = [400] # [hyps["best_params"]["num_hidden"]]
num_latent = 5
num_target_neurons = train_target.size(0)
num_timepoints = train_source.size(1)

# net0 = SVCANet(
#     num_neurons,
#     num_hidden,
#     num_latent,
#     num_target_neurons,
#     activation = torch.nn.ReLU(),
#     nonnegative = False, 
# ).to(device)

# net1 = BetaVAE(
#     num_neurons,
#     num_hidden,
#     num_latent,
#     num_target_neurons,
#     activation = torch.nn.ReLU(),
#     nonnegative = False,
# ).to(device)

# net2 = HurdleNet(
#     num_neurons,
#     num_hidden,
#     num_latent,
#     num_target_neurons,
#     activation = torch.nn.ReLU(),
#     nonnegative = False,
#     transparent_relu=True,
# ).to(device)

nets = [
    constructor(
        num_neurons,
        num_hidden,
        num_latent,
        num_target_neurons,
        activation = torch.nn.ReLU(),
        nonnegative = True,
        transparent_relu = True,
    ).to(device)
    for constructor in [SVCANet, SVCANet, SVCANet, SVCANet, HurdleNet, HurdleNet, HurdleNet, HurdleNet]
]

cols = 'kkkkrrrr'

# nets = [net0, net1]
betavae = [False for _ in range(len(nets))] #[False, True]

# nets = [net0, net1]
loss_functions = [torch.nn.MSELoss(reduction='sum') for _ in range(len(nets))]

# beta = [1e1, 1e2, 1e3, 1e4] #[1 for _ in range(len(nets))]
regularizers = [EmptyRegularizer() for _ in range(len(nets))]
# regularizers = [BetaVAE_KLDiv(beta=b, reduction='sum') for b in beta]

wd = 1e3 #hyps["best_params"]["weight_decay"]
lr = 1e-3 #hyps["best_params"]["lr"]
nl = 0 # hyps["best_params"]["noise_level"]
weight_decay = [wd for _ in range(len(nets))] 
opts = [torch.optim.Adam(net.parameters(), lr=lr, weight_decay=wd) for net, wd in zip(nets, weight_decay)]

net_reg_weight = [0 for _ in range(len(nets))]
noise_level = [nl for _ in range(len(nets))]

In [None]:
# train the network
batch_size = num_timepoints//10
num_epochs = 800
num_nets = len(nets)

train_loss = torch.zeros((num_nets, num_epochs))
train_reg = torch.zeros((num_nets, num_epochs))
train_score = torch.zeros((num_nets, num_epochs))
traintest_loss = torch.zeros((num_nets, num_epochs))
traintest_score = torch.zeros((num_nets, num_epochs))

train_source = train_source.to(device)
train_target = train_target.to(device)
test_source = test_source.to(device)
test_target = test_target.to(device)

for net in nets:
    net.train()
    
progress = tqdm(range(num_epochs), desc='Training Networks')
for epoch in progress:
            
    itime = torch.randperm(num_timepoints)[:batch_size]
    
    source_batch = train_source[:, itime].T
    target_batch = train_target[:, itime].T

    for opt in opts:
        opt.zero_grad()

    predictions = [net(source_batch + nl * torch.randn_like(source_batch)) for net, nl in zip(nets, noise_level)]
    mulogvar = [pred[1:] if b else (torch.tensor(0, device=device), torch.tensor(0, device=device)) for pred, b in zip(predictions, betavae)]
    predictions = [pred[0] if b else pred for pred, b in zip(predictions, betavae)]
    losses = [loss_fn(pred, target_batch) for pred, loss_fn in zip(predictions, loss_functions)]
    regs = []
    for b, mlv, reg, pred in zip(betavae, mulogvar, regularizers, predictions):
        if b:
            regs.append(reg(*mlv))
        else:
            regs.append(reg(source_batch, pred))
    full_losses = [loss + weight * reg for loss, reg, weight in zip(losses, regs, net_reg_weight)]
    for loss in full_losses:
        loss.backward()
    
    for opt in opts:
        opt.step()
    
    scores = [net.score(source_batch, target_batch) for net in nets]

    for inet in range(len(nets)):
        train_loss[inet, epoch] = losses[inet].item()
        train_reg[inet, epoch] = regs[inet].item()
        train_score[inet, epoch] = scores[inet].item()

        with torch.no_grad():
            for net in nets:
                net.eval()
            pred = nets[inet](test_source.T)
            if betavae[inet]:
                pred = pred[0]
            traintest_loss[inet, epoch] = loss_functions[inet](pred, test_target.T).item()
            traintest_score[inet, epoch] = nets[inet].score(test_source.T, test_target.T).item()
            for net in nets:
                net.train()
        
    progress.set_postfix({'Loss': losses[0].item(), 'Score': scores[0].item(), "Reg": regs[0].item()})

for net in nets:
    net.eval()
    
test_predictions = [net(test_source.T) for net in nets]
test_mulogvar = [pred[1:] if b else (torch.tensor(0, device=device), torch.tensor(0, device=device)) for pred, b in zip(test_predictions, betavae)]
test_predictions = [pred[0] if b else pred for pred, b in zip(test_predictions, betavae)]
test_losses = [loss_fn(test_prediction, test_target.T) for test_prediction, loss_fn in zip(test_predictions, loss_functions)]
test_regs = []
for b, mlv, reg, pred in zip(betavae, test_mulogvar, regularizers, test_predictions):
    if b:
        test_regs.append(reg(*mlv))
    else:
        test_regs.append(reg(test_source.T, pred))
test_scores = [net.score(test_source.T, test_target.T) for net in nets]

for inet in range(len(nets)):
    print(f"Net{inet} Test Loss: {test_losses[inet].item():.3f}, " +
          f"Test Score: {test_scores[inet].item():.3f}, " + 
          f"Test Reg: {test_regs[inet].item():.3f}" + 
          f"Maximum Test Score: {traintest_score[inet].max().item():.3f}")

# plot the training loss
fig, ax = plt.subplots(1, 3, figsize=(9, 3), layout="constrained")
for inet in range(len(nets)):
    ax[0].plot(train_loss[inet], c=cols[inet], label=f"net{inet}")
    ax[0].axhline(test_losses[inet].item(), linestyle='--', c=cols[inet])
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Loss')
ax[0].set_title('Training Loss')
ax[0].legend()
for inet in range(len(nets)):
    ax[1].plot(train_score[inet], c=cols[inet], label=f"net{inet}")
    ax[1].axhline(test_scores[inet].item(), linestyle='--', c=cols[inet])
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Score')
ax[1].set_title('Training Score')
# ax[1].set_ylim(-5, 1.0)
ax[1].legend()
for inet in range(len(nets)):
    ax[2].plot(train_reg[inet], c=cols[inet], label=f"net{inet}")
    ax[2].axhline(test_regs[inet].item(), linestyle='--', c=cols[inet])
ax[2].set_xlabel('Epoch')
ax[2].set_ylabel('Regularization')
ax[2].set_title('Training Regularization')
ax[2].legend()
plt.show()

In [13]:
# plot the training loss
fig, ax = plt.subplots(1, 3, figsize=(9, 3), layout="constrained")
for inet in range(len(nets)):
    ax[0].plot(traintest_loss[inet], c=cols[inet], label=f"net{inet}")
    ax[0].axhline(test_losses[inet].item(), linestyle='--', c=cols[inet])
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Loss')
ax[0].set_title('Training Loss')
ax[0].legend()
for inet in range(len(nets)):
    ax[1].plot(traintest_score[inet], c=cols[inet], label=f"net{inet}")
    ax[1].axhline(test_scores[inet].item(), linestyle='--', c=cols[inet])
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Score')
ax[1].set_title('Training Score')
# ax[1].set_ylim(-5.1, 0.5) #torch.min(traintest_score[:, 2:]), 1.2*torch.max(traintest_score))
ax[1].legend()
for inet in range(len(nets)):
    ax[2].plot(train_reg[inet], c=cols[inet], label=f"net{inet}")
    ax[2].axhline(test_regs[inet].item(), linestyle='--', c=cols[inet])
ax[2].set_xlabel('Epoch')
ax[2].set_ylabel('Regularization')
ax[2].set_title('Training Regularization')
ax[2].legend()
plt.show()

In [None]:
svca = SVCA().fit(train_source, train_target)
shared, total = svca.score(train_source, train_target)
print(f"{shared.sum() / total.sum() * 100:.2f}% of the variance is shared between the two groups.")

shared, total = svca.score(test_source, test_target)
print(f"{shared.sum() / total.sum() * 100:.2f}% of the variance is shared between the two groups.")


In [None]:
# The idea
# How to go from one dataset to another? 
# One way: learn the SVD of the covariance

# Question: 
# How does the SVD of gram between A and B related to a reduced rank solution of predicting B from A? 

# U S V.T = A.T @ B
# AX = B --> X = (A.T @ A)^-1 @ A.T @ B
# AX = B --> X = (A.T @ A)^-1 @ U S V.T
# AX = B --> X = Q 1/D Q.T @ U S V.T

# In general, if we have a SVD map of the gram matrix defining covariance between A and B, 
# then we can study how each mode of A/B maps onto each PC of A/B. 

# Simple:
# 1. Learn the SVD of the gram matrix between A and B
# 2. Learn the PCA of A and B
# 3. Get the OLS solution to transform SVD modes to PCA modes

In [None]:
import random

# choose a session randomly that has registered imaging data
vrexp = random.choice(sessiondb.iterSessions(imaging=True, vrRegistration=True))
print(vrexp.sessionPrint()) # show which session you chose

ospks = vrexp.loadone('mpci.roiActivityDeconvolvedOasis')
keep_idx = vrexp.idxToPlanes(keep_planes=[1])
ospks = ospks[:, keep_idx]

time_split_prms = dict(
    relative_size=[10, 1],
    chunks_per_group=25,
    num_buffer=10,
)
npop = Population(ospks.T, generate_splits=True, time_split_prms=time_split_prms)
print(npop.size())

# get eigenvalues of the full population to compare with simulated data appropriately
npop_evals = PCA().fit(ospks.T).get_eigenvalues()

train_source, train_target = npop.get_split_data(0, center=True)
test_source, test_target = npop.get_split_data(1, center=True)

print(train_source.shape, train_target.shape, test_source.shape, test_target.shape)

data_cross = CrossCompare().fit(train_source, train_target)


N = npop.size(0)//2
T = 5000
Ttest = 1000
Q = torch.linalg.qr(torch.normal(0, 1, (2*N, 2*N)))[0]
D = npop_evals[:2*N]

train_scores = torch.diag(D) @ torch.normal(0, 1, (2*N, T))
test_scores = torch.diag(D) @ torch.normal(0, 1, (2*N, Ttest))
train_data = Q @ train_scores
test_data = Q @ test_scores

train_source = train_data[:N]
train_target = train_data[N:]
test_source = test_data[:N]
test_target = test_data[N:]

sim_cross = CrossCompare().fit(train_source, train_target)

to_pca = True
d_source_com, d_target_com, d_source_entropy, d_target_entropy = data_cross.analyze(to_pca=to_pca)
source_com, target_com, source_entropy, target_entropy = sim_cross.analyze(to_pca=to_pca)

In [14]:
fig, ax = plt.subplots(2, 2, figsize=(6, 6), layout="constrained")

d_crossmap = data_cross.u_to_pc if to_pca else data_cross.pc_to_u
crossmap = sim_cross.u_to_pc if to_pca else sim_cross.pc_to_u

ax[0, 0].imshow(torch.abs(d_crossmap), aspect='auto', cmap='pink', interpolation="None")
ax[0, 0].set_xlabel("Principal Component" if to_pca else "SVD Mode")
ax[0, 0].set_ylabel("U Mode" if to_pca else "Principal Component")

ax[0, 1].imshow(torch.abs(crossmap), aspect='auto', cmap='pink', interpolation="None")
ax[0, 1].set_xlabel("Principal Component" if to_pca else "SVD Mode")
ax[0, 1].set_ylabel("U Mode" if to_pca else "Principal Component")

ax[1, 0].axline((0, 0), slope=1, color='k', linewidth=0.5, linestyle='--')
ax[1, 0].plot(source_com, color='k', label="Simulated")
ax[1, 0].plot(target_com, color='k')
ax[1, 0].plot(d_source_com, color='b', label="Data")
ax[1, 0].plot(d_target_com, color='b')
ax[1, 0].set_xlabel("Principal Component" if to_pca else "SVD Mode")
ax[1, 0].set_ylabel('Center of Mass ' + ("SVD Mode" if to_pca else "Principal Component"))
ax[1, 0].set_title('Map Dimension CoM')
ax[1, 0].legend()

ax[1, 1].plot(source_entropy, color='k', label="Source")
ax[1, 1].plot(target_entropy, color='k')
ax[1, 1].plot(d_source_entropy, color='b', label="Data")
ax[1, 1].plot(d_target_entropy, color='b')
ax[1, 1].set_xlabel("Principal Component" if to_pca else "SVD Mode")
ax[1, 1].set_ylabel('Entropy')
ax[1, 1].set_title('Entropy')
ax[1, 1].legend()

plt.show()

In [14]:
from matplotlib import pyplot as plt

sourcemap = sp.ndimage.gaussian_filter(torch.abs(data_cross.u_to_pc.T), 1)
targetmap = sp.ndimage.gaussian_filter(torch.abs(data_cross.v_to_pc.T), 1)

fig, ax = plt.subplots(1, 2, figsize=(8, 4), layout="constrained", sharex=True, sharey=True)
ax[0].imshow(sourcemap, cmap="hot", interpolation='gaussian')
ax[0].set_xlabel('U Mode')
ax[0].set_ylabel('PC Mode')
ax[0].set_title('U to PC Source')
ax[1].imshow(targetmap, cmap="hot")
ax[1].set_title('V to PC Target')
ax[1].set_xlabel('V Mode')
ax[1].set_ylabel('PC Mode')
plt.show()

In [None]:
# Database Requirements: 
# ---------------------
# GUI: db manager
# - click on entry and do things:
#                --> open file explorer to that session
#                --> do suite2p
#                --> do red cell management
# - update table data? 
# ---------------------
# Operational Commands: 
# - Automatically do suite2p 
# - Check if registration was done before a suite2p update

# Further Requirements:
# ---------------------
# ROICaT Alignment Tools 
# Track Red Cell Consistency across days 
# Now that I've refactored the database code, need to update some things in documentation and probably elsewhere too


In [None]:
# Notes from meeting with Kenneth:

# - skewness (violin plot) of Control & Red -- 
#     - all - 
#     - just reliable -- for each session - 

# Subsample control data for scatter plot
# Fisher z transformation (but label by original correlation...)
# Question:
# -- if reliable on 1 day, is it reliable on other days? 
# -- make a matrix with source and target, color by fraction of reliable on target out of those reliable on source
# -- also do this with your session kernels for control and red
# -- also do this for different reliability cutoffs

In [None]:
# Plan for attack:
# Make a suite of summary figures on a session by session basis and a multisession basis. 
# I just want to be able to look through a mouse's data and evaluate the behavior, the imaging data, and how well the tracking did.

# Inclusions:
# 1. Behavioral data (running speed and number of trials across any environments it was in -- also metadata about day in environment...)
# 2. Imaging data (example snakes from all environments, both train/test comparisons and remapping comparisons)
# 3. Red cell data (number of red cells per plane -- and some examples of red cells?)
# 4. Tracking data (number of tracked cells per combination (full matrix!), number of tracked red cells, number of tracked reliable cells per environment)

In [None]:
# ROICaT Analysis
# 1. Example data figures (from roicat_stats)
#    - show two post-alignment FOVs, highlight a few tracked neurons and a few (nearby) un-tracked neurons with colors
#    - below that, show the place field tuning in each session, color-coded the same way as the neurons ROI plot

# 2. Analysis of ROICaT agreement with functional data
#    - scatter plot of sConj & place field correlation (with "labels" pairs colored differently)
#    - mouse by mouse, session by session mean lines comparing average place field correlation of same pairs with different pairs
#        -- (imagining lines from 0->1 for mouse 1 of each sessions mean same/diff pfCorr, then also in 2->3 for mouse 2, and 4->5 for mouse 3, etc.)
#        -- can also have a supplemental plot showing distribution of same/diff pfCorr across each session pair? 

# 3. Control analysis with null model test (empirical version with subsampled null distribution)

# 4. Control analysis with bayesian model
# -- get pfCorr_withinSession (this is pairs of ROIs within a session, and should be representative of pfCorr_diff_acrossSession)
# -- normalize pfCorr_all_acrossSession by number of pairs, subtract density of pfCorr_withinSession
# -- remaining pfCorr_remain_acrossSession = pfCorr_same_acrossSession


# ---- note ----
# - should probably include target reliable pairs not represented in the source only reliable category...

In [None]:
# Check ROICaT within session matching
import pickle

mouseName = 'ATL012'
sessionDate = '2023-02-28'
sessionID = '701'

ses = session.vrExperiment(mouseName, sessionDate, sessionID)
print(ses.sessionPrint())

data = ses.loadone('mpci.roiActivityDeconvolvedOasis')

filepath = ses.sessionPath() / f"{mouseName}.within_session.ROICaT.tracking.results.pkl"
with open(filepath, 'rb') as f:
    roicat = pickle.load(f)

num_clusters = np.max(roicat['clusters']['labels'])+1

In [None]:
scc = analysis.sameCellCandidates(ses, keepPlanes=None)

In [None]:
explorer = analysis.clusterExplorerROICaT(scc, roicat['clusters']['labels'], keepPlanes=None)

In [None]:
cc_number = np.zeros(num_clusters)
cc_within = np.zeros(num_clusters)
for ic in range(num_clusters):
    cidx = roicat['clusters']['labels']==ic
    cdata = data[:, cidx]
    ccorr = np.corrcoef(cdata.T)
    idx_offdiag = np.triu(np.ones_like(ccorr, dtype=bool), k=1)
    cvals = ccorr[idx_offdiag]
    cc_within[ic] = np.mean(cvals)
    cc_number[ic] = np.sum(cidx)

plt.close('all')
plt.scatter(cc_number, cc_within)
plt.show()

In [None]:
i_clus = np.random.randint(0, num_clusters)

plt.close('all')
plt.plot(range(data.shape[0]), data[:, roicat['clusters']['labels']==i_clus])
plt.show()

In [None]:
np.sum(roicat['clusters']['labels']>-1)

In [None]:
roicat

In [None]:
# ucids in list of lists for requested sessions
ucids = [[[] for _ in range(num_ses)] for _ in range(num_planes)]
for planeidx, results in enumerate([self.results[p] for p in keepPlanes]):
    for sesidx, idx in enumerate(idx_ses):
        ucids[planeidx][sesidx] = results['clusters']['labels_bySession'][idx]

# this is the number of unique IDs per plane
num_ucids = [max([np.max(u) for u in ucid])+1 for ucid in ucids]

# this is a boolean array of size (number unique IDs x num sessions) where there is a 1 if a unique ROI is found in each session
roicat_index = [np.zeros((nucids, num_ses), dtype=bool) for nucids in num_ucids]
for planeidx, ucid in enumerate(ucids):
    for sesidx, uc in enumerate(ucid):
        cindex = uc[uc >= 0] # index of ROIs found in this session
        roicat_index[planeidx][cindex, sesidx] = True # label found ROI with True

return ucids, roicat_index

    # get ucids and 1s index for requested sessions
    ucids, roicat_index = self.prepare_tracking_idx(idx_ses=idx_ses, keepPlanes=keepPlanes)
    
    # list of UCIDs in all requested sessions (a list of the UCIDs...)
    idx_in_ses = [np.where(np.all(rindex, axis=1))[0] for rindex in roicat_index]
    
    # For each plane & session, a sorted index to the suite2p ROI to recreate the list of UCIDs
    idx_to_ucid = [[helpers.index_in_target(iis, uc)[1] for uc in ucid] for (iis, ucid) in zip(idx_in_ses, ucids)]
    
    # cumulative number of ROIs before eacg plane (in numeric order of planes using sorted(self.plane_names))
    roi_per_plane = self.roi_per_plane[keepPlanes][:, idx_ses]
    roi_plane_offset = np.cumsum(np.vstack((np.zeros((1,num_ses),dtype=int), roi_per_plane[:-1])), axis=0)

    # A straightforward numpy array of (numSessions, numROIs) containing the indices to retrieve tracked and sorted ROIs
    return np.concatenate([np.stack([offset+ucid for offset, ucid in zip(offsets, ucids)], axis=1) for offsets, ucids in zip(roi_plane_offset, idx_to_ucid)], axis=0).T

In [None]:
# Note: remove multipage tiff from: C:\Users\Andrew\Documents\localData\ATL012\2023-02-09\701\suite2p

In [None]:
idx_red = track.check_red_cell_consistency(idx_ses=idx_ses, keepPlanes=None, use_s2p=True)
idx_has_red = idx_red[:, np.any(idx_red, axis=0)]
idx_sort = np.argsort(-np.sum(idx_has_red,axis=0))
idx_plot = idx_has_red[:, idx_sort]
plt.close('all')
plt.imshow(idx_plot, aspect='auto', interpolation='none')
plt.show()

In [None]:
idx_red = track.check_red_cell_consistency(idx_ses=idx_ses, keepPlanes=None)
idx_has_red = idx_red[:, np.sum(idx_red, axis=0)>0]
idx_sort = np.argsort(-np.sum(idx_has_red,axis=0))
idx_plot = idx_has_red[:, idx_sort]
plt.close('all')
plt.imshow(idx_plot, aspect='auto', interpolation='none')
plt.show()