# Figures (Original Submission)

### MSIT Overlay

In [None]:
import os
from surfer import Brain
%matplotlib qt4

fs_dir = '/autofs/space/sophia_002/users/EMOTE-DBS/freesurfs'
subj_dir = os.environ["SUBJECTS_DIR"]

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Surface parameters.
subject = "fscopy"
surf = "inflated"
hemi = 'lh'

## I/O parameters.
overlay = os.path.join(fs_dir, subject, 'label', 'april2016', 'darpa_msit_overlay-lh.mgz')
color = '#AFFF94'

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Make Figure.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

labels = ['dacc-lh', 'dmpfc-lh', 'dlpfc_1-lh', 'dlpfc_2-lh', 'dlpfc_3-lh', 
          'dlpfc_4-lh', 'dlpfc_5-lh', 'dlpfc_6-lh', 'pcc-lh', 'racc-lh']

brain = Brain(subject, hemi, surf, background='white')
for label in labels:
    label = os.path.join(fs_dir, subject, 'label', 'april2016', '%s.label' %label)
    brain.add_label(label, color=color, alpha=1, borders=3)
brain.add_overlay(overlay, min=1.301, max=5, sign='pos', name='msit')
brain.overlays['msit'].pos_bar.visible = False

## Lateral view.
brain.show_view(dict(azimuth=150, roll=90), distance=350)
brain.save_image('plots/manuscript/fig1/msit_overlay_lateral.png')

## Medial view.
brain.show_view('medial', distance=425)
brain.save_image('plots/manuscript/fig1/msit_overlay_medial.png')

In [None]:
import os
from surfer import Brain
%matplotlib qt4

fs_dir = '/media/SZORO/arc-fir/recons/'

brain = Brain('fscopy', 'lh', 'pial', subjects_dir=fs_dir)
# brain.add_label('/media/SZORO/arc-fir/recons/fscopy/label/laus125/superiorfrontal_4-lh.label')
# brain.add_label('/media/SZORO/arc-fir/recons/fscopy/label/laus125/caudalmiddlefrontal_1-lh.label')
# brain.show_view('medial')
brain.add_label('/media/SZORO/EMOTE-DBS/freesurfs/fscopy/label/april2016/dlpfc_2-lh.label')

## Figure 2

### Grand Average Topoplots (Time-Domain)

In [None]:
import os
import numpy as np
import pylab as plt
from mne import EpochsArray, combine_evoked, grand_average, read_epochs, set_log_level
from mne.channels import read_montage
from mne.filter import low_pass_filter
set_log_level(verbose=False)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

subjects = ['BRTU','CHDR','CRDA','JADE','JASE','M5','MEWA','S2']
analysis = 'resp'
task = 'msit'
h_freq = 50

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Load data.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
montage = read_montage('standard_1020')

evokeds = []

for subject in subjects:

    ## Load epochs.
    epochs = read_epochs('ave/%s_%s_%s_%s-epo.fif' %(subject,task,h_freq,analysis))

    ## Update channel names according to montage.
    ch_map = dict()
    for ch in epochs.ch_names:
        ix = [m.lower() for m in montage.ch_names].index(ch.lower())
        ch_map[ch] = montage.ch_names[ix]
    epochs.rename_channels(ch_map)

    ## Set montage.
    epochs.set_montage(montage)

    ## Lowpass filter. Reassemble.
    data = epochs.get_data()        
    data = low_pass_filter(data, epochs.info['sfreq'], 15., filter_length='2s', n_jobs=3,)
    epochs = EpochsArray(data, epochs.info, epochs.events, epochs.tmin, epochs.event_id, proj=False)

    ## Compute evoked.
    evokeds.append( epochs.average() )

## Compute grand average.
evokeds = grand_average(evokeds)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Plotting.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

if analysis == 'stim': 

    fig = plt.figure(figsize=(6,6))
    ax = plt.subplot2grid((1,1),(0,0))
    evokeds.plot_topomap(times = 0.52, cmap='spectral', colorbar=False, average=0.05, axes=ax)
    
    fig = plt.figure(figsize=(6,6))
    ax = plt.subplot2grid((1,1),(0,0))
    evokeds.plot_topomap(times = 0.52, cmap='spectral', colorbar=True, average=0.05, axes=ax)   

elif analysis == 'resp':
    
    fig = plt.figure(figsize=(6,6))
    ax = plt.subplot2grid((1,1),(0,0))
    evokeds.plot_topomap(times = -0.7, cmap='spectral', colorbar=False, average=0.05, axes=ax)
    
    fig = plt.figure(figsize=(6,6))
    ax = plt.subplot2grid((1,1),(0,0))
    evokeds.plot_topomap(times = -0.7, cmap='spectral', colorbar=True, average=0.05, axes=ax)


### dACC Figure

In [None]:
import os
import numpy as np
import pylab as plt
import seaborn as sns
from pandas import read_csv
sns.set_style("white")
sns.set_context('notebook', font_scale=2)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## File parameters.
model_name = 'revised'
space = 'source'
label = 'dacc-lh'
freq = 15

## Plotting parameters.
contrasts = ['Interference','DBS']
palettes = [ ['#7b3294','#008837'], ['#0571b0','#ca0020'] ]
annotations = [ ['Control', 'Interference'], ['DBS OFF','DBS ON'] ]
y1, y2 = -0.2, 0.3

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Main loop.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Intialize figure.
fig, axes = plt.subplots(2,2,figsize=(12,8),sharey=True)
info = read_csv(os.path.join(space, 'afMSIT_%s_info.csv' %space))

for n, contrast, colors, legends in zip(range(2), contrasts, palettes, annotations):
    
    for m, analysis in enumerate(['stim', 'resp']):

        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
        ### Load data.
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

        ## Load source data.
        npz = np.load(os.path.join(space, 'afMSIT_%s_%s_%s_%s.npz' %(space,analysis,label,freq)))
        data = npz['data']
        times = npz['times']

        ## Load cluster results.
        f = os.path.join(space, 'results', '%s_%s_timedomain_results.csv' %(model_name, analysis))
        clusters = read_csv(f)
        
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
        ### Plotting.
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

        for i, color, legend in zip(range(2),colors,legends):

            ix, = np.where(info[contrast]==i)
            mu = data[ix].mean(axis=0)
            se = data[ix].std(axis=0) / np.sqrt(len(ix))
            axes[n,m].plot(times, mu, linewidth=3, color=color, label=legend)
            axes[n,m].fill_between(times, mu-se, mu+se, color=color, alpha=0.2)

        ## Plot significant clusters.
        axes[n,m].set_ylim(-0.2,0.2)
        for ix in np.where((clusters.Label==label)&(clusters.Freq==freq)&
                           (clusters.Contrast==contrast)&(clusters.FDR<0.05))[0]:
            tmin, tmax = clusters.loc[ix,'Tmin'], clusters.loc[ix,'Tmax']
            axes[n,m].fill_between(np.linspace(tmin,tmax,1e3), y1, y2, color='k', alpha=0.2)    
    
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Add flourishes.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

for n in range(2):
    
    for m in range(2):
        
        ## Stimulus-locked edits.
        if not m:
            
            ## Fix axes.
            xticks = np.array([0.0, 0.4, 0.9, 1.4])
            axes[n,m].set(xticks=xticks, xticklabels=xticks - 0.4, 
                          xlim=(-0.25,1.5), ylim=(y1,y2))

            
            ## Add markers.
            for x,s in zip([0, 0.4, 1.127],['IAPS','MSIT','Resp']): 
                axes[n,m].text(x+0.02,y1+np.abs(y1*0.05),s,fontsize=22)
                axes[n,m].vlines(x,y1,y2,linestyle='--',alpha=0.3)
                
        ## Response-locked edits.
        else:
            
            ## Fix axes.
            xticks = np.array([-1.0, -0.5, 0.0, 0.5, 1.0])
            axes[n,m].set(xticks=xticks, xlim=(-1.0, 1.0))
                        
            ## Add markers
            axes[n,m].text(0.02,y1+np.abs(y1*0.05),'Resp',fontsize=22)
            axes[n,m].vlines(0.0,y1,y2,linestyle='--',alpha=0.3)
        
            ## Add legends above plot.
            axes[n,m].legend(loc=1, handlelength=1.2, handletextpad=0.5, 
                             labelspacing=0.1, borderpad=0)
        
        ## Add y-labels.
        if n: axes[n,m].set_xlabel('Time (s)')
            

sns.despine()
plt.subplots_adjust(top=0.97, left = 0.08, right = 0.98, 
                    bottom=0.1, hspace=0.35, wspace=0.1)
plt.savefig('plots/manuscript/fig2/dacc_erp.png')
plt.savefig('plots/manuscript/fig2/dacc_erp.svg')
plt.show()
plt.close()

