# whole-brain timeseries extraction

### import rois

In [None]:
import os
import numpy as np
import itertools
import re

def rois(path):
    """
    import roi centers from (fiji Unzipped RoiSet folder) one frame/zplane in timeseries
    Create a mask around this pixel to use as roimask(x,y)
    """
    roifiles = list(filter(lambda f:f.endswith('.roi'), os.listdir(path)))
    roi,roimask,roimask_sorted = [{} for i in range(3)]
    for n in roifiles:
        roi[n]=n.split('.roi')[0]
        roi[n]=re.split('-0*',roi[n])
        roimask[np.int(roi[n][0])-1,n] = list(itertools.product(list(range(np.int(roi[n][2])-3,np.int(roi[n][2])+3)), list(range(np.int(roi[n][1])-3,np.int(roi[n][1])+3))))
        #correct for offset by 1 from numZ: np.int(roi[n][0])-1
    n_sorted = sorted(roimask.keys())
    for k in n_sorted:
        roimask_sorted[k] = roimask[k]
    return roimask_sorted

### sort tiff filenames

In [None]:
"""
sort filenames
"""
def sortfn(path):
    listfiles = list(filter(lambda f:f.endswith('ome.tif'), os.listdir(path)))
    fnum = {}
    for s in listfiles:
        s = os.path.join(path,s)
        fnum[s]=s.split('.ome.tif')[0]
        fnum[s]=fnum[s].split('Default')[1]
        if fnum[s] == '':
            fnum[s]="00"
        if len(fnum[s])==2:
            fnum[s]="0"+fnum[s][1]
        if len(fnum[s])>=3:
            fnum[s]=fnum[s][1:]
    fn = sorted(fnum.keys(),key=fnum.get)
    return fn

### organize sequential tiffs in callable chunks

In [None]:
import copy
import pdb
import json
import base64

import matplotlib
# matplotlib.use("tkAgg")
import matplotlib.pyplot as plt

def chunk_ix(shape=None, dims=None, req=None):
    """indices for a chunk request (expanding a chunk request)

    'shape' and 'dims' describe the shape and dims names of an array A of
    arbitrary dimensions. 'req' is a (compact, dictionary) request for a chunk
    of A (np.ndarray).

    arguments
    ------
    shape (list): array shape
    dims (list or str): array dims ('XYZTC' or ['X','Y','Z','T','C'])
    req (dict): chunk request (e.g. {'X':1, 'T':[0,2], 'Y':(0,5)})

    req format options:
        X=1             interpreted as [1]
        X=(0,3)         interpreted as range(0,3)
        X=[0,3]         interpreted as itself (a list)
        X=slice(0,3)    python slice (e.g. slice(None,None,10) for every 10th)
        X=None          (default) all values

    returns
    ------
    ix (list):  a list of lists of index values for each dimension. Using the 
                output, numpy.ix_(ix) is a valid slice of A. Similarly, 
                itertools.product(*ix) makes a list of tuples for chunk
                elements in A.

    TODO: how to handle missing dims in the request? Return all/zero/error?
            A: if size is 1, use it, otherwise
    """

    #== check if requested dims DNE in the data
    for k in req.keys():
        if k not in dims:
            raise Exception('requested axis (%s) not in dims (%s)' % (k, str(dims)))

    #== a partial request, missing some dims, is made explicit
    reqloc = {k:v for k,v in req.items()}
    for a in dims:
        if a not in req.keys():
            reqloc[a] = None
            #print('WARNING: request set to None for %s' % a)

    ix = []
    for size, dim in zip(shape, dims):
        r = reqloc[dim]
        if isinstance(r, int):
            ix.append([r])
        elif isinstance(r, tuple):
            ix.append(list(range(*r)))
        elif isinstance(r, list):
            ix.append(r)
        elif isinstance(r, slice):
            ix.append(range(size)[r])
        elif r is None:
            ix.append(list(range(size)))
        else:
            raise Exception('chunk index: (%s) not recognized' %
                            (str(reqloc[dim])))
    return ix

class DataChunk(object):
    """
    multidimensional data and metadata
    """

    def __init__(self, data, **kwargs):

        self.data = data
        self.dims = kwargs.get('dims', None)

        self.meta = kwargs.get('meta', {})

        if kwargs.get('axes', None) is not None:
            raise Exception('axes no longer accepted, used dims instead')

        #== derived
        self.dim_len = dict(zip(self.dims,self.data.shape))

    @property
    def shape(self):
        return self.data.shape

    @property
    def dtype(self):
        return self.data.dtype

    def subchunk(self, req=None, inplace=False, squeeze=False, verbose=False):
        """return a subchunk of a data (also a DataChunk)

        chunk indexing format options:
            X=1         interpreted as index 1
            X=(0,3)     interpreted as a range
            X=[0,3]     interpreted as a list
            X=None      all values
        """

        if inplace:
            raise Exception('inplace=True not yet implemented')

        #== building the chunk request
        req_ix = chunk_ix(shape=self.shape, dims=self.dims, req=req)
        ix = np.ix_(*req_ix)

        chunk = self.data[ix]
        dims = self.dims
        md = self.meta.copy()
        #md['generation'] = 2

        if verbose:
            print('-----------')
            print('dims   :', self.dims)
            print('shape   :', self.shape)
            print('request:', req)
            print('ix     :', ix)
            print('chunk  :', chunk)
            print('-----------')

        return DataChunk(data=chunk, dims=dims, meta=md)

In [None]:
import itertools
import pdb
import os

import tifffile as tf
import numpy as np
import pandas as pd

class TiffMetadataParser(object):
    def __init__(self, f=None):
        raise Exception('TiffMetadataParser DNE, parsing is handled by the TiffReader')



