<h1><center>Fiber Photometry Data Analysis</center><h1>


The package was developed by Ekaterina Martianova during her PhD studies in Christophe Proulx lab at CERVO Brain Research Center at Laval University

The package includes: 

- Functions necessary to process raw signals recorded with camera-based fiber photometry;

- Helper functions to process behavior data recorded/created with behavior softwares, e.g. Med Associates, ANY-maze, DeepLabCut;

- Helper functions to create summary plots.

# Import necessary libraries


In [None]:
import os
import sys
import time 
import pandas as pd
import numpy as np
np.random.seed(0)
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
from collections import OrderedDict
mpl.style.use('classic')
import seaborn as sns
from scipy.stats.stats import pearsonr
from scipy import signal

In [None]:
# Show images only if asked 
plt.ioff()

# Recording Class


In [None]:
class FiberPhotometryRecording:

  '''
    A class used to represent a Fiber Photometry Recording
    ...

    Attributes
    ----------
    signals : dict
        Calcium-dependent traces recorded with 470 nm (e.g. GCaMP) or 560 nm (e.g. jrGECO) LEDs
    references : dict
        Calcium-independent traces recorded with 405-415 nm LED
    time_ : numpay.ndarray
        Timestamps of the recording
    events : dict
        External or behavioral events that happening during the recoring
    meaurements : dict
        Other continious recordings happening along with the fiber photometry recording
    mouse : str
        Name of a mouse
    test : str
        Name of a test
    trial : str
        Number of a trial

    Methods
    -------
    loadRecording(filename,mouse,test,trial='1')
        Loads a recording from a specified filename file for specified mouse, test, and trial
    
    '''


  import h5py
  import numpy as np
  import pandas as pd

  def __init__(self,signals=None,references=None,time_=None,
               events=None,measurements=None,mouse='mouse',test='test',trial='1'):

    s = signals is None; r = references is None; t = time_ is None
    if (s^r or r^t or s^t):
      raise TypeError('To initialize the object, the function takes either 0 or 3-8 arguments (signals, references, time_, ...).')

    if signals and references and time_ is not None:
      if (type(signals) or type(references)) is not dict:
        raise TypeError('signals and references have to be dictionaries with keys indicating names of neural populations and values indicating corresponding traces as 1D numpy.ndarray.')

      if signals.keys() != references.keys():
        raise KeyError('Keys in signals and references dictionaries have to be the same.')

      for output in signals:
        if (type(signals[output]) or type(references[output])) is not np.ndarray:
          raise TypeError('Values of signals and references dictionaries have to be 1D numpy.ndarray.')
        if signals[output].ndim != 1 or references[output].ndim != 1:
          raise ValueError('Values of signals and refernces dictionaries have to be 1D numpy.ndarray.')

      if type(time_) is not np.ndarray:
        raise TypeError('time_ argument has to be 1D numpy.ndarray.')
      if time_.ndim != 1:
        raise ValueError('time_ argument has to be 1D numpy.ndarray.')
      
      for output in signals:
        if signals[output].size != time_.size or references[output].size != time_.size:
          raise ValueError('All signals, references, and time arrays have to be the same length.')

      if events is not None:
        if type(events) is not dict:
          raise TypeError('events argument has to be a dictionary.')
        for e in events:
          if type(events[e]) is not np.ndarray:
            raise TypeError('Values of events dictionary have to be 2D numpy.ndarrays.')
          if events[e].ndim != 2:
            raise ValueError('Values of events dictionary have to be 2D numpy.ndarrays.')

      if measurements is not None:
        if type(measurements) is not dict:
          raise TypeError('measurements argument has to be a dictionary.')
        for m in measurements:
          if type(measurements[m]) is not dict:
            raise TypeError('Values of measurements dictionary have to be dictionaries with keys "time" and "values".')
          keys = list(measurements[m].keys()); keys.sort()
          if keys != ['time', 'values']:
            raise TypeError('Values of measurements dictionary have to be dictionaries with keys "time" and "values".')
          if (type(measurements[m]['time']) or type(measurements[m]['values'])) is not np.ndarray:
            raise TypeError('Values of "time" and "values" in measurements dictionary have to be 1D numpy.ndarray.')
          if measurements[m]['time'].ndim != 1 or measurements[m]['values'].ndim != 1:
            raise ValueError('Values of "time" and "values" in measurements dictionary have to be 1D numpy.ndarray.')
          if measurements[m]['time'].size != measurements[m]['values'].size:
            raise ValueError('Arrays of "time" and "values" in measurements dictionary have to be the same length.')

        if type(mouse) is not str:
          raise TypeError('Argument mouse has to be a string.')
        if type(test) is not str:
          raise TypeError('Argument test has to be a string.')
        if type(trial) is not str:
          raise TypeError('Argument test has to be a string.')

          
    self.rawSignals = signals
    self.rawReferences = references
    self.signals = None
    self.references = None
    self.dFFs = None
    self.time = time_
    self.events = events
    self.measurements = measurements
    self.perievents = None
    self.measurePerievents = None
    self.mouse = mouse
    self.test = test
    self.trial = trial

    if self.rawSignals is not None:
      self.outputs = list(self.rawSignals.keys())
    else:
      self.outputs = None

    if time_ is not None:
      self.period = find_avg_period(self.time)
      self.frequency = 1 / self.period
    else:
      self.period = None
      self.frequency = None

    self.timeDFF = None
    


  def __repr__(self):
    if self.outputs is None:
      self = None
      return 'No recording was loaded.'

    outputs_string = ''
    for output in self.outputs:
      outputs_string += output + ', '
    return 'Fiber photometry recordings for mouse {} in outputs/pathways {}during test {}-{}.'\
            .format(self.mouse,outputs_string,self.test,self.trial)



  def saveRecording(self,fileHDF):

    with h5py.File(fileHDF, 'a') as f:

     # Raw
      for output in self.outputs:    
        path = 'Raw/'+self.test+'/'+output+'/'+self.mouse+'/'+self.trial+'/'
        saveToHDF(f,path+'time',self.time)
        saveToHDF(f,path+'signal',self.rawSignals[output])
        saveToHDF(f,path+'reference',self.rawReferences[output])

     # Recordings   
      if self.signals is not None:
        for output in self.outputs:    
          path = 'Recordings/'+self.test+'/'+output+'/'+self.mouse+'/'+self.trial+'/'
          saveToHDF(f,path+'time',self.time)
          saveToHDF(f,path+'signal',self.signals[output])
          saveToHDF(f,path+'reference',self.references[output])

     # dF/F
      if self.dFFs is not None:
        for output in self.outputs:    
          path = 'DFFs/'+self.test+'/'+output+'/'+self.mouse+'/'+self.trial+'/'
          saveToHDF(f,path+'dFF',self.dFFs[output])
          if self.timeDFF is not None:
            saveToHDF(f,path+'time',self.timeDFF)
          else: 
            saveToHDF(f,path+'time',self.time)

     # Events
      if self.events is not None:
        for event in self.events.keys():
          path = 'Events/'+self.test+'/'+event+'/'+self.mouse+'/'+self.trial+'/'
          saveToHDF(f,path+'timestamps',self.events[event])

     # Measurements
      if self.measurements is not None:
        for measure in self.measurements.keys():
          path = 'Measurements/'+self.test+'/'+measure+'/'+self.mouse+'/'+self.trial+'/'
          saveToHDF(f,path+'values',self.measurements[measure]['values'])
          saveToHDF(f,path+'time',self.measurements[measure]['time'])

     # Perievents
      if self.perievents is not None:
        for output in self.perievents:
          for event in self.perievents[output]:
            for onoffset in self.perievents[output][event]:
              path = 'Perievents/'+self.test+'/'+output+'/'+event+'/'+onoffset+'/'+self.mouse+'/'+self.trial
              saveToHDF(f,path,self.perievents[output][event][onoffset])

     # MeasurePerievents
      if self.measurePerievents is not None:
        for measure in self.measurePerievents:
          for event in self.measurePerievents[measure]:
            for onoffset in self.measurePerievents[measure][event]:
              path = 'MeasurePerievents/'+self.test+'/'+measure+'/'+event+'/'+onoffset+'/'+self.mouse+'/'+self.trial
              saveToHDF(f,path,self.measurePerievents[measure][event][onoffset])



  def removeRecording(self,fileHDF,remove='processed'):

    '''
    remove: 'all', 'processed'
            ['raw','events','measures','signal','dFF','perievents','measurePerievents']
    '''

    if remove=='all':
      remove = ['raw','events','measures','signal','dFF','perievents','measurePerievents']
    if remove=='processed':
      remove = ['signal','dFF','perievents','measurePerievents']

    with h5py.File(fileHDF, 'a') as f:

      if 'raw' in remove:
        if self.rawSignals is not None:
          for output in self.outputs:
            path = 'Raw/'+self.test+'/'+output+'/'+self.mouse+'/'+self.trial+'/'
            try:
              del f[path+'signal']
              del f[path+'reference']
              del f[path+'time']
            except KeyError:
              return print('The recording is not saved in the file.')

      if 'events' in remove:
        if self.events is not None:
          for event in self.events:
            path = 'Events/'+self.test+'/'+event+'/'+self.mouse+'/'+self.trial+'/'
            del f[path+'timestamps']

      if 'measurements' in remove:
        if self.measurements is not None:
          for measure in self.measurements:
            path = 'Measurements/'+self.test+'/'+measure+'/'+self.mouse+'/'+self.trial+'/'
            del f[path+'values']
            del f[path+'time']

      if 'signal' in remove:
        if self.signals is not None:
          for output in self.outputs:
            path = 'Recordings/'+self.test+'/'+output+'/'+self.mouse+'/'+self.trial+'/'
            del f[path+'time']
            del f[path+'signal']
            del f[path+'reference']

      if 'dFF' in remove:
        if self.dFFs is not None:
          for output in self.outputs:
            path = 'DFFs/'+self.test+'/'+output+'/'+self.mouse+'/'+self.trial+'/'
            del f[path+'dFF']
            del f[path+'time']

      if 'perievents' in remove:
        if self.perievents is not None:
          for output in self.perievents:
            for event in self.perievents[output]:
              for onoffset in self.perievents[output][event]:
                path = 'Perievents/'+self.test+'/'+output+'/'+event+'/'+onoffset+'/'+self.mouse+'/'+self.trial
                del f[path]

      if 'measurePerievents' in remove:
        if self.measurePerievents is not None:
          for measure in self.measurePerievents:
            for event in self.measurePerievents[measure]:
              for onoffset in self.measurePerievents[measure][event]:
                path = 'MeasurePerievents/'+self.test+'/'+measure+'/'+event+'/'+onoffset+'/'+self.mouse+'/'+self.trial
                del f[path]



  
  def loadRecording(self,fileHDF,mouse,test,trial='1'):
        
    self.mouse = mouse
    self.test = test
    self.trial = trial

    with h5py.File(fileHDF, 'r') as f:

      if 'Raw/'+test in f:
        outputs = list(f['Raw/'+test].keys())
      else:
        return print('No recordings are saved for test {}.'.format(test))

      self.rawSignals = {}
      self.rawReferences = {}
      for output in outputs:
        path = 'Raw/'+test+'/'+output+'/'+mouse+'/'+trial+'/'
        if path in f:
          self.rawSignals[output] = np.array(f[path].get('signal'))
          self.rawReferences[output] = np.array(f[path].get('reference'))
          self.time = np.array(f[path].get('time'))
      if self.rawSignals == {}:
        self.mouse = None
        self.test = None
        self.trial = None
        self.rawSignals = None
        self.rawReferences = None
        return 'The recording for animal {} in the experiment {}-{} is not saved in the file.'.format(mouse,test,trial)
        
      self.outputs = list(self.rawSignals.keys())

      self.period = find_avg_period(self.time)
      self.frequency = 1 / self.period
      
      if 'Recordings/'+self.test in f:
        self.signals = {}
        self.references = {}
        for output in self.outputs:
          path = 'Recordings/'+test+'/'+output+'/'+mouse+'/'+trial+'/'
          if path in f:
            self.signals[output] = np.array(f[path].get('signal'))
            self.references[output] = np.array(f[path].get('reference'))

      if 'DFFs/'+test in f:
        self.dFFs = {}
        for output in outputs:
          path = 'DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/'
          if path in f:
            self.dFFs[output] = np.array(f[path].get('dFF'))
            self.timeDFF = np.array(f[path].get('time'))

      if 'Events/'+test in f:
        self.events = {}
        for event in f['Events/'+test]:
          path = 'Events/'+test+'/'+event+'/'+mouse+'/'+trial+'/'
          if path in f:
            self.events[event] = np.array(f[path].get('timestamps'))

      if 'Measurements/'+test in f:    
        self.measurements = {}
        for measure in f['Measurements/'+test]:
          path = 'Measurements/'+test+'/'+measure+'/'+mouse+'/'+trial+'/'
          if path in f:
            self.measurements[measure] = {'time': np.array(f[path].get('time')),
                                        'values': np.array(f[path].get('values'))}
         
      if 'Perievents/'+test in f:
        self.perievents = {}
        for output in f['Perievents/'+test]:
          self.perievents[output] = {}
          for event in f['Perievents/'+test+'/'+output]:
            self.perievents[output][event] = {}
            for onoffset in f['Perievents/'+test+'/'+output+'/'+event]:
              path = 'Perievents/'+test+'/'+output+'/'+event+'/'+onoffset+'/'+mouse
              if path+'/'+trial in f:
                self.perievents[output][event][onoffset] = np.array(f[path].get(trial))
              else:
                self.perievents = {}
                break
            else:
              continue # continue if the inner loop wasn't broken
            break      # break if the inner loop was broken


      if 'MeasurePerievents/'+test in f:
        self.measurePerievents = {}
        for measure in f['MeasurePerievents/'+test]:
          self.measurePerievents[measure] = {}
          for event in f['MeasurePerievents/'+test+'/'+measure]:
            self.measurePerievents[measure][event] = {}
            for onoffset in f['MeasurePerievents/'+test+'/'+measure+'/'+event]:
              path = 'MeasurePerievents/'+test+'/'+measure+'/'+event+'/'+onoffset+'/'+mouse
              if path+'/'+trial in f:
                self.measurePerievents[measure][event][onoffset] = np.array(f[path].get(trial))
              else:
                self.measurePerievents = {}
                break
            else:
              continue
            break

    return print('The recording for mouse {} in the experiment {}-{} is successfully loaded.'.format(mouse,test,trial))





  def getDFF(self,
            smooth=True,smooth_filter='low-pass',smooth_parameter=1,
            remove_slope=True,airpls_lambda=1e4,absolute_intensities=False,
            remove_beginning=True,remove=10,
            standardize=True,
            model='Lasso',
            interpolate=False,period=0.1,
            plot=False,figsize=(20,13),save=False,image_format='.pdf'):

    self.preprocess(smooth,smooth_filter,smooth_parameter,
                    remove_slope,airpls_lambda,absolute_intensities,
                    remove_beginning,remove,standardize,
                    plot,figsize,save,image_format)

    self.align(model,plot,figsize,save,image_format)

    self.calculateDFF(standardize,plot,figsize,save,image_format)

    if interpolate:
      self.interpolateDFF(period)

    plt.close('all')




  def getPerievents(self,info_for_array=None,
                    plot=False,save=False,image_format='.pdf'):
    
    if self.timeDFF is not None:
      time_ = self.timeDFF
    else:
      time_ = self.time

   # Adjust events   
   # Remove events that are at the beginning of dF/F where are NANs
    dFF = self.dFFs[list(self.dFFs.keys())[0]]    
    idx = np.max(np.argwhere(np.isnan(dFF))) + 1
    events = self.events
    for event in events:
      e = events[event]
      if e.size != 0:
        events[event] = e[np.all(e > time_[idx], axis=1)]
    
    self.perievents = {}
    for output in self.outputs:
      self.perievents[output] = {}

    if self.measurements is not None:
      self.measurePerievents = {}
      for measure in self.measurements:
        m = self.measurements[measure]['values']
        if not isbinary(m):
          self.measurePerievents[measure] = {}

    period = find_avg_period(time_)

    cmap = get_cmap(len(self.events))

    for k,event in enumerate(self.events):

      if self.events[event].size != 0:

        try:
          window = info_for_array[event]['window']
        except:
          window = [-5.0,5.0]
        try: 
          dur = info_for_array[event]['duration']
        except:
          dur = None
        try:
          iei = info_for_array[event]['interval']
        except:
          iei = None
        try:    
          avg_win = info_for_array[event]['avg_frame']
        except:
          avg_win = None
        try:    
          figsize = info_for_array[event]['figsize']
        except:
          figsize = None  

        for output in self.dFFs:
          Array = create_perievents(self.dFFs[output],time_,events[event],
                                  window,dur,iei,avg_win)
          self.perievents[output][event] = Array
            
        if self.measurements is not None:
          for measure in self.measurements:
            measure_values = self.measurements[measure]['values']
            measure_time = self.measurements[measure]['time']
            if not isbinary(measure_values):
              Array1 = create_perievents(measure_values,measure_time,events[event],
                                        window,dur,iei)
              self.measurePerievents[measure][event] = Array1
            
      # Plot if asked 
        if plot:
          
          plt.close('all')

          if save:
            create_folder('./figures')
            create_folder('./figures/5_mean')
          for output in self.outputs:
            Array = self.perievents[output][event]
            if self.measurePerievents is not None:
              for measure in self.measurePerievents:
                Array1 = self.measurePerievents[measure][event]
                period1 = find_avg_period(self.measurements[measure]['time'])
                figtitle = self.mouse + ' ' + output + ' ' + self.test + \
                          self.trial + ' ' + event + ' ' + measure
                plot_perievents(Array,period,Array1,period1,
                                window,cmap(k),figtitle,figsize,
                                save,'./figures/5_mean/',image_format)
            else:
              figtitle = self.mouse + ' ' + output + ' ' + self.test + self.trial + ' ' + event 
              plot_perievents(Array,period,window=window,
                              color=cmap(k),figtitle=figtitle,figsize=figsize,
                              save=save,save_path='./figures/5_mean/',
                              image_format=image_format)
          
    plt.close('all')      



  # def preprocess(self,smooth_filter='low-pass',smooth_win=1,
  #                flatten=True,airpls_lambda=1e4,airpls_itermax=50,abs_int=False,
  #                remove=30,standardize=True,
  #                plot=False,figsize=(24,13),save=False,image_format='.pdf'):
  
  #   for i,t in enumerate(self.time):
  #     if t > remove:
  #       i0 = i-1
  #       break

  #   self.signals = {}
  #   self.references = {}

  # # Itterate through different outputs
  #   for output in self.outputs:

  #     s = self.rawSignals[output].copy()
  #     r = self.rawReferences[output].copy()

  #    # Smooth 
  #     if smooth_filter=='moving average':
  #       s = smooth_signal(s,window_len=smooth_win/self.period)
  #       r = smooth_signal(r,window_len=smooth_win/self.period)

  #     elif smooth_filter=='low-pass':
  #       cutoff = 1 / smooth_win
  #       f = int(round(self.frequency))
  #       s = butter_lowpass_filter(s, cutoff, f, order=5)
  #       r = butter_lowpass_filter(r, cutoff, f, order=5)

  #    # Flatten
  #     if flatten: 
  #       s, s_base = flatten_signal(s,lambda_=airpls_lambda,itermax=airpls_itermax)
  #       r, r_base = flatten_signal(r,lambda_=airpls_lambda,itermax=airpls_itermax)
        
  #       if abs_int:
  #         s = s + min(s_base)
  #         r = r + min(r_base)
        

  #    # Remove the begining
  #     r[:i0] = np.nan
  #     s[:i0] = np.nan


  #    # Standardize signal to mean 0 and std 1
  #     if standardize:
  #       s = standardize_signal(s)
  #       r = standardize_signal(r)


  #     self.signals[output] = s
  #     self.references[output] = r


  #    # Plot if asked
  #     if plot:
  #       plt.close('all')

  #       figtitle = self.mouse + ' ' + output + ' ' + self.test + self.trial

  #       if save:
  #         create_folder('figures')
  #         create_folder('figures/1_raw')

  #       plot_raw(self.rawSignals[output],self.rawReferences[output],s,r,s_base,r_base,self.time,
  #               self.events,self.measurements,
  #               figtitle,figsize,save,'./figures/1_raw/',image_format)
        
  def preprocess(self,
                 smooth=True,smooth_filter='low-pass',smooth_parameter=1,
                 remove_slope=True,airpls_lambda=1e4,absolute_intensities=False,
                 remove_beginning=True,remove=10,
                 standardize=True,
                 plot=False,figsize=(24,13),
                 save=False,image_format='.pdf'):
  
    self.signals = self.rawSignals.copy()
    self.references = self.rawReferences.copy()

   # Smooth
    if smooth:
      self.smooth(smooth_filter,smooth_parameter)

   # Remove the slope
    if remove_slope:
      s_slope,r_slope = self.removeSlope(airpls_lambda,absolute_intensities)

   # Remove the begining
    if remove_beginning:
      self.removeBeginning(remove)

   # Standardize signal to mean 0 and std 1
    if standardize:
      self.standardize()

            
   # Plot and save if needed
    if save:
      create_folder('figures')
      create_folder('figures/1_raw')

    if plot:
      for output in self.outputs:  
        
        figtitle = self.mouse + ' ' + output + ' ' + self.test + self.trial

        plot_raw(self.rawSignals[output],self.rawReferences[output],
                self.signals[output],self.references[output],
                s_slope[output],r_slope[output],
                self.time,self.events,self.measurements,
                figtitle,figsize,save,'./figures/1_raw/',image_format)


  def smooth(self, smooth_filter='low-pass',smooth_parameter=1,take_raw=False):
        
    if smooth_filter not in ['low-pass','moving average']:
      raise TypeError('Argument smooth_filter can be only "low-pass" or "moving average".')

    if take_raw or self.signals is None:
      self.signals = self.rawSignals.copy()
      self.references = self.rawReferences.copy()

    for output in self.outputs:

      s = self.signals[output].copy()
      r = self.references[output].copy()

     # Smooth 
      if smooth_filter=='moving average':
        s = smooth_signal(s,window_len=int(smooth_parameter/self.period))
        r = smooth_signal(r,window_len=int(smooth_parameter/self.period))

      elif smooth_filter=='low-pass':
        cutoff = smooth_parameter
        f = int(round(self.frequency))
        s = butter_lowpass_filter(s, cutoff, f, order=5)
        r = butter_lowpass_filter(r, cutoff, f, order=5)  

      self.signals[output] = s
      self.references[output] = r


  def removeSlope(self, airpls_lambda=1e4,absolute_intensities=False,take_raw=False):

    if take_raw or self.signals is None:
      self.signals = self.rawSignals.copy()
      self.references = self.rawReferences.copy()

    s_slope = {}
    r_slope = {}

    for output in self.outputs:

      s = self.signals[output].copy()
      r = self.references[output].copy()

      s, s_slope[output] = flatten_signal(s,lambda_=airpls_lambda)
      r, r_slope[output] = flatten_signal(r,lambda_=airpls_lambda)
        
      if absolute_intensities:
        s = s + min(s_slope)
        r = r + min(r_slope)

      self.signals[output] = s
      self.references[output] = r

    return s_slope,r_slope


  def removeBeginning(self, remove=10,take_raw=False):

    if take_raw or self.signals is None:
      self.signals = self.rawSignals.copy()
      self.references = self.rawReferences.copy()

    for i,t in enumerate(self.time):
      if t > remove:
        i0 = i-1
        break

    for output in self.outputs:
        
      self.signals[output][:i0] = np.nan
      self.references[output][:i0] = np.nan

           

  def standardize(self, take_raw=False):

    if take_raw or self.signals is None:
      self.signals = self.rawSignals.copy()
      self.references = self.rawReferences.copy()
  
    for output in self.outputs:

      self.signals[output] = standardize_signal(self.signals[output])
      self.references[output] = standardize_signal(self.references[output])




  def align(self, model='Lasso',
            plot=False,figsize=(24,13),
            save=False,image_format='.pdf'):

    for output in self.outputs:

      r_fitted = fit_signal(self.signals[output],self.references[output],model)

      if plot:
        plt.close('all')

        figtitle = self.mouse + ' ' + output + ' ' + self.test + self.trial

        if save:
          create_folder('figures')
          create_folder('figures/2_fit')
          create_folder('figures/3_align')

        plot_fit(self.signals[output],self.references[output],r_fitted,
                figtitle,(15,13),save,'./figures/2_fit/',image_format)
        plot_aligned(self.signals[output],r_fitted,self.time,self.events,self.measurements,
                    figtitle,figsize,save,'./figures/3_align/',image_format)
    
      self.references[output] = r_fitted




  def calculateDFF(self,standardized=True,
                   plot=False,figsize=(24, 13),save=False,image_format='.pdf'):
    
    self.dFFs = {}

    for output in self.outputs:
 
      self.dFFs[output] = calculate_dff(self.signals[output],self.references[output])

      if plot:
        plt.close('all')

        figtitle = self.mouse + ' ' + output + ' ' + self.test + self.trial

        if save:
          create_folder('figures')
          create_folder('figures/4_dFF')
        
        plot_dff(self.dFFs[output],self.time,self.events,self.measurements,
                figtitle,figsize,save,'./figures/4_dFF/',image_format)



  def interpolateDFF(self,period=0.1):

    time_ = self.time

    for output in self.outputs:

      signal = self.dFFs[output]

      i_nans = np.argwhere(np.isnan(signal))
      if i_nans.size != 0:
        i0 = np.max(i_nans) + 1
        t_nans = np.arange(0,time_[i0],period)
        t0 = np.max(t_nans) + period
        t1 = np.max(time_)
        t_new = np.arange(t0,t1,period)
        intrp_signal = interpolate_signal(signal[i0:],time_[i0:],t_new)
        nans = np.empty((len(t_nans),))
        nans[:] = np.nan
        new_signal = np.r_[nans,intrp_signal]
        new_time = np.r_[t_nans,t_new]
      else:
        t_new = np.arrange(0,time[-1],period)
        signal = interpolate_signal(signal,time_,t_new)

      self.dFFs[output] = new_signal
      self.timeDFF = new_time


  def smoothMeasurements(self,smooth_filter='low-pass',smooth_parameter=1):

    for measure in self.measurements:
      if not isbinary(self.measurements[measure]['values']):

        m = self.measurements[measure]['values']
        t = self.measurements[measure]['time']

        T = find_avg_period(t)

        i_nans = np.argwhere(np.isnan(m))
        if i_nans.size != 0:
          i0 = np.max(i_nans) + 1
          m = m[i0:]
          nans = m[:i0]

      
        if smooth_filter=='moving average':
          m = smooth_signal( m, window_len=int(round(smooth_parameter/T)) )

        elif smooth_filter=='low-pass':
          cutoff = smooth_parameter
          f = 1 / T
          m = butter_lowpass_filter(m, cutoff, f, order=10)


        if i_nans.size != 0:
          m = np.r_[nans,m]

        self.measurements[measure]['values'] = m




  def interpolateMeasurements(self,period=0.1):

    for measure in self.measurements:
      if not isbinary(self.measurements[measure]['values']):

        signal = self.measurements[measure]['values']
        time_ = self.measurements[measure]['time']

        if time_[0] < 0:
          i0 = np.max(np.argwhere(time_<0))
          signal = signal[i0:]
          time_ = time_[i0:]

        i_nans = np.argwhere(np.isnan(signal))
        if i_nans.size != 0:
          i0 = np.max(i_nans) + 1
          t_nans = np.arange(0,time_[i0],period)
          t0 = np.max(t_nans) + period
          t1 = np.max(time_)
          t_new = np.arange(t0,t1,period)
          intrp_signal = interpolate_signal(signal[i0:],time_[i0:],t_new)
          nans = np.empty((len(t_nans),))
          nans[:] = np.nan
          new_signal = np.r_[nans,intrp_signal]
          new_time = np.r_[t_nans,t_new]
        else:
          if time_[0] > 0:
            t_nans = np.arange(0,time_[0]+period,period)
            t0 = np.max(t_nans) + period
            t1 = np.max(time_)
            t_new = np.arange(t0,t1,period)
            intrp_signal = interpolate_signal(signal,time_,t_new)
            nans = np.empty((len(t_nans),))
            nans[:] = np.nan
            new_signal = np.r_[nans,intrp_signal]
            new_time = np.r_[t_nans,t_new]
          else:
            new_time = np.arange(0,time_[-1],period)
            new_signal = interpolate_signal(signal,time_,new_time)

        self.measurements[measure]['values'] = new_signal
        self.measurements[measure]['time'] = new_time




  def plotExample(self,outputs,event=None,measure=None,t0=0,t1=90,**kwargs):

    if self.timeDFF is not None:
      time_ = self.timeDFF
    else:
      time_ = self.time

    i0 = find_idx(t0,time_)
    i1 = find_idx(t1,time_)

    time_ = time_[i0:i1] - t0

    dFF1 = None
    dFF2 = None
    for i,output in enumerate(outputs):
      if i==0:
        dFF = self.dFFs[output][i0:i1]
      elif i==1:
        dFF1 = self.dFFs[output][i0:i1]
      elif i==2:
        dFF2 = self.dFFs[output][i0:i1]

    events = None
    if event is not None:
      events = self.events[event]
      events = events[np.all(events > t0, axis=1)]
      events = events[np.all(events < t1, axis=1)]
      events = events - t0

    measurement = None
    time_m = None
    if measure is not None:
      measurement = self.measurements[measure]['values']
      time_m = self.measurements[measure]['time']

      j0 = find_idx(t0,time_m)
      j1 = find_idx(t1,time_m)

      measurement = measurement[j0:j1]
      time_m = time_m[j0:j1]

  
    plot_example(dFF,time_,events,dFF1,dFF2,measurement,time_m,**kwargs)