### Significant ERP Clusters

In [None]:
import os
import numpy as np
import pylab as plt
import seaborn as sns
from pandas import read_csv
sns.set_style("white")
sns.set_context('notebook', font_scale=2.5)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Load and prepare data.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Load data.
f = 'source/results/revised_stim_timedomain_results.csv'
df = read_csv(f)

## Limit data.
df = df[df.FDR<0.05].reset_index(drop=True)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Plotting.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

fig, ax = plt.subplots(1,1,figsize=(6,12))

labels = ['racc-lh', 'dacc-lh', 'pcc-lh', 'dmpfc-lh', 'dlpfc_1-lh', 'dlpfc_2-lh', 
          'dlpfc_3-lh', 'dlpfc_4-lh', 'dlpfc_5-lh', 'dlpfc_6-lh', 
          'racc-rh', 'dacc-rh', 'pcc-rh', 'dmpfc-rh', 'dlpfc_1-rh', 'dlpfc_2-rh', 
          'dlpfc_3-rh', 'dlpfc_4-rh', 'dlpfc_5-rh', 'dlpfc_6-rh']

for n in range(len(df)):
    
    if df.loc[n,'Contrast'] == 'Interference': color = '#008837'
    elif df.loc[n,'Contrast'] == 'nsArousal': color = '#e6550d'
    else: continue
        
    y = labels[::-1].index(df.loc[n,'Label'])
    ax.fill_between(df.loc[n,['Tmin','Tmax']].astype(float), y+0.05, y+0.95, color=color, alpha=0.8)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Add flourishes.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Add legend.
for label, color in zip(['Interference','Arousal'],['#008837','#e6550d']): 
    ax.plot([],[],lw=10,color=color,label=label,alpha=0.7)
ax.legend(bbox_to_anchor=(0.7,1.1), handlelength=1.25, borderaxespad=0)
    
## Add timing details.
y1, y2 = 0, len(labels)
for x,s in zip([0, 0.4, 1.127],['IAPS','MSIT','Resp']): 
    ax.text(x+0.02,0.25,s,fontsize=20)
    ax.vlines(x, y1, y2, linewidth=2.5, linestyle='--',alpha=0.2)    

## Fix x-axis.
xticks = np.array([0.0, 0.4, 0.9, 1.4])
ax.set(xticks=xticks, xticklabels=xticks-0.4, xlim=(-0.25,1.5),xlabel='Time (s)')

## Fix y-axis.
labels = ['rACC', 'dACC', 'mCC', 'SFG', 'pMFG 1', 'pMFG 2', 'aMFG 1', 'aMFG 2', 'aIFG', 'pIFG'] * 2
ax.set(yticks=np.arange(len(labels))+0.5, yticklabels=labels[::-1], ylim=(0,len(labels)))

## Add dendrograms.
def dendrogram(ax, x, y1, y2, text):
    
    ## Parameters
    lw = 2.0
    alpha = 0.2
    
    ## Drawing
    ax.annotate('', (x, y1), xycoords='axes fraction', xytext=(x,y2), 
                arrowprops=dict(arrowstyle='-', color='k',  linewidth=lw, alpha=alpha))
    ax.annotate('', (x*1.02,y1), xycoords='axes fraction', xytext=(-1e-3,y1), 
                arrowprops=dict(arrowstyle='-', color='k',  linewidth=lw, alpha=alpha))
    ax.annotate('', (x*1.02,y2), xycoords='axes fraction', xytext=(-1e-3,y2), 
                arrowprops=dict(arrowstyle='-', color='k',  linewidth=lw, alpha=alpha))
    ax.annotate(text, (0,0), xycoords='axes fraction', xytext=(x*1.4, np.mean([y1,y2])), 
                rotation=90, va='center')

dendrogram(ax, -0.38, 0, 0.495, 'Right Hemisphere')
dendrogram(ax, -0.38, 0.505, 1, 'Left Hemisphere')

sns.despine()
plt.subplots_adjust(left=0.35, right=0.975, top=0.925, bottom=0.075)
plt.savefig('plots/manuscript/fig2/all_erps.png', dpi=180)
plt.savefig('plots/manuscript/fig2/all_erps.svg', dpi=180)
plt.show()
plt.close()

## Figure 3

### Cluster Mass Calculations

In [None]:
import numpy as np
from pandas import read_csv, concat

combined = []
for analysis in ['stim', 'resp']:
    
    ## Load info.
    df = read_csv('source/results/revised_%s_frequency_results.csv' %analysis)
    df = df[(df.Contrast=='DBS') & (df.Freq=='theta') & (df.FDR<0.05)]
    
    ## Define times.
    if analysis == 'stim': times = np.arange(0,1.5,1/1450.)
    elif analysis == 'resp': times = np.arange(-1,1,1/1450.)

    ## Make events mask.
    if analysis == 'stim': mask_eve = (times > 0.4) & (times < 1.127)
    elif analysis == 'resp': mask_eve = (times < 0)
    
    ## Iteratively compute percentage within window.
    percentages = []
    for _, row in df.iterrows():    
    
        ## Make significance mask.
        mask_sig = np.zeros_like(times)
        mask_sig += (times > row.Tmin) & (times < row.Tmax)
        mask_sig = mask_sig.astype(bool)

        ## Compute overlap.
        overlap = np.logical_and(mask_sig, mask_eve).sum() / mask_sig.sum().astype(float)
        percentages.append( overlap * 1e2 )
        
    print('%s: %0.3f' %(analysis, np.mean(percentages)))
    
    combined.append(df)
    
## Compute hemisphere dominance.
combined = concat(combined)
combined['hemi'] = ['lh' if label.endswith('lh') else 'rh' for label in combined.Label]
gb = combined.groupby('hemi').Tdiff.sum()
print(gb['lh'] / gb.sum())

### Grand Average Topoplots (Power Domain)

In [None]:
import os
import numpy as np
import pylab as plt
from mne import concatenate_epochs, read_epochs, set_log_level
from mne.channels import read_montage
from mne.time_frequency import tfr_morlet
from mne.viz.topomap import _prepare_topo_plot
set_log_level(verbose=False)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

subjects = ['BRTU','CHDR','CRDA','JADE','JASE','M5','MEWA','S2']
analysis = 'stim'
task = 'msit'
h_freq = 50

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Main loop.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
montage = read_montage('standard_1020')

for analysis in ['stim','resp']:

    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    ### Load data.
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    
    print 'Beginning processing for %s.' %analysis,
    data = []
    
    for subject in subjects:

        ## Load epochs.
        epochs = read_epochs('ave/%s_%s_%s_%s-epo.fif' %(subject,task,h_freq,analysis))

        ## Update channel names according to montage.
        ch_map = dict()
        for ch in epochs.ch_names:
            ix = [m.lower() for m in montage.ch_names].index(ch.lower())
            ch_map[ch] = montage.ch_names[ix]
        epochs.rename_channels(ch_map)

        ## Set montage.
        epochs.set_montage(montage)

        ## Subtract evoked.
        epochs = epochs.subtract_evoked()
        
        ## Compute evoked.
        epochs.info['projs'] = []
        data.append( epochs )

    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    ### Concatenate epochs.
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

    ## Identify list of common channels.
    channels = np.concatenate([epochs.ch_names for epochs in data])
    channels = [ch for ch, count in zip(*np.unique(channels, return_counts=True)) 
                if count==len(subjects)]

    ## Iteratively drop non-common channels.
    for n in range(len(subjects)):

        epochs = data[n]
        epochs = epochs.drop_channels([ch for ch in epochs.ch_names if ch not in channels])
        data[n] = epochs

    data = concatenate_epochs(data)

    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    ### TFR.
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

    ## Compute TFR. 
    freqs = np.arange(4,8+1e-6,2)
    n_cycles = 3
    tfr = tfr_morlet(data, freqs, n_cycles, return_itc=False, verbose=False)
    
    ## Compute baseline.
    if analysis == 'stim':
        mask = (tfr.times >= -0.5) & (tfr.times <= -0.1)
        baseline = np.median(tfr.data[:,:,mask], axis=-1)
        
    ## Baseline correct.
    data = tfr.data.copy().T / baseline.T
    data = 10 * np.log10(data.T)
    
    ## Average.
    data = np.apply_along_axis(np.median, 1, data).squeeze()
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    ### Plotting.
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    _, pos, _, _, _ = _prepare_topo_plot(tfr, 'eeg', None)
    np.savez_compressed('plots/manuscript/fig3/%s' %analysis, data=data, times=tfr.times, pos=pos)
    
print 'Done.'

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from mne.viz import plot_topomap
%matplotlib inline

