In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
from glob import glob
import seaborn as sns
import arviz as az
from os.path import join, split, splitext
from toolz import pipe
from fetch_data import create_arrays

In [None]:
# Directory containing model runs -- change as needed
run_dir = '/media/martin/External Drive/projects/pymc_vs_stan/rerun_01_10_2022/fits/'

# Directory to save plots in -- change as needed
plot_dir = '/home/martin/projects/pymc_vs_stan_revamp/plots/'

# Find all runtime files:
all_runtimes = glob(join(run_dir, '*/*.txt'))

In [None]:
len(all_runtimes)

In [None]:
# Extract runtimes from text files 
runtimes = [float(list(open(x))[0].strip()) for x in all_runtimes]

In [None]:
# Helper function to extract some more info from the filenames
def extract_info(filenames):
    
    approach_names = [x.split('/')[-2] for x in filenames]
    filenames = [split(splitext(x)[0])[-1] for x in filenames]
    years = [int(x.split('_')[-1]) for x in filenames]
    
    return {'year': years, 'method': approach_names}

In [None]:
info = extract_info(all_runtimes)
info['runtime'] = runtimes

In [None]:
results = pd.DataFrame(info)

In [None]:
results.sort_values('year').head()

In [None]:
# Fetch the number of data points for each year from the dataset
of_interest = results['year'].unique()

data = {x: create_arrays(x)['winner_ids'].shape[0] for x in of_interest}

In [None]:
results['n_matches'] = [data[row.year] for row in results.itertuples()]

In [None]:
total_matches = results['n_matches'].max()

In [None]:
results['log_matches'] = np.log10(results['n_matches'])

In [None]:
f, ax = plt.subplots(1, 1)

of_interest = ['pymc', 'cmdstanpy', 'pymc_jax_gpu_vectorized', 'pymc_jax_gpu_parallel',
               'pymc_blackjax_gpu_vectorized', 'pymc_blackjax_cpu_parallel']

colours = sns.color_palette(palette=None, n_colors=len(of_interest))

colour_lookup = {x: y for x, y in zip(of_interest, colours)}

rel_results = results[results['method'].isin(of_interest)]

for cur_method in rel_results['method'].unique():
    
    cur_data = rel_results[rel_results['method'] == cur_method].sort_values('n_matches')
    
    linestyle = '--' if 'gpu' in cur_method else None
    
    ax.plot(cur_data['n_matches'], cur_data['runtime'] / 60, label=cur_method, marker='o', linestyle=linestyle,
            color=colour_lookup[cur_method])
    
ax.set_yscale('log')
ax.set_xscale('log')

ax.set_xlabel('Number of matches (log scale)')
ax.set_ylabel('Runtime in minutes (log scale)')

ax.grid(alpha=0.5, linestyle='--')

ax.legend()

f.set_size_inches(8, 5)
f.tight_layout()

#plt.savefig(join(plot_dir, 'walltime_plot.png'), dpi=300)

In [None]:
f, ax = plt.subplots(1, 1)

to_plot = rel_results[rel_results['n_matches'] == rel_results['n_matches'].max()].copy()

to_plot['runtime (minutes)'] = to_plot['runtime'] / 60

to_plot = to_plot.sort_values('runtime (minutes)')

colours_to_plot = [colour_lookup[x] for x in to_plot['method']]

sns.barplot(x=to_plot['runtime (minutes)'], y=to_plot['method'], ax=ax, palette=colours_to_plot)
ax.grid(alpha=0.5)

ax.set_title(f'Runtime when using all matches in dataset ({total_matches} matches)')

f.set_size_inches(8, 4)
f.tight_layout()

#plt.savefig(join(plot_dir, 'walltime_full.png'), dpi=300)

In [None]:
to_plot = to_plot.set_index('method', drop=False)

to_plot['runtime'] / to_plot['runtime'].min()

In [None]:
to_plot['runtime'] / 60

In [None]:
# Compute ESS / second

In [None]:
all_draws = glob(join(run_dir, '*/*.netcdf'))

In [None]:
def compute_min_ess(arviz_draws):
    
    min_ess = az.ess(arviz_draws).min()
    
    return float(min_ess.to_array().min().values)

In [None]:
min_ess = list(pipe(all_draws,
               # Load
               lambda x: map(az.from_netcdf, x),
               # Compute minimum ESS
               lambda x: map(compute_min_ess, x)
              ))

In [None]:
draw_info = extract_info(all_draws)

In [None]:
ess_info = pd.DataFrame({'min_ess': min_ess, 'method': draw_info['method'], 'year': draw_info['year']})

In [None]:
with_runtime = ess_info.merge(results)

with_runtime['ESS / second'] = with_runtime['min_ess'] / with_runtime['runtime']

In [None]:
with_runtime.head()

In [None]:
f, ax = plt.subplots(1, 1)

rel_results = with_runtime[with_runtime['method'].isin(of_interest)]

for cur_method in rel_results['method'].unique():
    
    cur_data = rel_results[rel_results['method'] == cur_method].sort_values('n_matches')
    
    linestyle = '--' if 'gpu' in cur_method else None
    
    ax.plot(cur_data['n_matches'], cur_data['ESS / second'], label=cur_method, marker='o', linestyle=linestyle,
            color=colour_lookup[cur_method])
    
ax.set_yscale('log')
ax.set_xscale('log')

