In [None]:
%cd ..
import os
from pathlib import Path
import re
import pickle
import pandas as pd

import jax
import jax.numpy as jnp

from src.config.core import Config
import src.dataset as ds
import src.training.utils as train_utils
import src.inference.utils as inf_utils
from src.inference.evaluation import evaluate_bde
from src.types import ParamTree

from matplotlib import pyplot as plt
plt.rcParams['font.size'] = 10

import numpy as np

In [None]:
def load_results(path, type, ms, seeds, exploration_steps, schedule='constant'):
    df = pd.DataFrame(columns=['m', 'lppd', 'rmse', 'time', 'seed'])
    nan_counts = []
    for i, m in enumerate(ms):
        nan_count = 0
        for seed in seeds:
            if type == 'parallel':
                results_dir = f'{path}/{m}x1_{schedule}_{exploration_steps}+500_seed{seed+i}'
            else:
                results_dir = f'{path}/1x{m}_{schedule}_{exploration_steps}+500_seed{seed+i}'

            metrics_dir = f'{results_dir}/eval_metrics.pkl'
            samples_dir = f'{results_dir}/samples'
            tree_path = f'{results_dir}/tree'
            if not os.path.exists(metrics_dir):
                nan_count += 1
                print(f'Skipping {results_dir} - does not exist')
                samples = train_utils.load_samples_from_dir(samples_dir, tree_path=tree_path)
                samples_nan = jax.tree.map(lambda x: jnp.isnan(x), samples)
                print(f'Found NaN in samples: {samples_nan}')
                break
            with open(f'{results_dir}/eval_metrics.pkl', 'rb') as f:
                eval_metrics = pickle.load(f)
            with open(f'{results_dir}/samples/info.pkl', 'rb') as f:
                info = pickle.load(f)
            curr_lppd = eval_metrics['lppd']
            curr_rmse = eval_metrics['rmse']
            # lppd.append(curr_lppd)
            # rmse.append(curr_rmse)
            # times.append(info['total_time'])
            df = pd.concat([df, pd.DataFrame({
                'm': [m],
                'lppd': [curr_lppd],
                'rmse': [curr_rmse],
                'time': [info['total_time']],
                'seed': [seed+i]
            })], ignore_index=True)
        nan_counts.append(nan_count)

    print(f'Nan counts: {nan_counts}')

    df = df.groupby('m').agg({
        'lppd': ['mean', 'std'],
        'rmse': ['mean', 'std'],
        'time': ['mean', 'std'],
        # 'seed': 'count'
    }).reset_index().rename(columns={'index': 'm'})

    return df

In [None]:
ms=np.array([2, 4, 6, 8, 10, 12])
seeds = [0, 42, 221, 476, 1453, 1644, 1840, 1973, 2025, 2100]
# seeds = [0, 42, 1973, 2025, 2100]
res_parallel_5k = load_results('results/constant', 'parallel', ms=ms, seeds=seeds, exploration_steps=5000)
res_parallel_2k = load_results('results/constant', 'parallel', ms=ms, seeds=seeds, exploration_steps=2000)
res_sequential_2k = load_results('results/constant', 'sequential', ms=ms, seeds=seeds, exploration_steps=2000)
res_sequential_5k = load_results('results/constant', 'sequential', ms=ms, seeds=seeds, exploration_steps=5000)
res_parallel_2k

