# Volumetric data processing
This is a simple demo on toy 3d data for source extraction and deconvolution using CaImAn.
For more information check demo_pipeline.ipynb which performs the complete pipeline for
2d two photon imaging data.

In [None]:
from IPython import get_ipython
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy.ndimage import gaussian_filter
from tifffile.tifffile import imwrite

import caiman as cm
from caiman.utils.visualization import nb_view_patches3d
import caiman.source_extraction.cnmf as cnmf
from caiman.paths import caiman_datadir

try:
    if __IPYTHON__:
        get_ipython().run_line_magic('load_ext', 'autoreload')
        get_ipython().run_line_magic('autoreload', '2')
except NameError:
    pass

import bokeh.plotting as bpl
bpl.output_notebook()

logfile = None # Replace with a path if you want to log to a file
logger = logging.getLogger('caiman')
# Set to logging.INFO if you want much output, potentially much more output
logger.setLevel(logging.WARNING)
logfmt = logging.Formatter('%(relativeCreated)12d [%(filename)s:%(funcName)20s():%(lineno)s] [%(process)d] %(message)s')
if logfile is not None:
    handler = logging.FileHandler(logfile)
else:
    handler = logging.StreamHandler()
handler.setFormatter(logfmt)
logger.addHandler(handler)


Define a function to create some toy data

