Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
import numpy as np
import jax.numpy as jnp
from jax import jit, grad, value_and_grad, vmap, random
from jax.scipy.special import logsumexp
import warnings
#ignore by GPU/TPU message (generated by jax module)
warnings.filterwarnings("ignore", message='No GPU/TPU found, falling back to CPU.')
def genCovMat(key, d):
return jnp.eye(d)
def logistic(theta, x):
return 1/(1+jnp.exp(-jnp.dot(theta, x)))
batch_logistic = jit(vmap(logistic, in_axes=(None, 0)))
batch_benoulli = vmap(random.bernoulli, in_axes=(0, 0))
def gen_data(key, dim, N):
"""
Generate data with dimension `dim` and `N` data points
Parameters
----------
key: uint32
random key
dim: int
dimension of data
N: int
Size of dataset
Returns
-------
theta_true: ndarray
Theta array used to generate data
X: ndarray
Input data, shape=(N,dim)
y_data: ndarray
Output data: 0 or 1s. shape=(N,)
"""
key, subkey1, subkey2, subkey3 = random.split(key, 4)
print(f"generating data, with N={N} and dim={dim}")
theta_true = random.normal(subkey1, shape=(dim, ))*jnp.sqrt(10)
covX = genCovMat(subkey2, dim)
X = jnp.dot(random.normal(subkey3, shape=(N,dim)), jnp.linalg.cholesky(covX))
p_array = batch_logistic(theta_true, X)
keys = random.split(key, N)
y_data = batch_benoulli(keys, p_array).astype(jnp.int32)
return theta_true, X, y_data
def build_grad_log_post(X, y_data, N):
"""
Builds grad_log_post
"""
@jit
def loglikelihood(theta, x_val, y_val):
return -logsumexp(jnp.array([0., (1.-2.*y_val)*jnp.dot(theta, x_val)]))
@jit
def log_prior(theta):
return -(0.5/10)*jnp.dot(theta,theta)
batch_loglik = jit(vmap(loglikelihood, in_axes=(None, 0,0)))
def log_post(theta):
return log_prior(theta) + N*jnp.mean(batch_loglik(theta, X, y_data), axis=0)
grad_log_post = jit(grad(log_post))
return grad_log_post
def build_value_and_grad_log_post(X, y_data, N):
"""
Builds grad_log_post
"""
@jit
def loglikelihood(theta, x_val, y_val):
return -logsumexp(jnp.array([0., (1.-2.*y_val)*jnp.dot(theta, x_val)]))
@jit
def log_prior(theta):
return -(0.5/10)*jnp.dot(theta,theta)
batch_loglik = jit(vmap(loglikelihood, in_axes=(None, 0,0)))
def log_post(theta):
return log_prior(theta) + N*jnp.mean(batch_loglik(theta, X, y_data), axis=0)
val_and_grad_log_post = jit(value_and_grad(log_post))
return val_and_grad_log_post
def build_batch_grad_log_post(X, y_data, N):
"""
Builds grad_log_post that takes in minibatches X and y_data
"""
@jit
def loglikelihood(theta, x_val, y_val):
return -logsumexp(jnp.array([0., (1.-2.*y_val)*jnp.dot(theta, x_val)]))
@jit
def log_prior(theta):
return -(0.5/10)*jnp.dot(theta,theta)
batch_loglik = jit(vmap(loglikelihood, in_axes=(None, 0,0)))
def log_post(theta, X, y_data):
return log_prior(theta) + N*jnp.mean(batch_loglik(theta, X, y_data), axis=0)
grad_log_post = jit(grad(log_post))
return grad_log_post