In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

%matplotlib inline
import edward as ed
import matplotlib.pyplot as plt
import numpy as np
import pprint
import scipy
import tensorflow as tf

from sklearn.model_selection import train_test_split
from edward.models import Bernoulli, Normal, MultivariateNormalTriL

In [2]:
class FLAGS:
    N=1000   # Number of data points
    D=5     # Number of features


In [3]:
def build_toy_dataset(N, D, noise_std=1):
    X = np.random.uniform(-6, 6, size=(N, D))
    w = np.random.uniform(-1, 1, size=D)
    b = np.random.uniform(-4, 4)
    epsilon = np.random.normal(0, noise_std, size=N)
    y = (np.dot(X, w) + b + epsilon > 0).astype(int)
    # note this is actually generated from a probit model
    return X, y

In [4]:
ed.set_seed(42)

# DATA
X_all, y_all = build_toy_dataset(2*FLAGS.N, FLAGS.D)
X_train, X_next, y_train, y_next = train_test_split(X_all, y_all, train_size=FLAGS.N)

# MODEL
X = tf.placeholder(tf.float32, [FLAGS.N, FLAGS.D])
w = Normal(loc=tf.zeros(FLAGS.D), scale=tf.ones(FLAGS.D))
b = Normal(loc=tf.zeros([1]), scale=tf.ones([1]))
y = Bernoulli(logits=ed.dot(X, w) + b)

# INFERENCE
qb = Normal(
    loc=tf.Variable(tf.zeros([1])), 
    scale=tf.Variable(tf.ones([1])))  # should probably initialize to random values

w_init = np.random.randn(FLAGS.D)
print(w_init)

qw = MultivariateNormalTriL(
    loc=tf.Variable(tf.cast(w_init, tf.float32)),
    scale_tril=tf.Variable(tf.random_normal([FLAGS.D, FLAGS.D])))

# inference = ed.KLqp({w: qw, b: qb}, data={X: X_train, y: y_train})
inference = ed.Laplace({w: qw, b: qb}, data={X: X_train, y: y_train})
inference.initialize(n_print=10, n_iter=600)

inference.run()




[-0.45701382  0.22541014  1.6967911  -1.13522536 -1.30567135]


  not np.issubdtype(value.dtype, np.float) and \
  not np.issubdtype(value.dtype, np.int) and \


1000/1000 [100%] ██████████████████████████████ Elapsed: 1s | Loss: 161.571


In [5]:
if FLAGS.D == 1:
    n_posterior_samples = 10

    w_post = qw.sample(n_posterior_samples).eval()
    b_post = qb.sample(n_posterior_samples).eval()

    plt.rcParams["figure.figsize"] = (8,6)
    plt.scatter(X_train, y_train)

    inputs = np.linspace(-6, 6, num=400)
    for ns in range(n_posterior_samples):
        output = scipy.special.expit(np.dot(inputs[:,np.newaxis], w_post[ns]) + b_post[ns])
        plt.plot(inputs, output)

    plt.show()

In [6]:
# these give same result
qw.scale.to_dense().eval()
tf.cholesky(qw.covariance()).eval()

array([[ 0.04134055,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.00395684,  0.04146119,  0.        ,  0.        ,  0.        ],
       [-0.00021316,  0.00259639,  0.03781373,  0.        ,  0.        ],
       [-0.00136431, -0.00312252,  0.00105848,  0.04131312,  0.        ],
       [ 0.00027188,  0.00444304, -0.00186874, -0.00429154,  0.04052261]],
      dtype=float32)

In [7]:
# this is inverse of observed Fisher information, used in Laplace approximation
qw.covariance().eval()

array([[ 1.7090413e-03,  1.6357777e-04, -8.8121587e-06, -5.6401255e-05,
         1.1239480e-05],
       [ 1.6357777e-04,  1.7346867e-03,  1.0680584e-04, -1.3486191e-04,
         1.8528968e-04],
       [-8.8121587e-06,  1.0680584e-04,  1.4366646e-03,  3.2208438e-05,
        -5.9186139e-05],
       [-5.6401255e-05, -1.3486191e-04,  3.2208438e-05,  1.7195056e-03,
        -1.9351946e-04],
       [ 1.1239480e-05,  1.8528968e-04, -5.9186139e-05, -1.9351946e-04,
         1.6838056e-03]], dtype=float32)

In [8]:
qw.mean().eval()

array([-0.415916  , -0.49535665, -0.9982332 ,  0.38619578, -0.05535197],
      dtype=float32)

In [9]:
qb.scale.eval()

array([0.01915632], dtype=float32)

In [10]:
qb.loc.eval()

array([4.6540523], dtype=float32)

In [11]:
# verify Segall formula (with c_i=0) is the same
w_map = qw.mean().eval()
b_map = qb.loc.eval()
p = y.mean().eval(feed_dict={X: X_train, w: w_map, b: b_map})

In [12]:
W = np.diag(p * (1-p))

In [13]:
# note this is the inverse covariance for weighted least squares! The weights are just Bernoulli variances
hess_segall = np.matmul(np.matmul(X_train.T, W), X_train) / FLAGS.N

In [14]:
print(hess_segall)

[[ 0.58412335 -0.05584107  0.00713261  0.01491958  0.00410587]
 [-0.05584107  0.58605255 -0.04062332  0.04356729 -0.06125131]
 [ 0.00713261 -0.04062332  0.70118987 -0.02082551  0.03458291]
 [ 0.01491958  0.04356729 -0.02082551  0.58921989  0.05666981]
 [ 0.00410587 -0.06125131  0.03458291  0.05666981  0.62226782]]


