**Imports**

In [None]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
from matplotlib import gridspec, lines, legend_handler
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
import numpy as np
import pandas as pd
import scipy.stats as scs
from scipy import ndimage
from statsmodels.formula.api import ols
from os import path
from io import StringIO
from IPython.display import display
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.collections import LineCollection
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from cycler import cycler
import warnings
import itertools

from python_scripts.utils import loc_utils as lut
from python_scripts.utils import vis_utils as vut
from python_scripts.utils import model_utils as mut
from python_scripts.utils.vis_utils import gcolors, glabels, fullglabels, gmarkers, colors, ncolors, tlabels
from python_scripts.utils.model_utils import *

plt.style.use('python_scripts/my.mplstyle')

# Visualize outliers

In [None]:
def make_clean_dataset(input_data_path, save_path, **kwargs):
    # Define a response bias function
    def rbf(x):
        _, response_counts = np.unique(x.response, return_counts=True)
        return np.max(response_counts) / np.sum(response_counts)


    # Open combined data file
    df = pd.read_csv(input_data_path, index_col=None).set_index('sid')

    # Initialize columns to record values of interest
    df['alloc_bias'], df['resp_bias'] = 0, 0

    # Calculate values of interest
    activities = ('A1', 'A2', 'A3', 'A4')
    for sid, sdf in tqdm(df.groupby(by='sid'), desc='Progress: '):
        # Allocation variance
        counts = [sum(sdf.activity == i) for i in activities]
        allocation_variance = np.std(counts)
        df.loc[sid, 'alloc_bias'] = allocation_variance

        # Response bias
        response_bias = sdf.groupby('family').apply(rbf).mean()
        df.loc[sid, 'resp_bias'] = response_bias

    # Detect high allocation variance and response bias
    df_ = df.reset_index().groupby('sid').head(1).reset_index()
    df_['high_ab'] = df_.alloc_bias >= kwargs['ab_crit']
    df_['high_rb'] = np.logical_and(df_.resp_bias > df_.resp_bias.mean() + kwargs['rb_crit'] * df_.resp_bias.std(), ~df_.high_ab)

    display(df_.groupby(by='group')[['high_ab', 'high_rb']].sum().astype(int))
    print('Found {} outliers'.format(np.logical_or(df_.high_ab, df_.high_rb).sum()))

    # Exclude outliers
    outlier = df_.loc[df_.high_ab | df_.high_rb, 'sid']
    df = df.loc[~df.index.isin(outlier), :] if exclude else df
    display(df.reset_index().groupby(by='group')['sid'].nunique())

    # Save data
    if save_path:
        print('saving to {}'.format(path.abspath(save_path)))
        df.reset_index().to_csv(save_path, index=False)
    

exclude = False
save_path = 'data/clean_data.csv' if exclude else 'data/unclean_data.csv'

make_clean_dataset(
    input_data_path = 'data/combined_main.csv',
    save_path = save_path,

    # Set outlier criteria
    ab_crit = 100,   # allocation variance critical value
    rb_crit = 2 ,    # response bias critical value
    
    # Retain or not
    exclude = exclude
)

# Activity choices in different NAM groups

In [None]:
def make_fig(data_path, nam_data_path, figname, save_to, save_as=None):
    # Load data
    df = pd.read_csv(data_path).filter(items=['sid','group','trial','activity'])
    
    # Select only required free-play trials (N = 250)
    df = df.loc[df.trial.le(60+250) & df.trial.gt(60), :]
    nam_df = pd.read_csv(nam_data_path).filter(items=['sid','nam']).set_index('sid')
    df = df.dropna().drop(columns='trial')
    df = df.merge(nam_df, on='sid')
    df = df.loc[df.nam.gt(0), :]
    
    # Count trials per activity for each subject
    counts = df.groupby(['group', 'nam', 'sid']).activity.value_counts().to_frame('counts')
    counts = 100*counts / counts.groupby(['group', 'nam', 'sid']).transform('sum')
    counts_stats = counts.groupby(['group','nam','activity']).agg(['mean', 'sem'])
    counts_stats.columns = counts_stats.columns.droplevel(0)
    display(counts_stats)

    fig = plt.figure(figname, figsize=[8, 3])
    for nam in [1,2,3]:
        ax = vut.pretty(fig.add_subplot(1,3,nam), 'y')
        ax.axhline(25, ls='--', color='k', alpha=.8)
        x = np.array([1, 2, 3, 4])
        for i, group in enumerate([1, 0]):
            x_ = x+[-.1, .1][group]
            y = counts_stats.loc[(group, nam, slice(None)), 'mean']
            yerr = counts_stats.loc[(group, nam, slice(None)), 'sem']
            ax.errorbar(x_, y, yerr=yerr, color=gcolors[group], marker=gmarkers[group],
                        capsize=5, markersize=8, lw=2, label=fullglabels[group])
        ax.set_ylim(10, 57)
        ax.set_xticks(x)
        ax.set_xticklabels(['A1', 'A2', 'A3', 'A4'], fontsize=12, fontweight='bold')
        for xt, c in zip(ax.get_xticklabels(), colors):
            xt.set_color(c)
        
        ax.set_title('NAM {}'.format(nam), fontsize=12)
        if nam == 1: 
            ax.set_ylabel('Trials per activity\n(%; Mean and SEM)', fontsize=12)
        if nam == 2: 
            ax.set_xlabel('Learning activity'.format(nam), fontsize=12)
        
    leg = ax.legend(fontsize=12)

    fig.tight_layout()
    if save_as:
        vut.save_it(fig, save_to, figname, save_as=save_as, compress=False)


make_fig(
    data_path = 'data/clean_data.csv',
    nam_data_path = 'data/nam_data.csv',
    figname = 'choices_by_nam',
    save_to = 'figures',
    save_as = '' # File format (png, jpeg, svg, ...)
)

# Proportion of correct responses over time

In [None]:
def make_bottom_fig(data_path, figname, save_to, save_as=None):
    # Load data, select columns
    df = pd.read_csv(data_path)[['sid','group','trial','correct']]
    
    # Select free-play trials
    df = df.loc[(df.trial <= 60+250) & (df.trial > 60)]
    df.loc[:, 'trial'] -= 60
    
    # Calculate percentage of subjects who guessed correctly for each trial in each group
    df = df.groupby(['group', 'trial'])[['correct']].mean()
    df.loc[:, 'correct'] = df.correct * 100
    
    # Fit linear regression (set EG for contrast reference)
    lm = ols('correct ~ trial*C(group,Treatment(reference=1))', data=df.reset_index()).fit()
    params = lm.params
    display(lm.summary())

    # Make figure
    fig = plt.figure(figname, figsize=[5, 3])
    ax = vut.pretty(fig.add_subplot(111))

    for grp in [0, 1]:
        # Display raw percentages
        x = np.arange(1, 251)
        y = df.loc[(grp, slice(None)), :].values.squeeze()
        ax.plot(x, y, color=gcolors[grp], ls='', alpha=.3, marker='.')
        
        # Display model predictions 
        x_ = np.array([1, 251])
        y_ = lm.predict({'group': (grp, grp), 'trial': x_})
        ax.plot(x_, y_, color=gcolors[grp], lw=2, alpha=.9, label=fullglabels[grp])
        
        # Print fitted line equations
        intercept = params[0] + params[1] * grp
        slope = params[2] + params[3] * grp
        pos = int(slope > 0)
        txt = 'Y = {:.3f} {} {:.3f}*X'.format(intercept, '-+'[pos], np.abs(slope))
        print('group {}: {}'.format(grp, txt))

    ax.set_xlabel('Trial', fontsize=14)
    ax.set_ylabel('% correct', fontsize=14)

    leg = ax.legend(fontsize=14, ncol=2)
    vut.color_legend(leg)

    ax.set_ylim(55, 85)
    ax.set_xlim(1, 250)
    fig.tight_layout()

    if save_as:
        vut.save_it(fig, save_to, figname, save_as=save_as, compress=False)
        

make_bottom_fig(
    data_path='data/clean_data.csv',
    figname='figure2b_bottom',
    save_to='figures',
    save_as='' # File format (png, jpeg, svg, ...)
)

# Self-challenge index

## Relationship with *flat* final performance

