In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pandas import DataFrame, read_csv, concat
from arviz import hdi
sns.set_theme(style='ticks', context='notebook', font_scale=1.2)

## Section 1: Model Diagnostics

In [2]:
## Define parameters.
models = ['pgng_m1', 'pgng_m2', 'pgng_m3', 'pgng_m4', 'pgng_m5', 'pgng_m6', 'pgng_m7']
sessions = ['s1', 's2', 's3']

### 1.1 Stan diagnostics


In [None]:
## Main loop.
diagnostics = []
for m in models:
    
    for s in sessions:

        ## Load Stan summary.
        samples = read_csv(os.path.join('stan_results', s, f'{m}.tsv.gz'), sep='\t', compression='gzip')
        summary = read_csv(os.path.join('stan_results', s, f'{m}_summary.tsv'), sep='\t', index_col=0)
        ppc = read_csv(os.path.join('stan_results', s, f'{m}_ppc.csv'))

        ## Apply restrictions.
        ppc = ppc[~np.isinf(ppc.k_u)]                        # Removed fixed parameters.

        ## Identify number of divergences.
        divergence = samples.divergent__.sum()

        ## Identify parameters failing to reach convergence.
        rhat = len(summary.query('R_hat >= 1.02'))

        ## Identify parameters with low effective sample size.
        n_eff = len(summary.query('N_Eff < 400'))

        ## Identify number of effective parameters.
        p_loo = ppc.pwaic.sum()

        ## Identify number of poorly predicted observations.
        pk = np.sum(ppc.k_u > 0.7)

        ## Convert to dictionary. Append.
        diagnostics.append(dict(
            model = m,
            session = s,
            divergence = divergence,
            rhat = rhat,
            n_eff = n_eff,
            p_loo = np.round(p_loo, 1),
            pk = np.round(pk, 3)
        ))

## Convert to DataFrame.
diagnostics = DataFrame(diagnostics).sort_values(['session','model']).set_index(['session','model'])
diagnostics

## Section 2: Model Comparison

In [None]:
## Define parameters.
models = ['pgng_m1', 'pgng_m2', 'pgng_m3', 'pgng_m4', 'pgng_m5', 'pgng_m6', 'pgng_m7']
sessions = ['s1', 's2', 's3']

### 2.1 LOO-CV indices

In [None]:
## Main loop.
loocv = []
for m in models:

    for s in sessions:
    
        ## Load posterior predictive check.
        ppc = read_csv(os.path.join('stan_results', s, f'{m}_ppc.csv'))

        ## Compute LOO-CV.
        loo = -2 * ppc.loo.sum()

        ## Convert to dictionary. Append.
        loocv.append(dict(model=m, session=s, loocv=loo))
        
## Convert to DataFrame.
loocv = DataFrame(loocv).pivot_table('loocv', 'session', 'model')
loocv.round(1)

### 2.2 Model comparisons

In [3]:
from itertools import combinations

loocv = []
for s in sessions:

    for a, b in list(combinations(models, 2)):

        ## Load data.
        ppc1 = read_csv(os.path.join('stan_results', s, f'{a}_ppc.csv'))
        ppc2 = read_csv(os.path.join('stan_results', s, f'{b}_ppc.csv'))

        arr = -2 * (ppc2.loo - ppc1.loo)

        ## Compute stats.
        N = len(ppc1)
        mu = np.sum(arr)
        se = np.std(arr) * np.sqrt(N)

        ## Convert to dictionary. Append.
        loocv.append(dict(model=m, session=s, a=a, b=b, loocv='%0.1f (%0.1f)' %(mu, se)))
        
## Convert to DataFrame.
loocv = DataFrame(loocv).pivot_table('loocv', 'a', ['session','b'], aggfunc=lambda x: x).fillna('-')
loocv

NameError: name 'm' is not defined

### 2.3 Table 2

In [None]:
from sklearn.metrics import accuracy_score

## Define winning model.
winning = 'pgng_m7'