In [None]:
def plot_comparison(res_parallel, res_sequential):
    ms = res_parallel['m'].to_numpy()
    mean_lppd_parallel = res_parallel['lppd']['mean'].to_numpy()
    std_lppd_parallel = res_parallel['lppd']['std'].to_numpy()
    mean_rmse_parallel = res_parallel['rmse']['mean'].to_numpy()
    std_rmse_parallel = res_parallel['rmse']['std'].to_numpy()

    mean_lppd_sequential = res_sequential['lppd']['mean'].to_numpy()
    std_lppd_sequential = res_sequential['lppd']['std'].to_numpy()
    mean_rmse_sequential = res_sequential['rmse']['mean'].to_numpy()
    std_rmse_sequential = res_sequential['rmse']['std'].to_numpy()

    fig, axs = plt.subplots(figsize=(6.3, 6.3/2), ncols=2)
    axs[0].errorbar(ms+0.1, mean_lppd_parallel, yerr=std_lppd_parallel, fmt='s--', label='parallel', color='blue')
    axs[0].errorbar(ms-0.1, mean_lppd_sequential, yerr=std_lppd_sequential, fmt='s--', label='sequential', color='red')
    axs[0].set_ylabel('LPPD')
    axs[0].set_xlabel('Number of chains/cycles')
    axs[1].errorbar(ms+0.1, mean_rmse_parallel, yerr=std_rmse_parallel, fmt='s--', label='parallel', color='blue')
    axs[1].errorbar(ms-0.1, mean_rmse_sequential, yerr=std_rmse_sequential, fmt='s--', label='sequential', color='red')
    axs[1].set_ylabel('RMSE')
    axs[1].set_xlabel('Number of chains/cycles')

    axs[0].set_xticks(ms)
    axs[0].set_xticklabels(ms)
    axs[1].set_xticks(ms)
    axs[1].set_xticklabels(ms)

    # Remove duplicate labels in the legend
    handles, labels = axs[0].get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    fig.legend(by_label.values(), by_label.keys(), loc='lower center', ncol=2, bbox_to_anchor=(0.5, -0.1))
    fig.tight_layout()

    return fig, axs


In [None]:
fig, axs = plot_comparison(res_parallel_2k, res_sequential_2k)
# plt.savefig('../ba/images/parallel_vs_sequential_constant_2k.pdf', bbox_inches='tight')

fig, axs = plot_comparison(res_parallel_5k, res_sequential_5k)
# plt.savefig('../ba/images/parallel_vs_sequential_constant_5k.pdf', bbox_inches='tight')

In [None]:
res_parallel_2k['lppd']

In [None]:
def combine_mean_std(res):
    combined = res.copy()
    combined.columns = combined.columns.map("_".join)
    for col in ['lppd', 'rmse', 'time']:
        for suffix in ['mean', 'std']:
            name = f'{col}_{suffix}'
            combined[name] = combined[name].astype("float64").map(lambda x: f'{x:.3f}')
        combined[col] = combined[f'{col}_mean'].astype(str) + r' $\pm$ ' + combined[f'{col}_std'].astype(str)
    return combined[["m_", "lppd", "rmse", "time"]].rename(columns={'m_': 'm'})

def full_performance_table(res_par, res_seq, fn="tmp.tex"):
    res_par = combine_mean_std(res_par)
    res_seq = combine_mean_std(res_seq)
    m = r'$M$'
    lppd = r'LPPD ($\uparrow$)'
    rmse = r'RMSE ($\downarrow$)'
    header = pd.MultiIndex.from_tuples([
        (m, ''),
        (lppd, 'parallel'),
        (lppd, 'sequential'),
        (rmse, 'parallel'),
        (rmse, 'sequential'),
        # ('time', 'parallel'),
        # ('time', 'sequential')
    ])
    df = pd.DataFrame(columns=header)
    df[(m, '')] = res_par['m']
    df[(lppd, 'parallel')] = res_par['lppd']
    df[(lppd, 'sequential')] = res_seq['lppd']
    df[(rmse, 'parallel')] = res_par['rmse']
    df[(rmse, 'sequential')] = res_seq['rmse']
    with open(fn, 'w') as f:
        df.to_latex(
            f,
            index=False,
            column_format='c' * len(header),
            escape=False,
            multicolumn=True,
            multicolumn_format='c',
        )

full_performance_table(res_parallel_2k, res_sequential_2k, fn="tmp_2k.tex")
full_performance_table(res_parallel_5k, res_sequential_5k, fn="tmp_5k.tex")

