In [1]:
import zhusuan as zs
import tensorflow as tf
import numpy as np
import sys
from matplotlib import pyplot as plt
from zhusuan.variational import svgd
import seaborn as sns
%matplotlib inline

sess = tf.InteractiveSession()

In [2]:
n_covariates = 13

n_dat, n_train = 600, 300
w_true = np.random.normal(size=[n_covariates, 1])
b_true = np.random.normal() * 0.5
X = np.random.uniform(-2, 2, size=[n_dat, n_covariates])
Y = np.squeeze((X + 0.1 * X ** 3) @ w_true) +\
    np.diag(X @ np.random.normal(size=[n_covariates, n_dat]) * 0.1) + b_true
Y += np.random.normal(size=[n_dat]) * 0.1
X, Xt = X[:n_train], X[n_train:]
Y, Yt = Y[:n_train], Y[n_train:]
# Y_prob = 1 / (1 + np.exp(-Y_logits))
# Y = np.zeros((X.shape[0], ), dtype=np.int32)
# Y[np.random.uniform(0, 1, size=[X.shape[0]]) < Y_prob] = 1

@zs.reuse('linear_reg')
def linear_regression(inp, observed):
    with zs.BayesianNet(observed) as model:
        w = zs.Normal(
            'w', 
            mean=tf.zeros([n_covariates, 1]),
            std=np.float32(1),
            group_ndims=2)
        b = zs.Normal(
            'b', 
            mean=np.float32(0),
            std=np.float32(1))
        n_particles = tf.shape(w.tensor)[0]  # TODO: this assumes rank(w)==3. Does this hold when n_particles=1?
        inp = tf.tile(tf.expand_dims(inp, 0), [n_particles, 1, 1])
        # print(tf.squeeze(inp @ w, axis=-1).shape, b.tensor.shape)
        mean = tf.squeeze(inp @ w, axis=-1) + tf.expand_dims(b, 1) # [n_particles, n_batch]
        # print(logits.shape)
        var_out = zs.Normal('var_out', mean=np.float32(0.1), std=np.float32(1))
        std = tf.expand_dims(tf.nn.softplus(var_out), 1)
        out = zs.Normal('out', mean, std, group_ndims=1) 
    return model

x_ph = tf.placeholder(tf.float32, [None, n_covariates])
y_ph = tf.placeholder(tf.float32, [None])

def log_joint(observed):
    model = linear_regression(x_ph, observed)
    ret = tf.add_n(model.local_log_prob(['w', 'b', 'var_out', 'out']))
    return ret

hmc = zs.HMC(step_size=1e-2, n_leapfrogs=20, adapt_step_size=True,
             target_acceptance_rate=0.6)
n_chain = 10
w_hmc = tf.Variable(tf.zeros([n_chain, n_covariates, 1]), name='w_hmc')
b_hmc = tf.Variable(tf.zeros([n_chain]), name='b_hmc')
var_out_hmc = tf.Variable(tf.zeros([n_chain]), name='var_out')
sample_op, hmc_info = hmc.sample(
    log_joint, observed={'out': y_ph}, 
    latent={'w': w_hmc, 'b': b_hmc, 'var_out': var_out_hmc})

model = linear_regression(x_ph, {
    'out': y_ph, 'w': w_hmc, 'b': b_hmc, 'var_out': var_out_hmc})
y_pred = tf.reduce_mean(model.get('out').distribution.mean, axis=0)
test_rmse = tf.sqrt(tf.reduce_mean((y_pred - y_ph)**2))



In [3]:
hmc_T = 300
sess.run(tf.global_variables_initializer())
traces = np.zeros([hmc_T, n_chain, n_covariates+1, 1])
feed_dict = {x_ph: X, y_ph: Y}
for i in range(hmc_T):
    _, ws, bs, accr = sess.run(
        [sample_op, hmc_info.samples['w'], hmc_info.samples['b'], hmc_info.acceptance_rate],
        feed_dict)
    traces[i] = np.concatenate([ws, bs.reshape((-1, 1, 1))], axis=1)
    if i % 50 == 0:
        print()
        if i<100:
            avg_w = traces[:i+1].mean(axis=0)
        else:
            avg_w = traces[100:i+1].mean(axis=0)
        avg_w, avg_b = avg_w[:, :n_covariates], avg_w[:, -1].reshape((-1))
        rmse = sess.run(test_rmse, {x_ph: Xt, y_ph: Yt, w_hmc: avg_w, b_hmc: avg_b})
    print('\r Iter {} Acc.Rate {} RMSE {}'.format(
        i, np.mean(accr), rmse), end='')
