In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

In [2]:
from pathlib import Path
import copy
from functools import partial

import numpy as np
import scipy.sparse
import matplotlib.pyplot as plt

In [3]:
params = {}

In [8]:
params['paths'] = {}
## Directory to save results files into
params['paths']['dir_save'] = r'/media/rich/bigSSD/downloads_tmp/tmp_data/mouse_0315N/day0_2/'

## Directory with F.npy, stat.npy etc.
params['paths']['dir_s2p_outer'] = r'/media/rich/bigSSD/downloads_tmp/tmp_data/mouse_0315N/statFiles/20230423/'

In [9]:
%load_ext autoreload
%autoreload 2
import bnpm.path_helpers, bnpm.file_helpers, bnpm.ca2p_preprocessing, bnpm.file_helpers, bnpm.path_helpers
import roicat

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
use_multiple_sessions = False
path_roicat_classification = r'/home/rich/Desktop/mouse_0315N.ROICaT.classification_drawn.results.DAY0_2.pkl'

if use_multiple_sessions:
    path_roicat_tracking       = r'/home/rich/Desktop/mouse_0322R.ROICaT.tracking.results.DAY0.pkl'
    
    results_roicat_tracking       = bnpm.file_helpers.pickle_load(str(Path(path_roicat_tracking).resolve()))
    results_roicat_classification = bnpm.file_helpers.pickle_load(str(Path(path_roicat_classification).resolve()))    
    
    
    ucids_iscell_fullMatch = roicat.util.discard_UCIDs_with_fewer_matches(
        ucids=roicat.util.mask_UCIDs_with_iscell(
            ucids=results_roicat_tracking['UCIDs_bySession'],
            iscell=results_roicat_classification['preds'],
        ),
        n_sesh_thresh='all',
    )
    iscell_roicatClassification_trackingMatch = [np.array(u) >= 0 for u in ucids_iscell_fullMatch]
else:
    iscell_roicatClassification_trackingMatch = bnpm.file_helpers.pickle_load(str(Path(path_roicat_classification).resolve()))['preds']

In [19]:
filepaths_F = bnpm.path_helpers.find_paths(
    dir_outer=params['paths']['dir_s2p_outer'],
    reMatch='F.npy',
    depth=3,
)
print(f'filepaths with F.npy: {filepaths_F}')

dirs_s2p = [str(Path(p).parent) for p in filepaths_F]

filepaths with F.npy: ['/media/rich/bigSSD/downloads_tmp/tmp_data/mouse_0315N/statFiles/20230423/F.npy']


In [15]:
## == IMPORT DATA ==
outs = [bnpm.ca2p_preprocessing.import_s2p(d) for d in dirs_s2p]

## Concatenate sessions into a single combined session
F_all, Fneu_all, iscell_s2p_all, ops_all, spks_s2p_all, stat_all = ([o[ii] for o in outs] for ii in range(len(outs[0])))

spks.npy not found in /media/rich/bigSSD/downloads_tmp/tmp_data/mouse_0315N/statFiles/20230423


In [22]:
iscell_roicatClassification_trackingMatch

{'preds': [array([1, 0, 1, ..., 0, 0, 0])],
 'spatialFootprints': [<3320x524288 sparse matrix of type '<class 'numpy.float32'>'
  	with 411119 stored elements in Compressed Sparse Row format>],
 'FOV_height': 512,
 'FOV_width': 1024}

In [27]:
if use_multiple_sessions:
    FOVs_colored = roicat.visualization.compute_colored_FOV(
        spatialFootprints=results_roicat_tracking['ROIs']['ROIs_aligned'], 
        FOV_height=results_roicat_tracking['ROIs']['frame_height'], 
        FOV_width=results_roicat_tracking['ROIs']['frame_width'], 
    #     labels=results_tracking['UCIDs_bySession'], 
        labels=ucids_iscell_fullMatch,
    )

    %matplotlib notebook

    roicat.visualization.display_toggle_image_stack(FOVs_colored)

In [30]:
if use_multiple_sessions:
    F_matched, Fneu_matched = (roicat.util.match_arrays_with_ucids(f, ucids_iscell_fullMatch) for f in (F_all, Fneu_all))
else:
    F_matched, Fneu_matched = F_all, Fneu_all

In [31]:
def plot_paired_trace_reduction(traces, reduction=np.nanmean):
    """
    args:
        traces (ndarray):
            shape(n_sessions, n_neurons, n_timepoints)
        reduction (function):
            function that accepts 'axis' argument
    """
    red_traces = [reduction(t, axis=1) for t in traces]
    red_traces_max = np.nanmax(np.nanmax(red_traces))

    plt.figure()
    plt.scatter(red_traces[0], red_traces[1], alpha=0.1)
    plt.xlim([0, red_traces_max])
    plt.ylim([0, red_traces_max])
    plt.plot([0, red_traces_max], [0, red_traces_max], 'k')
    plt.xlabel(f'traces day 0: {reduction}')
    plt.ylabel(f'traces day 1: {reduction}')
    

