In [2]:
import os
os.chdir('../../')

In [3]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn.datasets as datasets
from models import cmlp
from utilities import fits
from flax import linen as nn
from sklearn.model_selection import train_test_split
from sklearn.metrics import brier_score_loss
from utilities.plot import plot_binary_class,plot_caliberation_classification
import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions
from functools import partial
from jax.flatten_util import ravel_pytree
import blackjax
import probml_utils as pml
from probml_utils.blackjax_utils import inference_loop
from sklearn.calibration import calibration_curve,CalibrationDisplay
import logging
logger = logging.getLogger()
class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return "check_types" not in record.getMessage()
logger.addFilter(CheckTypesFilter())
from jax.random import PRNGKey as rng_key

2022-07-22 21:46:56.676303: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory


In [4]:
os.environ['LATEXIFY'] = ''
os.environ['FIG_DIR'] = '.'

In [5]:
pml.latexify(width_scale_factor=2.4,fig_height=2)

# Dataset Visualization

In [6]:
X,Y =  datasets.make_moons(1000, noise=0.20,random_state=6)

In [7]:
plt.scatter(X[:,0],X[:,1],c=Y,cmap="bwr")
sns.despine()

In [8]:
X_ood = tfd.MultivariateNormalDiag([-1,-2],[0.2,0.2]).sample(seed =rng_key(0),sample_shape=50 )

In [9]:
plt.scatter(X[:,0],X[:,1],c=Y,cmap="bwr")
plt.scatter(*X_ood.T,c='c')
sns.despine()

In [10]:
h=0.05
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 2
y_min, y_max = X[:, 1].min() - 2, X[:, 1].max() + 1
xx, yy = jnp.meshgrid(jnp.arange(x_min, x_max, h),
                     jnp.arange(y_min, y_max, h))

In [11]:
xx.shape

(112, 136)

In [12]:
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.5, random_state=45)

In [13]:
X_train = jnp.array(X_train)
X_test = jnp.array(X_test)
y_train = jnp.array(y_train)
y_test = jnp.array(y_test)


# Single MLP

In [28]:
model = cmlp.cmlp([8,16,16,8,1],[nn.relu,nn.relu,nn.relu,nn.relu],[0.03]*4)

In [29]:
auxs = {
    "X" : X_train,
    "y" : y_train,
    "deterministic": True,
}

## Train

In [30]:
seed=0
params = model.init(jax.random.PRNGKey(seed), X_train, deterministic = True)
params, loss =  fits.fit(model, params,X_train,y_train,True,learning_rate=0.01, epochs=1000,batch_size=64)

In [31]:
plt.plot(jnp.arange(0,loss.shape[0],1),loss)
sns.despine()

In [48]:
def plot_caliberation_classification(pred_train,pred_test,title,y_train,y_test,legend=0):
    fig,ax1 = plt.subplots(1,1)
    prob_true_train,prob_pred_train = calibration_curve(y_train,pred_train,n_bins=5)
    prob_true_test,prob_pred_test = calibration_curve(y_test,pred_test)

    disp_train = CalibrationDisplay(prob_true_train,prob_pred_train,pred_train,estimator_name="Train")
    disp_test = CalibrationDisplay(prob_true_test,prob_pred_test,pred_test,estimator_name="Test",)
    disp_train.plot(ax=ax1)
    disp_test.plot(ax=ax1)
    handles,labels = ax1.get_legend_handles_labels()
    labels[0] = 'Ideal'
    sns.despine(ax=ax1)

    ax1.legend(handles,labels,loc=legend)
    ax1.set_title(title)

In [33]:
y_pred_train = model.apply(params, X_train, deterministic=True)
y_pred_test = model.apply(params, X_test, deterministic=True)

In [49]:
plot_caliberation_classification = partial(plot_caliberation_classification,y_train=y_train,y_test=y_test)

In [35]:
mlp_pred_train = model.apply(params,X_train,deterministic=True).reshape(y_train.shape)
mlp_pred_test = model.apply(params,X_test,deterministic=True).reshape(y_test.shape)
plot_caliberation_classification(mlp_pred_train,mlp_pred_test,title="")
pml.savefig("figures/MLP_caliberation.pdf")