## Main loop.
loocv = []
for i, m in enumerate(models):

    ## Load posterior predictive checks.
    ppc1 = concat([read_csv(os.path.join('stan_results', s, f'{m}_ppc.csv'))
                  for s in sessions])

    ## Load posterior predictive checks.
    ppc2 = concat([read_csv(os.path.join('stan_results', s, f'{winning}_ppc.csv'))
                  for s in sessions])
    
    ## Compute classification accuracy.
    score = accuracy_score(ppc1.choice, ppc1.Y_hat > 0.5) * 1e2
    
    ## Compute LOO-CV.
    loo = -2 * ppc1.loo.sum()
    
    ## Comute delta LOO-CV.
    arr = -2 * (ppc2.loo - ppc1.loo)
    mu = np.sum(arr)
    se = np.std(arr) * np.sqrt(len(arr))
    
    ## Append.
    loocv.append(dict(model=i+1, score='%0.1f' %score, loo='%0.1f' %loo, delta='%0.1f (%0.1f)' %(mu, se)))
    
## Convert to DataFrame.
loocv = DataFrame(loocv).set_index('model')
loocv

### 2.4 Table S3

In [None]:
from sklearn.metrics import accuracy_score

## Define winning model.
winning = 'pgng_m7'

## Main loop.
loocv = []
for i, m in enumerate(models):

    for s in sessions:
    
        ## Load posterior predictive checks.
        ppc1 = read_csv(os.path.join('stan_results', s, f'{m}_ppc.csv'))

        ## Load posterior predictive checks.
        ppc2 = read_csv(os.path.join('stan_results', s, f'{winning}_ppc.csv'))

        ## Compute classification accuracy.
        score = accuracy_score(ppc1.choice, ppc1.Y_hat > 0.5) * 1e2

        ## Compute LOO-CV.
        loo = -2 * ppc1.loo.sum()

        ## Comute delta LOO-CV.
        arr = -2 * (ppc2.loo - ppc1.loo)
        mu = np.sum(arr)
        se = np.std(arr) * np.sqrt(len(arr))

        ## Append.
        loocv.append(dict(model=i+1, session=s, score='%0.1f' %score, loo='%0.1f' %loo, 
                          delta='%0.1f (%0.1f)' %(mu, se)))
    
## Convert to DataFrame.
loocv = DataFrame(loocv).set_index(['session','model']).sort_index()
loocv

## Section 3: Posterior Predictive Checks

In [None]:
## Define parameters.
models = ['pgng_m1', 'pgng_m2', 'pgng_m3', 'pgng_m4', 'pgng_m5', 'pgng_m6', 'pgng_m7']
sessions = ['s1', 's2', 's3']

### 3.1 Group-level

In [None]:
## Initialize canvas.
fig, axes = plt.subplots(len(sessions), len(models), figsize=(len(models)*4, len(sessions)*3),
                         sharex=True, sharey=True)

## Define aesthetics.
order = ['gw', 'ngw', 'gal', 'ngal']
palette = sns.diverging_palette(220, 20, n=4)

for i, s in enumerate(sessions):
    
    for j, m in enumerate(models):
        
        ## Load posterior predictive check.
        ppc = read_csv(os.path.join('stan_results', s, f'{m}_ppc.csv'))
            
        ## Plot learning curves.
        sns.lineplot(x='exposure', y='choice', hue='robot', data=ppc, hue_order=order,
                      palette=palette, lw=3, ci=None, ax=axes[i,j])
        sns.lineplot(x='exposure', y='Y_hat', hue='robot', data=ppc, hue_order=order, 
                      palette=palette, lw=3, ci=None, linestyle='--', ax=axes[i,j])

        ## Add trend line.
        axes[i,j].axhline(0.5, color='0.5', alpha=0.4, zorder=-10)
        
        ## Adjust legend.
        axes[i,j].legend_.set_visible(False)
        
sns.despine()
plt.tight_layout()

In [None]:
## Initialize canvas.
fig, axes = plt.subplots(len(sessions), len(models), figsize=(len(models)*3, len(sessions)*3),
                         sharex=True, sharey=True)

## Define aesthetics.
order = ['gw', 'ngw', 'gal', 'ngal']
palette = sns.diverging_palette(220, 20, n=4)