In [None]:
def gen_data(p=1, noise=.5, T=256, framerate=30, firerate=2., motion=True, plot=False):
    if p == 2:
        gamma = np.array([1.5, -.55])
    elif p == 1:
        gamma = np.array([.9])
    else:
        raise
    dims = (70, 50, 10)  # size of image
    sig = (4, 4, 2)  # neurons size
    bkgrd = 10
    N = 20  # number of neurons
    np.random.seed(0)
    centers = np.asarray([[np.random.randint(s, x - s)
                           for x, s in zip(dims, sig)] for i in range(N)])
    if motion:
        centers += np.array(sig) * 2
        Y = np.zeros((T,) + tuple(np.array(dims) + np.array(sig) * 4), dtype=np.float32)      
    else:
        Y = np.zeros((T,) + dims, dtype=np.float32)
    trueSpikes = np.random.rand(N, T) < firerate / float(framerate)
    trueSpikes[:, 0] = 0
    truth = trueSpikes.astype(np.float32)
    for i in range(2, T):
        if p == 2:
            truth[:, i] += gamma[0] * truth[:, i - 1] + gamma[1] * truth[:, i - 2]
        else:
            truth[:, i] += gamma[0] * truth[:, i - 1]
    for i in range(N):
        Y[:, centers[i, 0], centers[i, 1], centers[i, 2]] = truth[i]
    tmp = np.zeros(dims)
    tmp[tuple(np.array(dims)//2)] = 1.
    z = np.linalg.norm(gaussian_filter(tmp, sig).ravel())
    Y = bkgrd + noise * np.random.randn(*Y.shape) + 10 * gaussian_filter(Y, (0,) + sig) / z
    if motion:
        shifts = np.transpose([np.convolve(np.random.randn(T-10), np.ones(11)/11*s) for s in sig])
        Y = np.array([cm.motion_correction.apply_shifts_dft(img, (sh[0], sh[1], sh[2]), 0,
                                                                     is_freq=False, border_nan='copy')
                               for img, sh in zip(Y, shifts)])
        Y = Y[:, 2*sig[0]:-2*sig[0], 2*sig[1]:-2*sig[1], 2*sig[2]:-2*sig[2]]
    else:
        shifts = None
    T, d1, d2, d3 = Y.shape

    if plot:
        Cn = cm.local_correlations(Y, swap_dim=False)
        plt.figure(figsize=(15, 3))
        plt.plot(truth.T)
        plt.figure(figsize=(15, 3))
        for c in centers:
            plt.plot(Y[c[0], c[1], c[2]])

        d1, d2, d3 = dims
        x, y = (int(1.2 * (d1 + d3)), int(1.2 * (d2 + d3)))
        scale = 6/x
        fig = plt.figure(figsize=(scale*x, scale*y))
        axz = fig.add_axes([1-d1/x, 1-d2/y, d1/x, d2/y])
        plt.imshow(Cn.max(2).T, cmap='gray')
        plt.scatter(*centers.T[:2], c='r')
        plt.title('Max.proj. z')
        plt.xlabel('x')
        plt.ylabel('y')
        axy = fig.add_axes([0, 1-d2/y, d3/x, d2/y])
        plt.imshow(Cn.max(0), cmap='gray')
        plt.scatter(*centers.T[:0:-1], c='r')
        plt.title('Max.proj. x')
        plt.xlabel('z')
        plt.ylabel('y')
        axx = fig.add_axes([1-d1/x, 0, d1/x, d3/y])
        plt.imshow(Cn.max(1).T, cmap='gray')
        plt.scatter(*centers.T[np.array([0,2])], c='r')
        plt.title('Max.proj. y')
        plt.xlabel('x')
        plt.ylabel('z');
        plt.show()

    return Y, truth, trueSpikes, centers, dims, -shifts

### Select file(s) to be processed
- create a file with a toy 3d dataset.

In [None]:
fname = os.path.join(caiman_datadir(), 'example_movies', 'demoMovie3D.tif')
Y, truth, trueSpikes, centers, dims, shifts = gen_data(p=2)
imwrite(fname, Y)
print(fname)

### Display the raw movie (optional)

Show a max-projection of the correlation image

In [None]:
Y = cm.load(fname)
Cn = cm.local_correlations(Y, swap_dim=False)
d1, d2, d3 = dims
x, y = (int(1.2 * (d1 + d3)), int(1.2 * (d2 + d3)))
scale = 6/x
fig = plt.figure(figsize=(scale*x, scale*y))
axz = fig.add_axes([1-d1/x, 1-d2/y, d1/x, d2/y])
plt.imshow(Cn.max(2).T, cmap='gray')
plt.title('Max.proj. z')
plt.xlabel('x')
plt.ylabel('y')
axy = fig.add_axes([0, 1-d2/y, d3/x, d2/y])
plt.imshow(Cn.max(0), cmap='gray')
plt.title('Max.proj. x')
plt.xlabel('z')
plt.ylabel('y')
axx = fig.add_axes([1-d1/x, 0, d1/x, d3/y])
plt.imshow(Cn.max(1).T, cmap='gray')
plt.title('Max.proj. y')
plt.xlabel('x')
plt.ylabel('z');
plt.show()

Play the movie (optional). This will require loading the movie in memory which in general is not needed by the pipeline. Displaying the movie uses the OpenCV library. Press `q` to close the video panel.

In [None]:
Y[...,5].play(magnification=2)

### Setup a cluster

In [None]:
#%% start a cluster for parallel processing (if a cluster already exists it will be closed and a new session will be opened)
if 'dview' in locals():
    cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='multiprocessing', n_processes=None, single_thread=False)

### Motion Correction
First we create a motion correction object with the parameters specified. Note that the file is not loaded in memory

In [None]:
# motion correction parameters
opts_dict = {'fnames': fname,
            'strides': (24, 24, 6),    # start a new patch for pw-rigid motion correction every x pixels
            'overlaps': (12, 12, 2),   # overlap between patches (size of patch strides+overlaps)
            'max_shifts': (4, 4, 2),   # maximum allowed rigid shifts (in pixels)
            'max_deviation_rigid': 5,  # maximum shifts deviation allowed for patch with respect to rigid shifts
            'pw_rigid': False,         # flag for performing non-rigid motion correction
            'is3D': True}

opts = cnmf.params.CNMFParams(params_dict=opts_dict)

In [None]:
# first we create a motion correction object with the parameters specified
mc = cm.motion_correction.MotionCorrect(fname, dview=dview, **opts.get_group('motion'))
# note that the file is not loaded in memory

In [None]:
%%capture
#%% Run motion correction using NoRMCorre
mc.motion_correct(save_movie=True)

In [None]:
plt.figure(figsize=(12,3))
for i, s in enumerate((mc.shifts_rig, shifts)):
    plt.subplot(1,2,i+1)
    for k in (0,1,2):
        plt.plot(np.array(s)[:,k], label=('x','y','z')[k])
    plt.legend()
    plt.title(('inferred shifts', 'true shifts')[i])
    plt.xlabel('frames')
    plt.ylabel('pixels')

### Memory mapping 

The cell below memory maps the file in order `'C'` and then loads the new memory mapped file. The saved files from motion correction are memory mapped files stored in `'F'` order. Their paths are stored in `mc.mmap_file`.

In [None]:
#%% MEMORY MAPPING
# memory map the file in order 'C'
fname_new = cm.save_memmap(mc.mmap_file, base_name='memmap_', order='C',
                           border_to_0=0, dview=dview) # exclude borders

# now load the file
Yr, dims, T = cm.load_memmap(fname_new)
images = np.reshape(Yr.T, [T] + list(dims), order='F') 
    #load frames in python format (T x X x Y)

Now restart the cluster to clean up memory

In [None]:
#%% restart cluster to clean up memory
cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='multiprocessing', n_processes=None, single_thread=False)

