In [None]:
import os
import matplotlib.pyplot as plt
import pylab as  pl
import caiman as cm
from caiman.utils.utils import download_demo
from caiman.source_extraction.cnmf import params as params
from caiman.utils.visualization import get_contours
from scipy.sparse.linalg import inv
from scipy.sparse import csc_matrix
from caiman.base.rois import com
from skimage.measure import find_contours
import cv2

In [None]:
import caiman as cm
#%% start a cluster for parallel processing (if a cluster already exists it will be closed and a new session will be opened)
if 'dview' in locals():
    cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='local', n_processes=None, single_thread=False)

### video load

In [None]:
# filepath = os.path.join('../data', 'Sue_2x_3000_40_-46.tif')
# assert os.path.exists(filepath)
filepath = 'Sue_2x_3000_40_-46.tif'
fnames = [download_demo(filepath)]
fnames

In [None]:
display_movie = False
if display_movie:
    m_orig = cm.load_movie_chain(fnames)
    ds_ratio = 0.2
    m_orig.resize(1, 1, ds_ratio).play(
        q_max=99.5, fr=30, magnification=2)

### setup parameter

In [None]:
# dataset dependent parameters
fr = 30                             # imaging rate in frames per second
decay_time = 0.4                    # length of a typical transient in seconds

# motion correction parameters
strides = (48, 48)          # start a new patch for pw-rigid motion correction every x pixels
overlaps = (24, 24)         # overlap between pathes (size of patch strides+overlaps)
max_shifts = (6,6)          # maximum allowed rigid shifts (in pixels)
max_deviation_rigid = 3     # maximum shifts deviation allowed for patch with respect to rigid shifts
pw_rigid = True             # flag for performing non-rigid motion correction

# parameters for source extraction and deconvolution
p = 1                       # order of the autoregressive system
gnb = 2                     # number of global background components
merge_thr = 0.85            # merging threshold, max correlation allowed
rf = 15                     # half-size of the patches in pixels. e.g., if rf=25, patches are 50x50
stride_cnmf = 6             # amount of overlap between the patches in pixels
K = 4                       # number of components per patch
gSig = [4, 4]               # expected half size of neurons in pixels
method_init = 'greedy_roi'  # initialization method (if analyzing dendritic data using 'sparse_nmf')
ssub = 1                    # spatial subsampling during initialization
tsub = 1                    # temporal subsampling during intialization

# parameters for component evaluation
min_SNR = 2.0               # signal to noise ratio for accepting a component
rval_thr = 0.85              # space correlation threshold for accepting a component
cnn_thr = 0.99              # threshold for CNN based classifier
cnn_lowest = 0.1 # neurons with cnn probability lower than this value are rejected

In [None]:
opts_dict = {'fnames': fnames,
            'fr': fr,
            'decay_time': decay_time,
            'strides': strides,
            'overlaps': overlaps,
            'max_shifts': max_shifts,
            'max_deviation_rigid': max_deviation_rigid,
            'pw_rigid': pw_rigid,
            'p': p,
            'nb': gnb,
            'rf': rf,
            'K': K, 
            'stride': stride_cnmf,
            'method_init': method_init,
            'rolling_sum': True,
            'only_init': True,
            'ssub': ssub,
            'tsub': tsub,
            'merge_thr': merge_thr, 
            'min_SNR': min_SNR,
            'rval_thr': rval_thr,
            'use_cnn': True,
            'min_cnn_thr': cnn_thr,
            'cnn_lowest': cnn_lowest}
border_nan = 'copy' 

from caiman.source_extraction.cnmf import params as params
opts = params.CNMFParams(params_dict=opts_dict)
opts

In [None]:
'dview' in locals()

### setup cluster

In [None]:
#%% start a cluster for parallel processing (if a cluster already exists it will be closed and a new session will be opened)
if 'dview' in locals():
    cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='local', n_processes=None, single_thread=False)

### motion correct

In [None]:
from caiman.motion_correction import MotionCorrect
# first we create a motion correction object with the parameters specified
mc = MotionCorrect(fnames, dview=dview, **opts.get_group('motion'))

In [None]:
dview

In [None]:
%%capture
#%% Run piecewise-rigid motion correction using NoRMCorre
mc.motion_correct(save_movie=True)
m_els = cm.load(mc.fname_tot_els)
border_to_0 = 0 if mc.border_nan is 'copy' else mc.border_to_0 
    # maximum shift to be used for trimming against NaNs

