Licensed under the Apache License, Version 2.0.

In [0]:
import tensorflow as tf
import os

from edward2.experimental.attentive_uncertainty import attention
from edward2.experimental.attentive_uncertainty.contextual_bandits.pretrain import train

In [0]:
savedir = '/tmp/wheel_bandit/models/multitask'

num_target = 50
num_context = 512
data_hparams = tf.contrib.training.HParams(context_dim=2,
                                           num_actions=5,
                                           num_target=num_target,
                                           num_context=num_context)
X_HIDDEN_SIZE = 100
x_encoder_sizes = [X_HIDDEN_SIZE]*2
HIDDEN_SIZE = 64
latent_units = 32
global_latent_net_sizes = [HIDDEN_SIZE]*2 + [2*latent_units]
local_latent_net_sizes = [HIDDEN_SIZE]*3 + [2]
x_y_encoder_sizes = [HIDDEN_SIZE]*3
heteroskedastic_net_sizes = None
mean_att_type = attention.laplace_attention
scale_att_type_1 = attention.laplace_attention
scale_att_type_2 = attention.laplace_attention
att_type = 'multihead'
att_heads = 8
data_uncertainty = False

# Prior Predictive + Freeform

In [0]:
uncertainty_type = 'attentive_freeform'
local_variational = False
model_hparams = tf.contrib.training.HParams(activation=tf.nn.relu,
                                            output_activation=tf.nn.relu,
                                            x_encoder_sizes=x_encoder_sizes,
                                            x_y_encoder_sizes=x_y_encoder_sizes,
                                            global_latent_net_sizes=global_latent_net_sizes,
                                            local_latent_net_sizes=local_latent_net_sizes,
                                            heteroskedastic_net_sizes=heteroskedastic_net_sizes,
                                            uncertainty_type=uncertainty_type,
                                            att_type=att_type,
                                            att_heads=att_heads,
                                            mean_att_type=mean_att_type,
                                            scale_att_type_1=scale_att_type_1,
                                            scale_att_type_2=scale_att_type_2,
                                            data_uncertainty=data_uncertainty,
                                            local_variational=local_variational)
save_path = os.path.join(savedir, 'best_prior_freeform_mse_unclipped.ckpt')
training_hparams = tf.contrib.training.HParams(lr=0.01,
                                               optimizer=tf.train.RMSPropOptimizer,
                                               num_iterations=10000,
                                               batch_size=10,
                                               num_context=num_context,
                                               num_target=num_target, 
                                               print_every=50,
                                               save_path=save_path,
                                               max_grad_norm=1000.0)
                                               


In [0]:
train(data_hparams,
      model_hparams,
      training_hparams)

it: 0, train nll: 175.937515259, mse: 326.51953125, local kl: 0.0 global kl: 7.74599175202e-05 valid nll: 201.171005249, mse: 393.021575928, local kl: 0.0 global kl: 0.000304171262542
Saving best model with MSE 393.02158
it: 50, train nll: 22.2970085144, mse: 111.704582214, local kl: 0.0 global kl: 2.50619686994e-05 valid nll: 21.1540718079, mse: 92.6402206421, local kl: 0.0 global kl: 2.43014455918e-05
Saving best model with MSE 92.64022
it: 100, train nll: 26.0193710327, mse: 79.1269226074, local kl: 0.0 global kl: 0.133565917611 valid nll: 27.6986160278, mse: 102.077888489, local kl: 0.0 global kl: 2.97966962535e-05
it: 150, train nll: 25.1414089203, mse: 65.0398178101, local kl: 0.0 global kl: 0.0110144298524 valid nll: 30.1620178223, mse: 98.4341812134, local kl: 0.0 global kl: 5.3404179198e-05
it: 200, train nll: 28.0258808136, mse: 81.0215988159, local kl: 0.0 global kl: 0.00028004439082 valid nll: 26.6753025055, mse: 81.8777999878, local kl: 0.0 global kl: 0.00149854901247
Savi

# Posterior predictive + freeform

In [0]:
uncertainty_type = 'attentive_freeform'
local_variational = True
model_hparams = tf.contrib.training.HParams(activation=tf.nn.relu,
                                            output_activation=tf.nn.relu,
                                            x_encoder_sizes=x_encoder_sizes,
                                            x_y_encoder_sizes=x_y_encoder_sizes,
                                            global_latent_net_sizes=global_latent_net_sizes,
                                            local_latent_net_sizes=local_latent_net_sizes,
                                            heteroskedastic_net_sizes=heteroskedastic_net_sizes,
                                            uncertainty_type=uncertainty_type,
                                            att_type=att_type,
                                            att_heads=att_heads,
                                            mean_att_type=mean_att_type,
                                            scale_att_type_1=scale_att_type_1,
                                            scale_att_type_2=scale_att_type_2,
                                            data_uncertainty=data_uncertainty,
                                            local_variational=local_variational)