# Test Class

In [None]:
class FiberPhotometryTest:
  def __init__(self,filename,test):
    self.filename = filename
    self.test = test
    with h5py.File(filename,'r') as f:

      if 'Raw/'+test not in f:
        return print('Data for test {} is not saved in the file.'.format(test))

      try:
        self.mice = list(f.attrs['mice'])
        self.outputs = list(f.attrs['outputs'])
        if f.attrs['good recordings'].size != 0:
          self.goodRecordings = [list(line) for line in list(f.attrs['good recordings']) if test in line]
        else:
          self.goodRecordings = []
        print('Experiment information for test {} is successfully loaded.'.format(test))
      except:
        print('Set names of mice, outputs/pathways recorded and good recordings\nas attributes of HDF file, and create the object again')



  def removeExperiment(self,remove='processed'):

    '''
    remove: 'all', 'processed
            ['Raw','Events','Measurements','Recordings','DFFs','Perievents','MeasurePerievents',
             'Means','MeasureMeans','MeasureCorrelation','OutputCorrelation',
             'MeasurePerieventCorrelation','OutputPerieventCorrelation']
    '''

    if remove=='all':
      remove = ['Raw','Events','Measurements','Recordings','DFFs','Perievents','MeasurePerievents',
                'Means','MeasureMeans','MeasureCorrelation','OutputCorrelation',
                'MeasurePerieventCorrelation','OutputPerieventCorrelation']

    if remove=='processed':
      remove = ['Recordings','DFFs','Perievents','MeasurePerievents',
                'Means','MeasureMeans','MeasureCorrelation','OutputCorrelation',
                'MeasurePerieventCorrelation','OutputPerieventCorrelation']


    with h5py.File(self.filename, 'a') as f:

      if 'Raw' in remove:
        if 'Raw/'+self.test in f:
          del f['Raw/'+self.test]

      if 'Events' in remove:
        if 'Events/'+self.test in f:  
          del f['Events/'+self.test]

      if 'Measurements' in remove:    
        if 'Measurements/'+self.test in f:  
          del f['Measurements/'+self.test]

      if 'Recordings' in remove:
        if 'Recordings/'+self.test in f:
          del f['Recordings/'+self.test]

      if 'DFFs' in remove:
        if 'DFFs/'+self.test in f:
          del f['DFFs/'+self.test]

      if 'Perievents' in remove:
        if 'Perievents/'+self.test in f:
          del f['Perievents/'+self.test]

      if 'MeasurePerievents' in remove:
        if 'MeasurePerievents/'+self.test in f:
          del f['MeasurePerievents/'+self.test]

      if 'Means' in remove:
        if 'Means/'+self.test in f:
          del f['Means/'+self.test]

      if 'MeasureMeans' in remove:
        if 'MeasureMeans/'+self.test in f:
          del f['MeasureMeans/'+self.test]

      if 'MeasureCorrelation' in remove:
        if 'MeasureCorrelation/'+self.test in f:
          del f['MeasureCorrelation/'+self.test]

      if 'OutputCorrelation' in remove:
        if 'OutputCorrelation/'+self.test in f:
          del f['OutputCorrelation/'+self.test]

      if 'MeasurePerieventCorrelation' in remove:
        if 'MeasurePerieventCorrelation/events/'+self.test in f:
          del f['MeasurePerieventCorrelation/events/'+self.test]
          del f['MeasurePerieventCorrelation/correlation/'+self.test]
          del f['MeasurePerieventCorrelation/counts/'+self.test]
        
      if 'OutputPerieventCorrelation' in remove:
        if 'OutputPerieventCorrelation/events/'+self.test in f:
          del f['OutputPerieventCorrelation/events/'+self.test]
          del f['OutputPerieventCorrelation/correlation/'+self.test]
          del f['OutputPerieventCorrelation/counts/'+self.test]
        
    print('The data is successfully removed from the file {}.'.format(self.filename))

    return

  
      


  def getMeans(self,period=None,perievent_windows=None,auc_frames=None):

    test = self.test

    with h5py.File(self.filename, 'a') as f:

      if 'Perievents/'+test not in f:
        return print('Perievents were not created.')

      for output in f['Perievents/'+test]:
        for event in f['Perievents/'+test+'/'+output]: 
          for onoffset in f['Perievents/'+test+'/'+output+'/'+event]:

            means = []
            mice = []
            mice_periods = []
            for mouse in f['Perievents/'+test+'/'+output+'/'+event+'/'+onoffset]:

              mouse_perievents = []
              mouse_periods = []
              for trial in f['Perievents/'+test+'/'+output+'/'+event+'/'+onoffset+'/'+mouse]:

                if self.goodRecordings != []:
                  if [mouse, test, trial, output] in self.goodRecordings:

                    print(mouse+','+trial, end=' ')
                            
                    time_ = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/time'])
                    trial_period = find_avg_period(time_)
                    mouse_periods.append(trial_period)
                    trial_perievents = list(f['Perievents/'+test+'/'+output+'/'+event+'/'+onoffset+'/'+mouse+'/'+trial])
                    mouse_perievents.extend(trial_perievents)

                else:

                  print(mouse+','+trial, end=' ')
                          
                  time_ = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/time'])
                  trial_period = find_avg_period(time_)
                  mouse_periods.append(trial_period)
                  trial_perievents = list(f['Perievents/'+test+'/'+output+'/'+event+'/'+onoffset+'/'+mouse+'/'+trial])
                  mouse_perievents.extend(trial_perievents)

              if mouse_perievents != []: 
                mouse_perievents = np.array(mouse_perievents).squeeze()
                if len(mouse_perievents.shape) == 1:
                  mouse_perievents = mouse_perievents.reshape(1,len(mouse_perievents))
                means.append(np.mean(mouse_perievents,axis=0))
                mice.append(mouse)
                mice_periods.append(np.mean(mouse_periods))

            means = np.array(means).squeeze()

            if len(means.shape) == 1:
              means = means.reshape(1,len(means))

            if means.size != 0:
            # Calculate AUC 
              try:
                window = perievent_windows[event+'-'+onoffset]
              except:
                window = [-5,5]
              try:
                time_frames = auc_frames[event+'-'+onoffset]
              except:
                time_frames = [[-2,-1],[1,2]]
              experiment_period = np.mean(mice_periods)
              auc = calculate_auc(means,experiment_period,window,time_frames)

            # Save
              meansPath = 'Means/'+test+'/'+output+'/'+event+'/'+onoffset+'/'
              saveToHDF(f,meansPath+'means',means)
              saveToHDF(f,meansPath+'auc',auc)
              saveToHDF(f,meansPath+'mice',np.array(mice,dtype=h5py.string_dtype(encoding='utf-8')))
              saveToHDF(f,meansPath+'periods',mice_periods)

              print('.')
              print('Saved dF/F mean traces for {} {}-{}'.format(output,event,onoffset))
            
            else:
              print('Empty dF/F mean traces for {} {}-{}.'.format(output,event,onoffset))

      if 'MeasurePerievents/'+test in f:

        for measure in f['MeasurePerievents/'+test]:
          for event in f['MeasurePerievents/'+test+'/'+measure]: 
            for onoffset in f['MeasurePerievents/'+test+'/'+measure+'/'+event]:

              means = []
              mice = []
              mice_periods = []
              for mouse in f['MeasurePerievents/'+test+'/'+measure+'/'+event+'/'+onoffset]:

                mouse_perievents = []
                mouse_periods = []
                for trial in f['MeasurePerievents/'+test+'/'+measure+'/'+event+'/'+onoffset+'/'+mouse]:
                  print(mouse+','+trial,end=' ')
                            
                  time_ = np.array(f['Measurements/'+test+'/'+measure+'/'+mouse+'/'+trial+'/time'])
                  trial_period = find_avg_period(time_)
                  mouse_periods.append(trial_period)
                  trial_perievents = list(f['MeasurePerievents/'+test+'/'+measure+'/'+event+'/'+onoffset+'/'+mouse+'/'+trial])
                  mouse_perievents.extend(trial_perievents)

                if mouse_perievents != []:
                  mouse_perievents = np.array(mouse_perievents).squeeze()
                  if len(mouse_perievents.shape) == 1:
                    mouse_perievents = mouse_perievents.reshape(1,len(mouse_perievents))
                  means.append(np.mean(mouse_perievents,axis=0))
                  mice.append(mouse)
                  mice_periods.append(np.mean(mouse_periods))

              means = np.array(means).squeeze()

              if len(means.shape) == 1:
                means = means.reshape(1,len(means))

              if means.size != 0:
              # Calculate AUC
                try:
                  window = perievent_windows[event+'-'+onoffset]
                except:
                  window = [-5,5]
                try:
                  time_frames = auc_frames[event+'-'+onoffset]
                except:
                  time_frames = [[-2,-1],[ 1, 2]]
                experiment_period = np.mean(mice_periods)
                auc = calculate_auc(means,experiment_period,window,time_frames)

              # Save
                meansPath = 'MeasureMeans/'+test+'/'+measure+'/'+event+'/'+onoffset+'/'
                saveToHDF(f,meansPath+'means',means)
                saveToHDF(f,meansPath+'auc',auc)
                saveToHDF(f,meansPath+'mice',np.array(mice,dtype=h5py.string_dtype(encoding='utf-8')))
                saveToHDF(f,meansPath+'periods',mice_periods)

                print('.')
                print('Saved dF/F mean traces for {} {}-{}.'.format(measure,event,onoffset))

              else:
                print('Empty measure mean trace for {} {}-{}.'.format(measure,event,onoffset))



  def plotMeans(self,output,event,onoffset='onset',output2=None,measure=None,**kwargs):

    test = self.test

    if (output2 is not None) and (measure is not None):
      print('Choose to plot 2 different outups/paths or 1 output/path and 1 measure.')
      return

    with h5py.File(self.filename, 'r') as f:
      
      if output == 'all':
        means = []
        auc = []
        mice = []
        Ts = []
        for output in f['Means/'+test]:
          means.extend(list(f['Means/'+test+'/'+output+'/'+event+'/'+onoffset+'/means']))
          auc.extend(list(f['Means/'+test+'/'+output+'/'+event+'/'+onoffset+'/auc']))
          mice.extend(list(f['Means/'+test+'/'+output+'/'+event+'/'+onoffset+'/mice']))
          Ts.extend(list(f['Means/'+test+'/'+output+'/'+event+'/'+onoffset+'/periods']))
        means = np.array(means).squeeze()
        auc = np.array(auc).squeeze()
        mice = np.array(mice).squeeze()
        T = np.mean(Ts)
      else:   
        means = np.array(f['Means/'+test+'/'+output+'/'+event+'/'+onoffset+'/means'])
        auc = np.array(f['Means/'+test+'/'+output+'/'+event+'/'+onoffset+'/auc'])
        mice = list(f['Means/'+test+'/'+output+'/'+event+'/'+onoffset+'/mice'])
        T = np.mean(f['Means/'+test+'/'+output+'/'+event+'/'+onoffset+'/periods'])
      if output2 is not None:
        means1 = np.array(f['Means/'+test+'/'+output2+'/'+event+'/'+onoffset+'/means'])
        auc1 = np.array(f['Means/'+test+'/'+output2+'/'+event+'/'+onoffset+'/auc'])
        T1 = np.mean(f['Means/'+test+'/'+output2+'/'+event+'/'+onoffset+'/periods'])
      elif measure is not None:
        means1 = np.array(f['MeasureMeans/'+test+'/'+measure+'/'+event+'/'+onoffset+'/means'])
        auc1 = np.array(f['MeasureMeans/'+test+'/'+measure+'/'+event+'/'+onoffset+'/auc'])
        mice1 = list(f['MeasureMeans/'+test+'/'+measure+'/'+event+'/'+onoffset+'/mice'])
        T1 = np.mean(f['MeasureMeans/'+test+'/'+measure+'/'+event+'/'+onoffset+'/periods'])
        idx = [i for i in range(len(mice1)) if mice1[i] in mice]
        means1 = means1[idx,:]
        auc1 = auc1[idx,:]
      else:
        means1 = None
        auc1 = None
        T1 = None

    plot_means(array=means,T=T,auc=auc,array1=means1,T1=T1,auc1=auc1,**kwargs)

    return               





  def getDataFrameAUC(self,event,onoffset,periods=['baseline','event'],
                      save=False,csvname='auc.csv'):

    test = self.test

    mice_list = []
    output_list = []
    period_list = []
    auc_list = []

    with h5py.File(self.filename,'r') as f:

      for output in f['Means/'+test]:

        path = 'Means/'+test+'/'+output+'/'+event+'/'+onoffset+'/'
        mice = list(f[path+'mice'])
        auc = np.array(f[path+'auc'])

        for i,period in enumerate(periods):

          n = len(mice)
          mice_list.extend(mice)
          output_list.extend([output]*n)
          period_list.extend([period]*n)
          auc_list.extend(list(auc[:,i]))


      df = pd.DataFrame({'mouse': mice_list,
                        'output': output_list,
                        'period': period_list,
                           'auc': auc_list})
      
      if save:
        df.to_csv(csvname,index=False)

    return df


  def getDataFrameAUCmeasure(self,event,onoffset,periods=['baseline','event'],
                            save=False,csvname='auc.csv'):

    test = self.test

    mice_list = []
    measure_list = []
    period_list = []
    auc_list = []

    with h5py.File(self.filename,'r') as f:

      for measure in f['MeasureMeans/'+test]:

        path = 'MeasureMeans/'+test+'/'+measure+'/'+event+'/'+onoffset+'/'
        mice = list(f[path+'mice'])
        auc = np.array(f[path+'auc'])

        for i,period in enumerate(periods):

          n = len(mice)
          mice_list.extend(mice)
          measure_list.extend([measure]*n)
          period_list.extend([period]*n)
          auc_list.extend(list(auc[:,i]))


      df = pd.DataFrame({'mouse': mice_list,
                       'measure': measure_list,
                        'period': period_list,
                           'auc': auc_list})
      
      if save:
        df.to_csv(csvname,index=False)

    return df




  def getOutputCorrelation(self,output,output1):
      
    from scipy.stats import pearsonr
    from scipy.signal import resample

    test = self.test

    with h5py.File(self.filename, 'a') as f:

      for output in f['DFFs/'+test]:

        Rs = []
        ps = []
        mice = []

        for mouse in f['DFFs/'+test+'/'+output]:

          mouse_Rs = []
          mouse_ps = []

          for trial in f['DFFs/'+test+'/'+output+'/'+mouse]:

            if self.goodRecordings != []:
              if ([mouse,test,trial,output] in self.goodRecordings) and ([mouse,test,trial,output1] in self.goodRecordings):

                print(mouse+','+trial, end=' ')

                signal = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/dFF'])
                signal1 = np.array(f['DFFs/'+test+'/'+output1+'/'+mouse+'/'+trial+'/dFF'])

                i0 = np.max(np.argwhere(np.isnan(signal))) + 1

                trial_R,trial_p = pearsonr(signal[i0:],signal1[i0:])

                mouse_Rs.append(trial_R)
                mouse_ps.append(trial_p)

            else:

              if 'DFFs/'+test+'/'+output1+'/'+mouse+'/'+trial+'/dFF' in f:

                print(mouse+','+trial, end=' ')
                            
                signal = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/dFF'])
                signal1 = np.array(f['DFFs/'+test+'/'+output1+'/'+mouse+'/'+trial+'/dFF'])

                i0 = np.max(np.argwhere(np.isnan(signal))) + 1

                trial_R,trial_p = pearsonr(signal[i0:],signal1[i0:])

                mouse_Rs.append(trial_R)
                mouse_ps.append(trial_p)

          if mouse_Rs != []:
            mice.append(mouse)
            Rs.append(np.mean(mouse_Rs))
            ps.append(np.mean(mouse_ps))

      # Save
        #print('R:',Rs)
        #print('p:',ps)
        path = 'OutputCorrelation/'+test+'/'+output+'_'+output1+'/'
        saveToHDF(f,path+'R',Rs)
        saveToHDF(f,path+'pvalue',ps)
        saveToHDF(f,path+'mice',np.array(mice,dtype=h5py.string_dtype(encoding='utf-8')))

        print('.')
        print('Saved pearson correlation R and p values between outputs {} and {}.'.format(output,output1 ))




  def getMeasureCorrelation(self,measure,new_period=0.1):

    from scipy.stats import pearsonr
    from scipy.signal import resample

    test = self.test

    with h5py.File(self.filename, 'a') as f:

      for output in f['DFFs/'+test]:

        Rs = []
        ps = []
        mice = []

        for mouse in f['DFFs/'+test+'/'+output]:

          mouse_Rs = []
          mouse_ps = []

          for trial in f['DFFs/'+test+'/'+output+'/'+mouse]:

            if self.goodRecordings != []:
              if [mouse, test, trial, output] in self.goodRecordings:

                print(mouse+','+trial, end=' ')
        
                time_s = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/time'])
                signal = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/dFF'])

                time_m = np.array(f['Measurements/'+test+'/'+measure+'/'+mouse+'/'+trial+'/time'])
                measurement = np.array(f['Measurements/'+test+'/'+measure+'/'+mouse+'/'+trial+'/values'])

                i0 = np.max(np.argwhere(np.isnan(signal))) + 1

                if not np.array_equal(time_s[:5],time_m[:5]):
                  return print('Interpolate signals.')

                if time_s[-1] < time_m[-1]:
                  i1 = len(time_s) - 1
                else:
                  i1 = len(time_m) - 1

                trial_R,trial_p = pearsonr(signal[i0:i1],measurement[i0:i1])

                mouse_Rs.append(trial_R)
                mouse_ps.append(trial_p)

            else:

              print(mouse+','+trial, end=' ')
                          
              time_s = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/time'])
              signal = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/dFF'])

              time_m = np.array(f['Measurements/'+test+'/'+measure+'/'+mouse+'/'+trial+'/time'])
              measurement = np.array(f['Measurements/'+test+'/'+measure+'/'+mouse+'/'+trial+'/values'])

              i0 = np.max(np.argwhere(np.isnan(signal))) + 1

              if not np.array_equal(time_s[:5],time_m[:5]):
                return print('Interpolate signals.')

              if time_s[-1] < time_m[-1]:
                i1 = len(time_s) - 1
              else:
                i1 = len(time_m) - 1

              trial_R,trial_p = pearsonr(signal[i0:i1],measurement[i0:i1])

              mouse_Rs.append(trial_R)
              mouse_ps.append(trial_p)

          if mouse_Rs != []:
            mice.append(mouse)
            Rs.append(np.mean(mouse_Rs))
            ps.append(np.mean(mouse_ps))

      # Save
        #print('R:',Rs)
        #print('p:',ps)
        path = 'MeasureCorrelation/'+test+'/'+measure+'/'+output+'/'
        saveToHDF(f,path+'R',Rs)
        saveToHDF(f,path+'pvalue',ps)
        saveToHDF(f,path+'mice',np.array(mice,dtype=h5py.string_dtype(encoding='utf-8')))

        print('.')
        print('Saved pearson correlation R and p values for output {} and measurement {}.'.format(output,measure))




  def getOutputPerieventCorrelation(self,output,output1,event,other_events='inbetween',
                                    win=1,min_duration=2,min_interval=2):
    
    test = self.test

    with h5py.File(self.filename, 'a') as f:
       
      if isinstance(event,list):

        for mouse in f['Events/'+test+'/'+event[0]]:
          for trial in f['Events/'+test+'/'+event[0]+'/'+mouse]:

            if 'DFFs/'+test+'/'+output+'/'+mouse+'/'+trial in f:
              dFF = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/dFF'])
              time_ = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/time'])
              period = find_avg_period(time_)

            # Find where not NaN part of dFF starts 
              i0 = np.max(np.argwhere(np.isnan(dFF))) + 1
              t0 = time_[i0]+2*win
              t1 = time_[-1]-2*win

              timestamps1 = np.array(f['Events/'+test+'/'+event[0]+'/'+mouse+'/'+trial+'/timestamps'])
              if len(timestamps1) != 0:
              # Remove events that fall to NaN part of dFF
                on_off_1 = timestamps1[np.all(timestamps1 > t0, axis=1)]
                on_off_1 = on_off_1[np.all(on_off_1 < t1, axis=1)]
              # Adjust timestamps to time vector
                inds1 = [ [find_idx(i,time_),find_idx(j,time_)] for i,j in on_off_1]
                on_off_1 = [ [time_[i],time_[j]] for i,j in inds1]
                on_off_1 = np.array(on_off_1)
              # Off event time
                off_on_1 = np.zeros(on_off_1.shape)
                off_on_1[0,0] = t0
                off_on_1[1:,0] = on_off_1[:-1,1]
                off_on_1[:,1] = on_off_1[:,0]
              # Onsets and offset
                on_1 = np.array(on_off_1[:,0])
                off_1 = np.array(on_off_1[:,1])
              # Adjust off time
                off_on_1 = adjust_intervals_durations(off_on_1,min_duration=2*win)
              else:
                on_1 = []
                off_1 = []
                off_on_1 = []
              

              timestamps2 = np.array(f['Events/'+test+'/'+event[1]+'/'+mouse+'/'+trial+'/timestamps'])
              if len(timestamps2) != 0:
              # Remove events that fall to NaN part of dFF
                on_off_2 = timestamps2[np.all(timestamps2 > t0, axis=1)]
                on_off_2 = on_off_2[np.all(on_off_2 < t1, axis=1)]
              # Adjust timestamps to time vector
                inds2 = [ [find_idx(i,time_),find_idx(j,time_)] for i,j in on_off_2]
                on_off_2 = [ [time_[i],time_[j]] for i,j in inds2]
                on_off_2 = np.array(on_off_2)
              # Off event time
                off_on_2 = np.zeros(on_off_2.shape)
                off_on_2[0,0] = t0
                off_on_2[1:,0] = on_off_2[:-1,1]
                off_on_2[:,1] = on_off_2[:,0]
              # Onsets and offsets
                on_2 = np.array(on_off_2[:,0])
                off_2 = np.array(on_off_2[:,1])
              # Adjust off time
                off_on_2 = adjust_intervals_durations(off_on_2,min_duration=2*win)
              else:
                on_2 = []
                off_2 = []
                off_on_2 = []
            
            # Combine two off event times
              if len(off_on_1) == 0:
                off_on = off_on_2
              elif len(off_on_2) == 0:
                off_on = off_on_1
              else:
                off_on = np.concatenate((off_on_1,off_on_2),axis=0)

            # Get random timestamps in off event time
              other = []
              for i,j in off_on:
                other.extend(np.arange(i+win,j-win+period,period))

              # Choose subset of 50 events
              other = np.array(random_subset(other, 50))

              path = 'OutputPerieventCorrelation/events/'+test+'/'+event[0]+'-'+event[1]+'/'
              saveToHDF(f,path+event[0]+'-onset/'+mouse+'/'+trial,on_1)
              saveToHDF(f,path+event[0]+'-offset/'+mouse+'/'+trial,off_1)
              saveToHDF(f,path+event[1]+'-onset/'+mouse+'/'+trial,on_2)
              saveToHDF(f,path+event[1]+'-offset/'+mouse+'/'+trial,off_2)
              saveToHDF(f,path+'other/'+mouse+'/'+trial,other)

          event = event[0]+'-'+event[1]

      else:

       # Choose events
        for mouse in f['Events/'+test+'/'+event]:
          for trial in f['Events/'+test+'/'+event+'/'+mouse]:

            if 'DFFs/'+test+'/'+output+'/'+mouse+'/'+trial in f:
              dFF = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/dFF'])
              time_ = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/time'])
              period = find_avg_period(time_)
                
              timestamps = np.array(f['Events/'+test+'/'+event+'/'+mouse+'/'+trial+'/timestamps'])

              if len(timestamps) == 0:
                continue

              if timestamps.shape[1]==1:

              # Find where not NaN part of dFF starts 
                i0 = np.max(np.argwhere(np.isnan(dFF))) + 1
                t0 = time_[i0]+win
                t1 = time_[-1]-win

              # Remove events that fall to NaN part of dFF
                on = timestamps[np.all(timestamps > t0, axis=1)]
                on = on[np.all(on < t1, axis=1)]

              # Adjust timestamps to time vector
                inds = [ find_idx(i,time_) for i in on]
                on = [ time_[i] for i in inds]
                on = np.array(on)

                inbetween = [t0] + list(on) + [t1]
                other = []
                for i in range(len(inbetween)-1):
                  other.extend(np.arange(inbetween[i]+win,inbetween[i+1]-win+period,period))

              # Choose subset of 50 events
                other = np.array(random_subset(other, 50))

                path = 'OutputPerieventCorrelation/events/'+test+'/'+event+'/'
                saveToHDF(f,path+event+'/'+mouse+'/'+trial,on)
                saveToHDF(f,path+'other/'+mouse+'/'+trial,other)

              else:
                if other_events == 'other':

                # Find where not NaN part of dFF starts 
                  i0 = np.max(np.argwhere(np.isnan(dFF))) + 1
                  t0 = time_[i0]+2*win
                  t1 = time_[-1]-2*win

                # Remove events that fall to NaN part of dFF
                  on_off = timestamps[np.all(timestamps > t0, axis=1)]
                  on_off = on_off[np.all(on_off < t1, axis=1)]

                # Adjust timestamps to time vector
                  inds = [ [find_idx(i,time_),find_idx(j,time_)] for i,j in on_off]
                  on_off = [ [time_[i],time_[j]] for i,j in inds]
                  on_off = np.array(on_off)

                  off_on = np.zeros(on_off.shape)
                  off_on[0,0] = t0
                  off_on[1:,0] = on_off[:-1,1]
                  off_on[:,1] = on_off[:,0]

                  on = np.array(on_off[:,0])
                  off = np.array(on_off[:,1])

                  off_on = adjust_intervals_durations(off_on,min_duration=2*win)

                  other = []
                  for i,j in off_on:
                    other.extend(np.arange(i+win,j-win+period,period))

                # Choose subset of 50 events
                  other = np.array(random_subset(other, 50))

                  path = 'OutputPerieventCorrelation/events/'+test+'/'+event+'/'
                  saveToHDF(f,path+'onset/'+mouse+'/'+trial,on)
                  saveToHDF(f,path+'offset/'+mouse+'/'+trial,off)
                  saveToHDF(f,path+'other/'+mouse+'/'+trial,other)

                elif other_events == 'inbetween':

                # Find where not NaN part of dFF starts 
                  i0 = np.max(np.argwhere(np.isnan(dFF))) + 1
                  t0 = time_[i0]+2*win
                  t1 = time_[-1]-2*win

                # Remove events that fall to NaN part of dFF
                  on_off = timestamps[np.all(timestamps > t0, axis=1)]
                  on_off = on_off[np.all(on_off < t1,  axis=1)]

                # Adjust timestamps to time vector
                  inds = [ [find_idx(i,time_),find_idx(j,time_)] for i,j in on_off]
                  on_off = [ [time_[i],time_[j]] for i,j in inds]
                  on_off = np.array(on_off)

                  off_on = np.zeros(on_off.shape)
                  off_on[0,0] = t0
                  off_on[1:,0] = on_off[:-1,1]
                  off_on[:,1] = on_off[:,0]

                  onoffsets = adjust_intervals_durations(on_off,min_duration,min_interval)
                  on = np.array(onoffsets[:,0])
                  off = np.array(onoffsets[:,1])

                  on_off = adjust_intervals_durations(on_off,min_duration=2*win)
                  off_on = adjust_intervals_durations(off_on,min_duration=2*win)

                  dur_on_off = []
                  for i,j in on_off:
                    dur_on_off.extend(np.arange(i+win,j-win+period,period))
                  dur_off_on = []
                  for i,j in off_on:
                    dur_off_on.extend(np.arange(i+win,j-win+period,period))

                # Choose subset of 50 events
                  dur_on_off = np.array(random_subset(dur_on_off, 50))
                  dur_off_on = np.array(random_subset(dur_off_on, 50))

                  path = 'OutputPerieventCorrelation/events/'+test+'/'+event+'/'
                  saveToHDF(f,path+'onset/'+mouse+'/'+trial,on)
                  saveToHDF(f,path+'on-off/'+mouse+'/'+trial,dur_on_off)
                  saveToHDF(f,path+'offset/'+mouse+'/'+trial,off)
                  saveToHDF(f,path+'off-on/'+mouse+'/'+trial,dur_off_on)

     # Calculate perievent correlation
      for mouse in f['DFFs/'+test+'/'+output]:

        for e_part in f['OutputPerieventCorrelation/events/'+test+'/'+event]:

          lags = []
          corrs = []
          maxlags = []
          maxcorrs = []
          Rs = []
          ps = []

          for trial in f['DFFs/'+test+'/'+output+'/'+mouse]:

            if self.goodRecordings != []:
              if ([mouse,test,trial,output] in self.goodRecordings) and ([mouse,test,trial,output1] in self.goodRecordings):

                signal = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/dFF'])
                signal1 = np.array(f['DFFs/'+test+'/'+output1+'/'+mouse+'/'+trial+'/dFF'])
                time_s = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/time'])
                T_s = find_avg_period(time_s)

                timestamps = np.array(f['OutputPerieventCorrelation/events/'+test+'/'+event+'/'+e_part+'/'+mouse+'/'+trial])

                if len(timestamps) != 0:
                  for t in timestamps:

                    idx = find_idx(t,time_s)
                    w = int(round(win/T_s))

                    s = signal[idx-w:idx+w+1]
                    s1 = signal1[idx-w:idx+w+1]

                    lag, corr = xcorr(s,s1,True,True,maxlags=w)
                    idx_maxlag = np.argmax(np.abs(corr))
                    maxcorr = corr[idx_maxlag]
                    maxlag = lag[idx_maxlag]

                    s = signal[idx+maxlag-w:idx+maxlag+w+1]

                    R,p = pearsonr(s,s1)

                    lag = [i*T_s for i in lag]
                    maxlag = maxlag*T_s

                    lags.append(lag)
                    corrs.append(corr)
                    maxlags.append(maxlag)
                    maxcorrs.append(maxcorr)
                    Rs.append(R)
                    ps.append(p)

                path = 'OutputPerieventCorrelation/correlation/'+test+'/'+event+'/'+e_part+'/'+output+'_'+output1+'/'+mouse
                saveToHDF(f,path+'/lag',lags)
                saveToHDF(f,path+'/corr',corrs)
                #saveToHDF(f,path+'/max-lag',maxlags)
                #saveToHDF(f,path+'/max-corr',maxcorrs)
                saveToHDF(f,path+'/R',Rs)
                saveToHDF(f,path+'/pvalue',ps)

            else:
                
              signal = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/dFF'])
              signal1 = np.array(f['DFFs/'+test+'/'+output1+'/'+mouse+'/'+trial+'/dFF'])
              time_s = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/time'])
              T_s = find_avg_period(time_s)

              timestamps = np.array(f['OutputPerieventCorrelation/events/'+test+'/'+event+'/'+e_part+'/'+mouse+'/'+trial])

              if len(timestamps) != 0:
                for t in timestamps:

                  idx = find_idx(t,time_s)
                  w = int(round(win/T_s))

                  s = signal[idx-w:idx+w+1]
                  s1 = signal1[idx-w:idx+w+1]

                  lag, corr = xcorr(s,s1,True,True,maxlags=w)
                  idx_maxlag = np.argmax(np.abs(corr))
                  maxcorr = corr[idx_maxlag]
                  maxlag = lag[idx_maxlag]

                  s = signal[idx+maxlag-w:idx+maxlag+w+1]

                  R,p = pearsonr(s,s1)

                  lag = [i*T_s for i in lag]
                  maxlag = maxlag*T_s

                  lags.append(lag)
                  corrs.append(corr)
                  maxlags.append(maxlag)
                  maxcorrs.append(maxcorr)
                  Rs.append(R)
                  ps.append(p)

              path = 'OutputPerieventCorrelation/correlation/'+test+'/'+event+'/'+e_part+'/'+output+'_'+output1+'/'+mouse
              saveToHDF(f,path+'/lag',lags)
              saveToHDF(f,path+'/corr',corrs)
              #saveToHDF(f,path+'/max-lag',maxlags)
              #saveToHDF(f,path+'/max-corr',maxcorrs)
              saveToHDF(f,path+'/R',Rs)
              saveToHDF(f,path+'/pvalue',ps)


     # Count positive/negative/not correlations 
      for e_part in f['OutputPerieventCorrelation/correlation/'+test+'/'+event]:

          mice = []
          list_not_corr = []
          list_pos_corr = []
          list_neg_corr = []

          for mouse in f['OutputPerieventCorrelation/correlation/'+test+'/'+event+'/'+e_part+'/'+output+'_'+output1]:
            
            ps = np.array(f['OutputPerieventCorrelation/correlation/'+test+'/'+event+'/'+e_part+'/'+output+'_'+output1+'/'+mouse+'/pvalue'])
            Rs = np.array(f['OutputPerieventCorrelation/correlation/'+test+'/'+event+'/'+e_part+'/'+output+'_'+output1+'/'+mouse+'/R'])

            if len(ps) != 0:
              n = len(ps)                               # total number events
              corr = Rs[(ps<0.001) & (np.abs(Rs)>0.6)]  # list of correlated events
              not_corr = (n - len(corr)) / n            # % of not correlated event
              pos_corr = sum(corr>0) / n                # % of positive correlated
              neg_corr = sum(corr<0) / n                # % of negative correlated

              mice.append(mouse)
              list_not_corr.append(not_corr)
              list_pos_corr.append(pos_corr)
              list_neg_corr.append(neg_corr)

          path = 'OutputPerieventCorrelation/counts/'+test+'/'+event+'/'+e_part+'/'+output+'_'+output1
          saveToHDF(f,path+'/mice',np.array(mice,dtype=h5py.string_dtype(encoding='utf-8')))
          saveToHDF(f,path+'/not-corr',list_not_corr)
          saveToHDF(f,path+'/pos-corr',list_pos_corr)
          saveToHDF(f,path+'/neg-corr',list_neg_corr)




  def getMeasurePerieventCorrelation(self,event,measure,other_events='inbetween',win=1,
                                     min_duration=2,min_interval=2):
    
    test = self.test

    with h5py.File(self.filename, 'a') as f:
       
      if isinstance(event,list):

        for mouse in f['Events/'+test+'/'+event[0]]:
          for trial in f['Events/'+test+'/'+event[0]+'/'+mouse]:

            outputs = list(f['DFFs/'+test])
            for output in outputs:
              if 'DFFs/'+test+'/'+output+'/'+mouse in f:
                break
            dFF = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/dFF'])
            time_ = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/time'])
            time_m = np.array(f['Measurements/'+test+'/'+measure+'/'+mouse+'/'+trial+'/time'])
            period = find_avg_period(time_)

           # Find where not NaN part of dFF starts 
            i0 = np.max(np.argwhere(np.isnan(dFF))) + 1
            t0 = time_[i0]+win
            if time_m[-1] < time_[-1]:
              t1 = time_m[-1] - 2*win
            else:
              t1 = time_[-1] - 2*win

            timestamps1 = np.array(f['Events/'+test+'/'+event[0]+'/'+mouse+'/'+trial+'/timestamps'])
            if len(timestamps1) != 0:
             # Remove events that fall to NaN part of dFF
              on_off_1 = timestamps1[np.all(timestamps1 > t0, axis=1)]
              on_off_1 = on_off_1[np.all(on_off_1 < t1, axis=1)]
             # Adjust timestamps to time vector
              inds1 = [ [find_idx(i,time_),find_idx(j,time_)] for i,j in on_off_1]
              on_off_1 = [ [time_[i],time_[j]] for i,j in inds1]
              on_off_1 = np.array(on_off_1)
             # Off event time
              off_on_1 = np.zeros(on_off_1.shape)
              off_on_1[0,0] = t0
              off_on_1[1:,0] = on_off_1[:-1,1]
              off_on_1[:,1] = on_off_1[:,0]
             # Onsets and offset
              on_1 = np.array(on_off_1[:,0])
              off_1 = np.array(on_off_1[:,1])
             # Adjust off time
              off_on_1 = adjust_intervals_durations(off_on_1,min_duration=2*win)
            else:
              on_1 = []
              off_1 = []
              off_on_1 = []
            

            timestamps2 = np.array(f['Events/'+test+'/'+event[1]+'/'+mouse+'/'+trial+'/timestamps'])
            if len(timestamps2) != 0:
             # Remove events that fall to NaN part of dFF
              on_off_2 = timestamps2[np.all(timestamps2 > t0, axis=1)]
              on_off_2 = on_off_2[np.all(on_off_2 < t1, axis=1)]
             # Adjust timestamps to time vector
              inds2 = [ [find_idx(i,time_),find_idx(j,time_)] for i,j in on_off_2]
              on_off_2 = [ [time_[i],time_[j]] for i,j in inds2]
              on_off_2 = np.array(on_off_2)
             # Off event time
              off_on_2 = np.zeros(on_off_2.shape)
              off_on_2[0,0] = t0
              off_on_2[1:,0] = on_off_2[:-1,1]
              off_on_2[:,1] = on_off_2[:,0]
             # Onsets and offsets
              on_2 = np.array(on_off_2[:,0])
              off_2 = np.array(on_off_2[:,1])
             # Adjust off time
              off_on_2 = adjust_intervals_durations(off_on_2,min_duration=2*win)
            else:
              on_2 = []
              off_2 = []
              off_on_2 = []
           
           # Combine two off event times
            if len(off_on_1) == 0:
              off_on = off_on_2
            elif len(off_on_2) == 0:
              off_on = off_on_1
            else:
              off_on = np.concatenate((off_on_1,off_on_2),axis=0)

           # Get random timestamps in off event time
            other = []
            for i,j in off_on:
              other.extend(np.arange(i+win,j-win+period,period))

             # Choose subset of 50 events
            other = np.array(random_subset(other, 50))

            path = 'MeasurePerieventCorrelation/events/'+test+'/'+measure+'/'+event[0]+'-'+event[1]+'/'
            saveToHDF(f,path+event[0]+'-onset/'+mouse+'/'+trial,on_1)
            saveToHDF(f,path+event[0]+'-offset/'+mouse+'/'+trial,off_1)
            saveToHDF(f,path+event[1]+'-onset/'+mouse+'/'+trial,on_2)
            saveToHDF(f,path+event[1]+'-offset/'+mouse+'/'+trial,off_2)
            saveToHDF(f,path+'other/'+mouse+'/'+trial,other)

        event = event[0]+'-'+event[1]

      else:

       # Choose events
        for mouse in f['Events/'+test+'/'+event]:
          for trial in f['Events/'+test+'/'+event+'/'+mouse]:

            outputs = list(f['DFFs/'+test])
            for output in outputs:
              if 'DFFs/'+test+'/'+output+'/'+mouse in f:
                break
            dFF = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/dFF'])
            time_ = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/time'])
            time_m = np.array(f['Measurements/'+test+'/'+measure+'/'+mouse+'/'+trial+'/time'])
            period = find_avg_period(time_)
              
            timestamps = np.array(f['Events/'+test+'/'+event+'/'+mouse+'/'+trial+'/timestamps'])

            if timestamps.shape[1]==1:

            # Find where not NaN part of dFF starts 
              i0 = np.max(np.argwhere(np.isnan(dFF))) + 1
              t0 = time_[i0]+win
              if time_m[-1] < time_[-1]:
                t1 = time_m[-1]-2*win
              else:
                t1 = time_[-1]-2*win

            # Remove events that fall to NaN part of dFF
              on = timestamps[np.all(timestamps > t0, axis=1)]
              on = on[np.all(on < t1, axis=1)]

            # Adjust timestamps to time vector
              inds = [ find_idx(i,time_) for i in on]
              on = [ time_[i] for i in inds]
              on = np.array(on)

              other = []
              for i in on:
                other.extend(np.arange(i-win,i+win+period,period))

            # Choose subset of 50 events
              other = np.array(random_subset(other, 50))

              path = 'MeasurePerieventCorrelation/events/'+test+'/'+measure+'/'+event+'/'
              saveToHDF(f,path+event+'/'+mouse+'/'+trial,on)
              saveToHDF(f,path+'other/'+mouse+'/'+trial,other)

            else:
              if other_events == 'other':

              # Find where not NaN part of dFF starts 
                i0 = np.max(np.argwhere(np.isnan(dFF))) + 1
                t0 = time_[i0]+win
                if time_m[-1]<time_[-1]:
                  t1 = time_m[-1]-2*win
                else:
                  t1 = time_[-1]-2*win

              # Remove events that fall to NaN part of dFF
                on_off = timestamps[np.all(timestamps > t0, axis=1)]
                on_off = on_off[np.all(on_off < t1, axis=1)]

              # Adjust timestamps to time vector
                inds = [ [find_idx(i,time_),find_idx(j,time_)] for i,j in on_off]
                on_off = [ [time_[i],time_[j]] for i,j in inds]
                on_off = np.array(on_off)

                off_on = np.zeros(on_off.shape)
                off_on[0,0] = t0
                off_on[1:,0] = on_off[:-1,1]
                off_on[:,1] = on_off[:,0]

                on = np.array(on_off[:,0])
                off = np.array(on_off[:,1])

                off_on = adjust_intervals_durations(off_on,min_duration=2*win)

                other = []
                for i,j in off_on:
                  other.extend(np.arange(i+win,j-win+period,period))

              # Choose subset of 50 events
                other = np.array(random_subset(other, 50))

                path = 'MeasurePerieventCorrelation/events/'+test+'/'+measure+'/'+event+'/'
                saveToHDF(f,path+'onset/'+mouse+'/'+trial,on)
                saveToHDF(f,path+'offset/'+mouse+'/'+trial,off)
                saveToHDF(f,path+'other/'+mouse+'/'+trial,other)

              elif other_events == 'inbetween':

              # Find where not NaN part of dFF starts 
                i0 = np.max(np.argwhere(np.isnan(dFF))) + 1
                t0 = time_[i0]+2*win
                if time_m[-1]<time_[-1]:
                  t1 = time_m[-1]-2*win
                else:
                  t1 = time_[-1]-2*win

              # Remove events that fall to NaN part of dFF
                on_off = timestamps[np.all(timestamps > t0, axis=1)]
                on_off = on_off[np.all(on_off < t1, axis=1)]

              # Adjust timestamps to time vector
                inds = [ [find_idx(i,time_),find_idx(j,time_)] for i,j in on_off]
                on_off = [ [time_[i],time_[j]] for i,j in inds]
                on_off = np.array(on_off)

                off_on = np.zeros(on_off.shape)
                off_on[0,0] = t0
                off_on[1:,0] = on_off[:-1,1]
                off_on[:,1] = on_off[:,0]

                onoffsets = adjust_intervals_durations(on_off,min_duration,min_interval)
                on = np.array(onoffsets[:,0])
                off = np.array(onoffsets[:,1])

                on_off = adjust_intervals_durations(on_off,min_duration=2*win)
                off_on = adjust_intervals_durations(off_on,min_duration=2*win)

                dur_on_off = []
                for i,j in on_off:
                  dur_on_off.extend(np.arange(i+win,j-win+period,period))
                dur_off_on = []
                for i,j in off_on:
                  dur_off_on.extend(np.arange(i+win,j-win+period,period))

              # Choose subset of 50 events
                dur_on_off = np.array(random_subset(dur_on_off, 50))
                dur_off_on = np.array(random_subset(dur_off_on, 50))

                path = 'MeasurePerieventCorrelation/events/'+test+'/'+measure+'/'+event+'/'
                saveToHDF(f,path+'onset/'+mouse+'/'+trial,on)
                saveToHDF(f,path+'on-off/'+mouse+'/'+trial,dur_on_off)
                saveToHDF(f,path+'offset/'+mouse+'/'+trial,off)
                saveToHDF(f,path+'off-on/'+mouse+'/'+trial,dur_off_on)

     # Calculate perievent correlation
      for output in f['DFFs/'+test]:
        for mouse in f['DFFs/'+test+'/'+output]:

          for e_part in f['MeasurePerieventCorrelation/events/'+test+'/'+measure+'/'+event]:

            lags = []
            corrs = []
            maxlags = []
            maxcorrs = []
            Rs = []
            ps = []

            for trial in f['DFFs/'+test+'/'+output+'/'+mouse]:

              if self.goodRecordings != []:
                if [mouse, test, trial, output] in self.goodRecordings:

                  signal = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/dFF'])
                  time_s = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/time'])
                  measurement = np.array(f['Measurements/'+test+'/'+measure+'/'+mouse+'/'+trial+'/values'])
                  time_m = np.array(f['Measurements/'+test+'/'+measure+'/'+mouse+'/'+trial+'/time'])

                  T_s = find_avg_period(time_s)
                  T_m = find_avg_period(time_m)
                  if T_s != T_m:
                    return print('Interpolate signals.')

                  timestamps = np.array(f['MeasurePerieventCorrelation/events/'+test+'/'+measure+'/'+event+'/'+e_part+'/'+mouse+'/'+trial])

                  if len(timestamps) != 0:
                    for t in timestamps:

                      idx = find_idx(t,time_s)
                      w = int(round(win/T_s))

                      s = signal[idx-w:idx+w+1]
                      m = measurement[idx-w:idx+w+1]

                      lag, corr = xcorr(s,m,True,True,maxlags=w)
                      idx_maxlag = np.argmax(np.abs(corr))
                      maxcorr = corr[idx_maxlag]
                      maxlag = lag[idx_maxlag]

                      s = signal[idx+maxlag-w:idx+maxlag+w+1]

                      R,p = pearsonr(s, m)

                      lag = [i*T_s for i in lag]
                      maxlag = maxlag*T_s

                      lags.append(lag)
                      corrs.append(corr)
                      maxlags.append(maxlag)
                      maxcorrs.append(maxcorr)
                      Rs.append(R)
                      ps.append(p)

                  path = 'MeasurePerieventCorrelation/correlation/'+test+'/'+measure+'/'+event+'/'+e_part+'/'+output+'/'+mouse
                  saveToHDF(f,path+'/lag',lags)
                  saveToHDF(f,path+'/corr',corrs)
                  #saveToHDF(f,path+'/max-lag',maxlags)
                  #saveToHDF(f,path+'/max-corr',maxcorrs)
                  saveToHDF(f,path+'/R',Rs)
                  saveToHDF(f,path+'/pvalue',ps)

              else:
                
                signal = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/dFF'])
                time_s = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/time'])
                measurement = np.array(f['Measurements/'+test+'/'+measure+'/'+mouse+'/'+trial+'/values'])
                time_m = np.array(f['Measurements/'+test+'/'+measure+'/'+mouse+'/'+trial+'/time'])

                T_s = find_avg_period(time_s)
                T_m = find_avg_period(time_m)
                if T_s != T_m:
                  return print('Interpolate signals.')

                timestamps = np.array(f['MeasurePerieventCorrelation/events/'+test+'/'+measure+'/'+event+'/'+e_part+'/'+mouse+'/'+trial])

                if len(timestamps) != 0:
                  for t in timestamps:

                    idx = find_idx(t,time_s)
                    w = int(round(win/T_s))

                    s = signal[idx-w:idx+w+1]
                    m = measurement[idx-w:idx+w+1]

                    lag, corr = xcorr(s,m,True,True,maxlags=w)
                    idx_maxlag = np.argmax(np.abs(corr))
                    maxcorr = corr[idx_maxlag]
                    maxlag = lag[idx_maxlag]

                    s = signal[idx+maxlag-w:idx+maxlag+w+1]

                    R,p = pearsonr(s, m)

                    lag = [i*T_s for i in lag]
                    maxlag = maxlag*time_s
                    
                    lags.append(lag)
                    corrs.append(corr)
                    maxlags.append(maxlag)
                    maxcorrs.append(maxcorr)
                    Rs.append(R)
                    ps.append(p)

                path = 'MeasurePerieventCorrelation/correlation/'+test+'/'+measure+'/'+event+'/'+e_part+'/'+output+'/'+mouse
                saveToHDF(f,path+'/lag',lags)
                saveToHDF(f,path+'/corr',corrs)
                #saveToHDF(f,path+'/max-lag',maxlags)
                #saveToHDF(f,path+'/max-corr',maxcorrs)
                saveToHDF(f,path+'/R',Rs)
                saveToHDF(f,path+'/pvalue',ps)


     # Count positive/negative/not correlations 
      for e_part in f['MeasurePerieventCorrelation/correlation/'+test+'/'+measure+'/'+event]:
        for output in f['MeasurePerieventCorrelation/correlation/'+test+'/'+measure+'/'+event+'/'+e_part]:

          mice = []
          list_not_corr = []
          list_pos_corr = []
          list_neg_corr = []

          for mouse in f['MeasurePerieventCorrelation/correlation/'+test+'/'+measure+'/'+event+'/'+e_part+'/'+output]:
            
            ps = np.array(f['MeasurePerieventCorrelation/correlation/'+test+'/'+measure+'/'+event+'/'+e_part+'/'+output+'/'+mouse+'/pvalue'])
            Rs = np.array(f['MeasurePerieventCorrelation/correlation/'+test+'/'+measure+'/'+event+'/'+e_part+'/'+output+'/'+mouse+'/R'])

            if len(ps) != 0:
              n = len(ps)                               # total number events
              corr = Rs[(ps<0.001) & (np.abs(Rs)>0.6)]  # list of correlated events
              not_corr = (n - len(corr)) / n            # % of not correlated event
              pos_corr = sum(corr>0) / n                # % of positive correlated
              neg_corr = sum(corr<0) / n                # % of negative correlated

              mice.append(mouse)
              list_not_corr.append(not_corr)
              list_pos_corr.append(pos_corr)
              list_neg_corr.append(neg_corr)

          #print(output,e_part)  
          #print(np.array([mice,list_pos_corr,list_neg_corr,list_not_corr]).T)

          path = 'MeasurePerieventCorrelation/counts/'+test+'/'+measure+'/'+event+'/'+e_part+'/'+output
          saveToHDF(f,path+'/mice',np.array(mice,dtype=h5py.string_dtype(encoding='utf-8')))
          saveToHDF(f,path+'/not-corr',list_not_corr)
          saveToHDF(f,path+'/pos-corr',list_pos_corr)
          saveToHDF(f,path+'/neg-corr',list_neg_corr)




  def getMeasureCrossCorrelation(self,measure,event,e_part='onset',win=[-3,3]):

    test = self.test

    with h5py.File(self.filename, 'a') as f:
    
      for output in f['DFFs/'+test]:
          for mouse in f['DFFs/'+test+'/'+output]:

            lags = []
            corrs = []
            maxlags = []
            maxcorrs = []
            Rs = []
            ps = []

            for trial in f['DFFs/'+test+'/'+output+'/'+mouse]:

              if self.goodRecordings != []:
                if [mouse, test, trial, output] in self.goodRecordings:

                  signal = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/dFF'])
                  time_s = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/time'])
                  measurement = np.array(f['Measurements/'+test+'/'+measure+'/'+mouse+'/'+trial+'/values'])
                  time_m = np.array(f['Measurements/'+test+'/'+measure+'/'+mouse+'/'+trial+'/time'])

                  T_s = find_avg_period(time_s)
                  T_m = find_avg_period(time_m)
                  if T_s != T_m:
                    return print('Interpolate signals.')

                  timestamps = np.array(f['Events/'+test+'/'+event+'/'+mouse+'/'+trial+'/timestamps'])

                  if len(timestamps) != 0:

                    if e_part == 'onset':
                      timestamps = timestamps[:,0]
                    elif e_part == 'offset':
                      timestamps = timestamps[:,1]

                  # Find where not NaN part of dFF starts 
                    i0 = np.max(np.argwhere(np.isnan(signal))) + 1
                    t0 = time_s[i0] + (win[1]-win[0])
                    if time_m[-1]<time_s[-1]:
                      t1 = time_m[-1] - (win[1]-win[0])
                    else:
                      t1 = time_s[-1] - (win[1]-win[0])
                  # Remove events that fall to NaN part of dFF
                    timestamps = timestamps[timestamps > t0]
                    timestamps = timestamps[timestamps < t1]


                    for t in timestamps:

                      idx = find_idx(t,time_s)
                      w0 = int(round(win[0]/T_s))
                      w1 = int(round(win[1]/T_s))
                      w = int(round((win[1]-win[0])/T_s/2))

                      s = signal[idx+w0:idx+w1+1]
                      m = measurement[idx+w0:idx+w1+1]

                      lag, corr = xcorr(s,m,True,True,maxlags=w)
                      idx_maxlag = np.argmax(np.abs(corr))
                      maxcorr = corr[idx_maxlag]
                      maxlag = lag[idx_maxlag]

                      s = signal[idx+maxlag+w0:idx+maxlag+w1+1]
                      R,p = pearsonr(s, m)

                      lag = [i*T_s for i in lag]
                      maxlag = maxlag*T_s

                      lags.append(lag)
                      corrs.append(corr)
                      maxlags.append(maxlag)
                      maxcorrs.append(maxcorr)
                      Rs.append(R)
                      ps.append(p)

                    path = 'MeasurePerieventCorrelation/correlation/'+test+'/'+measure+'/'+event+'/'+e_part+'/'+output+'/'+mouse
                    saveToHDF(f,path+'/lag',lags)
                    saveToHDF(f,path+'/corr',corrs)
                    #saveToHDF(f,path+'/max-lag',maxlags)
                    #saveToHDF(f,path+'/max-corr',maxcorrs)
                    saveToHDF(f,path+'/R',Rs)
                    saveToHDF(f,path+'/pvalue',ps)

              else:
                
                signal = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/dFF'])
                time_s = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/time'])
                measurement = np.array(f['Measurements/'+test+'/'+measure+'/'+mouse+'/'+trial+'/values'])
                time_m = np.array(f['Measurements/'+test+'/'+measure+'/'+mouse+'/'+trial+'/time'])

                T_s = find_avg_period(time_s)
                T_m = find_avg_period(time_m)
                if T_s != T_m:
                  return print('Interpolate signals.')

                timestamps = np.array(f['Events/'+test+'/'+event+'/'+mouse+'/'+trial+'/timestamps'])

                if len(timestamps) != 0:

                  if e_part == 'onset':
                    timestamps = timestamps[:,0]
                  elif e_part == 'offset':
                    timestamps = timestamps[:,1]

                # Find where not NaN part of dFF starts 
                  i0 = np.max(np.argwhere(np.isnan(signal))) + 1
                  t0 = time_s[i0] + (win[1]-win[0])
                  if time_m[-1]<time_s[-1]:
                    t1 = time_m[-1] - (win[1]-win[0])
                  else:
                    t1 = time_s[-1] - (win[1])
                # Remove events that fall to NaN part of dFF
                  timestamps = timestamps[timestamps > t0]
                  timestamps = timestamps[timestamps < t1]


                  for t in timestamps:

                    idx = find_idx(t,time_s)
                    w0 = int(round(win[0]/T_s))
                    w1 = int(round(win[1]/T_s))
                    w = int(round((win[1]-win[0])/T_s/2))

                    s = signal[idx+w0:idx+w1+1]
                    m = measurement[idx+w0:idx+w1+1]

                    lag, corr = xcorr(s,m,True,True,maxlags=w)
                    idx_maxlag = np.argmax(np.abs(corr))
                    maxcorr = corr[idx_maxlag]
                    maxlag = lag[idx_maxlag]

                    s = signal[idx+maxlag+w0:idx+maxlag+w1+1]
                    R,p = pearsonr(s, m)

                    lag = [i*T_s for i in lag]
                    maxlag = maxlag*T_s
                    
                    lags.append(lag)
                    corrs.append(corr)
                    maxlags.append(maxlag)
                    maxcorrs.append(maxcorr)
                    Rs.append(R)
                    ps.append(p)

                  path = 'MeasurePerieventCorrelation/correlation/'+test+'/'+measure+'/'+event+'/'+e_part+'/'+output+'/'+mouse
                  saveToHDF(f,path+'/lag',lags)
                  saveToHDF(f,path+'/corr',corrs)
                  #saveToHDF(f,path+'/max-lag',maxlags)
                  #saveToHDF(f,path+'/max-corr',maxcorrs)
                  saveToHDF(f,path+'/R',Rs)
                  saveToHDF(f,path+'/pvalue',ps)




  def getOutputCrossCorrelation(self,output,output1,event,e_part='onset',win=[-3,3]):

    test = self.test

    with h5py.File(self.filename, 'a') as f:
    
      for mouse in f['DFFs/'+test+'/'+output]:

        lags = []
        corrs = []
        maxlags = []
        maxcorrs = []
        Rs = []
        ps = []

        for trial in f['DFFs/'+test+'/'+output+'/'+mouse]:

          if self.goodRecordings != []:
            if ([mouse,test,trial,output] in self.goodRecordings) and ([mouse,test,trial,output1] in self.goodRecordings):

              print(mouse,trial,end=' ')

              signal = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/dFF'])
              signal1 = np.array(f['DFFs/'+test+'/'+output1+'/'+mouse+'/'+trial+'/dFF'])
              time_s = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/time'])

              T_s = find_avg_period(time_s)

              timestamps = np.array(f['Events/'+test+'/'+event+'/'+mouse+'/'+trial+'/timestamps'])

              if len(timestamps) != 0:

                if e_part == 'onset':
                  timestamps = timestamps[:,0]
                elif e_part == 'offset':
                  timestamps = timestamps[:,1]

               # Find where not NaN part of dFF starts 
                i0 = np.max(np.argwhere(np.isnan(signal))) + 1
                t0 = time_s[i0] + (win[1]-win[0])
                t1 = time_s[-1] - (win[1]-win[0])
               # Remove events that fall to NaN part of dFF
                timestamps = timestamps[timestamps > t0]
                timestamps = timestamps[timestamps < t1]


                for t in timestamps:

                  idx = find_idx(t,time_s)
                  w0 = int(round(win[0]/T_s))
                  w1 = int(round(win[1]/T_s))
                  w = int(round((win[1]-win[0])/T_s/2))

                  s = signal[idx+w0:idx+w1+1]
                  s1 = signal1[idx+w0:idx+w1+1]

                  lag, corr = xcorr(s,s1,True,True,maxlags=w)
                  idx_maxlag = np.argmax(np.abs(corr))
                  maxcorr = corr[idx_maxlag]
                  maxlag = lag[idx_maxlag]

                  s = signal[idx+maxlag+w0:idx+maxlag+w1+1]
                  R,p = pearsonr(s,s1)

                  lag = [i*T_s for i in lag]
                  maxlag = maxlag*T_s

                  lags.append(lag)
                  corrs.append(corr)
                  maxlags.append(maxlag)
                  maxcorrs.append(maxcorr)
                  Rs.append(R)
                  ps.append(p)

                path = 'OutputPerieventCorrelation/correlation/'+test+'/'+event+'/'+e_part+'/'+output+'_'+output1+'/'+mouse
                saveToHDF(f,path+'/lag',lags)
                saveToHDF(f,path+'/corr',corrs)
                #saveToHDF(f,path+'/max-lag',maxlags)
                #saveToHDF(f,path+'/max-corr',maxcorrs)
                saveToHDF(f,path+'/R',Rs)
                saveToHDF(f,path+'/pvalue',ps)

          else:
                
            signal = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/dFF'])
            signal1 = np.array(f['DFFs/'+test+'/'+output1+'/'+mouse+'/'+trial+'/dFF'])
            time_s = np.array(f['DFFs/'+test+'/'+output+'/'+mouse+'/'+trial+'/time'])

            timestamps = np.array(f['Events/'+test+'/'+event+'/'+mouse+'/'+trial+'/timestamps'])

            if len(timestamps) != 0:

              if e_part == 'onset':
                timestamps = timestamps[:,0]
              elif e_part == 'offset':
                timestamps = timestamps[:,1]

             # Find where not NaN part of dFF starts 
              i0 = np.max(np.argwhere(np.isnan(signal))) + 1
              t0 = time_s[i0] + (win[1]-win[0])
              t1 = time_s[-1] - (win[1])
             # Remove events that fall to NaN part of dFF
              timestamps = timestamps[timestamps > t0]
              timestamps = timestamps[timestamps < t1]

              for t in timestamps:

                idx = find_idx(t,time_s)
                w0 = int(round(win[0]/T_s))
                w1 = int(round(win[1]/T_s))
                w = int(round((win[1]-win[0])/T_s/2))

                s = signal[idx+w0:idx+w1+1]
                s1 = signal1[idx+w0:idx+w1+1]

                lag, corr = xcorr(s,s1,True,True,maxlags=w)
                idx_maxlag = np.argmax(np.abs(corr))
                maxcorr = corr[idx_maxlag]
                maxlag = lag[idx_maxlag]

                s = signal[idx+maxlag+w0:idx+maxlag+w1+1]
                R,p = pearsonr(s,s1)

                lag = [i*T_s for i in lag]
                maxlag = maxlag*T_s
                    
                lags.append(lag)
                corrs.append(corr)
                maxlags.append(maxlag)
                maxcorrs.append(maxcorr)
                Rs.append(R)
                ps.append(p)

              path = 'OutputPerieventCorrelation/correlation/'+test+'/'+event+'/'+e_part+'/'+output+'_'+output1+'/'+mouse
              saveToHDF(f,path+'/lag',lags)
              saveToHDF(f,path+'/corr',corrs)
              #saveToHDF(f,path+'/max-lag',maxlags)
              #saveToHDF(f,path+'/max-corr',maxcorrs)
              saveToHDF(f,path+'/R',Rs)
              saveToHDF(f,path+'/pvalue',ps)




  def plotOutputCorrelationCounts(self,output,output1,event,
                                  event_labels=None,error_type='SEM',**kwargs):

    test = self.test

    mean = []
    error = []
    labels = []

    with h5py.File(self.filename,'r') as f:

      for e_part in f['OutputPerieventCorrelation/counts/'+test+'/'+event]:

        path = 'OutputPerieventCorrelation/counts/'+test+'/'+event+'/'+e_part+'/'+output+'_'+output1

        not_corr = list(f[path+'/not-corr'])
        pos_corr = list(f[path+'/pos-corr'])
        neg_corr = list(f[path+'/neg-corr'])

        mean_pos_corr,error_pos_corr = calculate_mean_error(pos_corr,error_type)
        mean_neg_corr,error_neg_corr = calculate_mean_error(neg_corr,error_type)
        mean_not_corr,error_not_corr = calculate_mean_error(not_corr,error_type)

        mean.append([mean_pos_corr,mean_neg_corr,mean_not_corr])
        error.append([error_pos_corr,error_neg_corr,error_neg_corr])
          
        if event_labels is None:
          labels.append(e_part)

      mean = np.array(mean).T
      error = np.array(error).T

    if event_labels is None:
      event_labels = labels

    plot_perievent_correlation_counts(mean,error,event_labels,**kwargs)





  def plotMeasureCorrelationCounts(self,measure,output,event,
                                   event_labels=None,error_type='SEM',**kwargs):

    test = self.test

    mean = []
    error = []
    labels = []

    with h5py.File(self.filename,'r') as f:

      if output == 'all':

        for e_part in f['MeasurePerieventCorrelation/counts/'+test+'/'+measure+'/'+event]:

          list_pos_corr = []
          list_neg_corr = []
          list_not_corr = []

          for output in f['MeasurePerieventCorrelation/counts/'+test+'/'+measure+'/'+event+'/'+e_part]:

            path = 'MeasurePerieventCorrelation/counts/'+test+'/'+measure+'/'+event+'/'+e_part+'/'+output

            not_corr = list(f[path+'/not-corr'])
            pos_corr = list(f[path+'/pos-corr'])
            neg_corr = list(f[path+'/neg-corr'])

            list_not_corr.extend(not_corr)
            list_pos_corr.extend(pos_corr)
            list_neg_corr.extend(neg_corr)

          mean_pos_corr,error_pos_corr = calculate_mean_error(list_pos_corr,error_type)
          mean_neg_corr,error_neg_corr = calculate_mean_error(list_neg_corr,error_type)
          mean_not_corr,error_not_corr = calculate_mean_error(list_not_corr,error_type)

          mean.append([mean_pos_corr,mean_neg_corr,mean_not_corr])
          error.append([error_pos_corr,error_neg_corr,error_neg_corr])
          
          if event_labels is None:
            labels.append(e_part)

        mean = np.array(mean).T
        error = np.array(error).T

      else:

        for e_part in f['MeasurePerieventCorrelation/counts/'+test+'/'+measure+'/'+event]:

          path = 'MeasurePerieventCorrelation/counts/'+test+'/'+measure+'/'+event+'/'+e_part+'/'+output

          not_corr = list(f[path+'/not-corr'])
          pos_corr = list(f[path+'/pos-corr'])
          neg_corr = list(f[path+'/neg-corr'])

          mean_pos_corr,error_pos_corr = calculate_mean_error(pos_corr,error_type)
          mean_neg_corr,error_neg_corr = calculate_mean_error(neg_corr,error_type)
          mean_not_corr,error_not_corr = calculate_mean_error(not_corr,error_type)

          mean.append([mean_pos_corr,mean_neg_corr,mean_not_corr])
          error.append([error_pos_corr,error_neg_corr,error_neg_corr])
          
          if event_labels is None:
            labels.append(e_part)

        mean = np.array(mean).T
        error = np.array(error).T

      if event_labels is None:
        event_labels = labels

      plot_perievent_correlation_counts(mean,error,event_labels,**kwargs)




  def plotMeasureCrossCorrelation(self,measure,output,event,e_part='onset',**kwargs):

    from scipy.stats import mode

    test = self.test

    corrs = []
    lags = []

    with h5py.File(self.filename,'r') as f:

      if output=='all':

        for output in f['MeasurePerieventCorrelation/correlation/'+test+'/'+measure+'/'+event+'/'+e_part]:

          path = 'MeasurePerieventCorrelation/correlation/'+test+'/'+measure+'/'+event+'/'+e_part+'/'+output+'/'

          for mouse in f[path]:

            Rs = np.array(f[path+mouse+'/R'])
            ps = np.array(f[path+mouse+'/pvalue'])

            corr = np.array(f[path+mouse+'/corr'])
            lag = np.array(f[path+mouse+'/lag'])

            if len(corrs) != 0: 
              corr = corr[(ps<0.001) & (Rs>0.6)]
              lag = lag[(ps<0.001) & (Rs>0.6)]

              corrs.append(np.median(corr,axis=0))
              lags.append(np.median(lag,axis=0))

      else:

        path = 'MeasurePerieventCorrelation/correlation/'+test+'/'+measure+'/'+event+'/'+e_part+'/'+output+'/'

        for mouse in f[path]:

          Rs = np.array(f[path+mouse+'/R'])
          ps = np.array(f[path+mouse+'/pvalue'])

          corr = np.array(f[path+mouse+'/corr'])
          lag = np.array(f[path+mouse+'/lag'])

          if len(corr) != 0:
            corr = corr[(ps<0.001) & (Rs>0.6)]
            lag = lag[(ps<0.001) & (Rs>0.6)]

            median_corr = np.mean(corr, axis=0)
            median_lag = np.mean(lag, axis=0)

            corrs.append(median_corr)
            lags.append(median_lag)

    nrow = len(corrs)
    ncol = len(corrs[0])
    corrs = np.array(corrs).reshape((nrow,ncol))
    lags = np.array(lags).reshape((nrow,ncol))

    plot_crosscorrelation(lags[0],corrs,**kwargs)

    #return corrs




  def plotOutputCrossCorrelation(self,output,output1,event,e_part='onset',**kwargs):

    from scipy.stats import mode

    test = self.test

    corrs = []
    lags = []

    with h5py.File(self.filename,'r') as f:

      path = 'OutputPerieventCorrelation/correlation/'+test+'/'+event+'/'+e_part+'/'+output+'_'+output1+'/'

      for mouse in f[path]:

        Rs = np.array(f[path+mouse+'/R'])
        ps = np.array(f[path+mouse+'/pvalue'])

        corr = np.array(f[path+mouse+'/corr'])
        lag = np.array(f[path+mouse+'/lag'])

        if len(corr) != 0:
          corr = corr[(ps<0.001) & (Rs>0.6)]
          lag = lag[(ps<0.001) & (Rs>0.6)]

          median_corr = np.median(corr, axis=0)
          median_lag =np.median(lag, axis=0)

          corrs.append(median_corr)
          lags.append(median_lag)

    nrow = len(corrs)
    ncol = len(corrs[0])
    corrs = np.array(corrs).reshape((nrow,ncol))
    lags = np.array(lags).reshape((nrow,ncol))

    plot_crosscorrelation(lags[0],corrs,**kwargs)

    #return corrs




  def getDataFrameCrossCorr(self,measure,event,e_part,
                          save=False,csvname='crosscorr.csv'):

    test = self.test

    mice_list = []
    output_list = []
    maxlag_list = []

    with h5py.File(self.filename,'r') as f:

      for output in f['MeasurePerieventCorrelation/correlation/'+test+'/'+measure+'/'+event+'/'+e_part]:

        path = 'MeasurePerieventCorrelation/correlation/'+test+'/'+measure+'/'+event+'/'+e_part+'/'+output+'/'

        for mouse in f[path]:

          Rs = np.array(f[path+mouse+'/R'])
          ps = np.array(f[path+mouse+'/pvalue'])

          corr = np.array(f[path+mouse+'/corr'])
          lag = np.array(f[path+mouse+'/lag'])[0]
 
          corr = corr[(ps<0.001) & (Rs>0.6)]
          median_corr = np.median(corr, axis=0)

          sm_corr = smooth_signal(median_corr,5)
          maxlag = get_midline(lag,sm_corr)

          mice_list.append(mouse)
          maxlag_list.append(maxlag)
          output_list.append(output)
          

    df = pd.DataFrame({'mouse': mice_list,
                      'output': output_list,
                      'maxlag': maxlag_list})
      
    if save:
      df.to_csv(csvname,index=False)

    return df
    