print()


 Iter 49 Acc.Rate 0.5636581182479858 RMSE 1.51702880859375
 Iter 99 Acc.Rate 0.7184661030769348 RMSE 0.6041986942291266
 Iter 149 Acc.Rate 0.6586695313453674 RMSE 0.5087842941284188
 Iter 199 Acc.Rate 0.5569800138473511 RMSE 0.5083652734756477
 Iter 249 Acc.Rate 0.47831469774246216 RMSE 0.508267343044281
 Iter 299 Acc.Rate 0.8362857103347778 RMSE 0.50872433185577399


In [4]:
def test_svgd(n_svgd_particles):
    w_particles = tf.get_variable(
        'w_svgd', [n_svgd_particles, n_covariates, 1], tf.float32, 
        tf.random_uniform_initializer(-1, 1))
    b_particles = tf.get_variable(
        'b_svgd', [n_svgd_particles], tf.float32, 
        tf.zeros_initializer())
    var_out_particles = tf.get_variable(
        'var_out_svgd', [n_svgd_particles], tf.float32,
        tf.zeros_initializer())
    grad_and_vars = svgd.stein_variational_gradient(
        log_joint, {'out': y_ph}, {
            'w': w_particles,
            'b': b_particles,
            'var_out': var_out_particles
        })
    optimizer = tf.train.AdamOptimizer(0.05)
    opt_op = optimizer.apply_gradients([(-g, v) for g, v in grad_and_vars])
    
    sess.run(tf.global_variables_initializer())
    SVGD_T = 1000
    for i in range(SVGD_T):
        _ = sess.run([opt_op], feed_dict)
        if i % 100 == 0:
            fd = {x_ph: Xt, y_ph: Yt}
            wp, bp, vp = sess.run([w_particles, b_particles, var_out_particles], fd)
            if n_svgd_particles != n_chain:
                wp0 = np.zeros((n_chain, n_covariates, 1)).astype('f')
                bp0 = np.zeros((n_chain,)).astype('f')
                vp0 = np.zeros((n_chain,)).astype('f')
                for i in range(0, n_chain, n_svgd_particles):
                    wp0[i:i+n_svgd_particles] = wp
                    bp0[i:i+n_svgd_particles] = bp
                    vp0[i:i+n_svgd_particles] = vp
                wp, bp, vp = wp0, bp0, vp0
            fd.update({w_hmc: wp, b_hmc: bp, var_out_hmc: vp})
            print('Iter {} RMSE {}'.format(i, sess.run([test_rmse], fd)))

In [5]:
with tf.variable_scope('v10'):
    test_svgd(10)



Iter 0 RMSE [3.4812257]
Iter 100 RMSE [0.5099669]
Iter 200 RMSE [0.5096439]
Iter 300 RMSE [0.50967413]
Iter 400 RMSE [0.50972116]
Iter 500 RMSE [0.50975984]
Iter 600 RMSE [0.509758]
Iter 700 RMSE [0.5098064]
Iter 800 RMSE [0.5096244]
Iter 900 RMSE [0.5096958]


In [6]:
with tf.variable_scope('v2'):
    test_svgd(2)



Iter 8 RMSE [4.0294843]
Iter 8 RMSE [0.5107512]
Iter 8 RMSE [0.5097867]
Iter 8 RMSE [0.50982356]
Iter 8 RMSE [0.5097693]
Iter 8 RMSE [0.50978273]
Iter 8 RMSE [0.5097728]
Iter 8 RMSE [0.50977445]
Iter 8 RMSE [0.5097622]
Iter 8 RMSE [0.5097647]