In [None]:
def make_fig(data_path, figname, save_to, save_as=''):
    df = pd.read_csv(data_path)
    df = df.loc[df.nam > 0, :]

    propS = np.sum(df.group == 1) / df.shape[0]

    fig = plt.figure(figname, figsize=[7, 7])
    gs = gridspec.GridSpec(2, 2)

    # Make figure (scatter plot and histograms)
    ghost_top = fig.add_subplot(gs[0, 0])
    ghost_top.set_ylabel('Relative frequency', fontsize=14, labelpad=30)
    ghost_top.tick_params(left=False, labelleft=False, bottom=False, labelbottom=False)
    for spine in ghost_top.spines.values(): spine.set_visible(False)

    ax_top1 = vut.pretty(inset_axes(ghost_top, width='100%', height='30%', loc=9, borderpad=0))
    ax_top2 = vut.pretty(inset_axes(ghost_top, width='100%', height='30%', loc=10, borderpad=0))
    ax_top3 = vut.pretty(inset_axes(ghost_top, width='100%', height='30%', loc=8, borderpad=0))

    ghost_right = fig.add_subplot(gs[1, 1])
    ghost_right.set_xlabel('Relative frequency', fontsize=14, labelpad=30)
    ghost_right.tick_params(left=False, labelleft=False, bottom=False, labelbottom=False)
    for spine in ghost_right.spines.values(): spine.set_visible(False)
    ax_right1 = vut.pretty(inset_axes(ghost_right, width='30%', height='100%', loc=6, borderpad=0))
    ax_right2 = vut.pretty(inset_axes(ghost_right, width='30%', height='100%', loc=10, borderpad=0))
    ax_right3 = vut.pretty(inset_axes(ghost_right, width='30%', height='100%', loc=7, borderpad=0))

    ax_scat = vut.pretty(fig.add_subplot(gs[1, 0]))

    bins = np.arange(0, 1.02, .1)
    labels = {'nam': ['NAM ' + str(i) for i in (0, 1, 2, 3)], 'group': fullglabels}
    axes_by_nam = {
        'top': {
            3: ax_top1,
            2: ax_top2,
            1: ax_top3},
        'right': {
            1: ax_right1,
            2: ax_right2,
            3: ax_right3}}
    for nam in [1, 2, 3]:
        axes_by_nam['top'][nam].set_xlim(.0, .8)
        axes_by_nam['top'][nam].set_ylim(0., .4)
        axes_by_nam['right'][nam].set_xlim(0., .45)
        axes_by_nam['right'][nam].set_ylim(0.38, 1.)
        for group in [0, 1]:
            x = df.loc[(df.nam == nam) & (df.group == group), 'sc_lep']
            y = df.loc[(df.nam == nam) & (df.group == group), 'fpc']
            ax_scat.scatter(x, y, s=30, alpha=.7,
                            facecolors=ncolors[nam - 1] if group else 'w',
                            edgecolors=ncolors[nam - 1])

            rf, _ = np.histogram(x, bins=bins, weights=np.ones_like(x) / x.size)
            axes_by_nam['top'][nam].plot(bins[:-1], rf, c=ncolors[nam - 1], lw=2,
                                         ls='-' if group else '--',
                                         label='{} / NAM-{}'.format(glabels[group], nam))
            axes_by_nam['top'][nam].tick_params(labelbottom=False)
            axes_by_nam['top'][nam].text(.02, .9, 'NAM-{}'.format(nam), ha='left', va='top', fontsize=12,
                                         color=ncolors[nam - 1], transform=axes_by_nam['top'][nam].transAxes)

            rf, _ = np.histogram(y, bins=bins, weights=np.ones_like(x) / x.size)
            axes_by_nam['right'][nam].plot(rf, bins[1:], c=ncolors[nam - 1], lw=2,
                                           ls='-' if group else '--',
                                           label='{} / NAM-{}'.format(glabels[group], nam))
            axes_by_nam['right'][nam].tick_params(labelleft=False)
            axes_by_nam['right'][nam].text(.95, .02, 'NAM-{}'.format(nam), ha='right', va='bottom', fontsize=12,
                                           color=ncolors[nam - 1], transform=axes_by_nam['right'][nam].transAxes)

    ax_scat.set_xlim(.0, .8)
    ax_scat.set_ylim(0.38, 1.)
    ax_scat.set_xlabel('Self-challenge (SC)', fontsize=14)
    ax_scat.set_ylabel('Mean performance', fontsize=14)

    # Edit legend
    c = 'darkgray'
    mark_ig = lines.Line2D([0], [0], ls='', marker='o', label=fullglabels[0], markerfacecolor='w', markeredgecolor=c)
    line_ig = lines.Line2D([0], [0], color=c, lw=2, label=fullglabels[0], ls='--', dashes=(2, 1))

    mark_eg = lines.Line2D([0], [0], color=c, ls='', marker='o', label=fullglabels[1])
    line_eg = lines.Line2D([0], [0], color=c, lw=2, label=fullglabels[1])

    ax_scat.legend(((line_ig, mark_ig), (line_eg, mark_eg)), fullglabels.values(),
                   bbox_to_anchor=(.5, 1.1,),
                   fontsize=12, ncol=3, loc='center',
                   handler_map={tuple: legend_handler.HandlerTuple(ndivide=None)})

    # Plot line of best fit for unstandardized data
    qreg = ols('fpc ~ (ipc + {0} + np.power({0}, 2) + group)'.format('sc_lep'), data=df).fit()
    x = np.linspace(df.loc[:, 'sc_lep'].min(), df.loc[:, 'sc_lep'].max(), 100)
    y_hat = qreg.get_prediction({'sc_lep': x,
                                 'ipc': np.full_like(x, df.dwipc.mean()),
                                 'group': np.full_like(x, propS)
                                 }).summary_frame()
    display(y_hat.head())
    c, alpha = 'k', .7
    ax_scat.plot(x, y_hat['mean'], c=c, alpha=alpha)
    ax_scat.plot(x, y_hat['mean_ci_lower'], c=c, lw=1, ls='--', alpha=alpha)
    ax_scat.plot(x, y_hat['mean_ci_upper'], c=c, lw=1, ls='--', alpha=alpha)
    
    # Run quadratic regression of final performance
    df.loc[:, 'sc_lep'] = scs.stats.zscore(df.loc[:, 'sc_lep']) # Standardize x before fitting the quadratic model
    qreg = ols('dwfpc ~ dwipc + group + sc_lep + np.power(sc_lep, 2)', data=df).fit()
    display(qreg.summary())
    
    # Run nonquadratic regression and compare AIC
    nonqreg = ols('fpc ~ (ipc + sc_lep + group)', data=df).fit()
    print('Delta AIC = {:.2f}'.format(qreg.aic - nonqreg.aic))

    # Run model of average SC as a function of Group x NAM
    lreg = ols('sc_lep ~ C(group) * C(nam)', data=df).fit()
    display(lreg.summary())
    
    # Show group and subgroup counts
    display(df.groupby(['group', 'nam'])['sid'].agg('count'))
    
    # Save figure
    fig.tight_layout()
    if save_as:
        vut.save_it(fig, save_to, figname=figname, save_as=save_as, compress=False, dpi=100)


make_fig(
    data_path = 'data/learning_data.csv',
    figname = 'sm_fig4',
    save_to = 'figures',
    save_as = '' # File format (png, jpeg, svg, ...)
)

## Relationship with activity choices

In [None]:
def main(model_data_path, learning_data_path, figname, save_to, save_as=''):
    # Load data
    df = pd.read_csv(model_data_path, index_col='sid')
    df = df.filter(items='trial,abst1,abst2,abst3,abst4'.split(',')).set_index('trial', append=True)
    df = df.loc[(slice(None), 250), :]
    df.index = df.index.droplevel(1)
    df = df.rename(columns={'abst1':'A1','abst2':'A2','abst3':'A3','abst4':'A4'})
    
    df = df.merge(pd.read_csv(learning_data_path, index_col='sid').filter(items=['sc_flat']), on='sid')
    display(df.head())
    
    fig = plt.figure(num=figname, figsize=[8, 8])
    gs = GridSpec(4, 4)
    
    ghost = fig.add_subplot(111)
    ghost.set_title('Correlations of activity preferences with average SC', pad=10)
    ghost.set_ylabel('Average SC', fontsize=14, labelpad=35)
    ghost.set_xlabel('Choice preference', fontsize=14, labelpad=30)
    ghost.tick_params(left=False, labelleft=False, bottom=False, labelbottom=False)
    for spine in ghost.spines.values(): 
        spine.set_visible(False)
    
    for i, act1 in enumerate(['A1','A2','A3','A4']):
        for j, act2 in enumerate(['A1','A2','A3','A4']):
            ax = vut.pretty(fig.add_subplot(gs[i, j]))
#             ax.set_xlim(-250, 250)
            ax.set_ylim(0, 1)
            if i != 3:
                ax.tick_params(labelbottom=False)
            if j:
                ax.tick_params(labelleft=False)
            if i == j:
                plt.axis('off')
                continue
            else:
                x = df.loc[:, act1] - df.loc[:, act2]
                y = df.loc[:, 'sc_flat']
                sns.regplot(
                    x=x, y=y, ax=ax, color='k',
                    scatter_kws={'alpha': .3, 's': 3}
                )
                ax.set_ylabel('')
                ax.text(0, .83, '{} - {}'.format(act1, act2), ha='center')
        
    if save_as:
        vut.save_it(fig, save_to, figname=figname, save_as=save_as, compress=False, dpi=100)
    

