<a href="https://colab.research.google.com/github/katemartian/FiberPhotometryDataAnalysis/blob/master/FiberPhotometryDataAnalysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h1><center>Fiber Photometry Data Analysis</center><h1>


The package was developed by Ekaterina Martianova during her PhD studies in Christophe Proulx lab at CERVO Brain Research Center at Laval University

The package includes: 

- Functions necessary to process raw signals recorded with camera-based fiber photometry;

- Helper functions to process behavior data recorded/created with behavior softwares, e.g. Med Associates, ANY-maze, DeepLabCut;

- Helper functions to create summary plots.

# Import necessary libraries


In [0]:
import os
import sys
import time
import pandas as pd
import numpy as np
np.random.seed(0)
import matplotlib as mpl
import matplotlib.pyplot as plt
from collections import OrderedDict
mpl.style.use('classic')
import seaborn as sns
from scipy.stats.stats import pearsonr
from scipy import signal

In [0]:
# Show images only if asked
plt.ioff()

# Data processing pipeline

### Preprocess

In [0]:
def allrecordings_preprocess(exp_info, events, Raws,
                             flatten=True,lambda_=1e4,order=0.5,itermax=50,
                             abs_int=False,
                             prefilter='smooth', smooth_win=1,
                             remove=300,
                             standardize=True, 
                             plot=False, figsize=(24, 13), 
                             save=False):

  fs = int(np.round(exp_info['frequency']))
  
  Intensities = {}
  Bases = {}
  
 # Itterate through different outputs
  for output in Raws:
    
    Intensities[output] = {}
    Bases[output] = {}
    
   # Itterate through all signal recordings from one output
    for signal_type in Raws[output]:

      intensity = Raws[output][signal_type].copy()
  
     # Flatten
      if flatten: 
        intensity, base = flatten_signal(intensity,lambda_=lambda_,order=order,itermax=itermax)
        Bases[output][signal_type] = base.copy()
        
        if abs_int:
          intensity = intensity + min(base)
        
     # Smooth 
      if prefilter=='smooth':
        intensity = smooth_signal(intensity,window_len=smooth_win*fs)

      elif prefilter=='lowpass':
        cutoff = 1 / smooth_win
        intensity = butter_lowpass_filter(intensity, cutoff, fs, order=10)

     # Remove the begining
      intensity[:remove] = np.nan 


     # Standardize signal to mean 0 and std 1
      if standardize:
        intensity = standardize_signal(intensity)

     # Copy the result to the dictionary
      Intensities[output][signal_type] = intensity.copy() 

   # Plot if asked
    if plot:
      exp_time = exp_info['time']
      fig = plot_raw_int(Raws[output],Intensities[output],Bases[output],exp_time,events,figsize=figsize)
      figtitle = exp_info['mouse'] + ' ' + output + ' ' + exp_info['test']
      fig.suptitle(figtitle, fontsize='xx-large')

     # Save if asked
      if save:
        create_folder('figures')
        create_folder('figures/1_raw')
        imgname = figtitle.replace(' ','_') + '_raw.png'
        fig.savefig('./figures/1_raw/'+imgname)

  return Intensities

In [0]:
def smooth_signal(x,window_len=10,window='flat'):

    """smooth the data using a window with requested size.
    
    This method is based on the convolution of a scaled window with the signal.
    The signal is prepared by introducing reflected copies of the signal 
    (with the window size) in both ends so that transient parts are minimized
    in the begining and end part of the output signal.
    The code taken from: https://scipy-cookbook.readthedocs.io/items/SignalSmooth.html
    
    input:
        x: the input signal 
        window_len: the dimension of the smoothing window; should be an odd integer
        window: the type of window from 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'
                'flat' window will produce a moving average smoothing.

    output:
        the smoothed signal        
    """

    if x.ndim != 1:
        raise(ValueError, "smooth only accepts 1 dimension arrays.")

    if x.size < window_len:
        raise(ValueError, "Input vector needs to be bigger than window size.")

    if window_len<3:
        return x

    if not window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']:
        raise(ValueError, "Window is one of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'")

    s=np.r_[x[window_len-1:0:-1],x,x[-2:-window_len-1:-1]]

    if window == 'flat': # Moving average
        w=np.ones(window_len,'d')
    else:
        w=eval('np.'+window+'(window_len)')

    y=np.convolve(w/w.sum(),s,mode='valid')

    return y[(int(window_len/2)-1):-int(window_len/2)]

In [0]:
from scipy.signal import butter, filtfilt, freqz

def butter_lowpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    return b, a


def butter_lowpass_filter(data, cutoff, fs, order=10):
    b, a = butter_lowpass(cutoff, fs, order=order)
    y = filtfilt(b, a, data)
    return y

In [0]:
def flatten_signal(signal,lambda_=1e4,order=0.5,itermax=50):

# Find the signal baseline using airPLS alghorithm
  # add_beginning=600
  # signal = np.insert(signal, 0, signal[0]*np.ones(add_beginning)).reshape(len(signal)+add_beginning,1)
  # base = airPLS(signal[:,0],lambda_=lambda_,porder=order,itermax=itermax).reshape(len(signal),1)
  # signal = signal[add_beginning:]
  # base = base[add_beginning:]

  add=600
  s = np.r_[signal[add-1:0:-1],signal,signal[-2:-add-1:-1]]
  b = airPLS(s,lambda_=lambda_,porder=order,itermax=itermax)
  signal = s[add-1:-add+1]
  base = b[add-1:-add+1]

# Remove the begining of the signal baseline
  signal = (signal - base)

  return signal, base

In [0]:
def standardize_signal(signal):

  z_signal = (signal - np.nanmedian(signal)) / np.nanstd(signal)

  return z_signal

In [0]:
def plot_raw_int(raws,intensities,bases,
                exp_time=[],events=[],
                figsize=(24, 13)):

  signal_colors = ['purple','blue','green']

  if exp_time==[]:
    exp_time = range(len(raws[list(raws.keys())[0]]))
    
  fig, axs = plt.subplots(len(raws),2,figsize=figsize)
  axs = axs.ravel()
# Plot signal
  i=0
  for signal_type in raws:
    # Raw signal  
    axs[2*i].plot(exp_time,raws[signal_type], color=signal_colors[i], linewidth=1.5,label='signal')
    axs[2*i].plot(exp_time,bases[signal_type], color='black', linewidth=1.5)
    axs[2*i].set_ylabel(signal_type, fontsize='x-large', multialignment='center')
    # Flatten and smoothed signal    
    axs[2*i+1].plot(exp_time,intensities[signal_type], color=signal_colors[i],linewidth=1.5,label='signal')
    i=i+1
  axs[2*(i-1)].set_xlabel('time', fontsize='x-large', multialignment='center')
  axs[2*i-1].set_xlabel('time', fontsize='x-large', multialignment='center')
 # Plot events
  cmap = get_cmap(len(events))
  for k,key in enumerate(events): # plot all events
    try:
      for i in range(2*len(raws)):  # create events on each subplot
        # continious measurement (expect only one)
        if isinstance(events[key], dict):
          ax_e = axs[i].twinx()
          ax_e.plot(events[key]['time'], events[key]['values'], color=cmap(k),label=key)
          ax_e.set_ylabel(key,fontsize='x-large',multialignment='center',color=cmap(k))
          ax_e.tick_params('y', colors=cmap(k))
          e_max = np.nanmax(events[key]['values'])
          e_min = np.nanmin(events[key]['values'])
          ax_e.set_ylim([e_min, e_max + (e_max-e_min)]) # plot on the bottom half
        # one occurance event
        elif events[key].shape[1]==1:
          for event in events[key]:
            axs[i].axvline(event,linewidth=1,color=cmap(k),label=key)
        # event with onset and offset     
        elif events[key].shape[1]==2:
          for event0,event1 in events[key]:
            axs[i].axvspan(event0,event1,alpha=0.3,color=cmap(k),label=key)
    except:
      pass
  for i in range(2*len(raws)):
    axs[i].set_xlim([0,max(exp_time)])
    axs[i].tick_params(labelsize='large')
# Legend    
  handles, labels = plt.gca().get_legend_handles_labels()
  by_label = OrderedDict(zip(labels, handles))
  plt.legend(by_label.values(), by_label.keys(), prop={'size': 'small'})

  return fig

airPLS algorithm