for analysis, v in zip(['stim', 'resp'], [3.5, 4.5]):

    ## Load data.
    npz = np.load('plots/manuscript/fig3/%s.npz' %analysis)
    data = npz['data']
    times = npz['times']
    pos = npz['pos']
    
    ## Plot.
    if analysis == 'stim': mask = (times > 0.4) & (times < 0.8)
    elif analysis == 'resp': mask = (times > -0.2) & (times < 0.2)
    plot_topomap(data[:,mask].mean(axis=-1), pos, cmap='spectral', 
                 vmin=-v, vmax=v, contours=0)

### DLPFC_5-LH Theta

In [None]:
import os
import numpy as np
import pylab as plt
import seaborn as sns
from pandas import read_csv
sns.set(style="white", font_scale=1.00)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## I/O parameters.
space = 'source'
model_name = 'revised'

label = 'dlpfc_5-lh'
freq = 'theta'
contrast = 'DBS'

baseline = (-0.5, -0.1)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Main loop.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Intialize figure.
fig, axes = plt.subplots(1, 2, figsize=(12,4), dpi=300)

colors = ['#0571b0','#ca0020']
labels = ['DBS OFF','DBS ON']

for ax, analysis in zip(axes, ['stim','resp']):

    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    ### Load data.
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    
    ## Load trial information
    info = read_csv(os.path.join(space, 'afMSIT_%s_info.csv' %space))

    ## Load source data.
    npz = np.load(os.path.join(space, 'afMSIT_%s_%s_%s_%s.npz' %(space,analysis,label,freq)))
    data = npz['data']
    times = npz['times']

    ## Load cluster results.
    f = os.path.join(space, 'results', '%s_%s_frequency_results.csv' %(model_name, analysis))
    clusters = read_csv(f)
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    ### Main plotting.
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    
    ## Plot lines.
    for m, color, legend in zip([0,1],colors,labels):

        ## Identify DBS on/off trials.
        ix, = np.where(info.DBS==m)
        
        ## Compute average time course.
        mu = data[ix].mean(axis=0)
        
        ## If stimulus-locked, baseline subtract.
        if analysis == 'stim': mu -= mu[(times >= baseline[0])&(times <= baseline[1])].mean()
            
        ## Compute standard error. 
        se = data[ix].std(axis=0) / np.sqrt(len(ix))
        
        ## Plotting.
        ax.plot(times, mu, linewidth=3, color=color, label=legend)
        ax.fill_between(times, mu-se, mu+se, color=color, alpha=0.15)

    ## Plot significant clusters.
    for ix in np.where((clusters.Label==label)&(clusters.Freq==freq)&
                       (clusters.Contrast==contrast)&(clusters.FDR<0.05))[0]:

        if analysis == 'stim': y1, y2 = -1.0, 2.5
        else: y1, y2 = -1.5, 1.5
        tmin, tmax = clusters.loc[ix,'Tmin'], clusters.loc[ix,'Tmax']
        ax.fill_between(np.linspace(tmin,tmax,1000), y1, y2, color='k', alpha=0.2)   

    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    ### Add flourishes.
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

    ## Universal fixes.
    ax.set_xlabel('Time (s)', fontsize=24)
    ax.tick_params(axis='both', which='major', labelsize=20)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    
    if analysis == 'stim':
        
        ## Fix labels/legends.
        ax.set_ylabel(r' aIFG $\theta$ Power (dB)', fontsize=24)
        ax.legend(loc=2, fontsize=16, frameon=False, borderpad=0)
        
        ## Fix timing.
        xticks = np.array([0.0, 0.4, 0.9, 1.4])
        ax.set_xticks(xticks)
        ax.set_xticklabels(xticks - 0.4)
        ax.set_xlim(-0.25,1.5)
        
        ## Add time markers.
        for x,s in zip([0, 0.4, 1.127],['IAPS','MSIT','Resp']): 
            ax.text(x+0.02,-0.95,s,fontsize=16)
            ax.vlines(x,y1,y2,linestyle='--',alpha=0.3)
        
    elif analysis == 'resp':
        
        ## Add time markers.
        ax.text(0.02, y1+0.05,'Resp', fontsize=16)
        ax.vlines(0.0,y1,y2,linestyle='--',alpha=0.3)
            
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Save figure.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
            
plt.tight_layout()
# plt.show()
plt.savefig('plots/manuscript/fig3/dlpfc_5-lh.png')
plt.savefig('plots/manuscript/fig3/dlpfc_5-lh.svg')
plt.close('all')
print 'Done.'

### Significant Theta Clusters

In [None]:
import os
import numpy as np
import pylab as plt
import seaborn as sns
from pandas import read_csv
sns.set_style("white")
sns.set_context('notebook', font_scale=2.5)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Load and prepare data.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Load data.
f = 'source/results/revised_stim_frequency_results.csv'
df = read_csv(f)

## Limit data.
df = df[df.FDR<0.05]
df = df[df.Freq=='theta'].reset_index(drop=True)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Plotting.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

fig, ax = plt.subplots(1,1,figsize=(6,12))

labels = ['racc-lh', 'dacc-lh', 'pcc-lh', 'dmpfc-lh', 'dlpfc_1-lh', 'dlpfc_2-lh', 
          'dlpfc_3-lh', 'dlpfc_4-lh', 'dlpfc_5-lh', 'dlpfc_6-lh', 
          'racc-rh', 'dacc-rh', 'pcc-rh', 'dmpfc-rh', 'dlpfc_1-rh', 'dlpfc_2-rh', 
          'dlpfc_3-rh', 'dlpfc_4-rh', 'dlpfc_5-rh', 'dlpfc_6-rh']

## Add timing details.
for x,s in zip([0, 0.4, 1.127],['IAPS','MSIT','Resp']): 
    ax.text(x+0.01,-0.6,s,fontsize=20)
    ax.vlines(x, -1, len(labels), linewidth=2.5, linestyle='--',alpha=0.2)  

conds = ['DBS','Interference']
colors = ['#ca0020','#008837']

for n, label in enumerate(labels[::-1]):
    
    for m, contrast in enumerate(conds):
        
        ## Extract clusters.
        clusters = df.loc[(df.Contrast==contrast)&(df.Label==label),['Tmin','Tmax']]
        if not len(clusters): continue
        
        ## Plot clusters.
        y = n + m * 0.5
        for cluster in clusters.as_matrix(): 
            ax.hlines(y+0.25, cluster.min(), cluster.max(), color=colors[m], lw=24)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Add flourishes.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Add legend.
for label, color in zip(conds,colors): 
    ax.plot([],[],lw=10,color=color,label=label,alpha=0.7)
ax.legend(bbox_to_anchor=(0.7,1.1), handlelength=1.25, borderaxespad=0)

## Fix x-axis.
xticks = np.array([0.0, 0.4, 0.9, 1.4])
ax.set(xticks=xticks, xticklabels=xticks-0.4, xlim=(-0.25,1.5),xlabel='Time (s)')

## Fix y-axis.
labels = ['rACC', 'dACC', 'mCC', 'SFG', 'pMFG 1', 'pMFG 2', 'aMFG 1', 'aMFG 2', 'aIFG', 'pIFG'] * 2
ax.set(yticks=np.arange(len(labels))+0.5, yticklabels=labels[::-1], ylim=(-0.7,len(labels)))

## Add dendrograms.
def dendrogram(ax, x, y1, y2, text):
    
    ## Parameters
    lw = 2.0
    alpha = 0.2
    
    ## Drawing
    ax.annotate('', (x, y1), xycoords='axes fraction', xytext=(x,y2), 
                arrowprops=dict(arrowstyle='-', color='k',  linewidth=lw, alpha=alpha))
    ax.annotate('', (x*1.02,y1), xycoords='axes fraction', xytext=(-1e-3,y1), 
                arrowprops=dict(arrowstyle='-', color='k',  linewidth=lw, alpha=alpha))
    ax.annotate('', (x*1.02,y2), xycoords='axes fraction', xytext=(-1e-3,y2), 
                arrowprops=dict(arrowstyle='-', color='k',  linewidth=lw, alpha=alpha))
    ax.annotate(text, (0,0), xycoords='axes fraction', xytext=(x*1.4, np.mean([y1,y2])), 
                rotation=90, fontsize=30, va='center')

dendrogram(ax, -0.38, 0.025, 0.51, 'Right Hemisphere')
dendrogram(ax, -0.38, 0.515, 1, 'Left Hemisphere')

sns.despine()
plt.subplots_adjust(left=0.35, right=0.975, top=0.925, bottom=0.075)
plt.savefig('plots/manuscript/fig3/all_theta.png', dpi=180)
plt.savefig('plots/manuscript/fig3/all_theta.svg', dpi=180)
plt.show()

### Spectral Barplots

In [None]:
import os
import numpy as np
import pylab as plt
import seaborn as sns
from pandas import DataFrame, read_csv
sns.set_style("white")
sns.set_context('notebook', font_scale=2.5)
%matplotlib inline

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## I/O parameters
space = 'source'
model_name = 'revised'
contrast = 'DBS'

