# Motion correction and cells detection
---

Microglia cells with tdTomato + GCamp5f

In [None]:
import cv2
try:
    cv2.setNumThreads(8)
except():
    pass

import os
import sys
import glob
import yaml
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import caiman as cm
from caiman.motion_correction import MotionCorrect
from caiman.source_extraction.cnmf import cnmf as cnmf
from caiman.source_extraction.cnmf import params as params
from caiman.source_extraction.cnmf import utilities as util

from skimage.util import montage
from skimage.filters import rank
from skimage import morphology
from skimage import exposure
from skimage import measure

from scipy.ndimage import measurements

import bokeh.plotting as bpl
import holoviews as hv
bpl.output_notebook()
hv.notebook_extension('bokeh')

# Parameters

#### Input files path

In [None]:
samp_name = 'F7'
samp_path = os.path.join(''.join(sys.path[0].split('glia')), 'data_glia', samp_name)

# sample YAML metadata file uploading
with open(f'{samp_path}/{samp_name}_meta.yaml') as f:
    samp_meta = yaml.safe_load(f)

file_raw = f'{samp_path}/{samp_name}_ca.tif'
file_memmap = f'{samp_path}/F7_memmap_d1_512_d2_512_d3_1_order_C_frames_1500.mmap'
file_fit = f'{samp_path}/{samp_name}_fit.hdf5'
file_refit = f'{samp_path}/{samp_name}_refit.hdf5'

#### CaImAn parameters

In [None]:
# data params
ca_path = [file_raw]
fr = 1                      # imaging rate in frames per second
decay_time = 3              # length of a typical transient in seconds (see source/Getting_Started.rst)
dxy = (0.311, 0.311)        # spatial resolution of FOV in pixels per um

# patch params
rf =  50                    # half-size of the patches in pixels. e.g., if rf=25, patches are 50x50
stride = 10                 # amount of overlap between the patches in pixels

# pre-peocess params
noise_method = 'logmexp'     # PSD averaging method for computing the noise std
only_init = False

# motion correction params
max_deviation_rigid = 3     # maximum shifts deviation allowed for patch with respect to rigid shifts
max_shifts = (10, 10)       # maximum allowed rigid shifts (in pixels)
strides = (80, 80)          # start a new patch for pw-rigid motion correction every x pixels
overlaps = (20, 20)         # overlap between pathes (size of patch strides+overlaps)
pw_rigid = True             # flag for performing non-rigid motion correction

# init params
K = 4                            # number of components to be found (per patch or whole FOV depending on whether rf=None)
gSig = [15, 15]                  # radius of average spatial components (in pixels)
ssub = 5                         # spatial subsampling during initialization
tsub = 2                         # temporal subsampling during intialization
method_init = 'greedy_roi'       # initialization method ('sparse_nmf' NOT WORKING!),   'graph_nmf'
seed_method = 'auto'             # methods for choosing seed pixels during greedy_roi or corr_pnr initialization

# merge params
merge_thr = 0.1                  # trace correlation threshold for merging two components.
merge_parallel = True            # perform merging in parallel

# spatial and temporal params
nb = 2                           # number of global background components
method_deconvolution = 'oasis'   # method for solving the constrained deconvolution problem ('oasis','cvx' or 'cvxpy') if method cvxpy, primary and secondary (if problem unfeasible for approx solution)
noise_range = [0.25, 0.5]        # range of normalized frequencies over which to compute the PSD for noise determination
noise_method = 'logmexp'         # PSD averaging method for computing the noise std
p = 2                            # order of the autoregressive system

# quality params
min_SNR = 5                      # trace SNR threshold. Traces with SNR above this will get accepted
SNR_lowest = 2                   # minimum required trace SNR. Traces with SNR below this will get rejected
rval_thr = 0.65                  # space correlation threshold. Components with correlation higher than this will get accepted                 
rval_lowest = -2                 # minimum required space correlation. Components with correlation below this will get rejected
use_cnn = False                  # flag for using the CNN classifier