In [0]:
'''
airPLS.py Copyright 2014 Renato Lombardo - renato.lombardo@unipa.it
Baseline correction using adaptive iteratively reweighted penalized least squares

This program is a translation in python of the R source code of airPLS version 2.0
by Yizeng Liang and Zhang Zhimin - https://code.google.com/p/airpls
Reference:
Z.-M. Zhang, S. Chen, and Y.-Z. Liang, Baseline correction using adaptive iteratively reweighted penalized least squares. Analyst 135 (5), 1138-1146 (2010).

Description from the original documentation:

Baseline drift always blurs or even swamps signals and deteriorates analytical results, particularly in multivariate analysis.  It is necessary to correct baseline drift to perform further data analysis. Simple or modified polynomial fitting has been found to be effective in some extent. However, this method requires user intervention and prone to variability especially in low signal-to-noise ratio environments. The proposed adaptive iteratively reweighted Penalized Least Squares (airPLS) algorithm doesn't require any user intervention and prior information, such as detected peaks. It iteratively changes weights of sum squares errors (SSE) between the fitted baseline and original signals, and the weights of SSE are obtained adaptively using between previously fitted baseline and original signals. This baseline estimator is general, fast and flexible in fitting baseline.


LICENCE
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>
'''

import numpy as np
from scipy.sparse import csc_matrix, eye, diags
from scipy.sparse.linalg import spsolve

def WhittakerSmooth(x,w,lambda_,differences=1):
    '''
    Penalized least squares algorithm for background fitting
    
    input
        x: input data (i.e. chromatogram of spectrum)
        w: binary masks (value of the mask is zero if a point belongs to peaks and one otherwise)
        lambda_: parameter that can be adjusted by user. The larger lambda is,  the smoother the resulting background
        differences: integer indicating the order of the difference of penalties
    
    output
        the fitted background vector
    '''
    X=np.matrix(x)
    m=X.size
    i=np.arange(0,m)
    E=eye(m,format='csc')
    D=E[1:]-E[:-1] # numpy.diff() does not work with sparse matrix. This is a workaround.
    W=diags(w,0,shape=(m,m))
    A=csc_matrix(W+(lambda_*D.T*D))
    B=csc_matrix(W*X.T)
    background=spsolve(A,B)
    return np.array(background)

def airPLS(x, lambda_=100, porder=1, itermax=15):
    '''
    Adaptive iteratively reweighted penalized least squares for baseline fitting
    
    input
        x: input data (i.e. chromatogram of spectrum)
        lambda_: parameter that can be adjusted by user. The larger lambda is,  the smoother the resulting background, z
        porder: adaptive iteratively reweighted penalized least squares for baseline fitting
    
    output
        the fitted background vector
    '''
    m=x.shape[0]
    w=np.ones(m)
    for i in range(1,itermax+1):
        z=WhittakerSmooth(x,w,lambda_, porder)
        d=x-z
        dssn=np.abs(d[d<0].sum())
        if(dssn<0.001*(abs(x)).sum() or i==itermax):
            if(i==itermax): print('WARING max iteration reached!')
            break
        w[d>=0]=0 # d>0 means that this point is part of a peak, so its weight is set to 0 in order to ignore it
        w[d<0]=np.exp(i*np.abs(d[d<0])/dssn)
        w[0]=np.exp(i*(d[d<0]).max()/dssn) 
        w[-1]=w[0]
    return z


### Fit and align calcium-independent signal to calcium dependent

In [0]:
def allrecordings_align(exp_info,events,Intensities,
                        standardize=False,
                        plot=False, figsize=(24, 12),
                        save=False):

  for output in Intensities:

   # Positive linear fit reference vs signal
    reference_fitted = fit_signal(Intensities[output]['signal'],Intensities[output]['reference'],
                                  standardize=standardize)
    if plot:
      figtitle = exp_info['mouse'] + ' ' + output + ' ' + exp_info['test']
      fig_fit = plot_fit(Intensities[output]['signal'],Intensities[output]['reference'],reference_fitted,figsize=figsize)
      fig_fit.suptitle(figtitle, fontsize='xx-large')
    
    Intensities[output]['reference'] = reference_fitted

   # Plot aligned data
    if plot:
      exp_time = exp_info['time']
      fig_align = plot_aligned(Intensities[output]['signal'],Intensities[output]['reference'],exp_time,events,figsize=figsize)
      fig_align.suptitle(figtitle, fontsize='xx-large')

     # Save if asked
      if save:
        create_folder('figures')
        create_folder('./figures/2_fit')
        create_folder('./figures/3_align')
        imgname_fit = figtitle.replace(' ','_') + '_fit.png'
        imgname_align = figtitle.replace(' ','_') + '_align.png'
        fig_fit.savefig('./figures/2_fit/'+imgname_fit)
        fig_align.savefig('./figures/3_align/'+imgname_align)

  return Intensities

In [0]:
def fit_signal(signal, reference, standardize=False):

  if standardize:
    median_signal = np.nanmedian(signal)
    std_signal = np.nanstd(signal)
    signal = (signal - median_signal) / std_signal
    median_ref = np.nanmedian(reference)
    std_ref = np.nanstd(reference)
    reference = (reference - median_ref) / std_ref
    
  from sklearn.linear_model import Lasso

  nans = np.argwhere(np.isnan(signal))
  signal = signal[int(max(nans))+1:]
  reference = reference[int(max(nans))+1:]

  signal = np.array(signal).reshape(len(signal),1)
  reference = np.array(reference).reshape(len(reference),1)

# Positive linear regression   
  lin = Lasso(alpha=0.0001,precompute=True,max_iter=1000,
              positive=True, random_state=9999, selection='random')
  lin.fit(reference, signal)
  reference_fitted = lin.predict(reference) 

  a = np.empty((int(max(nans))+1,))
  a[:] = np.nan   
  reference_fitted = np.r_[a,reference_fitted]

  if standardize:
    reference_fitted = reference_fitted * std_ref + median_signal
      
  return reference_fitted

In [0]:
def plot_fit(signal,reference,reference_fitted,figsize=(24, 12)):
  
  fig = plt.figure(figsize=figsize)
  ax = fig.add_subplot(111)
  ax.plot(reference,signal,'b.')
  ax.plot(reference,reference_fitted, 'r--',linewidth=1.5)
  ax.set_xlabel('reference', fontsize='x-large', multialignment='center')
  ax.set_ylabel('signal', fontsize='x-large', multialignment='center')
  ax.tick_params(labelsize='large')

  return fig

In [0]:
def plot_aligned(signal,reference,exp_time=[],events=[],
                 figsize=(24, 12)):
  
  if exp_time==[]:
    exp_time = range(len(signal))

  fig = plt.figure(figsize=figsize)
  ax = fig.add_subplot(111)
 # Signal
  ax.plot(exp_time, signal, 'black' ,linewidth=1.5)
  ax.plot(exp_time, reference, 'purple',linewidth=1.5)
 # Events
  cmap = get_cmap(len(events))
  for k,key in enumerate(events): # plot all events
    try:
      # continious measurement (expect only one)
      if isinstance(events[key], dict):
        ax_e = ax.twinx()
        ax_e.plot(events[key]['time'], events[key]['values'], color=cmap(k),label=key)
        ax_e.set_ylabel(key,fontsize='x-large',multialignment='center',color=cmap(k))
        ax_e.tick_params('y', colors=cmap(k))
        e_max = np.nanmax(events[key]['values'])
        e_min = np.nanmin(events[key]['values'])
        ax_e.set_ylim([e_min, e_max + (e_max-e_min)]) # plot on the bottom half
      # one occurance event
      elif events[key].shape[1]==1:
        for event in events[key]:
          ax.axvline(event,linewidth=1,color=cmap(k),label=key)
      # event with onset and offset     
      elif events[key].shape[1]==2:
        for event0,event1 in events[key]:
          ax.axvspan(event0,event1,alpha=0.3,color=cmap(k),label=key)
    except:
      pass
 # Params
  ax.set_xlabel('time', fontsize='x-large', multialignment='center')
  ax.set_ylabel('Intensity', fontsize='x-large', multialignment='center')
  ax.set_xlim([0,max(exp_time)])
  ax.tick_params(labelsize='x-large')
 # Legend    
  handles, labels = plt.gca().get_legend_handles_labels()
  by_label = OrderedDict(zip(labels, handles))
  plt.legend(by_label.values(), by_label.keys(), prop={'size': 'medium'})

  return fig

### Calculate z dFF

In [0]:
def allrecordings_dff(exp_info,events,Intensities,
                      standardized=True,
                      standardize=True,
                      plot=False, figsize=(24, 12),
                      save=False):
  
  dFFs = {}

  for output in Intensities:

    dFFs[output]  = dff(Intensities[output]['signal'],Intensities[output]['reference'],
                        standardized=standardized,standardize=standardize)

   # Plot dFF
    if plot:
      exp_time = exp_info['time']
      fig = plot_dff(dFFs[output],exp_time,events,figsize=figsize)
      figtitle = exp_info['mouse'] + ' ' + output + ' ' + exp_info['test']
      fig.suptitle(figtitle, fontsize='xx-large')

     # Save if asked
      if save:
        create_folder('./figures')
        create_folder('./figures/4_dFF')
        imgname = figtitle.replace(' ','_') + '_dFF.png'
        fig.savefig('./figures/4_dFF/'+imgname)

  return dFFs