In [15]:
obs_fisher = np.linalg.inv(qw.covariance().eval()) / FLAGS.N

In [16]:
print(obs_fisher)

[[ 0.5909487  -0.05544468  0.00757476  0.0153646   0.00418875]
 [-0.05544468  0.59431267 -0.04795557  0.03868396 -0.06226909]
 [ 0.00757476 -0.04795557  0.7011369  -0.01345808  0.02832496]
 [ 0.0153646   0.03868396 -0.01345808  0.59247214  0.06326022]
 [ 0.00418875 -0.06226909  0.02832496  0.06326022  0.60898316]]


In [17]:
# The expression Segall derived for the Fisher information is close enough 
# to the observed Fisher information computed using TF autodiff, available from Laplace approximation.
# What causes the discrepancy?
# Is the Edward Laplace approx really computed at the mode? We evaluate the Segall expression at the MAP estimate...

# Discrepancy increases if the number of samples is one order of magnitude more/less (not sure why)

In [19]:
# y.log_prob requires argument of length y
# construct a new variable representing candidate from y_next
# compute likelihood with autodiff

X = tf.placeholder(tf.float32, [1, FLAGS.D])
w = Normal(loc=tf.zeros(FLAGS.D), scale=tf.ones(FLAGS.D))
b = Normal(loc=tf.zeros([1]), scale=tf.ones([1]))
y = Bernoulli(logits=ed.dot(X, w) + b)


In [20]:
y_next = tf.get_variable("y_next", [1])

In [39]:
y_next_hess = tf.hessians(y.log_prob(y_next.value()), w)[0]

In [80]:
# y_next_hess.eval(feed_dict={X: X_next[[0]], w: w_map, b: b_map})
# hessian doesn't involve y_next, so don't need to feed it in

In [86]:
def new_point_info(X_new):
    return -y_next_hess.eval(feed_dict={X: X_new, w: w_map, b: b_map})

obs_fisher_2 = np.zeros((FLAGS.D, FLAGS.D))
for i in range(len(X_train)):
    obs_fisher_2 += new_point_info(X_train[[i]])

In [87]:
obs_fisher_2 / FLAGS.N

array([[ 0.58412336, -0.05584107,  0.00713259,  0.01491957,  0.00410587],
       [-0.05584107,  0.58605253, -0.0406233 ,  0.04356728, -0.06125131],
       [ 0.0071326 , -0.0406233 ,  0.70118985, -0.02082552,  0.03458291],
       [ 0.01491957,  0.04356728, -0.02082552,  0.58921987,  0.05666981],
       [ 0.00410587, -0.06125132,  0.03458291,  0.05666981,  0.62226782]])

In [78]:
hess_segall

array([[ 0.58412335, -0.05584107,  0.00713261,  0.01491958,  0.00410587],
       [-0.05584107,  0.58605255, -0.04062332,  0.04356729, -0.06125131],
       [ 0.00713261, -0.04062332,  0.70118987, -0.02082551,  0.03458291],
       [ 0.01491958,  0.04356729, -0.02082551,  0.58921989,  0.05666981],
       [ 0.00410587, -0.06125131,  0.03458291,  0.05666981,  0.62226782]])

In [79]:
# Yay, we match Segall exactly if we compute the Hessians with autodiff ourselves, 
# instead of using the observed Fisher information from the Laplace approximation!
# Now let's implement item selection.

In [83]:
from collections import namedtuple

OptCriteria = namedtuple('OptCriteria', ['trace', 'logdet', 'max_eigval'])

# Lambda is the precision matrix
# return measures of the variance, Lambda^{-1}
def compute_opt_criteria(Lambda):
    eigvals_Lambda = np.linalg.eigvals(Lambda)
    # When there are not enough previous questions for full rank,
    # some eigenvalues will be 0. Ensure numerical stability here.
    # Originally chose threshold 1e-8, but sometimes an eigenvalue
    # that should've been 0 exceeded this threshold.
    # TODO: handle in a safer way.
    eigvals_Lambda[eigvals_Lambda < 1e-6] = 1e-6
    eigvals_Var = 1.0 / eigvals_Lambda
    return OptCriteria(
        trace = np.sum(eigvals_Var),
        logdet = np.sum(np.log(eigvals_Var)),
        max_eigval = np.max(eigvals_Var)
    )

opt_criteria_comparators = dict(
    A = lambda o1, o2: o1.trace - o2.trace,
    D = lambda o1, o2: o1.logdet - o2.logdet,
    E = lambda o1, o2: o1.max_eigval - o2.max_eigval
)

def cmp_criteria_for(optimality_type):
    return opt_criteria_comparators[optimality_type]


In [92]:
Lambda_prev = obs_fisher_2
best_new, best_opt_criteria = None, None
cmp_criteria = cmp_criteria_for('A')

for i in range(FLAGS.N):
    X_cand = X_next[[i]]
    Lambda_cand = Lambda_prev + new_point_info(X_cand)
    opt_criteria = compute_opt_criteria(Lambda_cand)
    
    if best_opt_criteria is None or cmp_criteria(opt_criteria, best_opt_criteria) < 0:
        best_new = i
        best_opt_criteria = opt_criteria

In [95]:
best_new, X_next[[best_new]]

(259,
 array([[ 4.09315405,  5.96219141, -0.83588192, -2.15641509,  4.87983956]]))