In [33]:
if use_multiple_sessions:
    plot_paired_trace_reduction(
        F_matched,
        reduction=partial(np.percentile, q=30)
    )
    plot_paired_trace_reduction(
        Fneu_matched, 
        reduction=partial(np.percentile, q=30)
    )

In [34]:
n_frames, n_rois, n_sessions = sum([f.shape[1] for f in F_matched]), F_matched[0].shape[0], len(F_matched)
Fs = ops_all[0]['fs']

In [36]:
# # Oopsie!
# # Mouse g2FB: delete 43000 - 50000 frames
# F_toUse = np.delete(F_toUse, range(43000,50000), axis=1)
# Fneu_toUse = np.delete(Fneu_toUse, range(43000,50000), axis=1)

In [37]:
percentile_baseline = 30
neuropil_fraction=0.7

outs = [bnpm.ca2p_preprocessing.make_dFoF(
    F=f,
    Fneu=fneu,
    neuropil_fraction=neuropil_fraction,
    percentile_baseline=percentile_baseline,
    rolling_percentile_window=None,
    multicore_pref=True,
    verbose=True
) for f, fneu in zip(F_matched, Fneu_matched)]

dFoF_matched , dF_matched , F_neuSub_matched , F_baseline_matched = ([outs[i_sesh][i_out] for i_sesh in range(n_sessions)] for i_out in range(len(outs[0])))

dFoF_params = {
    "channelOffset_correction": 0,
    "percentile_baseline": percentile_baseline,
    "neuropil_fraction": neuropil_fraction,
}

Calculated dFoF. Total elapsed time: 8.57 seconds


In [38]:
win_rolling_percentile = 15*60*30

outs = [bnpm.ca2p_preprocessing.make_dFoF(
    F=f,
    Fneu=fneu,
    neuropil_fraction=neuropil_fraction,
    percentile_baseline=percentile_baseline,
    rolling_percentile_window=win_rolling_percentile,
    roll_centered=True,
    roll_stride=1,
    roll_interpolation='linear',
    multicore_pref=True,
    verbose=True
) for f, fneu in zip(F_matched, Fneu_matched)]

dFoF_roll_matched , dF_roll_matched , F_roll_neuSub_matched , F_roll_baseline_matched = ([outs[i_sesh][i_out] for i_sesh in range(n_sessions)] for i_out in range(len(outs[0])))

dFoF_params = {
    "channelOffset_correction": 0,
    "percentile_baseline": percentile_baseline,
    "neuropil_fraction": neuropil_fraction,
}

100%|███████████████████████████████████████████| 36/36 [00:29<00:00,  1.23it/s]


Calculated dFoF. Total elapsed time: 38.35 seconds


In [39]:
if use_multiple_sessions:
    plot_paired_trace_reduction(
        dFoF_roll_matched, 
        reduction=np.linalg.norm,
    )

In [45]:
if use_multiple_sessions:
    norms = np.stack([np.linalg.norm(dfof, axis=1) for dfof in dFoF_roll_matched])
    # folds = (norms[0] - norms[1]) / ((norms[0] + norms[1])/2)
    fanos = np.var(norms, axis=0) / np.mean(norms, axis=0)

    plt.figure()
    plt.hist(fanos, bins=bnpm.math_functions.bounded_logspace(1e-6, 1000, 200));
    plt.xscale('log')
    plt.xlabel('fractional change in norm of dFoF')

In [111]:
F_roll_baseline_matched[0][iscell_roicatClassification_trackingMatch[0].astype(np.bool_)].shape

(1653, 108000)

In [126]:
%matplotlib notebook

thresh = {
    'var_ratio__Fneu_over_F': (0, 0.6),
    'EV__F_by_Fneu': (0, 0.75),
    'base_FneuSub': (100, 2000),
    'base_F': (200, 3500),
    'nsr_autoregressive': (0, 6),
    'noise_derivMAD': (0, 0.02),
    'max_dFoF': (0.75, 10),
    'baseline_var': (0, 0.03),
    
#     'var_ratio__Fneu_over_F': (0, np.inf),
#     'EV__F_by_Fneu': (0, 0.9),
#     'base_FneuSub': (75, 2000),
#     'base_F': (200, 3500),
#     'nsr_autoregressive': (0, 25),
#     'noise_derivMAD': (0, 0.1),
#     'max_dFoF': (0, 10),
#     'baseline_var': (0, 0.1),

}
# thresh = {
#     'var_ratio__Fneu_over_F': np.inf,
#     'EV__F_by_Fneu': np.inf,
#     'base_FneuSub': -np.inf,
#     'base_F': -np.inf,
#     'nsr_autoregressive': np.inf,
#     'noise_derivMAD': np.inf,
#     'max_dFoF': np.inf,
#     'baseline_var': np.inf,
# }
    