In [0]:
def dff(signal,reference,standardized=True,standardize=False):

  if standardized:
    dFF = signal - reference

  else:
    dFF = (signal - reference) / reference

    if standardize:
      dFF = standardize_signal(dFF)

  return dFF  

In [0]:
def plot_dff(dFF,exp_time=[],events=[],figsize=(24,12)):

  if exp_time==[]:
    exp_time = range(len(dFF))

  ymin = np.nanmin([-3, np.nanmin(dFF)])
  ymax = np.nanmax([3, np.nanmax(dFF)])

  fig = plt.figure(figsize=figsize)
  ax = fig.add_subplot(111)
 # Signal
  ax.plot(exp_time, dFF, 'black' ,linewidth=1.5)
 # Events
  cmap = get_cmap(len(events))
  for k,key in enumerate(events): # plot all events
    try:
      # continious measurement (expect only one)
      if isinstance(events[key], dict):
        ax_e = ax.twinx()
        ax_e.plot(events[key]['time'], events[key]['values'], color=cmap(k),label=key)
        ax_e.set_ylabel(key,fontsize='x-large',multialignment='center',color=cmap(k))
        ax_e.tick_params('y', colors=cmap(k))
        e_max = np.nanmax(events[key]['values'])
        e_min = np.nanmin(events[key]['values'])
        ax_e.set_ylim([e_min, e_max + (e_max-e_min)]) # plot on the bottom half
      # one occurance event
      elif events[key].shape[1]==1:
        for event in events[key]:
          ax.axvline(event,linewidth=1,color=cmap(k),label=key)
      # event with onset and offset     
      elif events[key].shape[1]==2:
        for event0,event1 in events[key]:
          ax.axvspan(event0,event1,alpha=0.3,color=cmap(k),label=key)
    except:
      pass
 # Params
  ax.set_xlabel('time', fontsize='x-large', multialignment='center')
  ax.set_ylabel('dF/F', fontsize='x-large', multialignment='center')
  ax.set_xlim([0,max(exp_time)])
  ax.set_ylim([ymin, ymax])
  ax.tick_params(labelsize='large')
 # Legend    
  handles, labels = plt.gca().get_legend_handles_labels()
  by_label = OrderedDict(zip(labels, handles))
  plt.legend(by_label.values(), by_label.keys(), prop={'size': 'medium'})

  return fig

### Peri event data

In [0]:
def allrecordings_events_arrays(exp_info,events,dFFs,
                                info_for_array=None,
                                plot=False,
                                save=False):

  freq = exp_info['frequency']
  time_vector = exp_info['time']

  Events_arrays = {}
  for output in dFFs:
    Events_arrays[output] = {}

  cmap = get_cmap(len(events))

  for k,event in enumerate(events):

    try:
      window = info_for_array[event]['window']
    except:
      window = [-5,5]

    try: 
      dur = info_for_array[event]['dur']
    except:
      dur = None

    try:
      iei = info_for_array[event]['iei']
    except:
      iei = None

    try:    
      avg_win = info_for_array[event]['avg_win']
    except:
      avg_win = None

    for output in dFFs:

     # Create an array
      Array = create_event_array(dFFs[output],events[event],time_vector=time_vector,
                                window=window,dur=dur,iei=iei,avg_win=avg_win)
      
      Events_arrays[output][event] = Array.copy()
      keys = list(Array.keys())
      
     # Plot if asked 
      if plot:
        try:
          figsize = info_for_array[event]['figsize']
          fig = plot_event_array(Array,freq,window=window,color_mean=cmap(k),figsize=figsize)
        except:
          fig = plot_event_array(Array,freq,window=window,color_mean=cmap(k))
        figtitle = exp_info['mouse'] + ' ' + output + ' ' + exp_info['test'] + ' ' + event
        fig.suptitle(figtitle, fontsize='xx-large')
      
       # Save if asked
        if save:
          create_folder('./figures')
          create_folder('./figures/5_mean')
          imgname = figtitle.replace(' ','_') + '_mean.png'
          fig.savefig('./figures/5_mean/'+imgname)

  return Events_arrays

In [0]:
def plot_event_array(Array,freq=10,window=[-5,5],
                     color_trace='black',color_mean='green',
                     figsize=(12,10)):

  Mean = {}
  Error = {}
  for key in Array:
    Mean[key] = np.nanmean(Array[key],axis=0)
    Error[key] = np.nanstd(Array[key],axis=0) / np.sqrt(Array[key].shape[0])
  
  ymax = 1
  ymin = -1
  for key in Array:
    ymax = max(ymax,1.1*Array[key].max())
    ymin = min(ymin,1.1*Array[key].min())

  xmin = window[0]
  xmax = Array[key].shape[1]/freq + xmin
  event_line = xmax - window[1]

  from matplotlib import gridspec

  fig = plt.figure(figsize=figsize)
  gs = gridspec.GridSpec(1, len(Array))
  for i,key in enumerate(Array):
    ax = fig.add_subplot(gs[i])  
    ax.set_title(key, fontsize='xx-large')
    ax.plot(np.arange(xmin,xmax,1/freq),Array[key].T,color=color_trace,alpha=0.3,linewidth=1)
    ax.plot(np.arange(xmin,xmax,1/freq),Mean[key],color=color_mean,linewidth=2)
    ax.fill_between(np.arange(xmin,xmax,1/freq), Mean[key]-Error[key],Mean[key]+Error[key],
                    alpha=0.3,edgecolor=color_mean,facecolor=color_mean,linewidth=0)
    ax.axvline(0,linestyle='--',color='black',linewidth=1.5)
    ax.axvline(event_line,linestyle='--',color='black',linewidth=1.5)
    ax.set_xlabel('time', fontsize='xx-large', multialignment='center')
    ax.set_ylabel('z dF/F', fontsize='xx-large', multialignment='center')
    ax.set_ylim([ymin,ymax])
    ax.set_xlim([xmin,xmax])
    ax.tick_params(labelsize='x-large')

  return fig

In [0]:
def create_event_array(dFF,event,time_vector,window=[-5,5],dur=None,iei=None,avg_win=None):

  period = find_avg_period(time_vector, time_format='total seconds')

  Event_array = {}

 # Continuous events -----------------------------------------------------------
  if isinstance(event, dict) or len(event)==0:
    pass

  else:

   # Remove events at the beginning and end of test that less than win 
    event = event[np.all(event > window[0], axis=1)]
    event = event[np.all(event < max(time_vector)-window[1], axis=1)] 

   # Events with one occurence -------------------------------------------------
    if event.shape[1]==1:
      Array = []
      for e in event:
        dFF_event = chunk_signal(dFF,e,time_vector,window)
        if avg_win != None:
          dFF_event_mean = (dFF_event[int((window[0]+avg_win[0])/period):int((window[0]+avg_win[1])/period)]).mean()
          dFF_event = dFF_event - dFF_event_mean
        Array.append(dFF_event)    
      Array = np.array(Array).squeeze()
      if Array.ndim == 1:
        Array = Array.reshape(1,len(Array))
      Array = Array[~np.isnan(Array).any(axis=1)]

      Event_array['events'] = Array

   # Events with onset and offset ----------------------------------------------
    elif event.shape[1]==2:

      events_dur = [e1 - e0 for e0,e1 in event]

     # Events with same duration -----------------------------------------------
      if all(x == events_dur[0] for x in events_dur):
        
        Array = []
        for e0,e1 in event:
          dFF_event = chunk_signal(dFF,e0,time_vector,[-window[0],window[1]+e1])
         # Normalize in avg_win
          if avg_win != None:
            dFF_event_mean = (dFF_event[int((window[0]+avg_win[0])/period):int((window[0]+avg_win[1])/period)]).mean()
            dFF_event = dFF_event - dFF_event_mean
          Array.append(dFF_event)
        Array = np.array(Array).squeeze()
        if Array.ndim == 1:
          Array = Array.reshape(1,len(Array))
        Array = Array[~np.isnan(Array).any(axis=1)]

        Event_array['events'] = Array

     # Events with different duration ------------------------------------------
      else:   
      # Remove short intervals and durations
        event = adjust_intervals_durations(event, iei, dur)
          
       # Create Array
        Array_start = []
        Array_end = []
        for e0,e1 in event:
          dFF_event = chunk_signal(dFF,e0,time_vector,[-window[0],window[1]+e1])
         # Normalize in avg_win
          if avg_win != None:
            dFF_event_mean = (dFF_event[int((window[0]+avg_win[0])/period):int((window[0]+avg_win[1])/period)]).mean()
            dFF_event = dFF_event - dFF_event_mean
          Array_start.append(dFF_event[:int((window[1]-window[0])/period)])
          Array_end.append(dFF_event[-int((window[1]-window[0])/period):])
        Array_start = np.array(Array_start).squeeze()
        Array_end = np.array(Array_end).squeeze()
        if Array_start.ndim == 1:
          Array_start = Array_start.reshape(1,len(Array_start))
          Array_end = Array_end.reshape(1,len(Array_end))
        Array_start = Array_start[~np.isnan(Array_start).any(axis=1)]
        Array_end = Array_end[~np.isnan(Array_end).any(axis=1)]

        Event_array['onset'] = Array_start 
        Event_array['offset'] = Array_end

  return Event_array    

