# General Plotting
## Primary Results
### Time Domain

In [None]:
import os
import numpy as np
import pylab as plt
from pandas import read_csv

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## File parameters.
labels = ['dacc-lh', 'dacc-rh', 'dmpfc-lh', 'dmpfc-rh', 'dlpfc_1-lh', 'dlpfc_1-rh', 'dlpfc_2-lh', 'dlpfc_2-rh', 
          'dlpfc_3-lh', 'dlpfc_3-rh', 'dlpfc_4-lh', 'dlpfc_4-rh', 'dlpfc_5-lh', 'dlpfc_5-rh', 
          'dlpfc_6-lh', 'dlpfc_6-rh', 'pcc-lh', 'pcc-rh', 'racc-lh', 'racc-rh']
# labels = ['FCZ']

space = 'source'
model_name = 'revised'

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Main loop.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
root_dir = '/space/sophia/2/users/EMOTE-DBS/afMSIT/%s' %space
img_dir = '/space/sophia/2/users/EMOTE-DBS/afMSIT/plots/%s' %space
info = read_csv(os.path.join(root_dir, 'afMSIT_%s_info.csv' %space))

for label in labels:

    for freq in [15]:

        ## Intialize figure.
        fig = plt.figure(figsize=(16,8))
        axes = []
        
        ## Initialize labels.
        if space == 'sensor': ylabel = 'Voltage (uV)'
        elif space == 'source': ylabel = 'dSPM'
        title = '%s ERP' %(label.replace('_',' ').upper())
        out_path = os.path.join(img_dir, 'timedomain', '%s_%s.png' %(model_name,title.replace(' ','_').lower()))
            
        for n, analysis in enumerate(['stim', 'resp']):
 
            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
            ### Load data.
            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

            ## Load source data.
            npz = np.load(os.path.join(root_dir, 'afMSIT_%s_%s_%s_%s.npz' %(space,analysis,label,freq)))
            data = npz['data']
            times = npz['times']
            
            ## Load cluster results.
            f = os.path.join(root_dir, 'results', '%s_%s_timedomain_results.csv' %(model_name, analysis))
            clusters = read_csv(f)
            
            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
            ### Plot DBS Difference.
            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
            
            ax = plt.subplot2grid((2,2),(n,0))
            for m, color, legend in zip([0,1],['#0571b0','#ca0020'],['DBSoff','DBSon']):
                
                ix, = np.where(info.DBS==m)
                mu = data[ix].mean(axis=0)
                se = data[ix].std(axis=0) / np.sqrt(len(ix))
                ax.plot(times, mu, linewidth=2, color=color, label=legend)
                ax.fill_between(times, mu-se, mu+se, color=color, alpha=0.2)
        
            ## Plot significant clusters.
            for ix in np.where((clusters.Label==label)&(clusters.Freq==freq)&
                               (clusters.Contrast=='DBS')&(clusters.FDR<0.05))[0]:
                y1, y2 = ax.get_ylim()
                tmin, tmax = clusters.loc[ix,'Tmin'], clusters.loc[ix,'Tmax']
                ax.fill_between(np.linspace(tmin,tmax,1e3), y1, y2, color='k', alpha=0.2) 
                ax.set_ylim(y1,y2)
                
            axes.append(ax)
                
            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
            ### Plot Interference Difference.
            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

            ax = plt.subplot2grid((2,2),(n,1))
            for m, color, legend in zip([0,1],['#7b3294','#008837'],['Neu','Int']):
                
                ix, = np.where(info.Interference==m)
                mu = data[ix].mean(axis=0)
                se = data[ix].std(axis=0) / np.sqrt(len(ix))
                ax.plot(times, mu, linewidth=2, color=color, label=legend)
                ax.fill_between(times, mu-se, mu+se, color=color, alpha=0.2)
                
            ## Plot significant clusters.
            for ix in np.where((clusters.Label==label)&(clusters.Freq==freq)&
                               (clusters.Contrast=='Interference')&(clusters.FDR<0.05))[0]:
                y1, y2 = ax.get_ylim()
                tmin, tmax = clusters.loc[ix,'Tmin'], clusters.loc[ix,'Tmax']
                ax.fill_between(np.linspace(tmin,tmax,1e3), y1, y2, color='k', alpha=0.2)    
                ax.set_ylim(y1,y2)
            
            axes.append(ax)
            
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
        ### Add Flourishes
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

        for i, ax in enumerate(axes):

            ax.legend(loc=2, fontsize=18, markerscale=2, frameon=False, borderpad=0, handletextpad=0.2)
            ax.set_xlabel('Time (s)', fontsize=18)
            if not i%2: ax.set_ylabel(ylabel, fontsize=20)
            ax.tick_params(axis='both', which='major', labelsize=12)
            if not i/2: ax.set_title('Stimulus-Locked', fontsize=24)
            else: ax.set_title('Response-Locked', fontsize=24)

            ## Time-lock specific.
            if not i/2:
                y1, y2 = ax.get_ylim()
                for x,s in zip([0, 0.4, 1.127],['IAPS','MSIT','Resp']): 
                    ax.text(x+0.02,y1+np.abs(y1*0.05),s,fontsize=16)
                    ax.vlines(x,y1,y2,linestyle='--',alpha=0.3)
                ax.set_ylim(y1,y2)

            else:
                y1, y2 = ax.get_ylim()
                ax.text(0.02,y1+np.abs(y1*0.05),'Resp',fontsize=16)
                ax.vlines(0.0,y1,y2,linestyle='--',alpha=0.3)
                ax.set_ylim(y1,y2)

        plt.suptitle(title, y=0.99, fontsize=28)
        plt.subplots_adjust(left=0.05, right=0.975, hspace=0.35, wspace=0.15)
        #plt.show()
        plt.savefig(out_path)
        plt.close()