## Label parameters.
labels = ['dlpfc_5-lh', 'dlpfc_4-lh', 'pcc-lh']
xlabels = ['aIFG', 'aMFG 2', 'mCC']

## Define averaging parameters.
baseline = (-0.5, -0.1)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Main loop.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Preallocate space.
analyses = []
freqs = []
rois = []
values = []

for analysis in ['stim','resp']:
        
    for label, xlabel in zip(labels,xlabels):
    
        for freq, ffreq in zip(['theta','alpha','beta'],
                               [r'$\theta$',r'$\alpha$',r'$\beta$']):
    
            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
            ### Load data.
            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

            ## Load trial information
            info = read_csv(os.path.join(space, 'afMSIT_%s_info.csv' %space))

            ## Load source data.
            npz = np.load(os.path.join(space, 'afMSIT_%s_%s_%s_%s.npz' %(space,analysis,label,freq)))
            data = npz['data']
            times = npz['times']

            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
            ### Compute differences.
            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

            ## Define mask.
            if analysis == 'stim': tmin, tmax = 0.4, 0.8
            elif analysis == 'resp': tmin, tmax = -0.2, 0.2

            ## Compute averages within window.
            delta = []
            for i in range(2):

                ## Identify DBS on/off trials.
                ix, = np.where(info.DBS==i)

                ## Compute average time course.
                mu = data[ix].mean(axis=0)

                ## Reduce to time of interest.
                mu = mu[(times >= tmin)&(times <= tmax)]
                delta.append(mu)

            ## Compute difference.
            delta = np.diff(delta, axis=0).squeeze()
            
            ## Append information.
            analyses += [analysis] * len(delta)
            freqs += [ffreq] * len(delta)
            rois += [xlabel]* len(delta)
            values += delta.tolist()

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Compute differences.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Conver to DataFrame.
df = DataFrame([analyses, freqs, rois, values], index=('Analysis','Freq','ROI','Delta')).T

## Plot.
g = sns.FacetGrid(df, col='Analysis', size=6, aspect=1.5)
g.map(sns.barplot, 'ROI', 'Delta', 'Freq', ci='sd',
      palette=sns.color_palette(n_colors=3))

## Add flourishes.
for n, ax in enumerate(g.axes.squeeze()):
    x1, x2 = ax.get_xlim()
    ax.hlines(0,x1,x2)
    ax.set(xlabel = '', title='')
    ax.legend(loc=1, labelspacing=0, borderpad=0)
    if not n: ax.set_ylabel('Power (ON - OFF)')
        
plt.savefig('plots/manuscript/fig3/barplots.png', dpi=180)
plt.savefig('plots/manuscript/fig3/barplots.svg', dpi=180)
plt.show()

## Figure 4

### DLPFC_5-LH Correlations/ROC plots

In [None]:
import os
import numpy as np
import pylab as plt
import seaborn as sns
import matplotlib.gridspec as gridspec
from pandas import DataFrame, Series, read_csv
from scipy.stats import pearsonr
from sklearn.metrics import auc, roc_curve
sns.set_style('white')
sns.set_context('notebook', font_scale=2)
%matplotlib inline

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## I/O parameters.
space = 'source'
model = 'revised'
analysis = 'stim'
domain = 'frequency'
contrast = 'DBS'
label = 'dlpfc_5-lh'
freq = 'theta'
fdr = 0.05
    
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Load and prepare clinical data.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

scores = read_csv('behavior/Subject_Rating_Scales.csv', index_col=0)
subjects = scores.index

madrs = scores['MADRS_Now'] - scores['MADRS_Base']
mania = scores['Hypomania']

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Load and prepare reaction time data.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

rt = read_csv('behavior/EMOTE_behav_data.csv')
rt = rt.groupby(['DBS','subject']).origResponseTimes.mean()
rt = rt[1] - rt[0]

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Load and prepare power data.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

info = read_csv(os.path.join(space, 'afMSIT_%s_info.csv' %space))

## Load and limit cluster results.
results = read_csv(os.path.join(space, 'results', '%s_%s_%s_results.csv' %(model,analysis,domain)))
results = results[results.Contrast==contrast]
results = results[results.FDR<fdr]
results = results[results.Label == label].reset_index(drop=True)

## Load time series data.
npz = np.load(os.path.join(space, 'afMSIT_%s_%s_%s_%s.npz' %(space, analysis, label, freq)))
data = npz['data']
times = npz['times']

## Compute condition differences.
delta = np.zeros(subjects.shape[0])
mask = (times >= results.Tmin.min()) & (times <= results.Tmax.max()) # NOTE: collapsing across clusters

for m, subject in enumerate(subjects):
    i, = np.where((info['Subject']==subject)&(info[contrast]==0))
    j, = np.where((info['Subject']==subject)&(info[contrast]==1))
    delta[m] += (data[j][:,mask].mean(axis=0) - data[i][:,mask].mean(axis=0)).mean()
delta = Series(delta, index=subjects)
    
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Construct DataFrame.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#  
np.random.seed(47404)

## Concatenate data.
df = DataFrame([madrs,mania,rt,delta], index=['MADRS','Hypomania','RT','Delta']).T

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Main loop.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 

def swap_arr(x,y):
    return y.copy(), x.copy()

def simple_roc(y,x):
    '''http://blog.revolutionanalytics.com/2016/08/roc-curves-in-two-lines-of-code.html'''
    assert np.all(np.in1d(y, [0,1]))
    y = (y[np.argsort(x)[::-1]]).astype(bool)
    return np.cumsum(y, dtype=float) / y.sum(), np.cumsum(~y, dtype=float) / (~y).sum()

def ROC(y,x):
    
    ## Compute RoC, AUC.
    tpr, fpr = simple_roc(y,x)
    roc_auc = auc(fpr, tpr)
    
    ## Correct for misidentification.
    if roc_auc < 0.5:
        roc_auc = 1 - roc_auc
        tpr, fpr = swap_arr(tpr, fpr)
        
    return tpr, fpr, roc_auc

## Initialize figure.
fig  = plt.figure(figsize=(16,8))

## Define plotting variables.
colors = np.array([['#1f77b4','#2ca02c'], ['#d62728', '#9467bd']])
xticklabels = [['No Response', 'Remission'], ['No History', 'Converted']]
ylabels = [r'$\Delta$ RT (s)', r'$\Delta$ $\theta$-power (dB)']

for n, xlabel in enumerate(['MADRS', 'Hypomania']):
    
    for m, ylabel in enumerate(['RT','Delta']):
           
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
        ### Preparations.
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 
            
        ## Initialize axes.
        if not n: top, bottom = 0.95, 0.6
        else: top, bottom = 0.4, 0.05
        if not m: left, right = 0.05, 0.435
        else: left, right = 0.565, 0.95
        gs = gridspec.GridSpec(1,2)
        gs.update(left=left, right=right, top=top, bottom=bottom, wspace=0.4)
        
        ## Extract variables.
        x, y = df[[xlabel,ylabel]].dropna().as_matrix().T
        
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
        ### Correlation Plot.
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 
    
        ## Plot correlation.
        ax = plt.subplot(gs[0])
        sns.regplot(x, y, df, color=colors[n,m], ax=ax)
        
        ## Add flourishes.
        if not n and not m: ax.set(xticks=[-40,-20,0], xlabel=r'$\Delta$ MADRS')
        elif not n:  ax.set(xticks=[-26,-13,0], xlabel=r'$\Delta$ MADRS')
        else: ax.set(xticks=[0,1], xticklabels=xticklabels[n])
        if not m: ax.set(ylim=(-0.15, 0.10), yticks=[-0.10,0.0,0.10],
                         ylabel=ylabels[m])
        else: ax.set(ylim=(-0.5,2), yticks=np.linspace(-0.5,2,3),
                     ylabel=ylabels[m])
        ax.tick_params(axis='x', which='major', pad=15)
    
        ## Add text.
        r, p = pearsonr(x,y)
        if not m and n: 
            ax.annotate('r = %0.2f, p = %0.2f' %(r,p), xy=(0,0), xytext=(0.05,0.05),
                              xycoords = 'axes fraction', fontsize=16)
        else:
            ax.annotate('r = %0.2f, p = %0.2f' %(r,p), xy=(0,0), xytext=(0.3,0.05),
                              xycoords = 'axes fraction', fontsize=16)
    
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
        ### RoC plots.
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# 
    
        if xlabel == 'MADRS':
            x = np.where(scores['MADRS_Now'] / scores['MADRS_Base'] > 0.5, 0, 1)
            x = x[df[ylabel].notnull()]
    
        ## Initialize plot.
        ax = plt.subplot(gs[1])
        ax.plot(np.linspace(0,1,10),np.linspace(0,1,10),lw=1,linestyle='--',color='k')
        
        ## Plot true RoC.
        tpr, fpr, roc_auc = ROC(x, y)
        ax.plot(fpr, tpr, lw=2, color=colors[n,m])
        
        auc_sim = []
        for i in range(1000):
            
            ## Shuffle values.
            if i: ix = np.random.choice(np.arange(len(x)), len(x), replace=True)
            else: ix = np.arange(len(x))
            x_p, y_p = x[ix].copy(), y[ix].copy() 
            
            ## Compute AUC. 
            _, _, sim = ROC(x_p, y_p)
            auc_sim.append( sim )
            
        ## Plot bootstrapped CI.
        text = 'AUC = %0.2f [%0.2f, %0.2f]' %(roc_auc, np.nanpercentile(auc_sim, 2.5),
                                              np.nanpercentile(auc_sim, 97.5))
        ax.annotate(text, xy=(0,0), xytext=(0.15,0.05), xycoords = 'axes fraction', 
                    fontsize=16)
        
        ## Add flourishes.
        ax.set(xticks=np.linspace(0,1,3), xlim=(-0.01,1.00), xlabel='FPR', 
               yticks=np.linspace(0,1,3), ylim=(0.00,1.01), ylabel='TPR')
    