saving image to ./figures/MLP_caliberation_latexified.pdf
Figure size: [2.5 2. ]


  f"renaming {fname_full} to {fname}{extention} because LATEXIFY is {LATEXIFY}",


In [36]:
p =model.apply(params,jnp.c_[xx.ravel(), yy.ravel()],deterministic=True).reshape(xx.shape)
var = p*(1-p).reshape(xx.shape)

plot_binary_class(X,Y,X_ood,xx,yy,p,titles=("Probability with Single MLP","Variance with Single MLP"),legend=True,color_bar=True)
pml.savefig("figures/mlp_moons.pdf",tight_bbox=False)
#pml.savefig("MLP_moons.pdf")

  f"renaming {fname_full} to {fname}{extention} because LATEXIFY is {LATEXIFY}",


saving image to ./figures/mlp_moons_latexified.pdf
Figure size: [2.5 2. ]


# Deep Ensemble

In [37]:
params_lst= []
loss_lst =[]
for i in range(5):
    params =  model.init(jax.random.PRNGKey(i),X,deterministic=True)
    params,loss = fits.fit(model, params,X_train,y_train,True,learning_rate=0.01, epochs=100)
    params_lst.append(params)
    loss_lst.append(loss)

In [38]:
p_de_train_ensemble = []
p_de_test_ensemble = []

for i in range(5):
    params = params_lst[i]
    p =model.apply(params,X_train,deterministic=True)
    p_de_train_ensemble.append(p)
    p =model.apply(params,X_test,deterministic=True)
    p_de_test_ensemble.append(p)
ensemble_pred_train = jnp.array(p_de_train_ensemble).mean(axis=0).reshape(y_train.shape)
ensemble_pred_test= jnp.array(p_de_test_ensemble).mean(axis=0).reshape(y_test.shape)

In [39]:
plot_caliberation_classification(ensemble_pred_train,ensemble_pred_test,title="",legend=0)
pml.savefig("figures/Deep Ensemble caliberation.pdf")

  f"renaming {fname_full} to {fname}{extention} because LATEXIFY is {LATEXIFY}",


saving image to ./figures/Deep Ensemble caliberation_latexified.pdf
Figure size: [2.5 2. ]


In [40]:
p_grid_ensemble = []

for i in range(5):
    params = params_lst[i]
    p =  model.apply(params,jnp.c_[xx.ravel(), yy.ravel()],deterministic=True).reshape(xx.shape)
    p_grid_ensemble.append(p)
ensemble_p_mean = jnp.array(p_grid_ensemble).mean(axis=0)
# ensemble_p_sigma =  jnp.array(p_grid_ensemble).std(axis=0)
ensemble_p_sigma = jnp.sqrt(ensemble_p_mean*(1-ensemble_p_mean))


In [41]:
plot_binary_class(X,Y,X_ood,xx,yy,ensemble_p_mean,titles=("Probability with Deep ensemble","Variance with Deep Ensemble"),legend=False)
pml.savefig("figures/Deep Ensemble Moons.pdf")

  f"renaming {fname_full} to {fname}{extention} because LATEXIFY is {LATEXIFY}",


saving image to ./figures/Deep Ensemble Moons_latexified.pdf
Figure size: [2.5 2. ]


### MC Dropout

In [42]:
y_stacks = []

@jax.jit
def func(params, i):
    z = model.apply(params, jnp.c_[xx.ravel(), yy.ravel()], deterministic=False, 
                    rngs={"dropout": jax.random.PRNGKey(i)}).reshape(xx.shape)
    return z
y_stacks = jax.vmap(fun=func, in_axes=(None, 0))(params, jnp.arange(50))

In [43]:
mc_pred_mean= jnp.array(y_stacks).mean(axis=0).reshape(xx.shape)
mc_pred_sigma = (jnp.array(y_stacks).std(axis=0)).reshape(yy.shape)

