In [None]:
import matplotlib.patches as mpatches
import numpy as np
import os
import math
import matplotlib.pyplot as plt
%matplotlib inline  

In [None]:
params = {
   'axes.labelsize': 12,
   'font.size': 10,
   'legend.fontsize': 10,
   'xtick.labelsize': 8,
   'ytick.labelsize': 12,
   'text.usetex': False,
   'figure.figsize': [6, 2] # instead of 4.5, 4.5
   }
plt.rcParams.update(params)

## helper function

In [None]:
def autolabel(rects, ses):
    # attach some text labels
    for rect, se in zip(rects, ses):
        height = rect.get_height()
        plt.axes().text(rect.get_x()+rect.get_width()/2., height+1.03*se, '{:2.3f}'.format(float(height)).lstrip('0'),
                ha='center', va='bottom',fontsize=9.5)

## Fixed Parameter

In [None]:
# experiment
exp_label = 'fMRI to Text Mapping'

# dictionaries
cl_dict = {'class':{0:1/7,1:1/4,2:1/4,3:1/25},'rank':{0:1/2,1:1/2,2:1/2,3:1/2}} # chance level of each dataset
feat_dict = {'multi_srm':[75,150,75],'multi_dict':[75,50,50],'indv_srm':[75,50,75],'indv_ica':[50,50,25],\
 'indv_gica':[50,50,50],'indv_dict':[50,25,50],'avg':[50,50,50]}
roi_dict = {'dmn':0,'pt':1,'eac':2}
md_dict = {'avg':'MNI','multi_srm':'MDMS','ica':'ICA','indv_gica':'GICA','indv_dict':'DL','indv_srm':'SRM','multi_dict':'MDDL'}
ds_dict = {0:'GreenEyes',1:'Milky',2:'Vodka',3:'Sherlock'}
actp_dict = {'class':'Classification Accuracy','rank':'Ranking Accuracy'}
lg_dc = ['Prm.','Prm.+Sec.']
# lg_dc = ['Train separately','Train jointly']

# path
input_file = '../../output/accu_bar/{}/{}/{}/{}_ds{}.npz' #exp,total_ds,roi,model,ds
output_path = '../../output/figures/{}/' #exp
output_file = output_path+'{}_{}_ds{}' #roi,accu_type,ds

## Plotting Parameter

In [None]:
exp = 'mapping'
total_ds = list(range(12))
max_accu_all = [[0.75,0.61,0.54],[1.5,1.5,1.5]] #accuracy for classification and ranking plots
datasets = [0,3]
# max_accu_all = [0.8,1.2]
# datasets = [1,2]

model_all = ['avg','indv_gica','indv_dict','indv_srm','multi_dict','multi_srm']
roi_all = ['dmn','pt','eac']
accu_type = ['class','rank']

if not os.path.exists(output_path.format(exp,total_ds).replace(' ','')):
    os.makedirs(output_path.format(exp,total_ds).replace(' ',''))

if len(datasets)==1:
    params['figure.figsize'] = [5.5, 2]
    plt.rcParams.update(params)

## Aggregate Accuracies

In [None]:
all_mean = [] # length # of roi
all_se = [] # length # of roi
# aggregate each roi,each type
for roi in roi_all:
    roi_mean = []
    roi_se = []
    for ac_tp in accu_type:
        roi_mean.append([])
        roi_se.append([])
    for ds in datasets:
        for model in model_all: 
            ws=np.load(input_file.format(exp,total_ds,roi,model,ds).replace(' ',''))
            for r,ac_tp in enumerate(accu_type):
                roi_mean[r].append(ws[ac_tp+'_mean'].item())
                roi_se[r].append(ws[ac_tp+'_se'].item())
    all_mean.append(roi_mean)
    all_se.append(roi_se)

## Plot all roi

In [None]:
width=1
nmodel = len(model_all) # number of models
ndata = len(datasets)
group_width = 1.1*width*nmodel+0.4*width
center_all = np.linspace(0,group_width*(ndata-1),ndata)
# set colors
color_all = ['lightgrey','mediumseagreen','dodgerblue','mediumorchid','gold','red']
pattern = ['///',' ']