In [None]:
# more quantitative comparison
df = pd.DataFrame({
    r'$M$': res_parallel_2k['m'],
    'LPPD (2k)': res_sequential_2k['lppd']['mean'] - res_parallel_2k['lppd']['mean'],
    'LPPD (5k)': res_sequential_5k['lppd']['mean'] - res_parallel_5k['lppd']['mean'],
    'RMSE (2k)': res_sequential_2k['rmse']['mean'] - res_parallel_2k['rmse']['mean'],
    # '2k_runtime': res_sequential_2k['time']['mean'] - res_parallel_2k['time']['mean'],
    'RMSE (5k)': res_sequential_5k['rmse']['mean'] - res_parallel_5k['rmse']['mean'],
    # '5k_runtime': res_sequential_5k['time']['mean'] - res_parallel_5k['time']['mean']
})
summary = df.agg(["mean"])
# "average" of m is meaningless
print(df)
df = pd.concat([df, summary])
formatter = {
    'LPPD (2k)': lambda x: f'{x:.3f}',
    'LPPD (5k)': lambda x: f'{x:.3f}',
    'RMSE (2k)': lambda x: f'{x:.3f}',
    'RMSE (5k)': lambda x: f'{x:.3f}',
    r'$M$': lambda x: f'{int(x)}' if x != 7 else 'Average',
}
df.to_latex("tmp_diff.tex", index=False, formatters=formatter, column_format='rrrrr')

In [None]:
pd.DataFrame({
    'm': res_parallel_5k['m'],
    'mean_time_parallel': res_parallel_5k['time']['mean'],
    'std_time_parallel': res_parallel_5k['time']['std'],
    'mean_time_sequential': res_sequential_5k['time']['mean'],
    'std_time_sequential': res_sequential_5k['time']['std']
})

In [None]:
# runtime
fig, ax = plt.subplots(figsize=(0.8*6.3, 0.8*6.3/3*2))
# ax.plot(ms, res_parallel_5k['time']['mean'], 'o--', label='parallel', color='blue')
# ax.plot(ms, res_sequential_5k['time']['mean'], 'o--', label='sequential', color='red')
ax.errorbar(ms, res_parallel_5k['time']['mean'], yerr=np.asarray(res_parallel_5k['time']['std']), fmt='o--', label='parallel+constant', color='blue', markersize=3)
ax.errorbar(ms, res_sequential_5k['time']['mean'], yerr=np.asarray(res_sequential_5k['time']['std']), fmt='o--', label='sequential+constant', color='red', markersize=3)
ax.set_ylabel('Runtime [s]')
ax.set_xlabel('Number of chains/cycles')
plt.legend(loc='upper left')
# plt.savefig('../ba/images/parallel_vs_sequential_constant_5k_runtime.pdf', bbox_inches='tight')

### Exploration budget

In [None]:
def evaluate_bde_from_file(
    results_dir: Path,
    samples: ParamTree | None = None,
    cycle: int | None = None,
    chain: int | None = None,
    batch_size: int | None = None
):
    # """Evaluate BDE from a file."""
    if samples is None:
        sample_path = results_dir / 'samples'
        tree_path = results_dir / 'tree'
        samples = train_utils.load_samples_from_dir(sample_path, tree_path)

    config = Config.from_yaml(results_dir / 'config.yaml')
    n_samples = inf_utils.count_samples(samples)
    n_cycles = config.training.sampler.scheduler_config.parameters['n_cycles']
    n_chains = inf_utils.count_chains(samples)

    n_samples_per_cycle = n_samples // n_cycles

    if cycle is not None:
        assert 0 <= cycle < n_cycles, f'Cycle index {cycle} must be between 0 and {n_cycles}-1'
        samples = jax.tree.map(
            lambda x: x[:, cycle * n_samples_per_cycle : (cycle+1) * n_samples_per_cycle],
            samples
        )
    if chain is not None:
        assert 0 <= chain < n_chains, f'Chain index {chain} must be between 0 and {n_chains}-1'
        samples = jax.tree.map(
            lambda x: x[chain, ...][None, ...], # always preserve the chain dimension
            samples
        )

    # print(jax.tree.map(lambda x: x.shape, samples))

    module = config.get_flax_model()
    loader = ds.TabularLoader(
        config.data,
        rng=config.jax_rng,
        target_len=config.data.target_len
    )

    features = loader.test_x # (B x F)
    labels = loader.test_y # (B x T)

    metrics = {}
    logits, metrics = evaluate_bde(
        params=samples, # type: ignore
        module=module,
        features=features,
        labels=labels,
        task=config.data.task,
        batch_size=batch_size,
        verbose=False,
        metrics_dict=metrics,
        nominal_coverages=[0.5, 0.75, 0.9, 0.95],
        per_chain=False
    )

    return logits, metrics