class TiffReader(object):
    """for pulling frame(s) and chunks out of a tiff hyperstack

    TODO: lil-tiff (first N frames to a new file, allows fiji inspection
          to determine metadata parameters (dim order, num-Z, starting frame
          etc...))
    """

    def __init__(self, f=None):

        if isinstance(f, list):
            self.files = f
        else:
            self.files = [f]

        self.TF = [tf.TiffFile(f) for f in self.files]
        self.build_index()

        #self.mmmd = [tiff.micromanager_metadata for tiff in self.TF]
        return

    def build_index(self):
        """global frame index for >=1 tiff file(s) and metadata

        The dataframe df holds all of the tiff page indexing info. Columns
        T, Z, and C are self-explanatory. F is the (input) file index and 'ndx'
        is the page index WITHIN the corresponding file. Given, TZC, one can
        determine F and ndx.

        tzc2fndx is a dictionary. Keys are (T,Z,C) tuples (converted to strings)
        and values are (F, ndx) tuples

        """

        # extract some info from zeroth series of zeroth file
        series0 = self.TF[0].series[0]
        axes = series0.axes
        shape = series0.shape
        dtype = series0.dtype

        if axes[-2:] not in ['XY', 'YX']:
            raise Exception('expecting the last two axes to be XY or YX ')

        # split axes into tzc and xy
        xy_axes = axes[-2:]
        xy_shape = shape[-2:]
        tzc_axes = axes[:-2]
        tzc_shape = shape[:-2]
        tzc_ndx = list(itertools.product(*[range(n) for n in tzc_shape]))

        # a dataframe to hold all the indexing information
        df = pd.DataFrame(data=tzc_ndx, columns=list(tzc_axes))
        for col in ['T', 'Z', 'C']:
            if col not in df.columns:
                df[col] = np.zeros(len(df), dtype=int)

        filesinfo = []
        fndx = []
        ndx = []
        for i, ff in enumerate(self.files):
            tiff = self.TF[i]
            num_pages = len(tiff.pages)
            fndx += [i]*num_pages
            ndx += list(range(num_pages))
            filesinfo.append(dict(f=ff, num_pages=num_pages))
            
        if len(ndx) != len(df):
            print('WARNING: something is wrong')
            raise Exception()


        df['F'] = fndx
        df['ndx'] = ndx
        df = df[['T', 'Z', 'C', 'F', 'ndx']]

        # given TZC, return [file,frame] indices
        tzc2fndx = {str(tuple(x[:3])): (x[3],x[4]) for x in df.values}
        #tzc2ndx = {str(tuple(x[:3])): x[3],x[4]) for x in df.values}

        meta = dict(
            _about="tiff metadata",
            files=self.files,
            filesinfo=filesinfo,
            shape=shape,
            axes=axes,
            dtype=dtype,
            axes_xy=xy_axes,
            shape_xy=xy_shape,
            axes_tzc=tzc_axes,
            shape_tzc=tzc_shape,
            df_tzc=df,
            tzc2fndx=tzc2fndx
            )

        self.meta = meta

        return


    def about(self):
        """helpful information at a glance"""
        print('#------------------------')
        print('# TiffReader metadata')
        print('#------------------------')
        print('num_files :', len(self.meta['files']))
        for i, ff in enumerate(self.meta['filesinfo']):
            print('file%3.3i   :' % i, ff['f'])
        print('pages/file:', [ff['num_pages'] for ff in self.meta['filesinfo']])
        print('dtype     :', self.meta['dtype'])
        print('axes      :', self.meta['axes'])
        print('shape     :', self.meta['shape'])
        print('df_tzc    :')
        print(pd.concat([self.meta['df_tzc'].head(), self.meta['df_tzc'].tail()]))
        print()
        return


    def getframe(self, Z=0, T=0, C=0):
        """returns a (YX) frame given (hyper)stack indices Z,T,C

        TODO: turbo request of a frame subregion?
        TODO: what other metadata to carry forward?
        TODO: return a datachunk?
        """
        #index = self.meta['tzc2ndx'][str(tuple([T, Z, C]))]
        #data = self.tiff.pages[index].asarray()

        f, index = self.meta['tzc2fndx'][str(tuple([T, Z, C]))]
        data = self.TF[f].pages[index].asarray()

        axes = self.meta['axes_xy']
        shape = self.meta['shape_xy']
        dtype = self.meta['dtype']
        meta = dict(Z=Z, T=T, C=C, dtype=dtype)

        frame = dict(data=data, axes=axes, shape=shape, meta=meta)
        return frame

    def getchunk(self, req=None):
        """get multiple frames and assemble a DataChunk

        TODO: better naming for chunk requests (compact vs full?)
        """
        dtype = self.meta['dtype']

        #== confirm that requested axes exist in the file
        for a in req.keys():
            if a not in self.meta['axes']:
                raise Exception('requested axis (%s) not in the tiff axes (%s)'
                % (a, self.meta['axes']))

        #===========================================================
        #== enumerate TZC combinations (frames) and generate ALL
        #== of the getframe requests. Each request is a dictionary
        #===========================================================
        shape_tzc = self.meta['shape_tzc']
        axes_tzc = self.meta['axes_tzc']
        req_tzc = {k:req.get(k, None) for k in axes_tzc}
        ix_tzc = chunk_ix(shape=shape_tzc, dims=axes_tzc, req=req_tzc)
        #== list of tuples
        TZC = list(itertools.product(*ix_tzc))

        #== the XY request (basically a crop) is done on each TZC frame
        shape_xy = self.meta['shape_xy']
        axes_xy = self.meta['axes_xy']
        req_xy = {k:req.get(k, None) for k in 'XY'}
        ix_xy = chunk_ix(shape=shape_xy, dims=axes_xy, req=req_xy)

        #== determine the shape of the output chunk
        new_shape_tzc = [len(x) for x in ix_tzc]
        new_shape_xy = [len(x) for x in ix_xy]

        #== build a list of frames, then reshape it
        shp = [len(TZC)] + new_shape_xy
        data = np.zeros(shp, dtype=dtype)
        for i, row in enumerate(TZC):
            reqi = dict(zip(axes_tzc, row))
            frm = self.getframe(**reqi)
            # if no req_xy, do not make a DataChunk
            if req_xy['X'] == None and req_xy['Y'] == None:
                data[i] = frm['data']
            else:
                frm['dims'] = frm['axes']
                frm.pop('axes')
                chk = DataChunk(**frm).subchunk(req=req_xy)
                data[i] = chk.data
            #print(row, dict(zip(axes_tzc, row)))

        #== reshape
        shape = new_shape_tzc+new_shape_xy
        data = data.reshape(shape)

        #== build the DataChunk
        axes = axes_tzc + axes_xy
        meta = dict(_about='DataChunk from a tiff file',
                    req=req)
        chunk = DataChunk(data=data, dims=axes, meta=meta)

        return chunk