## Define convenience functions.
RMSE = lambda x: np.sqrt(np.mean(np.square(x)))

for i, s in enumerate(sessions):
    
    for j, m in enumerate(models):
        
        ## Load posterior predictive check.
        ppc = read_csv(os.path.join('stan_results', s, f'{m}_ppc.csv'))

        ## Compute accuracy by participant / condition.
        gb = ppc.groupby(['subject','robot']).agg({'choice':'mean', 'Y_hat':'mean'}).reset_index()
        
        ## Compute fit statistics.
        rmse = RMSE(gb.choice - gb.Y_hat)
        corr = gb[['choice','Y_hat']].corr().values[0,1]
        
        ## Plot learning curves.
        sns.scatterplot(x='choice', y='Y_hat', hue='robot', data=gb, hue_order=order, 
                        palette=palette, ax=axes[i,j])
        axes[i,j].plot([-1,2], [-1,2], color='0.8')
        
        ## Adjust x-axis.
        axes[i,j].set(xlim=(-0.05,1.05))
        
        ## Adjust y-axis
        axes[i,j].set(ylim=(-0.05,1.05))
        
        ## Adjust legend.
        axes[i,j].legend(loc=4, frameon=False, ncol=2, borderpad=0, handletextpad=0, columnspacing=0.3)
        
        ## Add annotation.
        annot = 'RMSE = %0.3f\nr = %0.3f' %(rmse, corr)
        axes[i,j].annotate(annot, (0,0), (0.04, 0.98), 'axes fraction', ha='left', va='top', fontsize=11)
        
sns.despine()
plt.tight_layout()

In [None]:
## Define parameters.
sessions = ['s1', 's2', 's3']

## Iteratively load data.
data = concat([read_csv(os.path.join('stan_results', session, 'pgng_m7_ppc.csv'))
               for session in ['s1','s2','s3']])
data['exposure'] -= 1

g = sns.FacetGrid(data, col='runsheet', col_order=['1a','2a','3a','1b','2b','3b'], col_wrap=3)
palette = sns.diverging_palette(220, 20, n=4)

g.map(sns.pointplot, 'exposure', 'choice', 'robot', order=np.arange(12),
      hue_order=['gw','gal','ngw','ngal'], palette=palette, ci=None)
g.map(sns.lineplot, 'exposure', 'Y_hat', 'robot', 
      hue_order=['gw','gal','ngw','ngal'], palette=palette, linestyle='--', ci=None)

## Section 4: Parameter stability

### 4.1 Between-session comparisons

In [4]:
## Define parameters.
pairs = list(combinations(['s1','s2','s3'], 2))
params = ['b1','b2','b3','b4','a1','a2','c1']
model = 'pgng_m7'

## Main loop.
comparisons = []
for s1, s2 in pairs:
    
    ## Load samples.
    samples_1 = read_csv(os.path.join('stan_results', s1, f'{model}.tsv.gz'), 
                                      sep='\t', compression='gzip')
    samples_2 = read_csv(os.path.join('stan_results', s2, f'{model}.tsv.gz'), 
                         sep='\t', compression='gzip')
    
    ## Iterate over parameters.
    for p in params:
        
        ## Extract parameters.
        a = samples_1[f'{p}_mu'].values
        b = samples_2[f'{p}_mu'].values
        
        ## Summarize & report.
        mu1 = np.mean(a); mu2 = np.mean(b); delta = np.mean(a - b)
        lb, ub = hdi(a - b, hdi_prob=0.95)
        is_credible = '**' if np.sign(lb) == np.sign(ub) else ''
        comparisons.append({'s1': s1, 's2': s2, 'param': p, 'Mean[1]': mu1, 'Mean[2]': mu2, 
                            'delta': delta, 'lb': lb, 'ub': ub, 'credible': is_credible})
        