In [0]:
def chunk_signal(signal, t0, time_vector, w):

  idx = find_idx(t0, time_vector, 'total seconds')

  period = find_avg_period(time_vector, time_format='total seconds')

  i0 = idx + int(w[0]/period) - 1
  i1 = idx + int(w[1]/period) + 1

  chunk = signal[i0:i1]

  return chunk

# OS helper functions

### Find file in the list of files by part of the name

In [0]:
def find_file(folder,string):

    file_list = os.listdir(folder)
    files = []
    
    for file_name in file_list:
        if string in file_name:
            files.append(file_name)
    
    return files

### Create folder if it doesn't exist

In [0]:
def create_folder(new_folder):
  if not os.path.exists('./'+ new_folder):
    os.mkdir(new_folder)

# Time helper functions

### Transform real time 'HH:MM:SS.ms' to total seconds

In [0]:
def time_to_seconds(t, t0):

  t = np.array([(pd.Timedelta(x)-pd.Timedelta(t0)).total_seconds() for x in t])

  return t

### Find avereage period of recording

In [0]:
def find_avg_period(t, time_format='total seconds'):

  if time_format=='real time':
    t = time_to_seconds(t, t[0])

  dif = [t[i+1]-t[i] for i in range(len(t)-1)]

  period = np.median(dif)

  return period

### Find index based on time

In [0]:
def find_idx(t, time_vector, time_format='total seconds'):

  if time_format == 'real time':
    t0 = pd.Timedelta(time_vector[0])
    time_vector = np.array([(pd.Timedelta(t)-t0).total_seconds() for t in time_vector])
    t = (pd.Timedelta(t) - t0).total_seconds()
  
  elif time_format == 'total seconds':
    time_vector = np.array(time_vector)
    time_vector -= time_vector[0]
    t -= time_vector[0]

  idx = (np.abs(time_vector-t)).argmin()

  return idx 

# Behavior events helper functions

### Get vector of values from Med Associates file

In [0]:
def values_from_medfile(file_name,letter):

  import string
  med_letters = []
  for l in list(string.ascii_uppercase):
    med_letters.append(l+':')

# Get info from the file  
  lines = []
  file = open(file_name, 'r')
  for line in file:
    lines.append(line.rstrip('\n'))
  file.close()

# Get the indecies of start and end lines
  start_line = ''
  end_line = ''
  for idx in range(13,len(lines)):
    if letter+':' in lines[idx]:
      start_line = idx
      break
  for idx in range(start_line+1,len(lines)):    
    if lines[idx][:2] in med_letters:
      end_line = idx
      break
  if end_line == '':
    end_line = len(lines)

# Create a list of values
  values = []
  for idx in range(start_line,end_line):
    line = lines[idx]
    for letter in med_letters:
      if letter in line:
        line = line.replace(letter,'')
    line_elements = line.split(' ')
    for element in line_elements:
      if element != '' and ':' not in element:
        values.append(float(element)) 

  return(values)

### Find onsets and offsets based on binary vector

In [0]:
def event_onoffset(vector,event_time,t0='00:00:00.0',time_format='real time'):

  '''
  Finds event onsets and offsets based on the binary vector
  
  input
    vector: binary vector, where 1 means event was present and 0 - no event;
    time: time, the same length as vector;
    start_time: time when fiber photometry recordings started, expected to be in
                the format as time. It can be real time stamp or a number 
                indicating how much time before the photometry recordings event 
                recording started. If event recordings started after photomery, 
                the number should be negative
    time_format: 'real time' is format of'HH:MM:SS.ms', 
                 'total seconds' is a vector of numbers

  output
    event: numpy array with first colum of event onset
  '''

  # Create list of time in total seconds if it was in the format of real time
  if time_format == 'real time':
    total_sec = [pd.Timedelta(t).total_seconds() for t in event_time]
    t0 = pd.Timedelta(t0).total_seconds()   
  elif time_format == 'total seconds':
    total_sec = list(event_time)

 # Adjust time based on the 0 time 
  total_sec = [t-t0 for t in total_sec]

  # Transform the vector to numpy array
  vector = np.array(vector)

  # Calculate the n-th discrete difference
  diff_v = np.diff(vector)

 # Find event onset (==1) and offset (==-1)
  onset = [total_sec[i+1] for i in range(len(diff_v)) if diff_v[i] > 0]
  offset = [total_sec[i] for i in range(len(diff_v)) if diff_v[i] < 0]
 
 # Concat onset and offsets to 2D numpy array
  event = np.array([onset, offset]).reshape(len(onset),2)

  return event

### Remove events with short inter event intervals and/or short duration

In [0]:
def adjust_intervals_durations(event, min_inter=None, min_dur=None):

  if min_inter != None:         
    onset = event[:,0]
    offset = event[:,1]
    intervals = np.array(onset[1:]) - np.array(offset[:-1])
    remove = [i for i in range(len(intervals)) if intervals[i] < min_inter]
    onset = np.delete(onset,[i+1 for i in remove])
    offset = np.delete(offset,remove)
    event = np.array([onset,offset]).T

  if min_dur != None:
    event = [[e0,e1] for e0,e1 in event if e1-e0 > min_dur]

  return event

### Find speed from coordinates and time

In [0]:
def get_speed(x_coord,y_coord,rec_time,smooth_win=5):

  speed = np.zeros(len(x_coord))

  for i in range(len(x_coord)):
    
    if i==0 or x_coord[i-1] == None or y_coord[i-1] == None:
      i_next = i + 1
      i_prev = i

    elif i==len(x_coord)-1:
      i_next = i
      i_prev = i-1

    else:
      i_next = i + 1
      i_prev = i -1

    dist = np.sqrt((x_coord[i_next]-x_coord[i_prev])**2 + (y_coord[i_next]-y_coord[i_prev])**2)
    #dist = smooth_signal(dist,smooth_win)
    time_dif = rec_time[i_next]-rec_time[i_prev]
    speed[i] = dist / time_dif

  return speed

### Align continuous vector to fiber photometry recordings 

In [0]:
def align_vector_am2fp(vector_am,time_am,time_fp,time='real time',norm=True):

# Convert vector to numpy array
  vector_am = np.array(vector_am[:len(time_am)]).reshape(len(time_am),1)
  
# Find FP indecies corresponding to Anymaze time
  if time=='real time':
    fp_idxs = idx_event(time_fp, time_am).reshape(len(time_am),)
  elif time=='test time':
    fp_idxs = np.searchsorted(time_fp,time_am)

# Create new vector with length of FP data  
  vector = np.zeros((len(time_fp),1))
  vector[fp_idxs,:] = vector_am

# Change nans to 0
  vector = np.nan_to_num(vector)

# Fill missing numbers with previous value    
  for i in range(len(vector)):
      if vector[i] == 0:
          vector[i] = vector[i-1]
  for i in range(len(vector)-1,-1,-1):
      if vector[i] == 0:
          vector[i] = vector[i+1]
# Smooth 
  vector[:,0] = moving_average(vector[:,0], 5)
  
  if norm==True:
  # Normalize to range 0-1
      OldMin = vector.min()
      OldMax = vector.max()
      for s in range(len(vector)):  
          vector[s] = ((vector[s] - OldMin) / (OldMax - OldMin))
                  
  return vector

### Find immobility onsets and offsets

In [0]:
def find_onoffset_immobility(movement, rec_time, min_duration=2, on_threshold=0.01, off_threshold=0.015):

  i = np.max(np.argwhere(np.isnan(movement)))+1

  onset = []
  offset = []

  while i < len(movement):

    if movement[i] < on_threshold:
      i1 = i
      t0 = rec_time[i]
      while rec_time[i] - t0 < min_duration:
        i += 1
        if i == len(movement):
          break
      
      if (movement[i1:i] < on_threshold).all():
        onset.append(rec_time[i1])

        while movement[i] < off_threshold:
          i += 1
          if i ==len(movement):
            break

        offset.append(rec_time[i-1])
        
    i += 1

  intervals = np.array(onset[1:]) - np.array(offset[:-1])
  remove = [i for i in range(len(intervals)) if intervals[i] < min_duration]
  onset = np.delete(onset,[i+1 for i in remove])
  offset = np.delete(offset,remove) 

  immobility = np.array([onset,offset]).T

  return immobility  