In [44]:
plot_binary_class(X,Y,X_ood,xx,yy,mc_pred_mean,titles=("Probability with MC Droput","Variance with MC Dropout"))
pml.savefig("figures/MC Dropout Moons.pdf")

  f"renaming {fname_full} to {fname}{extention} because LATEXIFY is {LATEXIFY}",


saving image to ./figures/MC Dropout Moons_latexified.pdf
Figure size: [2.5 2. ]


In [45]:
y_stacks_train = []
y_stacks_test = []


def create_apply_func(X_test):
    def func(params, i):
        z = model.apply(params, X_test, deterministic=False, 
                        rngs={"dropout": jax.random.PRNGKey(i)})
        return z
    return func
y_stacks_train = jax.vmap(fun=create_apply_func(X_train), in_axes=(None, 0))(params, jnp.arange(50))
y_stacks_test = jax.vmap(fun=create_apply_func(X_test), in_axes=(None, 0))(params, jnp.arange(50))


In [46]:
mc_pred_train = y_stacks_train.mean(axis=0)
mc_pred_test = y_stacks_test.mean(axis=0)

In [47]:
plot_caliberation_classification(mc_pred_train,mc_pred_test,title='')
pml.savefig("figures/MC_dropout_caliberation.pdf")

  f"renaming {fname_full} to {fname}{extention} because LATEXIFY is {LATEXIFY}",


saving image to ./figures/MC_dropout_caliberation_latexified.pdf
Figure size: [2.5 2. ]


### Bootstrap

In [48]:
num_bootstraps = 10
params_list=[]
loss_list=[]
keys = jax.random.split(jax.random.PRNGKey(0),num_bootstraps)
for i in range(num_bootstraps):
    ids = jax.random.choice(keys[i], jnp.array(range(len(X_train))), (len(X_train),))
    x, y = X_train[ids], y_train[ids]
    params =  model.init(jax.random.PRNGKey(i),X_train,deterministic=True)
    params,loss = fits.fit(model, params, x, y, True, batch_size=len(x), learning_rate=0.01, epochs=100)
    params_list.append(params) 
    loss_list.append(loss)

In [49]:
bs_train = []
bs_test = []

for i in range(num_bootstraps):
    params = params_list[i]
    b =model.apply(params,X_train,deterministic=True)
    bs_train.append(b)
    b =model.apply(params,X_test,deterministic=True)
    bs_test.append(b)
bootstrap_pred_train = jnp.array(bs_train).mean(axis=0).reshape(y_train.shape)
bootstrap_pred_test= jnp.array(bs_test).mean(axis=0).reshape(y_test.shape)

In [50]:
plot_caliberation_classification(bootstrap_pred_train,bootstrap_pred_test,title="")
pml.savefig("figures/Bootstrap caliberation.pdf")

  f"renaming {fname_full} to {fname}{extention} because LATEXIFY is {LATEXIFY}",


saving image to ./figures/Bootstrap caliberation_latexified.pdf
Figure size: [2.5 2. ]


In [51]:
bootstrap = []

for i in range(num_bootstraps):
    params = params_list[i]
    b =  model.apply(params,jnp.c_[xx.ravel(), yy.ravel()],deterministic=True).reshape(xx.shape)
    bootstrap.append(b)
bootstrap_mean = jnp.array(bootstrap).mean(axis=0)
bootstrap_sigma = jnp.sqrt(bootstrap_mean*(1-bootstrap_mean))

In [52]:
plot_binary_class(X,Y,X_ood,xx,yy,bootstrap_mean,titles=("Probability with Bootstrap","Variance with Bootstrap"))
pml.savefig("figures/Bootstrap Moons.pdf")

  f"renaming {fname_full} to {fname}{extention} because LATEXIFY is {LATEXIFY}",


saving image to ./figures/Bootstrap Moons_latexified.pdf
Figure size: [2.5 2. ]


# NUTS

In [53]:
params = model.init(jax.random.PRNGKey(seed), X_train, deterministic = True)