param_dict = {'fnames': ca_path,
              'fr': fr,
              'decay_time': decay_time,
              'dxy': dxy,
              'rf': rf,
              'stride': stride,
              'noise_method': noise_method,
              'only_init': only_init,
              'max_deviation_rigid': max_deviation_rigid,
              'max_shifts': max_shifts,
              'strides': strides,
              'overlaps': overlaps,
              'pw_rigid': pw_rigid,
              'K': K,
              'gSig': gSig,
              'ssub': ssub,
              'tsub': tsub,
              'method_init': method_init,
              'seed_method': seed_method,
              'merge_thr': merge_thr,
              'merge_parallel': merge_parallel,
              'nb': nb,
              'method_deconvolution': method_deconvolution,
              'noise_range': noise_range,
              'noise_method': noise_method,
              'p': p,
              'min_SNR': min_SNR,
              'SNR_lowest': SNR_lowest,
              'rval_thr': rval_thr,
              'rval_lowest': rval_lowest,
              'use_cnn': use_cnn}

opts = params.CNMFParams(params_dict=param_dict)

# Motion correction
*if there is no existing memmap*

In [None]:
# if True - Ca channel demostration on
display_movie = True
if display_movie:
    m_orig = cm.load_movie_chain(ca_path)
    ds_ratio = 0.31
    m_orig.resize(1, 1, ds_ratio).play(
        q_max=99.5, fr=100, magnification=1)

### Motion correction

In [None]:
# start a cluster for parallel processing (if a cluster already exists it will be closed and a new session will be opened)
if 'dview' in locals():
    cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='local', n_processes=None, single_thread=False)

mc = MotionCorrect(ca_path, dview=dview, **opts.get_group('motion'))

mc.motion_correct(save_movie=samp_path)
m_els = cm.load(mc.fname_tot_els)
border_to_0 = 0 if mc.border_nan == 'copy' else mc.border_to_0 

In [None]:
# display results
display_movie = True
save_avi = True
if display_movie:
    m_orig = cm.load_movie_chain(ca_path)
    ds_ratio = 0.2
    cm.concatenate([m_orig.resize(1, 1, ds_ratio) - mc.min_mov*mc.nonneg_movie,
                    m_els.resize(1, 1, ds_ratio)], 
                   axis=2).play(fr=30, q_max=99.5, magnification=1, offset=0, save_movie=save_avi)

In [None]:
# memory map the file in order 'C' saving
ca_new = cm.save_memmap(mc.mmap_file, base_name=f'{samp_name}_memmap_', order='C',
                        border_to_0=border_to_0) # exclude borders


# CNMF fit & refit

*if there is no existing refited CNMF*

## memmap loading

In [None]:
if isinstance(file_memmap, str):
    Yr, dims, T = cm.load_memmap(file_memmap)
else:
    cm.load_memmap(ca_new)

ca_images = np.reshape(Yr.T, [T] + list(dims), order='F') 
Cn = cm.local_correlations(ca_images, swap_dim=False)
Cn[np.isnan(Cn)] = 0

plt.figure(figsize=(8, 8))
plt.imshow(Cn, cmap='magma')
plt.show()

## Fit section

#### Start cluster

In [None]:
#%% start a cluster for parallel processing (if a cluster already exists it will be closed and a new session will be opened)
if 'dview' in locals():
    cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='local', n_processes=None, single_thread=False)

In [None]:
cnm = cnmf.CNMF(n_processes, params=opts, dview=dview)
if isinstance(file_fit, str):
    cnm = cnmf.load_CNMF(file_fit, n_processes=1, dview=dview)
else:   
    cnm = cnm.fit(ca_images)
    save_results = True
    if save_results:
        cnm.save(f'{samp_path}/{samp_name}_fit.hdf5')

cnm.estimates.plot_contours_nb(img=Cn, cmap='magma')

## Refit section

#### Filtering

In [None]:
cnm.estimates.evaluate_components(ca_images, cnm.params, dview=dview)

min_size = 100              # minimal component area in px
max_size = 512^2             # maximal component area in px

cnm.estimates.threshold_spatial_components(maxthr=0.3, dview=dview)
cnm.estimates.remove_small_large_neurons(min_size_neuro=min_size, max_size_neuro=max_size)
cnm.estimates.plot_contours_nb(img=Cn, idx=cnm.estimates.idx_components, cmap='magma')