### read tiffs for the requested chunks

In [None]:
import pdb
import tifffile as tf
import pandas as pd

class TiffReader(object):
    """
    for pulling frame(s) and chunks out of a tiff hyperstack
    """

    def __init__(self, f=None):

        if isinstance(f, list):
            self.files = f
        else:
            self.files = [f]

        self.TF = [tf.TiffFile(f) for f in self.files]
        self.build_index()

        #self.mmmd = [tiff.micromanager_metadata for tiff in self.TF]
        return

    def build_index(self):
        """global frame index for >=1 tiff file(s) and metadata

        The dataframe df holds all of the tiff page indexing info. Columns
        T, Z, and C are self-explanatory. F is the (input) file index and 'ndx'
        is the page index WITHIN the corresponding file. Given, TZC, one can
        determine F and ndx.

        tzc2fndx is a dictionary. Keys are (T,Z,C) tuples (converted to strings)
        and values are (F, ndx) tuples

        """

        # extract some info from zeroth series of zeroth file
        series0 = self.TF[0].series[0]
        axes = series0.axes
        shape = series0.shape
        dtype = series0.dtype

        if axes[-2:] not in ['XY', 'YX']:
            raise Exception('expecting the last two axes to be XY or YX ')

        # split axes into tzc and xy
        xy_axes = axes[-2:]
        xy_shape = shape[-2:]
        tzc_axes = axes[:-2]
        tzc_shape = shape[:-2]
        tzc_ndx = list(itertools.product(*[range(n) for n in tzc_shape]))

        # a dataframe to hold all the indexing information
        df = pd.DataFrame(data=tzc_ndx, columns=list(tzc_axes))
        for col in ['T', 'Z', 'C']:
            if col not in df.columns:
                df[col] = np.zeros(len(df), dtype=int)

        filesinfo = []
        fndx = []
        ndx = []
        for i, ff in enumerate(self.files):
            tiff = self.TF[i]
            num_pages = len(tiff.pages)
            fndx += [i]*num_pages
            ndx += list(range(num_pages))
            filesinfo.append(dict(f=ff, num_pages=num_pages))
            
        if len(ndx) != len(df):
            print('WARNING: something is wrong')
            raise Exception()

        df['F'] = fndx
        df['ndx'] = ndx
        df = df[['T', 'Z', 'C', 'F', 'ndx']]

        # given TZC, return [file,frame] indices
        tzc2fndx = {str(tuple(x[:3])): (x[3],x[4]) for x in df.values}

        meta = dict(
            _about="tiff metadata",
            files=self.files,
            filesinfo=filesinfo,
            shape=shape,
            axes=axes,
            dtype=dtype,
            axes_xy=xy_axes,
            shape_xy=xy_shape,
            axes_tzc=tzc_axes,
            shape_tzc=tzc_shape,
            df_tzc=df,
            tzc2fndx=tzc2fndx
            )

        self.meta = meta

        return

    def getframe(self, Z=0, T=0, C=0):
        """returns a (YX) frame given (hyper)stack indices Z,T,C

        TODO: turbo request of a frame subregion?
        TODO: what other metadata to carry forward?
        TODO: return a datachunk?
        """
        #index = self.meta['tzc2ndx'][str(tuple([T, Z, C]))]
        #data = self.tiff.pages[index].asarray()

        f, index = self.meta['tzc2fndx'][str(tuple([T, Z, C]))]
        data = self.TF[f].pages[index].asarray()

        axes = self.meta['axes_xy']
        shape = self.meta['shape_xy']
        dtype = self.meta['dtype']
        meta = dict(Z=Z, T=T, C=C, dtype=dtype)

        frame = dict(data=data, axes=axes, shape=shape, meta=meta)
        return frame

    def getchunk(self, req=None):
        """get multiple frames and assemble a DataChunk

        TODO: better naming for chunk requests (compact vs full?)
        """
        dtype = self.meta['dtype']

        #== confirm that requested axes exist in the file
        for a in req.keys():
            if a not in self.meta['axes']:
                raise Exception('requested axis (%s) not in the tiff axes (%s)'
                % (a, self.meta['axes']))

        #===========================================================
        #== enumerate TZC combinations (frames) and generate ALL
        #== of the getframe requests. Each request is a dictionary
        #===========================================================
        shape_tzc = self.meta['shape_tzc']
        axes_tzc = self.meta['axes_tzc']
        req_tzc = {k:req.get(k, None) for k in axes_tzc}
        ix_tzc = chunk_ix(shape=shape_tzc, dims=axes_tzc, req=req_tzc)
        #== list of tuples
        TZC = list(itertools.product(*ix_tzc))

        #== the XY request (basically a crop) is done on each TZC frame
        shape_xy = self.meta['shape_xy']
        axes_xy = self.meta['axes_xy']
        req_xy = {k:req.get(k, None) for k in 'XY'}
        ix_xy = chunk_ix(shape=shape_xy, dims=axes_xy, req=req_xy)

        #== determine the shape of the output chunk
        new_shape_tzc = [len(x) for x in ix_tzc]
        new_shape_xy = [len(x) for x in ix_xy]

        #== build a list of frames, then reshape it
        shp = [len(TZC)] + new_shape_xy
        data = np.zeros(shp, dtype=dtype)
        for i, row in enumerate(TZC):
            reqi = dict(zip(axes_tzc, row))
            frm = self.getframe(**reqi)
            # if no req_xy, do not make a DataChunk
            if req_xy['X'] == None and req_xy['Y'] == None:
                data[i] = frm['data']
            else:
                frm['dims'] = frm['axes']
                frm.pop('axes')
                chk = DataChunk(**frm).subchunk(req=req_xy)
                data[i] = chk.data
            #print(row, dict(zip(axes_tzc, row)))

        #== reshape
        shape = new_shape_tzc+new_shape_xy
        data = data.reshape(shape)

        #== build the DataChunk
        axes = axes_tzc + axes_xy
        meta = dict(_about='DataChunk from a tiff file',
                    req=req)
        chunk = DataChunk(data=data, dims=axes, meta=meta)

        return chunk