main(
    model_data_path = 'data/model_data.csv',
    learning_data_path = 'data/learning_data.csv',
    figname = 'response_fig3',
    save_to = 'figures',
    save_as = 'png' # File format (png, jpeg, svg, ...)
)

# Learning and preference for A3 vs A1

In [None]:
def make_fig(data_path, data_path2, figname, save_to, save_as=''):
    df2 = pd.read_csv(data_path2, index_col='sid')
    df2 = df2.filter(items='trial,abst1,abst2,abst3,abst4'.split(',')).set_index('trial', append=True)
    df2 = df2.loc[(slice(None), 250), :]
    df2.index = df2.index.droplevel(1)
    df2 = df2.rename(columns={'abst1':'A1','abst2':'A2','abst3':'A3','abst4':'A4'})
    display(df2.head())
    
    df = pd.read_csv(data_path).merge(df2, on='sid')
    df = df.loc[df.nam > 0, :]
    df['pref'] = df.A3

    propS = np.sum(df.group == 1) / df.shape[0]

    fig = plt.figure(figname, figsize=[4, 4])

    ax_scat = vut.pretty(fig.add_subplot(111))

    bins = np.arange(0, 1.02, .1)
    labels = {'nam': ['NAM ' + str(i) for i in (0, 1, 2, 3)], 'group': fullglabels}
    for nam in [1, 2, 3]:
        for group in [0, 1]:
            x = df.loc[(df.nam == nam) & (df.group == group), 'pref']
            y = df.loc[(df.nam == nam) & (df.group == group), 'dwipc']
            ax_scat.scatter(x, y, s=30, alpha=.7,
                            facecolors=ncolors[nam - 1] if group else 'w',
                            edgecolors=ncolors[nam - 1])

#     ax_scat.set_xlim(-245, 200)
    ax_scat.set_ylim(0.38, 1.)
    ax_scat.set_xlim(0, 250)
    ax_scat.set_xlabel('Selection of A4 vs. A1', fontsize=14)
    ax_scat.set_ylabel('Mean performance', fontsize=14)

    # Edit legend
    c = 'darkgray'
    mark_ig = lines.Line2D([0], [0], ls='', marker='o', label=fullglabels[0], markerfacecolor='w', markeredgecolor=c)
    line_ig = lines.Line2D([0], [0], color=c, lw=2, label=fullglabels[0], ls='--', dashes=(2, 1))

    mark_eg = lines.Line2D([0], [0], color=c, ls='', marker='o', label=fullglabels[1])
    line_eg = lines.Line2D([0], [0], color=c, lw=2, label=fullglabels[1])

    ax_scat.legend(((line_ig, mark_ig), (line_eg, mark_eg)), fullglabels.values(),
                   bbox_to_anchor=(.5, 1.1,),
                   fontsize=12, ncol=3, loc='center',
                   handler_map={tuple: legend_handler.HandlerTuple(ndivide=None)})

    
    # Run quadratic regression of final performance
    df.loc[:, 'pref'] = scs.stats.zscore(df.loc[:, 'pref']) # Standardize x before fitting the quadratic model
    qreg = ols('pref ~ dwipc', data=df).fit()
    display(qreg.summary())
    
    # Run nonquadratic regression and compare AIC
    nonqreg = ols('dwfpc ~ (dwipc + pref + group)', data=df).fit()
    display(nonqreg.summary())
    print('Delta AIC = {:.2f}'.format(qreg.aic - nonqreg.aic))

    
    # Show group and subgroup counts
    display(df.groupby(['group', 'nam'])['sid'].agg('count'))
    
    # Save figure
    fig.tight_layout()
    if save_as:
        vut.save_it(fig, save_to, figname=figname, save_as=save_as, compress=False, dpi=100)


make_fig(
    data_path = 'data/learning_data.csv',
    data_path2 = 'data/model_data.csv',
    figname = 'response_fig4',
    save_to = 'figures',
    save_as = '' # File format (png, jpeg, svg, ...)
)

# Self-reports

## Average ratings by instruction and NAM

In [None]:
def make_fig(data_path, nam_data_path, item, norm, figname, save_to, save_as=None):
    # Load data
    df = pd.read_csv(data_path)
    df = df.merge(pd.read_csv(nam_data_path).loc[:, ('sid', 'nam')], on='sid')
    df = df.loc[df.item.eq(item) & df.nam.gt(0), :]
    
    if norm:
        df.loc[:, 'rating'] = df.rating_norm
        df.drop(columns='rating_norm', inplace=True)
    
    # Calculate average scores
    df = df.groupby(['group','nam','activity'])[['rating']].agg(['mean', 'sem'])
    df.columns = df.columns.droplevel(0)
    display(df)
    
    # Plot results
    fig = plt.figure(figname, figsize=[8, 3])
    for nam in [1,2,3]:
        ax = vut.pretty(fig.add_subplot(1, 3, nam), 'y')
        x = np.array([1, 2, 3, 4])
        for i, grp in enumerate([1, 0]):
            x_ = x+[-.1, .1][grp]
            y = df.loc[(grp, nam, slice(None)), 'mean']
            yerr = df.loc[(grp, nam, slice(None)), 'sem']
            ax.errorbar(x_, y, yerr=yerr, color=gcolors[grp], marker=gmarkers[grp],
                        capsize=5, markersize=8, lw=2, label=fullglabels[grp])
#         ax.set_ylim(4.2, 8.1)
        ax.set_xticks(x)
        ax.set_xticklabels(['A1', 'A2', 'A3', 'A4'], fontsize=12, fontweight='bold')
        for xt, c in zip(ax.get_xticklabels(), colors):
            xt.set_color(c)
        
        ax.set_title('NAM {}'.format(nam), fontsize=12)
        if nam == 1: 
            ax.set_ylabel('Subjective interest\n(Mean and SEM)', fontsize=12)
        if nam == 2: 
            ax.set_xlabel('Learning activity'.format(nam), fontsize=12)
        
    leg = ax.legend(fontsize=12)

    fig.tight_layout()
    if save_as:
        vut.save_it(fig, save_to, figname, save_as=save_as, compress=False)


make_fig(
    data_path = 'data/combined_extra.csv',
    nam_data_path = 'data/nam_data.csv',
    item = 'time',
    norm = True,
    figname = 'interest_by_nam',
    save_to = 'figures',
    save_as = '' # File format (png, jpeg, svg, ...)
)

## Relationship between time and interest by group

In [None]:
def prepare_data(data_path, ratings_data_path):
    # Load data of activity choices during free play
    df = pd.read_csv(data_path).filter(items=['sid','nam','trial','abst1','abst2','abst3','abst4'])
    df = df.loc[df.trial.eq(250) & df.nam.gt(0), :].drop(columns='trial')
    df = pd.wide_to_long(df, stubnames='abst', i=['sid','nam'], j='act_ind').reset_index()
    df['activity'] = 'A'+df.act_ind.astype(str)
    df = df.drop(columns='act_ind')
    df = df.rename(columns={'abst': 'time'})
    
    # Merge with interest rating data
    ratings_df = pd.read_csv(ratings_data_path)
    ratings_df = ratings_df.loc[ratings_df.item.eq('int'), :]
    df = df.merge(ratings_df.drop(columns='item'), on=['sid', 'activity'])
    return df


def subplot_a(ax, df):
    for group in [0, 1]:
        x = df.loc[df.group.eq(group), 'time']
        y = df.loc[df.group.eq(group), 'rating_norm']
        sns.regplot(x=x, y=y, color=gcolors[group], 
                    ax=ax, scatter_kws={'alpha': .3, 's': 5})
        ax.set_xlabel('Number of trials')
        ax.set_ylabel('Interest rating\n(Mean centered)')
        
        
def subplot_b(axes, df):
    for group in [0, 1]:
        ax = axes[group]
        for nam in [1, 2, 3]:
            x = df.loc[df.group.eq(group) & df.nam.eq(nam), 'time']
            y = df.loc[df.group.eq(group) & df.nam.eq(nam), 'rating_norm']
            sns.regplot(x=x, y=y, color=ncolors[nam-1], 
                        ax=ax, scatter_kws={'alpha': .3, 's': 5})
        ax.set_ylabel('Interest rating\n(Mean centered)')
        ax.set_xlabel('Number of trials')
        ax.set_title(fullglabels[group], color=gcolors[group], fontweight='bold')
            


