# Load dependencies and define functions

In [1]:
### dependencies

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import warnings
warnings.filterwarnings('ignore')

### functions

def xy_axis(ax):
    
    ### Removes the top and right bounding axes that are plotted by default in matplotlib
    
    ### INPUTS
    ### ax: axis object (e.g. from fig,ax = plt.subplots(1,1))
    
    ### OUTPUTS
    ### ax: the same axis w/top and right lines removed
    
    # Hide the right and top spines
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    # Only show ticks on the left and bottom spines
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')

    return ax

print('finished loading dependencies')

def plot_group_eye_head_traces(df,sacs,mov_type,eye_head,baselines,baseline_pts,trace_range,trange,save_pdf,pp):
    
    ### plots clustered traces of spiking vs. eye/head movements and saves a pdf (optional)
    
    ### INPUTS
    ### recs: list of recordings names, e.g. ['fm1']
    ### sacs: list of saccade labels, e.g. ['upsacc','downsacc']
    ### mov_type: list of movement types, e.g. ['gaze_shift','comp']
    ### eye_head: list of eye/head movement labels, e.g. ['dEye','dHead']
    ### trace_range: time range you want to plot over, e.g. [-0.25,0.5]
    ### trange: time points for traces, current is np.arange(-1,1.1-0.025,0.025)
    ### baselines: list of booleans if you want to baseline traces or not, e.g. [True,False]
    ### baseline_pts: indices to use for baseline subtraction, e.g. np.arange(20,36)
    ### save_pdf: boolean value for if you want to save a pdf, e.g. True
    
    ### OUTPUTS
    ### currently none
    
    for rec in recs:
        for eh in eye_head:
            for baseline in baselines:
                fig, axs = plt.subplots(len(mov_type),len(sacs),figsize=(5*len(sacs),5*len(mov_type)))
                for m,m_t in enumerate(mov_type):
                    for s,sac in enumerate(sacs):
                        trace_label = '%s_%s_avg_%s_%s' % (rec,sac,m_t,eh)
                        cluster = trace_label + '_cluster'
                        clusts = np.unique(df[cluster])
                        len_trace = len(df[trace_label].iloc[0])

                        for clust in clusts:
                            clust_df = df[df[cluster]==clust]
                            clust_df.reset_index(inplace=True,drop=True)
                            cluster_label = clust_df[cluster + '_type'].iloc[0]
                            sessions = np.unique(clust_df['session'])
                            num_sess = len(np.unique(clust_df['session']))
                            all_sess = np.zeros((num_sess,len_trace))
                            for se,sess in enumerate(sessions):
                                sess_df = clust_df[clust_df['session']==sess]
                                if len(sess_df)>0:
                                    sess_df.reset_index(inplace=True,drop=True)
                                    sess_array = np.zeros((len(sess_df),len_trace))
                                    for t,trace in enumerate(sess_df[trace_label]):
                                        sess_array[t,:] = trace
                                    all_sess[se,:] = np.nanmean(sess_array,axis=0)
                                else:
                                    sess_array = np.empty(len_trace)
                                    sess_array[:] = np.nan
                                    all_sess[se,:] = sess_array

                            ax = axs[m,s]
                            x = trange
                            y=np.nanmean(all_sess,axis=0)
                            if baseline:
                                y-=np.nanmean(y[baseline_pts])
                            err = np.nanstd(all_sess,axis=0)/np.sqrt(num_sess)

                            ax.plot(x,y,'-',label=str(clust))
                            ax.fill_between(x,y-err,y+err,alpha=0.2)
                        ax.set_xlabel('time from movement (s)')
                        ax.set_title(trace_label)
                        ax.set_xlim(trace_range[0],trace_range[1])
                        if baseline:
                            ax.set_ylim(-15,15)
                            ax.set_ylabel('baselined sp/s')
                        else:
                            ax.set_ylim(0,40)
                            ax.set_ylabel('sp/s')
                        ax = xy_axis(ax)
                        ax.legend(fontsize=8)
                fig.tight_layout()
                if save_pdf:
                    pp.savefig(fig)
                    plt.close(fig)

    print('finished adding eye/head plots to PDF!')

finished loading dependencies


# Load the data

In [None]:
group_file = r'\\Goeppert\nlab-nas\freely_moving_ephys\batch_files\092221\pooled_ephys_population_update_092221.pickle'
df = pd.read_pickle(group_file)

# remove inserted column if it exists
try:
    del df['level_0']
except:
    pass

# set up PDF

In [None]:
pdf_file = r'T:/freely_moving_ephys/group_analysis.pdf'
pp = PdfPages(pdf_file)

# Plot firing rate around eye/head movements for all sessions

In [None]:
recs = ['fm1'] #list of recording names
sacs = ['upsacc','downsacc'] #list of saccade names
mov_type = ['gaze_shift','comp'] #list of movement types
eye_head = ['dEye','dHead'] #list of eye/head variables
trace_range = [-0.25,0.5] #time range to plot traces over
trange = np.arange(-1,1.1-0.025,0.025) #time points for traces
baselines=[True,False] #do you want to baseline or do not or do both
baseline_pts = np.arange(20,36) #indices for baselining
save_pdf=True #do you want to save pdf

plot_group_eye_head_traces(df,sacs,mov_type,eye_head,baselines,baseline_pts,trace_range,trange,save_pdf,pp)

# Close the PDF

In [None]:
pp.close()