save_path = os.path.join(savedir, 'best_posterior_freeform_mse_unclipped.ckpt')
training_hparams = tf.contrib.training.HParams(lr=0.01,
                                               optimizer=tf.train.RMSPropOptimizer,
                                               num_iterations=10000,
                                               batch_size=10,
                                               num_context=num_context,
                                               num_target=num_target, 
                                               print_every=50,
                                               save_path=save_path,
                                               max_grad_norm=1000.0)


In [0]:
train(data_hparams,
      model_hparams,
      training_hparams)

it: 0, train nll: 291.345306396, mse: 327.003387451, local kl: 0.0251190047711 global kl: 8.42864028527e-05 valid nll: 446.904907227, mse: 410.268127441, local kl: 0.0530770793557 global kl: 0.00013029173715
Saving best model with MSE 410.26813
it: 50, train nll: 11.5967645645, mse: 131.161407471, local kl: 0.0747972652316 global kl: 1.34051947498e-06 valid nll: 8.17937660217, mse: 128.501358032, local kl: 0.0416157543659 global kl: 1.41697219078e-05
Saving best model with MSE 128.50136
it: 100, train nll: 39.7724533081, mse: 252.00189209, local kl: 0.0145239755511 global kl: 34902859776.0 valid nll: 20.704328537, mse: 202.928283691, local kl: 0.0488733649254 global kl: 9.995731034e-07
it: 150, train nll: 3.64748597145, mse: 105.372077942, local kl: 0.0107508469373 global kl: 0.0003227260313 valid nll: 3.00736522675, mse: 94.4865570068, local kl: 0.0274358782917 global kl: 3.69362642232e-05
Saving best model with MSE 94.48656
it: 200, train nll: 2.67741012573, mse: 80.642074585, local 

# Prior Predictive + GP

In [0]:
uncertainty_type = 'attentive_gp'
local_variational = False
model_hparams = tf.contrib.training.HParams(activation=tf.nn.relu,
                                            output_activation=tf.nn.relu,
                                            x_encoder_sizes=x_encoder_sizes,
                                            x_y_encoder_sizes=x_y_encoder_sizes,
                                            global_latent_net_sizes=global_latent_net_sizes,
                                            local_latent_net_sizes=local_latent_net_sizes,
                                            heteroskedastic_net_sizes=heteroskedastic_net_sizes,
                                            uncertainty_type=uncertainty_type,
                                            att_type=att_type,
                                            att_heads=att_heads,
                                            mean_att_type=mean_att_type,
                                            scale_att_type_1=scale_att_type_1,
                                            scale_att_type_2=scale_att_type_2,
                                            data_uncertainty=data_uncertainty,
                                            local_variational=local_variational)
save_path = os.path.join(savedir, 'best_prior_gp_mse_unclipped.ckpt')
training_hparams = tf.contrib.training.HParams(lr=0.01,
                                               optimizer=tf.train.RMSPropOptimizer,
                                               num_iterations=10000,
                                               batch_size=10,
                                               num_context=num_context,
                                               num_target=num_target, 
                                               print_every=50,
                                               save_path=save_path,
                                               max_grad_norm=1000.0)

In [0]:
train(data_hparams,
      model_hparams,
      training_hparams)

it: 0, train nll: 133.707901001, mse: 265.018371582, local kl: 0.0 global kl: 4.86915232614e-05 valid nll: 155.182296753, mse: 307.859466553, local kl: 0.0 global kl: 6.83785183355e-05
Saving best model with MSE 307.85947
it: 50, train nll: 30.6259536743, mse: 47.5792503357, local kl: 0.0 global kl: 0.00189643469639 valid nll: 44.8390197754, mse: 69.9601364136, local kl: 0.0 global kl: 0.00633073132485
Saving best model with MSE 69.96014
it: 100, train nll: 30.8255844116, mse: 46.4106445312, local kl: 0.0 global kl: 0.000299640523735 valid nll: 50.1438789368, mse: 65.9298400879, local kl: 0.0 global kl: 0.00054213067051
Saving best model with MSE 65.92984
it: 150, train nll: 20.5862350464, mse: 29.4456310272, local kl: 0.0 global kl: 2.68571693596e-05 valid nll: 45.0638008118, mse: 60.0532035828, local kl: 0.0 global kl: 8.25333918328e-05
Saving best model with MSE 60.053204
it: 200, train nll: 23.5716304779, mse: 32.0456619263, local kl: 0.0 global kl: 3.63523827218e-06 valid nll: 45.

