Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
jax_MCMC_blog_post/logistic_regression_model.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
116 lines (87 sloc)
3.08 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import jax.numpy as jnp | |
from jax import jit, grad, value_and_grad, vmap, random | |
from jax.scipy.special import logsumexp | |
import warnings | |
#ignore by GPU/TPU message (generated by jax module) | |
warnings.filterwarnings("ignore", message='No GPU/TPU found, falling back to CPU.') | |
def genCovMat(key, d): | |
return jnp.eye(d) | |
def logistic(theta, x): | |
return 1/(1+jnp.exp(-jnp.dot(theta, x))) | |
batch_logistic = jit(vmap(logistic, in_axes=(None, 0))) | |
batch_benoulli = vmap(random.bernoulli, in_axes=(0, 0)) | |
def gen_data(key, dim, N): | |
""" | |
Generate data with dimension `dim` and `N` data points | |
Parameters | |
---------- | |
key: uint32 | |
random key | |
dim: int | |
dimension of data | |
N: int | |
Size of dataset | |
Returns | |
------- | |
theta_true: ndarray | |
Theta array used to generate data | |
X: ndarray | |
Input data, shape=(N,dim) | |
y_data: ndarray | |
Output data: 0 or 1s. shape=(N,) | |
""" | |
key, subkey1, subkey2, subkey3 = random.split(key, 4) | |
print(f"generating data, with N={N} and dim={dim}") | |
theta_true = random.normal(subkey1, shape=(dim, ))*jnp.sqrt(10) | |
covX = genCovMat(subkey2, dim) | |
X = jnp.dot(random.normal(subkey3, shape=(N,dim)), jnp.linalg.cholesky(covX)) | |
p_array = batch_logistic(theta_true, X) | |
keys = random.split(key, N) | |
y_data = batch_benoulli(keys, p_array).astype(jnp.int32) | |
return theta_true, X, y_data | |
def build_grad_log_post(X, y_data, N): | |
""" | |
Builds grad_log_post | |
""" | |
@jit | |
def loglikelihood(theta, x_val, y_val): | |
return -logsumexp(jnp.array([0., (1.-2.*y_val)*jnp.dot(theta, x_val)])) | |
@jit | |
def log_prior(theta): | |
return -(0.5/10)*jnp.dot(theta,theta) | |
batch_loglik = jit(vmap(loglikelihood, in_axes=(None, 0,0))) | |
def log_post(theta): | |
return log_prior(theta) + N*jnp.mean(batch_loglik(theta, X, y_data), axis=0) | |
grad_log_post = jit(grad(log_post)) | |
return grad_log_post | |
def build_value_and_grad_log_post(X, y_data, N): | |
""" | |
Builds grad_log_post | |
""" | |
@jit | |
def loglikelihood(theta, x_val, y_val): | |
return -logsumexp(jnp.array([0., (1.-2.*y_val)*jnp.dot(theta, x_val)])) | |
@jit | |
def log_prior(theta): | |
return -(0.5/10)*jnp.dot(theta,theta) | |
batch_loglik = jit(vmap(loglikelihood, in_axes=(None, 0,0))) | |
def log_post(theta): | |
return log_prior(theta) + N*jnp.mean(batch_loglik(theta, X, y_data), axis=0) | |
val_and_grad_log_post = jit(value_and_grad(log_post)) | |
return val_and_grad_log_post | |
def build_batch_grad_log_post(X, y_data, N): | |
""" | |
Builds grad_log_post that takes in minibatches X and y_data | |
""" | |
@jit | |
def loglikelihood(theta, x_val, y_val): | |
return -logsumexp(jnp.array([0., (1.-2.*y_val)*jnp.dot(theta, x_val)])) | |
@jit | |
def log_prior(theta): | |
return -(0.5/10)*jnp.dot(theta,theta) | |
batch_loglik = jit(vmap(loglikelihood, in_axes=(None, 0,0))) | |
def log_post(theta, X, y_data): | |
return log_prior(theta) + N*jnp.mean(batch_loglik(theta, X, y_data), axis=0) | |
grad_log_post = jit(grad(log_post)) | |
return grad_log_post |