sns.despine()
plt.savefig('plots/manuscript/fig4/combo_plot.png')
plt.savefig('plots/manuscript/fig4/combo_plot.svg')

## Supplementary Figures

### Figure S3: AIC Plot

In [None]:
import rpy2
import numpy as np
import pylab as plt
import seaborn as sns
from pandas import DataFrame, read_csv
sns.set_style("white")
sns.set_context('notebook', font_scale=2)
%load_ext rpy2.ipython
%R require(lme4)
%matplotlib inline

## Load data.
df = read_csv('behavior/EMOTE_behav_data.csv')
df = df[np.where(df.responseTimes,True,False)].reset_index(drop=True)
df['DBSxInt'] = df.DBS * df.interference
df['AROxVAL'] = df.arousal * df.valence

In [None]:
%%R -i df -o AICs,BICs

formula = 'responseTimes ~ (1|subject)'
variables = c('interference','DBS','valence','arousal','TrialNum', 'AROxVAL')
AICs = c()
BICs = c()

for (variable in variables){
    formula = paste(formula, variable, sep=' + ')
    model = glmer(formula, data=df, family=Gamma(link='inverse'))
    AICs = c(AICs, AIC(model))
    BICs = c(BICs, BIC(model))
}

In [None]:
## Build dataframe
variables = np.array(['Interference','DBS','Valence','Arousal','Trial #',r'Arousal $\cdot$ Valence'] * 2)
metrics = np.concatenate([ ['AIC'] * 6, ['BIC'] * 6 ])
fits = DataFrame(dict(Fit = np.concatenate([AICs,BICs]),
                      Model = variables,
                      Metric = metrics))
fits.Fit = np.sign(fits.Fit) * np.log(np.abs(fits.Fit))

In [None]:
## Plotting
fig, ax = plt.subplots(1,1,figsize=(10,5))
sns.set(style="white", font_scale=1.75)
g = sns.pointplot(x='Model', y='Fit', hue='Metric', data=fits, 
               palette='colorblind', kind='point', ax=ax, legend=1)

## Flourishes
ax.legend_.set_title(None)
ax.set(xlabel='', yticks=[-8.52, -8.50, -8.48])
ax.set_xticklabels(fits.Model.unique(), ha='left', rotation=-15, fontsize='18')
ax.set_ylabel('Model Deviance (Log Scale)')

sns.despine()
plt.tight_layout()
plt.savefig('plots/manuscript/supplementary/aic.png', dpi=300)
plt.savefig('plots/manuscript/supplementary/aic.svg', dpi=300)

### Figure S2: EEfRT Behavior

In [None]:
import numpy as np 
import pylab as plt
import seaborn as sns
from pandas import read_csv
from scipy.stats import gamma, mannwhitneyu
sns.set_style("white")
sns.set_context('notebook', font_scale=2.5)
%matplotlib inline

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define useful functions.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

def gamma_censor(arr, threshold = 0.005):
    '''Fit gamma distribution. Set outlier to NaN.'''
    
    ## Estimate fit. Loc fixed to 0.
    p = gamma.fit(arr, floc=0)
    
    ## Compute likelihood of value.
    likelihood = gamma.cdf(arr, *p)
    
    ## Censor and return.
    return np.where(likelihood < threshold, np.nan, 
                    np.where(likelihood > 1 - threshold, np.nan, arr))

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Load and prepare data.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Load data.
df = read_csv('behavior/clean_eefrt_behavior.csv')

## Remove outlier RTs.
df.ChoiceRT = np.where(df.ChoiceRT > 0.3, df.ChoiceRT, np.nan)       # Remove RTs faster than 300 ms.
df.ChoiceRT = df.groupby('Subject').ChoiceRT.transform(gamma_censor) # Gamma censoring.

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Plotting.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

fig, axes = plt.subplots(1,2,figsize=(15,6))
palette = ['#0571b0','#ca0020']
yvars = ['ChoiceRT', 'ButtonPressRate']
ylabels = ['Response Time (s)','Button Press Rate']
titles = ['EEfRT Choice', 'EEfRT Bar Fill']

for ax, y, ylabel, title in zip(axes, yvars, ylabels, titles):

    ## Compute Wilcoxon signed-rank test.
    U, p = mannwhitneyu(df.loc[df.DBS==0, y], df.loc[df.DBS==1, y],)

    ## Plot.
    sns.barplot('DBS', y, data=df, palette=palette, ax=ax)
    
    ## Add Wilcoxon info.
    _, y2 = ax.get_ylim()
    ax.hlines(y2, 0, 1)
    ax.vlines(0, y2 * 0.975, y2)
    ax.vlines(1, y2 * 0.975, y2)
    ax.text(0.5, y2*1.025, 'U = %0.1f, p = %0.3f' %(U,p), ha='center', fontsize=24)
    ax.set(xlabel='', xticklabels=['OFF','ON'], ylim=(0, y2*1.15), 
           ylabel=ylabel, title=title)

sns.despine()
plt.subplots_adjust(top=0.9, bottom=0.1, left=0.08, right=0.99, wspace=.2)
plt.savefig('plots/manuscript/supplementary/eefrt.png', dpi=180)
plt.savefig('plots/manuscript/supplementary/eefrt.svg', dpi=180)
plt.show()

### Figure S4: FCZ ERP

In [None]:
import os
import numpy as np
import pylab as plt
import seaborn as sns
from pandas import read_csv
sns.set_style("white")
sns.set_context('notebook', font_scale=2.5)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## File parameters.
model_name = 'revised'
space = 'sensor'
label = 'FCZ'
freq = 15

## Plotting parameters.
contrasts = ['Interference','DBS']
palettes = [ ['#7b3294','#008837'], ['#0571b0','#ca0020'] ]
annotations = [ ['Control', 'Interference'], ['DBS OFF','DBS ON'] ]
ylimits = {'stim':(-5,1), 'resp':(-2.5,2.5)}

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Main loop.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Intialize figure.
fig, axes = plt.subplots(2,2,figsize=(12,9))
info = read_csv(os.path.join(space, 'afMSIT_%s_info.csv' %space))

for n, contrast, colors, legends in zip(range(2), contrasts, palettes, annotations):
    
    for m, analysis in enumerate(['stim', 'resp']):

        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
        ### Load data.
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

        ## Load source data.
        npz = np.load(os.path.join(space, 'afMSIT_%s_%s_%s_%s.npz' %(space,analysis,label,freq)))
        data = npz['data']
        times = npz['times']

        ## Load cluster results.
        f = os.path.join(space, 'results', '%s_%s_timedomain_results.csv' %(model_name, analysis))
        clusters = read_csv(f)
        
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
        ### Plotting.
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

        y1, y2 = ylimits[analysis]
        
        for i, color, legend in zip(range(2),colors,legends):

            ix, = np.where(info[contrast]==i)
            mu = data[ix].mean(axis=0)
            se = data[ix].std(axis=0) / np.sqrt(len(ix))
            axes[n,m].plot(times, mu, linewidth=3, color=color, label=legend)
            axes[n,m].fill_between(times, mu-se, mu+se, color=color, alpha=0.2)

        ## Plot significant clusters.
        for ix in np.where((clusters.Label==label)&(clusters.Freq==freq)&
                           (clusters.Contrast==contrast)&(clusters.FDR<0.05))[0]:
            tmin, tmax = clusters.loc[ix,'Tmin'], clusters.loc[ix,'Tmax']
            axes[n,m].fill_between(np.linspace(tmin,tmax,1e3), y1, y2, color='k', alpha=0.2)    
    
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Add flourishes.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