print 'Done.'

### Power Domain

In [None]:
import os
import numpy as np
import pylab as plt
from pandas import read_csv

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## File parameters.
labels = ['dacc-lh', 'dacc-rh', 'dmpfc-lh', 'dmpfc-rh', 'dlpfc_1-lh', 'dlpfc_1-rh', 'dlpfc_2-lh', 'dlpfc_2-rh', 
          'dlpfc_3-lh', 'dlpfc_3-rh', 'dlpfc_4-lh', 'dlpfc_4-rh', 'dlpfc_5-lh', 'dlpfc_5-rh', 
          'dlpfc_6-lh', 'dlpfc_6-rh', 'pcc-lh', 'pcc-rh', 'racc-lh', 'racc-rh']
# labels = ['FCZ']

space = 'source'
model_name = 'revised'
baseline = (-0.5, -0.1)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Main loop.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
root_dir = '/space/sophia/2/users/EMOTE-DBS/afMSIT/%s' %space
img_dir = '/space/sophia/2/users/EMOTE-DBS/afMSIT/plots/%s' %space
info = read_csv(os.path.join(root_dir, 'afMSIT_%s_info.csv' %space))

for label in labels:

    for freq in ['theta','alpha','beta']:

        ## Intialize figure.
        fig = plt.figure(figsize=(16,8))
        axes = []
        
        ## Initialize labels.
        ylabel = 'Power (dB)'
        title = '%s %s Power' %(label.replace('_',' ').upper(), freq.capitalize())
        out_path = os.path.join(img_dir, 'frequency', '%s_%s.png' %(model_name,title.replace(' ','_').lower()))
            
        for n, analysis in enumerate(['stim', 'resp']):
 
            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
            ### Load data.
            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

            ## Load source data.
            npz = np.load(os.path.join(root_dir, 'afMSIT_%s_%s_%s_%s.npz' %(space,analysis,label,freq)))
            data = npz['data']
            times = npz['times']
            
            ## Load cluster results.
            f = os.path.join(root_dir, 'results', '%s_%s_frequency_results.csv' %(model_name, analysis))
            clusters = read_csv(f)
            
            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
            ### Plot DBS Difference.
            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
            
            ax = plt.subplot2grid((2,2),(n,0))
            for m, color, legend in zip([0,1],['#0571b0','#ca0020'],['DBSoff','DBSon']):
                
                ix, = np.where(info.DBS==m)
                mu = data[ix].mean(axis=0)
                if analysis == 'stim': mu -= mu[(times >= baseline[0])&(times <= baseline[1])].mean()
                se = data[ix].std(axis=0) / np.sqrt(len(ix))
                ax.plot(times, mu, linewidth=2, color=color, label=legend)
                ax.fill_between(times, mu-se, mu+se, color=color, alpha=0.2)
        
            ## Plot significant clusters.
            for ix in np.where((clusters.Label==label)&(clusters.Freq==freq)&
                               (clusters.Contrast=='DBS')&(clusters.FDR<0.05))[0]:
                y1, y2 = ax.get_ylim()
                tmin, tmax = clusters.loc[ix,'Tmin'], clusters.loc[ix,'Tmax']
                ax.fill_between(np.linspace(tmin,tmax,1e3), y1, y2, color='k', alpha=0.2) 
                ax.set_ylim(y1,y2)
                
            axes.append(ax)
                
            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
            ### Plot Interference Difference.
            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

            ax = plt.subplot2grid((2,2),(n,1))
            for m, color, legend in zip([0,1],['#7b3294','#008837'],['Neu','Int']):
                
                ix, = np.where(info.Interference==m)
                mu = data[ix].mean(axis=0)
                if analysis == 'stim': mu -= mu[(times >= baseline[0])&(times <= baseline[1])].mean()
                se = data[ix].std(axis=0) / np.sqrt(len(ix))
                ax.plot(times, mu, linewidth=2, color=color, label=legend)
                ax.fill_between(times, mu-se, mu+se, color=color, alpha=0.2)
                
            ## Plot significant clusters.
            for ix in np.where((clusters.Label==label)&(clusters.Freq==freq)&
                               (clusters.Contrast=='Interference')&(clusters.FDR<0.05))[0]:
                y1, y2 = ax.get_ylim()
                tmin, tmax = clusters.loc[ix,'Tmin'], clusters.loc[ix,'Tmax']
                ax.fill_between(np.linspace(tmin,tmax,1e3), y1, y2, color='k', alpha=0.2)    
                ax.set_ylim(y1,y2)
            
            axes.append(ax)
            
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
        ### Add Flourishes
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

        for i, ax in enumerate(axes):

            ax.legend(loc=2, fontsize=18, markerscale=2, frameon=False, borderpad=0, handletextpad=0.2)
            ax.set_xlabel('Time (s)', fontsize=18)
            if not i%2: ax.set_ylabel(ylabel, fontsize=20)
            ax.tick_params(axis='both', which='major', labelsize=12)
            if not i/2: ax.set_title('Stimulus-Locked', fontsize=24)
            else: ax.set_title('Response-Locked', fontsize=24)

            ## Time-lock specific.
            if not i/2:
                y1, y2 = ax.get_ylim()
                for x,s in zip([0, 0.4, 1.127],['IAPS','MSIT','Resp']): 
                    ax.text(x+0.02,y1+np.abs(y1*0.05),s,fontsize=16)
                    ax.vlines(x,y1,y2,linestyle='--',alpha=0.3)
                ax.set_ylim(y1,y2)

            else:
                y1, y2 = ax.get_ylim()
                ax.text(0.02,y1+np.abs(y1*0.05),'Resp',fontsize=16)
                ax.vlines(0.0,y1,y2,linestyle='--',alpha=0.3)
                ax.set_ylim(y1,y2)

        plt.suptitle(title, y=0.99, fontsize=28)
        plt.subplots_adjust(left=0.05, right=0.975, hspace=0.35, wspace=0.15)
        #plt.show()
        plt.savefig(out_path)
        plt.close()