print(f'{len(cnm.estimates.idx_components)} good components: {cnm.estimates.idx_components}')
print(f'{len(cnm.estimates.idx_components_bad)} bad components: {cnm.estimates.idx_components_bad}')

# cnm.estimates.select_components(idx_components=cnm.estimates.idx_components)

#### Refit

In [None]:
if isinstance(file_refit, str):
    cnm2 = cnmf.load_CNMF(file_refit, n_processes=1, dview=dview)
else:   
    cnm2 = cnm.refit(ca_images, dview=dview)

    save_results = True
    if save_results:
        cnm2.save(f'{samp_path}/{samp_name}_refit.hdf5')

In [None]:
#%% restart cluster to clean up memory
cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='local', n_processes=None, single_thread=False)

In [None]:
cnm2.estimates.evaluate_components(ca_images, cnm.params, dview=dview)

In [None]:
cnm.estimates.plot_contours_nb(img=Cn, idx=cnm.estimates.idx_components, cmap='magma')
cnm2.estimates.plot_contours_nb(img=Cn, idx=cnm2.estimates.idx_components, cmap='magma')

#### Final CNMF selection

In [None]:
fin_cnm = cnm2
print(fin_cnm.estimates.idx_components)

# Plot & output

#### Plot func

In [None]:
def dF_cascade_plot(prof_arr, y_shift=0.5):
    """ prof_arr, [prof_num, prof_val] - 2d numpy array with dF/F profiles

    """
    plt.figure(figsize=(20, 8))

    shift = 0
    for i in prof_arr:
        plt.plot(i+shift, alpha=.5)
        shift += y_shift
    plt.vlines(x=[-20], ymin=[-0.2], ymax=[0.8], linewidth=3, color='k')
    plt.text(x=-60, y=-0.1, s="100% ΔF/F", size=15, rotation=90.)
    plt.axis('off')
    plt.show()

def comp_mask_plot(samp_cnmf, samp_img):
    """ All spatial components (A) mask ctrl img

    """
    A = samp_cnmf.estimates.A.toarray().reshape(samp_img.shape + (-1,), order='F').transpose([2, 0, 1])
    all_A = np.zeros_like(A[0])

    component_centers = {}
    for i in samp_cnmf.estimates.idx_components:  # range(A_shape[0]):
        A_frame = A[i]
        A_frame != 0
        component_centers.update({i:np.asarray(measurements.center_of_mass(A_frame), dtype=int)})
        all_A[np.array(A_frame, dtype=bool)] = i

    all_A = np.ma.masked_where(all_A == 0, all_A, copy=False)

    plt.figure(figsize=(10,10))
    plt.imshow(samp_img, cmap='magma')
    plt.imshow(all_A, cmap='jet', alpha=.5)
    for component_num in component_centers.keys():
        center_coord = component_centers[component_num]
        plt.annotate(component_num+1, # this is the value which we want to label (text)
                    (center_coord[1], center_coord[0]), # x and y is the points location where we have to label
                    textcoords="offset points",
                    xytext=(2,2),
                    ha='center',
                    color='white')
    plt.show()

def comp_contour_plot(samp_cnmf, samp_img):
    """ All spatial components (A) contours overlap ctrl img

    https://stackoverflow.com/questions/28779559/how-to-set-same-color-for-markers-and-lines-in-a-matplotlib-plot-loop

    """
    A = samp_cnmf.estimates.A.toarray().reshape(samp_img.shape + (-1,), order='F').transpose([2, 0, 1])

    plt.figure(figsize=(10,10))
    plt.imshow(samp_img, cmap='magma')
    ax = plt.gca()

    for i in samp_cnmf.estimates.idx_components: 
        A_frame = A[i]
        A_frame[A_frame != 0] = 1
        A_frame = np.ma.masked_where(A_frame == 0, A_frame, copy=False)
        A_center = measurements.center_of_mass(A_frame)
        A_contour = np.asarray(measure.find_contours(A_frame, level=0.5))

        plt.imshow(A_frame, cmap='jet', alpha=.5)
        color = next(ax._get_lines.prop_cycler)['color']
        for cont in A_contour:            
            plt.plot(cont[:, 1], cont[:, 0], linewidth=2, color=color)
        plt.annotate(f'ROI {i+1}',
                    (A_center[1], A_center[0]),
                    textcoords="offset points",
                    xytext=(2,2),
                    ha='center',
                    color='white',
                    weight='bold',
                    fontsize=10)
    plt.axis('off')
    plt.show()