# Experiment Class

In [None]:
class FiberPhotometryExperiment:
  def __init__(self,filename):
    self.filename = filename
    with h5py.File(filename,'r') as f:
      try:
        self.mice = list(f.attrs['mice'])
        self.outputs = list(f.attrs['outputs'])
        self.tests = list(f.attrs['tests'])
        self.goodRecordings = list(f.attrs['good recordings'])
        print('Tests information is successfully loaded.')
      except:
        print('Set names of mice, outputs/pathways recorded and good recordings\nas attributes of HDF file, and create the object again')

  

  def plotRoverTests(self,output,output1=None,measure=None,tests='all',**kwargs):

    if (output1 is not None) and (measure is not None):
      print('Choose to plot correlation of 2 different outups/paths or 1 output/path and 1 measure.')
      return

    if tests == 'all':
      tests = self.tests

    list_Rs = []
    list_tests = []

    with h5py.File(self.filename, 'r') as f:
      
      if output == 'all':

        if output1 is not None:

          for test in tests:
            for output in self.outputs:
              Rs = list(f['OutputCorrelation/'+test+'/'+output+'_'+output1+'/R'])
              list_Rs.extend(Rs)
              list_tests.extend([test]*len(Rs))
      
        if measure is not None:

          for test in tests:
            for output in self.outputs:
              Rs = list(f['MeasureCorrelation/'+test+'/'+measure+'/'+output+'/R'])
              list_Rs.extend(Rs)
              list_tests.extend([test]*len(Rs))
      
      else:

        if output1 is not None:
          for test in tests:
            Rs = list(f['OutputCorrelation/'+test+'/'+output+'_'+output1+'/R'])
            list_Rs.extend(Rs)
            list_tests.extend([test]*len(Rs))

        elif measure is not None:
          for test in tests:
            Rs = list(f['MeasureCorrelation/'+test+'/'+measure+'/'+output+'/R'])
            list_Rs.extend(Rs)
            list_tests.extend([test]*len(Rs))

    plot_violinplot(list_tests,list_Rs,**kwargs)

    return




  def getDataFrameAUC(self,tests,events,onoffset,periods=['baseline','event'],
                      save=False,csvname='auc.csv'):

    test = self.test

    mice_list = []
    output_list = []
    period_list = []
    auc_list = []

    with h5py.File(self.filename,'r') as f:

      for output in f['Means/'+test]:

        path = 'Means/'+test+'/'+output+'/'+event+'/'+onoffset+'/'
        mice = list(f[path+'mice'])
        auc = np.array(f[path+'auc'])

        for i,period in enumerate(periods):

          n = len(mice)
          mice_list.extend(mice)
          output_list.extend([output]*n)
          period_list.extend([period]*n)
          auc_list.extend(list(auc[:,i]))


      df = pd.DataFrame({'mouse': mice_list,
                        'output': output_list,
                        'period': period_list,
                           'auc': auc_list})
      
      if save:
        df.to_csv(csvname,index=False)

    return df



  
  def getDataFrameMeasureCorrCounts(self,measure,tests=None,outputs=None,
                                    save=True,csvname='corrcounts.csv'):
    
    mice_list = []
    test_list = []
    output_list = []
    event_list = []
    pos_corr_list = []
    neg_corr_list = []
    not_corr_list = []


    if tests == None:
      tests = self.tests
    if outputs == None:
      outputs = self.outputs
    
    with h5py.File(self.filename, 'r') as f:

      for test in tests:
        for output in outputs:

          for event in f['MeasurePerieventCorrelation/counts/'+test+'/'+measure]:
            for e_part in f['MeasurePerieventCorrelation/counts/'+test+'/'+measure+'/'+event]:

              path = 'MeasurePerieventCorrelation/counts/'+test+'/'+measure+'/'+event+'/'+e_part+'/'+output

              mice = np.array(f[path+'/mice'])
              pos_corr = np.array(f[path+'/pos-corr'])
              neg_corr = np.array(f[path+'/neg-corr'])
              not_corr = np.array(f[path+'/not-corr'])

              mice_list.extend(mice)
              test_list.extend([test]*len(mice))
              output_list.extend([output]*len(mice))
              event_list.extend([event+'_'+e_part]*len(mice))
              pos_corr_list.extend(pos_corr)
              neg_corr_list.extend(neg_corr)
              not_corr_list.extend(not_corr)
    

    df = pd.DataFrame({'mouse': mice_list,
                        'test': test_list,
                      'output': output_list,
                       'event': event_list,
                    'pos-corr': pos_corr_list,
                    'neg-corr': neg_corr_list,
                    'not-corr': not_corr_list})
    
    return df





  def getDataFrameOutputCorrCounts(self,tests=None,
                                  save=True,csvname='corrcounts.csv'):
    
    mice_list = []
    test_list = []
    output_list = []
    event_list = []
    pos_corr_list = []
    neg_corr_list = []
    not_corr_list = []


    if tests == None:
      tests = self.tests
    
    with h5py.File(self.filename, 'r') as f:

      for test in tests:

        for event in f['OutputPerieventCorrelation/counts/'+test]:
          for e_part in f['OutputPerieventCorrelation/counts/'+test+'/'+event]:
            for output in f['OutputPerieventCorrelation/counts/'+test+'/'+event+'/'+e_part]:

              path = 'OutputPerieventCorrelation/counts/'+test+'/'+event+'/'+e_part+'/'+output

              mice = np.array(f[path+'/mice'])
              pos_corr = np.array(f[path+'/pos-corr'])
              neg_corr = np.array(f[path+'/neg-corr'])
              not_corr = np.array(f[path+'/not-corr'])

              mice_list.extend(mice)
              test_list.extend([test]*len(mice))
              output_list.extend([output]*len(mice))
              event_list.extend([event+'_'+e_part]*len(mice))
              pos_corr_list.extend(pos_corr)
              neg_corr_list.extend(neg_corr)
              not_corr_list.extend(not_corr)
    

    df = pd.DataFrame({'mouse': mice_list,
                        'test': test_list,
                     'outputs': output_list,
                       'event': event_list,
                    'pos-corr': pos_corr_list,
                    'neg-corr': neg_corr_list,
                    'not-corr': not_corr_list})
    
    return df



  # def plotRoverTests(self,output,output1=None,measure=None,tests='all',**kwargs):

  #   if (output1 is not None) and (measure is not None):
  #     print('Choose to plot correlation of 2 different outups/paths or 1 output/path and 1 measure.')
  #     return

  #   if tests == 'all':
  #     tests = self.tests

  #   list_Rs = []
  #   list_tests = []

  #   with h5py.File(self.filename, 'r') as f:
      
  #     if output == 'all':

  #       if output1 is not None:

  #         for test in tests:
  #           for output in self.outputs:
  #             Rs = list(f['OutputCorrelation/'+test+'/'+output+'/'+output1+'/R'])
  #             list_Rs.extend(Rs)
  #             list_tests.extend([test]*len(Rs))
      
  #       if measure is not None:

  #         for test in tests:
  #           for output in self.outputs:
  #             Rs = list(f['MeasureCorrelation/'+test+'/'+measure+'/'+output+'/R'])
  #             list_Rs.extend(Rs)
  #             list_tests.extend([test]*len(Rs))
      
  #     else:

  #       if output1 is not None:
  #         for test in tests:
  #           Rs = list(f['OutputCorrelation/'+test+'/'+output+'/'+output1+'/R'])
  #           list_Rs.extend(Rs)
  #           list_tests.extend([test]*len(Rs))

  #       elif measure is not None:
  #         for test in tests:
  #           Rs = list(f['MeasureCorrelation/'+test+'/'+measure+'/'+output+'/R'])
  #           list_Rs.extend(Rs)
  #           list_tests.extend([test]*len(Rs))

  #   plot_violinplot(list_tests,list_Rs,**kwargs)

  #   return



  def getDataFrameRmeasure(self,measure,tests='all',
                          save=False,csvname='R.csv'):

    if tests == 'all':
      tests = self.tests

    mice_list = []
    test_list = []
    output_list = []
    R_list = []
    p_list = []

    with h5py.File(self.filename,'r') as f:

      for test in tests:
        for output in f['MeasureCorrelation/'+test+'/'+measure]:
          mice = list(f['MeasureCorrelation/'+test+'/'+measure+'/'+output+'/mice'])
          Rs = list(f['MeasureCorrelation/'+test+'/'+measure+'/'+output+'/R'])
          ps = list(f['MeasureCorrelation/'+test+'/'+measure+'/'+output+'/p-value'])

          mice_list.extend(mice)
          R_list.extend(Rs)
          p_list.extend(ps)
          output_list.extend([output]*len(Rs))
          test_list.extend([test]*len(Rs))

      df = pd.DataFrame({'mouse': mice_list,
                          'test': test_list,
                        'output': output_list,
                             'R': R_list,
                        'pvalue': p_list})
      
      if save:
        df.to_csv(csvname,index=False)

    return df



    

  def getDataFrameRoutputs(self,tests='all',
                          save=False,csvname='R.csv'):

    if tests == 'all':
      tests = self.tests

    mice_list = []
    test_list = []
    output_list = []
    R_list = []
    p_list = []

    with h5py.File(self.filename,'r') as f:

      for test in tests:
        for output in f['OutputCorrelation/'+test]:
          mice = list(f['OutputCorrelation/'+test+'/'+output+'/mice'])
          Rs = list(f['OutputCorrelation/'+test+'/'+output+'/R'])
          ps = list(f['OutputCorrelation/'+test+'/'+output+'/p-value'])

          mice_list.extend(mice)
          R_list.extend(Rs)
          p_list.extend(ps)
          output_list.extend([output]*len(Rs))
          test_list.extend([test]*len(Rs))

      df = pd.DataFrame({'mouse': mice_list,
                          'test': test_list,
                        'output': output_list,
                             'R': R_list,
                        'pvalue': p_list})
      
      if save:
        df.to_csv(csvname,index=False)

    return df