print 'Done.'

## Secondary Results
### Intertrial Interval Spectra Comparison


In [None]:
import os
import numpy as np
import pylab as plt
from cmap_utils import *

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Load and extract data.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Load data.
root_dir = '/space/sophia/2/users/EMOTE-DBS/afMSIT/source'
npz = np.load( os.path.join(root_dir, 'afMSIT_source_iti_dlpfc_5-lh_spectra.npz') )

## Extract data.
power = npz['power']
times = npz['times']
freqs = npz['freqs']
dbs = (npz['conds'] > 2).astype(int)

## Compute contrasts.
dbs_off = np.log10( np.median( power[dbs==0], axis=0 ) ) * 10
dbs_on = np.log10( np.median( power[dbs==1], axis=0 ) ) * 10
contrast = dbs_on - dbs_off

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Plot spectrograms.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

fig, axes = plt.subplots(1,3,figsize=(15,5),sharex=True)
vmin, vmax = np.floor( np.min([dbs_off, dbs_on]) ), np.ceil( np.max([dbs_off, dbs_on]) )


## Plot DBS-off.
cmap = center_color_map(dbs_off, 'jet')
cbar = axes[0].imshow(dbs_off , aspect='auto', origin='lower', cmap=cmap, vmin=vmin, vmax=vmax)
plt.colorbar(cbar, ax=axes[0],  ticks=np.arange(vmin, vmax+0.01, 0.5))
axes[0].set_ylabel('Frequency', fontsize=18)

