# Welcome to pyHaiCS

In [13]:
import pyHaiCS as haics

In [14]:
print(f"Running pyHaiCS v.{haics.__version__}")

Running pyHaiCS v.0.0.1


## Example - Bayesian Logistic Regression + HMC for Breast Cancer Classification

In [15]:
import jax
import jax.numpy as jnp

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Load the breast cancer dataset
data = load_breast_cancer()
X, y = data.data, data.target

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Standardize the data & convert to jax arrays
scaler = StandardScaler()
X_train = jnp.array(scaler.fit_transform(X_train))
X_test = jnp.array(scaler.transform(X_test))

# Add column of ones to the input data (for intercept terms)
X_train = jnp.hstack([X_train, jnp.ones((X_train.shape[0], 1))])
X_test = jnp.hstack([X_test, jnp.ones((X_test.shape[0], 1))])

In [16]:
# Bayesian Logistic Regression model (in JAX)
@jax.jit
def model_fn(params, x):
    return jax.nn.sigmoid(jnp.matmul(x, params))

@jax.jit
def log_prior_fn(params):
    return jnp.sum(jax.scipy.stats.norm.logpdf(params))

@jax.jit
def likelihood_fn(params, x, y):
    preds = model_fn(params, x)
    return jnp.prod(preds ** y * (1 - preds) ** (1 - y))

@jax.jit
def log_likelihood_fn(params, x, y):
    preds = model_fn(params, x)
    return jnp.sum(y * jnp.log(preds) + (1 - y) * jnp.log(1 - preds))

@jax.jit
def log_posterior_fn(params, x, y):
    return log_prior_fn(params) + log_likelihood_fn(params, x, y)

# Define a wrapper function to negate the log posterior
@jax.jit
def neg_log_posterior_fn(params, x, y):
    return -log_posterior_fn(params, x, y)

# Initialize the model parameters (includes intercept term)
key = jax.random.PRNGKey(42)
mean_vector = jnp.zeros(X_train.shape[1])
cov_mat = jnp.eye(X_train.shape[1])
params = jax.random.multivariate_normal(key, mean_vector, cov_mat)

# HMC for posterior sampling
# params_samples = haics.samplers.hamiltonian.HMC(params, n_samples=5000, burn_in=2000, 
#                             step_size=1e-3, n_steps=100, 
#                             potential=neg_log_posterior_fn,  
#                             mass_matrix=jnp.eye(X_train.shape[1]), 
#                             integrator=haics.integrators.VerletIntegrator())

In [17]:
from tqdm import tqdm

# Learn the model parameters using HMC
num_samples = 1000
num_burnin = 200
step_size = 1e-3
num_leapfrog_steps = 100

@jax.jit
def Hamiltonian(params, x, y):
    return -log_posterior_fn(params, x, y)


def hmc_step(params, x, y, step_size, num_leapfrog_steps, key, grad_hamiltonian):
    @jax.jit
    def leapfrog(params, momentum, step_size):
        grads = grad_hamiltonian(params, x, y)
        momentum = momentum - 0.5 * step_size * grads
        params = params + step_size * momentum
        grads = grad_hamiltonian(params, x, y)
        momentum = momentum - 0.5 * step_size * grads
        return params, momentum

    momentum = jax.random.normal(key, shape=params.shape)
    new_params, new_momentum = params, momentum

    for _ in range(num_leapfrog_steps):
        new_params, new_momentum = leapfrog(new_params, new_momentum, step_size)

    current_H = Hamiltonian(params, x, y) + 0.5 * jnp.sum(momentum ** 2)
    new_H = Hamiltonian(new_params, x, y) + 0.5 * jnp.sum(new_momentum ** 2)

    accept_prob = jnp.exp(current_H - new_H)
    accept = jax.random.uniform(key) < accept_prob

    return jax.lax.cond(accept, lambda _: new_params, lambda _: params, operand=None)

# Run HMC
samples = []
grad_hamiltonian = jax.grad(Hamiltonian)

for i in tqdm(range(num_samples + num_burnin)):
    key, subkey = jax.random.split(key)
    params = hmc_step(params, X_train, y_train, step_size, num_leapfrog_steps, subkey, grad_hamiltonian)
    if i >= num_burnin:
        samples.append(params)

samples = jnp.array(samples)

# Make predictions using the posterior samples
preds = jax.vmap(lambda params: model_fn(params, X_test))(samples)
mean_preds = jnp.mean(preds, axis=0)
mean_preds = mean_preds > 0.5

# Evaluate the model
accuracy = jnp.mean(mean_preds == y_test)
print(f"Accuracy: {accuracy}")

100%|██████████| 1200/1200 [03:55<00:00,  5.09it/s]


Accuracy: 0.1315789520740509


In [104]:
# Do the same using pymc (same model, and parameters) Use numpy instead of Jax
import pymc as pm
import numpy as np
from sklearn.metrics import accuracy_score

X_train = np.array(X_train)
y_train = np.array(y_train)
X_test = np.array(X_test)
y_test = np.array(y_test)

with pm.Model() as model:
    X_pymc = pm.MutableData('x', X_train)
    y_pymc = pm.MutableData('y', y_train)
    beta = pm.Normal("beta", mu=0, sigma=1, shape=X_pymc.shape[1])

    preds = pm.math.sigmoid(pm.math.dot(X_pymc, beta))
    y_obs = pm.Bernoulli("y_obs", p=preds, observed=y_pymc)

    trace = pm.sample(num_samples, tune=num_burnin, step=pm.HamiltonianMC(adapt_step_size = False), chains = 1)

with model:
    pm.set_data({'x': X_test, 'y': y_test})
    # Make predictions using the posterior samples
    trace.extend(pm.sample_posterior_predictive(trace))
    p_test_pred = trace.posterior_predictive["y_obs"].mean(dim=["chain", "draw"])
    y_test_pred = (p_test_pred >= 0.5).astype("int").to_numpy()

# Evaluate the model
print(f"Accuracy: {accuracy_score(y_test, y_test_pred)}")

Sequential sampling (1 chains in 1 job)
HamiltonianMC: [beta]


Output()

Sampling 1 chain for 200 tune and 1_000 draw iterations (200 + 1_000 draws total) took 4 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sampling: [beta, y_obs]


Output()

Accuracy: 0.7456140350877193