In [None]:
def isbinary(vector):
  return np.array_equal(vector, vector.astype(bool))

In [None]:
def arebinary(dictionary):

  for key in dictionary:
    if not isbinary(dictionary[key]['values']):
      return False
      break
  else:
    return True    

In [None]:
def interpolate_signal(signal,t_old,t_new):

  from scipy.interpolate import interp1d

  func = interp1d(t_old,signal)
  s_new = func(t_new)

  return s_new

In [None]:
def saveToHDF(f,path,data):

  try:
    f.create_dataset(path,data=data)
  except ValueError:
    try:
      f[path][()] = data
    except TypeError:
      del f[path]
      f.create_dataset(path,data=data)

In [None]:
def calculate_auc(means, period=0.10, window=[-5.0,5.0], time_frames=[[-3,0],[0,3]]):

  from sklearn import metrics

  t = create_centered_time_vector(period,window)

  stats = np.zeros([len(means),len(time_frames)])

  for i,frame in enumerate(time_frames):
    idx = [i for i in range(len(t)) if t[i]>frame[0] and t[i]<frame[1]]
    #stats1 = np.mean(means[:,idx],axis=1)
    for m in range(len(means)):
      auc = metrics.auc(t[idx],means[m,idx])
      stats[m,i] = auc / (frame[1] - frame[0])

  return stats