## Plot DBS-on.
cmap = center_color_map(dbs_on, 'jet')
cbar = axes[1].imshow(dbs_on , aspect='auto', origin='lower', cmap=cmap, vmin=vmin, vmax=vmax)
plt.colorbar(cbar, ax=axes[1], ticks=np.arange(vmin, vmax+0.01, 0.5))

## Plot contrast.
cmap = center_color_map(contrast, 'jet')
cbar = axes[2].imshow(contrast , aspect='auto', origin='lower', cmap=cmap, vmin=-1, vmax=1)
plt.colorbar(cbar, ax=axes[2], ticks=np.arange(-1,1.1,0.5))

## Add flourishes.
for ax, title in zip(axes, ['DBSoff', 'DBSon', 'On-Off']):
    
    ax.set_xticks(np.linspace(0, times.shape[0], 6))
    ax.set_xticklabels(np.linspace(times.min(), times.max(), 6).round(2))
    ax.set_xlabel('Time (s)', fontsize=16)
    
    ax.set_yticks([0,5,10,15,18,20,22,24])
    ax.set_yticklabels(freqs[[0,5,10,15,18,20,22,24]].astype(int))
    ax.set_title(title, fontsize=20)
    
plt.subplots_adjust(left=0.05, right=0.98, top=0.8)
plt.suptitle('Intertrial Power Spectra', fontsize=28)
plt.show()

### Intertrial Interval Power Spectrum Density

In [None]:
import os
import numpy as np
import pylab as plt
from mne import EpochsArray, create_info
from mne.time_frequency import psd_multitaper

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Data parameters.
subjects = ['BRTU','CHDR', 'CRDA', 'JADE', 'JASE', 'M5', 'MEWA', 'S2']
method = 'dSPM'
h_freq = 50
roi = 'dlpfc_5-lh'
sfreq = 1450

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Setup.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
root_dir = '/space/sophia/2/users/EMOTE-DBS/afMSIT/source'

epochs, conds = [], []
for subject in subjects:

    ## Load NPZ.
    npz = np.load(os.path.join(root_dir, 'stcs', '%s_msit_iti_%s_%s_epochs.npz' %(subject,method,h_freq)))
    epochs.append(npz['ltcs'])
    conds.append(npz['conds'])

## Concatenate.
epochs = np.concatenate(epochs, axis=0)
conds = np.concatenate(conds, axis=0)
times = npz['times']

## Make into epochs object.
n_trials, n_channels, n_times = epochs.shape

## Make info.
info = create_info([roi], sfreq, ['eeg'])

## Make events.
events = np.zeros((n_trials,3), dtype=int)
events.T[0] = np.arange(n_trials)
events.T[-1] = conds
event_id = dict( FN=1, FI=2, NN=3, NI=4 )

epochs = EpochsArray(epochs, info, events, tmin=-1.25, event_id=event_id, verbose=False)