# Posterior predictive + GP

In [0]:
uncertainty_type = 'attentive_gp'
local_variational = True
model_hparams = tf.contrib.training.HParams(activation=tf.nn.relu,
                                            output_activation=tf.nn.relu,
                                            x_encoder_sizes=x_encoder_sizes,
                                            x_y_encoder_sizes=x_y_encoder_sizes,
                                            global_latent_net_sizes=global_latent_net_sizes,
                                            local_latent_net_sizes=local_latent_net_sizes,
                                            heteroskedastic_net_sizes=heteroskedastic_net_sizes,
                                            uncertainty_type=uncertainty_type,
                                            att_type=att_type,
                                            att_heads=att_heads,
                                            mean_att_type=mean_att_type,
                                            scale_att_type_1=scale_att_type_1,
                                            scale_att_type_2=scale_att_type_2,
                                            data_uncertainty=data_uncertainty,
                                            local_variational=local_variational)
save_path = os.path.join(savedir, 'best_posterior_gp_mse_unclipped.ckpt')
training_hparams = tf.contrib.training.HParams(lr=0.01,
                                               optimizer=tf.train.RMSPropOptimizer,
                                               num_iterations=10000,
                                               batch_size=10,
                                               num_context=num_context,
                                               num_target=num_target, 
                                               print_every=50,
                                               save_path=save_path,
                                               max_grad_norm=1000.0)

In [0]:
train(data_hparams,
      model_hparams,
      training_hparams)

it: 0, train nll: 131.736602783, mse: 261.122589111, local kl: 0.0353629663587 global kl: 8.2984319306e-05 valid nll: 152.57359314, mse: 302.690948486, local kl: 0.0492436401546 global kl: 0.000122616358567
Saving best model with MSE 302.69095
it: 50, train nll: 8.20581817627, mse: 9.99204444885, local kl: 12.4604902267 global kl: 0.000217912674998 valid nll: 12.7324056625, mse: 17.2057094574, local kl: 14.8639354706 global kl: 0.000617719313595
Saving best model with MSE 17.20571
it: 100, train nll: 9.82409477234, mse: 13.3075666428, local kl: 7.85672092438 global kl: 0.000187600715435 valid nll: 10.4615631104, mse: 13.1943149567, local kl: 16.4872779846 global kl: 6.16395409452e-05
Saving best model with MSE 13.194315
it: 150, train nll: 5.35595321655, mse: 6.58239555359, local kl: 7.75857448578 global kl: 1.5606310626e-05 valid nll: 10.0951595306, mse: 12.7748575211, local kl: 15.3519239426 global kl: 2.25215171668e-05
Saving best model with MSE 12.7748575
it: 200, train nll: 7.0551

# Prior predictive + freeform

In [0]:
num_target = 50
num_context = 512
data_hparams = tf.contrib.training.HParams(context_dim=2,
                                           num_actions=5,
                                           num_target=num_target,
                                           num_context=num_context)
X_HIDDEN_SIZE = 100
x_encoder_sizes = [X_HIDDEN_SIZE]*3

HIDDEN_SIZE = 64
latent_units = 32
freeform_decoder_sizes = [HIDDEN_SIZE]*3 + [2]
global_decoder_sizes = [HIDDEN_SIZE]*2 + [2*latent_units]
global2local_decoder_sizes = None
x_y_encoder_sizes = [HIDDEN_SIZE]*3
heteroskedastic_sizes = None
uncertainty_type = None
mean_att_type = attention.laplace_attention
scale_att_type_1 = attention.laplace_attention
scale_att_type_2 = attention.laplace_attention
att_type = 'multihead'
att_heads = 8
data_uncertainty = False

model_hparams = tf.contrib.training.HParams(activation=tf.nn.relu,
                                            output_activation=tf.nn.relu,
                                            x_encoder_sizes=x_encoder_sizes,
                                            x_y_encoder_sizes=x_y_encoder_sizes,
                                            freeform_decoder_sizes=freeform_decoder_sizes,
                                            global_decoder_sizes=global_decoder_sizes,
                                            global2local_decoder_sizes=global2local_decoder_sizes,
                                            heteroskedastic_sizes=heteroskedastic_sizes,
                                            uncertainty_type=uncertainty_type,
                                            att_type=att_type,
                                            att_heads=att_heads,
                                            mean_att_type=mean_att_type,
                                            scale_att_type_1=scale_att_type_1,
                                            scale_att_type_2=scale_att_type_2,
                                            meta_learn=False,
                                            data_uncertainty=data_uncertainty)