In [None]:
# Check if experiment is in the list of god recordings
def isGoodRecording(goodRecordings,mouse,test,trial=None,output=None):

  goodMouse = [i for i in range(len(goodRecordings)) if goodRecordings[idx,0]==mouse]

  goodTest = [i for i in range(len(goodRecordings)) if goodRecordings[idx,1]==test]

  if trial is not None:
    goodTrial = [i for i in range(len(goodRecordings)) if goodRecordings[idx,2]==trial]
  else: 
    goodTrial = range(len(goodRecordings))

  if output is not None:
    goodOutput = [i for i in range(len(goodRecordings)) if goodRecordings[idx,3]==trial]
  else:
    goodOutput = range(len(goodRecordings))

  intersection = list(set(goodMouse)&set(goodTest)&set(goodTrial)&set(goodOutput))

  return intersection != []

# Data processing helper functions

### Preprocess

In [None]:
def smooth_signal(x,window_len=10,window='flat'):

    """smooth the data using a window with requested size.
    
    This method is based on the convolution of a scaled window with the signal.
    The signal is prepared by introducing reflected copies of the signal 
    (with the window size) in both ends so that transient parts are minimized
    in the begining and end part of the output signal.
    The code taken from: https://scipy-cookbook.readthedocs.io/items/SignalSmooth.html
    
    input:
        x: the input signal 
        window_len: the dimension of the smoothing window; should be an odd integer
        window: the type of window from 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'
                'flat' window will produce a moving average smoothing.

    output:
        the smoothed signal        
    """

    if x.ndim != 1:
        raise(ValueError, "smooth only accepts 1 dimension arrays.")

    if x.size < window_len:
        raise(ValueError, "Input vector needs to be bigger than window size.")

    if window_len<3:
        return x

    if window_len % 2 == 1:
      window_len -= 1

    if not window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']:
        raise(ValueError, "Window is one of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'")

    s=np.r_[x[window_len-1:0:-1],x,x[-2:-window_len-1:-1]]

    if window == 'flat': # Moving average
        w=np.ones(window_len,'d')
    else:
        w=eval('np.'+window+'(window_len)')

    y=np.convolve(w/w.sum(),s,mode='valid')

    return y[(int(window_len/2)-1):-int(window_len/2)]