### extract timeseries of imported rois from images

In [None]:
class Timeseries(object):
    def __init__(self):
        self.path = r"C:\Users\heeun\Documents\Garrison_lab\somatag_2019\20190925w2_nls" #directory where image files for the recording are
        self.wormID = "20190925_w2_nls"  
        self.fnames = sortfn(self.path)
        self.roimasks = rois(self.path +'/RoiSet') #rois are drawn in fiji on frames of first volume and their centers are imported in
        self.bck = rois(self.path + '/bckg') #one background roi per each frame in first volume is drawn in the darkest region in frame
        self.tiffs = TiffReader(f=self.fnames)
        self.numZ = 10 #input an integer for number of z positions in 1 volume; 1 if singleplane
        self.exptime = 0.02 #input a float for exposure time in seconds
        
    def ZorganizedTimepoints(self):
        """
        request z organized timepoints across OME Tiff files
        """
        zTidx,zTchunk,zT0 = ({} for i in range(3))
        zTidx = [dict(T=list(range(z,self.tiffs.meta['shape'][0],self.numZ))) for z in range(self.numZ)]
        for z in range(self.numZ):
            zTchunk[z] = self.tiffs.getchunk(req=zTidx[z]) #this takes the longest amount of time to run..
            zT0[z] = self.tiffs.getchunk(req=dict(T=z)) #for verifying correct association of rois with image
        self.zTchunk = zTchunk
        return zT0

    def ZorganizedROIs(self):
        """
        Make a list of x,y coordinates of all rois indexed by z plane for the first volume in time
        """
        roixy,bckxy,roilabels=({} for i in range(3))
        roilabels=list(enumerate(self.roimasks.keys()))
        for z in range(self.numZ):
            roixy[z] = []
            bckxy[z] = []
            for k,v in self.roimasks.items():
                if type(k) in [list,tuple,dict] and z in k:
                    roixy[z].append(v)
            for k,v in self.bck.items():
                if type(k) in [list,tuple,dict] and z in k:
                    bckxy[z].append(v)
        self.roixy = roixy
        self.bckxy = bckxy
        self.roilabels = roilabels
        rlabels=pd.DataFrame.from_dict(roilabels)
        rlabels.to_pickle(self.path+"/rlabels.pkl")
        return
    
    def roipixels(self):
        """
        Obtain fluorescence intensity values from corresponding image for each (x,y) coordinate
        """
        roi_pixels,roi_pixels_trcorr,bck_pixels = ({} for i in range(3))
        for z in range(self.numZ):
            for t in range(len(self.zTchunk[z].data)):
                roi_pixels[z,t] = [[self.zTchunk[z].data[t][y,x] for (x,y) in self.roixy[z][r]] for r in range(len(self.roixy[z]))]
                bck_pixels[z,t] = [[self.zTchunk[z].data[t][y,x] for (x,y) in self.bckxy[z][r]] for r in range(len(self.bckxy[z]))]
        self.roi_pixels = roi_pixels
        self.bck_pixels = bck_pixels
        return

    def dFF(self):
        """
        Calculate and save dFF for each roi with and without background subtraction
        """
        os.mkdir(self.path+"/Quant")
        bckg,roi,roi_nob,roi_nob_mean,roi_mean_nobsub,roi_dFF,roi_dFF_nobsub,roin_dFF,roin_dFF_nobsub,roi_bck,rn_dFF,rn_dFF_nobsub = ({} for i in range(12))
        for z in range(self.numZ):
            for t in range(len(self.zTchunk[z].data)):
                bckg[z,t] = np.mean(self.bck_pixels[z,t])
                for r in range(len(self.roixy[z])):
                    roi[(z,r),t] = np.mean(self.roi_pixels[z,t][r])
                    roi_nob[(z,r),t] = [roi[(z,r),t] - bckg[z,t]]
                    roi_mean_nobsub[(z,r)] = np.mean(self.roi_pixels[z,t][r])
        for z in range(self.numZ):
            for r in range(len(self.roixy[z])):
                roi_nob_mean[(z,r)]=[]
                for t in range(len(self.zTchunk[z].data)):
                    roi_nob_mean[(z,r)].extend(roi_nob[(z,r),t])
        for z in range(self.numZ):
            for r in range(len(self.roixy[z])):
                roin_dFF[(z,r)]=[]
                roin_dFF_nobsub[(z,r)]=[]
                roi_bck[z]=[]
                for t in range(len(self.zTchunk[z].data)):            
                    roi_dFF[(z,r),t] = [np.divide(roi_nob[(z,r),t],np.mean(roi_nob_mean[(z,r)]),where=np.mean(roi_nob_mean[(z,r)])!=0)]
                    roin_dFF[(z,r)].extend(roi_dFF[(z,r),t])
                    roi_dFF_nobsub[(z,r),t] = [np.divide(np.mean(self.roi_pixels[z,t][r]),roi_mean_nobsub[(z,r)],where=roi_mean_nobsub[(z,r)]!=0)]
                    roin_dFF_nobsub[(z,r)].extend(roi_dFF_nobsub[(z,r),t])
                    roi_bck[z].extend([bckg[z,t]])
        for z in range(self.numZ):
            for r in range(len(self.roixy[z])):        
                rn_dFF[(z,r)] = roin_dFF[(z,r)][:len(roi_bck[z])-1] #there must be at least 1 roi on the last z plane analyzed
                rn_dFF_nobsub[(z,r)] = roin_dFF_nobsub[(z,r)][:len(roi_bck[z])-1]
        rn_dFF_df=pd.DataFrame.from_dict(rn_dFF)
        rn_dFF_nobsub_df=pd.DataFrame.from_dict(rn_dFF_nobsub)
        rn_dFF_df.to_pickle(self.path+"/Quant/rn_dFF.pkl")
        rn_dFF_nobsub_df.to_pickle(self.path+"/Quant/rn_dFF_nobsub.pkl")
        self.roin_dFF = roin_dFF
        self.roin_dFF_nobsub = roin_dFF_nobsub
        return rn_dFF, roin_dFF 

    def roiplots(self):
        """
        Save dFF timetrace of each roi as pdf and pkl
        """
        os.mkdir(self.path+"/Quant/roiplots")
        for z in range(self.numZ):
            for r in range(len(self.roixy[z])):
                f=plt.figure(1,figsize=(15,3),facecolor='white')
                ax1=plt.subplot(111)
                plt.plot(self.roin_dFF[z,r],linewidth=0.3)
                fr,labels=plt.xticks()
                plt.xticks(fr,(fr*self.exptime*self.numZ).astype(int))
                plt.xlabel("time (s)")
                plt.ylabel("dF/F")
                roin = self.wormID +"_z"+str(z)+"r"+str(r)
                plt.title(roin)
                plt.tight_layout()
                f.savefig(self.path +'/Quant/roiplots/' + "dFF_" + roin + ".pdf",dpi=1200,format='pdf')
                plt.clf()
                rn_df=pd.DataFrame.from_dict(self.roin_dFF[z,r])
                rn_df.to_pickle(self.path+ "/Quant/roiplots/"+ "dFF_" + roin+".pkl")


