In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [4]:
exclusions = ('state_name=="Artifact"',
              'state_name=="Transition"',
              'state_name=="REM"',
              'animal=="RW9" and (state_name=="Active" or tetrode==6)', # allow Active in trace plots
              'animal=="RW7" and day==5',
              'animal=="RZ7" and day==3',
              'animal=="RZ9" and day==5 and tetrode==8')
qStr = 'not ('+') and not ('.join(exclusions)+')'

In [3]:
method = 'CWT' # CWT or STFT
measure = 'mean' # mean, median, total or norm_mean
group = 'genotype' # genotype, gender or animal
group_order = ('WT','Df1')

In [None]:
def logMean(x):
    tmp = np.mean(x)
    return 10*np.log10(tmp)
def spec_plot(data,method,group,hue='state_name',asp=1.33,group_order=None,nBoot=1000):
    yCol = method+'_Power'
    if len(data[group].unique())>4:
        legend_y = 1.2
        cWrap=4
    else:
        legend_y = 1.05
        cWrap=2
    g = sns.FacetGrid(data=data,col=group,hue='state_name',
                      margin_titles=True,hue_order=('NREM','Rest','Active'),
                      palette=col_palette,col_order=group_order,height=6,aspect=asp,col_wrap=cWrap)
    g = g.map(sns.lineplot,'frequency',yCol,sort=True,n_boot=nBoot,estimator=logMean)
    plt.subplots_adjust(top=0.8)
    g.fig.suptitle('Mean '+method+' Spectra')
    [plt.setp(ax.texts, text="") for ax in g.axes.flat]
    g.set_titles(row_template='{row_name}', col_template='{col_name}')
    g.set_axis_labels('',method+' Power (dB)')
    g.fig.text(0.5,0,'Frequency (Hz)',ha='center',fontsize=24)
    legend = plt.legend(loc='upper right', bbox_to_anchor=(1.1, legend_y),
                        ncol=1, fancybox=True, shadow=True,fontsize=18,title='State')
    g.despine()

In [None]:
def trace_plot(data,method,group,group_order=None,nBoot=1000):
    yCol = method+'_delta_power'
    g = sns.FacetGrid(data,hue='epoch_type',row=group,col='epoch_type',
                      margin_titles=True,sharex=False,height=5,aspect=2,
                      row_order=group_order,palette=col_palette,col_order=['Baseline','Saline','Ketamine'])
    g = g.map(sns.lineplot,'epoch_time',yCol,n_boot=nBoot)
    plt.subplots_adjust(top=0.9)
    g.fig.suptitle(method+' Delta Power')
    [plt.setp(ax.texts, text="") for ax in g.axes.flat]
    g.set_titles(row_template='{row_name}', col_template='{col_name}')
    g.set_axis_labels('','')
    g.fig.text(0,.5,'Normalized '+method+' Power (dB)',rotation=90,va='center',fontsize=24)
    g.fig.text(0.5,0,'Frequency (Hz)',ha='center',fontsize=24)
    g.despine()

In [None]:
def bar_plot(data,method,measure,group,hue='state_name',col=None,
             group_order=None,nBoot=1000,col_order=None,row_order=None,
             hue_order=None,row='power_band'):
    yCol = method+'_'+measure
    if len(data[group].unique())>4:
        legend_y = 1.2
        cWrap=4
    else:
        legend_y = 1.3
        cWrap=2
    if row is not None:
        cWrap=None
    g = sns.catplot(kind='bar',data=data,x=group,y=yCol,col=col,hue=hue,
                    margin_titles=True,row=row,n_boot=nBoot,order=group_order,
                    row_order=row_order,col_order=col_order,height=5,aspect=2,col_wrap=cWrap,
                    hue_order=hue_order,legend=False)
    plt.subplots_adjust(top=0.85)
    g.fig.suptitle(method+' Band Power')
    [plt.setp(ax.texts, text="") for ax in g.axes.flat]
    g.set_titles(row_template='{row_name}', col_template='{col_name}')
    legend = g.axes.flat[0].legend(loc='best', bbox_to_anchor=(1.05, legend_y),
                        ncol=1, fancybox=True, shadow=True,fontsize=18,title='State')
    g.set_axis_labels('','')
    g.fig.text(0.01,0.5,method+' '+measure+' Power',rotation=90,va='center',fontsize=24)
    g.fig.text(0.5,0.03,'Frequency (Hz)',ha='center',fontsize=24)
    g.despine()
                    