In [None]:
#%% compare with original movie
display_movie = False
if display_movie:
    m_orig = cm.load_movie_chain(fnames)
    ds_ratio = 0.2
    cm.concatenate([
        m_orig.resize(1, 1, ds_ratio) - mc.min_mov*mc.nonneg_movie, 
        m_els.resize(1, 1, ds_ratio)
    ], axis=2
    ).play(fr=60, gain=15, magnification=2, offset=0)  # press q to exit

In [None]:
import numpy as np
fname_mc = mc.fname_tot_els if pw_rigid else mc.fname_tot_rig

bord_px = 0 if border_nan is 'copy' else bord_px
fname_new = cm.save_memmap(fname_mc, base_name='memmap_', order='C', border_to_0=bord_px)

In [None]:
# load memory mappable file
Yr, dims, T = cm.load_memmap(fname_new)
images = Yr.T.reshape((T,) + dims, order='F')

In [None]:
from caiman.source_extraction import cnmf
Ain = None

cnm = cnmf.CNMF(n_processes=n_processes, dview=dview, Ain=Ain, params=opts)
cnm.fit(images)

In [None]:
#%% plot contours of found components
import caiman as cm
Cn = cm.local_correlations(images.transpose(1,2,0))
Cn[np.isnan(Cn)] = 0
cnm.estimates.plot_contours_nb(img=Cn)

In [None]:
cnm.estimates.A

In [None]:
coordinates = get_contours(cnm.estimates.A, np.shape(Cn))

In [None]:
print(len(coordinates))
print(len(coordinates[0]['coordinates']))

In [None]:
# pl.imshow(Cn)
for c in coordinates:
    pl.plot(*c['coordinates'].T)
plt.gca().invert_yaxis()

In [None]:
thr = 0.9
thr_method = 'nrg'
cont = get_contours(
    cnm.estimates.A, cnm.dims, thr=thr, thr_method=thr_method, swap_dim=False)
cont_cent = np.zeros([len(cont), 2])
sparse_rois = []
for i in range(len(cont)):
    cont_cent[i, :] = np.nanmean(cont[i]['coordinates'], axis=0)
    sparse_rois.append(cont[i]['coordinates'].T)

iscell = np.zeros(cont_cent.shape[0]).astype(np.bool)
iscell[cnm.estimates.idx_components] = True

In [None]:
for c in sparse_rois:
    plt.plot(*c)
plt.gca().invert_yaxis()

In [None]:
print(im[~np.isnan(im)].shape)
print(im[np.isnan(im)].shape)

In [None]:

A = cnm.estimates.A
d, nr = np.shape(A)
# cm = com(A, *dims)

# for each patches
ims = []
for i in range(nr):
    pars = dict()
    # we compute the cumulative sum of the energy of the Ath component that has been ordered from least to highest
    patch_data = A.data[A.indptr[i]:A.indptr[i + 1]]
    indx = np.argsort(patch_data)[::-1]

    if thr_method == 'nrg':
        cumEn = np.cumsum(patch_data[indx]**2)
        if len(cumEn) == 0:
            pars = dict(
                coordinates=np.array([]),
                CoM=np.array([np.NaN, np.NaN]),
                neuron_id=i + 1,
            )
            coordinates.append(pars)
            continue
        else:
            # we work with normalized values
            cumEn /= cumEn[-1]
            Bvec = np.ones(d)
            # we put it in a similar matrix
            Bvec[A.indices[A.indptr[i]:A.indptr[i + 1]][indx]] = cumEn
    else:
        Bvec = np.zeros(d)
        Bvec[A.indices[A.indptr[i]:A.indptr[i + 1]]] = patch_data / patch_data.max()

    Bmat = np.reshape(Bvec, dims, order='F')

    r_mask = np.zeros_like(Bmat, dtype='bool')
    contour = find_contours(Bmat.T, thr)
    for c in contour:
        r_mask[np.round(c[:, 0]).astype('int'), np.round(c[:, 1]).astype('int')] = 1
    
    # Fill in the hole created by the contour boundary
    r_mask = ndimage.binary_fill_holes(r_mask)
#     ims.append(r_mask + (i * r_mask))
    ims.append(r_mask)
ims = np.stack(ims)

In [None]:
# ims = ims.astype(np.float32)
# ims[ims == 0] = np.nan
plt.imshow(np.nanmax(ims, axis=0))

In [None]:
ims = ims.astype(np.float32)
ims[ims == 0] = np.nan
plt.imshow(np.nanmax(ims, axis=0))