for m,roi in enumerate(roi_all):
    for r,ac_tp,max_accu in zip(range(len(accu_type)),accu_type,max_accu_all):
        # xtick 
        xtick_idx = []
        xtick_name = [] 
        xtick_idx.append(center_all[0]-0.5*width-0.8)
        xtick_name.append('model\nk')
        # dataset names
        ds_idx = []
        ds_name = []
        plt.figure()
        # configure bars
        for i,(ds, center) in enumerate(zip(datasets,center_all)):
            mean = list(all_mean[m][r][i*nmodel:(i+1)*nmodel])
            se   = list(all_se[m][r][i*nmodel:(i+1)*nmodel])    
            idx = np.concatenate(((center-0.15*width)[None],np.arange(center+width,center+nmodel*width-2.2*width,width),\
                              (center+nmodel*width-1.85*width)[None],(center+nmodel*width-0.8*width)[None]))
            error_config = {'ecolor': '0','capsize':3}   
            rects = plt.bar(idx, mean, yerr=se, align='center', error_kw=error_config, width = width-0.1)
            # set colors
            for rect_i in range(len(rects)):
                rects[rect_i].set_color(color_all[rect_i])
            # add patterns
            for rect_i in range(1,len(rects)-2):
                rects[rect_i].set_hatch(pattern[0])
            autolabel(rects,se) 
            ds_idx.append(center+(nmodel-2)*width/2) 
            ds_name.append(ds_dict[ds])
            # xtick names
            xtick_idx.extend(idx)
            for model in model_all:
                feat = str(feat_dict[model][roi_dict[roi]])
                if model == 'avg':
                    xtick_name.append(md_dict[model]+'\n ')
                else:
                    xtick_name.append(md_dict[model]+'\n'+feat)
                    
#         plt.xticks(rotation=15,ha='center')
        plt.xticks(xtick_idx,xtick_name)
        plt.yticks(np.arange(0, max_accu[m]+0.001, 0.2))
        plt.ylabel('Accuracy')
        left_lim = center_all[0]-0.5*width-0.8
        right_lim = center_all[-1]+(nmodel-0.5)*width+0.5
        plt.xlim([left_lim,right_lim])
        plt.ylim([0,max_accu[m]])

        # plot chance accuracy of each dataset
        for d,center in enumerate(center_all):
            cl = cl_dict[ac_tp][datasets[d]]
            line = plt.plot([center-width, center+(nmodel)*width], [cl, cl], 'k-.', linewidth=2)

        # Add texts
        title_height = 1.4*max_accu[m]
#         plt.text((right_lim+left_lim)/2, title_height, actp_dict[ac_tp]+' for '+exp_label+' (ROI: '+roi.upper()+')',fontsize=11,horizontalalignment='center', verticalalignment='bottom')
        plt.text((right_lim+left_lim)/2, title_height, exp_label+' (ROI: '+roi.upper()+')',fontsize=13,horizontalalignment='center', verticalalignment='bottom')
        props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
        for d in range(len(datasets)):
            plt.text(ds_idx[d],0.96*max_accu[m],'Prm.: '+ds_name[d],bbox=props,fontsize=11,horizontalalignment='center', verticalalignment='top')

        # legend
        legend_handle = []
        for patch_i in range(len(model_all)):
            legend_handle.append(mpatches.Patch(color=color_all[patch_i], label=md_dict[model_all[patch_i]]))
        if len(datasets)==1:
            l1 = plt.legend(handles=legend_handle,bbox_to_anchor=(0.,1.02,1.,.102),loc=3,ncol=int(len(model_all)/2)+1,mode="expand", borderaxespad=0.)
        else:
            l1 = plt.legend(handles=legend_handle,bbox_to_anchor=(0.,1.02,1.,.102),loc=3,ncol=len(model_all),mode="expand", borderaxespad=0.)
        ax = plt.gca().add_artist(l1)
        l2 = plt.legend(line ,['chance'],bbox_to_anchor=(0.,1.42,1.015,0.),loc=1,ncol=1)
        ax = plt.gca().add_artist(l2)
        legend_handle2 = []
        for patch_i in range(2):
            legend_handle2.append(mpatches.Patch(hatch=pattern[patch_i], color='beige',label=lg_dc[patch_i]))
        ax = plt.gca().add_artist(l2)
        plt.legend(handles=legend_handle2,bbox_to_anchor=(-0.015,1.42,1.005,0.),loc=2,ncol=2)

                
    #     plt.savefig(output_file.format(exp,roi,ac_tp,datasets).replace(' ','')+'.eps', format='eps', dpi=200,bbox_inches='tight')
        plt.savefig(output_file.format(exp,roi,ac_tp,datasets).replace(' ','')+'.pdf', format='pdf', dpi=200,bbox_inches='tight')