In [54]:
from jax import tree_map,tree_leaves
def glorot_prior(param):
    if(param.ndim==1):
        return tfd.Normal(loc = jnp.zeros_like(param), scale = (1.414/jnp.sqrt(param.shape[0])) )
    else:
        return tfd.Normal(loc=jnp.zeros_like(param),scale=(1.414/jnp.sqrt(param.shape[0]+param.shape[1])))

def one_prior(param):
    if(param.ndim==1):
        return tfd.Normal(loc = jnp.zeros_like(param), scale =1)
    else:
        return tfd.Normal(loc=jnp.zeros_like(param),scale=1)


def set_prior(mean,scale,params):
    priors = tree_map(one_prior,params)
    return priors


def eval_log(prior,param):
    return prior.log_prob(param).sum()

def eval_log_prior(params,priors):
    is_leaf = lambda x: isinstance(x, tfd.Distribution)
    # params_unfreezed = unfreeze(params)
    log_priors = tree_map(eval_log,priors,params,is_leaf=is_leaf)
    return jnp.array(tree_leaves(log_priors)).sum()

def eval_likelihood(params,model,X,Y):
    return tfd.Bernoulli(logits=model.apply(params,X,deterministic=True)).log_prob(Y.reshape((-1,1))).sum()

def log_joint(params,priors,model,X,Y):
    log_prior = eval_log_prior(params,priors=priors)
    log_likelihood = eval_likelihood(params,model,X,Y)
    return log_prior+log_likelihood
    # return log_likelihood

priors = set_prior(0,1,params)
bnn_log_joint_partial = partial(log_joint,model=model,priors=priors,X=X_train,Y=y_train)
bnn_log_joint_partial(params)

DeviceArray(-922.1937, dtype=float32)

In [55]:
def bnn_log_joint(params:dict, X, y, model):
    """
    computes the numerator term of the posterior funtion.

    params: dictionary initalized by model 
    X: shape-(n_samples,2) training points 
    y: shape-(n_samples,) labels for training points
    model: bnn model
    """
    
    logits = model.apply(params, X,deterministic=True).ravel()
    # flatten_params gives a list with all the parameters. its dimension will be 1.
    flatten_params, _ = ravel_pytree(params)
    log_prior = tfd.Normal(0.0, 1.0).log_prob(flatten_params).sum()
    log_likelihood = tfd.Bernoulli(logits=logits).log_prob(y).sum()
    log_joint = log_prior + log_likelihood
    return log_joint +log_likelihood

params = model.init(jax.random.PRNGKey(0), X_train, deterministic = True)
bnn_log_joint_partial =partial(bnn_log_joint,X=X,y=Y,model=model)
bnn_log_joint_partial(params)

DeviceArray(-1999.243, dtype=float32)

In [56]:
from pandas import read_pickle
states,infos = read_pickle("mcmc_moons.pkl")['states'],read_pickle("mcmc_moons.pkl")['infos']

In [57]:
num_warmup = 3000
key = jax.random.PRNGKey(314)
key_samples, key_init, key_warmup, key = jax.random.split(key, 4)
adapt = blackjax.window_adaptation(blackjax.nuts, bnn_log_joint_partial, num_warmup,progress_bar=True)
final_state, kernel, _ = adapt.run(key_warmup, params)

Running window adaptation





In [58]:
num_samples = 3000
states,infos = inference_loop(key_samples, kernel, final_state, num_samples)

In [59]:
from pandas import to_pickle
to_pickle({"states":states,"infos":infos},"mcmc_moons_good.pkl")

In [60]:
seed = jax.random.PRNGKey(0)

def predict_mcmc(x_test):
    def one_step(carry, state):
        params_cur = state.position
        # samples = tfd.Bernoulli(probs=model.apply(params_cur,x_test,deterministic=True)).sample(seed=seed)
        probs=model.apply(params_cur,x_test,deterministic=True)
        return carry, probs
    _, y_stacks = jax.lax.scan(one_step, None, states)
    y_stacks = y_stacks>=0.5
    mean = y_stacks.mean(axis=0)
    std_dev = y_stacks.std(axis=0)
    return mean,std_dev