In [None]:
from scipy.signal import butter, filtfilt, freqz

def butter_lowpass(cutoff, fs, order=10):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    return b, a


def butter_lowpass_filter(data, cutoff, fs, order=10):
    b, a = butter_lowpass(cutoff, fs, order=order)
    y = filtfilt(b, a, data)
    return y

In [None]:
def flatten_signal(signal,lambda_=1e4,order=0.5,itermax=50):

# Find the signal baseline using airPLS alghorithm
  # add_beginning=600
  # signal = np.insert(signal, 0, signal[0]*np.ones(add_beginning)).reshape(len(signal)+add_beginning,1)
  # base = airPLS(signal[:,0],lambda_=lambda_,porder=order,itermax=itermax).reshape(len(signal),1)
  # signal = signal[add_beginning:]
  # base = base[add_beginning:]

  add=600
  s = np.r_[signal[add-1:0:-1],signal,signal[-2:-add-1:-1]]
  b = airPLS(s,lambda_=lambda_,porder=order,itermax=itermax)
  signal = s[add-1:-add+1]
  base = b[add-1:-add+1]

# Remove the begining of the signal baseline
  signal = (signal - base)

  return signal, base

In [None]:
def standardize_signal(signal):

  z_signal = (signal - np.nanmedian(signal)) / np.nanstd(signal)

  return z_signal

airPLS algorithm

In [None]:
'''
airPLS.py Copyright 2014 Renato Lombardo - renato.lombardo@unipa.it
Baseline correction using adaptive iteratively reweighted penalized least squares

This program is a translation in python of the R source code of airPLS version 2.0
by Yizeng Liang and Zhang Zhimin - https://code.google.com/p/airpls
Reference:
Z.-M. Zhang, S. Chen, and Y.-Z. Liang, Baseline correction using adaptive iteratively reweighted penalized least squares. Analyst 135 (5), 1138-1146 (2010).

Description from the original documentation:

Baseline drift always blurs or even swamps signals and deteriorates analytical results, particularly in multivariate analysis.  It is necessary to correct baseline drift to perform further data analysis. Simple or modified polynomial fitting has been found to be effective in some extent. However, this method requires user intervention and prone to variability especially in low signal-to-noise ratio environments. The proposed adaptive iteratively reweighted Penalized Least Squares (airPLS) algorithm doesn't require any user intervention and prior information, such as detected peaks. It iteratively changes weights of sum squares errors (SSE) between the fitted baseline and original signals, and the weights of SSE are obtained adaptively using between previously fitted baseline and original signals. This baseline estimator is general, fast and flexible in fitting baseline.


LICENCE
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>
'''

import numpy as np
from scipy.sparse import csc_matrix, eye, diags
from scipy.sparse.linalg import spsolve

def WhittakerSmooth(x,w,lambda_,differences=1):
    '''
    Penalized least squares algorithm for background fitting
    
    input
        x: input data (i.e. chromatogram of spectrum)
        w: binary masks (value of the mask is zero if a point belongs to peaks and one otherwise)
        lambda_: parameter that can be adjusted by user. The larger lambda is,  the smoother the resulting background
        differences: integer indicating the order of the difference of penalties
    
    output
        the fitted background vector
    '''
    X=np.matrix(x)
    m=X.size
    #i=np.arange(0,m)
    E=eye(m,format='csc')
    D=E[1:]-E[:-1] # numpy.diff() does not work with sparse matrix. This is a workaround.
    W=diags(w,0,shape=(m,m))
    A=csc_matrix(W+(lambda_*D.T*D))
    B=csc_matrix(W*X.T)
    background=spsolve(A,B)
    return np.array(background)

def airPLS(x, lambda_=100, porder=1, itermax=15):
    '''
    Adaptive iteratively reweighted penalized least squares for baseline fitting
    
    input
        x: input data (i.e. chromatogram of spectrum)
        lambda_: parameter that can be adjusted by user. The larger lambda is,  the smoother the resulting background, z
        porder: adaptive iteratively reweighted penalized least squares for baseline fitting
    
    output
        the fitted background vector
    '''
    m=x.shape[0]
    w=np.ones(m)
    for i in range(1,itermax+1):
        z=WhittakerSmooth(x,w,lambda_, porder)
        d=x-z
        dssn=np.abs(d[d<0].sum())
        if(dssn<0.001*(abs(x)).sum() or i==itermax):
            if(i==itermax): print('WARING max iteration reached!')
            break
        w[d>=0]=0 # d>0 means that this point is part of a peak, so its weight is set to 0 in order to ignore it
        w[d<0]=np.exp(i*np.abs(d[d<0])/dssn)
        w[0]=np.exp(i*(d[d<0]).max()/dssn) 
        w[-1]=w[0]
    return z


In [None]:
def plot_raw(raw_signal,raw_reference,signal,reference,s_base,r_base,
            time_=None,events=None,measurements=None,
            figtitle=None,figsize=(22, 13),
            save=False,save_path='',image_format='.pdf'):

  if time_ is None:
    time_ = range(len(raw_signal))

  if (measurements is None) or arebinary(measurements):
   # Create figure
    fig, axs = plt.subplots(2,2,figsize=figsize)
    axs = axs.ravel()
   # Plot recordings
    axs[0].plot(time_,raw_signal, color='blue', linewidth=1.5)
    axs[0].plot(time_,s_base, color='black', linewidth=1.5)
    axs[0].set_ylabel('signal', fontsize='x-large', multialignment='center')
    axs[1].plot(time_,signal, color='blue',linewidth=1.5)

    axs[2].plot(time_,raw_reference, color='purple', linewidth=1.5)
    axs[2].plot(time_,r_base, color='black', linewidth=1.5)
    axs[2].set_ylabel('reference', fontsize='x-large', multialignment='center')
    axs[3].plot(time_,reference, color='purple',linewidth=1.5)

    axs[2].set_xlabel('time', fontsize='x-large', multialignment='center')
    axs[3].set_xlabel('time', fontsize='x-large', multialignment='center')

   # Plot events
    if events is not None:
      cmap = get_cmap(len(events))
      for k,key in enumerate(events): # plot all events
      # if it's empty
        if len(events[key])==0:
          pass
        else:
          for ax in axs:
          # one occurance event
            if events[key].shape[1]==1:
              for event in events[key]:
                ax.axvline(event,linewidth=1,color=cmap(k),label=key)
          # event with onset and offset     
            elif events[key].shape[1]==2:
              for event0,event1 in events[key]:
                ax.axvspan(event0,event1,alpha=0.3,color=cmap(k),label=key)
   # X-ticks
    for ax in axs:
      ax.set_xlim([0,max(time_)])
      ax.tick_params(labelsize='large')
   # Legend
    for ax in axs:    
      handles, labels = ax.get_legend_handles_labels()
      by_label = OrderedDict(zip(labels, handles))
      ax.legend(by_label.values(), by_label.keys(), prop={'size': 'small'})
  
   # Title
    if figtitle is not None:
      fig.suptitle(figtitle, fontsize='xx-large')
    
   # Save figure
    if save:
      imgname = figtitle.replace(' ','_') + '_raw' + image_format
      fig.savefig(save_path+imgname)

  else:
    for measure in measurements:
      if not isbinary(measurements[measure]['values']):
       # Create figure
        fig, axs = plt.subplots(2,2,figsize=figsize)
        axs = axs.ravel()
       # Plot recordings
        axs[0].plot(time_,raw_signal, color='blue', linewidth=1.5)
        axs[0].plot(time_,s_base, color='black', linewidth=1.5)
        axs[0].set_ylabel('signal', fontsize='x-large', multialignment='center')

        axs[1].plot(time_,signal, color='blue',linewidth=1.5)

        axs[2].plot(time_,raw_reference, color='purple', linewidth=1.5)
        axs[2].plot(time_,r_base, color='black', linewidth=1.5)
        axs[2].set_ylabel('reference', fontsize='x-large', multialignment='center')
        axs[2].set_xlabel('time', fontsize='x-large', multialignment='center')
        
        axs[3].plot(time_,reference, color='purple',linewidth=1.5)      
        axs[3].set_xlabel('time', fontsize='x-large', multialignment='center')

       # Plot events
        cmap = get_cmap(len(events))
        for k,key in enumerate(events): # plot all events
         # if it's empty
          if len(events[key])==0:
            pass
          else:
            for ax in axs:
            # one occurance event
              if events[key].shape[1]==1:
                for event in events[key]:
                  ax.axvline(event,linewidth=1,color=cmap(k),label=key)
            # event with onset and offset     
              elif events[key].shape[1]==2:
                for event0,event1 in events[key]:
                  ax.axvspan(event0,event1,alpha=0.3,color=cmap(k),label=key)
      # Plot continious measurements
        if not np.array_equal(measurements[measure]['values'], measurements[measure]['values'].astype(bool)):
          for ax in axs:
            ax_m = ax.twinx()
            ax_m.plot(measurements[measure]['time'], measurements[measure]['values'], color='black',label=key)
            ax_m.set_ylabel(measure,fontsize='x-large',multialignment='center',color='black')
            ax_m.tick_params('y', colors='black')
            m_max = np.nanmax(measurements[measure]['values'])
            m_min = np.nanmin(measurements[measure]['values'])
            ax_m.set_ylim([m_min, m_max + (m_max-m_min)]) # plot on the bottom half
            ax.set_zorder(ax_m.get_zorder()+1) # put ax in front of ax_m
            ax.patch.set_visible(False) # hide the 'canvas'
      # X-ticks
        for ax in axs:
          ax.set_xlim([0,max(time_)])
          ax.tick_params(labelsize='large')
      # Legend
        for ax in axs:    
          handles, labels = ax.get_legend_handles_labels()
          by_label = OrderedDict(zip(labels, handles))
          ax.legend(by_label.values(), by_label.keys(), prop={'size': 'small'})   
      # Title
        if figtitle is not None:
          fig.suptitle(figtitle, fontsize='xx-large')
        
      # Save figure
        if save:
          if figtitle is None:
            imgname = 'raw' + image_format
          else:
            imgname = figtitle.replace(' ','_') + '_' + measure + '_raw' + image_format
          fig.savefig(save_path+imgname)

### Fit and align calcium-independent signal to calcium dependent

In [None]:
def fit_signal(signal, reference, model='RANSAC'):

  i0 = np.max(np.argwhere(np.isnan(signal)))
  signal = signal[i0+1:]
  reference = reference[i0+1:]

  signal = np.array(signal).reshape(len(signal),1)
  reference = np.array(reference).reshape(len(reference),1)

# Positive linear regression
  if model == 'RANSAC':
    from sklearn.linear_model import RANSACRegressor
    lin = RANSACRegressor(max_trials=1000,random_state=9999)
  elif model == 'Lasso':
    from sklearn.linear_model import Lasso
    lin = Lasso(alpha=0.0001,precompute=True,max_iter=1000,
                positive=True, random_state=9999, selection='random')

  lin.fit(reference, signal)
  reference_fitted = lin.predict(reference)
  reference_fitted = reference_fitted.reshape(len(reference_fitted),)

  a = np.empty((i0+1,))
  a[:] = np.nan   
  reference_fitted = np.r_[a,reference_fitted]
      
  return reference_fitted

In [None]:
def plot_fit(signal,reference,reference_fitted,
             figtitle=None,figsize=(15,13),save=False,save_path='',image_format='.pdf'):
  
  fig = plt.figure(figsize=figsize)
  ax = fig.add_subplot(111)
  ax.plot(reference,signal,'b.')
  ax.plot(reference,reference_fitted, 'r--',linewidth=1.5)
  ax.set_xlabel('reference', fontsize='x-large', multialignment='center')
  ax.set_ylabel('signal', fontsize='x-large', multialignment='center')
  ax.tick_params(labelsize='large')
 # Title
  if figtitle is not None:
    fig.suptitle(figtitle, fontsize='xx-large')
  
  # Save figure
  if save:
    if figtitle is None:
      imgname = 'fit' + image_format
    else:
      imgname = figtitle.replace(' ','_') + '_fit' + image_format
    fig.savefig(save_path+imgname)

In [None]:
def plot_aligned(signal,reference,time_=None,events=None,measurements=None,
                 figtitle=None,figsize=(20,13),save=False,save_path='',image_format='.pdf'):
  
  if time_ is None:
    time_ = range(len(signal))

  if (measurements is None) or arebinary(measurements):
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
   # Signal
    ax.plot(time_, signal, 'black' ,linewidth=1.5)
    ax.plot(time_, reference, 'purple',linewidth=1.5)
   # Events
    if events is not None:
      cmap = get_cmap(len(events))
      for k,key in enumerate(events): # plot all events
        if len(events[key])==0:
          pass
        # one occurance event
        elif events[key].shape[1]==1:
          for event in events[key]:
            ax.axvline(event,linewidth=1,color=cmap(k),label=key)
        # event with onset and offset     
        elif events[key].shape[1]==2:
          for event0,event1 in events[key]:
            ax.axvspan(event0,event1,alpha=0.3,color=cmap(k),label=key)
   # Params
    ax.set_xlabel('time', fontsize='x-large', multialignment='center')
    ax.set_ylabel('Intensity', fontsize='x-large', multialignment='center')
    ax.set_xlim([0,max(time_)])
    ax.tick_params(labelsize='x-large')
   # Legend    
    handles, labels = ax.get_legend_handles_labels()
    by_label = OrderedDict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys(), prop={'size': 'medium'})
   # Title
    if figtitle is not None:
      fig.suptitle(figtitle, fontsize='xx-large')
    
   # Save figure
    if save:
      imgname = figtitle.replace(' ','_') +'_align' + image_format
      fig.savefig(save_path+imgname)


  else:
    for measure in measurements:
      if not isbinary(measurements[measure]['values']):
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111)
      # Signal
        ax.plot(time_, signal, 'black' ,linewidth=1.5)
        ax.plot(time_, reference, 'purple',linewidth=1.5)
      # Events
        cmap = get_cmap(len(events))
        for k,key in enumerate(events): # plot all events
          if len(events[key])==0:
            pass
         # one occurance event
          elif events[key].shape[1]==1:
            for event in events[key]:
              ax.axvline(event,linewidth=1,color=cmap(k),label=key)
         # event with onset and offset     
          elif events[key].shape[1]==2:
            for event0,event1 in events[key]:
              ax.axvspan(event0,event1,alpha=0.3,color=cmap(k),label=key)
      # Measurements
        ax_m = ax.twinx()
        ax_m.plot(measurements[measure]['time'], measurements[measure]['values'], color=cmap(k),label=key)
        ax_m.set_ylabel(measure,fontsize='x-large',multialignment='center',color=cmap(k))
        ax_m.tick_params('y', colors=cmap(k))
        m_max = np.nanmax(measurements[measure]['values'])
        m_min = np.nanmin(measurements[measure]['values'])
        ax_m.set_ylim([m_min, m_max + (m_max-m_min)]) # plot on the bottom half
        ax.set_zorder(ax_m.get_zorder()+1) # put ax in front of ax_e
        ax.patch.set_visible(False) # hide the 'canvas'
      # Params
        ax.set_xlabel('time', fontsize='x-large', multialignment='center')
        ax.set_ylabel('Intensity', fontsize='x-large', multialignment='center')
        ax.set_xlim([0,max(time_)])
        ax.tick_params(labelsize='x-large')
      # Legend    
        handles, labels = ax.get_legend_handles_labels()
        by_label = OrderedDict(zip(labels, handles))
        ax.legend(by_label.values(), by_label.keys(), prop={'size': 'medium'})
      # Title
        if figtitle is not None:
          fig.suptitle(figtitle, fontsize='xx-large')
        
      # Save figure
        if save:
          if figtitle is None:
            imgname = 'align' + image_format
          else:
            imgname = figtitle.replace(' ','_') + '_'+ measure +'_align' + image_format
          fig.savefig(save_path+imgname)

### Calculate z dFF