### use the functions defined above to process the movies recorded and get dFF timeseries for each ROI

In [None]:
x = Timeseries()
rois = x.ZorganizedROIs()
zt  = x.ZorganizedTimepoints()
rpixels = x.roipixels()
dff = x.dFF()
rplots = x.roiplots()


# risetimes

### plot rise times manually for each trace

In [None]:
from matplotlib.widgets import Cursor
%matplotlib qt
path    = x.path #For saving the rise time files 
exptime = x.exptime
numZ    = x.numZ
wormID = x.wormID 
rn_dFF = dff[0]
"""
on every rise, note the risetime (90%-10% of baseline value)
"""
z,n,val,addrtime,rtimebegins,rtimeends=({} for i in range(6))
for n, p in enumerate(rn_dFF.keys()):
    fig,ax = plt.subplots(1,1,num=n,squeeze=True,figsize=(8,6))
    ax.plot(rn_dFF[p],linewidth=0.3,marker='.',markersize=1)
    ax.set_title(p)
    ax.text(0.8, 0.8, numZ*exptime, transform=ax.transAxes, fontsize=8, verticalalignment='top',horizontalalignment='right')              
    ax.set_ylabel("dF/F")
    cursor = Cursor(ax, useblit=True, color='k', linewidth=1)
    zoom_ok = False
    print('\nZoom or pan to view, \npress spacebar when ready to click:\n')
    while not zoom_ok:
        zoom_ok=plt.waitforbuttonpress()
    print('Click once to select timepoint:')
    val=plt.ginput(n=-1,timeout=0,show_clicks=True,mouse_add=1,mouse_pop=2,mouse_stop=3)
    addrtime,rtimebegins,rtimeends=([] for i in range (3))
    for num in range(0,len(val),2):
        ax.plot(val[num][0],val[num][1],'m.')
        ax.plot(val[num+1][0],val[num+1][1],'y.')
        ylim=ax.get_ylim()
        ax.text(val[num][0],ylim[0]+0.05,np.round(((val[num+1][0]-val[num][0])*(numZ*exptime)),decimals=2),fontsize=8)
        rtimebegins.append(val[num][0])
        rtimeends.append(val[num+1][0])
        addrtime.append(np.round(((val[num+1][0]-val[num][0])*(numZ*exptime)),decimals=2))
    xlim=ax.get_xlim()
    ax.set_xticks(range(np.int(xlim[0]),np.int(xlim[1]),np.int(20/(exptime*numZ))))
    xtl=range(np.int(xlim[0]*exptime*numZ),np.int(xlim[1]*exptime*numZ),20)
    ax.set_xticklabels(xtl)#, fontsize=8
    ax.axes.tick_params(axis='y')#, labelsize=8
    ax.set_xlabel("time (s)")#, fontsize=8
#     z[p]=roin_dFF[p].split('r')[0]
#     z[p]=np.int(z[p].split('z')[1])
#     n[p]=np.int(roin_dFF[p].split('r')[1])
#     ax[1].imshow(zT0[z[p]])
#     for pix in range(len(roixy[z[p]][n[p]])):
#         ax[1].scatter(roixy[z[p]][n[p]][pix][0],roixy[z[p]][n[p]][pix][1],marker='.')            
    fig.savefig(path+str(p[0])+str(p[1])+".pdf",dpi=1200,format='pdf')
    plt.close()
    plt.clf()
    addrtimesv=pd.DataFrame.from_dict([rtimebegins,rtimeends,addrtime])
    addrtimesv.to_pickle(path+"\\" + x.wormID +"_z"+str(p[0])+"r"+str(p[1])+".pkl")

### compile dFF timeseries from select ROIs from each worm in one datastructure for that worm

In [None]:
import numpy as np
import pandas as pd
import os
import glob

path       = r"C:\Users\heeun\Documents\Garrison_lab\somatag_2019\20190925w2_nls\Quant\roiplots" ## Path with individual ROI plots in .pdf and .pkl files
numZ       = x.numZ #number of z positions; 1 if singleplane
exp_time   = x.exptime # exposure time 
 
    ## Go through each file in the directory to get the ROI data

all_files = glob.glob(path + "/*.pkl")
ROI_list      = [] 
 
analysis_path = path + "/analysis" 
if not os.path.exists(analysis_path):
    os.makedirs(analysis_path) 
       
for i, file_name in enumerate (all_files):
    file_name_cut = file_name[(len(path)+1):-4]
    ROIid         = file_name_cut.split("_")[-1]
    ROI_list.append(ROIid)
    temp          = pd.read_pickle(file_name)
    if i == 0:
        wormID        = file_name_cut.split("_z")[0]
        total_frames  = len(temp)
        df_dFF        = pd.DataFrame(np.arange(0, total_frames*exp_time*numZ, exp_time*numZ), columns = ['time_sec']) 
    df_dFF[file_name_cut] = temp
    
