# Hierarchical AR(1) inference with ABI

We evaluate the coupling flow version (Heinrich et al., 2024, Habermann et al., 2024) and the v-prediction DM used for our compositional backbones.

In [None]:
import os
os.environ['KERAS_BACKEND'] = 'jax'

import numpy as np
from scipy.stats import median_abs_deviation as mad

import keras
import bayesflow as bf
from bayesflow import diagnostics

from problems import AR1GridPrior, AR1GridSimulator

In [None]:
grid_size = [4*4, 16*16][0]
grid_size

In [None]:
prior = AR1GridPrior()
sim = AR1GridSimulator()

# Match budget of compositional
train_dict = prior.sample(
    batch_size=10000 // grid_size,
    n_local_samples=grid_size,
)
train_dict['global_params'] = keras.ops.convert_to_numpy(prior.normalize_theta(train_dict['global_params'], global_params=True))
train_dict['data'] = keras.ops.convert_to_numpy(prior.normalize_data(train_dict['data']))
train_dict['global_params'].shape

In [None]:
test_dict = prior.sample(
    batch_size=100,
    n_local_samples=grid_size,
)

test_dict['global_params'] = keras.ops.convert_to_numpy(prior.normalize_theta(test_dict['global_params'], global_params=True))
test_dict['data'] = keras.ops.convert_to_numpy(prior.normalize_data(test_dict['data']))

In [None]:
adapter = (
    bf.adapters.Adapter()
    .to_array()
    .convert_dtype("float64", "float32")
    .rename("global_params", "inference_variables")
    .rename("data", "summary_variables")
    .keep(["inference_variables", "summary_variables"])
)

In [None]:
models = {
    "coupling": (bf.networks.CouplingFlow, {
        "transform": "spline",
        "depth": 2
    }),
    "dm_cosine_v": (bf.networks.DiffusionModel, {
        "subnet_kwargs": {'widths': (256, 256, 256, 256, 256), 'dropout': 0.1},
        "noise_schedule": "cosine",
        "schedule_kwargs": {"weighting": "likelihood_weighting"},
        "prediction_type": "velocity",
        "integration_kwargs": {"method": "euler_maruyama", "steps": 300},
    })
}

In [None]:
for model_name, model_packet in models.items():

    workflow_global = bf.BasicWorkflow(
            adapter=adapter,
            summary_network=bf.networks.DeepSet(summary_dim=5, dropout=0.1, depth=1),  # shallow summary net
            inference_network=model_packet[0](**model_packet[1]),
            #checkpoint_filepath=f"bf_checkpoints/ar1_{model_name}_{grid_size}",
            standardize=None
        )

    history = workflow_global.fit_offline(
        train_dict, batch_size=32, epochs=1000 if model_name != "coupling" else 100, verbose=2
    )

    # Evaluation
    test_global_samples = workflow_global.sample(conditions=test_dict, num_samples=1000)
    global_rmse = diagnostics.metrics.root_mean_squared_error(test_global_samples, test_dict,
                                                              aggregation=np.median)['values'].mean().round(2)
    global_rmse_mad = diagnostics.metrics.root_mean_squared_error(test_global_samples, test_dict,
                                                                  aggregation=mad)['values'].mean().round(2)
    print('Global RMSE:', global_rmse, global_rmse_mad)

    global_rmse = diagnostics.posterior_contraction(test_global_samples, test_dict,
                                                    aggregation=np.median)['values'].mean().round(2)
    global_rmse_mad = diagnostics.posterior_contraction(test_global_samples, test_dict,
                                                        aggregation=mad)['values'].mean().round(2)
    print('Global Contraction:', global_rmse, global_rmse_mad)

## 10x Simulation Budget

In [None]:
# Match budget of compositional
train_dict = prior.sample(
    batch_size=10000 // grid_size * 10,
    n_local_samples=grid_size,
)

train_dict['global_params'] = keras.ops.convert_to_numpy(prior.normalize_theta(train_dict['global_params'], global_params=True))
train_dict['data'] = keras.ops.convert_to_numpy(prior.normalize_data(train_dict['data']))
train_dict['global_params'].shape

In [None]:
for model_name, model_packet in models.items():

    workflow_global = bf.BasicWorkflow(
            adapter=adapter,
            summary_network=bf.networks.DeepSet(summary_dim=5, dropout=0.1, depth=1),  # shallow summary net
            inference_network=model_packet[0](**model_packet[1]),
            #checkpoint_filepath=f"bf_checkpoints/ar1_{model_name}_{grid_size}_x10",
            standardize=None
        )

    history = workflow_global.fit_offline(
        train_dict, batch_size=32, epochs=1000 if model_name != "coupling" else 100, verbose=2
    )

    # Evaluation
    test_global_samples = workflow_global.sample(conditions=test_dict, num_samples=1000)
    global_rmse = diagnostics.metrics.root_mean_squared_error(test_global_samples, test_dict,
                                                              aggregation=np.median)['values'].mean().round(2)
    global_rmse_mad = diagnostics.metrics.root_mean_squared_error(test_global_samples, test_dict,
                                                                  aggregation=mad)['values'].mean().round(2)
    print('Global RMSE:', global_rmse, global_rmse_mad)

    global_rmse = diagnostics.posterior_contraction(test_global_samples, test_dict,
                                                    aggregation=np.median)['values'].mean().round(2)
    global_rmse_mad = diagnostics.posterior_contraction(test_global_samples, test_dict,
                                                        aggregation=mad)['values'].mean().round(2)
    print('Global Contraction:', global_rmse, global_rmse_mad)