In [None]:
def calculate_dff(signal,reference,standardized=True):

  if standardized:
    dFF = signal - reference
  else:
    dFF = (signal - reference) / reference

  dFF = standardize_signal(dFF)

  return dFF  

In [None]:
def plot_dff(dFF,time_=None,events=None,measurements=None,
            figtitle=None,figsize=(20,13),save=False,save_path='',image_format='.pdf'):

  if time_ is None:
    time_ = range(len(dFF))

  ymin = np.nanmin([-3, np.nanmin(dFF)])
  ymax = np.nanmax([3, np.nanmax(dFF)])

  if (measurements is None) or arebinary(measurements):
   # Figure
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
   # Signal
    ax.plot(time_, dFF, 'black' ,linewidth=1.5)
   # Events
    if events is not None:
      cmap = get_cmap(len(events))
      for k,key in enumerate(events): # plot all events
        if len(events[key])==0:
          pass
      # one occurance event
        elif events[key].shape[1]==1:
          for event in events[key]:
            ax.axvline(event,linewidth=1,color=cmap(k),label=key)
      # event with onset and offset     
        elif events[key].shape[1]==2:
          for event0,event1 in events[key]:
            ax.axvspan(event0,event1,alpha=0.3,color=cmap(k),label=key)
   # Params
    ax.set_xlabel('time', fontsize='x-large', multialignment='center')
    ax.set_ylabel('dF/F', fontsize='x-large', multialignment='center')
    ax.set_xlim([0,max(time_)])
    ax.set_ylim([ymin, ymax])
    ax.tick_params(labelsize='large')
   # Legend    
    handles, labels = ax.get_legend_handles_labels()
    by_label = OrderedDict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys(), prop={'size': 'medium'})
   # Title
    if figtitle is not None:
      fig.suptitle(figtitle, fontsize='xx-large')

   # Save figure
    if save:
      imgname = figtitle.replace(' ','_') +'_dFF' + image_format
      fig.savefig(save_path+imgname)

  else:
    for measure in measurements:
      if not isbinary(measurements[measure]['values']):
      # Figure
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111)
      # Signal
        ax.plot(time_, dFF, 'black' ,linewidth=1.5)
      # Events
        cmap = get_cmap(len(events))
        for k,key in enumerate(events): # plot all events
          if len(events[key])==0:
            pass
        # one occurance event
          elif events[key].shape[1]==1:
            for event in events[key]:
              ax.axvline(event,linewidth=1,color=cmap(k),label=key)
        # event with onset and offset     
          elif events[key].shape[1]==2:
            for event0,event1 in events[key]:
              ax.axvspan(event0,event1,alpha=0.3,color=cmap(k),label=key)
      # Measurements
        ax_m = ax.twinx()
        ax_m.plot(measurements[measure]['time'], measurements[measure]['values'], color=cmap(k),label=key)
        ax_m.set_ylabel(measure,fontsize='x-large',multialignment='center',color=cmap(k))
        ax_m.tick_params('y', colors=cmap(k))
        m_max = np.nanmax(measurements[measure]['values'])
        m_min = np.nanmin(measurements[measure]['values'])
        ax_m.set_ylim([m_min, m_max + (m_max-m_min)]) # plot on the bottom half
        ax.set_zorder(ax_m.get_zorder()+1) # put ax in front of ax_e
        ax.patch.set_visible(False) # hide the 'canvas'
      # Params
        ax.set_xlabel('time', fontsize='x-large', multialignment='center')
        ax.set_ylabel('dF/F', fontsize='x-large', multialignment='center')
        ax.set_xlim([0,max(time_)])
        ax.set_ylim([ymin, ymax])
        ax.tick_params(labelsize='large')
      # Legend    
        handles, labels = ax.get_legend_handles_labels()
        by_label = OrderedDict(zip(labels, handles))
        ax.legend(by_label.values(), by_label.keys(), prop={'size': 'medium'})
      # Title
        if figtitle is not None:
          fig.suptitle(figtitle, fontsize='xx-large')

      # Save figure
        if save:
          if figtitle is None:
            imgname = 'dFF' + image_format
          else:
            imgname = figtitle.replace(' ','_') + '_'+ measure +'_dFF' + image_format
          fig.savefig(save_path+imgname)

### Perievent data

In [None]:
def create_perievents(signal,time_,event,window=[-5.0,5.0],dur=None,iei=None,avg_win=None):

  period = find_avg_period(time_, time_format='total seconds')

 # Remove events at the beginning and end of test that less than win 
  event = event[np.all(event > abs(window[0]), axis=1)]
  event = event[np.all(event < max(time_)-window[1], axis=1)]

 # Events with one occurence ---------------------------------------------------
  if event.shape[1]==1:
    Array = []
    for e in event:
      s_event = chunk_signal(signal,e,time_,window)
      if avg_win is not None:
        s_event_mean = (s_event[int((-window[0]+avg_win[0])/period):int((-window[0]+avg_win[1])/period)]).mean()
        s_event = s_event - s_event_mean
      Array.append(s_event)    
    Array = np.array(Array).squeeze()
    if Array.ndim == 1:
      Array = Array.reshape(1,len(Array))
    Array = Array[~np.isnan(Array).any(axis=1)]

    Perievents = {'onset': Array}

 # Events with onset and offset ------------------------------------------------
  elif event.shape[1]==2:

 # Remove short intervals and durations
    event = adjust_intervals_durations(event, iei, dur)
      
   # Create Perievent Arrays
   # Initialize Arrays
    Array_onset = []
    Array_offset = []
   # Loop through all onsets and offsets
    for e0,e1 in event:
      s_onset = chunk_signal(signal,e0,time_,window)
      s_offset = chunk_signal(signal,e1,time_,window)
     # Normalize signals to signals in avg_win
      if avg_win is not None:
        s_event_mean = (s_onset[int((-window[0]+avg_win[0])/period):int((-window[0]+avg_win[1])/period)]).mean()
        s_onset = s_onset - s_event_mean
        s_offset = s_offset - s_event_mean
     # Append to arrays   
      Array_onset.append(s_onset)
      Array_offset.append(s_offset)
   # Squeeze to 2D arrays   
    Array_onset = np.array(Array_onset).squeeze()
    Array_offset = np.array(Array_offset).squeeze()
   # Reshape Arrays if squeezed to 1D 
    if Array_onset.ndim == 1:
      Array_onset = Array_onset.reshape(1,len(Array_onset))
      Array_offset = Array_offset.reshape(1,len(Array_offset))
   # Remove elements with nans   
    Array_onset = Array_onset[~np.isnan(Array_onset).any(axis=1)]
    Array_offset = Array_offset[~np.isnan(Array_offset).any(axis=1)]

    Perievents = {'onset': Array_onset,
                 'offset': Array_offset}

  return Perievents

In [None]:
def chunk_signal(signal, t0, t, w):

  idx = find_idx(t0, t, 'total seconds')

  period = find_avg_period(t, 'total seconds')

  i0 = idx + int(w[0]/period - 1/2)
  i1 = idx + int(w[1]/period + 1/2) + 1

  chunk = signal[i0:i1]

  return chunk

In [None]:
def create_centered_time_vector(period=0.1,window=[-5.0,5.0]):

  t_pre = np.arange(-period,window[0]-period/2,-period)
  t_post = np.arange(0,window[1]+period/2,period)
  t = np.concatenate([t_pre[-1::-1],t_post])

  return t

In [None]:
def plot_perievents(Array,period=0.10,
                    Array1=None,period1=None,
                    window=[-5.0,5.0],color='green',
                    figtitle=None,figsize=None,
                    save=False,save_path='',image_format='.pdf'):

  Mean = {}
  Error = {}
  for key in Array:
    Mean[key] = np.nanmean(Array[key],axis=0)
    Error[key] = np.nanstd(Array[key],axis=0) / np.sqrt(Array[key].shape[0])
  
  ymax = 2
  ymin = -2
  for key in Array:
    ymax = max(ymax,1.1*Array[key].max())
    ymin = min(ymin,1.1*Array[key].min())

  ts = create_centered_time_vector(period,window)

  if Array1 is not None:
    Mean1 = {}
    Error1 = {}
    for key in Array1:
      Mean1[key] = np.nanmean(Array1[key],axis=0)
      Error1[key] = np.nanstd(Array1[key],axis=0) / np.sqrt(Array1[key].shape[0])

    ymax1 = max( [Array1[key1].max() for key1 in Array1.keys()] )
    ymin1 = min( [Array1[key1].min() for key1 in Array1.keys()] )
    std1 = np.nanstd( [np.nanstd(Array1[key1]) for key1 in Array1.keys()] )

    ts1 = create_centered_time_vector(period1,window)

  if figsize is None:
    if len(Array) == 1:
      figsize = (12,10)
    else:
      figsize = (20,10)

  from matplotlib import gridspec

  fig = plt.figure(figsize=figsize)
  gs = gridspec.GridSpec(1, len(Array))
  for i,key in enumerate(Array):
    ax = fig.add_subplot(gs[i])  
    ax.set_title(key, fontsize='xx-large')
    ax.plot(ts,Array[key].T,color=color,alpha=0.5,linewidth=1)
    ax.plot(ts,Mean[key],color=color,linewidth=2)
    ax.fill_between(ts, Mean[key]-Error[key],Mean[key]+Error[key],
                    alpha=0.3,edgecolor=color,facecolor=color,linewidth=0)
    ax.axvline(0,linestyle='--',color='black',linewidth=1.5)
    ax.set_xlabel('time', fontsize='xx-large', multialignment='center')
    ax.set_ylabel('z dF/F', fontsize='xx-large', multialignment='center')
    ax.set_ylim([ymin,ymax])
    ax.set_xlim(window)
    ax.tick_params(labelsize='x-large')
    if Array1 is not None:
      ax_m = ax.twinx()
      ax_m.plot(ts1,Array1[key].T,color='black',alpha=0.5,linewidth=1)
      ax_m.plot(ts1,Mean1[key],color='black',linewidth=2)
      ax_m.fill_between(ts1, Mean1[key]-Error1[key], Mean1[key]+Error1[key],
                    alpha=0.3,edgecolor='black',facecolor='black',linewidth=0)
      ax_m.set_ylim([ymin1, ymax1 + std1]) # plot on the bottom half
      ax.set_zorder(ax_m.get_zorder()+1) # put ax in front of ax_m
      ax.patch.set_visible(False) # hide the 'canvas'
  # Title
  if figtitle is not None:
    fig.suptitle(figtitle, fontsize='xx-large')

 # Save figure
  if save:
    if figtitle is None:
      imgname = 'mean' + image_format
    else:
      imgname = figtitle.replace(' ','_') +'_mean' + image_format
    fig.savefig(save_path+imgname)

# OS helper functions

### Find file in the list of files by part of the name

In [None]:
def contains(name, strings):
  answer = True
  for string in strings:
    if string not in name:
      answer = False
  return answer


def find_files(folder,strings):

    file_list = os.listdir(folder)
    files = []
    
    for file_name in file_list:
      if contains(file_name,strings):
        files.append(file_name)
    
    return files

### Create folder if it doesn't exist

In [None]:
def create_folder(new_folder):
  if not os.path.exists(new_folder):
    os.mkdir(new_folder)

# Time helper functions

### Create time of format 'HH:MM:SS.ms' from list of hours, minutes, seconds, andmiliseconds.

In [None]:
def create_realtime(hh,mm,ss,ms):
  
 # Hours
  if hh is not list:
    hh = hh*np.ones(len(mm),dtype=int)
    dif_mm = np.diff(mm)
    hour_change = [i+1 for i in range(len(dif_mm)) if dif_mm[i]<0]
    if len(hour_change) != 0:
      for i in hour_change:
        hh[i:] = [h+1 for h in hh[i:]]
    hh = [str(int(h)) for h in hh]

 # Minutes
  for i in range(len(mm)):
    if mm[i]<10:
      mm[i] = '0'+str(int(mm[i]))
    else:
      mm[i] = str(int(mm[i]))
 
 # Seconds
  for i in range(len(ss)):
    if ss[i]<10:
      ss[i] = '0'+str(int(ss[i]))
    else:
      ss[i] = str(int(ss[i]))
 
 # Miliseconds
  for i in range(len(ms)):
    if ms[i]<10:
      ms[i] = '00'+str(int(ms[i]))
    elif ms[i]<100:
      ms[i] = '0'+str(int(ms[i]))
    else:
      ms[i] = str(int(ms[i]))

  realtime = [h+':'+m+':'+s+'.'+x for h,m,s,x in zip(hh,mm,ss,ms)]

  return realtime

### Transform real time 'HH:MM:SS.ms' to total seconds

In [None]:
def time_to_seconds(t, t0=None):

  if t0 is not None:
    t = np.array([(pd.Timedelta(x)-pd.Timedelta(t0)).total_seconds() for x in t])
  else:
    t = np.array([pd.Timedelta(x).total_seconds() for x in t])

  return t

### Find avereage period of recording

In [None]:
def find_avg_period(t, time_format='total seconds'):

  if time_format=='real time':
    t = time_to_seconds(t, t[0])

  dt = np.diff(t)

  T = np.median(dt)

  T = round(T,10)

  return T

### Find index based on time

In [None]:
def find_idx(t, time_vector, time_format='total seconds'):

  if time_format == 'real time':
    time_vector = np.array([(pd.Timedelta(t)-t0).total_seconds() for t in time_vector])
  
  elif time_format == 'total seconds':
    time_vector = np.array(time_vector)

  idx = (np.abs(time_vector-t)).argmin()

  return idx 

# Behavior events helper functions

### Get vector of values from Med Associates file

In [None]:
def values_from_medfile(file_name,letter):

  import string
  med_letters = []
  for l in list(string.ascii_uppercase):
    med_letters.append(l+':')

# Get info from the file  
  lines = []
  file = open(file_name, 'r')
  for line in file:
    lines.append(line.rstrip('\n'))
  file.close()

  if letter in ['File','Start Date','End Date','Subject','Experiment','Group','Box','Start Time','End Time','MSN']:
    for l in lines:
      if letter+':' in l:
        break
    return l.replace(letter+': ','')

# Get the indecies of start and end lines
  start_line = ''
  end_line = ''
  for idx in range(13,len(lines)):
    if letter+':' in lines[idx]:
      start_line = idx
      break
  for idx in range(start_line+1,len(lines)):    
    if lines[idx][:2] in med_letters:
      end_line = idx
      break
  if end_line == '':
    end_line = len(lines)

# Create a list of values
  values = []
  for idx in range(start_line,end_line):
    line = lines[idx]
    for letter in med_letters:
      if letter in line:
        line = line.replace(letter,'')
    line_elements = line.split(' ')
    for element in line_elements:
      if element != '' and ':' not in element:
        values.append(float(element)) 

  return(values)

### Find onsets and offsets based on binary vector

In [None]:
def event_onoffset(vector,time_=None,t0=None,time_format='total seconds'):

  '''
  Finds event onsets and offsets based on the binary vector
  
  input
    vector: binary vector, where 1 means event was present and 0 - no event;
    time_: time, the same length as vector;
    t0: start time, should be in the same format as time_
    time_format: 'real time' is format of'HH:MM:SS.ms', 
                 'total seconds' is a vector of numbers

  output
    event: numpy array with first column with event onset and second with offset
  '''

  # Transform the vector to numpy array
  vector = np.array(vector)

  # Calculate the n-th discrete difference
  diff_v = np.diff(vector)

 # Find event onset (==1) and offset (==-1)
  onset = [i+1 for i in range(len(diff_v)) if diff_v[i] > 0]
  offset = [i for i in range(len(diff_v)) if diff_v[i] < 0]

 # Check if onset and offset are the same length and adjust
  if len(onset) < len(offset):
    onset = [0] + onset
  elif len(offset) < len(onset):
    offset = offset + [len(vector)-1] 
  
  if time_ is not None:
   # Create list of time in total seconds if it was in the format of real time
    if time_format == 'real time':
      total_sec = time_to_seconds(time_,t0=t0)  
    elif time_format == 'total seconds':
      total_sec = list(time_)
      if t0 is not None:
        total_sec = [t-t0 for t in total_sec]

   # Convert indices to time
    onset = [total_sec[i] for i in onset]
    offset = [total_sec[i] for i in offset] 

 # Reshape
  onset = np.array(onset).reshape(len(onset),1)
  offset = np.array(offset).reshape(len(offset),1)

 # Concat onset and offsets to 2D numpy array
  event = np.concatenate([onset,offset],axis=1)

  return event

### Remove events with short inter event intervals and/or short duration

In [None]:
def adjust_intervals_durations(event, min_interval=None, min_duration=None):

  if min_interval is not None:         
    onset = event[:,0]
    offset = event[:,1]
    intervals = np.array(onset[1:]) - np.array(offset[:-1])
    idx = np.array(np.where(intervals > min_interval)).squeeze()
    onset = np.append(onset[0],onset[idx+1])
    offset = np.append(offset[idx],offset[-1])
    event = np.concatenate([onset.reshape(len(onset),1),offset.reshape(len(offset),1)],axis=1)

  if min_duration is not None:
    duration = event[:,1] - event[:,0]
    idx = np.array(np.where(duration > min_duration)).squeeze()
    event = event[idx]

  if len(event.shape) == 1:
    event = event.reshape(1,2)

  return event

### Find speed from coordinates and time

In [None]:
def get_speed(x,y,t,smooth_filter=None,smooth_parameter=1):

  v = np.zeros(len(x))

  for i in range(len(x)):
    
    if i==0:
      i_next = i + 1
      i_prev = i

    elif i==len(x)-1:
      i_next = i
      i_prev = i - 1

    else:
      i_next = i + 1
      i_prev = i - 1

    dl = np.sqrt((x[i_next]-x[i_prev])**2 + (y[i_next]-y[i_prev])**2)
    dt = abs(t[i_next]-t[i_prev])
    v[i] = dl / dt

 # Smooth speed
  if smooth_filter is not None:
    T = find_average(t)
  if smooth_filter=='moving average':
    v = smooth_signal(v,window_len=int(smooth_parameter/T))
  if smooth_filter=='low-pass':
    v = butter_lowpass_filter(v, smooth_parameter, 1/T, order=10)

  return v 

In [None]:
def remove_outliers(v,t,percentile=0.995):
  
  q = np.quantile(v,percentile)
  outliers = np.argwhere(v>q).squeeze()
  v = np.delete(v,outliers)
  t_new = np.delete(t,outliers)

  if outliers[0]==0:
    v = np.insert(v,0,v[0])
    t_new = np.insert(t_new,0,t[0])
  
  if outliers[-1]==(t.size-1):
    v = np.insert(v,v.size,v[-1])
    t_new = np.insert(t_new,t_new.size,t[-1])
    
  return v,t_new

### Find immobility onsets and offsets

In [None]:
def find_onoffset_immobility(movement,time_, min_duration=1, on_threshold=0.1, off_threshold=0.15,output='time'):

  if np.isnan(movement).any():
    i = np.max(np.argwhere(np.isnan(movement)))+1
  else: i = 0

  onset = []
  offset = []

  while i < len(movement)-1:

    if movement[i] < on_threshold:
      start = i
      t0 = time_[start]
      while time_[i] - t0 < min_duration:
        i += 1
        if i >= len(movement)-1:
          break
      
      if (movement[start:i] < on_threshold).all():
        if output=='time': onset.append(time_[start])
        elif output=='index': onset.append(start)

        while movement[i] < off_threshold:
          i += 1
          if i >= len(movement)-1:
            break

        if output=='time': offset.append(time_[i-1])
        elif output=='index': offset.append(i-1)
        
    i += 1 

 # Adjust mobility onset
  T = find_avg_period(time_)
  win = int(round(min_duration/T))
  for i in range(len(offset)):
    if output=='time': i1 = find_idx(offset[i],time_)
    elif output=='index': i1 = offset[i]
    off = i1 - np.argmin(movement[i1:i1-win:-1])
    if output=='time': offset[i] = time_[off]
    elif output=='index': offset[i] = off

  if output=='time' : intervals = np.array(onset[1:]) - np.array(offset[:-1])
  elif output=='index': intervals = np.array(time_[onset[1:]]) - np.array(time_[offset[:-1]])
  remove = [i for i in range(len(intervals)) if intervals[i] < min_duration]
  onset = np.delete(onset,[i+1 for i in remove])
  offset = np.delete(offset,remove) 

  immobility = np.array([onset,offset]).T

  return immobility

In [None]:
def convert_onoffset(onoff, end, start=0):
  if onoff[0,0] != start and onoff[-1,1] != end:
    offon = np.zeros((len(onoff)+1,2),dtype=onoff.dtype)
    offon[1:,0] = onoff[:,1]
    offon[:-1,1] = onoff[:,0]
    offon[0,0] = start
    offon[-1,1] = end
  if onoff[0,0] != start and onoff[-1,1] == end:
    offon = np.zeros(onoff.shape,dtype=onoff.dtype)
    offon[1:,0] = onoff[:-1,1]
    offon[:,1] = onoff[:,0]
    offon[0,0] = start
  if onoff[0,0] == start and onoff[-1,1] != end:
    offon = np.zeros(onoff.shape,dtype=onoff.dtype)
    offon[:,0] = onoff[:,1]
    offon[:-1,1] = onoff[1:,0]
    offon[-1,1] = end
  if onoff[0,0] == start and onoff[-1,1] == end:
    offon = np.zeros((len(onoff)-1,2),dtype=onoff.dtype)
    offon[:,0] = onoff[:-1,1]
    offon[:,1] = onoff[1:,0]

  return offon

# Plot helper functions

### Make patch spines invisible

