In [1]:
from matplotlib.patches import Patch
import matplotlib.pyplot as plt
import matplotlib.cm as cm

import numpy as np
import scipy.io
plt.style.use('fast')

In [None]:
path_full = 'data/pc_factorial_run1.mat'
path_pi = 'data/pc_factorial_run1_pi_only.mat'
path_pd = 'data/pc_factorial_run1_pd_only.mat'

n_subjects = 32

### Model selection

- load MCMC samples from .mat file
- calculate sample counts and bayes factor for PI vs PD model comparison

In [None]:
mat = scipy.io.loadmat(path_full, variable_names=['samples'], squeeze_me=True)
z = mat['samples']['z'].item()

In [None]:
eps = np.finfo(float).eps

pmp = np.empty((2, n_subjects), dtype=float)
pmp[0] = np.sum(z == 1, (0, 1)) # PI model 
pmp[1] = np.sum(z == 2, (0, 1)) # PD model
pmp[pmp == 0] = eps

bf_main = np.divide(pmp[0], pmp[1])

In [None]:
def gen_logbf_barplot(bf, modelname, outpath=None):
    """Create barplots for subjectwise model comparison using log-scale.

    Args:
        bf: bayes factor vector of size n_subjects
        modelname: two-element list of modelnames
    """
    logbf = np.log10(bf)
    
    fig, ax = plt.subplots(nrows=1, ncols=1, facecolor='w', figsize=(20, 5))
    
    b = ax.bar(range(n_subjects), logbf)

    cmap = cm.get_cmap('bone')
    color = {
        'extreme': cmap(0),
        'vstrong': cmap(1/5),
        'strong': cmap(2/5),  
        'moderate': cmap(3/5),
        'anecdotal': cmap(4/5),
    }
    
    for rect in b:
        if np.abs(rect.get_height()) < np.log10(3):
            rect.set_color(color['anecdotal'])
        elif np.abs(rect.get_height()) < np.log10(10):
            rect.set_color(color['moderate'])
        elif np.abs(rect.get_height()) < np.log10(30):
            rect.set_color(color['strong'])
        elif np.abs(rect.get_height()) < np.log10(100):
            rect.set_color(color['vstrong'])
        else:
            rect.set_color(color['extreme'])
        
    ax.set_title('Individual Bayes Factors')
    ax.set_xlim((-1, n_subjects))
    ax.set_xlabel('Subjects')
    ax.set_xticks(range(n_subjects))
    ax.set_xticklabels([f'm{sub:02}' for sub in range(2, n_subjects+2)],
                      rotation=-45);
    ax.set_ylim((-3, 3))
    ax.set_ylabel(r'$\log_{10}(BF)$', rotation=90)
    ax.set_axisbelow(True)
    ax.grid()

    legend_elements = [Patch(facecolor=col, edgecolor='k', label=key) 
                   for key, col in color.items()]
    ax.legend(handles=legend_elements, bbox_to_anchor=(1, .7))
    ax.annotate(r'$\uparrow$' + model_names[0], 
                [1.01, .8], xycoords='axes fraction', fontsize=20)
    ax.annotate(r'$\downarrow$' + model_names[1], 
                [1.01, .1], xycoords='axes fraction', fontsize=20)
    plt.tight_layout()
    
    if outpath:
        print('saving...')
        fig.savefig(outpath)

In [None]:
model_names = ('PI', 'PD')
gen_logbf_barplot(bf_main, model_names, 'pygures/logbf_pi_vs_pd.png')