In [1]:


import cv2
import glob
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import time
import h5py

try:
    cv2.setNumThreads(0)
except():
    pass

try:
    if __IPYTHON__:
        # this is used for debugging purposes only. allows to reload classes
        # when changed
        get_ipython().magic('load_ext autoreload')
        get_ipython().magic('autoreload 2')
except NameError:
    pass

import caiman as cm
from caiman.motion_correction import MotionCorrect
from caiman.source_extraction.cnmf import cnmf as cnmf
from caiman.source_extraction.cnmf import params as params
from caiman.utils.utils import download_demo
from caiman.utils.visualization import plot_contours, nb_view_patches, nb_plot_contour


import pylab as pl
from functools import partial
import tifffile as tf
import multiprocessing as mp
import json
import time
import re
import optparse
import sys


def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split('(\d+)', text) ]



from caiman.source_extraction.cnmf.initialization import downscale as cmdownscale



def extract_options(options):

    parser = optparse.OptionParser()

    # PATH opts:
    parser.add_option('-D', '--root', action='store', dest='rootdir', default='/n/coxfs01/2p-data', help='data root dir (root project dir containing all animalids) [default: /n/coxfs01/2pdata]')
    parser.add_option('-i', '--animalid', action='store', dest='animalid', default='', help='Animal ID')
    parser.add_option('-S', '--session', action='store', dest='session', default='', help='session dir (format: YYYMMDD_ANIMALID')
    parser.add_option('-A', '--acq', action='store', dest='fov', default='FOV1_zoom2p0x', help="acquisition folder (ex: 'FOV1_zoom2p0x') [default: FOV1_zoom2p0x]")
    parser.add_option('-E', '--exp', action='store', dest='experiment', default='', help="Name of experiment (stimulus type), e.g., rfs")
    parser.add_option('-t', '--traceid', action='store', dest='traceid', default='traces001', help="Traceid from which to get seeded rois (default: traces001)")


    parser.add_option('-n', '--nproc', action="store",
                      dest="n_processes", default=2, help="N processes [default: 1]")
    parser.add_option('-d', '--downsample', action="store",
                      dest="ds_factor", default=5, help="Downsample factor (int, default: 5)")

    parser.add_option('--destdir', action="store",
                      dest="destdir", default='/n/scratchlfs/cox_lab/julianarhee/downsampled', help="output dir for movie files [default: /n/scratchlfs/cox_lab/julianarhee/downsampled]")
    parser.add_option('--plot', action='store_true', dest='plot_rois', default=False, help="set to plot results of each roi's analysis")
    parser.add_option('--processed', action='store_false', dest='use_raw', default=True, help="set to downsample on non-raw source")

    parser.add_option('--new', action='store_true', dest='create_new', default=False, help="Set to downsample and motion correct anew")
    parser.add_option('--prefix', action='store', dest='prefix', default='Yr', help="Prefix for sourced memmap/mc files (default: Yr)")


    (options, args) = parser.parse_args(options)

    return options

def caiman_params(fnames):
    # dataset dependent parameters
    fr = 44.65                             # imaging rate in frames per second
    decay_time = 0.4                    # length of a typical transient in seconds

    # motion correction parameters
    strides = (48, 48)          # start a new patch for pw-rigid motion correction every x pixels
    overlaps = (24, 24)         # overlap between pathes (size of patch strides+overlaps)
    max_shifts = (6,6)          # maximum allowed rigid shifts (in pixels)
    max_deviation_rigid = 3     # maximum shifts deviation allowed for patch with respect to rigid shifts
    pw_rigid = False             # flag for performing non-rigid motion correction

    # parameters for source extraction and deconvolution
    p = 2                       # order of the autoregressive system
    gnb = 2                     # number of global background components
    merge_thr = 0.85            # merging threshold, max correlation allowed
    rf = 25                     # half-size of the patches in pixels. e.g., if rf=25, patches are 50x50
    stride_cnmf = 12             # amount of overlap between the patches in pixels
    K = 8                      # number of components per patch
    gSig = [2, 2]               # expected half size of neurons in pixels
    method_init = 'greedy_roi'  # initialization method (if analyzing dendritic data using 'sparse_nmf')
    ssub = 1                    # spatial subsampling during initialization
    tsub = 1                    # temporal subsampling during intialization

    # parameters for component evaluation
    min_SNR = 2.0               # signal to noise ratio for accepting a component
    rval_thr = 0.85              # space correlation threshold for accepting a component
    cnn_thr = 0.99              # threshold for CNN based classifier
    cnn_lowest = 0.1 # neurons with cnn probability lower than this value are rejected

    opts_dict = {'fnames': fnames,
                'fr': fr,
                'decay_time': decay_time,
                'strides': strides,
                'overlaps': overlaps,
                'max_shifts': max_shifts,
                'max_deviation_rigid': max_deviation_rigid,
                'pw_rigid': pw_rigid,
                'p': 1,
                'nb': gnb,
                'rf': rf,
                'K': K, 
                'stride': stride_cnmf,
                'method_init': method_init,
                'rolling_sum': True,
                'only_init': True,
                'ssub': ssub,
                'tsub': tsub,
                'merge_thr': merge_thr, 
                'min_SNR': min_SNR,
                'rval_thr': rval_thr,
                'use_cnn': True,
                'min_cnn_thr': cnn_thr,
                'cnn_lowest': cnn_lowest}

    opts = params.CNMFParams(params_dict=opts_dict)

    return opts

In [2]:

def save_mc_results(results_dir, prefix='Yr'):
    np.savez(os.path.join(results_dir, 'mc_rigid.npz'),
            mc=mc,
            fname=mc.fname, max_shifts=mc.max_shifts, min_mov=mc.min_mov,
            border_nan=mc.border_nan,
            fname_tot_rig=mc.fname_tot_rig,
            total_template_rig=mc.total_template_rig,
            templates_rig=mc.templates_rig,
            shifts_rig=mc.shifts_rig,
            mmap_file=mc.mmap_file,
            border_to_0=mc.border_to_0)
    print("--- saved MC results: %s" % os.path.join(results_dir, '%s_mc-rigid.npz' % prefix))

    
def load_mc_results(results_dir, prefix='Yr'):
    try:
        mc_results = np.load(os.path.join(results_dir, '%s_mc-rigid.npz' % prefix))
        mc = mc_results[mc] 
#            fname=mc.fname, max_shifts=mc.max_shifts, min_mov=mc.min_mov,
#            border_nan=mc.border_nan,
#            fname_tot_rig=mc.fname_tot_rig,
#            total_template_rig=mc.total_template_rig,
#            templates_rig=mc.templates_rig,
#            shifts_rig=mc.shifts_rig,
#            mmap_file=mc.mmap_file,
#            border_to_0=mc.border_to_0)
    except Exception as e:
        return None

    return mc 

def get_file_paths(results_dir, prefix='Yr'):
    try:
        mparams_fpath = os.path.join(results_dir, '%s_memmap-params.json' % prefix)
        print("Loading memmap params...")
        with open(mparams_fpath, 'r') as f:
            mparams = json.load(f)
        fnames = mparams['fnames']
    except Exception as e:
        try:
            dpath = glob.glob(os.path.join(results_dir, 'memmap', '*%s*.npz' % prefix))[0]#) [0]) 
            minfo = np.load(dpath)
            fnames = sorted(list(minfo['mmap_fnames']))
        except Exception as e:
            print("unable to load file names.")
            return None
    
    return fnames #fnames = mparams['fnames']

def get_full_memmap_path(results_dir, prefix='Yr'):
    print("Getting full mmap path for prefix: %s" % prefix)
    fname_new = glob.glob(os.path.join(results_dir, 'memmap', '*%s*_d*_.mmap' % prefix))[0]
    prefix = os.path.splitext(os.path.split(fname_new)[-1])[0].split('_d1_')[0]
    print("CORRECTED PREFIX: %s" % prefix)
    return fname_new, prefix


In [3]:

def get_roiid_from_traceid(animalid, session, fov, run_type=None, traceid='traces001', rootdir='/n/coxfs01/2p-data'):
    
    if run_type is not None:
        if int(session) < 20190511 and run_type == 'gratings':
            a_traceid_dict = glob.glob(os.path.join(rootdir, animalid, session, fov, '*run*', 'traces', 'traceids*.json'))[0]
        else:
            a_traceid_dict = glob.glob(os.path.join(rootdir, animalid, session, fov, '*%s*' % run_type, 'traces', 'traceids*.json'))[0]
    else:
        a_traceid_dict = glob.glob(os.path.join(rootdir, animalid, session, fov, '*run*', 'traces', 'traceids*.json'))[0]
    with open(a_traceid_dict, 'r') as f:
        tracedict = json.load(f)
    
    tid = tracedict[traceid]
    roiid = tid['PARAMS']['roi_id']
    
    return roiid


def load_roi_masks(animalid, session, fov, rois=None, rootdir='/n/coxfs01/2p-data'):
    masks=None; zimg=None;
    mask_fpath = glob.glob(os.path.join(rootdir, animalid, session, 'ROIs', '%s*' % rois, 'masks.hdf5'))[0]
    try:
        mfile = h5py.File(mask_fpath, 'r')

        # Load and reshape masks
        fkey = list(mfile.keys())[0]
        masks = mfile[fkey]['masks']['Slice01'][:] #.T
        #print(masks.shape)
        #mfile[mfile.keys()[0]].keys()

        zimg = mfile[fkey]['zproj_img']['Slice01'][:] #.T
        zimg.shape
    except Exception as e:
        print("error loading masks")
    finally:
        mfile.close()
        
    return masks, zimg

def reshape_and_binarize_masks(masks):
    # Binarze and reshape:
    nrois, d1, d2 = masks.shape
    Ain = np.reshape(masks, (nrois, d1*d2))
    Ain[Ain>0] = 1
    Ain = Ain.astype(bool).T 
    
    return Ain


In [4]:
options = ['-i', 'JC084', '-S', '20190525', '-A', 'FOV1_zoom2p0x', '-E', 'gratings', 
          '--prefix=JC084-20190525-FOV1_zoom2p0x-gratings-downsample-5', '-n', 8]

In [5]:
opts = extract_options(options) 
rootdir = opts.rootdir #'/n/coxfs01/2p-data'
animalid = opts.animalid #'JC084'
session = opts.session #'20190525' #'20190505_JC083'
fov = opts.fov
experiment = opts.experiment
ds_factor = int(opts.ds_factor)
destdir = opts.destdir
use_raw = opts.use_raw
n_processes = int(opts.n_processes) 
create_new = opts.create_new
prefix = opts.prefix
traceid=opts.traceid



In [None]:

# Load manual ROIs and format
print("Getting seeds...")
roiid = get_roiid_from_traceid(animalid, session, fov, run_type=experiment, traceid=traceid)
masks, zimg = load_roi_masks(animalid, session, fov, rois=roiid)
Ain = reshape_and_binarize_masks(masks)
Ain.shape

Getting seeds...