def make_fig(modeling_data_path, ratings_data_path, figname, save_to, save_as=None):
    # Load data and create figure
    df = prepare_data(modeling_data_path, ratings_data_path)
    fig = plt.figure(figname, figsize=[9, 3])
    
    # Plot relationship between interest and time by group
    ax = vut.pretty(fig.add_subplot(1,3,1))
    subplot_a(ax, df)
    lm = ols('rating ~ C(group) * time', data=df)
    display(lm.fit().summary())
    
    # Plot relationship between interest and time by NAM in each group
    axes = [vut.pretty(fig.add_subplot(1,3,2)), vut.pretty(fig.add_subplot(1,3,3))]
    subplot_b(axes, df)
    for group in [0, 1]:
        print('==='*20)
        lm = ols('rating ~ C(nam) * time', data=df.loc[df.group.eq(group), :])
        display(lm.fit().summary())
        

    fig.tight_layout()
    if save_as:
        vut.save_it(fig, save_to, figname, save_as=save_as, compress=False)


make_fig(
    modeling_data_path='data/model_data.csv', 
    ratings_data_path='data/combined_extra.csv',
    figname = 'interest_time_reg',
    save_to = 'figures',
    save_as = '' # File format (png, jpeg, svg, ...)
)

# Modeling

## Individual model predictions

In [None]:
def main(data_path, sid, n_max, n_stop, init_dict, figname, save_to='', save_as=''):
    # Load data
    sdf = lut.get_sdf('data/model_data.csv', sid)
    resps = sdf.correct.values.astype(float)

    # Create model
    model = SoftmaxChoiceModel(
        objective = neg_log_likelihood,
        data = sdf,
        init_dict = init_dict,
        hist = True
    )
    
    # Fit and visualize
    fig = plt.figure('Data transformation', figsize=[7.5, 10])
    gs = GridSpec(3+len(model.components), 2, height_ratios=[1,1,1,1,.5], width_ratios=[.05, 1])

    # Transform data
    model.transform_inp_data(normalize)

    # Fit
    model.n_best_stop(n_stop=n_stop, max_iter=n_max, show_progress=True)
    print(model.params[-1])

    # Plot utility components
    for j, comp in enumerate(model.components):
        vut.add_subplot_label(x=1, y=1, label='ab'[j], size=20, ax=fig.add_subplot(gs[j, 0]))
        ax = vut.pretty(ax = fig.add_subplot(gs[j, 1]))
        ax.set_prop_cycle('color', colors)
        ax.plot(model.fit_data[j])
        ax.set_ylabel('Recent '+comp[1:].upper())
        ax.set_xlim(1, 250)
        ax.set_ylim(-.05, 1.05)
        ax.axhline(0, color='gray', ls=':')
        ax.tick_params(labelbottom=False)
    
    # Plot utility and predictions
    u, p = model.get_predictions()
    vut.add_subplot_label(x=1, y=1, label='c', size=20, ax=fig.add_subplot(gs[len(model.components)+0, 0]))
    ax = vut.pretty(fig.add_subplot(gs[len(model.components)+0, 1]))
    ax.set_prop_cycle('color', colors)
    ax.plot(u)
    ax.set_xlim(1, 250)
    ax.tick_params(labelbottom=False)
    ax.set_ylabel('Utility')
    params = model.params[:-1]
    norm_params = params / np.sqrt(np.sum(params*params))
    signs = [[r'$+$ ', r'$-$ '][int(p < 0)] for p in norm_params]
    t = r'$U$ ='
    for ind, (sign, coef, comp) in enumerate(zip(signs, norm_params, model.components)):
        if ind==0 and coef >= 0: 
            sign=''
        t = t + fr' {sign}{np.abs(coef):.2f}$\times {comp[1:].upper()}$'
    ax.set_title(t, fontsize=14, loc='right', pad=-5)
    
    vut.add_subplot_label(x=1, y=1, label='d', size=20, ax=fig.add_subplot(gs[len(model.components)+1, 0]))
    ax = vut.pretty(fig.add_subplot(gs[len(model.components)+1, 1]))
    ax.set_prop_cycle('color', colors)
    ax.plot(p)
    ax.set_xlim(1, 250)
    ax.set_ylim(0, 1.05)
    ax.tick_params(labelbottom=False)
    ax.set_ylabel('Probability')
    
    legend_elements = [Patch(facecolor=c, label=l) for c,l in zip(colors, tlabels.values())]
    ax.legend(handles=legend_elements, bbox_to_anchor=[0.5, 1.1], loc='center', ncol=4)
    
    # Plot subject data
    vut.add_subplot_label(x=1, y=1, label='e', size=20, ax=fig.add_subplot(gs[len(model.components)+2, 0]))
    ax = vut.pretty(fig.add_subplot(gs[len(model.components)+2, 1]))
    ax.set_xlim(1, 250)
    x = np.arange(1, 250)
    for i in [0,1,2,3]:
        y = np.zeros_like(x)+.2
        mask = model.choice_data[:, i].astype(bool)
        y[~mask] = np.nan
        ax.plot(x, y, c=colors[i], lw=10)
    resps[~resps.astype(bool)] = np.nan
    ax.plot(resps*.4, ls='', marker='|', color='k')
    ax.set_ylim(0, 1)
    vut.despine(ax, ['left','right','top'])
    ax.set_ylabel('Actual choices\nand responses', labelpad=25)
    ax.set_xlabel('Trial')
    ax.tick_params(labelleft=False, left=False)
    legend_elements = [
        Line2D([0],[0], ls='', marker='|', color='k', label='Correct guesses')
    ]
    ax.legend(handles=legend_elements, bbox_to_anchor=[0.5, 1], loc='center')
    
    fig.tight_layout()
    fig.subplots_adjust(hspace=.3, right=.90)
    if save_as:
        vut.save_it(fig, save_to, figname, 
            save_as=save_as, compress=False, dpi=100)

    
main(
    data_path = 'data/model_data.csv',
    sid = 104,
    n_max = 500,
    n_stop = 10,
    init_dict = {
        # 'param_handle': (init_range, apply_boundary)
        'rpc':([-1,1], True),
        'rlp':([-1,1], True),
        'tau': ([0,100], True),
    },
    figname = 'sm_fig1',
    save_to = 'figures',
    save_as = ''
)

## Group level coefficients and behavior

Effect sizes conventions for reference

|Effect size|*d*|
|---|---|
|Very small|0.01|
|Small|0.20|
|Medium|0.50|
|Large|0.80|
|Very large|1.20|
|Huge|2.0|

### Full version
(includes all groups)

In [None]:
def load_data(heuristics_data_path, params_data_path, learning_data_path, nq=1, nam=None):
    idx = pd.IndexSlice
    # Heuristics dataset
    hdf = pd.read_csv(heuristics_data_path)
    # Params dataset
    pdf = pd.read_csv(params_data_path)
    pdf = pdf.loc[pdf.vars.eq('rpc,rlp'), :]
    # Learning dataset
    ldf = pd.read_csv(learning_data_path)
    # Optionally, filter by NAM
    if nam: 
        hdf = df.loc[hdf.nam.eq(nam), :]
        pdf = df.loc[pdf.nam.eq(nam), :]
    # Get normalized coefficient values and label quantiles
    norm = np.linalg.norm(pdf.loc[:, 'rpc':'rlp'].values, axis=1)
    pdf['norm_rpc'] = pdf.rpc / norm
    pdf['norm_rlp'] = pdf.rlp / norm
#     pdf['qi_rpc'] = pd.qcut(pdf.norm_rpc, q=nq)
#     pdf['qi_rlp'] = pd.qcut(pdf.norm_rlp, q=nq)
    pdf['qi_rpc'] = pd.cut(pdf.norm_rpc, bins=nq)
    pdf['qi_rlp'] = pd.cut(pdf.norm_rlp, bins=nq)
    # Select columns in each df
    pcols = ['sid','group','nam','rpc', 'rlp', 'norm_rpc', 'norm_rlp', 'qi_rpc', 'qi_rlp']
    hcols = ['sid', 'group', 'trial'] + [i for sub in [[v+si for si in '1234'] for v in ['rpc', 'rlp', 'ch']] for i in sub]
    lcols = ['sid', 'dwfpc', 'dwipc']
    return {'p': pdf.filter(items=pcols), 'h': hdf.filter(items=hcols), 'l': ldf.filter(items=lcols)}