save_path = os.path.join(savedir, 'best_prior_freeform_mse_unclipped.ckpt')
pred_type = 'prior_predictive'
training_hparams = tf.contrib.training.HParams(lr=0.01,
                                               optimizer=tf.train.RMSPropOptimizer,
                                               num_iterations=10000,
                                               batch_size=10,
                                               num_context=num_context,
                                               num_target=num_target, 
                                               print_every=50,
                                               save_path=save_path,
                                               pred_type=pred_type,
                                               max_grad_norm=1000.0)


In [0]:
train(data_hparams,
         model_hparams,
         training_hparams)

it: 0, train nll: 233.244369507, mse: 330.893981934, local kl: 0.0 global kl: 7.03447440173e-05 valid nll: 299.086914062, mse: 424.603118896, local kl: 7.28687268747e-07 global kl: 9.62082631304e-05
Saving best model with MSE 424.60312
it: 50, train nll: 37.6227493286, mse: 136.156738281, local kl: 0.0 global kl: 0.000122057666886 valid nll: 119.781051636, mse: 853.618164062, local kl: 0.141291186213 global kl: 0.000135921232868
it: 100, train nll: 35.6790542603, mse: 121.717926025, local kl: 0.0 global kl: 0.000265129288891 valid nll: 26.3708591461, mse: 106.190765381, local kl: 0.0135101545602 global kl: 3.95455208491e-05
Saving best model with MSE 106.190765
it: 150, train nll: 18.2585353851, mse: 67.5693893433, local kl: 0.0 global kl: 6.13677766523e-05 valid nll: 14.4377069473, mse: 92.0325088501, local kl: 0.0030499540735 global kl: 2.05152064154e-05
Saving best model with MSE 92.03251
it: 200, train nll: 12.2136974335, mse: 107.714637756, local kl: 0.0 global kl: 1.22998130792e-

# Prior predictive + gp

In [0]:
num_target = 50
num_context = 512
data_hparams = tf.contrib.training.HParams(context_dim=2,
                                           num_actions=5,
                                           num_target=num_target,
                                           num_context=num_context)
X_HIDDEN_SIZE = 100
x_encoder_sizes = [X_HIDDEN_SIZE]*3

HIDDEN_SIZE = 64
latent_units = 32
freeform_decoder_sizes = None
global_decoder_sizes = [HIDDEN_SIZE]*2 + [2*latent_units]
global2local_decoder_sizes = [HIDDEN_SIZE]*3 + [2]
x_y_encoder_sizes = [HIDDEN_SIZE]*3
heteroskedastic_sizes = None
uncertainty_type = 'attentive_gp'
mean_att_type = attention.laplace_attention
scale_att_type_1 = attention.laplace_attention
scale_att_type_2 = attention.laplace_attention
att_type = 'multihead'
att_heads = 8
data_uncertainty = False

model_hparams = tf.contrib.training.HParams(activation=tf.nn.relu,
                                            output_activation=tf.nn.relu,
                                            x_encoder_sizes=x_encoder_sizes,
                                            x_y_encoder_sizes=x_y_encoder_sizes,
                                            freeform_decoder_sizes=freeform_decoder_sizes,
                                            global_decoder_sizes=global_decoder_sizes,
                                            global2local_decoder_sizes=global2local_decoder_sizes,
                                            heteroskedastic_sizes=heteroskedastic_sizes,
                                            uncertainty_type=uncertainty_type,
                                            att_type=att_type,
                                            att_heads=att_heads,
                                            mean_att_type=mean_att_type,
                                            scale_att_type_1=scale_att_type_1,
                                            scale_att_type_2=scale_att_type_2,
                                            meta_learn=False,
                                            data_uncertainty=data_uncertainty)
save_path = os.path.join(savedir, 'best_prior_gp_mse_unclipped.ckpt')
pred_type = 'prior_predictive'
training_hparams = tf.contrib.training.HParams(lr=0.01,
                                               optimizer=tf.train.RMSPropOptimizer,
                                               num_iterations=10000,
                                               batch_size=10,
                                               num_context=num_context,
                                               num_target=num_target, 
                                               print_every=50,
                                               save_path=save_path,
                                               pred_type=pred_type,
                                               max_grad_norm=1000.0)


In [0]:
train(data_hparams,
      model_hparams,
      training_hparams)