# Plot helper functions

### Make patch spines invisible

In [0]:
def make_patch_spines_invisible(ax):
    ax.set_frame_on(True)
    ax.patch.set_visible(False)
    for sp in ax.spines.values():
        sp.set_visible(False)

### Plot colors

In [0]:
def get_cmap(n, name='gist_rainbow'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct 
    RGB color; the keyword argument name must be a standard mpl colormap name.'''
    return plt.cm.get_cmap(name, n)

### Plot green and red dFF signals on one figure

In [0]:
def plot_dff_green_red(dFF_green,dFF_red,time,events,exp_info,fiber,test,movement=[],save=False):
    
    ymax_green = max(1,1.1*max(dFF_green))
    ymin_green = min(-1,1.1*min(dFF_green))
    
    ymax_red = max(1,1.1*max(dFF_red))
    ymin_red = min(-1,1.1*min(dFF_red))
    
    fig, ax = plt.subplots(figsize=(24, 12)) 
    fig.suptitle(exp_info['mouse'] + ' ' + exp_info['green'][fiber] + ' vs ' + exp_info['red'][fiber] + ' ' + exp_info['test'], fontsize=32)
    ax1 = ax.twinx()    
    ax.patch.set_visible(False) # hide the 'canvas'
    
    p1, = ax.plot(time, dFF_green, 'green', linewidth=1.5)
    p2, = ax1.plot(time, dFF_red,'red',linewidth=1.5)
    
    if test == 'AP':
        for i in events:
            ax.axvline(x=time[i],color='red',linewidth=1.5)
            
    elif test == 'CT':
        for i in range(len(events)):
            ax.axvspan(time[events[i,0]],time[events[i,1]],color='red',alpha=0.3)
        
    elif test == 'TST':
        for i in range(len(events)):
            ax.axvspan(time[events[i,0]],time[events[i,1]],color='blue',alpha=0.2)      
    
    ax.set_xlim(min(time), max(time))
    ax.set_ylim(ymin_green-2, ymax_green+2)
    ax1.set_ylim(ymin_red, ymax_red+4)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_color(p1.get_color())
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    ax.spines['left'].set_linewidth(2)
    ax.spines['bottom'].set_linewidth(2)
    
    ax1.spines['top'].set_visible(False)
    ax1.spines['left'].set_visible(False)
    ax1.spines['right'].set_color(p2.get_color())
    ax1.spines['right'].set_linewidth(2)
    
    ax.set_xlabel('Time (s)', fontsize=24, multialignment='center')
    ax.set_ylabel('z dF/F ('+exp_info['green'][fiber]+')',fontsize=24,multialignment='center')
    ax1.set_ylabel('z dF/F ('+exp_info['red'][fiber]+')',fontsize=24,multialignment='center')
    
    ax.yaxis.label.set_color(p1.get_color())
    ax1.yaxis.label.set_color(p2.get_color())
    
    tkw = dict(size=4, width=1.5)
    ax.tick_params(axis='x', **tkw)
    ax.tick_params(axis='y', colors=p1.get_color(), **tkw)
    ax1.tick_params(axis='y', colors=p2.get_color(), **tkw)
   
    ax.xaxis.set_tick_params(labelsize=18)
    ax.yaxis.set_tick_params(labelsize=18)
    ax1.yaxis.set_tick_params(labelsize=18)
    
    if len(movement) != 0:
        fig.subplots_adjust(right=0.9)
        ax2 = ax.twinx()
        ax2.spines["right"].set_position(("axes", 1.07))
        make_patch_spines_invisible(ax2)
        ax2.spines["right"].set_visible(True)
        
        p3, = ax2.plot(time, movement,'blue',linewidth=1.5)
        ax2.set_ylim(-2, 1)
        
        ax2.spines['top'].set_visible(False)
        ax2.spines['left'].set_visible(False)
        ax2.spines['right'].set_color(p3.get_color())
        ax2.spines['right'].set_linewidth(2)
        
        ax2.set_ylabel('Mobility score',fontsize=24,multialignment='center')
        ax2.yaxis.label.set_color(p3.get_color())
        ax2.tick_params(axis='y', colors=p3.get_color(), **tkw)
        ax2.yaxis.set_tick_params(labelsize=18)
        
# Save figure
    if save:
      figfile = exp_info['mouse'] + '_' + exp_info['green'][fiber] + '_' + exp_info['red'][fiber] + '_' + exp_info['test'] + '_dFF.png'
      fig.savefig(exp_info['folder']+'/figures/dFF/'+figfile)

### Plot red and green normalized events on one figure

In [0]:
def plot_norm_green_red(Array_green,Array_red,exp_info,fiber,test,Array_movement=[],save=False):
    
    plt.close('all')
    sec = 5
    win = int(sec*Hz)
    
# TST test. Start of mobility and immobility
# Consumption test. Onset and offset
    if test == 'CT' or test == 'TST':            
        Mean_green_start = Array_green['start'].mean(axis=1)
        Error_green_start = Array_green['start'].std(axis=1) / np.sqrt(Array_green['start'].shape[1])
        Mean_green_end = Array_green['end'].mean(axis=1)
        Error_green_end = Array_green['end'].std(axis=1) / np.sqrt(Array_green['end'].shape[1])
        
        Mean_red_start = Array_red['start'].mean(axis=1)
        Error_red_start = Array_red['start'].std(axis=1) / np.sqrt(Array_red['start'].shape[1])
        Mean_red_end = Array_red['end'].mean(axis=1)
        Error_red_end = Array_red['end'].std(axis=1) / np.sqrt(Array_red['end'].shape[1])
        
        ymax_green = max(0.5,1.5*Mean_green_start.max())
        ymin_green = min(-0.5,1.5*Mean_green_end.min())
        
        ymax_red = max(0.5,1.5*Mean_red_start.max())
        ymin_red = min(-0.5,1.5*Mean_red_end.min())
        
    # Movement arrays
        if len(Array_movement) != 0:          
            Mean_move_start = Array_movement['start'].mean(axis=1)
            Error_move_start = Array_movement['start'].std(axis=1) / np.sqrt(Array_movement['start'].shape[1])
            Mean_move_end = Array_movement['end'].mean(axis=1)
            Error_move_end = Array_movement['end'].std(axis=1) / np.sqrt(Array_movement['end'].shape[1])
            
# One event
    elif test == 'AP':            
        Mean_green_start = Array_green['start'].mean(axis=1)
        Error_green_start = Array_green['start'].std(axis=1) / np.sqrt(Array_green['start'].shape[1])
        
        Mean_red_start = Array_red['start'].mean(axis=1)
        Error_red_start = Array_red['start'].std(axis=1) / np.sqrt(Array_red['start'].shape[1])
        
        ymax_green = max(0.5,1.5*Mean_green_start.max())
        ymin_green = min(-0.5,1.5*Mean_green_start.min())
        
        ymax_red = max(0.5,1.5*Mean_red_start.max())
        ymin_red = min(-0.5,1.5*Mean_red_start.min())
        
    # Movement arrays
        if len(Array_movement) != 0:          
            Mean_move_start = Array_movement['start'].mean(axis=1)
            Error_move_start = Array_movement['start'].std(axis=1) / np.sqrt(Array_movement[0].shape[1])
        
# Plots with start and end
    if test == 'CT' or test == 'TST':
        fig = plt.figure(figsize=(25, 10))
        fig.suptitle(exp_info['mouse'] + ' ' + exp_info['green'][fiber] + 
                     ' vs ' + exp_info['red'][fiber] + ' ' + exp_info['test'], fontsize=32)
        
        host1 = fig.add_subplot(121)
        host1.set_title('Start', fontsize=24)
        ax1 = host1.twinx() 
        host1.patch.set_visible(False) # hide the 'canvas'
        
        p1, = host1.plot(np.arange(-sec,sec,1/Hz), Mean_green_start, 'green', linewidth=2)
        host1.fill_between(np.arange(-sec,sec,1/Hz), 
                         Mean_green_start-Error_green_start,Mean_green_start+Error_green_start,
                         alpha=0.3,edgecolor='green',facecolor='green',linewidth=0)
        p2, = ax1.plot(np.arange(-sec,sec,1/Hz), Mean_red_start,'red',linewidth=2)
        ax1.fill_between(np.arange(-sec,sec,1/Hz), 
                         Mean_red_start-Error_red_start,Mean_red_start+Error_red_start,
                         alpha=0.3,edgecolor='red',facecolor='red',linewidth=0)
        ax1.axvline(0,color='black',linestyle='--',linewidth=2)
        
        host1.set_xlim([-sec,sec])
        host1.set_ylim(ymin_green-0.5, ymax_green+0.5)
        ax1.set_ylim(ymin_red, ymax_red+1)
    
        host1.spines['top'].set_visible(False)
        host1.spines['right'].set_visible(False)
        host1.spines['left'].set_color(p1.get_color())
        host1.yaxis.set_ticks_position('left')
        host1.xaxis.set_ticks_position('bottom')
        host1.spines['left'].set_linewidth(2)
        host1.spines['bottom'].set_linewidth(2)
        
        ax1.spines['top'].set_visible(False)
        ax1.spines['left'].set_visible(False)
        ax1.spines['right'].set_color(p2.get_color())
        ax1.spines['right'].set_linewidth(2)
        
        host1.set_xlabel('Time (s)', fontsize=24, multialignment='center')
        host1.set_ylabel('z dF/F ('+exp_info['green'][fiber]+')',fontsize=24,multialignment='center')
        #ax1.set_ylabel('z dF/F ('+exp_info['red'][fiber]+')',fontsize=24,multialignment='center')
        
        host1.yaxis.label.set_color(p1.get_color())
        ax1.yaxis.label.set_color(p2.get_color())
        
        tkw = dict(size=4, width=1.5)
        host1.tick_params(axis='x', **tkw)
        host1.tick_params(axis='y', colors=p1.get_color(), **tkw)
        ax1.tick_params(axis='y', colors=p2.get_color(), **tkw)
       
        host1.xaxis.set_tick_params(labelsize=18)
        host1.yaxis.set_tick_params(labelsize=18)
        ax1.yaxis.set_tick_params(labelsize=18)
        
        if len(Array_movement) != 0:
            #fig.subplots_adjust(right=0.8)
            ax2 = host1.twinx()
            ax2.spines["right"].set_position(("axes", 1.08))
            make_patch_spines_invisible(ax2)
            ax2.spines["right"].set_visible(True)
            
            p3, = ax2.plot(np.arange(-sec,sec,1/Hz),Mean_move_start,color='blue',linewidth=2)
            ax2.fill_between(np.arange(-sec,sec,1/Hz), 
                             Mean_move_start-Error_move_start,Mean_move_start+Error_move_start,
                             alpha=0.3,edgecolor='blue',facecolor='blue',linewidth=0)
            ax2.set_ylim(-1, 1.5*np.max(Mean_move_start))
            
            ax2.spines['top'].set_visible(False)
            ax2.spines['left'].set_visible(False)
            ax2.spines['right'].set_color(p3.get_color())
            ax2.spines['right'].set_linewidth(2)
            
            #ax2.set_ylabel('Mobility score',fontsize=24,multialignment='center')
            ax2.yaxis.label.set_color(p3.get_color())
            ax2.tick_params(axis='y', colors=p3.get_color(), **tkw)
            ax2.yaxis.set_tick_params(labelsize=18)
        
        
        host2 = fig.add_subplot(122)
        host2.set_title('End', fontsize=24)
        ax3 = host2.twinx()    
        host2.patch.set_visible(False) # hide the 'canvas'
        
        p4, = host2.plot(np.arange(-sec,sec,1/Hz), Mean_green_end, 'green', linewidth=2)
        host2.fill_between(np.arange(-sec,sec,1/Hz), 
                         Mean_green_end-Error_green_end,Mean_green_end+Error_green_end,
                         alpha=0.3,edgecolor='green',facecolor='green',linewidth=0)
        p5, = ax3.plot(np.arange(-sec,sec,1/Hz), Mean_red_end,'red',linewidth=2)
        ax3.fill_between(np.arange(-sec,sec,1/Hz), 
                         Mean_red_end-Error_red_end,Mean_red_end+Error_red_end,
                         alpha=0.3,edgecolor='red',facecolor='red',linewidth=0)
        ax3.axvline(0,color='black',linestyle='--',linewidth=2)
        
        host2.set_xlim([-sec,sec])
        host2.set_ylim(ymin_green-0.5, ymax_green+0.5)
        ax3.set_ylim(ymin_red, ymax_red+1)
    
        host2.spines['top'].set_visible(False)
        host2.spines['right'].set_visible(False)
        host2.spines['left'].set_color(p1.get_color())
        host2.yaxis.set_ticks_position('left')
        host2.xaxis.set_ticks_position('bottom')
        host2.spines['left'].set_linewidth(2)
        host2.spines['bottom'].set_linewidth(2)
        
        ax3.spines['top'].set_visible(False)
        ax3.spines['left'].set_visible(False)
        ax3.spines['right'].set_color(p2.get_color())
        ax3.spines['right'].set_linewidth(2)
        
        host2.set_xlabel('Time (s)', fontsize=24, multialignment='center')
        #host2.set_ylabel('z dF/F ('+exp_info['green'][fiber]+')',fontsize=24,multialignment='center')
        ax3.set_ylabel('z dF/F ('+exp_info['red'][fiber]+')',fontsize=24,multialignment='center')
        
        host2.yaxis.label.set_color(p1.get_color())
        ax3.yaxis.label.set_color(p2.get_color())
        
        tkw = dict(size=4, width=1.5)
        host2.tick_params(axis='x', **tkw)
        host2.tick_params(axis='y', colors=p1.get_color(), **tkw)
        ax3.tick_params(axis='y', colors=p2.get_color(), **tkw)
       
        host2.xaxis.set_tick_params(labelsize=18)
        host2.yaxis.set_tick_params(labelsize=18)
        ax3.yaxis.set_tick_params(labelsize=18)
        
        if len(Array_movement) != 0:
            #fig.subplots_adjust(right=0.8)
            ax4 = host2.twinx()
            ax4.spines["right"].set_position(("axes", 1.15))
            make_patch_spines_invisible(ax4)
            ax4.spines["right"].set_visible(True)
            
            p6, = ax4.plot(np.arange(-sec,sec,1/Hz),Mean_move_end,color='blue',linewidth=2)
            ax4.fill_between(np.arange(-sec,sec,1/Hz), 
                             Mean_move_end-Error_move_end,Mean_move_end+Error_move_end,
                             alpha=0.3,edgecolor='blue',facecolor='blue',linewidth=0)
            ax4.set_ylim(-1, 1.5*np.max(Mean_move_start))
            
            ax4.spines['top'].set_visible(False)
            ax4.spines['left'].set_visible(False)
            ax4.spines['right'].set_color(p3.get_color())
            ax4.spines['right'].set_linewidth(2)
            
            ax4.set_ylabel('Mobility score',fontsize=24,multialignment='center')
            ax4.yaxis.label.set_color(p3.get_color())
            ax4.tick_params(axis='y', colors=p3.get_color(), **tkw)
            ax4.yaxis.set_tick_params(labelsize=18)


# Plots with one event
    elif test == 'AP':
        fig = plt.figure(figsize=(13, 10))
        fig.suptitle(exp_info['mouse'] + ' ' + exp_info['green'][fiber] + 
                     ' vs ' + exp_info['red'][fiber] + ' ' + exp_info['test'], fontsize=32)
        
        host1 = fig.add_subplot(111)
        ax1 = host1.twinx() 
        host1.patch.set_visible(False) # hide the 'canvas'
        
        p1, = host1.plot(np.arange(-sec,sec,1/Hz), Mean_green_start, 'green', linewidth=2)
        host1.fill_between(np.arange(-sec,sec,1/Hz), 
                         Mean_green_start-Error_green_start,Mean_green_start+Error_green_start,
                         alpha=0.3,edgecolor='green',facecolor='green',linewidth=0)
        p2, = ax1.plot(np.arange(-sec,sec,1/Hz), Mean_red_start,'red',linewidth=2)
        ax1.fill_between(np.arange(-sec,sec,1/Hz), 
                         Mean_red_start-Error_red_start,Mean_red_start+Error_red_start,
                         alpha=0.3,edgecolor='red',facecolor='red',linewidth=0)
        ax1.axvline(0,color='black',linestyle='--',linewidth=2)
        
        host1.set_xlim([-sec,sec])
        host1.set_ylim(ymin_green-0.5, ymax_green+0.5)
        ax1.set_ylim(ymin_red, ymax_red+1)
    
        host1.spines['top'].set_visible(False)
        host1.spines['right'].set_visible(False)
        host1.spines['left'].set_color(p1.get_color())
        host1.yaxis.set_ticks_position('left')
        host1.xaxis.set_ticks_position('bottom')
        host1.spines['left'].set_linewidth(2)
        host1.spines['bottom'].set_linewidth(2)
        
        ax1.spines['top'].set_visible(False)
        ax1.spines['left'].set_visible(False)
        ax1.spines['right'].set_color(p2.get_color())
        ax1.spines['right'].set_linewidth(2)
        
        host1.set_xlabel('Time (s)', fontsize=24, multialignment='center')
        host1.set_ylabel('z dF/F ('+exp_info['green'][fiber]+')',fontsize=24,multialignment='center')
        ax1.set_ylabel('z dF/F ('+exp_info['red'][fiber]+')',fontsize=24,multialignment='center')
        
        host1.yaxis.label.set_color(p1.get_color())
        ax1.yaxis.label.set_color(p2.get_color())
        
        tkw = dict(size=4, width=1.5)
        host1.tick_params(axis='x', **tkw)
        host1.tick_params(axis='y', colors=p1.get_color(), **tkw)
        ax1.tick_params(axis='y', colors=p2.get_color(), **tkw)
       
        host1.xaxis.set_tick_params(labelsize=18)
        host1.yaxis.set_tick_params(labelsize=18)
        ax1.yaxis.set_tick_params(labelsize=18)
        
        if len(Array_movement) != 0:
            #fig.subplots_adjust(right=0.8)
            ax2 = host1.twinx()
            ax2.spines["right"].set_position(("axes", 1.08))
            make_patch_spines_invisible(ax2)
            ax2.spines["right"].set_visible(True)
            
            p3, = ax2.plot(np.arange(-sec,sec,1/Hz),Mean_move_start,color='blue',linewidth=2)
            ax2.fill_between(np.arange(-sec,sec,1/Hz), 
                             Mean_move_start-Error_move_start,Mean_move_start+Error_move_start,
                             alpha=0.3,edgecolor='blue',facecolor='blue',linewidth=0)
            ax2.set_ylim(-1, 1.5*np.max(Mean_move_start))
            
            ax2.spines['top'].set_visible(False)
            ax2.spines['left'].set_visible(False)
            ax2.spines['right'].set_color(p3.get_color())
            ax2.spines['right'].set_linewidth(2)
            
            ax2.set_ylabel('Mobility score',fontsize=24,multialignment='center')
            ax2.yaxis.label.set_color(p3.get_color())
            ax2.tick_params(axis='y', colors=p3.get_color(), **tkw)
            ax2.yaxis.set_tick_params(labelsize=18)
        
                    
# Save figure
    if save:
      figfile = exp_info['mouse'] + '_' + exp_info['green'][fiber] + '_' + exp_info['red'][fiber] + '_' + exp_info['test'] + '_mean.png'
      fig.savefig(exp_info['folder']+'/figures/mean/'+figfile)

### Plot summary with mean for each animal

In [0]:
def plot_sum(Array,Means,title,periods=['baseline','event'],color='green',
             figsize=(3.5, 3.5),xmin=-5,xmax=5,ymin=-1,ymax=5,
             Array_move=[],color0='blue',ymin0=-1.5,ymax0=1):

  from matplotlib import gridspec
  
  fig = plt.figure(figsize=figsize)
  fig.suptitle(title,fontsize=24,color=color)
  
  gs = gridspec.GridSpec(1, 2, width_ratios=[4, 1]) 
  ax1 = fig.add_subplot(gs[0])
  ax1.patch.set_visible(False) 

  ax1.axvline(0,linestyle='--',color='black',linewidth=1.5)
  ax1.plot(np.arange(-sec,sec,1/Hz),Array,color=color,linewidth=1)

  ax1.set_xlim([xmin,xmax])
  ax1.set_ylim(ymin, ymax)

  ax1.spines['top'].set_visible(False)
  ax1.spines['right'].set_visible(False)
  ax1.yaxis.set_ticks_position('left')
  ax1.xaxis.set_ticks_position('bottom')
  ax1.spines['left'].set_linewidth(2)
  ax1.spines['bottom'].set_linewidth(2)

  ax1.set_ylabel('z dF/F',fontsize=24,multialignment='center')
  ax1.set_xlabel('Time (s)', fontsize=24, multialignment='center')

  tkw = dict(size=4, width=1.5)
  ax1.tick_params(axis='y', **tkw)
  ax1.tick_params(axis='x', **tkw)
  ax1.xaxis.set_tick_params(labelsize=18)
  ax1.yaxis.set_tick_params(labelsize=18)

  ax2 = fig.add_subplot(gs[1])
  ax2.patch.set_visible(False) 
  ax2.plot(range(1,2*len(periods),2),Means,color=color,linewidth=1)
  
  ax2.set_ylim(ymin, ymax)
  ax2.set_xlim([0,2*len(periods)])
  ax2.xaxis.set_ticks(range(1,2*len(periods),2))
  ax2.set_xticklabels(periods, rotation=60)
  
  ax2.spines['top'].set_visible(False)
  ax2.spines['right'].set_visible(False)
  ax2.yaxis.set_ticks_position('left')
  ax2.xaxis.set_ticks_position('bottom')
  ax2.spines['left'].set_linewidth(2)
  ax2.spines['bottom'].set_linewidth(2)

  tkw = dict(size=4, width=1.5)
  ax2.tick_params(axis='y', **tkw)
  ax2.tick_params(axis='x', **tkw)
  ax2.xaxis.set_tick_params(labelsize=14)
  ax2.yaxis.set_tick_params(labelsize=18)
  #ax2.set_yticklabels([])

  if len(Array_move) != 0:
    ax0 = ax1.twinx()
    ax0.plot(np.arange(-sec,sec,1/Hz),Array_move,color=color0,linewidth=1)

    ax0.set_xlim(xmin,xmax)
    ax0.set_ylim(ymin0, ymax0)
    ax0.yaxis.set_ticks([0,0.5])

    ax0.spines['top'].set_visible(False)
    ax0.spines['left'].set_visible(False)
    ax0.yaxis.set_ticks_position('right')
    ax0.xaxis.set_ticks_position('bottom')
    ax0.spines['right'].set_color(color0)
    ax0.spines['right'].set_linewidth(2)
    ax0.spines['bottom'].set_linewidth(2)

    ax0.set_ylabel('Mobility score',color=color0,fontsize=18,multialignment='center')
    ax0.tick_params(axis='y',colors=color0, **tkw)
    ax0.yaxis.label.set_color(color0)
    ax0.tick_params(axis='x', **tkw)
    ax0.xaxis.set_tick_params(labelsize=14)
    ax0.yaxis.set_tick_params(labelsize=14)

  plt.tight_layout()
  plt.show()
  
  return fig

### Plot 2 summary with mean for each animal

In [0]:
def plot_2sums(Array1,Means1,Array2,Means2,periods,
               title1='GCaMP',title2='jrGECO',
               color1='green',color2='red',
               figsize = (5.5, 3.5),xmin=-5,xmax=5,
               ymin1=-2,ymax1=4,ymin2=-1,ymax2=5):

  
  from matplotlib import gridspec
  
  
  
  fig = plt.figure(figsize=figsize)
  #fig.suptitle('',fontsize=24)  
  gs = gridspec.GridSpec(1, 2, width_ratios=[4, 1])
  
  
  
  ax1 = fig.add_subplot(gs[0])
  ax1.patch.set_visible(False) 
  ax1.axvline(0,linestyle='--',color='black',linewidth=1.5)
  ax1.plot(np.arange(-sec,sec,1/Hz),Array1,color=color1,linewidth=1)
  
  ax2 = ax1.twinx()
  ax2.plot(np.arange(-sec,sec,1/Hz),Array2,color=color2,linewidth=1)
  
  ax3 = fig.add_subplot(gs[1])
  ax3.patch.set_visible(False) 
  ax3.plot(range(1,2*len(periods),2),Means1,color=color1,linewidth=1)
  
  ax4 = ax3.twinx()
  ax4.plot(range(1,2*len(periods),2),Means2,color=color2,linewidth=1)

  
  
  ax1.set_xlim(xmin,xmax)
  ax1.set_ylim(ymin1, ymax1)
  ax1.yaxis.set_ticks(range(int(ymin1),int(ymax1)+1,1))
  
  ax2.set_xlim(xmin,xmax)
  ax2.set_ylim(ymin2, ymax2)
  ax2.yaxis.set_ticks(range(int(ymin2),int(ymax2)+1,1))
  
  ax3.set_xlim([0,2*len(periods)])
  ax3.xaxis.set_ticks(range(1,2*len(periods),2))
  ax3.set_xticklabels(periods, rotation=40,ha='right')
  ax3.set_ylim(ymin1, ymax1)
  ax3.yaxis.set_ticks(range(int(ymin1),int(ymax1)+1,1))
  
  ax4.set_xlim([0,2*len(periods)])
  ax4.xaxis.set_ticks(range(1,2*len(periods),2))
  ax4.set_ylim(ymin2, ymax2)
  ax4.yaxis.set_ticks(range(int(ymin2),int(ymax2)+1,1))

  
  
  ax1.spines['top'].set_visible(False)
  ax1.spines['right'].set_visible(False)
  ax1.yaxis.set_ticks_position('left')
  ax1.xaxis.set_ticks_position('bottom')
  ax1.spines['left'].set_color(color1)
  ax1.spines['left'].set_linewidth(2)
  ax1.spines['bottom'].set_linewidth(2)
  
  ax2.spines['top'].set_visible(False)
  ax2.spines['left'].set_visible(False)
  ax2.yaxis.set_ticks_position('right')
  ax2.xaxis.set_ticks_position('bottom')
  ax2.spines['right'].set_color(color2)
  ax2.spines['right'].set_linewidth(2)
  ax2.spines['bottom'].set_linewidth(2)
  
  ax3.spines['top'].set_visible(False)
  ax3.spines['right'].set_visible(False)
  ax3.yaxis.set_ticks_position('left')
  ax3.xaxis.set_ticks_position('bottom')
  ax3.spines['left'].set_color(color1)
  ax3.spines['left'].set_linewidth(2)
  ax3.spines['bottom'].set_linewidth(2)
  
  ax4.spines['top'].set_visible(False)
  ax4.spines['left'].set_visible(False)
  ax4.yaxis.set_ticks_position('right')
  ax4.xaxis.set_ticks_position('bottom')
  ax4.spines['right'].set_color(color2)
  ax4.spines['right'].set_linewidth(2)
  ax4.spines['bottom'].set_linewidth(2)
  
  ax1.set_xlabel('Time (s)', fontsize=18, multialignment='center')
  ax1.set_ylabel('z dF/F ({})'.format(title1),color=color1,fontsize=18,multialignment='center')
  ax4.set_ylabel('z dF/F ({})'.format(title2),color=color2,fontsize=18,multialignment='center')

  tkw = dict(size=4, width=1.5)
  
  ax1.tick_params(axis='y', colors=color1, **tkw)
  ax1.tick_params(axis='x', **tkw)
  ax1.yaxis.label.set_color(color1)
  
  ax2.tick_params(axis='y',colors=color2, **tkw)
  ax2.tick_params(axis='x', **tkw)
  ax2.yaxis.label.set_color(color2)
  
  ax3.tick_params(axis='y', colors=color1, **tkw)
  ax3.tick_params(axis='x', **tkw)
  ax3.yaxis.label.set_color(color1)
  
  ax4.tick_params(axis='y', colors=color2, **tkw)
  ax4.tick_params(axis='x', **tkw)
  ax4.yaxis.label.set_color(color2)
  
  ax1.xaxis.set_tick_params(labelsize=14)
  ax1.yaxis.set_tick_params(labelsize=14)
  
  ax2.xaxis.set_tick_params(labelsize=14)
  ax2.yaxis.set_tick_params(labelsize=14)
  
  ax3.xaxis.set_tick_params(labelsize=14)
  ax3.yaxis.set_tick_params(labelsize=14)
  
  ax4.xaxis.set_tick_params(labelsize=14)
  ax4.yaxis.set_tick_params(labelsize=14)

  plt.tight_layout()
  plt.show()
  
  return fig

### Plot example trace

In [0]:
def example_2traces(dFF1,dFF2,events,test='AP',
                    title1='GCaMP',title2='jrGECO',
                    color1='green',color2='red',
                    figsize=(7, 3.5),ymin1=-2,ymax1=3,ymin2=-1,ymax2=4,
                    movement=[],color='blue',ymin=-0.1,ymax=1.1):
  
  from matplotlib import gridspec
  
  time = np.arange(0,len(dFF1)/Hz,1/Hz)

  fig = plt.figure(figsize=figsize)
  
  if len(movement) != 0:
  
    gs = gridspec.GridSpec(2, 1, height_ratios=[2, 5])

    ax0 = fig.add_subplot(gs[0])
    ax0.patch.set_visible(False)
    ax0.plot(time, movement, color, linewidth=1.5)

    ax0.spines['top'].set_visible(False)
    ax0.spines['bottom'].set_visible(False)
    ax0.spines['left'].set_visible(False)
    ax0.yaxis.set_ticks_position('right')
    ax0.spines['right'].set_color(color)
    ax0.set_xticks([])
    ax0.set_ylabel('Mobility\n score', fontsize=18, color=color, multialignment='center')
    tkw = dict(size=4, width=1.5)
    ax0.tick_params(axis='y', colors=color, **tkw)
    ax0.set_xlim(min(time), max(time))
    ax0.set_ylim(ymin,ymax)
    ax0.yaxis.set_ticks(np.arange(int(ymin),int(ymax)+1,0.5))
    
    ax = fig.add_subplot(gs[1])
    
  else: 

    ax = fig.add_subplot(111)
   
  ax.patch.set_visible(False) # hide the 'canvas'
  ax.plot(time, dFF1, color1, linewidth=1.5)
  if test == 'AP':
    ax.vlines(events,ymin1,ymax2,color='black',linestyle='--',linewidth=2)
  elif test == 'CT':
    for i in range(len(events)):
      ax.axvspan(events[i,0],events[i,1],color='r',alpha=0.3) 

  ax1 = ax.twinx()    
  ax1.plot(time, dFF2,color2,linewidth=1.5) 

  ax.set_xlim(min(time), max(time))
  ax.set_ylim(ymin1, ymax1)
  ax1.set_ylim(ymin2, ymax2)

  ax.spines['top'].set_visible(False)
  ax.spines['right'].set_visible(False)
  ax.spines['left'].set_color(color1)
  ax.yaxis.set_ticks_position('left')
  ax.xaxis.set_ticks_position('bottom')
  ax.spines['left'].set_linewidth(2)
  ax.spines['bottom'].set_linewidth(2)
  ax1.spines['top'].set_visible(False)
  ax1.spines['left'].set_visible(False)
  ax1.spines['right'].set_color(color2)
  ax1.spines['right'].set_linewidth(2)

  ax.set_xlabel('Time (s)', fontsize=24, multialignment='center')
  ax.set_ylabel('z dF/F ({})'.format(title1),color=color1,fontsize=18,multialignment='center')
  ax1.set_ylabel('z dF/F ({})'.format(title2),color=color2,fontsize=18,multialignment='center')

  #ax.yaxis.label.set_color(color1)
  #ax1.yaxis.label.set_color(color2)
  #ax2.yaxis.label.set_color(color3)

  tkw = dict(size=4, width=1.5)
  ax.tick_params(axis='x', **tkw)
  ax.tick_params(axis='y', colors=color1, **tkw)
  ax1.tick_params(axis='y', colors=color2, **tkw)

  ax.xaxis.set_tick_params(labelsize=18)
  ax.yaxis.set_tick_params(labelsize=18)
  ax1.yaxis.set_tick_params(labelsize=18)

  plt.tight_layout()
  plt.show()
  
  return fig

# Correlation analysis

### Cross-correlation

In [0]:
def xcorr(x, y, normed=True, detrend=False, maxlags=10):
  """
  Cross correlation of two signals of equal length
  Returns the coefficients when normed=True
  Returns inner products when normed=False
  Usage: lags, c = xcorr(x,y,maxlags=len(x)-1)
  Optional detrending e.g. mlab.detrend_mean
  """    

  Nx = len(x)
  if Nx != len(y):
      raise ValueError('x and y must be equal length')

  if detrend:
      import matplotlib.mlab as mlab
      x = mlab.detrend_mean(np.asarray(x)) # can set your preferences here
      y = mlab.detrend_mean(np.asarray(y))

  c = np.correlate(x, y, mode='full')

  if normed:
      n = np.sqrt(np.dot(x, x) * np.dot(y, y)) # this is the transformation function
      c = np.true_divide(c,n)

  if maxlags is None:
      maxlags = Nx - 1

  if maxlags >= Nx or maxlags < 1:
      raise ValueError('maglags must be None or strictly '
                       'positive < %d' % Nx)

  lags = np.arange(-maxlags, maxlags + 1)
  c = c[Nx - 1 - maxlags:Nx + maxlags]
  return lags, c

### Random sampling

In [0]:
import random

def random_subset( iterator, K ):
    result = []
    N = 0

    for item in iterator:
        N += 1
        if len( result ) < K:
            result.append( item )
        else:
            s = int(random.random() * N)
            if s < K:
                result[ s ] = item

    return result

# Print done

In [0]:
print('All FP functions are ready to use')

All FP functions are ready to use