def subfig_a(axes, df, qi_rlp, qi_rpc, bins=25):
    # Group offsets
    off = [-.05, .05]
    # Bin data
    qrlp = df.qi_rlp.unique()[qi_rlp]
    qrpc = df.qi_rpc.unique()[qi_rpc]
    # Plot joint data
    for group in [0, 1]:
        ax = axes[1]
        grp_flt = df.group.eq(group)
        q_flt = df.qi_rlp.eq(qrlp) & df.qi_rpc.eq(qrpc)
        ax.scatter(
            df.loc[grp_flt & ~q_flt, 'norm_rlp']+off[group],
            df.loc[grp_flt & ~q_flt, 'norm_rpc']+off[group],
            alpha=.05, color=gcolors[group], marker='o', s=5)
        ax.scatter(
            df.loc[grp_flt & q_flt, 'norm_rlp']+off[group],
            df.loc[grp_flt & q_flt, 'norm_rpc']+off[group],
            alpha=.3, color=gcolors[group], marker='o', s=10)
    ax.text(0, 0, 'N = {:}'.format(q_flt.sum()), ha='center', va='center')
#         display(df.groupby('group')[['norm_rpc','norm_rlp']].agg(['mean','sem']))
    # Plot marginal data
    sns.histplot(x='norm_rlp', data=df, ax=axes[0], stat='probability', bins=bins, element='step', color='gray')
    axes[0].axvspan(qrlp.left, qrlp.right, color='magenta', alpha=.2)
    sns.histplot(y='norm_rpc', data=df, ax=axes[2], stat='probability', bins=bins, element='step', color='gray')
    axes[2].axhspan(qrpc.left, qrpc.right, color='magenta', alpha=.2)
    # Labels
    for ax in (axes[0], axes[2]):
        ax.set_xlabel('')
        ax.set_ylabel('')
    axes[0].tick_params(labelbottom=False)
    axes[2].tick_params(labelleft=False, labelrotation=-90)
    axes[0].set_xlabel(r'$w_{LP}$')
    axes[0].xaxis.set_label_position('top')
    axes[2].set_ylabel(r'$w_{PC}$', rotation=-90, va='top')
    axes[2].yaxis.set_label_position('right')


def subfig_var(ax, df, variable, errb=False, sample_size=False, lw=2):
    idx = pd.IndexSlice
    ax.set_prop_cycle(cycler(color=colors))
    grouped = df.groupby('trial').agg(['mean', 'sem'])
    s = (variable+'{}').format
    y = grouped.loc[:, idx[s(1):s(4), 'mean']].values
    x = np.stack([np.arange(y.shape[0]) for i in range(y.shape[1])], axis=1)
    yerr = grouped.loc[:, idx[s(1):s(4), 'sem']].values
    ax.plot(x, y, lw=1)
    if errb:
        for i in range(x.shape[1]):
            ax.fill_between(x[:, i], y1=y[:, i]+yerr[:, i], y2=y[:, i]-yerr[:, i], alpha=.2)
    ax.set_xlim(0, 250)
    if sample_size:
        ax.set_title('N = {}'.format(len(df.sid.unique())))

        
def subfig_stats(ax, df):
    t, pval = scs.ttest_rel(df.dwipc, df.dwfpc)
    d = (df.dwfpc.mean() - df.dwipc.mean()) / np.sqrt((df.dwfpc.std()**2 + df.dwipc.std()**2)/2)
    df = df.melt()
    df = df.replace({'dwfpc': 'dwfPC', 'dwipc': 'dwiPC'})
    sns.barplot(
        x='variable', y='value', data=df, order=['dwiPC', 'dwfPC'], ax=ax,
        linewidth=1, facecolor=(1, 1, 1, 0), errcolor='k', edgecolor='k'
    )
    ax.text(.5, .90, r'$d={:.3f}^*$'.format(d, pval), va='top', ha='center')
    ax.set_xlabel(''); ax.set_ylabel('')
    vut.change_width(ax, .6)
    

def make_fig(nq, figname, save_to, save_as=''):
    nrows, ncols = 5+1, (nq**2-1)
    fig = plt.figure(num=figname, figsize=[2.5 + 2*ncols, 8])
    subfig_ratios = [.7, .4]
    gs = fig.add_gridspec(
        ncols = 1 + ncols, 
        nrows = nrows, 
        width_ratios = [.1] + list(np.ones(ncols)),
        height_ratios = [2] + list(np.ones(nrows-1) + [.25])
    )
    add = fig.add_subplot
    
    # Load data
    data = load_data(
        heuristics_data_path = 'data/model_data.csv',
        params_data_path = 'data/model_results/param_fits_clean.csv',
        learning_data_path = 'data/learning_data.csv',
        nam = None,
        nq = nq
    )
    
    # Annotate figure rows
    for i, letter in enumerate('abcde'):
        vut.add_subplot_label(x=0, y=1, label=letter, size=18, ax=fig.add_subplot(gs[i, 0]))
    
    # Plot data
    qs_rlp = data['p'].qi_rlp.unique()
    qs_rpc = data['p'].qi_rpc.unique()
    # Order according to 
    for ci, (i, j) in enumerate(zip([0,1,2,0,2,2,0,1], [0,2,2,2,0,1,1,1]), 1):
        # Subplot (a)
        main_ax = vut.pretty(add(gs[0, ci], aspect='equal'))
        divider = make_axes_locatable(main_ax)
        marg_ax1 = vut.pretty(divider.append_axes('top', '30%', pad=0.2, sharex=main_ax))
        marg_ax2 = vut.pretty(divider.append_axes('right', '30%', pad=0.2, sharey=main_ax))
        axes = [marg_ax1, main_ax, marg_ax2]
        subfig_a(axes, data['p'], qi_rlp=i, qi_rpc=j)

        # Filter sids for subplots (b) to (e)
        sids = data['p'].loc[
                data['p'].qi_rlp.eq(qs_rlp[i]) & data['p'].qi_rpc.eq(qs_rpc[j])
            ].sid.unique()

        # Select data for subplots (b) to (d)
        df = data['h'].set_index('sid').loc[sids, :].reset_index()

        # Subplot (b)
        ax = vut.pretty(add(gs[1, ci]))
        subfig_var(ax, df, variable='ch', errb=False)
        ax.set_ylim(0, 0.75)
        if ci == 1: 
            ax.set_ylabel('% selection')

        # Subplot (c)
        ax = vut.pretty(add(gs[2, ci]))
        subfig_var(ax, df, variable='rlp', errb=True)
        ax.set_ylim(0.03, 0.25)
        if ci == 1:
            ax.set_ylabel('Recent LP')

        # Subplot (d)
        ax = vut.pretty(add(gs[3, ci]))
        subfig_var(ax, df, variable='rpc', errb=True)
        ax.set_ylim(0.45, 0.92)
        if ci == 1:
            ax.set_ylabel('Recent PC')

        # Subplot (e)
        # Select data for subplots (e)
        df = data['l'].set_index('sid').loc[sids, :]
        ax = vut.pretty(add(gs[4, ci]))
        subfig_stats(ax, df)
        ax.set_ylim(.5, .9)
        if ci == 1:
            ax.set_ylabel('score')
            
    # Add legend
    ax = vut.ghost(add(gs[5, 1:]))
    handles = [lines.Line2D([0], [0], color=colors[k], ls='', marker='o', label=tlabels[k + 1]) for k in range(4)]
    handles += [lines.Line2D([0], [0], ls='', marker='o', markerfacecolor=c, markeredgecolor=c, color=c) for c in gcolors]
    legw, legh = .25, .2
    leg = ax.legend(handles, list(tlabels.values())+list(fullglabels.values()), handletextpad=.05,
                    bbox_to_anchor=(.5-legw/2, .5, legw, legh), loc='center', mode='expand', ncol=3)
    vut.color_legend(leg)

    fig.tight_layout()
    fig.subplots_adjust(hspace=.4)
    if save_as:
        vut.save_it(fig, save_to, figname=figname, save_as=save_as, compress=False, dpi=100)
    
    
make_fig(
    nq = 3,
    figname = 'sm_fig2a',
    save_to = 'figures',
    save_as = '' # File format (png, jpeg, svg, ...)
)

### Simplified version
(include only cardinal groups)

In [None]:
def load_data(heuristics_data_path, params_data_path, learning_data_path, nq=1, nam=None):
    idx = pd.IndexSlice
    # Heuristics dataset
    hdf = pd.read_csv(heuristics_data_path)
    # Params dataset
    pdf = pd.read_csv(params_data_path)
    pdf = pdf.loc[pdf.vars.eq('rpc,rlp'), :]
    # Learning dataset
    ldf = pd.read_csv(learning_data_path)
    # Optionally, filter by NAM
    if nam: 
        hdf = df.loc[hdf.nam.eq(nam), :]
        pdf = df.loc[pdf.nam.eq(nam), :]
    # Get normalized coefficient values and label quantiles
    norm = np.linalg.norm(pdf.loc[:, 'rpc':'rlp'].values, axis=1)
    pdf['norm_rpc'] = pdf.rpc / norm
    pdf['norm_rlp'] = pdf.rlp / norm