it: 0, train nll: 135.67288208, mse: 268.965209961, local kl: 0.0 global kl: 2.97951009998e-05 valid nll: 156.140930176, mse: 309.864196777, local kl: 0.0464261583984 global kl: 3.66583008145e-05
Saving best model with MSE 309.8642
it: 50, train nll: 38.713054657, mse: 47.3657722473, local kl: 0.0 global kl: 7.08528241375e-05 valid nll: 7.87959337234, mse: 9.00686168671, local kl: 20.2937602997 global kl: 0.000171415551449
Saving best model with MSE 9.006862
it: 100, train nll: 39.7649269104, mse: 56.1764793396, local kl: 0.0 global kl: 1.59876290127e-05 valid nll: 5.57811546326, mse: 5.43897247314, local kl: 24.2452487946 global kl: 2.80916428892e-05
Saving best model with MSE 5.4389725
it: 150, train nll: 22.0698070526, mse: 27.5763893127, local kl: 0.0 global kl: 1.70928542502e-05 valid nll: 5.52220392227, mse: 5.64932394028, local kl: 25.2032299042 global kl: 4.46284757345e-05
it: 200, train nll: 25.4558467865, mse: 34.0109519958, local kl: 0.0 global kl: 8.34837101138e-06 valid nl

# Archive

In [0]:
def sample_training_wheel_bandit_data(num_total_states,
                                      num_actions,
                                      context_dim,
                                      delta,
                                      mean_v,
                                      std_v,
                                      mu_large,
                                      std_large):
  """Samples from Wheel bandit game (see https://arxiv.org/abs/1802.09127).

  Args:
    num_total_states: Number of points to sample, i.e. (context, action rewards).
    num_actions: Number of actions.
    context_dim: Number of dimensions in the context
    delta: Exploration parameter: high reward in one region if norm above delta.
    mean_v: Mean reward for each action if context norm is below delta.
    std_v: Gaussian reward std for each action if context norm is below delta.
    mu_large: Mean reward for optimal action if context norm is above delta.
    std_large: Reward std for optimal action if context norm is above delta.

  Returns:
    dataset: Sampled matrix with n rows: (context, one_hot_actions).
    opt_vals: Vector of expected optimal (reward, action) for each context.
  """


  data = []
  actions = []
  rewards = []

  # sample uniform contexts in unit ball
  while len(data) < num_total_states:
    raw_data = np.random.uniform(-1, 1, (int(num_total_states / 3), context_dim))

    for i in range(raw_data.shape[0]):
      if np.linalg.norm(raw_data[i, :]) <= 1:
        
        data.append(raw_data[i, :])

  states = np.stack(data)[:num_total_states, :]

  # sample rewards and random actions
  
  for i in range(num_total_states):
    r = [np.random.normal(mean_v[j], std_v[j]) for j in range(num_actions)]
    if np.linalg.norm(states[i, :]) >= delta:
      # large reward in the right region for the context
      r_big = np.random.normal(mu_large, std_large)
      if states[i, 0] > 0:
        if states[i, 1] > 0:
          r[0] = r_big
        else:
          r[1] = r_big
      else:
        if states[i, 1] > 0:
          r[2] = r_big
        else:
          r[3] = r_big
    one_hot_vector = np.zeros((5))
    random_action = np.random.randint(num_actions)
    one_hot_vector[random_action]=1
    actions.append(one_hot_vector)
    rewards.append(r[random_action])

  rewards = np.expand_dims(np.array(rewards), -1)
  state_action_pairs = np.hstack([states, actions])
  perm = np.random.permutation(len(rewards))
  return state_action_pairs[perm, :], rewards[perm, :]


def get_training_wheel_data(num_total_states, num_actions, context_dim, delta):

  mean_v = [1.0, 1.0, 1.0, 1.0, 1.2]
  std_v = [0.01, 0.01, 0.01, 0.01, 0.01]
  mu_large = 50
  std_large = 0.01
  state_action_pairs, rewards = sample_training_wheel_bandit_data(num_total_states,
                                              num_actions,
                                              context_dim,
                                              delta,
                                              mean_v,
                                              std_v,
                                              mu_large,
                                              std_large)
  return state_action_pairs, rewards

In [0]:
def procure_dataset(hparams, num_wheels, seed=0):
  np.random.seed(seed)

  data_type = 'wheel_2'

  all_state_action_pairs, all_rewards = [], []
  for _ in range(num_wheels):
    delta = np.random.uniform()
    state_action_pairs, rewards = get_training_wheel_data(
        hparams.num_target + hparams.num_context,
        hparams.num_actions,
        hparams.context_dim,
        delta)
    all_state_action_pairs.append(state_action_pairs)
    all_rewards.append(rewards)

  all_state_action_pairs = np.stack(all_state_action_pairs)
  all_rewards = np.stack(all_rewards)

  return all_state_action_pairs, all_rewards

