# Bayesian Neural Network (BNN) models training

Import necessary libraries for numerical computing, data manipulation, and probabilistic programming:

In [None]:
import numpy as np
import pandas as pd
import jax.numpy as jnp
import jax.random as random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, init_to_mean
from numpyro.infer.autoguide import AutoDiagonalNormal, AutoMultivariateNormal
from sklearn.model_selection import train_test_split
import pickle
import os
import time

# Configure NumPyro to use two CPU devices for parallel execution
numpyro.set_host_device_count(2)

Load training dataset and separate features (X) and class labels (y):

In [None]:
train = pd.read_csv('data/interim/train.csv')  
X = train.drop('label', axis=1)
y = train['label']

Split dataset into training and validation sets with stratification to preserve label distribution:

In [None]:
X_train, X_val, y_train, y_val = train_test_split(
  X, y, test_size=1000, random_state=123, stratify=y
)

Convert training and validation arrays into DataFrames with proper column names:

In [33]:
X_train_df = pd.DataFrame(X_train, columns = X.columns)
y_train_df = pd.DataFrame(y_train, columns = ['label'])

X_val_df = pd.DataFrame(X_val, columns = X.columns)
y_val_df = pd.DataFrame(y_val, columns = ['label'])

Combine labels and features into single DataFrames for training and validation:

In [34]:
train_df = pd.concat([y_train_df, X_train_df], axis=1)
val_df = pd.concat([y_val_df, X_val_df], axis=1)

Save processed datasets for reproducibility:

In [35]:
train_df.to_csv("data/processed/train.csv", index=False)
val_df.to_csv("data/processed/validation.csv", index=False)

Convert datasets to JAX arrays for NumPyro compatibility:

In [36]:
X_train, X_val = jnp.array(X_train), jnp.array(X_val)
y_train, y_val = jnp.array(y_train), jnp.array(y_val)

Define experimental configurations: network widths, prior means for precision, and inference methods:

In [None]:
widths = [5, 10, 14]
precision_priors = [0.01, 0.1, 1.0]  
inference_methods = ['mcmc', 'vi']
vi_guides = {
    'AutoDiag': AutoDiagonalNormal,
    'AutoMult': AutoMultivariateNormal
}
mcmc_kernels = {
    'NUTS': NUTS
}

Map prior precision values to corresponding Gamma distribution parameters (alpha, beta):

In [None]:
precision_prior_map = {
    0.01: (2.0, 200.0),  
    0.1: (2.0, 20.0),   
    1.0: (2.0, 2.0)   
}

Define BNN model with one hidden layer:

In [None]:
def bnn_model(X, y=None, hidden_dim=10, precision_prior=1.0):
    n, m = X.shape
    alpha, beta = precision_prior_map[precision_prior]
    precision_nn = numpyro.sample('precision_nn', dist.Gamma(alpha, beta))

    # First layer: biases and weights
    with numpyro.plate('l1_hidden', hidden_dim):
        b1 = numpyro.sample(
            'nn_b1', dist.Normal(0.0, jnp.sqrt(1.0 / (precision_nn*m)))
        )
        with numpyro.plate('l1_feat', m):
            w1 = numpyro.sample(
                'nn_w1', dist.Normal(0.0, jnp.sqrt(1.0 / (precision_nn*m)))
            )

    # Second (output) layer: weights and bias
    with numpyro.plate('l2_hidden', hidden_dim):
        w2 = numpyro.sample(
            'nn_w2', dist.Normal(0.0, jnp.sqrt(1.0 / (precision_nn*hidden_dim)))
        )
    b2 = numpyro.sample(
        'nn_b2', dist.Normal(0.0, jnp.sqrt(1.0 / (precision_nn*hidden_dim)))
    )

    # Forward pass with ReLU activation
    hidden = jnp.maximum(X @ w1 + b1, 0)
    logits = hidden @ w2 + b2

    # Bernoulli likelihood
    with numpyro.plate('data', n):
        numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=y)

Create output directories for results, ELBO values, and extra diagnostics:

In [40]:
os.makedirs('results/bnn/', exist_ok=True)
os.makedirs('results/bnn/elbo/', exist_ok=True)
os.makedirs('results/bnn/extra/', exist_ok=True)

Initialise experiment tracking variables:

In [None]:
experiment_logs = []
mcmc_chains = 2
main_key = random.PRNGKey(123)

Main experiment loop over inference methods and configurations:

In [None]:
for method in inference_methods:
    if method == 'mcmc':

        # Loop over MCMC kernels (just NUTS)
        for kernel_name, kernel_cls in mcmc_kernels.items():
            for width in widths:
                for prior in precision_priors:
                    id = f'{method}_{kernel_name}_w{width}_p{prior:.3g}'
                    print(f'Running {id}')
                    
                    # Generate independent PRNG keys for each MCMC chain
                    main_key, run_key = random.split(main_key)
                    chain_keys = random.split(run_key, num=mcmc_chains)
                    
                    start_time = time.time()
                    
                    try:
                        # Initialise and run MCMC
                        kernel = kernel_cls(
                            bnn_model, init_strategy=init_to_mean
                        )
                        mcmc = MCMC(
                            kernel, num_warmup=100, num_samples=250, 
                            num_chains=mcmc_chains
                        )
                        mcmc.run(
                            chain_keys, X_train, y_train, hidden_dim=width, 
                            precision_prior=prior
                        )
                        results = mcmc.get_samples(group_by_chain=True)
                        
                        # Save posterior samples and extra diagnostics
                        with open(f'results/bnn/{id}_samples.pkl', 'wb') as f:
                            pickle.dump(results, f)

                        extras = mcmc.get_extra_fields()
                        with open(f'results/bnn/extra/{id}_extra.pkl', 'wb') as f:
                            pickle.dump(extras, f)
                        
                        # Log experiment success
                        duration = time.time() - start_time
                        experiment_logs.append({
                            'id': id,
                            'method': method,
                            'kernel': kernel_name,
                            'width': width,
                            'precision_prior': prior,
                            'duration_seconds': round(duration, 2),
                            'status': 'success'
                        })
                        
                    except Exception as e:
                        # Log experiment failure
                        duration = time.time() - start_time
                        experiment_logs.append({
                            'id': id,
                            'method': method,
                            'kernel': kernel_name,
                            'width': width,
                            'precision_prior': prior,
                            'duration_seconds': round(duration, 2),
                            'status': f'error: {str(e)}'
                        })
                        print(f'Experiment {id} failed: {e}')

    else:  # Variational inference branch

        # Loop over VI guide types
        for guide_name, guide_cls in vi_guides.items():
            for width in widths:
                for prior in precision_priors:
                    
                    id = f'{method}_{guide_name}_w{width}_p{prior:.3g}'
                    print(f'Running {id}')
                    
                    # Generate PRNG keys for training and posterior sampling
                    main_key, run_key, post_key = random.split(main_key, 3)
                    
                    start_time = time.time()
                    
                    try:
                        # Initialise and run SVI
                        guide = guide_cls(bnn_model)
                        svi = SVI(
                            bnn_model, guide, numpyro.optim.Adam(1e-3), 
                            Trace_ELBO()
                        )
                        svi_result = svi.run(
                            run_key, 2000, X_train, y_train, hidden_dim=width, 
                            precision_prior=prior
                        )

                        # Save ELBO trajectory
                        elbo_vals = svi_result.losses
                        
                        # Draw posterior samples from the fitted guide
                        results = guide.sample_posterior(
                            post_key, svi_result.params, sample_shape=(500,)
                        )
                        
                        # Save posterior samples and ELBO values
                        with open(f'results/bnn/{id}_samples.pkl', 'wb') as f:
                            pickle.dump(results, f)

                        with open(f'results/bnn/elbo/{id}_elbo.pkl', 'wb') as f:
                            pickle.dump(elbo_vals, f)
                        
                        # Log experiment success
                        duration = time.time() - start_time
                        experiment_logs.append({
                            'id': id,
                            'method': method,
                            'guide': guide_name,
                            'width': width,
                            'precision_prior': prior,
                            'duration_seconds': round(duration, 2),
                            'status': 'success'
                        })
                        
                    except Exception as e:
                        # Log experiment failure
                        duration = time.time() - start_time
                        experiment_logs.append({
                            'id': id,
                            'method': method,
                            'guide': guide_name,
                            'width': width,
                            'precision_prior': prior,
                            'duration_seconds': round(duration, 2),
                            'status': f'error: {str(e)}'
                        })
                        print(f'Experiment {id} failed: {e}')

# Save experiment log to CSV
log_df = pd.DataFrame(experiment_logs)
log_df.to_csv('results/experiment_log.csv', index=False)
print('Experiment summary saved to results/experiment_log.csv')

Running mcmc_NUTS_w5_s0.5


  0%|          | 0/750 [00:00<?, ?it/s]

  0%|          | 0/750 [00:00<?, ?it/s]

Running mcmc_NUTS_w5_s1.0


  0%|          | 0/750 [00:00<?, ?it/s]

  0%|          | 0/750 [00:00<?, ?it/s]

Running mcmc_NUTS_w5_s2.0


  0%|          | 0/750 [00:00<?, ?it/s]

  0%|          | 0/750 [00:00<?, ?it/s]

Running mcmc_NUTS_w10_s0.5


  0%|          | 0/750 [00:01<?, ?it/s]

  0%|          | 0/750 [00:00<?, ?it/s]

Running mcmc_NUTS_w10_s1.0


  0%|          | 0/750 [00:00<?, ?it/s]

  0%|          | 0/750 [00:00<?, ?it/s]