In [None]:
from IPython.display import clear_output
dir = Path('results/exploration_budget')

# os.listdir(results_dir)
pattern = r'1x12_constant_(\d+)\+(\d+)_10_seed(\d+)'
n_cycles = 12
all_metrics = []
# exploration_lengths = [2000, 3000, 4000, 5000, 6000]
all_metrics = []
for result in sorted(os.listdir(dir)):
    match = re.match(pattern, result)
    if not match:
        continue
    results_dir = dir / result
    exploration_steps = int(match.group(1))
    samples = train_utils.load_samples_from_dir(
        results_dir / 'samples',
        tree_path=results_dir / 'tree'
    )
    for cycle in range(n_cycles):
        print(f'Evaluating cycle {cycle} in {results_dir}')
        logits, metrics = evaluate_bde_from_file(
            results_dir=results_dir,
            samples=samples,
            cycle=cycle,
            chain=None,
            batch_size=None
        )
        clear_output(wait=True)
        all_metrics.append({
            'cycle': cycle + 1,  # 1-indexed for plotting
            'exploration_steps': exploration_steps,
            'lppd': metrics['lppd'],
            'rmse': metrics['rmse']
        })

df = pd.DataFrame(all_metrics)

In [None]:
df_agg = df.groupby(['exploration_steps', 'cycle']).agg({
    'lppd': ['mean', 'std'],
    'rmse': ['mean', 'std']
})

fig, axs = plt.subplots(2, 3, figsize=(6.3, 6.3/4*3), sharey='all')
for i, exploration_steps in enumerate([2000, 3000, 4000, 5000, 6000, 7000]):
    ax = axs[i // 3, i % 3]
    df_subset = df_agg.xs(exploration_steps, level='exploration_steps')
    cycles = df_subset.index.get_level_values('cycle')
    lppd_means = df_subset['lppd']['mean']
    lppd_stds = df_subset['lppd']['std']

    ax.errorbar(cycles, lppd_means, yerr=lppd_stds, fmt='o', color='red', markersize=2)
    
    ax.set_title(f'{exploration_steps//1000}k steps')
    ax.set_xlabel('Cycle Index')
    ax.set_xticks([2, 4, 6, 8, 10, 12])

axs[0,0].set_ylabel('LPPD')
axs[1,0].set_ylabel('LPPD')
fig.tight_layout()
# fig.savefig('../ba/images/ablation_cycle_length.pdf', bbox_inches='tight')

### Cosine schedule (only parallel)

In [None]:
ms = [2, 4, 6, 8]
seeds = [0, 42, 1973, 2025, 2100]
res_parallel_cos = load_results('results/cosine', 'parallel', ms=ms, seeds=seeds, exploration_steps=11500, schedule='cosine')
res_parallel_cos_combined = combine_mean_std(res_parallel_cos)
res_parallel_cos_combined.rename(columns={
    'm_': r'$M$',
    'lppd': r'LPPD ($\uparrow$)',
    'rmse': r'RMSE ($\downarrow$)'
}, inplace=True)
res_parallel_cos_combined.to_latex(
    'tmp_cosine.tex',
    index=False,
    column_format='c' * len(res_parallel_cos_combined.columns),
    escape=False,
    multicolumn=True,
    multicolumn_format='c',
)