ax.set_xlabel('Number of matches (log scale)')
ax.set_ylabel('Minimum ESS / second (log scale)')

ax.grid(alpha=0.5, linestyle='--', which='both')

ax.legend(loc='lower left')

f.set_size_inches(8, 5)
f.tight_layout()

plt.savefig(join(plot_dir, 'ess_values.png'), dpi=300)

In [None]:
f, ax = plt.subplots(1, 1)

to_plot = rel_results[rel_results['year'] == 1968].sort_values('ESS / second')

colours_to_plot = [colour_lookup[x] for x in to_plot['method']]

sns.barplot(x=to_plot['ESS / second'], y=to_plot['method'], ax=ax, palette=colours_to_plot)

ax.grid(alpha=0.5)

ax.set_title(f'ESS / second when using all matches in dataset ({total_matches} matches)')

f.set_size_inches(8, 4)
f.tight_layout()

plt.savefig(join(plot_dir, 'ess_per_second_full.png'), dpi=300)

In [None]:
rel = to_plot.set_index('method', drop=False)

rel['ESS / second'].max() / rel['ESS / second']

In [None]:
11.3 / 3.9

In [None]:
# Check estimates agree

stan_res = az.from_netcdf(join(run_dir, 'cmdstanpy/samples_1968.netcdf'))

In [None]:
list(stan_res.keys())

In [None]:
p_skills = stan_res.posterior['player_skills']

player_means_stan = p_skills.values.mean(axis=(0, 1))
player_sds_stan = p_skills.values.std(axis=(0, 1))

In [None]:
pymc3_res = az.from_netcdf(join(run_dir, 'pymc/samples_1968.netcdf'))

In [None]:
p_skills_pymc3 = pymc3_res.posterior['player_skills'].mean(dim=('chain', 'draw'))
p_skills_pymc3_sd = pymc3_res.posterior['player_skills'].std(dim=('chain', 'draw'))

In [None]:
# Load data for 1968 to get player names
cur_data = create_arrays(1968)

jax_res = az.from_netcdf(join(run_dir, 'pymc_jax_gpu_vectorized/samples_1968.netcdf'))

In [None]:
cur_data['player_encoder']

In [None]:
p_skills_jax = jax_res.posterior['player_skills'].mean(dim=('chain', 'draw'))
p_skills_jax_sd = jax_res.posterior['player_skills'].std(dim=('chain', 'draw'))

In [None]:
f, ax = plt.subplots(1, 3)

ax[0].scatter(p_skills_jax.values.reshape(-1), p_skills_pymc3.values.reshape(-1))
ax[0].plot([p_skills_jax.min(), p_skills_jax.max()], [p_skills_jax.min(), p_skills_jax.max()])

ax[1].scatter(p_skills_jax.values.reshape(-1), player_means_stan.reshape(-1))
ax[1].plot([p_skills_jax.min(), p_skills_jax.max()], [p_skills_jax.min(), p_skills_jax.max()])

ax[2].scatter(p_skills_pymc3.values.reshape(-1), player_means_stan.reshape(-1))
ax[2].plot([p_skills_pymc3.min(), p_skills_pymc3.max()], [p_skills_pymc3.min(), p_skills_pymc3.max()])

ax[0].set_xlabel('JAX GPU means')
ax[0].set_ylabel('PyMC means')

ax[1].set_xlabel('JAX GPU means')
ax[1].set_ylabel('Stan means')

ax[2].set_xlabel('PyMC means')
ax[2].set_ylabel('Stan means')

f.set_size_inches(12, 4)
f.tight_layout()

plt.savefig(join(plot_dir, 'mean_comparison.png'), dpi=300)

In [None]:
f, ax = plt.subplots(1, 3)

ax[0].scatter(p_skills_jax_sd.values.reshape(-1), p_skills_pymc3_sd.values.reshape(-1))
ax[0].plot([p_skills_jax_sd.min(), p_skills_jax_sd.max()], [p_skills_jax_sd.min(), p_skills_jax_sd.max()])

ax[1].scatter(p_skills_jax_sd.values.reshape(-1), player_sds_stan.reshape(-1))
ax[1].plot([p_skills_jax_sd.min(), p_skills_jax_sd.max()], [p_skills_jax_sd.min(), p_skills_jax_sd.max()])

ax[2].scatter(p_skills_pymc3_sd.values.reshape(-1), player_sds_stan.reshape(-1))
ax[2].plot([p_skills_pymc3_sd.min(), p_skills_pymc3_sd.max()], [p_skills_pymc3_sd.min(), p_skills_pymc3_sd.max()])

ax[0].set_xlabel('JAX GPU sds')
ax[0].set_ylabel('PyMC sds')

ax[1].set_xlabel('JAX GPU sds')
ax[1].set_ylabel('Stan sds')

ax[2].set_xlabel('PyMC sds')
ax[2].set_ylabel('Stan sds')

f.set_size_inches(12, 4)
f.tight_layout()

plt.savefig(join(plot_dir, 'sd_comparison.png'), dpi=300)

In [None]:
p_skills.shape

In [None]:
p_skills_jax.shape

In [None]:
print(pd.DataFrame({'mean_skill': p_skills_jax, 'skill_sd': p_skills_jax_sd}, index=cur_data['player_encoder'].classes_).sort_values('mean_skill', ascending=False).head(20).round(2).to_markdown())