In [5]:
import os
import pickle
import pandas as pd
import jax.numpy as jnp
import jax.random as random
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, log_loss, brier_score_loss

import numpyro
import numpyro.distributions as dist
from numpyro.infer import Predictive, init_to_mean
from numpyro.infer.autoguide import AutoDiagonalNormal, AutoMultivariateNormal

In [3]:
val_df = pd.read_csv("data/processed/validation.csv")
X_val = jnp.array(val_df.drop(columns="label"))
y_val = jnp.array(val_df["label"])

In [4]:
prec_prior_map = {
    0.5: (1.0, 2.0),  
    1.0: (6.0, 6.0),   
    2.0: (4.0, 2.0)   
}

In [6]:
def bnn_model(X, y=None, hidden_dim=10, prec_level=1.0, use_weights=False, weights=None):
    n, m = X.shape
    alpha, beta = prec_prior_map[prec_level]
    prec_nn = numpyro.sample('prec_nn', dist.Gamma(alpha, beta))

    with numpyro.plate('l1_hidden', hidden_dim):
        b1 = numpyro.sample('nn_b1', dist.Normal(0.0, 1.0 / jnp.sqrt(prec_nn * (m + 1))))
        with numpyro.plate('l1_feat', m):
            w1 = numpyro.sample('nn_w1', dist.Normal(0.0, 1.0 / jnp.sqrt(prec_nn * (m + 1))))

    with numpyro.plate('l2_hidden', hidden_dim):
        w2 = numpyro.sample('nn_w2', dist.Normal(0.0, 1.0 / jnp.sqrt(prec_nn * (hidden_dim + 1))))
    b2 = numpyro.sample('nn_b2', dist.Normal(0.0, 1.0 / jnp.sqrt(prec_nn * (hidden_dim + 1))))

    hidden = jnp.maximum(X @ w1 + b1, 0)
    logits = hidden @ w2 + b2

    with numpyro.plate('data', n):
        if use_weights and weights is not None:
            with numpyro.handlers.scale(scale=weights):
                numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=y)
        else:
            numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=y)

In [7]:
def compute_ece(probabilities, true_labels, n_bins=10):
    """Expected Calibration Error (ECE) for binary classification."""
    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)
    bin_indices = np.digitize(probabilities, bin_edges) - 1

    ece = 0.0
    for i in range(n_bins):
        bin_mask = bin_indices == i
        if np.any(bin_mask):
            bin_confidence = probabilities[bin_mask].mean()
            bin_accuracy = true_labels[bin_mask].mean()
            ece += np.abs(bin_confidence - bin_accuracy) * bin_mask.mean()
    return ece

In [10]:
result_files = [f for f in os.listdir("results") if f.endswith(".pkl")]
evaluation_logs = []
key = random.PRNGKey(0)

for filename in result_files:
    if filename == "experiment_log.csv":
        continue

    id = filename.replace(".pkl", "")
    path = os.path.join("results", filename)
    print(f"Evaluating {id}")

    try:
        # Extract metadata from filename
        parts = id.split("_")
        method = parts[0]
        hidden_dim = int([p[1:] for p in parts if p.startswith("w")][0])
        precision = float([p[1:] for p in parts if p.startswith("p")][0])
        
        # Load posterior samples
        with open(path, "rb") as f:
            posterior_samples = pickle.load(f)

        # Generate predictive distribution
        predictive = Predictive(
            bnn_model,
            posterior_samples=posterior_samples,
            return_sites=["obs"],
            parallel=True
        )

        preds = predictive(key, X_val,
                           hidden_dim=hidden_dim,
                           prec_level=precision,
                           use_weights=False)['obs']  # shape: (samples, N)

        # Convert to mean prediction (probability), then threshold
        pred_probs = preds.mean(axis=0)
        pred_labels = (pred_probs >= 0.5).astype(int)

        # Metrics
        precision_score_val = precision_score(y_val, pred_labels)
        recall_score_val = recall_score(y_val, pred_labels)
        f1_score_val = f1_score(y_val, pred_labels)

        true_labels_np = np.array(y_val)
        pred_probs_np = np.array(pred_probs)

        # Probabilistic metrics
        logloss_val = log_loss(true_labels_np, pred_probs_np)
        brier_val = brier_score_loss(true_labels_np, pred_probs_np)
        ece_val = compute_ece(pred_probs_np, true_labels_np, n_bins=10)

        evaluation_logs.append({
            'id': id,
            'precision': round(precision_score_val, 3),
            'recall': round(recall_score_val, 3),
            'f1': round(f1_score_val, 3),
            'log_loss': round(logloss_val, 4),
            'brier': round(brier_val, 4),
            'ece': round(ece_val, 4)
        })

    except Exception as e:
        print(f"Failed to evaluate {id}: {e}")
        evaluation_logs.append({
            'id': id,
            'precision': 'error',
            'recall': 'error',
            'f1': f'error: {str(e)}'
        })

# Save evaluation results
eval_df = pd.DataFrame(evaluation_logs)
eval_df.to_csv('results/evaluation_metrics.csv', index=False)
print("Evaluation metrics saved to results/evaluation_metrics.csv")

Evaluating mcmc_NUTS_w10_p0.5
Evaluating mcmc_NUTS_w10_p1.0
Evaluating mcmc_NUTS_w10_p2.0
Evaluating mcmc_NUTS_w14_p0.5
Evaluating mcmc_NUTS_w14_p1.0
Evaluating mcmc_NUTS_w14_p2.0
Evaluating mcmc_NUTS_w5_p0.5
Evaluating mcmc_NUTS_w5_p1.0
Evaluating mcmc_NUTS_w5_p2.0
Evaluating vi_AutoDiag_w10_p0.5
Evaluating vi_AutoDiag_w10_p1.0
Evaluating vi_AutoDiag_w10_p2.0
Evaluating vi_AutoDiag_w14_p0.5
Evaluating vi_AutoDiag_w14_p1.0
Evaluating vi_AutoDiag_w14_p2.0
Evaluating vi_AutoDiag_w5_p0.5
Evaluating vi_AutoDiag_w5_p1.0
Evaluating vi_AutoDiag_w5_p2.0
Evaluating vi_AutoMult_w10_p0.5
Evaluating vi_AutoMult_w10_p1.0
Evaluating vi_AutoMult_w10_p2.0
Evaluating vi_AutoMult_w14_p0.5
Evaluating vi_AutoMult_w14_p1.0
Evaluating vi_AutoMult_w14_p2.0
Evaluating vi_AutoMult_w5_p0.5
Evaluating vi_AutoMult_w5_p1.0
Evaluating vi_AutoMult_w5_p2.0
Evaluation metrics saved to results/evaluation_metrics.csv