print (ROI_list)
    
save_file_path  = analysis_path + "\\" + wormID
df_dFF.to_pickle(save_file_path + ".pkl")
    

### compile all dFF timeseries in one datastructure automated risetime calculation by peakfinding


In [None]:
import glob
import pickle as pk


path = r"C:\Users\heeun\Documents\Garrison_lab\somatag_2019\all_trace_pkls" 
 # directory with .pkl files from all worms with all selected ROIs
all_files = glob.glob(path + "/*.pkl")
analysis_path = path + "/analysis" 
if not os.path.exists(analysis_path):
    os.makedirs(analysis_path) 

wormID_list   = [] 
dict_all_ROIs = {}
ROI_list      = {}

for i, file_path in enumerate (all_files):
    wormID           = file_path[(len(path)+5):-4]
    wormID_list.append(wormID) 
    temp             = pd.read_pickle(file_path)
    dFF_all          = temp
    time_sec         = dFF_all["time_sec"]
    dFF_all          = dFF_all.drop(columns=['time_sec']) #OK
    ROI_list[wormID] = list(dFF_all.columns)
     
    for j, ROIid in enumerate(ROI_list[wormID]):
        dict_all_ROIs[ROIid]              = {}        


for i, file_path in enumerate (all_files):
    temp             = pd.read_pickle(file_path)
    df_dFF_all       = temp
    time_sec         = df_dFF_all["time_sec"]
    df_dFF_all       = df_dFF_all.drop(columns=["time_sec"]) 

    for j, ROIid in enumerate(ROI_list[wormID]):
        dict_all_ROIs[ROIid]["WormID"]    = wormID
        dict_all_ROIs[ROIid]["GCaMPtype"] = wormID[-3:]
        dict_all_ROIs[ROIid]["Zplane"]    = ROIid.split("_z")[1].split("r")[0]
        dict_all_ROIs[ROIid]["ROInumber"] = ROIid.split("_z")[1].split("r")[1]
        dict_all_ROIs[ROIid]["dFF"]       = dFF_all[ROIid]
        dict_all_ROIs[ROIid]["time_sec"]  = time_sec      
save_file_path  = analysis_path+ "/allROIs.pkl" 

output = open(save_file_path, 'wb')
pk.dump(dict_all_ROIs, output)
output.close()    



### automated risetime calculation by peakfinding

In [None]:
import pickle
def peakfind(vec, thresh=1, relthreshflag = True, peaksearchingmode=1):
# find the peaks and valleys of a time series using threshold tracking
#
#    [PEAKS,VALLEYS] = peakfind(VEC,THRESH,PEAKSEARCHMODESTART) 
#
#    input:
#    VEC is a vector time series
#    THRESH is an absolute threshold criterion (scalar)
#
#    output:
#    PEAKS and VALLEYS are lists ofindices
#
#    if peaksearchingmode=1 then first search for a peak as we
#    scan from left to right
#    else if peaksearchingmode ~= 1 otherwise first search for a valley
#    
#    by default, search for a peak first (say, if initial signal deflection 
#    is positive)
#
#   hints:
#    len(PEAKS) gives number of peaks
#    len(PEAKS) and len(VALLEYS) differ by at most 1
#    peaks and valleys always alternate    
#
#    saul.kato@ucsf.edu
#

  if relthreshflag:
    thresh = thresh * (np.max(vec) - np.min(vec))


  peaks = []
  valleys = []
  max_tracker_val = -np.inf
  min_tracker_val = np.inf
  max_tracker_index = np.nan
  min_tracker_index = np.nan

  for i in range(len(vec)):
      
    vi=vec[i]
    
    # update max_tracker_val if needed
    if vi > max_tracker_val:
      max_tracker_val=vi
      max_tracker_index=i

    # update min_tracker_val if needed
    if vi < min_tracker_val:
      min_tracker_val=vi
      min_tracker_index=i
    
    if peaksearchingmode==1:
      if vi < (max_tracker_val - thresh):
        peaks.append(max_tracker_index) # add entry to peaks
        min_tracker_index=i  # move up min tracker to current time
        min_tracker_val=vi
        peaksearchingmode=0  # switch to valley searching
    else:
      if vi > (min_tracker_val + thresh):
        valleys.append(min_tracker_index) # add entry to valleys
        max_tracker_index=i  # move up max tracker to current time
        max_tracker_val=vi
        peaksearchingmode=1  # switch to peak searching
    
  return peaks, valleys


"""
define function to find rise transients, using peakfinding method
"""

def risetimes(timeseries, sm=20, thresh=5):
 
    #3-value box filter
    xs=timeseries
    # find peaks
    peaks, _ = peakfind(xs,thresh=thresh,relthreshflag=True)
    rise_start_frames=np.zeros(len(peaks))
    numframes=np.zeros(len(peaks))
    
    # derivative of smoothed timeseries
    xd=np.diff(timeseries.rolling(sm).mean())
    
    # count backwards from peaks
    for p in range(len(peaks)):
        j=peaks[p]-1
        num=0
        while np.nan_to_num(xd[j]) > 0 and j > 0:
            num=num+1
            j=j-1
        numframes[p]=num
    
    rise_end_frames = peaks
    rise_start_frames = np.subtract(peaks,numframes)
    rise_start_frames = rise_start_frames.astype(int)  #hack
    return rise_start_frames, rise_end_frames
"""
measure all rise transients
"""
threshold=.4
smoothing=40
for d in data:
    data[d]['rise_start_frames'],data[d]['rise_end_frames'] = risetimes(data[d]['dFF'],thresh=threshold,sm=smoothing)
    data[d]['risetimes']=np.zeros(len(data[d]['rise_start_frames']))
    for t in range(len(data[d]['rise_start_frames'])):
        data[d]['risetimes'][t]=data[d]['time_sec'][data[d]['rise_end_frames'][t]] - data[d]['time_sec'][data[d]['rise_start_frames'][t]]