## Convert to DataFrame.
comparisons = DataFrame(comparisons).set_index(['param','s1','s2']).sort_index()
comparisons.loc[params].round(3)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Mean[1],Mean[2],delta,lb,ub,credible
param,s1,s2,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
b1,s1,s2,7.582,9.409,-1.827,-4.181,0.589,
b1,s1,s3,7.582,13.328,-5.746,-9.217,-2.52,**
b1,s2,s3,9.409,13.328,-3.918,-7.755,-0.458,**
b2,s1,s2,6.626,11.093,-4.468,-7.548,-1.697,**
b2,s1,s3,6.626,13.054,-6.429,-10.281,-3.126,**
b2,s2,s3,11.093,13.054,-1.961,-6.239,2.066,
b3,s1,s2,1.448,1.041,0.407,0.083,0.722,**
b3,s1,s3,1.448,0.982,0.466,0.151,0.781,**
b3,s2,s3,1.041,0.982,0.058,-0.262,0.359,
b4,s1,s2,0.182,0.105,0.077,-0.11,0.284,


## Section 6: Reliability

In [None]:
## Define paramaters.
model = 'pgng_m7'

## Load summary.
reliability = read_csv(os.path.join('stan_results', f'{model}_reliability.csv'))

### 6.1 Split-half reliability

In [None]:
## Initialize canvas.
fig, ax = plt.subplots(1, 1, figsize=(10,4))
palette = np.append('k', sns.color_palette('crest_r', n_colors=3).as_hex())
labels = ['Overall','Session 1','Session 2', 'Session 3']
offsets = np.linspace(-0.2,0.2,4)

for i, (offset, color, label) in enumerate(zip(offsets, palette, labels)):

    ## Define points.
    y = reliability.query(f'Type == "sh" and Group == {i}').Mean
    x = np.arange(len(y)) + offset
    yerr = np.array([
        reliability.query(f'Type == "sh" and Group == {i}')['97.5%'] - y,
        y - reliability.query(f'Type == "sh" and Group == {i}')['2.5%']
    ])
    
    ## Plot.
    ax.errorbar(x, y, fmt='o', yerr=yerr, color=color, label=label, capsize=3, elinewidth=1.33)
    
## Add detail.
ax.axhline(0.7, color='0.8', lw=0.8, linestyle='--')
ax.legend(loc=4, frameon=False, borderpad=0, handletextpad=0.2)
ax.set(xticks=np.arange(7), ylim=(0,1.05), ylabel='Split-half reliability')
ax.set_xticklabels(['Inverse\ntemperature\n(Positive)','Inverse\ntemperature\n(Negative)',
                    'Go Bias\n(Positive)','Go Bias\n(Negative)','Learning\nRate\n(Positive)',
                    'Learning\nRate\n(Negative)','Lapse Rate'])

sns.despine()
plt.tight_layout()

### 6.2 Test-retest reliability

In [None]:
## Initialize canvas.
fig, ax = plt.subplots(1, 1, figsize=(10,4))
palette = np.append('k', sns.color_palette('crest_r', n_colors=3).as_hex())
labels = ['Overall','S1 vs. S2','S1 vs. S3', 'S2 vs. S3']
offsets = np.linspace(-0.2,0.2,4)

for i, (offset, color, label) in enumerate(zip(offsets, palette, labels)):

    ## Define query.
    query = f'Type == "trt" and Group == {i}'
    
    ## Define points.
    y = reliability.query(query).Mean
    x = np.arange(len(y)) + offset
    yerr = np.array([
        reliability.query(query)['97.5%'] - y,
        y - reliability.query(query)['2.5%']
    ])
    
    ## Plot.
    ax.errorbar(x, y, fmt='o', yerr=yerr, color=color, label=label, capsize=3, elinewidth=1.33)
    
## Add detail.
ax.axhline(0.7, color='0.8', lw=0.8, linestyle='--')
ax.legend(loc=4, frameon=False, borderpad=0, handletextpad=0.2)
ax.set(xticks=np.arange(7), ylim=(0,1.05), ylabel='Test-retest reliability')
ax.set_xticklabels(['Inverse\ntemperature\n(Positive)','Inverse\ntemperature\n(Negative)',
                    'Go Bias\n(Positive)','Go Bias\n(Negative)','Learning\nRate\n(Positive)',
                    'Learning\nRate\n(Negative)','Lapse Rate'])

sns.despine()
plt.tight_layout()