def comp_dF_plot(samp_cnmf, y_shift=0.5):
    plt.figure(figsize=(20, 8))

    shift = 0
    for i in samp_cnmf.estimates.idx_components:
        df_prof = samp_cnmf.estimates.F_dff[i]
        plt.plot(df_prof+shift, alpha=.5, label=f'ROI {i+1}')
        shift -= y_shift

    plt.vlines(x=[-20], ymin=[-0.2], ymax=[0.8], linewidth=3, color='k')
    plt.text(x=-60, y=-0.5, s="100% ΔF/F", size=15, rotation=90.)
    plt.axis('off')
    plt.legend(loc=1)
    plt.show()

#### ΔF/F calc & ctrl plot

In [None]:
fin_cnm.estimates.detrend_df_f(quantileMin=5, frames_window=500,
                            flag_auto=True, use_fast=False, detrend_only=False)

comp_contour_plot(fin_cnm, Cn)
comp_dF_plot(fin_cnm)

#### Output CSV saving

In [None]:
def save_prof_df(samp_cnmf, samp_img, reg_time, samp_name, samp_path):
    A = samp_cnmf.estimates.A.toarray().reshape(samp_img.shape[1:] + (-1,), order='F').transpose([2, 0, 1])

    # init df
    output_df = pd.DataFrame(columns=['reg_name',      # registration name
                                      'index',         # frame index
                                      'time',          # registration time
                                      'comp',          # component num
                                      'profile_raw',   # component raw value, total mean
                                      'profile_C',     # component denoised value
                                      'profile_ddf'])   # component detrended ΔF/F value
    
    frame_num = samp_cnmf.estimates.C.shape[1]
    i_col = range(frame_num)
    time_col = np.linspace(0, reg_time, num=frame_num)
    reg_name_col = np.full(frame_num, samp_name)

    for component_num in samp_cnmf.estimates.idx_components:
        component_col = np.full(ca_images.shape[0], component_num)
        
        A_frame = A[component_num]
        A_mask = np.copy(A_frame)
        A_mask != 0
        A_mask = np.array(A_mask, dtype=bool)

        # mean by spatial component mask
        est_raw = np.asarray([np.mean(np.ma.masked_where(~A_mask, frame)) for frame in samp_img])

        # temporal component
        est_C = samp_cnmf.estimates.C[component_num]

        # detrended temporal component
        est_df = samp_cnmf.estimates.F_dff[component_num]
        
        component_df = pd.DataFrame({'reg_name':reg_name_col,
                                     'index':i_col,
                                     'time':time_col,
                                     'comp':component_col,
                                     'profile_raw':est_raw,
                                     'profile_C':est_C,
                                     'profile_ddf':est_df})
        output_df = pd.concat([output_df, component_df], ignore_index=True)

    output_df.to_csv(f'{samp_path}/{samp_name}_components_df.csv', index=False)
    print(output_df.head())

save_prof_df(samp_cnmf=fin_cnm,
             samp_img=ca_images,
             samp_name=samp_name,
             samp_path=samp_path, 
             reg_time=samp_meta['Reg_time'])

#### Stop cluster and clean up log files

In [None]:
#%% STOP CLUSTER and clean up log files
cm.stop_server(dview=dview)
log_files = glob.glob('*_LOG_*')
for log_file in log_files:
    os.remove(log_file)

In [None]:
# diff stuff

# cnm.estimates.play_movie(ca_images, q_max=95, magnification=1, include_bck=False, use_color=True, save_movie=True)

# time-series correlation: https://towardsdatascience.com/four-ways-to-quantify-synchrony-between-time-series-data-b99136c4a9c9