mcmc_pred_mean,mcmc_pred_sigma = predict_mcmc(jnp.c_[xx.ravel(), yy.ravel()])
mcmc_pred_mean,mcmc_pred_sigma = mcmc_pred_mean.reshape(xx.shape),mcmc_pred_sigma.reshape(xx.shape)

In [61]:
mcmc_pred_mean[0][0]

DeviceArray(0.445, dtype=float32)

In [62]:
mcmc_pred_train, _= predict_mcmc(X_train)
mcmc_pred_test,_ = predict_mcmc(X_test)
plot_caliberation_classification(mcmc_pred_train,mcmc_pred_test,title="")
pml.savefig("figures/MCMC_caliberation_glorot_.pdf")

saving image to ./figures/MCMC_caliberation_glorot__latexified.pdf
Figure size: [2.5 2. ]


  f"renaming {fname_full} to {fname}{extention} because LATEXIFY is {LATEXIFY}",


In [63]:
# def plot_binary_class(
#     X_scatters,
#     y_scatters,
#     XX1_grid,
#     XX2_grid,
#     grid_preds_mean,
#     grid_preds_sigma,
#     titles: tuple,
# ):
#     """
#   funtion to binary classificaton outputs
# @@ -160,22 +161,25 @@ def plot_binary_class(
#   titles: tuple with title of the two images. 
#   """

#     fig, ax = plt.subplots(1, 2, figsize=(20, 6))

#     ax[0].set_title(titles[0], fontsize=16)
#     ax[0].contourf(XX1_grid, XX2_grid, grid_preds_mean, cmap="coolwarm", alpha=0.8)
#     hs = ax[0].scatter(X_scatters.T[0], X_scatters.T[1], c=y_scatters, cmap="bwr")
#     ax[0].legend(*hs.legend_elements(), fontsize=20)

#     ax[1].set_title(titles[1], fontsize=16)
#     CS = ax[1].contourf(XX1_grid, XX2_grid, grid_preds_sigma, cmap="viridis", alpha=0.8)
#     hs = ax[1].scatter(X_scatters.T[0], X_scatters.T[1], c=y_scatters, cmap="bwr")
#     # ax[1].legend(*hs.legend_elements(), fontsize=20)
#     fig.colorbar(CS, ax=ax[1])
#     sns.despine()

In [66]:
titles_mcmc = ("Mean of Predictions of NUTS","Variance of Predictions of NUTS")
plot_binary_class(X_scatters=X,y_scatters=Y,XX1_grid=xx,
XX2_grid = yy,grid_preds_mean=mcmc_pred_mean,X_outside = X_ood,titles=titles_mcmc)
pml.savefig("figures/BNN using MCMC.pdf")
plt.show()

  f"renaming {fname_full} to {fname}{extention} because LATEXIFY is {LATEXIFY}",