#     pdf['qi_rpc'] = pd.qcut(pdf.norm_rpc, q=nq)
#     pdf['qi_rlp'] = pd.qcut(pdf.norm_rlp, q=nq)
    pdf['qi_rpc'] = pd.cut(pdf.norm_rpc, bins=nq)
    pdf['qi_rlp'] = pd.cut(pdf.norm_rlp, bins=nq)
    # Select columns in each df
    pcols = ['sid','group','nam','rpc', 'rlp', 'norm_rpc', 'norm_rlp', 'qi_rpc', 'qi_rlp']
    hcols = ['sid', 'group', 'trial'] + [i for sub in [[v+si for si in '1234'] for v in ['rpc', 'rlp', 'ch']] for i in sub]
    lcols = ['sid', 'dwfpc', 'dwipc']
    return {'p': pdf.filter(items=pcols), 'h': hdf.filter(items=hcols), 'l': ldf.filter(items=lcols)}


def subfig_a(axes, df, qi_rlp, qi_rpc, bins=25):
    # Group offsets
    off = [-.05, .05]
    # Bin data
    qrlp = df.qi_rlp.unique()[qi_rlp]
    qrpc = df.qi_rpc.unique()[qi_rpc]
    # Plot joint data
    for group in [0, 1]:
        ax = axes[1]
        grp_flt = df.group.eq(group)
        q_flt = df.qi_rlp.eq(qrlp) & df.qi_rpc.eq(qrpc)
        ax.scatter(
            df.loc[grp_flt & ~q_flt, 'norm_rlp']+off[group],
            df.loc[grp_flt & ~q_flt, 'norm_rpc']+off[group],
            alpha=.05, color=gcolors[group], marker='o', s=5)
        ax.scatter(
            df.loc[grp_flt & q_flt, 'norm_rlp']+off[group],
            df.loc[grp_flt & q_flt, 'norm_rpc']+off[group],
            alpha=.3, color=gcolors[group], marker='o', s=10)
    ax.text(0, 0, 'N = {:}'.format(q_flt.sum()), ha='center', va='center')
#         display(df.groupby('group')[['norm_rpc','norm_rlp']].agg(['mean','sem']))
    # Plot marginal data
    sns.histplot(x='norm_rlp', data=df, ax=axes[0], stat='probability', bins=bins, element='step', color='gray')
    axes[0].axvspan(qrlp.left, qrlp.right, color='magenta', alpha=.2)
    sns.histplot(y='norm_rpc', data=df, ax=axes[2], stat='probability', bins=bins, element='step', color='gray')
    axes[2].axhspan(qrpc.left, qrpc.right, color='magenta', alpha=.2)
    # Labels
    for ax in (axes[0], axes[2]):
        ax.set_xlabel('')
        ax.set_ylabel('')
    axes[0].tick_params(labelbottom=False)
    axes[2].tick_params(labelleft=False, labelrotation=-90)
    axes[0].set_xlabel(r'$w_{LP}$')
    axes[0].xaxis.set_label_position('top')
    axes[2].set_ylabel(r'$w_{PC}$', rotation=-90, va='top', labelpad=15)
    axes[2].yaxis.set_label_position('right')

    
def subfig_var(ax, df, variable, errb=False, sample_size=False, lw=2):
    idx = pd.IndexSlice
    ax.set_prop_cycle(cycler(color=colors))
    grouped = df.groupby('trial').agg(['mean', 'sem'])
    s = (variable+'{}').format
    y = grouped.loc[:, idx[s(1):s(4), 'mean']].values
    x = np.stack([np.arange(y.shape[0]) for i in range(y.shape[1])], axis=1)
    yerr = grouped.loc[:, idx[s(1):s(4), 'sem']].values
    ax.plot(x, y, lw=1)
    if errb:
        for i in range(x.shape[1]):
            ax.fill_between(x[:, i], y1=y[:, i]+yerr[:, i], y2=y[:, i]-yerr[:, i], alpha=.2)
    ax.set_xlim(0, 250)
    if sample_size:
        ax.set_title('N = {}'.format(len(df.sid.unique())))
    
        
def subfig_choices(ax, df):
    # Compute stats
    df = df.copy().set_index(['sid', 'group'])
    df = df.groupby(['sid', 'group']).sum()[['ch1','ch2','ch3','ch4']]/250 * 100
    df = df.reset_index()
    df = df.groupby(['group']).agg(['mean', 'sem'])
    df.columns = df.columns.map('_'.join)
    # Group offsets
    off = [-.05, .05]
    for group in [0, 1]:
        mean = df.loc[group, [f'ch{i}_mean' for i in '1234']]
        sem = df.loc[group, [f'ch{i}_sem' for i in '1234']]
        ax.errorbar(x=np.arange(4)+off[group], y=mean, yerr=sem, color=gcolors[group], marker='os'[group])
    ax.set_ylim(0, 60)
    ax.set_xticks(np.arange(4))
    ax.set_xticklabels(['A1', 'A2', 'A3', 'A4'], fontweight='bold')
    for xt, c in zip(ax.get_xticklabels(), colors):
        xt.set_color(c)

        
def subfig_stats(ax, df):
    t, pval = scs.ttest_rel(df.dwipc, df.dwfpc)
    d = (df.dwfpc.mean() - df.dwipc.mean()) / np.sqrt((df.dwfpc.std()**2 + df.dwipc.std()**2)/2)
    df = df.melt()
    df = df.replace({'dwfpc': 'dwfPC', 'dwipc': 'dwiPC'})
    sns.barplot(
        x='variable', y='value', data=df, order=['dwiPC', 'dwfPC'], ax=ax,
        linewidth=1, facecolor=(1, 1, 1, 0), errcolor='k', edgecolor='k'
    )
    ax.text(.5, .90, r'$d={:.3f}^*$'.format(d, pval), va='top', ha='center')
    ax.set_xlabel(''); ax.set_ylabel('')
    vut.change_width(ax, .6)
    

def make_fig(nq, figname, save_to, save_as=''):
    nrows, ncols = 4+1, (nq**2-1)-4
    fig = plt.figure(num=figname, figsize=[2.5 + 2*ncols, 8])
    subfig_ratios = [.7, .4]
    gs = fig.add_gridspec(
        ncols = 1 + ncols, 
        nrows = nrows, 
        width_ratios = [.1] + list(np.ones(ncols)),
        height_ratios = [2] + list(np.ones(nrows-1) + [.25])
    )
    add = fig.add_subplot
    
    # Load data
    data = load_data(
        heuristics_data_path = 'data/model_data.csv',
        params_data_path = 'data/model_results/param_fits_clean.csv',
        learning_data_path = 'data/learning_data.csv',
        nam = None,
        nq = nq
    )
    
    # Annotate figure rows
    for i, letter in enumerate('abcd'):
        vut.add_subplot_label(x=0, y=1, label=letter, size=18, ax=fig.add_subplot(gs[i, 0]))
    
    # Plot data
    qs_rlp = data['p'].qi_rlp.unique()
    qs_rpc = data['p'].qi_rpc.unique()
    # Order according to 
#     for ci, (i, j) in enumerate(zip([0,1,2,0,2,2,0,1], [0,2,2,2,0,1,1,1]), 1):
    for ci, (i, j) in enumerate(zip([0,1,2,1], [0,2,0,1]), 1):
        # Subplot (a)
        main_ax = vut.pretty(add(gs[0, ci], aspect='equal'))
        divider = make_axes_locatable(main_ax)
        marg_ax1 = vut.pretty(divider.append_axes('top', '30%', pad=0.2, sharex=main_ax))
        marg_ax2 = vut.pretty(divider.append_axes('right', '30%', pad=0.2, sharey=main_ax))
        axes = [marg_ax1, main_ax, marg_ax2]
        subfig_a(axes, data['p'], qi_rlp=i, qi_rpc=j)

        # Filter sids for subplots (b) to (e)
        sids = data['p'].loc[
                data['p'].qi_rlp.eq(qs_rlp[i]) & data['p'].qi_rpc.eq(qs_rpc[j])
            ].sid.unique()

        # Select data for subplots (b) to (d)
        df = data['h'].set_index('sid').loc[sids, :].reset_index()

        # Subplot (b)
        ax = vut.pretty(add(gs[1, ci]))
        subfig_var(ax, df, variable='ch', errb=False)
        ax.set_ylim(0, 0.75)
        if ci == 1: 
            ax.set_ylabel('selection\nrates (%)')

        # Subplot (c)
        ax = vut.pretty(add(gs[2, ci]))
        subfig_choices(ax, df)
        ax.set_xlim(-.5, 3.5)
        if ci == 1:
            ax.set_ylabel('time\nallocation (%)')

        # Subplot (d)
        df = data['l'].set_index('sid').loc[sids, :]
        ax = vut.pretty(add(gs[3, ci]))
        subfig_stats(ax, df)
        ax.set_ylim(.5, .9)
        if ci == 1:
            ax.set_ylabel('score')
            
    # Add legend
    ax = vut.ghost(add(gs[4, 1:]))
    handles = [lines.Line2D([0], [0], color=colors[k], ls='', marker='o', label=tlabels[k + 1]) for k in range(4)]
    handles += [lines.Line2D([0], [0], ls='', marker='o', markerfacecolor=c, markeredgecolor=c, color=c) for c in gcolors]
    legw, legh = .25, .2
    leg = ax.legend(handles, list(tlabels.values())+list(fullglabels.values()), handletextpad=.05,
                    bbox_to_anchor=(.5-legw/2, .5, legw, legh), loc='center', mode='expand', ncol=3)
    vut.color_legend(leg)

    fig.tight_layout()
    fig.subplots_adjust(hspace=.4)
    if save_as:
        vut.save_it(fig, save_to, figname=figname, save_as=save_as, compress=False, dpi=100)
    
    