for n in range(2):
    
    for m in range(2):
        
        y1, y2 = ylimits[analysis]
        
        ## Stimulus-locked edits.
        if not m:
            
            ## Fix axes.
            y1, y2 = -5, 1
            xticks = np.array([0.0, 0.4, 0.9, 1.4])
            axes[n,m].set(xticks=xticks, xticklabels=xticks - 0.4, 
                          xlim=(-0.25,1.5), ylim=(y1,y2), 
                          ylabel = r'FCz Voltage ($\mu$V)')

            
            ## Add markers.
            for x,s in zip([0, 0.4, 1.127],['IAPS','MSIT','Resp']): 
                axes[n,m].text(x+0.02,y1+np.abs(y1*0.05),s,fontsize=22)
                axes[n,m].vlines(x,y1,y2,linestyle='--',alpha=0.3)
                
        ## Response-locked edits.
        else:
            
            ## Fix axes.
            xticks = np.array([-1.0, -0.5, 0.0, 0.5, 1.0])
            axes[n,m].set(xticks=xticks, xlim=(-1.0, 1.0), ylim=(y1, y2))
                        
            ## Add markers
            axes[n,m].text(0.02,y1+np.abs(y1*0.05),'Resp',fontsize=22)
            axes[n,m].vlines(0.0,y1,y2,linestyle='--',alpha=0.3)
        
            ## Add legends above plot.
            axes[n,m].legend(loc=1, handlelength=1.2, handletextpad=0.5, 
                             labelspacing=0.1, borderpad=0)
        
        ## Add y-labels.
        if n: axes[n,m].set_xlabel('Time (s)')
            

sns.despine()
plt.subplots_adjust(top=0.97, left = 0.08, right = 0.98, 
                    bottom=0.1, hspace=0.35, wspace=0.1)
plt.savefig('plots/manuscript/supplementary/fcz.png')
plt.savefig('plots/manuscript/supplementary/fcz.svg')
plt.show()
plt.close()

### Figure S5

In [None]:
import os
import numpy as np
import pylab as plt
import seaborn as sns
from pandas import DataFrame, read_csv
sns.set_style("white")
sns.set_context('notebook', font_scale=2.5)
%matplotlib inline

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## I/O parameters
space = 'source'
model_name = 'revised'
contrast = 'DBS'

## Label parameters.
labels = ['racc-lh', 'dacc-lh', 'pcc-lh', 'dmpfc-lh', 'dlpfc_1-lh', 'dlpfc_2-lh', 
          'dlpfc_3-lh', 'dlpfc_4-lh', 'dlpfc_5-lh', 'dlpfc_6-lh', 
          'racc-rh', 'dacc-rh', 'pcc-rh', 'dmpfc-rh', 'dlpfc_1-rh', 'dlpfc_2-rh', 
          'dlpfc_3-rh', 'dlpfc_4-rh', 'dlpfc_5-rh', 'dlpfc_6-rh']
xlabels = ['rACC', 'dACC', 'mCC', 'SFG', 'pMFG 1', 'pMFG 2',
           'aMFG 1', 'aMFG 2', 'aIFG', 'pIFG'] * 2

## Define averaging parameters.
baseline = (-0.5, -0.1)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Main loop.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Preallocate space.
analyses = []
freqs = []
rois = []
legend = []
values = []

for analysis in ['stim','resp']:
        
    for label, xlabel in zip(labels,xlabels):
    
        for freq, ffreq in zip(['theta','alpha','beta'],
                               [r'$\theta$',r'$\alpha$',r'$\beta$']):
    
            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
            ### Load data.
            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

            ## Load trial information
            info = read_csv(os.path.join(space, 'afMSIT_%s_info.csv' %space))

            ## Load source data.
            npz = np.load(os.path.join(space, 'afMSIT_%s_%s_%s_%s.npz' %(space,analysis,label,freq)))
            data = npz['data']
            times = npz['times']

            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
            ### Compute differences.
            #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

            ## Define mask.
            if analysis == 'stim': tmin, tmax = 0.4, 0.8
            elif analysis == 'resp': tmin, tmax = -0.2, 0.2

            ## Compute averages within window.
            delta = []
            for i in range(2):

                ## Identify DBS on/off trials.
                ix, = np.where(info.DBS==i)

                ## Compute average time course.
                mu = data[ix].mean(axis=0)

                ## Reduce to time of interest.
                mu = mu[(times >= tmin)&(times <= tmax)]
                delta.append(mu)

            ## Compute difference.
            delta = np.diff(delta, axis=0).squeeze()
            
            ## Append information.
            analyses += [analysis] * len(delta)
            freqs += [ffreq] * len(delta)
            rois += [xlabel] * len(delta)
            legend += [label] * len(delta)
            values += delta.tolist()

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Plotting.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Conver to DataFrame.
df = DataFrame([analyses, freqs, rois, legend, values], 
               index=('Analysis','Freq','ROI','Label','Delta')).T
df['Hemi'] = [s.split('-')[-1] for s in df.Label]

## Plot.
fig, axes = plt.subplots(4,1,figsize=(15,12))
analyses = ['stim','stim','resp','resp']
hemis = ['lh','rh','lh','rh']
ylabels = ['Left Hemisphere', 'Right Hemisphere', 'Left Hemisphere', 'Right Hemisphere']
titles = ['Stimulus-Locked',False,'Response-Locked',False]

for ax, analysis, hemi, ylabel, title in zip(axes, analyses, hemis, ylabels, titles):
    
    ## Plot.
    ix = np.logical_and(df.Analysis==analysis,df.Hemi==hemi)
    sns.barplot('ROI', 'Delta', 'Freq', df[ix], palette=sns.color_palette(n_colors=3), ax=ax)

    ## Add flouishes.
    ax.hlines(0,*ax.get_xlim())
    ax.set(xlabel = '', yticks=[0,0.5,1])
    ax.set_xticklabels(df.ROI.unique(), fontsize=20, rotation=-15)
    ax.set_ylabel(ylabel, fontsize=20)
    ax.legend(loc=1, bbox_to_anchor=(1.125,0.9), labelspacing=0, borderpad=0, 
              handletextpad=0.25)
    ax.legend_.set_title('Power (On - Off)', prop = {'size':'x-large'})
    if title: ax.set_title(title)

## Draw asterisks.
for ax, analysis, hemi in zip(axes, analyses, hemis):
    
    ## Load significant clusters.
    info = read_csv('source/results/revised_%s_frequency_results.csv' %analysis)
    info = info[np.logical_and(info.Contrast=='DBS', info.FDR<0.05)]
    info = info[[True if label.endswith(hemi) else False for label in info.Label]]
    info = info[['Label','Freq']].drop_duplicates()
    
    ## Iteratively draw asterisks.
    for _, row in info.iterrows():
        
        y = df.loc[(df.Analysis==analysis)&(df.Label==row.Label)&
                   (df.Freq==r'$\%s$' %row.Freq),'Delta'].mean()
        x1 = np.argmax(np.in1d(df.loc[df.Hemi==hemi,'Label'].unique(), row.Label))
        x2 = np.argmax(np.in1d(['theta','alpha','beta'], row.Freq))
        ax.annotate('*', xy=(0,0), xytext=(-0.32 + x1 + 0.28 * x2, y+0.05),
                              xycoords='data', fontsize=24)
        
sns.despine()
plt.subplots_adjust(left=0.075, right=0.9, top=0.95, bottom=0.06, hspace=0.5)
plt.savefig('plots/manuscript/supplementary/S5.png', dpi=180)
plt.savefig('plots/manuscript/supplementary/S5.svg', dpi=180)
plt.show()

### Figure S6: Out of Task Power
####  PSD of eyes-open resting state

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from mne import Epochs, make_fixed_length_events, pick_channels, read_proj, set_log_level
from mne.io import Raw
from mne.time_frequency import psd_multitaper
set_log_level(verbose=False)
sns.set_style("white")
sns.set_context('notebook', font_scale=2.5)
%matplotlib inline

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define Parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## I/O parameters.
subjects = ['BRTU', 'CHDR', 'JADE', 'S2']
conds = ['resting_dbsoff_eo', 'resting_dbson_eo']

## Filtering parameters.
l_freq = 0.5
h_freq = 50
l_trans_bandwidth = l_freq / 2.
h_trans_bandwidth = 1.0
filter_length = '20s'
n_jobs = 3

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Main loop.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

root_dir = '../resting/raw/'

