In [None]:
import matplotlib.patches as mpatches
import numpy as np
import os
import math
import matplotlib.pyplot as plt
%matplotlib inline  

In [None]:
params = {
   'axes.labelsize': 11,
   'font.size': 10,
   'legend.fontsize': 11,
   'xtick.labelsize': 11,
   'ytick.labelsize': 12,
   'text.usetex': False,
   'figure.figsize': [8, 2] # instead of 4.5, 4.5
   }
plt.rcParams.update(params)

## Fixed Parameter

In [None]:
roi_all = ['dmn']
model = 'multi_srm'

# dictionaries
cl_dict = {0:1/450,1:1/297,2:1/297,3:1/1973} # chance level of each dataset
feat_dict = {'multi_srm':[75,75,100],'multi_dict':[25,50,50],'avg':[50,50,50]}
roi_dict = {'dmn':0,'pt':1,'eac':2}
# change multi_srm name!
md_dict = {'avg':'MNI','multi_srm':'MDMSL','multi_dict':'MDDict'}
ds_dict = {0:'GreenEyes',1:'Milky',2:'Vodka',3:'Sherlock'}

# paths
input_file = '../../output/accu_bar/{}/{}/{}_ds{}.npz' #exp,roi,model,ds
output_path = '../../output/figures/{}/' #exp
output_file = output_path+'{}_{}_{}' #model,roi,ds

## Plotting Parameter

In [None]:
# experiment
exp = 'shared_subj'
exp_label = 'Add Shared Subjects to Sec. Dataset'

ds_all = [[2,0],[1,0]] # length must be 2
min_accu = [0.1,0.1]
max_accu = [0.4,0.2]

if not os.path.exists(output_path.format(exp)):
    os.makedirs(output_path.format(exp))

## Aggregate Accuracy

In [None]:
# Accuracies
all_mean = [] # length # of roi
all_se = [] # length # of roi
# aggregate each roi
for roi in roi_all:
    roi_mean = []
    roi_se = []
    # aggregate each ds pair
    for ds in ds_all:
        ws=np.load(input_file.format(exp,roi,model,ds).replace(' ',''))
        roi_mean.append(ws['mean'])
        roi_se.append(ws['se'])
    all_mean.append(roi_mean)
    all_se.append(roi_se)

## Plot all roi (line plot)

In [None]:
color = 'red'
for m,roi in enumerate(roi_all):
    fig = plt.figure()
    for r,ds in enumerate(ds_all):
        mean = all_mean[m][r]
        se = all_se[m][r]
        num_subj = mean.shape[0]
        idx = np.array(list(range(num_subj)),dtype=np.int32)
        if r == 0:
            ax = fig.add_subplot(121)
            plt.ylabel('Accuracy')
        elif r == 1:
            ax = fig.add_subplot(122)
        ax.errorbar (idx,mean,se, color=color, capsize=3, ecolor='k',linewidth=2.0)  
        left_lim = -0.5
        right_lim = num_subj-0.5
        ax.set_xlim([left_lim,right_lim])

        ax.set_ylim([min_accu[r],max_accu[r]])
        ax.set_xticks(list(range(0,num_subj,2)))
        ax.set_xticklabels([n+1 for n in list(range(0,num_subj,2))])
        ax.set_yticks(np.arange(min_accu[r], max_accu[r]+0.001, 0.1))
        
        # plot chance accuracy
        cl = cl_dict[ds[0]]
        line = plt.plot([left_lim, right_lim], [cl, cl], 'k-.', linewidth=2)

        # Add texts        
        props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
        plt.text(4*(right_lim+left_lim)/5, 0.22*(max_accu[r]-min_accu[r])+min_accu[r], 'k = '+str(feat_dict[model][roi_dict[roi]]),bbox=props,fontsize=12,horizontalalignment='center', verticalalignment='top')
        props2 = dict(boxstyle='square', facecolor='ivory',alpha=0.5)
        ds_text = 'Prm.:'+ds_dict[ds[0]]+'\nSec.:'+ds_dict[ds[1]]
        plt.text(left_lim+0.5,0.95*(max_accu[r]-min_accu[r])+min_accu[r],ds_text,bbox=props2,fontsize=12,horizontalalignment='left', verticalalignment='top')

    plt.text(left_lim-0.6, 1.02*(max_accu[r]-min_accu[r])+min_accu[r], exp_label+' (ROI: '+roi.upper()+')',fontsize=12,horizontalalignment='center', verticalalignment='bottom')
    plt.text(left_lim-0.6, min_accu[r]-0.015, 'Number of shared subjects',fontsize=11,horizontalalignment='center', verticalalignment='top') 
    plt.subplots_adjust(wspace=0.15)
#     plt.xlabel('Number of shared subjects')    
#     plt.savefig(output_file.format(exp,roi,model,ds).replace(' ','')+'.eps', format='eps', dpi=200,bbox_inches='tight')
    plt.savefig(output_file.format(exp,roi,model,ds).replace(' ','')+'.pdf', format='pdf', dpi=200,bbox_inches='tight')