"""
save to .pkl
"""
output = open(r'C:\Users\heeun\Documents\Garrison_lab\somatag_2019\all_trace_pkls\analysis\allROIs_output.pkl', 'wb')
pickle.dump(data, output)
output.close()
"""
plot traces with detected transients overlaid
"""
num_traces=len(data)
figsize = (20, 2*num_traces)
fig = plt.figure(figsize=figsize)
for d,t in zip(data,range(num_traces)):
    ax = plt.subplot(num_traces,1,t+1)
    ax.plot(data[d]['time_sec'],data[d]['dFF'])
    for r in range(len(data[d]['risetimes'])):
        r1=data[d]['rise_start_frames'][r]
        r2=data[d]['rise_end_frames'][r]
        ax.plot(data[d]['time_sec'][r1:r2],data[d]['dFF'][r1:r2],'r')
        ax.text(data[d]['time_sec'][r1],0,"{:.2f}".format(data[d]['risetimes'][r]))
    ax.set_xlabel('time (s)')
    ax.set_ylabel('DF/F')
    ax.text(.5,.9,d,
    horizontalalignment='center',
    fontsize=10,
    transform=ax.transAxes)

fig.savefig(r"C:\Users\heeun\Documents\Garrison_lab\somatag_2019\all_trace_pkls\analysis\traces_with_risetimes-sm"+ str(smoothing)+ "-tr"+ str(threshold) +".pdf", bbox_inches='tight')

"""
histograms of risetimes, throwing out spurious fast events
"""
rib_risetimes=[]
nls_risetimes=[]

min_thresh=1.5

for d in data:
    if data[d]['GCaMPtype']=='rib':
        rib_risetimes.extend(list(filter(lambda x: (x > min_thresh),list(data[d]['risetimes']))))
    else:
        nls_risetimes.extend(list(filter(lambda x: (x > min_thresh),list(data[d]['risetimes']))))

figsize = (4, 4)
fig = plt.figure(figsize=figsize)

_, bins, _ = plt.hist(rib_risetimes,range=[0, 30],bins=80,label='ribo')
_ = plt.hist(nls_risetimes,range=[0, 30],bins=bins,label='nls',alpha=0.5)
plt.legend(loc='upper right')
plt.xlabel("rise time (s)")
plt.show()
fig.savefig(r'C:\Users\heeun\Documents\Garrison_lab\somatag_2019\all_trace_pkls\analysis\histogram_all_ribo_vs_nls' + str(smoothing)+ "-tr"+ str(threshold) + '.pdf', bbox_inches='tight')


# Modeling the temporal transformation between ribo and nls

### create synthetic ribo-gcamp ramp trace generator

In [None]:
time_vec=np.arange(-40,40,0.1)

def gen_ribo_trace(rise_time,in_time_vec):
    out_vec=np.zeros(len(time_vec))
    for i in range(len(time_vec)):
        if in_time_vec[i]>0:
            if in_time_vec[i]<rise_time:
                out_vec[i]=in_time_vec[i]/rise_time
            else:
                out_vec[i]=1.0
                                   
    return out_vec 

fig, ax = plt.subplots()
plt.plot(time_vec,gen_ribo_trace(4.0,time_vec))
plt.xlabel("time (s)")
plt.ylabel("magnitude")
ax.axvline(x=0, color='gray')
plt.show()

### create first-order impulse response function

In [None]:
def gen_irf(tau,in_time_vec):
    out_vec=np.zeros(len(in_time_vec))
    for i in range(len(in_time_vec)):
        if in_time_vec[i]>0:
                out_vec[i]=np.exp(-in_time_vec[i]/tau)
    return out_vec

time_vec=np.arange(-40,40,0.1)
fig, ax = plt.subplots()
plt.plot(time_vec,gen_irf(9.0,time_vec))
plt.xlabel("time (s)")
plt.ylabel("magnitude")
ax.axvline(x=0, color='gray')
plt.show()

### create nls trace generator as convolution of ribo-gcamp traces and irf

In [None]:
def gen_nls_trace(rise_time,tau,in_time_vec):   
    convolved_vec=np.convolve(gen_ribo_trace(rise_time,in_time_vec),gen_irf(tau,in_time_vec),mode='same') 
    out_vec=convolved_vec
    return out_vec


fig, ax = plt.subplots()
plt.plot(time_vec,gen_nls_trace(4.0,9.0,time_vec))
plt.xlabel("time (s)")
plt.ylabel("magnitude")
ax.axvline(x=0, color='gray')
plt.show()

### measure rise time of trace

In [None]:
def measure_rise_time(trace,in_time_vec,threshold_low=0.05,threshold_high=0.95):
    
    # find start time
    start_time=in_time_vec[np.argmax(trace > (max(trace)-min(trace))*threshold_low+min(trace))]
    
    # find end time
    end_time=in_time_vec[np.argmax(trace > (max(trace)-min(trace))*threshold_high+min(trace))]

    rise_time = end_time - start_time
    return rise_time, start_time, end_time
 
    
test_trace=gen_nls_trace(4.0,6.0,time_vec)       
test_rise_time, test_start_time, test_end_time = measure_rise_time(test_trace,time_vec)
    
fig, ax = plt.subplots()
plt.plot(time_vec,gen_nls_trace(4.0,6.0,time_vec))
plt.text(test_start_time, 0, 'risetime={:.2f}'.format(test_rise_time))
plt.xlabel("time (s)")
plt.ylabel("magnitude")
ax.axvline(x=0, color='gray')
ax.axvline(x=test_start_time, color='red')
ax.axvline(x=test_end_time, color='red')
plt.show()

### transform ribo-gcamp rise times into nls-gcamp rise times given an i.r.f. tau

In [None]:
def transform_rise_times(ribo_vec,tau):
    
    this_time_vec=np.arange(-40,40,0.1)
    nls_vec=[]
    
    for time in ribo_vec:
        rt, _, _ = measure_rise_time(gen_nls_trace(time,tau,this_time_vec),this_time_vec,
                                     threshold_low=0.05,threshold_high=0.99)
        nls_vec = np.append(nls_vec,rt)
    
    return nls_vec

### create synthetic test data

In [None]:
test_ribo=abs((np.random.randn(200)*1.3)+4)

test_nls=transform_rise_times(test_ribo,tau=3.1)