outs = [bnpm.ca2p_preprocessing.trace_quality_metrics(
    F=f,
    Fneu=fneu,
    dFoF=dfof,
    F_neuSub=fneusub,
    F_baseline_roll=fbs,

#     F=f[iscell_roicatClassification_trackingMatch[0].astype(np.bool_)],
#     Fneu=fneu[iscell_roicatClassification_trackingMatch[0].astype(np.bool_)],
#     dFoF=dfof[iscell_roicatClassification_trackingMatch[0].astype(np.bool_)],
#     F_neuSub=fneusub[iscell_roicatClassification_trackingMatch[0].astype(np.bool_)],
#     F_baseline_roll=fbs[iscell_roicatClassification_trackingMatch[0].astype(np.bool_)],

    percentile_baseline=percentile_baseline,
    Fs=Fs,
    plot_pref=True,
    thresh=thresh,
    device='cpu',
) for f, fneu, dfof, fneusub, fbs in zip(F_matched, Fneu_matched, dFoF_roll_matched, F_neuSub_matched, F_roll_baseline_matched)]

tqm, iscell_tqm = ([outs[ii][jj] for ii in range(len(outs))] for jj in range(len(outs[0])))

idxROI_tqm_toInclude = [np.where(ic)[0] for ic in iscell_tqm]
idxROI_tqm_toExclude = [np.where(~ic)[0] for ic in iscell_tqm]

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

ROIs excluded: 1218 / 3320
ROIs included: 2102 / 3320


In [129]:
import rastermap

rmap = rastermap.Rastermap(n_components=1)
rmap.fit(dFoF_roll_matched[0][iscell_tqm[0] * iscell_roicatClassification_trackingMatch[0].astype(np.bool_)])

nmin 200
0.20216727256774902
10.071236371994019
10.517228126525879
10.519766092300415
(38, 40)
(70,)
1.0
time; iteration;  explained PC variance
0.00s     0        0.0663      2
0.06s    10        0.2083      4
0.10s    20        0.2301      8
0.23s    30        0.2917      18
0.34s    40        0.3243      28
0.41s    50        0.3577      38
0.51s    60        0.3581      38
0.65s   final      0.3581
0.65s upsampled    0.3581


<rastermap.mapping.Rastermap at 0x7f4b08040be0>

In [131]:
plt.figure()
plt.imshow(dFoF_roll_matched[0][iscell_tqm[0] * iscell_roicatClassification_trackingMatch[0].astype(np.bool_)][rmap.isort], aspect='auto', vmin=-0.1, vmax=2)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f4c2d339f40>

In [138]:
iscell_conj = [(ic_rc_tm * ic_tqm).astype(np.bool_) for ic_rc_tm, ic_tqm in zip(iscell_roicatClassification_trackingMatch, iscell_tqm)]

In [142]:
if use_multiple_sessions:
    ucids_iscell_fullMatch_tqm = roicat.util.discard_UCIDs_with_fewer_matches(
        ucids=roicat.util.mask_UCIDs_with_iscell(
            ucids=results_roicat_tracking['UCIDs_bySession'],
            iscell=iscell_conj,
        ),
        n_sesh_thresh='all',
    )

In [145]:
if use_multiple_sessions:
    FOVs_colored = roicat.visualization.compute_colored_FOV(
        spatialFootprints=results_roicat_tracking['ROIs']['ROIs_aligned'], 
        FOV_height=results_roicat_tracking['ROIs']['frame_height'], 
        FOV_width=results_roicat_tracking['ROIs']['frame_width'], 
    #     labels=results_tracking['UCIDs_bySession'], 
        labels=ucids_iscell_fullMatch_tqm,
    )

    %matplotlib notebook

    roicat.visualization.display_toggle_image_stack(FOVs_colored)

In [147]:
plt.figure()
[plt.plot(dFoF_roll_matched[ii][np.where(iscell_conj[ii][iscell_roicatClassification_trackingMatch[ii]])[ii][-10:]].T + np.arange(10)*4, alpha=0.3) for ii in range(len(dFoF_roll_matched))];


<IPython.core.display.Javascript object>

In [148]:
plt.figure()
[plt.plot(dFoF_roll_matched[ii][np.where(~iscell_conj[ii][iscell_roicatClassification_trackingMatch[ii]])[ii][-10:]].T + np.arange(10)*4, alpha=0.3) for ii in range(len(dFoF_roll_matched))];


<IPython.core.display.Javascript object>

In [149]:
params['paths']

{'dir_save': '/media/rich/bigSSD/downloads_tmp/tmp_data/mouse_0315N/day0_2/',
 'dir_s2p_outer': '/media/rich/bigSSD/downloads_tmp/tmp_data/mouse_0315N/statFiles/20230423/'}

In [159]:
bnpm.file_helpers.pickle_save(
    obj={
        "tqm": tqm,
        "iscell_tqm": iscell_tqm,
        "iscell_classifier_matching_tqm": iscell_conj,
        "dFoF_params": dFoF_params
    },
    filepath=str(Path(params['paths']['dir_save']).resolve() / 'trace_quality.pkl')
)

bnpm.file_helpers.pickle_save(
    obj={
        "iscell_classifier_matching_tqm": iscell_conj,
        "UCIDs_classifier_matching_tqm": ucids_iscell_fullMatch_tqm if use_multiple_sessions else None,
    },
    filepath=str(Path(params['paths']['dir_save']).resolve() / 'iscell_classifier_matching_tqm.pkl'),
)