PSD = []
for subject in subjects:
    
    s = []
    for cond in conds:
        
        ## Load data.
        raw = Raw('%s/%s_%s_raw.fif' %(root_dir, subject, cond), preload=True)

        ## Apply projection.
        proj = '%s/%s_%s-proj.fif' %(root_dir, subject, cond)
        if os.path.isfile(proj): raw.add_proj(read_proj(proj))
        else: raw.set_eeg_reference()
        raw = raw.apply_proj()
        
        ## Filter raw.
        raw = raw.filter(l_freq, h_freq, filter_length=filter_length, l_trans_bandwidth=l_trans_bandwidth,
                        h_trans_bandwidth=h_trans_bandwidth)
        
        ## Create epochs.
        events = make_fixed_length_events(raw, 1, start=1, stop=61, duration=1)
        epochs = Epochs(raw, events, tmin=-0.5, tmax=1, baseline=(-0.5,0))

        
        ## Compute PSD
        picks = pick_channels(raw.ch_names, ['FZ'])
        psd, freqs = psd_multitaper(epochs, fmin=l_freq, fmax=30, picks=picks)
        s.append(psd)
        
    PSD.append(s)
    
## Merge into one array.
PSD = np.array(PSD).squeeze().swapaxes(0,1)
n_cond, n_subj, n_trial, n_freq = PSD.shape
PSD = PSD.reshape(n_cond,n_subj*n_trial,n_freq)

## Plot.
fig, ax = plt.subplots(1,1,figsize=(8,4))

for n, color, label in zip(range(2),['#0571b0','#ca0020'],['DBS OFF','DBS ON']):
    
    mu = np.median(PSD[n], axis=0)
    mu /= mu.sum()
    
    ax.plot(freqs, mu, lw=3, color=color, label=label)
    
## Flourishes.
ax.vlines([4,8], 0, 0.15, linestyle='--', alpha=0.5)
ax.set(xlim=(0.5,30), xticks=(0.5,10,20,30), xticklabels=(0,10,20,30), xlabel='Frequency (Hz)',
       ylim=(0,0.12), ylabel='Normalized PSD');
ax.legend(loc=0, borderpad=0, labelspacing=0)

sns.despine()
plt.tight_layout()
plt.savefig('plots/manuscript/supplementary/S6b.png')
plt.savefig('plots/manuscript/supplementary/S6b.svg')

In [None]:
from scipy.stats import mannwhitneyu

mask = np.logical_and(freqs >= 4, freqs <= 8)
psd = PSD[...,mask].mean(axis=-1)
psd = np.log10(psd)
mannwhitneyu(*psd)

#### Power Timecourses of all bands in FZ

In [None]:
import os
import numpy as np
import pylab as plt
import seaborn as sns
from pandas import DataFrame, read_csv
sns.set_style("white")
sns.set_context('notebook', font_scale=2)
%matplotlib inline

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## File parameters.
model_name = 'revised'
space = 'sensor'
label = 'FZ'

baseline = (-0.5, -0.1)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Main loop.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Load trial information.
info = read_csv('%s/afMSIT_sensor_info.csv' %space)
cdict = dict(DBS = ['#0571b0','#ca0020'], Interference = ['#7b3294','#008837'])
ldict = dict(DBS = ['DBS OFF','DBS ON'], Interference = ['Control', 'Interference'])

## Initialize figure.
fig = plt.figure(figsize=(18,12))

for i, analysis in enumerate(['stim','resp']):
    
    results = read_csv('sensor/results/revised_%s_frequency_results.csv' %analysis)
    
    for j, freq in enumerate(['theta','alpha','beta']):
        
        ## Load data.
        npz = np.load('%s/afMSIT_%s_%s_%s_%s.npz' %(space,space,analysis,label,freq))
        data = npz['data']
        times = npz['times']
        
        for k, contrast in enumerate(['DBS', 'Interference']):
            
            ## Initialize canvas.
            ax = plt.subplot2grid((4,3),(k+i*2,j))
            
            ## Plot power.
            for n, color, legend in zip(range(2),cdict[contrast],ldict[contrast]):

                ## Identify indices per condition of contrast.
                ix, = np.where(info[contrast]==n)
                
                ## Compute mean and standard error.
                mu = data[ix].mean(axis=0)
                
                ## If stimulus-locked, baseline subtract.
                if analysis == 'stim': mu -= mu[(times >= baseline[0])&(times <= baseline[1])].mean()
                
                se = data[ix].std(axis=0) / np.sqrt(len(ix))
                
                ## Plot.
                ax.plot(times, mu, linewidth=3, color=color, label=legend)
                ax.fill_between(times, mu-se, mu+se, color=color, alpha=0.2)
            
                ## Plot significant clusters.
                for _, row in results.loc[(results.Contrast==contrast)&(results.Label==label)&
                                       (results.Freq==freq)&(results.FDR<0.05),
                                       ('Tmin','Tmax')].iterrows():
                    ax.fill_between(np.linspace(row.Tmin,row.Tmax,1e3), -10, 10, color='k', alpha=0.1)    
            
            
            ## Clean-up.
            if analysis == 'stim':

                ## Fix timing.
                ax.set(xlim=(-0.25,1.5), xticks=[0.0, 0.4, 0.9, 1.4], 
                       xticklabels=[-0.4, 0.0, 0.5, 1.0], ylim=(-1.5,2), yticks=[-1, 0, 1, 2])

                ## Add time markers.
                ax.vlines([0, 0.4, 1.127],*ax.get_ylim(),linestyle='--',alpha=0.3)
                ax.hlines(0, *ax.get_xlim(), linestyle='--',alpha=0.3)
                
            elif analysis == 'resp':

                ## Fix timing
                ax.set(xlim=(-1.0,1.0), xticks=np.arange(-1,1.1,0.5), ylim=(-2.5,1.5), yticks=[-2,-1,0,1])
                
                ## Add time markers.
                ax.vlines(0.0,*ax.get_ylim(),linestyle='--',alpha=0.3)
                ax.hlines(0, *ax.get_xlim(), linestyle='--',alpha=0.3)
                
            ## Special cases.
            if j: ax.set(yticklabels=[])
            if j == 2 and not k: ax.legend(loc=7, bbox_to_anchor=(1.53,0.5), labelspacing=0, handlelength=1.5)
            if j == 2 and k: ax.legend(loc=7, bbox_to_anchor=(1.6,0.5), labelspacing=0, handlelength=1.5)
            if not k: ax.set(xticklabels=[])
            if not i and not k: ax.set_title(r'$\%s$-power' %freq, fontsize=36)

## Additional annotations.
ax.annotate('Stimulus-Locked', xy=(0,0), xytext=(0.01, 0.75), xycoords='figure fraction',
            rotation=90, fontsize=36, va='center')
ax.annotate('Response-Locked', xy=(0,0), xytext=(0.01, 0.25), xycoords='figure fraction',
            rotation=90, fontsize=36, va='center')

sns.despine()
plt.subplots_adjust(left=0.07, right=0.85, top=0.95, bottom=0.05, hspace=0.175, wspace=0.15)
plt.savefig('plots/manuscript/supplementary/S6a.png')
plt.savefig('plots/manuscript/supplementary/S6a.svg')

### Figure S7: Significant Alpha/Beta clusters

In [None]:
import os
import numpy as np
import pylab as plt
import seaborn as sns
from pandas import read_csv
sns.set_style("white")
sns.set_context('notebook', font_scale=2.5)
%matplotlib inline

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Load and prepare data.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Load data.
f = 'source/results/revised_stim_frequency_results.csv'
df = read_csv(f)

## Limit data.
df = df[df.FDR<0.05]

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Defining plotting info.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Define label information.
labels = ['racc-lh', 'dacc-lh', 'pcc-lh', 'dmpfc-lh', 'dlpfc_1-lh', 'dlpfc_2-lh', 
          'dlpfc_3-lh', 'dlpfc_4-lh', 'dlpfc_5-lh', 'dlpfc_6-lh', 
          'racc-rh', 'dacc-rh', 'pcc-rh', 'dmpfc-rh', 'dlpfc_1-rh', 'dlpfc_2-rh', 
          'dlpfc_3-rh', 'dlpfc_4-rh', 'dlpfc_5-rh', 'dlpfc_6-rh']
rois = ['rACC', 'dACC', 'mCC', 'SFG', 'pMFG 1', 'pMFG 2',
        'aMFG 1', 'aMFG 2', 'aIFG', 'pIFG'] * 2

## Define plotting features.
conds = ['DBS','Interference']
colors = ['#ca0020','#008837']

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Main loop.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Initialize figure.
fig, axes = plt.subplots(1,2,figsize=(15,12),sharex=True, sharey=True)