_, bins, _ = plt.hist(test_ribo,range=[0, 30],bins=50,label='ribo')
_ = plt.hist(test_nls,range=[0, 30],bins=bins,label='nls',alpha=0.5)
plt.legend(loc='upper right')
plt.xlabel("rise time (s)")
plt.show()

In [None]:
plt.plot(test_ribo,test_nls,'.')
plt.xlabel("ribo rise time (s)")
plt.ylabel("nls rise time (s)")
plt.show()

### define histogram-comparison error function

In [None]:
def error_func(tau,vec1,vec2,num_bins=30,range=[0, 20]):
    
    # make hist
    hist_exper=np.histogram(vec2,bins=num_bins,range=range)
    
    # transform rise_times
    vec2_model = transform_rise_times(vec1,tau)
    hist_model=np.histogram(vec2_model,bins=num_bins,range=range)
    
    # compare hists
    err=distance.euclidean(hist_exper[0],hist_model[0])
    return err 

### find tau that best explains data

In [None]:
res=optimize.minimize(error_func,2.5,args=(test_ribo, test_nls),method='nelder-mead',options={'disp': True})
res.x[0]

### plot tau error landscape

In [None]:
my_tau_vec = np.arange(1,5,0.05)
my_error_vec = []
for my_tau in my_tau_vec:
    my_error_vec=np.append(my_error_vec,error_func(my_tau,test_ribo,test_nls))
    
plt.plot(my_tau_vec,my_error_vec)    
plt.show()

# Now run this model on experimental data- comparing ribo and nls aggregate data

### create time series excerpts of ribo rise transients

In [None]:
rib_excerpts_dFF=[] 
rib_excerpts_time_sec=[]
for d in data:
    if data[d]['GCaMPtype']=='rib':
        for r in range(len(data[d]['rise_start_frames'])):
            rib_excerpts_dFF.append(data[d]['dFF'][data[d]['rise_start_frames'][r]:(50+data[d]['rise_end_frames'][r])])
            rib_excerpts_time_sec.append(data[d]['time_sec'][data[d]['rise_start_frames'][r]:(50+data[d]['rise_end_frames'][r])])

for x in rib_excerpts_dFF:
    plt.plot(x.values)

In [None]:
risetimes_rib=[]
for r,t in zip(rib_excerpts_dFF,rib_excerpts_time_sec):
    if not r.empty:
        rt,_,_ = measure_rise_time(r.values,t.values,threshold_low=0.05,threshold_high=0.95)
        risetimes_rib.append(rt)
risetimes_rib=list(filter(lambda x: (x > min_thresh),risetimes_rib))
plt.hist(risetimes_rib,bins=30)
plt.show()

### function to convolve 1st order filter

In [None]:
def gen_irf(tau,in_time_vec):
    out_vec=np.zeros(len(in_time_vec))
    for i in range(len(in_time_vec)):
        if in_time_vec[i]>0:
                out_vec[i]=np.exp(-in_time_vec[i]/tau)
    return out_vec


def convolve_ribo_trace(ribo_trace,tau,in_time_vec):   
    convolved_vec=np.convolve(ribo_trace,gen_irf(tau,in_time_vec),mode='same') 
    return convolved_vec

In [None]:
tau=1
for r,t in zip(rib_excerpts_dFF,rib_excerpts_time_sec):
    if not r.empty:
        plt.plot(convolve_ribo_trace(r.values,tau,t.values))

In [None]:
def make_nls_traces(tau=10):
    risetimes_sim=[]
    for r,t in zip(rib_excerpts_dFF,rib_excerpts_time_sec):
        if not r.empty:
            nls_sim_trace=convolve_ribo_trace(r.values,tau,t.values)
            rt,_,_ = measure_rise_time(nls_sim_trace,t.values,threshold_low=0.05,threshold_high=0.95)
            risetimes_sim.append(rt)
        risetimes_sim=list(filter(lambda x: (x > min_thresh),risetimes_sim))     
    return risetimes_sim  


plt.hist(make_nls_traces(7),bins=30)
plt.show()

### create time series excerpts of nls rise transients

In [None]:
nls_excerpts_dFF=[] 
nls_excerpts_time_sec=[]
for d in data:
    if data[d]['GCaMPtype']=='nls':
        for r in range(len(data[d]['rise_start_frames'])):
            nls_excerpts_dFF.append(data[d]['dFF'][data[d]['rise_start_frames'][r]:data[d]['rise_end_frames'][r]])
            nls_excerpts_time_sec.append(data[d]['time_sec'][data[d]['rise_start_frames'][r]:data[d]['rise_end_frames'][r]])

for x in nls_excerpts_dFF:
    plt.plot(x.values)

In [None]:
risetimes_nls_actual=[]

for r,t in zip(nls_excerpts_dFF,nls_excerpts_time_sec):
    if not r.empty:
        rt,_,_ = measure_rise_time(r.values,t.values,threshold_low=0.05,threshold_high=0.95)
        risetimes_nls_actual.append(rt)

risetimes_nls_actual=list(filter(lambda x: (x > min_thresh),risetimes_nls_actual))   
plt.hist(risetimes_nls_actual,bins=30)
plt.show

In [None]:
def hist_error_func(tau):

    num_bins=30
    range=[0, 10]
    
    # make hists
    hist_exper=np.histogram(make_nls_traces(tau),bins=num_bins,range=range)
    hist_model=np.histogram(risetimes_nls_actual,bins=num_bins,range=range)
    
    # compare hists
    err=distance.euclidean(hist_exper[0],hist_model[0])
    return err 

In [None]:
res=optimize.minimize(hist_error_func,3,method='nelder-mead',options={'disp': True})
res.x[0]

In [None]:
my_tau_vec = np.arange(1,8,0.05)
my_error_vec = []
for my_tau in my_tau_vec:
    my_error_vec=np.append(my_error_vec,hist_error_func(my_tau))
    
plt.plot(my_tau_vec,my_error_vec)    
plt.show()

In [None]:
res=optimize.minimize(error_func,2.5,args=(test_ribo, test_nls),method='nelder-mead',options={'disp': True})
res.x[0]