make_fig(
    nq = 3,
    figname = 'sm_fig2b',
    save_to = 'figures',
    save_as = '' # File format (png, jpeg, svg, ...)
)

## Model comparisons

### Split by groups and NAMs

In [None]:
def make_fig(data_path, figname, save_to, save_as='', __=slice(None)):
    # Load data
    plot_df = pd.read_csv(data_path, index_col='vars').filter(items=['group','nam','aic'])
    plot_df = plot_df.replace(to_replace={'group': {0: 'IG', 1: 'EG'}})
    
    # Optionally, filter out model forms with RelT
#     plot_df = plot_df[~plot_df.index.to_series().str.contains('relt')]
    
    # Relabel model forms
    new_ind = plot_df.index.to_series().str.replace(',',' + ')
    new_ind = new_ind.str.replace('rpc', 'PC')
    new_ind = new_ind.str.replace('rlp', 'LP')
    new_ind = new_ind.str.replace('abst','EXP')
    plot_df.index = pd.Index(new_ind, name='vars')

    # Calculate average scores per model form
    df = plot_df.groupby(['group','nam','vars']).agg({'aic':['mean', 'std', 'count']})
    df.columns = df.columns.map('_'.join)
    df.sort_values(by=['group','nam','aic_mean'], inplace=True)
    display(df)
    
    fig = plt.figure(figname, figsize=(6, 8))
    gs = GridSpec(3, 1)
    for i in range(3):
        ax = vut.pretty(fig.add_subplot(gs[i, 0]), 'x')
        plot_order = reversed(list(df.index.get_level_values(2))[:7])
        sub_df = plot_df.loc[plot_df.nam.eq(i+1), :].reset_index()
        sns.boxplot(
            x='aic', y='vars', data=sub_df, ax=ax,
            linewidth=1, order=plot_order, whis=100,
            hue='group', palette=gcolors, saturation=1,
        )
        for v in ax.collections: v.set_linewidth(.5)
        baseline_aic = get_baseline_aic(250, 4)
        ax.axvline(baseline_aic, ls='--', zorder=3, color='orange',
                   label='Random model')
        ax.set_xlim([10, 800])
        ax.tick_params(axis='both', labelsize=12)
        ax.text(760, 3, 'NAM {}'.format(i+1))
        
        if i == 0: 
            ax.legend(bbox_to_anchor=[0, 1.1, 1, .2], loc='lower center', mode='expand', ncol=3)
        else:
            ax.legend().remove()
            
        if i == 1: 
            ax.set_ylabel('Model', fontsize=14)
        else: 
            ax.set_ylabel('')
            
        if i == 2:
            ax.set_xlabel('AIC', fontsize=14)
        else:
            ax.set_xlabel('')
            ax.tick_params(axis='x', labelbottom=False)

    print('Baseline AIC = {:.3f}'.format(baseline_aic))
    fig.tight_layout()
    if save_as:
        vut.save_it(fig, save_to, figname=figname, save_as=save_as, compress=False, dpi=100)


make_fig(
    data_path = 'data/model_results/param_fits_clean.csv',
    figname = 'model_comparisons_full',
    save_to = 'figures',
    save_as = '' # File format (png, jpeg, svg, ...)
)

### Overall

In [None]:
def make_fig(data_path, figname, save_to, save_as='', __=slice(None)):
    # Load data
    plot_df = pd.read_csv(data_path, index_col='vars').filter(items=['aic'])
    # Relabel model forms
    new_ind = plot_df.index.to_series().str.replace(',',' + ')
    new_ind = new_ind.str.replace('rpc', 'PC')
    new_ind = new_ind.str.replace('rlp', 'LP')
    new_ind = new_ind.str.replace('abst','EXP')
    plot_df.index = pd.Index(new_ind, name='vars')
    # Calculate average scores per model form
    df = plot_df.groupby('vars').agg({'aic':['mean', 'std']})
    df.columns = df.columns.map('_'.join)
    df.sort_values(by='aic_mean', ascending=False, inplace=True)
    display(df)
    # Plot data
    fig = plt.figure(figname, figsize=(6, 6))
    ax = vut.pretty(fig.add_subplot(111))
    sns.stripplot(
        x='aic', y='vars', data=plot_df.reset_index(), ax=ax,
        color='k', alpha=.6, size=2, order=df.index
    )
    sns.boxplot(
        x='aic', y='vars', data=plot_df.reset_index(), ax=ax,
        color='lightgray', linewidth=2, order=df.index, whis=100, width=.5
    )
    for v in ax.collections: v.set_edgecolor('w')
    # Add details
    baseline_AIC = get_baseline_aic(250, 4)
    ax.axvline(baseline_AIC, ls='--', zorder=3, color='red',
               label='Random model')
#     ax.legend().remove()
    ax.set_ylabel('Model form')
    ax.set_xlabel('AIC')
    ax.set_xlim(50, 800)
    ax.tick_params(labelsize=12)
    ax.text(780, 1, 'Random model', color='red', ha='center', va='center', rotation=-90, fontsize=14)

    print('Baseline AIC = {:.3f}'.format(mut.get_baseline_aic(250, 4)))
    fig.tight_layout()
    if save_as:
        vut.save_it(fig, save_to, figname=figname, save_as=save_as, compress=False, dpi=100)


make_fig(
    data_path = 'data/model_results/param_fits_clean.csv',
    figname = 'sm_fig3',
    save_to = 'figures',
    save_as = '' # File format (png, jpeg, svg, ...)
)

# MP-aligned plots

In [None]:
def make_fig(choice_data_path, xlim, boxcars, figname, save_to='figures', save_as=''):
    df = pd.read_csv(choice_data_path).filter(
        items=['sid','group','nam','trial','sc','activity','mp1','mp2','mp3']
    )
    df.activity = pd.Categorical(df.activity)
    df.loc[:, 'activity'] = df.activity.cat.codes + 1
    
    fig = plt.figure(figname, figsize=[12,4])
    gs = gridspec.GridSpec(nrows=2, ncols=8, width_ratios=[1, 0.1, 1,1, 0.1, 1,1,1])
    # Add ghost plots to group subplots and add common labels
    ax = fig.add_subplot(gs[:, :])
    vut.ghost(ax)
    ax.set_xlabel('Trials after mastery point (MP)', labelpad=30)
    for ax in [fig.add_subplot(gs[:, 1]), fig.add_subplot(gs[:, 4])]:
        vut.ghost(ax)
        ax.axvline(1, lw=3, c='k')
    for i, ax in enumerate([fig.add_subplot(gs[:, 0]), fig.add_subplot(gs[:, 2:4]), fig.add_subplot(gs[:, 5:])]):
        vut.ghost(ax)
        ax.set_title(f'NAM {i+1}', pad=30, fontsize=16)
    
    getax = lambda row, col: vut.pretty(fig.add_subplot(gs[row, col]))
    sp_rows = []
    for r in [0, 1]:
        sp_rows.append({
            1: [getax(r, 0)],
            2: [getax(r, 2), getax(r, 3)],
            3: [getax(r, 5), getax(r, 6), getax(r, 7)]
        })
    # Plot data
    df.set_index(['sid','trial','group','nam'], inplace=True)
    for nam in [1, 2, 3]:
        for group in [0, 1]:
            sub_df = df.loc[(slice(None), slice(None), group, nam), :].droplevel([2, 3])
            for mp in range(nam):
                sids = sub_df.index.get_level_values(0).unique().tolist()
                mps = [sub_df.loc[(sid, 1), f'mp{mp+1}'] for sid in sids]
                # Plot sticking
                ax = sp_rows[0][nam][mp]
                ax.set_xlim([1,xlim])
                ax.set_ylim(0, 1)
                ax.tick_params(labelbottom=False)
                if nam > 1:
                    ax.tick_params(labelleft=False)
                aligned_choice = lut.boolean_indexing(
                    [list(sub_df.loc[(sid, slice(max([mp-1,0]),250)), 'activity']) for sid, mp in zip(sids, mps)],
                    fillval = 5
                )
                aligned_choice = aligned_choice[:, 0][:, np.newaxis] == aligned_choice[:, 1:]
                mean = np.mean(aligned_choice, axis=0)
                err = scs.sem(aligned_choice, axis=0, nan_policy='omit')
                smooth_mean = pd.Series(mean).rolling(boxcars, min_periods=1).mean()
                smooth_err = pd.Series(err).rolling(boxcars, min_periods=1).mean()
                ax.plot(smooth_mean, color=gcolors[group])
                ax.fill_between(np.arange(smooth_mean.size), smooth_mean+smooth_err, smooth_mean-smooth_err, color=gcolors[group], alpha=.3)
                ax.set_title(f'MP {mp+1}', fontsize=14)
                # Plot SC
                aligned_sc = lut.boolean_indexing(
                    [list(sub_df.loc[(sid, slice(mp,250)), 'sc']) for sid, mp in zip(sids, mps)],
                    fillval = np.nan
                )
                ax = sp_rows[1][nam][mp]
                if nam > 1:
                    ax.tick_params(labelleft=False)
                ax.set_ylim(0, 1)
                ax.set_xlim([1,xlim])
                mean = np.nanmean(aligned_sc, axis=0)
                err = scs.sem(aligned_sc, axis=0, nan_policy='omit')
                smooth_mean = pd.Series(mean).rolling(boxcars, min_periods=1).mean()
                smooth_err = pd.Series(err).rolling(boxcars, min_periods=1).mean()
                ax.plot(np.arange(smooth_mean.size), smooth_mean, color=gcolors[group])
                ax.fill_between(np.arange(smooth_mean.size), smooth_mean+smooth_err, smooth_mean-smooth_err, color=gcolors[group], alpha=.3)


    
    fig.tight_layout()
    fig.subplots_adjust(wspace=.2)
    if save_as:
        vut.save_it(fig, save_to, figname=figname, save_as=save_as, compress=False, dpi=100)

