# Bayesian logistic regression:

In [29]:
import jax.numpy as jnp
from jaxtyping import Array, Float
from dataclasses import dataclass

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

from coinem.model import AbstractModel
from coinem.dataset import Dataset

# Model

Consider two variants of the logistic regression model considered by Liu and Wang (2016). One where the parameters of the Gamma distribution on the prior of $\alpha$ are fixed, and the other where we learn them.

In [48]:
import tensorflow_probability.substrates.jax.distributions as tfd

@dataclass
class LogisticRegressionLearnTheta(AbstractModel):
    """Base class for p(θ, x)."""
    
    def log_prob(self, latent: Float[Array, "D 1"], theta: Float[Array, "Q"], data: Dataset) -> Float[Array, ""]:
        """Compute gradient of the objective function at x.

        Args:
            latent (Float[Array, "D"]): Input weights of shape (D,).
            theta (Float[Array, "Q"]): Parameters of shape (Q,).

        Returns:
            Float[Array, ""]: log-probability of the data.
        """
        alpha = jnp.exp(latent[0])
        beta = latent[1:]


        # likelihood
        z = jnp.matmul(data.X, beta)
        log_lik = tfd.Bernoulli(logits=z.squeeze()).log_prob(data.y.squeeze()).sum()
        

        # Compute linear predictor.
        z = jnp.matmul(data.X, beta)

        # Prior
        log_prior = tfd.Normal(loc=0.0, scale=1.0/jnp.sqrt(alpha)).log_prob(beta).sum().squeeze()
        log_prior_alpha = tfd.Gamma(jnp.exp(theta[0]).squeeze(), rate=jnp.exp(theta[1])).log_prob(alpha).sum().squeeze()


        # Compute log-probability.
        return (log_lik + log_prior + log_prior_alpha).squeeze()
    


@dataclass
class LogisticRegressionFixedTheta(AbstractModel):
    """Base class for p(θ, x)."""
    
    def log_prob(self, latent: Float[Array, "D 1"], theta: Float[Array, "Q"], data: Dataset) -> Float[Array, ""]:
        """Compute gradient of the objective function at x.

        Args:
            latent (Float[Array, "D"]): Input weights of shape (D,).
            theta (Float[Array, "Q"]): Parameters of shape (Q,).

        Returns:
            Float[Array, ""]: log-probability of the data.
        """
        alpha = jnp.exp(latent[0])
        beta = latent[1:]


        # likelihood
        z = jnp.matmul(data.X, beta)
        log_lik = tfd.Bernoulli(logits=z.squeeze()).log_prob(data.y.squeeze()).sum()
        

        # Compute linear predictor.
        z = jnp.matmul(data.X, beta)

        # Prior
        log_prior = tfd.Normal(loc=0.0, scale=1.0/jnp.sqrt(alpha)).log_prob(beta).sum().squeeze()
        log_prior_alpha = tfd.Gamma(concentration=1.0, rate=0.01).log_prob(alpha).sum().squeeze()


        # Compute log-probability.
        return (log_lik + log_prior + log_prior_alpha).squeeze()

# Train the models on covertype

In [49]:
import scipy 
import numpy as np
from sklearn.model_selection import train_test_split

data = scipy.io.loadmat('data/covertype.mat')
    
X_input = data['covtype'][:, 1:]
y_input = data['covtype'][:, 0]
y_input[y_input == 2] = 0

N = X_input.shape[0]
X_input = np.hstack([X_input, np.ones([N, 1])])

# split the dataset into training and testing
X_train, X_test, y_train, y_test = train_test_split(X_input, y_input, test_size=0.2, random_state=42)

D = Dataset(jnp.array(X_train), jnp.array(y_train))

In [65]:
# Hyper-parameters
a0, b0 = 1, 0.01
th0 = jnp.array([jnp.log(a0), jnp.log(b0)])

np.random.seed(42)
    
# initialization
N = 100  # number of particles
d = D.X.shape[-1]
X0 = np.zeros([N, 1 + d])
alpha0 = np.random.gamma(a0, b0, N); 
for i in range(N):
    X0[i, :] = np.hstack([np.log(alpha0[i]), np.random.normal(0, np.sqrt(1 / alpha0[i]), d)])

X0 = jnp.array(X0)

(100, 56)

In [66]:
from coinem.zoo import coin_svgd
from jax import vmap
import jax.random as jr

fixed_theta = LogisticRegressionFixedTheta()
learn_theta = LogisticRegressionLearnTheta()

# Set approximation parameters:
K = 1000  # Number of steps.


# Run SVGD:
x_coin_fixed, theta_coin_fixed = coin_svgd(fixed_theta, D, X0, th0, K, batch_size=100)
x_coin_learn, theta_coin_learn = coin_svgd(learn_theta, D, X0, th0, K, batch_size=100)


In [67]:
# Compute AUC:
from sklearn.metrics import roc_auc_score
def predict_prob(test_inputs, latent):
    """Returns label maximizing the approximate posterior predictive 
    distribution defined by the cloud X, vectorized over feature vectors f.
    """
    s = vmap(lambda x: tfd.Bernoulli(logits=jnp.matmul(test_inputs, x[1:])).mean())(latent).mean(0)
    return s

coin_fixed_auc = roc_auc_score(y_test, predict_prob(X_test, x_coin_fixed[-1]))
coin_learn_auc = roc_auc_score(y_test, predict_prob(X_test, x_coin_learn[-1]))

print(f"Coin Fixed AUC: {coin_fixed_auc}")
print(f"Coin Learn AUC: {coin_learn_auc}")

Coin Fixed AUC: 0.8040118644813954
Coin Learn AUC: 0.8049609884562804
