## Setup

In [1]:
import os
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pickle
from nadaray_watson_ce import nadaraya_watson_ece, dirichlet_calibration_error
from projected_ce import top_calibration_error, class_wise_calibration_error, projected_calibration_error
from utils import median_heuristic

In [2]:
def generative_model(key, alpha, beta, pi, n_samples):
    """
    Simulates the generative model described in the text.

    Parameters:
    - key: JAX random key.
    - alpha: Dirichlet concentration parameters (m-dimensional).
    - beta: Category probabilities for Y|Z=0 (m-dimensional).
    - pi: Bernoulli parameter for Z.
    - n_samples: Number of samples to generate.

    Returns:
    - g_X: Predictions g(X), sampled from Dir(alpha).
    - Z: Latent variable Z, sampled from Ber(pi).
    - Y: Labels Y, sampled conditionally on Z and g(X).
    """
    keys = jax.random.split(key, 3)

    # Sample g(X) ~ Dir(alpha)
    g_X = jax.random.dirichlet(keys[0], alpha, shape=(n_samples,))

    # Sample Z ~ Ber(pi)
    Z = jax.random.bernoulli(keys[1], p=pi, shape=(n_samples,)).astype(jnp.float32)

    # Compute the combined distribution for Y
    dist = (1 - Z[:, None]) * g_X + Z[:, None] * beta[None, :]

    # Sample Y from the categorical distribution using dist
    Y = jax.random.categorical(keys[2], logits=jnp.log(dist), axis=1)

    return g_X, Z, Y

## Calibration Errors

In [None]:

params = {
        "dim":10 , 
        "alpha":1*jnp.ones(10,), 
        "beta":1*jnp.ones(10,),
        "pi_m1":0.0 , 
        "pi_m2":0.2 , 
        "pi_m3":0.6, 
        "kernelized":False, 
        "num_bins":20, 
        "equal_size":False , 
        "num_proj"
    }

calibration_methods = {
        "top_class": lambda probs, labels: top_calibration_error(probs, labels, num_bins=params['num_bins'], equal_size=params['equal_size']),
        "random_projected": lambda key, probs, labels: projected_calibration_error(key, probs, labels, num_bins=params['num_bins'] ,  equal_size=params['equal_size']))
    }

jitted_calibration_methods = {k:jax.jit(calibration_methods[k]) for k in calibration_methods.keys()}

# Main experimental function
def run_experiments(key,params:dict, calibration_methods:dict, num_runs=100, num_samples=1000):
    
    
    methods_name = list(calibration_methods.keys())
    all_results = {"M1": {m:[] for m in methods_name},
                   "M2": {m:[] for m in methods_name},
                   "M3": {m:[] for m in methods_name}}
    
    models = {
        "M1": lambda key: generative_model(key,params['alpha'], params["beta"], params["pi_m1"], num_samples),
        "M2": lambda key: generative_model(key,params['alpha'], params["beta"], params["pi_m2"], num_samples),
        "M3": lambda key: generative_model(key,params['alpha'], params["beta"], params["pi_m3"], num_samples),
    }
    
    for run in range(num_runs):
        for model_name, model_fn in models.items():
            
            if "random" in model_name:
                key, _ = jax.random.split(key,2)
                error = 
            probs, labels = model_fn()
            errors = model_fn()
            all_results[model_name].append(errors)


# Load or run experiments
def load_or_run_experiments():
    if os.path.exists("calibration_experiment_results.pkl"):
        with open("calibration_experiment_results.pkl", "rb") as f:
            all_results = pickle.load(f)
    else:
        all_results = run_experiments()
    return all_results

# Plot results
def plot_results(all_results):
    metrics = list(next(iter(all_results.values()))[0].keys())
    num_models = len(all_results.keys())
    
    fig, axes = plt.subplots(len(metrics), num_models, figsize=(12, 8), sharex=False, sharey=False)
    fig.subplots_adjust(hspace=0.4, wspace=0.4)

    for i, metric in enumerate(metrics):
        for j, model in enumerate(all_results.keys()):
            ax = axes[i, j]
            data = [res[metric] for res in all_results[model]]

            # Plot histogram
            ax.hist(data, bins=30, alpha=0.7, color='teal', edgecolor='black', density=False)

            # Add vertical lines for mean (orange) and dashed blue for zero
            ax.axvline(x=np.mean(data), color='orange', linewidth=1.5, label='Mean')
            ax.axvline(x=0, color='blue', linestyle='--', linewidth=1.2)

            # Set labels and titles
            if i == 0:
                ax.set_title(f"{model}", fontsize=12)
            if j == 0:
                ax.set_ylabel(f"{metric}", fontsize=12)
            if i == len(metrics) - 1:
                ax.set_xlabel("Calibration Error Estimate", fontsize=10)

    # Add legend to one subplot
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc="upper right", fontsize=10)
    plt.tight_layout()
    plt.show()

# Execute
all_results = load_or_run_experiments()
plot_results(all_results)


SyntaxError: invalid syntax (2777148048.py, line 39)