saving image to ./figures/BNN using MCMC_latexified.pdf
Figure size: [2.5 2. ]


  """


# GP

In [67]:
import GPy
import numpy as np
# h=0.02
# x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
# y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
# xx, yy = jnp.meshgrid(jnp.arange(x_min, x_max, h),
#                      jnp.arange(y_min, y_max, h))

model = GPy.models.GPClassification(np.array(X_train), np.array(y_train.reshape(-1,1)))
print(model)
print(model.log_likelihood())
model.optimize(messages=True)
model.optimize_restarts(num_restarts=5)
print(model)
model.plot_f()
probability_gp,na = model.predict(jnp.c_[xx.ravel(), yy.ravel()])


Name : gp_classification
Objective : 109.74521507089162
Number of Parameters : 2
Number of Optimization Parameters : 2
Updates : True
Parameters:
  [1mgp_classification.[0;0m  |  value  |  constraints  |  priors
  [1mrbf.variance      [0;0m  |    1.0  |      +ve      |        
  [1mrbf.lengthscale   [0;0m  |    1.0  |      +ve      |        
-109.74521507089162


HBox(children=(VBox(children=(IntProgress(value=0, max=1000), HTML(value=''))), Box(children=(HTML(value=''),)â€¦

Optimization restart 1/5, f = 57.93479031079556
Optimization restart 2/5, f = 87.89248015636133
Optimization restart 3/5, f = 65.79696928778316
Optimization restart 4/5, f = 76.6686355350339
Optimization restart 5/5, f = 66.22382809115038

Name : gp_classification
Objective : 68.24582311586948
Number of Parameters : 2
Number of Optimization Parameters : 2
Updates : True
Parameters:
  [1mgp_classification.[0;0m  |               value  |  constraints  |  priors
  [1mrbf.variance      [0;0m  |  13.753845426549372  |      +ve      |        
  [1mrbf.lengthscale   [0;0m  |  0.7116805207154344  |      +ve      |        




In [68]:
probability_gp = probability_gp.reshape(xx.shape)
# sigma_gp = jnp.sqrt(probability_gp*(1-probability_gp)).reshape(xx.shape)


In [69]:
plot_binary_class(X,Y,X_ood,xx,yy,probability_gp,titles=("Probability with GP","Variance with GP"))
pml.savefig("figures/GP moons.pdf")



saving image to ./figures/GP moons_latexified.pdf
Figure size: [2.5 2. ]


In [70]:
gp_pred_train,na = model.predict(X_train)
gp_pred_test, na = model.predict(X_test)
plot_caliberation_classification(gp_pred_train,gp_pred_test,title="")
pml.savefig("figures/GP_caliberation.pdf")

saving image to ./figures/GP_caliberation_latexified.pdf
Figure size: [2.5 2. ]




# SNGP

In [None]:
%pip install -qq tf-models-official

In [14]:
import pkg_resources
import importlib
importlib.reload(pkg_resources)

<module 'pkg_resources' from '/home/anand/anaconda3/envs/srip/lib/python3.7/site-packages/pkg_resources/__init__.py'>

In [15]:
import matplotlib.pyplot as plt
import matplotlib.colors as colors

from sklearn.model_selection import train_test_split
import numpy as np
import tensorflow as tf
try:
    import official.nlp.modeling.layers as nlp_layers
except ModuleNotFoundError:
    %pip install -qq tf-models-official
    import official.nlp.modeling.layers as nlp_layers

    

In [16]:
class DeepResNet(tf.keras.Model):
  """Defines a multi-layer residual network."""
  def __init__(self, num_classes, num_layers=3, num_hidden=128,
               dropout_rate=0.1, **classifier_kwargs):
    super().__init__()
    # Defines class meta data.
    self.num_hidden = num_hidden
    self.num_layers = num_layers
    self.dropout_rate = dropout_rate
    self.classifier_kwargs = classifier_kwargs

    # Defines the hidden layers.
    self.input_layer = tf.keras.layers.Dense(self.num_hidden, trainable=False)
    self.dense_layers = [self.make_dense_layer() for _ in range(num_layers)]

    # Defines the output layer.
    self.classifier = self.make_output_layer(num_classes)

  def call(self, inputs):
    # Projects the 2d input data to high dimension.
    hidden = self.input_layer(inputs)

    # Computes the resnet hidden representations.
    for i in range(self.num_layers):
      resid = self.dense_layers[i](hidden)
      resid = tf.keras.layers.Dropout(self.dropout_rate)(resid)
      hidden += resid

    return self.classifier(hidden)

  def make_dense_layer(self):
    """Uses the Dense layer as the hidden layer."""
    return tf.keras.layers.Dense(self.num_hidden, activation="relu")

  def make_output_layer(self, num_classes):
    """Uses the Dense layer as the output layer."""
    return tf.keras.layers.Dense(
        num_classes, **self.classifier_kwargs)

In [17]:
class DeepResNetSNGP(DeepResNet):
  def __init__(self, spec_norm_bound=0.9, **kwargs):
    self.spec_norm_bound = spec_norm_bound
    super().__init__(**kwargs)

  def make_dense_layer(self):
    """Applies spectral normalization to the hidden layer."""
    dense_layer = super().make_dense_layer()
    return nlp_layers.SpectralNormalization(
        dense_layer, norm_multiplier=self.spec_norm_bound)

  def make_output_layer(self, num_classes):
    """Uses Gaussian process as the output layer."""
    return nlp_layers.RandomFeatureGaussianProcess(
        num_classes, 
        gp_cov_momentum=-1,
        **self.classifier_kwargs)

  def call(self, inputs, training=False, return_covmat=False):
    # Gets logits and covariance matrix from GP layer.
    logits, covmat = super().call(inputs)

    # Returns only logits during training.
    if not training and return_covmat:
      return logits, covmat

    return logits

In [18]:
class ResetCovarianceCallback(tf.keras.callbacks.Callback):

  def on_epoch_begin(self, epoch, logs=None):
    """Resets covariance matrix at the beginning of the epoch."""
    if epoch > 0:
      self.model.classifier.reset_covariance_matrix()

In [19]:
class DeepResNetSNGPWithCovReset(DeepResNetSNGP):
  def fit(self, *args, **kwargs):
    """Adds ResetCovarianceCallback to model callbacks."""
    kwargs["callbacks"] = list(kwargs.get("callbacks", []))
    kwargs["callbacks"].append(ResetCovarianceCallback())

    return super().fit(*args, **kwargs)

In [20]:
def compute_posterior_mean_probability(logits, covmat, lambda_param=np.pi / 8.):
  # Computes uncertainty-adjusted logits using the built-in method.
  logits_adjusted = nlp_layers.gaussian_process.mean_field_logits(
      logits, covmat, mean_field_factor=lambda_param)

  return tf.nn.softmax(logits_adjusted, axis=-1)[:, 0]

In [21]:
DEFAULT_X_RANGE = (-3.5, 3.5)
DEFAULT_Y_RANGE = (-2.5, 2.5)
DEFAULT_N_GRID = 100


In [22]:
# def make_training_data(sample_size=500):
#   """Create two moon training dataset."""
#   train_examples, train_labels = sklearn.datasets.make_moons(
#       n_samples=2 * sample_size, noise=0.1)

#   # Adjust data position slightly.
#   train_examples[train_labels == 0] += [-0.1, 0.2]
#   train_examples[train_labels == 1] += [0.1, -0.2]

#   return train_examples, train_labels

# def make_testing_data(x_range=DEFAULT_X_RANGE, y_range=DEFAULT_Y_RANGE, n_grid=DEFAULT_N_GRID):
#   """Create a mesh grid in 2D space."""
#   # testing data (mesh grid over data space)
#   x = np.linspace(x_range[0], x_range[1], n_grid)
#   y = np.linspace(y_range[0], y_range[1], n_grid)
#   xv, yv = np.meshgrid(x, y)
#   return np.stack([xv.flatten(), yv.flatten()], axis=-1),xv,yv

# X_train,y_train = make_training_data(
#     sample_size=500)

# X_test,xx,yy = make_testing_data()

In [23]:
def train_and_test_sngp(X_train, X_test):
  sngp_model = DeepResNetSNGPWithCovReset(**resnet_config)

  sngp_model.compile(**train_config)
  sngp_model.fit(X_train, np.array(y_train), verbose=0, **fit_config)

  sngp_logits, sngp_covmat = sngp_model(X_test, return_covmat=True)
  sngp_probs = compute_posterior_mean_probability(sngp_logits, sngp_covmat)

  return sngp_model, sngp_probs

In [24]:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.keras.metrics.SparseCategoricalAccuracy(),
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
resnet_config = dict(num_classes=2, num_layers=6, num_hidden=128)
train_config = dict(loss=loss, metrics=metrics, optimizer=optimizer)
fit_config = dict(batch_size=128, epochs=1000)

2022-07-22 21:47:21.497754: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2022-07-22 21:47:21.497822: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)


In [25]:
X_train.shape, X_test.shape

((500, 2), (500, 2))

In [26]:
import numpy as np
sngp_model, sngp_probs = train_and_test_sngp(np.array(X_train), np.array(X_test))

In [27]:
sngp_std = np.sqrt(sngp_probs*(1-sngp_probs))

In [28]:
sngp_logits, sngp_covmat = sngp_model(np.c_[xx.ravel(), yy.ravel()], return_covmat=True)
sngp_grid_probs = compute_posterior_mean_probability(sngp_logits, sngp_covmat)

In [64]:
sngp_grid_probs = 1-np.array(sngp_grid_probs).reshape(xx.shape)

In [65]:
# plot_binary_class(X,Y,xx,yy,sngp_probs,sngp_std,titles=("Probability with SNGP","Variance with SNGP"))
plot_binary_class(X,Y,X_ood,xx,yy,sngp_grid_probs,titles=("Probability with GP","Variance with GP"))

pml.savefig("figures/SNGP Moons.pdf")


  f"renaming {fname_full} to {fname}{extention} because LATEXIFY is {LATEXIFY}",


saving image to ./figures/SNGP Moons_latexified.pdf
Figure size: [2.5 2. ]


In [39]:
sngp_pred_train = train_and_test_sngp(np.array(X_train), np.array(X_train))

In [59]:
sngp_pred_train_actual = 1-sngp_pred_train[1]
sngp_probs = 1-sngp_probs

In [63]:
plot_caliberation_classification(jnp.array(sngp_pred_train_actual),jnp.array(sngp_probs),title="")
pml.savefig("figures/SNGP_caliberation.pdf")

saving image to ./figures/SNGP_caliberation_latexified.pdf
Figure size: [2.5 2. ]


  f"renaming {fname_full} to {fname}{extention} because LATEXIFY is {LATEXIFY}",


# VI

In [81]:
from flax.core.frozen_dict import freeze, unfreeze
from jax.flatten_util import ravel_pytree
from ajax.advi import ADVI
from ajax.utils import train
tfb = tfp.bijectors
import optax


In [82]:
class MLP(nn.Module):
    layers: list
    
    @nn.compact
    def __call__(self, x):
        for num_features in self.layers[:-1]:
            x = nn.relu(nn.Dense(num_features)(x))
        x = nn.Dense(self.layers[-1])(x)
        return x.ravel()

mlp = MLP([8,16,16,8,1])

In [83]:
seed = jax.random.PRNGKey(45)
frozen_params = mlp.init(seed, X_train)
params = unfreeze(frozen_params)

In [84]:
prior = jax.tree_map(
    lambda param: tfd.Independent(
        tfd.Normal(loc=jnp.zeros(param.shape), scale=jnp.ones(param.shape)),
        reinterpreted_batch_ndims=len(param.shape),
    ),
    params,
)

bijector = jax.tree_map(lambda param: tfb.Identity(), params)

def get_log_likelihood(latent_sample,data,aux, **kwargs):
    frozen_params = freeze(latent_sample)
    logit = mlp.apply(frozen_params, aux['X'])
    def sigmoid(x, scale=100):
        return 0.5 * (jnp.tanh(x * scale / 2) + 1)
    #prob = sigmoid(logit)
    return tfd.Bernoulli(logits=logit).log_prob(data).sum()

model = ADVI(prior, bijector, get_log_likelihood, vi_type = "mean_field")


In [85]:
params = model.init(jax.random.PRNGKey(8))
mean = params["posterior"].mean()
params["posterior"] = tfd.MultivariateNormalDiag(
    loc=mean,
    scale_diag=jax.random.normal(jax.random.PRNGKey(3), shape=(len(mean),))-10,
)

In [87]:
%%time
tx = optax.adam(learning_rate=0.01)
seed1 = jax.random.PRNGKey(4)
seed2 = jax.random.PRNGKey(100)

loss_fn = partial(model.loss_fn, aux={"X": X_train}, batch=y_train, data_size=len(y_train), n_samples = 100)
results = train(
    loss_fn,
    params,
    n_epochs=3000,
    optimizer=tx,
    seed=seed2,
    return_args={"losses"}
)

TypeError: Using a non-tuple sequence for multidimensional indexing is not allowed; use `arr[tuple(seq)]` instead of `arr[seq]`. See https://github.com/google/jax/issues/4564 for more information.