In [None]:
def make_patch_spines_invisible(ax):
    ax.set_frame_on(True)
    ax.patch.set_visible(False)
    for sp in ax.spines.values():
        sp.set_visible(False)

### Plot colors

In [None]:
def get_cmap(n, name='gist_rainbow'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct 
    RGB color; the keyword argument name must be a standard mpl colormap name.'''
    return mpl.colormaps.get_cmap(name).resampled(n)

### Plot means

In [None]:
def plot_means(array,T=0.1,w=[-5,5],auc=None,plot_type='individual traces',plot_event=None,
              title='',periods=['baseline', 'event'],figsize=(3.5,3),
              color='green',ylab='z-dF/F',xlim=[-5,5],ylim=[-1,5],yticks=None,
              array1=None,auc1=None,T1=0.1,title1='AUC',
              color1='magenta',ylab1='Mobility score',ylim1=[-1,1],yticks1=None,
              subplot_ratio=[5,2],save=False,imgname='means.pdf'):
  
  import seaborn as sns
  from matplotlib import gridspec

  if plot_type in ['mean and SEM','mixed']:
    mean,error = calculate_mean_error(array)
    if auc is not None:
      auc_m,auc_e = calculate_mean_error(auc)
    if array1 is not None:
      mean1,error1 = calculate_mean_error(array1)
    if auc1 is not None:
      auc_m1,auc_e1 = calculate_mean_error(auc1)


  # Create time vector
  t = create_centered_time_vector(T,w)

  fig = plt.figure(figsize=figsize)
  if auc is not None:
    gs = gridspec.GridSpec(1, 2, width_ratios=subplot_ratio)
    ax = fig.add_subplot(gs[0])
  else:
    ax = fig.add_subplot()
  ax.set_title(title,size='x-large')
  ax.set_ylabel(ylab,size='large')
  ax.set_xlabel('Time (s)',size='large')
  ax.patch.set_visible(False) 
  ax.set_xlim(xlim)
  ax.set_ylim(ylim)
  if yticks is not None:
    ax.set_yticks(yticks)
  # Plot mean trace for each mouse
  sns.set(style="ticks")
  sns.despine(ax=ax)
  if plot_type=='individual traces':
    for mouse in range(len(array)):
      sns.lineplot(t,array[mouse,:],color=color,ax=ax)
  if plot_type=='mean and SEM':
    ax.plot(t,mean,color=color)
    ax.fill_between(t,mean-error,mean+error,alpha=0.3,edgecolor=color,facecolor=color,linewidth=0)
  elif plot_type=='mixed':
    for mouse in range(len(array)):
      sns.lineplot(t,array[mouse,:],color='black',alpha=0.2,ax=ax)
    ax.plot(t,mean,color=color)
    ax.fill_between(t,mean-error,mean+error,alpha=0.3,edgecolor=color,facecolor=color,linewidth=0)

  if plot_event is not None:
    if len(plot_event) == 2:
      ax.axvspan(plot_event[0],plot_event[1],color='blue',alpha=0.2)
  else:
    ax.axvline(0,linestyle='--',color='black')

  if array1 is not None:
    ax.tick_params(axis='y',colors=color)
    ax.yaxis.label.set_color(color)
    ax.spines['left'].set_color(color)

  # Plot values for area under the curve
  if auc is not None:
    x = range(1,2*len(periods),2)
    ax1 = fig.add_subplot(gs[1])
    sns.despine(ax=ax1)
    if plot_type=='individual traces':
      for mouse in range(len(auc)):
        sns.lineplot(x,auc[mouse,:],color=color,ax=ax1)
    if plot_type=='mean and SEM':
      ax1.errorbar(x,auc_m,auc_e,color=color)
    if plot_type=='mixed':
      ax1.plot(x,auc.T,color='black',alpha=0.2)
      ax1.errorbar(x,auc_m.T,auc_e.T,color=color)
    ax1.set_xticklabels(periods, rotation=50)
    ax1.xaxis.set_ticks(range(1,2*len(periods),2))
    ax1.set_xlim([0,2*len(periods)])
    ax1.set_ylim(ylim)
    if yticks is not None:
      ax1.set_yticks(yticks)
    ax1.set_yticklabels([])
    for label in ax1.get_xticklabels():
      label.set_horizontalalignment('center')
    if auc1 is not None:
      ax1.tick_params(axis='y',colors=color)
      ax1.yaxis.label.set_color(color)
      ax1.spines['left'].set_color(color)
    

  # Plot a second measure  
  if array1 is not None:
    # Create time vector for the second measure
    t1 = create_centered_time_vector(T1,w)
    # Plot means for each animal
    ax2 = ax.twinx()
    sns.despine(ax=ax2,right=False)
    if plot_type=='individual traces':
      for mouse in range(len(array1)):   
        sns.lineplot(t1,array1[mouse,:],color=color1,ax=ax2)
    elif plot_type=='mean and SEM':
      ax2.plot(t1,mean1,color=color1)
      ax2.fill_between(t1,mean1-error1,mean1+error1,alpha=0.3,edgecolor=color1,facecolor=color1,linewidth=0)
    elif plot_type=='mixed':
      for mouse in range(len(array1)):
        sns.lineplot(t1,array1[mouse,:],color='black',alpha=0.2,ax=ax2)
      ax2.plot(t1,mean1,color=color1)
      ax2.fill_between(t1,mean1-error1,mean1+error1,alpha=0.3,edgecolor=color1,facecolor=color1,linewidth=0)
    ax2.set_ylim(ylim1)
    if yticks1 is not None:
      ax2.set_yticks(yticks1)
    ax2.tick_params(axis='y',colors=color1)
    ax2.yaxis.label.set_color(color1)
    ax2.spines['right'].set_color(color1)
    ax2.spines['left'].set_visible(False)
    if auc1 is None:
      ax2.set_ylabel(ylab1,color=color1,size='large')
    else:
      ax2.set_yticklabels([])


  if auc1 is not None:
    ax3 = ax1.twinx()
    sns.despine(ax=ax1,right=False)
    if plot_type=='individual traces':
      for mouse in range(len(auc1)):
        sns.lineplot(x,auc1[mouse,:],color=color1,ax=ax3)
    if plot_type=='mean and SEM':
      ax3.errorbar(x,auc_m1,auc_e1,color=color1)
    if plot_type=='mixed':
      ax3.plot(x,auc1.T,color='black',alpha=0.2)
      ax3.errorbar(x,auc_m1.T,auc_e1.T,color=color1)
    sns.despine(ax=ax3,right=False)
    ax3.set_ylim(ylim1)
    if yticks1 is not None:
      ax3.set_yticks(yticks1)
    ax3.tick_params(axis='y',colors=color1)
    ax3.yaxis.label.set_color(color1)
    ax3.spines['right'].set_color(color1)
    ax3.spines['left'].set_visible(False)
    ax3.set_ylabel(ylab1,color=color1,size='large')
    for label in ax3.get_xticklabels():
      label.set_horizontalalignment('center')
    
  
  plt.tight_layout()

  if save:
    fig.savefig(imgname)

  return

In [None]:
def calculate_mean_error(data,error='SEM',confidence=0.95):

  import scipy

  a = 1.0 * np.array(data)
  n = len(a)
  m = np.mean(a,axis=0)
  if m is None:
    m = np.nanmean(a,axis=0)
    print('There are missing values.')
  se = scipy.stats.sem(a,axis=0)
  ci = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)

  if error=='SEM':
    e = se
  if error=='CI':
    e = ci
  
  return m, e

### Plot example


In [None]:
def plot_example(dFF,time_,events=None,dFF1=None,dFF2=None,measurement=None,time_m=None,
                 color='purple',ylim=None,yticks=None,ylabel=None,xticks=None,
                 color1='blue',ylim1=None,yticks1=None,ylabel1=None,
                 color2='red',ylim2=None,yticks2=None,ylabel2=None,
                 color_m='magenta',ylim_m=None,yticks_m=None,
                 color_e='red',
                 figsize=(5,3),save=False,imgname='example.pdf'):

  fig = plt.figure(figsize=figsize)
  sns.set(style='ticks')

  if measurement is not None:
    from matplotlib import gridspec
    gs = gridspec.GridSpec(2,1, height_ratios=[1,3])
      
    ax_m = fig.add_subplot(gs[0])
    sns.lineplot(x=time_m,y=measurement,color=color_m)

    ax_m.set_xticks([])
    make_patch_spines_invisible(ax_m)
    ax_m.spines['right'].set_visible(True)
    ax_m.yaxis.set_ticks_position('right')
    ax_m.yaxis.set_label_position('right')

    ax_m.spines['right'].set_color('magenta')
    ax_m.set_ylabel('Mobility\n score',color=color_m)
    ax_m.tick_params(axis='y', colors=color_m)

    ax_m.set_xlim(time_m[0], time_m[-1])
    if ylim_m is not None:
      ax_m.set_ylim(ylim_m)
    if yticks_m is not None:
      ax_m.set_yticks(yticks_m)

    ax = fig.add_subplot(gs[1])

  else:
    ax = fig.add_subplot()

  sns.lineplot(x=time_,y=dFF,color=color,ax=ax)

  if events is not None:
    if events.shape[1]==1:
      for e0 in events:
        ax.axvline(e0,linestyle='--',color='black')
    elif events.shape[1]==2:
      for e0,e1 in events:
        ax.axvspan(e0,e1,color=color_e,alpha=0.3)

  ax.spines['top'].set_visible(False)
  ax.spines['right'].set_visible(False)

  ax.spines['left'].set_color(color)
  ax.yaxis.label.set_color(color)
  ax.tick_params(axis='y',colors=color)

  ax.set_xlabel('Time (s)',fontsize='large')
  ax.set_ylabel('z dF/F',fontsize='large',color='black')

  ax.set_xlim(time_[0],time_[-1])
  if xticks is not None:
    ax.set_xticks(xticks)
  if ylim is not None:
    ax.set_ylim(ylim)
  if yticks is not None:
    ax.set_yticks(yticks)
  if ylabel is not None:
    ax.set_ylabel(ylabel,fontsize='large',color=color)

  if dFF1 is not None:
    ax1 = ax.twinx()
    sns.lineplot(x=time_,y=dFF1,color=color1,ax=ax1)

    if ylim1 is not None:
      ax1.set_ylim(ylim1)
    if yticks1 is not None:
      ax1.set_yticks(yticks1)
    if ylabel1 is not None:
      ax1.set_ylabel(ylabel1,fontsize='large',color=color1)

    make_patch_spines_invisible(ax1)
    ax1.spines['right'].set_visible(True)

    ax1.spines['right'].set_color(color1)
    ax1.yaxis.label.set_color(color1)
    ax1.tick_params(axis='y',colors=color1)


  if dFF2 is not None:
    ax2 = ax.twinx()  
    sns.lineplot(x=time_,y=dFF2,color=color2,ax=ax2)

    if ylim2 is not None:
      ax2.set_ylim(ylim2)
    if yticks2 is not None:
      ax2.set_yticks(yticks2)
    if ylabel2 is not None:
      ax2.set_ylabel(ylabel2,fontsize='large' ,color=color2)

    fig.subplots_adjust(right=0.9)
    ax2.spines['right'].set_position(('axes',1.15))

    make_patch_spines_invisible(ax2)
    ax2.spines['right'].set_visible(True)

    ax2.spines['right'].set_color(color2)
    ax2.yaxis.label.set_color(color2)
    ax2.tick_params(axis='y', colors=color2)

  plt.tight_layout()


  if save:
    fig.savefig(imgname)

### Plot perievent correlation

In [None]:
def plot_perievent_correlation_counts(mean,sem,labels,figsize=(3,3),
                                      save=False,imgname='correlationCounts.pdf'):

  import seaborn as sns

  bottom = np.add(mean[0], mean[1])
  n = mean.shape[1]

  sns.set(style="ticks")
  fig = plt.figure(figsize=figsize)
  ax = fig.add_subplot()
  ax.bar(range(n),mean[2],yerr=[np.zeros(n),sem[2]],color='grey',edgecolor='grey',ecolor='grey',bottom=bottom)
  ax.bar(range(n),mean[1],yerr=[np.zeros(n),sem[1]],color='red',edgecolor='red',ecolor='red',bottom=mean[0])
  ax.bar(range(n),mean[0],yerr=[np.zeros(n),sem[0]],color='green',edgecolor='green',ecolor='green')
  ax.set_yticks([0.25, 0.5, 0.75, 1])
  ax.set_yticklabels(['25%', '50%', '75%', '100%'])
  ax.set_xticks(range(n))
  ax.set_xticklabels(labels,rotation=30,ha='right')
  sns.despine()
  plt.tight_layout()

  if save:
    fig.savefig(imgname)

  return fig

### Plot cross-correlation

In [None]:
def plot_crosscorrelation(lags,corr,plot_type='mean and SEM',
                          color='green',ylim=None,xlim=None,yticks=None,xticks=None,
                          ylim1=None,yticks1=None,
                          figsize=(4,2),save=False,imgname='cross-correlation.pdf'):
  
  from matplotlib import gridspec

  maxlag = np.zeros(len(corr))
  for i in range(len(corr)):
    sm_corr = list(smooth_signal(corr[i],5))
    maxlag[i] = get_midline(lags,sm_corr)
    
  #print(maxlag)

  sns.set(style="ticks")
  fig = plt.figure(figsize=figsize)
  gs = gridspec.GridSpec(1, 2, width_ratios=[5, 1])
  ax = fig.add_subplot(gs[0])

  if plot_type=='individual traces':
    for mouse in range(corr.shape[0]):
      sns.lineplot(lags,corr[mouse],color=color)
  if plot_type=='mean and SEM':
    mean,error = calculate_mean_error(corr)
    ax.plot(lags,mean,color=color)
    ax.fill_between(lags,mean-error,mean+error,alpha=0.3,edgecolor=color,facecolor=color,linewidth=0)
  if plot_type=='mean and CI':
    mean,error = calculate_mean_error(corr,error='CI')
    ax.plot(lags,mean,color=color)
    ax.fill_between(lags,mean-error,mean+error,alpha=0.3,edgecolor=color,facecolor=color,linewidth=0)

  ax.axvline(0,linestyle='--',color='black')
  if ylim is not None:
    ax.set_ylim(ylim)
  if xlim is not None:
    ax.set_xlim(xlim)
  if yticks is not None:
    ax.set_yticks(yticks)
  if xticks is not None:
    ax.set_xticks(xticks)
  ax.set_xlabel('lags (s)',size='large')
  ax.set_ylabel('correlation',size='large')

  ax1 = fig.add_subplot(gs[1])
  if plot_type=='individual traces':
    ax1.plot(np.ones(len(maxlag)),maxlag,'o',color=color)
  if plot_type=='mean and SEM':
    m_maxlag,e_maxlag = calculate_mean_error(maxlag)
    ax1.errorbar(1,m_maxlag,e_maxlag,color=color)
    ax1.plot(1,m_maxlag,'o',color=color)
  if plot_type=='mean and CI':
    m_maxlag,e_maxlag = calculate_mean_error(maxlag,error='CI')
    ax1.errorbar(1,m_maxlag,e_maxlag,color=color)
    ax1.plot(1,m_maxlag,'o',color=color)
  ax1.set_ylabel('max lag (s)',size='large')
  ax1.set_xticks([])
  ax1.spines['bottom'].set_visible(False)
  ax1.axhline(0,color='black')
  if ylim1 is not None:
    ax1.set_ylim(ylim1)
  if yticks1 is not None:
    ax1.set_yticks(yticks1)

  sns.despine()
  plt.tight_layout()

  if save==True:
    fig.savefig(imgname)

In [None]:
def get_midline(x,y):

  total_sum = 0
  for i in range(len(x)-1):
    total_sum += y[i]*(x[i+1]-x[i])

  left_sum=0
  i = 0
  while i<len(y)-1:
    left_sum += y[i]*(x[i+1]-x[i])
    if left_sum > total_sum/2:
      if (left_sum - total_sum/2) < (total_sum/2 - left_sum-y[i]*(x[i+1]-x[i])):
        middle = i
      else:
        middle = i-1
      break
    i+=1
  return x[middle]

In [None]:
# def plot_crosscorrelation(lags,corr,plot_type='mean and SEM',
#                           color='green',ylim=None,xlim=None,yticks=None,xticks=None,
#                           figsize=(3,2),save=False,imgname='cross-correlation.pdf'):
  
#   sns.set(style="ticks")
#   fig = plt.figure(figsize=figsize)
#   ax = fig.add_subplot()

#   if plot_type=='individual traces':
#     for mouse in range(corr.shape[0]):
#       sns.lineplot(lags,corr[mouse],color=color)
#   if plot_type=='mean and SEM':
#     mean,error = calculate_mean_error(corr)
#     ax.plot(lags,mean,color=color)
#     ax.fill_between(lags,mean-error,mean+error,alpha=0.3,edgecolor=color,facecolor=color,linewidth=0)
#   if plot_type=='mean and CI':
#     mean,error = calculate_mean_error(corr,error='CI')
#     ax.plot(lags,mean,color=color)
#     ax.fill_between(lags,mean-error,mean+error,alpha=0.3,edgecolor=color,facecolor=color,linewidth=0)

#   ax.axvline(0,linestyle='--',color='black')
#   if ylim is not None:
#     ax.set_ylim(ylim)
#   if xlim is not None:
#     ax.set_xlim(xlim)
#   if yticks is not None:
#     ax.set_yticks(yticks)
#   if xticks is not None:
#     ax.set_xticks(xticks)
#   ax.set_xlabel('lags (s)')
#   ax.set_ylabel('correlation')
#   sns.despine()
#   plt.tight_layout()

#   if save==True:
#     fig.savefig(imgname)

### Plot measures over tests

In [None]:
def plot_boxplot(groups,values,color='green',ylabel=None,ylim=None,yticks=None,
                 figsize=(3,3),save=False,imgname='boxplot.pdf'):
  
  sns.set(style="ticks")
  fig = plt.figure(figsize=figsize)
  ax = fig.add_subplot()
  sns.boxplot(groups,values,color=color)
  sns.swarmplot(groups,values,color='black')
  if ylabel is not None:
    ax.set_ylabel(ylabel)
  if ylim is not None:
    ax.set_ylim(ylim)
  if yticks is not None:
    ax.set_yticks(yticks)
  sns.despine()
  plt.tight_layout()

  if save==True:
    fig.savefig(imgname)

In [None]:
def plot_violinplot(groups,values,color='green',ylabel=None,ylim=None,yticks=None,
                 figsize=(3,3),save=False,imgname='boxplot.pdf'):
  
  sns.set(style="ticks")
  fig = plt.figure(figsize=figsize)
  ax = fig.add_subplot()
  sns.violinplot(x=groups,y=values,color=color)
  #sns.swarmplot(groups,values,color='black')
  if ylabel is not None:
    ax.set_ylabel(ylabel)
  if ylim is not None:
    ax.set_ylim(ylim)
  if yticks is not None:
    ax.set_yticks(yticks)
  sns.despine()
  plt.tight_layout()

  if save==True:
    fig.savefig(imgname)

### Plot 3 phases

In [None]:
def plot_3phases(green_signal,red_signal,
                 figtitle,figsize=(24,13),
                 save=False,save_path='./figures/',image_format='.pdf'):  
  
  fig, axs = plt.subplots(3,2,figsize=figsize)
  axs = axs.ravel()

  for i,key in enumerate(green_signal):
    axs[2*i].plot(green_signal[key],color='green')
    axs[2*i+1].plot(red_signal[key],color='red')
    axs[2*i].set_ylabel(key,fontsize='x-large')
  axs[4].set_xlabel('time', fontsize='x-large', multialignment='center')
  axs[5].set_xlabel('time', fontsize='x-large', multialignment='center')
  
 # Title
  if figtitle is not None:
    fig.suptitle(figtitle, fontsize='xx-large')
    
 # Save figure
  if save:
    imgname = figtitle.replace(' ','_') + '_raw3phases' + image_format
    fig.savefig(save_path+imgname)

# Correlation analysis

### Cross-correlation

In [None]:
def xcorr(x, y, normed=True, detrend=False, maxlags=10):
  """
  Cross correlation of two signals of equal length
  Returns the coefficients when normed=True
  Returns inner products when normed=False
  Usage: lags, c = xcorr(x,y,maxlags=len(x)-1)
  Optional detrending e.g. mlab.detrend_mean
  """    

  Nx = len(x)
  if Nx != len(y):
      raise ValueError('x and y must be equal length')

  if detrend:
      import matplotlib.mlab as mlab
      x = mlab.detrend_mean(np.asarray(x)) # can set your preferences here
      y = mlab.detrend_mean(np.asarray(y))

  c = np.correlate(x, y, mode='full')

  if normed:
      n = np.sqrt(np.dot(x, x) * np.dot(y, y)) # this is the transformation function
      c = np.true_divide(c,n)

  if maxlags is None:
      maxlags = Nx - 1

  if maxlags >= Nx or maxlags < 1:
      raise ValueError('maglags must be None or strictly '
                       'positive < %d' % Nx)

  lags = np.arange(-maxlags, maxlags + 1)
  c = c[Nx - 1 - maxlags:Nx + maxlags]
  return lags, c

### Random sampling

In [None]:
import random

def random_subset( iterator, K, seed=30 ):

    random.seed(seed)

    result = []
    N = 0

    for item in iterator:
        N += 1
        if len( result ) < K:
            result.append( item )
        else:
            s = int(random.random() * N)
            if s < K:
                result[ s ] = item

    return result

# Print done

In [None]:
print('All Fiber Photometry functions are ready to use')

All Fiber Photometry functions are ready to use