## If data is small enough use a single patch approach

In [None]:
# set parameters
K = 20  # number of neurons expected per patch
gSig = [4, 4, 2]  # expected half size of neurons
merge_thresh = 0.8  # merging threshold, max correlation allowed
p = 2  # order of the autoregressive system

### Run CNMF

In [None]:
# INIT
cnm = cnmf.CNMF(n_processes, k=K, gSig=gSig, merge_thresh=merge_thresh, p=p, dview=dview)
cnm.params.set('spatial', {'se': np.ones((3,3,1), dtype=np.uint8)})

In [None]:
%%capture
# FIT
cnm.fit(images)

### View the results
View components per plane

In [None]:
cnm.estimates.nb_view_components_3d(image_type='mean', dims=dims, axis=2);

## For larger data use a patch approach

In [None]:
# set parameters
rf = 15  # half-size of the patches in pixels. rf=25, patches are 50x50
stride = 10  # amount of overlap between the patches in pixels
K = 10  # number of neurons expected per patch
gSig = [4, 4, 2]  # expected half size of neurons
merge_thresh = 0.8  # merging threshold, max correlation allowed
p = 2  # order of the autoregressive system
print('set')

### Run CNMF

In [None]:
%%time
cnm = cnmf.CNMF(n_processes, 
                k=K, 
                gSig=gSig, 
                merge_thresh=merge_thresh, 
                p=p, 
                dview=dview,
                rf=rf, 
                stride=stride, 
                only_init_patch=True)
cnm.params.set('spatial', {'se': np.ones((3,3,1), dtype=np.uint8)})
cnm.fit(images)
print(('Number of components:' + str(cnm.estimates.A.shape[-1])))

### Component Evaluation

In [None]:
#%% COMPONENT EVALUATION
# the components are evaluated in two ways:
#   a) the shape of each component must be correlated with the data
#   b) a minimum peak SNR is required over the length of a transient

fr = 10 # approx final rate  (after eventual downsampling )
decay_time = 1.  # length of typical transient in seconds 
use_cnn = False  # CNN classifier is designed for 2d (real) data
min_SNR = 3      # accept components with that peak-SNR or higher
rval_thr = 0.6   # accept components with space correlation threshold or higher
cnm.params.change_params(params_dict={'fr': fr,
                                      'decay_time': decay_time,
                                      'min_SNR': min_SNR,
                                      'rval_thr': rval_thr,
                                      'use_cnn': use_cnn})

cnm.estimates.evaluate_components(images, cnm.params, dview=dview)

In [None]:
print(('Keeping ' + str(len(cnm.estimates.idx_components)) +
       ' and discarding  ' + str(len(cnm.estimates.idx_components_bad))))

### Re-run seeded CNMF
Now we re-run CNMF on the whole FOV seeded with the accepted components.

In [None]:
%%time
cnm.params.set('temporal', {'p': p})
cnm2 = cnm.refit(images)

### View the results
Unlike the above layered view, here we view the components as max-projections (frontal in the XY direction, sagittal in YZ direction and transverse in XZ), and we also show the denoised trace.

In [None]:
cnm2.estimates.nb_view_components_3d(image_type='corr', 
                                     dims=dims, 
                                     Yr=Yr, 
                                     denoised_color='red', 
                                     max_projection=True,
                                     axis=2);

In [None]:
# STOP CLUSTER
cm.stop_server(dview=dview)