@tf.function
def step(model, data, optimizer_config, num_context):

    context_x, context_y, target_x, target_y, unseen_targets = data
    with tf.GradientTape() as tape:
      prior_prediction, posterior_prediction = model(
          context_x, 
          context_y, 
          target_x, 
          target_y)
      unseen_targets = target_y[:, num_context:]
      unseen_predictions = posterior_prediction[:, num_context:]
      nll = utils.nll(unseen_targets, unseen_predictions)
      mse = utils.mse(unseen_targets, unseen_predictions)
      local_kl = tf.reduce_mean(model.losses[-1][:, num_context:])
      global_kl = tf.reduce_mean(model.losses[-2])        
      # loss = nll + local_kl + global_kl
      loss = mse + local_kl + global_kl
      # loss = nll + global_kl
      # loss = mse + global_kl
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer_config.apply_gradients(zip(gradients, model.trainable_variables))
    return nll, mse, local_kl, global_kl

def training_loop(train_dataset,
                  valid_dataset,
                  model,
                  hparams):
  
  optimizer_config = hparams.optimizer(hparams.lr)
  num_context = hparams.num_context 
  best_mse = np.inf

  train_target_x, train_target_y = train_dataset

  def _get_splits(dataset, n_context, batch_size, points_perm=True):
    full_x, full_y = dataset
    dataset_perm = np.random.permutation(len(full_x))[:batch_size]
    if points_perm:
      datapoints_perm = np.random.permutation(full_x.shape[1])
    else:
      datapoints_perm = np.arange(full_x.shape[1])

    target_x = tf.to_float(full_x[dataset_perm[:, None], datapoints_perm])
    target_y = tf.to_float(full_y[dataset_perm[:, None], datapoints_perm])
    context_x = target_x[:, :n_context, :]
    context_y = target_y[:, :n_context, :]
    unseen_targets = target_y[:, n_context:]

    return context_x, context_y, target_x, target_y, unseen_targets

  for it in range(hparams.num_iterations):
    batch_train_data = _get_splits(train_dataset, num_context, hparams.batch_size, points_perm=True)
    nll, mse, local_z_kl, global_z_kl = step(
        model,
        batch_train_data, 
        optimizer_config,
        num_context)
    
    if it % hparams.print_every == 0:
      batch_context_x, batch_context_y, batch_target_x, batch_target_y, batch_unseen_targets = _get_splits(valid_dataset, num_context, hparams.batch_size, points_perm=False)
      prior_prediction, posterior_prediction = model(
          batch_context_x, 
          batch_context_y, 
          batch_target_x, 
          batch_target_y)
      
      valid_unseen_predictions = posterior_prediction[:, num_context:]
      # unseen_predictions = prior_prediction[:, num_context:]
      valid_nll = utils.nll(batch_unseen_targets, valid_unseen_predictions)
      valid_mse = utils.mse(batch_unseen_targets, valid_unseen_predictions)
      valid_local_kl = tf.reduce_mean(model.losses[-1][:, num_context:])
      valid_global_kl = tf.reduce_mean(model.losses[-2])        

      print('it: {}, train nll: {}, mse: {}, local kl: {} global kl: {} '
            'valid nll: {}, mse: {}, local kl: {} global kl: {}'
            .format(it, nll, mse, local_z_kl, global_z_kl,
                    valid_nll, valid_mse, valid_local_kl, valid_global_kl))
      if valid_mse.numpy() < best_mse:
        print('Saving best model')
        best_mse = valid_mse.numpy()
        model.save_weights(hparams.save_path)

  print('Best MSE is', best_mse)
      