## Compute PSD.
psd_dbsoff, freqs = psd_multitaper(epochs[['FI','FN']], fmin=0, fmax=50, n_jobs=3, verbose=False)
psd_dbson,  freqs = psd_multitaper(epochs[['NI','NN']], fmin=0, fmax=50, n_jobs=3, verbose=False)

In [None]:
fig, ax = plt.subplots(1,1,figsize=(8,8))

for arr, label, color in zip([psd_dbsoff,psd_dbson], ['DBSoff','DBSon'], ['#0571b0','#ca0020']):
    
    
    mu = arr.squeeze().mean(axis=0)
    mask = (freqs <= 30)
    auc = mu[mask].sum()
    
    arr2 = arr / auc
    
    mu = arr2.squeeze().mean(axis=0)
    sd = arr2.squeeze().std(axis=0)
    se = sd / np.sqrt(arr2.shape[0])
    
    ax.plot(freqs, mu, linewidth=3, label=label, color=color)
    ax.fill_between(freqs, mu-se, mu+se, color=color, alpha=0.2)
    
ax.legend(loc=1)

plt.tight_layout()
plt.show()

In [None]:
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Plot spectrograms.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

fig, axes = plt.subplots(1,3,figsize=(15,5),sharex=True)
vmin, vmax = np.floor( np.min([dbs_off, dbs_on]) ), np.ceil( np.max([dbs_off, dbs_on]) )


## Plot DBS-off.
cmap = center_color_map(dbs_off, 'jet')
cbar = axes[0].imshow(dbs_off , aspect='auto', origin='lower', cmap=cmap, vmin=vmin, vmax=vmax)
plt.colorbar(cbar, ax=axes[0],  ticks=np.arange(vmin, vmax+0.01, 0.5))
axes[0].set_ylabel('Frequency', fontsize=18)

## Plot DBS-on.
cmap = center_color_map(dbs_on, 'jet')
cbar = axes[1].imshow(dbs_on , aspect='auto', origin='lower', cmap=cmap, vmin=vmin, vmax=vmax)
plt.colorbar(cbar, ax=axes[1], ticks=np.arange(vmin, vmax+0.01, 0.5))

## Plot contrast.
cmap = center_color_map(contrast, 'jet')
cbar = axes[2].imshow(contrast , aspect='auto', origin='lower', cmap=cmap, vmin=-1, vmax=1)
plt.colorbar(cbar, ax=axes[2], ticks=np.arange(-1,1.1,0.5))

## Add flourishes.
for ax, title in zip(axes, ['DBSoff', 'DBSon', 'On-Off']):
    
    ax.set_xticks(np.linspace(0, times.shape[0], 6))
    ax.set_xticklabels(np.linspace(times.min(), times.max(), 6).round(2))
    ax.set_xlabel('Time (s)', fontsize=16)
    
    ax.set_yticks([0,5,10,15,18,20,22,24])
    ax.set_yticklabels(freqs[[0,5,10,15,18,20,22,24]].astype(int))
    ax.set_title(title, fontsize=20)
    
plt.subplots_adjust(left=0.05, right=0.98, top=0.8)
plt.suptitle('Intertrial Power Spectra', fontsize=28)
plt.show()

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Plot power spectrum densities.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

fig, ax = plt.subplots(1,1,figsize=(8,8))

ax.plot(freqs, )

# Manuscript Figures

## Methods: AIC Plot

In [None]:
import rpy2
import numpy as np
import pylab as plt
import seaborn as sns
from pandas import DataFrame, read_csv
%load_ext rpy2.ipython
%R require(lme4)

## Load data.
df = read_csv('EMOTE_behav_data.csv')
df = df[np.where(df.responseTimes,True,False)].reset_index(drop=True)

In [None]:
%%R -i df -o AICs,BICs

formula = 'responseTimes ~ (1|subject)'
variables = c('interference','DBS','valence','arousal','TrialNum')
AICs = c()
BICs = c()

for (variable in variables){
    formula = paste(formula, variable, sep=' + ')
    model = glmer(formula, data=df, family=Gamma(link='inverse'))
    AICs = c(AICs, AIC(model))
    BICs = c(BICs, BIC(model))
}