make_fig(
    choice_data_path = 'data/model_data.csv',
    xlim = 100,
    boxcars = 15,
    figname = 'sm_fig4a',
    save_to = 'figures',
    save_as = '',
)

# Effect of learning criterion on NAM grouping

In [None]:
def make_fig(data_path, omit_nans, figname, save_to, save_as):
    def get_mts(df, **kwargs):
        arr = df.values
        mask = (arr != 0)
        arr = np.where(mask.any(axis=0), mask.argmax(axis=0), kwargs['invalid_val'])
        return pd.Series(arr, dtype='Int64')
    
    df = pd.read_csv(data_path).set_index(['group','sid'])
    activities = [1,2,3,4]
    
    cols = [f'rpc{act_ind}' for act_ind in activities]
    df = df.loc[df.stage.eq('free'), cols]
    df = df.iloc[:-1, :]
    
    fractions_mastered = []
    xx = [10,11,12,13,14]
    for crit in xx:
        crit_pc = df.loc[:, 'rpc1':'rpc4'] > crit/15
        crit_pc.columns = activities[:]
        mastered = crit_pc
        
        mts_df = mastered.groupby(['group','sid']).apply(get_mts, invalid_val=np.nan if omit_nans else 250)
        mts_by_grp = mts_df.groupby('group', as_index=False)

        mastered = mastered.groupby(['group','sid']).any()
        fracs = lambda col: np.sum(col)/np.shape(col)[0]
        fractions_mastered_df = mastered.groupby(['group']).apply(fracs).reset_index()
        fractions_mastered_df['crit'] = crit
        fractions_mastered.append(fractions_mastered_df)
    
    fractions_mastered_df = pd.concat(fractions_mastered).sort_values(['group', 'crit']).set_index(['group', 'crit'])
    
    fig = plt.figure(figname, figsize=[8, 3])
    ax = fig.add_subplot(111)
    vut.ghost(ax)
    ax.set_xlabel('Mastery criterion (N of 15 correct)', labelpad=30, fontsize=14)
    
    # Display fractions of participants mastering each task in each group
    for activity in activities:
        ax = vut.pretty(fig.add_subplot(1, 4, activity))
        ax.set_ylim(0,1)
        ax.set_title(tlabels[activity], color=colors[activity-1], fontsize=14, fontweight='bold')
        if activity == 1: 
            ax.set_ylabel('Fraction\nmastering', fontsize=14)
        else:
            ax.tick_params(labelleft=False)
        for group in [0, 1]:
            series = fractions_mastered_df.loc[(group, slice(None)), activity]
            x = series.index.levels[1][:].tolist()
            y = series.tolist()
            ax.plot(x, y, color=gcolors[group])
            if activity==4:
                ax.text(12.5, .80 if group else .65, fullglabels[group], 
                        fontsize=14, fontweight='bold', color=gcolors[group])
        plt.xticks(xx, [str(i) for i in xx])
    fig.tight_layout()
    if save_as:
        vut.save_it(fig, save_to, figname, save_as=save_as, compress=False)
    
    

make_fig(
    data_path = 'data/model_data.csv',
    omit_nans = False,
    figname = 'sm_fig4b',
    save_to = 'figures',
    save_as = ''
)

# Weights and SC

In [None]:
def make_fig(data_path, data_path2, figname, save_to, save_as):    
    df = pd.read_csv(data_path)
    df = df.loc[df.vars.eq('rpc,rlp'), ['sid','group','nam','rpc','rlp']]
    df['norm'] = np.sqrt(df.rpc**2 + df.rlp**2)
    df['nrpc'] = df.rpc / df.norm
    df['nrlp'] = df.rlp / df.norm
    df = df.merge(
        pd.read_csv(data_path2).loc[:, ['sid', 'sc_flat', 'sc_lep', 'sc_streaks', 'dwfpc', 'fpc']], on='sid')
    print(ols('sc_flat ~ nrpc * nrlp', data=df).fit().summary())
    
    
    fig = plt.figure(figname, figsize=[8, 4])
    for i, var in enumerate(['nrpc', 'nrlp'], 1):
        ax = vut.pretty(fig.add_subplot(1,2,i))
        sns.regplot(
            x=df.loc[:, var], y=df.loc[:, 'sc_flat'], ax=ax, color='k',
            scatter_kws={'alpha': .3, 's': 20}
        )
        ax.set_xlabel(r'$\hat w_{'+f'{var.upper()[-2:]}'+r'}$')
        if i == 1:
            ax.set_ylabel('Average SC')
        else:
            ax.set_ylabel('')
            ax.tick_params(labelleft=False)
        
        lm = ols(fr'dwfpc ~ {var}', data=df).fit()
        display(lm.summary())
 
    fig.tight_layout()
    if save_as:
        vut.save_it(fig, save_to, figname, save_as=save_as, compress=False)
    
    

make_fig(
    data_path = 'data/model_results/param_fits_clean.csv',
    data_path2 = 'data/learning_data.csv',
    figname = 'response_fig2',
    save_to = 'figures',
    save_as = ''
)

# PC separately for each activity

In [None]:
def make_fig(data_path, figname, save_to, save_as):    
    df = pd.read_csv(data_path).filter(items=['sid','group','trial','rpc1','rpc2','rpc3','rpc4'])
    df = df.groupby(['group','trial']).agg(['mean', 'sem'])
    display(df.head())
    
    fig, ax = plt.subplots(num=figname, figsize=[6, 4])
    vut.pretty(ax)
    ax.set_xlim(0, 250)
    ax.set_ylabel('recent PC\n(Mean and SEM)')
    ax.set_xlabel('Trial')
    for group in [0, 1]:
        for i, col in enumerate(['rpc1','rpc2','rpc3','rpc4']):
            m = df.loc[(group, slice(None)), (col, 'mean')].values
            sem = df.loc[(group, slice(None)), (col, 'sem')].values
            ax.plot(m, c=colors[i], ls='--' if group else '-')
            ax.fill_between(np.arange(m.size), m+sem, m-sem, color=colors[i], alpha=.4)
            
    legend_elements = [Line2D([0],[0], color=colors[i-1], label=f'A{i}') for i in [1,2,3,4]]
    legend_elements += [Line2D([0],[0], color='gray', label=fullglabels[i], ls=['-','--'][i]) for i in [0,1]]
    ax.legend(handles=legend_elements, bbox_to_anchor=(1.1, 1), loc='upper left')
 
    fig.tight_layout()
    if save_as:
        vut.save_it(fig, save_to, figname, save_as=save_as, compress=False)
    
    

make_fig(
    data_path = 'data/model_data.csv',
    figname = 'response_fig1',
    save_to = 'figures',
    save_as = ''
)