def pretrain(data_hparams,
             model_hparams,
             training_hparams):

  all_state_action_pairs, all_rewards = procure_dataset(data_hparams,
                                                        num_wheels=100,
                                                        seed=0)
  train_dataset = (all_state_action_pairs, all_rewards)

  all_state_action_pairs, all_rewards = procure_dataset(data_hparams,
                                                        num_wheels=10,
                                                        seed=42)
  valid_dataset = (all_state_action_pairs, all_rewards)

  model = Regressor(
        input_dim=data_hparams.context_dim + data_hparams.num_actions,
        output_dim=1,
        x_encoder_sizes=model_hparams.x_encoder_sizes,
        x_y_encoder_sizes=model_hparams.x_y_encoder_sizes,
        freeform_decoder_sizes=model_hparams.freeform_decoder_sizes,
        global_decoder_sizes=model_hparams.global_decoder_sizes,
        global2local_decoder_sizes=model_hparams.global2local_decoder_sizes,
        heteroskedastic_sizes=model_hparams.heteroskedastic_sizes,
        att_type=model_hparams.att_type,
        att_heads=model_hparams.att_heads,
        uncertainty_type=model_hparams.uncertainty_type,
        mean_att_type=model_hparams.mean_att_type,
        scale_att_type_1=model_hparams.scale_att_type_1,
        scale_att_type_2=model_hparams.scale_att_type_2,
        activation=model_hparams.activation,
        output_activation=model_hparams.output_activation,
        meta_learn=model_hparams.meta_learn)

  
  training_loop(train_dataset,
                             valid_dataset,
                             model,
                             training_hparams)
  
  # # check if weights are saved correctly
  # valid_context_x, valid_context_y, valid_target_x, valid_target_y, valid_unseen_targets = valid_data
  # model.load_weights(training_hparams.save_path)

  # prior_prediction, posterior_prediction = model(
  #         valid_context_x, 
  #         valid_context_y, 
  #         valid_target_x, 
  #         valid_target_y)
  
  # valid_unseen_predictions = posterior_prediction[:, num_context:]
  # valid_nll = utils.nll(valid_unseen_targets, valid_unseen_predictions)
  # valid_mse = utils.mse(valid_unseen_targets, valid_unseen_predictions)
  # print('Verified best MSE is', valid_mse.numpy())
  

  

In [0]:
pretrain(data_hparams,
         model_hparams,
         training_hparams)

# prior predictive + mse

In [0]:
pretrain(data_hparams,
         model_hparams,
         training_hparams)

it: 0, train nll: 80.64142608642578, mse: 264.84783935546875, local kl: 2.2282488346099854 global kl: 0.012998933903872967valid nll: 73.07066345214844, mse: 239.69493103027344, local kl: 3.808199167251587 global kl: 0.020804043859243393
it: 50, train nll: 80.03545379638672, mse: 263.78887939453125, local kl: 3.2589476108551025 global kl: 0.0017950760666280985valid nll: 77.01897430419922, mse: 253.8538055419922, local kl: 1.3728536367416382 global kl: 0.0006710884626954794
it: 100, train nll: 76.54899597167969, mse: 252.11932373046875, local kl: 3.2033069133758545 global kl: 0.00012088955554645509valid nll: 70.98995208740234, mse: 233.72201538085938, local kl: 3.41615891456604 global kl: 0.000149666826473549
it: 150, train nll: 86.5409927368164, mse: 284.36871337890625, local kl: 6.026893138885498 global kl: 0.00034067913657054305valid nll: 64.36710357666016, mse: 210.2588653564453, local kl: 2.018472194671631 global kl: 0.00029383436776697636
it: 200, train nll: 83.56436157226562, mse:

# prior predictive + nll

In [0]:
pretrain(data_hparams,
         model_hparams,
         training_hparams)

it: 0, train nll: 73.67438507080078, mse: 247.44496154785156, local kl: 0.8368964195251465 global kl: 0.002964801387861371valid nll: 61.58158493041992, mse: 223.068359375, local kl: 3.387556314468384 global kl: 0.011656454764306545
it: 50, train nll: 4.328117847442627, mse: 263.22772216796875, local kl: 3.845857620239258 global kl: 0.0013694826047867537valid nll: 4.294598579406738, mse: 259.070556640625, local kl: 3.2342352867126465 global kl: 0.0010290180798619986
it: 100, train nll: 3.747567892074585, mse: 292.353271484375, local kl: 5.592900276184082 global kl: 0.00027055011014454067valid nll: 3.7165119647979736, mse: 283.5585021972656, local kl: 2.8959457874298096 global kl: 0.00016625048010610044
it: 150, train nll: 3.6381709575653076, mse: 282.8884582519531, local kl: 3.68673038482666 global kl: 4.3551444832701236e-05valid nll: 3.6170289516448975, mse: 273.54779052734375, local kl: 3.87436842918396 global kl: 3.7984536902513355e-05
it: 200, train nll: 3.6732494831085205, mse: 267

# Hide Run