In [None]:
## Build dataframe
variables = np.array(['Con','DBS','Val','Aro','Trial'] * 2)
metrics = np.concatenate([ ['AIC'] * 5, ['BIC'] * 5 ])
fits = DataFrame(dict(Fit = np.concatenate([AICs,BICs]),
                      Model = variables,
                      Metric = metrics))
fits.Fit = np.sign(fits.Fit) * np.log(np.abs(fits.Fit))

## Plotting
fig, ax = plt.subplots(1,1,figsize=(6,6))
sns.set(style="white", font_scale=1.75)
g = sns.factorplot(x='Model', y='Fit', hue='Metric', data=fits, 
               palette='colorblind', kind='point', ax=ax, legend=1)

## Flourishes
ax.legend_.set_title(None)
ax.set_xticklabels(fits.Model.unique(), ha='center', weight='bold')
ax.set_xlabel('')
ax.set_yticks(np.linspace(-8.52, -8.48, 2))
ax.set_ylabel('', labelpad=-1.0)
ax.text(-1.0, -8.5, 'Model Deviance (Log Scale)', rotation=90, va='center')

plt.close()
plt.tight_layout()
plt.show()

## Results: 4-panel Behavior Plot

In [None]:
import os
import numpy as np
import pylab as plt
import seaborn as sns
from pandas import read_csv
sns.set(style="white", font_scale=1.75)

def standard_error(x,y):
    values, counts = np.unique(x,return_counts=True)
    return np.array([ np.std( y[x==v] ) / np.sqrt(c) for v,c in zip(values, counts) ])

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Initialize.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

fig, axes = plt.subplots(2,2,figsize=(12,12))
df = read_csv('EMOTE_behav_data.csv')

## DBS effect.
df.DBS = np.where(df.DBS,'ON','OFF')
df = df.sort_values('DBS')
# sns.barplot('DBS', 'origResponseTimes', palette='Greys', data=df, ci=None, ax=axes[0,0])
sns.barplot('DBS', 'origResponseTimes', palette=sns.color_palette(['#0571b0','#ca0020']), 
            data=df, ci=None, ax=axes[0,0])

## Interference effect.
df.interference = np.where(df.interference, 'Interference', 'Control')
sns.barplot('interference', 'origResponseTimes', palette=sns.color_palette(['#7b3294','#008837']), 
            data=df, ci=None, ax=axes[0,1])

## Valence effect.
bins = 10
df.valence = np.digitize(df.valence, np.linspace(0,1.01,bins))
# sns.factorplot('valence', 'origResponseTimes', color='Grey', data=df, ax=axes[1,0])
sns.factorplot('valence', 'origResponseTimes', color='#02818a', data=df, ax=axes[1,0])
plt.close()

## Arousal effect.
df.arousal = np.digitize(df.arousal, np.linspace(0,1.01,bins))
# sns.factorplot('arousal', 'origResponseTimes', color='Grey', data=df, ax=axes[1,1])
sns.factorplot('arousal', 'origResponseTimes', color='#e6550d', data=df, ax=axes[1,1])
plt.close()

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Flourishes.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

titles = [['DBS','Conflict'],['Valence','Arousal']]

for n in range(2):
    for m in range(2):
        if not n: 
            axes[n,m].set_ylim(0,1.0)
            axes[n,m].hlines(0.875+m*0.075,0.0,1.0)
            axes[n,m].vlines([0,1], 0.85+m*0.075, 0.875+m*0.075)
        else: 
            axes[n,m].set_ylim(0.6,1.2)
            axes[n,m].set_xticks(np.arange(0,9,2))
            axes[n,m].set_xticklabels(np.arange(0,9,2)+1)
        axes[n,m].set_xlabel(titles[n][m], fontsize=24, weight='bold')
        axes[n,m].set_ylabel('Response Time (s)', fontsize=20, weight='bold')

        
plt.subplots_adjust(top=0.97, bottom=0.07, hspace=0.25, wspace=0.32)
plt.show()