In [None]:
import numpy as np
import pandas as pd

from os.path import join

import matplotlib.pyplot as plt
%matplotlib inline

# Load functions

In [None]:
def format_p(p):
    """
    Format a p-value to a string with some cutoffs.
    """
    if p < 0.00001:
        return r'$P$ < 0.00001'
    elif p < 0.0001:
        return r'$P$ < 0.0001'
    else:
        return r'$P$ = {:.4f}'.format(p)

In [None]:
def plot_recovery_panel(recovery_df, variable, other_levels, ax=None, ci=0.95):
    
    import statsmodels.api as sm
    from statsmodels.stats.outliers_influence import summary_table
    
    if ax is None:
        fig, ax = plt.subplots()
    
    # Subset data
    df = recovery_df[(recovery_df['{}_level'.format(variable)] == 'variable')]
    for parameter, level in other_levels.items():
        df = df[df['{}_level'.format(parameter)] == level].copy()
    
    gen = df['{}_gen'.format(variable)].values
    rec = df['{}_rec'.format(variable)].values

    mini = np.stack([gen, rec]).min()
    maxi = np.stack([gen, rec]).max()
    diff = maxi - mini
    
    minilim = mini - 0.1 * diff
    maxilim = maxi + 0.1 * diff
    
    ax.scatter(gen,
               rec,
               color='lightgray',
               edgecolor='k',
               linewidth=1,
               alpha=0.9)
    
    ax.set_xlim(minilim, maxilim)
    ax.set_ylim(minilim, maxilim)

    # Linear model fit
    X = sm.add_constant(gen)
    lm = sm.OLS(rec, X).fit()
    intercept, slope = lm.params
    table, data, columns = summary_table(lm, alpha=1.-ci)
    predicted, mean_ci_lower, mean_ci_upper = data[:, np.array([2, 4, 5])].T

    tval = lm.tvalues[-1]
    pval = lm.pvalues[-1]
    
    xs = np.linspace(*ax.get_xlim(), 100)
    ax.plot(xs, intercept + slope * xs,
                color='deeppink', zorder=-8)
    sort_idx = np.argsort(gen)
    ax.fill_between(gen[sort_idx], mean_ci_lower[sort_idx], mean_ci_upper[sort_idx],
                    color='deeppink', alpha=0.1, zorder=-8)
    
    p_string = format_p(pval)
    beta_annotation = [r'$\beta_{}$ = {:.4f}'.format(b, beta)
                       for b, beta in enumerate(lm.params)]
    
    annotation = '\n'.join(beta_annotation) + '\n' + r'$t$ = {:.2f}'.format(tval) + '\n' + p_string
    ax.text(0.1,
            0.9,
            annotation,
            verticalalignment='top',
            transform=ax.transAxes,
            fontsize='small')
    
    ax.plot(ax.get_xlim(), ax.get_ylim(),
            linewidth=0.25,
            color='k', alpha=0.9, zorder=-9)
    
    ax.set_xlabel('Generating')
    ax.set_ylabel('Recovered')
    
    return ax

In [None]:
def plot_recovery(recovery_df,
                  variables=['v', 'gamma', 's', 'tau'],
                  levels=['low', 'medium', 'high'],
                  bounds=dict(v=[0.000015, 0.00015],
                              gamma=[-1, 1],
                              s=[0.004, 0.011],
                              tau=[0.1, 1.25])):
    """
    Plots a single axes containing the relationship between
    generating and recovered values of the indicated `variable`,
    keeping all other parameters at their respective levels
    specified in the `other_levels` dictionary.
    """
    from itertools import product
    
    n_rows = len(variables)
    n_cols = (len(levels))**(len(variables)-1)
    
    fig, axs = plt.subplots(n_rows, n_cols,
                            figsize=(n_cols*3, n_rows*3),
                            sharey='row',
                            sharex='row')

    for i, var in enumerate(variables):
        others = [v for v in variables
                  if v != var]
        n_others = len(others)
        
        if n_others == 1:            
            constellations = list(product(levels))
        elif n_others == 2:
            constellations = list(product(levels, levels))
        elif n_others == 3:
            constellations = list(product(levels, levels, levels))
        else:
            ValueError('What?!')

        for j, constellation in enumerate(constellations):
            other_levels = {other: constellation[o]
                            for o, other in enumerate(others)}
            try: 
                axs[i, j] = plot_recovery_panel(recovery_df, variable=var, other_levels=other_levels, ax=axs[i, j])
            except:
                continue
            axs[i, j].set_title('')
            axs[i, j].set_xlabel('')
            axs[i, j].set_ylabel('')
            title = ['{}: {}'.format(other, level)
                     for other, level in other_levels.items()]
            axs[i, j].set_title(', '.join(title), fontsize=8)
            axs[i, j].set_xlim(*bounds[var])
            axs[i, j].set_ylim(*bounds[var])
            
        axs[i, 0].set_ylabel('{}\nRecovered'.format(var))
    
    for ax in axs[-1, :]:
        ax.set_xlabel('Generating')
    
    fig.tight_layout()
    
    return fig, axs

# Visualization

In [None]:
methods = ['nuts', 'metropolis', 'advi']

for method in methods:
    print(method)
    parameter_recovery = pd.read_csv(join('results',
                                          'parameter_recovery',
                                           method,
                                          'parameter_recovery_{}.csv'.format(method)))
    
    print('P(converged):', parameter_recovery['converged'].mean())
    parameter_recovery = parameter_recovery[parameter_recovery['converged'] != False].copy()

    for variable_set in [['v', 'gamma', 's', 'tau'],
                         ['v', 's', 'tau'],
                         ['v', 'gamma', 'tau'],
                         ['v', 'gamma', 's']]:

        plot_recovery(recovery_df=parameter_recovery, variables=variable_set);
        plt.savefig(join('results',
                         'parameter_recovery',
                          method,
                         'parameter_recovery_{}_'.format(method) + '-'.join(variable_set) + '.png'),
                    dpi=300)
        plt.close()