Initializing model NeuroLinear-bnn.
Initializing model SNP - Attentive GP.
Initializing model SNP - Freeform.
Training NeuroLinear-bnn for 50 steps...
20 Training SNP - Freeform for 50 steps...
Average nll: 13.975089073181152, mse: 442.6430969238281, local kl: 0.0027215962763875723 global kl: 0.008305639028549194
Training NeuroLinear-bnn for 50 steps...
40 Training SNP - Freeform for 50 steps...
Average nll: 4.444005966186523, mse: 424.95819091796875, local kl: 0.05070192366838455 global kl: 0.006373311392962933
Training NeuroLinear-bnn for 50 steps...
60 Training SNP - Freeform for 50 steps...
Average nll: 4.522835731506348, mse: 495.9842834472656, local kl: 0.008277200162410736 global kl: 0.00125440105330199
Training NeuroLinear-bnn for 50 steps...
80 Training SNP - Freeform for 50 steps...
Average nll: 3.5025625228881836, mse: 274.245849609375, local kl: 0.023440886288881302 global kl: 0.0004953107563778758
Training NeuroLinear-bnn for 50 steps...
100 Training SNP - Freeform for 50 

# posterior predictive + mse

In [0]:
pretrain(data_hparams,
         model_hparams,
         training_hparams)

it: 0, train nll: 69.30583953857422, mse: 227.4987335205078, local kl: 4.271792888641357 global kl: 0.01877990923821926valid nll: 86.39154815673828, mse: 283.9880676269531, local kl: 2.247591972351074 global kl: 0.00802246667444706
it: 50, train nll: 29.94801139831543, mse: 67.4618148803711, local kl: 65.69648742675781 global kl: 0.0041467431001365185valid nll: 43.102577209472656, mse: 105.57545471191406, local kl: 47.18109893798828 global kl: 0.002302606124430895
it: 100, train nll: 24.34380340576172, mse: 51.35904312133789, local kl: 84.17845916748047 global kl: 0.0002516347449272871valid nll: 37.54330825805664, mse: 84.63033294677734, local kl: 49.208221435546875 global kl: 0.0002653984120115638
it: 150, train nll: 20.86395835876465, mse: 41.2745361328125, local kl: 66.63032531738281 global kl: 1.979325679712929e-05valid nll: 39.090572357177734, mse: 83.65245056152344, local kl: 63.73513412475586 global kl: 5.718140891985968e-05
it: 200, train nll: 33.20714569091797, mse: 68.5636825

In [0]:
pretrain(data_hparams,
         model_hparams,
         training_hparams)

it: 0, train nll: 63.06072235107422, mse: 215.38909912109375, local kl: 3.032419443130493 global kl: 0.00995991937816143valid nll: 63.42512512207031, mse: 216.81292724609375, local kl: 3.3935022354125977 global kl: 0.01095118559896946
it: 50, train nll: 40.56180953979492, mse: 113.71380615234375, local kl: 146.83770751953125 global kl: 0.007505028508603573valid nll: 41.103233337402344, mse: 122.31623077392578, local kl: 49.39350128173828 global kl: 0.018857663497328758
it: 100, train nll: 30.77800178527832, mse: 79.66529083251953, local kl: 68.61537170410156 global kl: 6.1731977462768555valid nll: 26.246097564697266, mse: 72.5315933227539, local kl: 60.54109573364258 global kl: 0.02694229781627655
it: 150, train nll: 28.668663024902344, mse: 76.17350006103516, local kl: 70.8505859375 global kl: 0.005348356906324625valid nll: 23.36986541748047, mse: 59.470455169677734, local kl: 67.83472442626953 global kl: 0.003085080999881029
it: 200, train nll: 23.498933792114258, mse: 63.86367797851

KeyboardInterrupt: ignored

In [0]:
pretrain(data_hparams,
         model_hparams,
         training_hparams)

it: 0, train nll: 58.961856842041016, mse: 205.11830139160156, local kl: 2.90657639503479 global kl: 0.003845751751214266valid nll: 83.07482147216797, mse: 290.5240783691406, local kl: 1.5624231100082397 global kl: 0.0018705299589782953
it: 50, train nll: 2.700871467590332, mse: 3.628453493118286, local kl: 4039669.5 global kl: 0.012357242405414581valid nll: 3.038377523422241, mse: 4.814786911010742, local kl: 9913327.0 global kl: 0.009815489873290062
it: 100, train nll: 2.113877058029175, mse: 0.2591399848461151, local kl: 63664288.0 global kl: 0.0012066202471032739valid nll: 2.106165885925293, mse: 0.5803396105766296, local kl: 73767528.0 global kl: 0.00068276422098279
it: 150, train nll: 2.9890506267547607, mse: 0.21710118651390076, local kl: 368845952.0 global kl: 0.011972357518970966valid nll: 2.979311466217041, mse: 0.062382180243730545, local kl: 452203360.0 global kl: 0.010854621417820454
it: 200, train nll: 2.993539333343506, mse: 0.12942281365394592, local kl: 174734832.0 glo

KeyboardInterrupt: ignored