for ax, freq in zip(axes,['alpha','beta']):

    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    ### Plotting Clusters.
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    
    ## Reduce DataFrame to frequency of interest.
    copy = df[df.Freq==freq].copy()
    
    ## Plot timings.
    for x,s in zip([0, 0.4, 1.127],['IAPS','MSIT','Resp']): 
        ax.text(x+0.01,-0.6,s,fontsize=20)
        ax.vlines(x, -1, len(labels), linewidth=2.5, linestyle='--',alpha=0.2)  

    ## Plot clusters.
    for n, label in enumerate(labels[::-1]):

        for m, contrast in enumerate(conds):

            ## Extract clusters.
            ix = np.logical_and(copy.Contrast==contrast, copy.Label==label)
            clusters = copy.loc[ix,['Tmin','Tmax']]
            if not len(clusters): continue

            ## Plot clusters.
            y = n + m * 0.5
            for cluster in clusters.as_matrix(): 
                ax.hlines(y+0.25, cluster.min(), cluster.max(), color=colors[m], lw=12)
                
    ## Fix x-axis.
    xticks = np.array([0.0, 0.4, 0.9, 1.4])
    ax.set(xticks=xticks, xticklabels=xticks-0.4, xlim=(-0.25,1.5), xlabel='Time (s)')
                
    ## Set title
    ax.set_title(r'$\%s$-Power' %freq)
                
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Add flourishes.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Add legend.
for label, color in zip(conds,colors): 
    axes[1].plot([],[],lw=10,color=color,label=label,alpha=0.7)
axes[1].legend(loc=7, bbox_to_anchor=(1.5,0.5), handlelength=1.25, borderaxespad=0)

## Fix y-axis.
ax.set(yticks=np.arange(len(rois))+0.5, yticklabels=rois[::-1], ylim=(-0.7,len(rois)))

## Add dendrograms.
def dendrogram(ax, x, y1, y2, text):
    
    ## Parameters
    lw = 2.0
    alpha = 0.2
    
    ## Drawing
    ax.annotate('', (x, y1), xycoords='axes fraction', xytext=(x,y2), 
                arrowprops=dict(arrowstyle='-', color='k',  linewidth=lw, alpha=alpha))
    ax.annotate('', (x*1.02,y1), xycoords='axes fraction', xytext=(-1e-3,y1), 
                arrowprops=dict(arrowstyle='-', color='k',  linewidth=lw, alpha=alpha))
    ax.annotate('', (x*1.02,y2), xycoords='axes fraction', xytext=(-1e-3,y2), 
                arrowprops=dict(arrowstyle='-', color='k',  linewidth=lw, alpha=alpha))
    ax.annotate(text, (0,0), xycoords='axes fraction', xytext=(x*1.4, np.mean([y1,y2])), 
                rotation=90, fontsize=30, va='center')

dendrogram(axes[0], -0.3, 0.025, 0.51, 'Right Hemisphere')
dendrogram(axes[0], -0.3, 0.515, 1, 'Left Hemisphere')

sns.despine()
plt.subplots_adjust(left=0.15, right=0.85, top=0.95, bottom=0.1, wspace=0.225)
plt.savefig('plots/manuscript/supplementary/S7.png', dpi=180)
plt.savefig('plots/manuscript/supplementary/S7.svg', dpi=180)
plt.show()

### Figure S8

In [None]:
import os
import numpy as np
import pylab as plt
import seaborn as sns
from pandas import DataFrame, concat, read_csv
sns.set_style("white")
sns.set_context('notebook', font_scale=2.5)
%matplotlib inline

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## I/O parameters
space = 'source'
model_name = 'revised'
contrast = 'DBS'
freq = 'theta'

## Label parameters.
labels = ['racc-lh', 'dacc-lh', 'pcc-lh', 'dmpfc-lh', 'dlpfc_1-lh', 'dlpfc_2-lh', 
          'dlpfc_3-lh', 'dlpfc_4-lh', 'dlpfc_5-lh', 'dlpfc_6-lh']
xlabels = ['rACC', 'dACC', 'mCC', 'SFG', 'pMFG 1', 'pMFG 2', 'aMFG 1', 'aMFG 2', 'aIFG', 'pIFG']

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Main loop.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
        
## Load trial information
info = read_csv(os.path.join(space, 'afMSIT_%s_info.csv' %space))
n_subj, = info.Subject.unique().shape
n_cond, = info.DBS.unique().shape
  
corr = []
for analysis in ['stim','resp']:
    
    ## Define mask.
    if analysis == 'stim': tmin, tmax = 0.4, 0.8
    else: tmin, tmax = -0.2, 0.2
    
    df = []
    for label, xlabel in zip(labels,xlabels):

        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
        ### Load data.
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

        ## Load source data.
        npz = np.load(os.path.join(space, 'afMSIT_%s_%s_%s_%s.npz' %(space,analysis,label,freq)))
        data = npz['data']
        times = npz['times']

        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
        ### Compute differences.
        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

        ## Preallocate space.
        mu = np.zeros((n_subj,n_cond))

        for n, subject in enumerate(info.Subject.unique()):

            for m in [0,1]:

                ## Locate trials.
                ix = np.logical_and(info.Subject==subject, info.DBS==m)

                ## Compute average.
                mu[n,m] = data[ix][:,(times >= tmin)&(times <= tmax)].mean()

        ## Convert to DataFrame.
        mu = DataFrame(mu, columns=('DBS_off','DBS_on'), index=info.Subject.unique())

        ## Compute DBSon - DBSoff power differential.
        mu['DBS_diff'] = mu.DBS_on - mu.DBS_off

        ## Compute DBSon - DBSoff RT differential.
        mu['RT_off'] = info.groupby(['DBS','Subject']).RT.mean()[0]
        mu['RT_on'] = info.groupby(['DBS','Subject']).RT.mean()[1]
        mu['RT_diff'] = mu.RT_on - mu.RT_off

        ## Store label information and label.
        mu['Label'] = label
        mu['ROI'] = xlabel
        df.append(mu)

    ## Concatenate DataFrames.
    df = concat(df)

    ## Compute correlations.
    gb = df.groupby('ROI')[['DBS_diff','RT_diff']].corr().reset_index()
    gb = gb[gb.level_1=='DBS_diff'].drop(['level_1','DBS_diff'], 1)
    gb['Analysis'] = analysis
    corr.append(gb)
    
## Concatenate.
corr = concat(corr)
corr.Analysis = np.where(corr.Analysis=='stim','Stimulus-locked','Response-locked')

## Plot.
fig, ax = plt.subplots(1,1,figsize=(12,4))
sns.barplot(x='ROI', y='RT_diff', hue='Analysis', data=corr, 
            order=xlabels, ax=ax)
ax.hlines(0,-0.5,len(xlabels)+0.5)
ax.legend(loc=7, bbox_to_anchor=(1.3,0.5), fontsize=16, handletextpad=0.2)
ax.set(xlabel='', ylabel="Pearson's $r$", title='DBS x RT Correlation')
ax.set_xticklabels(xlabels, fontsize=16, rotation=-30)

sns.despine()
plt.subplots_adjust(left=0.12, right=0.8, top=0.85, bottom=0.17)
plt.savefig('plots/manuscript/supplementary/S8.png', dpi=180)
plt.savefig('plots/manuscript/supplementary/S8.svg', dpi=180)

### S9: Lausanne Mapping
#### Plot Labels

In [None]:
import os
from surfer import Brain
from pandas import read_csv
%matplotlib qt4

## Initialize brain.
brain = Brain('fscopy', 'split', 'inflated', views = ['lateral','medial'], 
              size = (1200,800), subjects_dir='../freesurfs')

## Load mapping info.
mapping = read_csv('../freesurfs/fscopy/label/april2016/mapping.csv')

emote_label = ''
for _, row in mapping.iterrows():
    
    ## Load EMOTE label.
    if not row['EMOTE Label'] == emote_label:
        emote_label = row['EMOTE Label']
        brain.add_label('../freesurfs/fscopy/label/april2016/%s.label' %emote_label, 
                        hemi='lh' if emote_label.endswith('lh') else 'rh',
                        borders = True, color = row['Color'])

#### Find number of vertices

In [None]:
import os
import numpy as np
from mne import read_label, read_source_spaces
from pandas import read_csv

## Load mapping info.
label_dir = '../freesurfs/fscopy/label/april2016'
mapping = read_csv('%s/mapping.csv' %label_dir)

## Load source space.
src = read_source_spaces('../freesurfs/fscopy/bem/fscopy-oct-6p-src.fif', verbose=False)

## Locate labels.
labels = [f for f in os.listdir(label_dir) if f.endswith('label')]

## Iteratively identify number of labels in source space.
mapping['Vertices'] = 0
for label in labels:
    
    ## Load label.
    label = read_label('%s/%s' %(label_dir,label))
    
    ## Compute number of vertices in source space.
    n_vert = np.in1d(src[0 if label.hemi == 'lh' else 1]['vertno'], label.vertices).sum()
    
    ## Store in DataFrame.
    mapping.loc[mapping['EMOTE Label']==label.name, 'Vertices'] = n_vert
    
mapping.to_csv('%s/mapping.csv' %label_dir, index=False)