In [None]:
import matplotlib.patches as mpatches
import numpy as np
import os
import math
import matplotlib.pyplot as plt
%matplotlib inline  

In [None]:
params = {
   'axes.labelsize': 12,
   'font.size': 10,
   'legend.fontsize': 12,
   'xtick.labelsize': 12,
   'ytick.labelsize': 12,
   'text.usetex': False,
   'figure.figsize': [7, 2] # instead of 4.5, 4.5
   }
plt.rcParams.update(params)

## helper function

In [None]:
def autolabel(rects, ses):
    # attach some text labels
    for rect, se in zip(rects, ses):
        height = rect.get_height()
        plt.axes().text(rect.get_x()+rect.get_width()/2., height+1.03*se, '{:2.3f}'.format(float(height)).lstrip('0'),
                ha='center', va='bottom',fontsize=10)

## Fixed Parameter

In [None]:
# experiment
exp = 'corr'
exp_label = 'Pearson Correlation Between Two Subject Groups'

# dictionaries
cl_dict = {0:2/450,1:2/297,2:2/297,3:2/1973} # chance level of each dataset
nfeat = 75

ds_dict = {0:'GreenEyes',1:'Milky',2:'Vodka',3:'Sherlock'}
model_all = ['avg','indv_srm','indv_srm']
accu_type = ['sep_','sep_','tgr_']
lg_lb  = ['MNI','Train Separately','Train Together'] # legend label

# paths
input_file = '../../output/accu_bar/{}/{}/{}_ds{}.npz' #exp,roi,model,ds
output_path = '../../output/figures/{}/' #exp
if not os.path.exists(output_path.format(exp)):
    os.makedirs(output_path.format(exp))
output_file = output_path+'ds{}' #ds

## Plotting Parameter

In [None]:
roi_all = ['dmn','pt','eac']
max_accu = 1.2
ds = 3

## Aggregate Accuracy

In [None]:
# Accuracies
all_mean = [] # length # of roi
all_se = [] # length # of roi
# aggregate each roi
for roi in roi_all:
    roi_mean = []
    roi_se = []
    for model,ac_tp in zip(model_all,accu_type):
        ws=np.load(input_file.format(exp,roi,model,ds))
        roi_mean.append(ws[ac_tp+'mean'])
        roi_se.append(ws[ac_tp+'se'])
    all_mean.append(roi_mean)
    all_se.append(roi_se)

## Plot all roi

In [None]:
width=1
nmodel = 3
ndata = len(roi_all)
group_width = 1.25*width*nmodel
center_all = np.linspace(0,group_width*(ndata-1),ndata)
# set colors
color_all = ['lightgrey','dodgerblue','red']


xtick_idx = []
xtick_name = []
plt.figure()
for m,(roi, center) in enumerate(zip(roi_all,center_all)):
    mean = all_mean[m]
    se   = all_se[m]  
    idx = np.arange(center,center+nmodel*width,width)
    error_config = {'ecolor': '0','capsize':3}   
    rects = plt.bar(idx, mean, yerr=se, align='center', error_kw=error_config, width = width-0.1)
    # set colors
    for rect_i in range(len(rects)):
        rects[rect_i].set_color(color_all[rect_i])
    autolabel(rects, se)
    xtick_idx.append(center+(nmodel-1)*width/2)        
    xtick_name.append(roi.upper())                                    
plt.xticks(xtick_idx,xtick_name)

plt.ylabel('Accuracy')
plt.xlabel('ROI')
left_lim = center_all[0]-0.5*width-0.5
right_lim = center_all[-1]+(nmodel-0.5)*width+0.5
plt.xlim([left_lim,right_lim])
plt.ylim([0,max_accu])

# plot chance accuracy
cl = cl_dict[ds]
line = plt.plot([left_lim, right_lim], [cl, cl], 'k-.', linewidth=2)

# Add texts
plt.text((right_lim+left_lim)/2, 1.25*max_accu, exp_label,fontsize=12,horizontalalignment='center', verticalalignment='bottom')
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
plt.text(right_lim-0.02*(right_lim-left_lim), 0.92*max_accu, 'k = '+str(nfeat),bbox=props,fontsize=12,horizontalalignment='right', verticalalignment='top')
props2 = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
plt.text(left_lim+0.02*(right_lim-left_lim), 0.92*max_accu, 'Dataset: '+ds_dict[ds],bbox=props2,fontsize=12,horizontalalignment='left', verticalalignment='top')

# legend
legend_handle = []
for patch_i in range(3):
    legend_handle.append(mpatches.Patch(color=color_all[patch_i], label=lg_lb[patch_i]))    
l1 = plt.legend(handles=legend_handle,bbox_to_anchor=(0.,1.02,1.,.102), loc=3,ncol=3,mode="expand", borderaxespad=0.)
ax = plt.gca().add_artist(l1)
plt.legend(line ,['chance'],loc=9,ncol=1)

# plt.savefig(output_file.format(exp,ds)+'.eps', format='eps', dpi=200,bbox_inches='tight')
plt.savefig(output_file.format(exp,ds)+'.pdf', format='pdf', dpi=200,bbox_inches='tight')