# Imports and helper functions

In [1]:
import collections

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

import tensorflow as tf
import tensorflow_datasets as tfds
import sonnet as snt

In [2]:
def sample_trial(pi, batch_size, class1, class2):
  """Implements dynamics for laboratory-style experiment.

  With 50/50 probability choose either class 1 or class 2.
  For class 2 return a random image and reward 1.0,
  for class 1 return a random image and a reward of 3.0
  with probability pi[0], otherwise -1.0.
  """

  images = np.zeros((batch_size, ) + class1[0].shape)
  rewards = np.zeros(batch_size)
  
  for b in range(batch_size):
    if np.random.random() < 0.5:
      images[b] = class2[np.random.randint(len(class2))]
      rewards[b] = 1.
    else:
      images[b] = class1[np.random.randint(len(class1))]
      rewards[b] = 3. if np.random.random() < pi else -1.

  return images, rewards

def running_mean(x, N):
  """Compute a moving-window average"""
  cumsum = np.pad(np.cumsum(x), [1, 0], 'constant')
  return (cumsum[N:] - cumsum[:-N]) / float(N)
  

In [3]:
_Outputs = collections.namedtuple('Outputs', ['activations', 'values'])

class Cifar10ValueNet(snt.AbstractModule):
  """Based upon the Cifar10ConvNet in Sonnet 2."""
  def __init__(self,
               num_atoms,
               initializers=None,
               regularizers=None,
               partitioners=None,
               custom_getter=None,
               name="cifar10_convnet"):
    super(Cifar10ValueNet, self).__init__(custom_getter=custom_getter, name=name)

    self._num_atoms = num_atoms
    self._output_channels = [
        64, 64, 128, 128, 128, 256, 256, 256, 512, 512, 512
    ]

    self._num_layers = len(self._output_channels)
    self._kernel_shapes = [[3, 3]] * self._num_layers  # All kernels are 3x3.
    self._strides = [1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1]
    self._paddings = [snt.SAME] * self._num_layers

    self._initializers = initializers
    self._regularizers = regularizers
    self._partitioners = partitioners

    with self._enter_variable_scope():
      self._layers = tuple(
          snt.Conv2D(
              name="conv_net_2d/conv_2d_{}".format(i),
              output_channels=self._output_channels[i],
              kernel_shape=self._kernel_shapes[i],
              stride=self._strides[i],
              padding=self._paddings[i],
              initializers=initializers,
              regularizers=regularizers,
              partitioners=partitioners,
              use_bias=True) for i in range(self._num_layers))

  def _build(self, inputs):
    """Connects the module into the graph.

    Args:
      inputs: A Tensor of size [batch_size, input_height, input_width,
        input_channels], representing a batch of input images.
      is_training: Boolean to indicate to `snt.BatchNorm` if we are
        currently training. By default `True`.
      test_local_stats: Boolean to indicate to `snt.BatchNorm` if batch
        normalization should  use local batch statistics at test time.
        By default `True`.
      get_intermediate_activations: Boolean to indicate whether the activations
        of intermediate layers should be returned when the module is called.
        False by default.

    Returns:
      If get_intermediate_activations is False, a Tensor corresponding to the
        output logits of the network.
      If get_intermediate_activations is True, a list of tf.Tensor, the feature
        activations of the module. The order of the activations is preserved in
        the output list. Namely, output[0], corresponds to the first layer
        activations, and output[1] corresponds to the feature activations of the
        second block, and so on. The activations in the output list are those
        computed after the activation function is applied, if one is applied at
        that layer.

      In both cases, the shapes of the returned Tensors depends on the size
      of the input Tensor.
    """

    net = inputs
    for i, layer in enumerate(self._layers):
      net = layer(net)
      net = tf.nn.relu(net)

    flat_output = tf.reduce_mean(
        net, reduction_indices=[1, 2], keepdims=False, name="avg_pool")

    # Replace classifier output with linear function predicting values
    values = snt.Linear(
        self._num_atoms,
        initializers=self._initializers,
        regularizers=self._regularizers,
        partitioners=self._partitioners)(flat_output)

    return _Outputs(activations=flat_output, values=values)


AttributeError: module 'sonnet' has no attribute 'AbstractModule'

# Bird-dog experiment

## Load Cifar10 train and test data

We load for 'dog'/'bird' (primary experiment) and 'airplane'/'ship' (control experiment)

Class label indices:
airplane (0), automobile (1), bird (2), cat (3), deer (4), dog (5), frog (6), horse (7), ship (8), truck (9)

In [None]:
class1_ind = 5 # dog
class2_ind = 2 # bird

class1_neg_ind = 0 # airplane
class2_neg_ind = 8 # ship

In [None]:
# Load training images for the targeted image classes
train_ds = tfds.load("cifar10", split=tfds.Split.TRAIN)
train_class1, train_class2 = [], []
train_neg_class1, train_neg_class2 = [], []

for example in tfds.as_numpy(train_ds):
  numpy_image, numpy_label = example["image"], example["label"]
  if numpy_label == class1_ind:
    train_class1.append(numpy_image)
  elif numpy_label == class2_ind:
    train_class2.append(numpy_image)
  elif numpy_label == class1_neg_ind:
    train_neg_class1.append(numpy_image)
  elif numpy_label == class2_neg_ind:
    train_neg_class2.append(numpy_image)

# Load testing images for the targeted image classes
test_ds = tfds.load("cifar10", split=tfds.Split.TEST)
test_class1, test_class2 = [], []
test_neg_class1, test_neg_class2 = [], []

for example in tfds.as_numpy(test_ds):
  numpy_image, numpy_label = example["image"], example["label"]
  if numpy_label == class1_ind:
    test_class1.append(numpy_image)
  elif numpy_label == class2_ind:
    test_class2.append(numpy_image)
  elif numpy_label == class1_neg_ind:
    test_neg_class1.append(numpy_image)
  elif numpy_label == class2_neg_ind:
    test_neg_class2.append(numpy_image)

train_class1 = np.array(train_class1)
train_class2 = np.array(train_class2)
train_neg_class1 = np.array(train_neg_class1)
train_neg_class2 = np.array(train_neg_class2)

test_class1 = np.array(test_class1)
test_class2 = np.array(test_class2)
test_neg_class1 = np.array(test_neg_class1)
test_neg_class2 = np.array(test_neg_class2)


In [None]:
print("Train:", train_class1.shape, train_class2.shape, train_neg_class1.shape, train_neg_class2.shape)
print("Test:", test_class1.shape, test_class2.shape, test_neg_class1.shape, test_neg_class2.shape)

In [None]:
# Display some example images for the primary experiment
# (test set in the final row)

fig = plt.figure(figsize=(8,8))
cnt = 1

for i in range(4):
  for j in range(6):
    plt.subplot(5, 6, cnt)
    cnt += 1
    if j <= 2:
      plt.imshow(train_class1[np.random.randint(len(train_class1))], interpolation='none')
    else:
      plt.imshow(train_class2[np.random.randint(len(train_class2))], interpolation='none')
    plt.gca().axis('off')

for j in range(6):
  plt.subplot(5, 6, cnt)
  cnt += 1
  if j <= 2:
    plt.imshow(test_class1[np.random.randint(len(test_class1))], interpolation='none')
  else:
    plt.imshow(test_class2[np.random.randint(len(test_class2))], interpolation='none')
  plt.gca().axis('off')


In [None]:
# Display some example images for the control experiment
# (test set in the final row)

fig = plt.figure(figsize=(8,8))
cnt = 1
for i in range(4):
  for j in range(6):
    plt.subplot(5, 6, cnt)
    cnt += 1
    if j <= 2:
      plt.imshow(train_neg_class1[np.random.randint(len(train_neg_class1))], interpolation='none')
    else:
      plt.imshow(train_neg_class2[np.random.randint(len(train_neg_class2))], interpolation='none')
    plt.gca().axis('off')

for j in range(6):
  plt.subplot(5, 6, cnt)
  cnt += 1
  if j <= 2:
    plt.imshow(test_neg_class1[np.random.randint(len(test_neg_class1))], interpolation='none')
  else:
    plt.imshow(test_neg_class2[np.random.randint(len(test_neg_class2))], interpolation='none')
  plt.gca().axis('off')


## Run BirdDog experiment

In [None]:
tf.reset_default_graph()
# tf.Session.reset("local")
print('Setting up graph...')

width, height = train_class1.shape[1:-1]
batch_size = 256

pi = 0.5  # Probability of 'happy' dog (outcome 0) for Task 1
mu = 1.0  # Probability of 'happy' dog (outcome 0) for Task 2

input_ph = tf.placeholder(
    tf.float32, shape=[None, width, height, 3], name="images")
reward_ph = tf.placeholder(tf.float32, shape=[None], name="rewards")

# Classic TD
value_net = Cifar10ValueNet(1)
output = value_net(input_ph)
values = tf.reshape(output.values, [-1])
td_features = output.activations

# Classic TD loss function
delta = reward_ph - values
td_loss = tf.reduce_mean(0.5 * tf.square(delta))
mse = tf.reduce_mean(tf.square(reward_ph - values))

# Distributional TD
n_atoms = 64
dist_net = Cifar10ValueNet(n_atoms)
taus_np = np.linspace(0., 1., n_atoms+2)[1:-1]
taus = tf.constant(taus_np, dtype=tf.float32)

output = dist_net(input_ph)
dist_values = output.values
dtd_features = output.activations

# Distributional TD loss function
delta = reward_ph[:, None] - dist_values
indic = tf.cast(delta <= 0., dtype=tf.float32)
weights = tf.stop_gradient(tf.abs(taus[None] - indic))

dtd_loss = tf.reduce_mean(weights * tf.abs(delta))
mse_dtd = tf.reduce_mean(tf.square(reward_ph - tf.reduce_mean(dist_values, -1)))

optimizer = tf.train.AdamOptimizer(0.0005)
update = optimizer.minimize(td_loss + dtd_loss)


In [None]:
num_updates = 1000
num_trials = 10

losses = np.zeros((2, num_trials, num_updates,))
transfer_losses = np.zeros((2, num_trials, num_updates,))
mserrors = np.zeros((2, num_trials, num_updates, 2))

for trial in range(num_trials):
  with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    # Extract features and values on test set (before training)
    ex1_td_init_features, ex1_dtd_init_features = sess.run(
        [td_features, dtd_features],
        feed_dict={input_ph: test_class1})
    ex2_td_init_features, ex2_dtd_init_features = sess.run(
        [td_features, dtd_features],
        feed_dict={input_ph: test_class2})

    # Train for Task 1
    for epoch in range(num_updates):
      imgs, rwds = sample_trial(pi, batch_size, train_class1, train_class2)
      _, loss_np, dloss_np, mse_np, dmse_np = sess.run(
          [update, td_loss, dtd_loss, mse, mse_dtd],
          feed_dict={input_ph: imgs, reward_ph: rwds})

      # Record losses and mean-square errors for later analysis
      losses[0, trial, epoch] = loss_np
      losses[1, trial, epoch] = dloss_np
      mserrors[0, trial, epoch, 0] = mse_np
      mserrors[1, trial, epoch, 0] = dmse_np

      if epoch % 500 == 0:
        print("%d: Task 1 Loss %d: %f, %f" % (trial, epoch, loss_np, dloss_np))

    # Extract features and values on test set (after task 1, before task 2)
    ex1_td_values, ex1_td_features, ex1_dtd_values, ex1_dtd_features = sess.run(
        [values, td_features, dist_values, dtd_features],
        feed_dict={input_ph: test_class1})
    ex2_td_values, ex2_td_features, ex2_dtd_values, ex2_dtd_features = sess.run(
        [values, td_features, dist_values, dtd_features],
        feed_dict={input_ph: test_class2})

    # Train for Task 2 (observe we change from pi to mu)
    for epoch in range(num_updates):
      imgs, rwds = sample_trial(mu, batch_size, train_class1, train_class2)
      _, loss_np, dloss_np, mse_np, dmse_np = sess.run(
          [update, td_loss, dtd_loss, mse, mse_dtd],
          feed_dict={input_ph: imgs, reward_ph: rwds})

      # Record losses and mean-square errors for later analysis
      transfer_losses[0, trial, epoch] = loss_np
      transfer_losses[1, trial, epoch] = dloss_np
      mserrors[0, trial, epoch, 1] = mse_np
      mserrors[1, trial, epoch, 1] = dmse_np

      if epoch % 500 == 0:
        print("%d: Task 2 Loss %d: %f %f" % (trial, epoch, loss_np, dloss_np))

    # Extract features and values on test set (after task 1 *and* task 2)
    tex1_td_values, tex1_td_features, tex1_dtd_values, tex1_dtd_features = sess.run(
        [values, td_features, dist_values, dtd_features],
        feed_dict={input_ph: test_class1})
    tex2_td_values, tex2_td_features, tex2_dtd_values, tex2_dtd_features = sess.run(
        [values, td_features, dist_values, dtd_features],
        feed_dict={input_ph: test_class2})


## Visualize results

In [None]:
colours = [
    '#e74c3c',
    '#38495C',
]

In [None]:
sns.set_style('whitegrid')
plt.figure(figsize=(12, 4))
window = 10

plt.subplot(1, 2, 1)
plt.plot(running_mean(losses[0].mean(0), window), label="TD", color=colours[1], lw=2)
plt.plot(running_mean(losses[1].mean(0), window), label="Distributional TD", color=colours[0], lw=2)

for trial in range(num_trials):
  plt.plot(running_mean(losses[0, trial], window), label=None, color=colours[1], lw=1, alpha=0.25)
  plt.plot(running_mean(losses[1, trial], window), label=None, color=colours[0], lw=1, alpha=0.25)

plt.title("Task 1", fontsize=18)
plt.xlabel("Training Iterations", fontsize=16)
plt.ylabel("Loss", fontsize=16)
plt.tick_params(top=False, right=False, labelsize=14)
plt.ylim([0., 5.])

plt.subplot(1, 2, 2)
plt.plot(running_mean(transfer_losses[0].mean(0), window), label="TD", color=colours[1], lw=2)
plt.plot(running_mean(transfer_losses[1].mean(0), window), label="Distributional TD", color=colours[0], lw=2)

for trial in range(num_trials):
  plt.plot(running_mean(transfer_losses[0, trial], window), label=None, color=colours[1], lw=1, alpha=0.25)
  plt.plot(running_mean(transfer_losses[1, trial], window), label=None, color=colours[0], lw=1, alpha=0.25)


plt.title("Task 2", fontsize=18)
plt.xlabel("Training Iterations", fontsize=16)
plt.ylabel("Loss", fontsize=16)
plt.tick_params(top=False, right=False, labelsize=14)
plt.legend(loc='best', fontsize=16)

plt.savefig("birddog_loss.pdf")
%download_file birddog_loss.pdf

In [None]:
plt.figure(figsize=(12, 4))
window = 10
sns.set_style('whitegrid')

plt.subplot(1, 2, 1)
plt.plot(running_mean(mserrors[0, ..., 0].mean(0), window), label="TD", color=colours[1], lw=2)
plt.plot(running_mean(mserrors[1, ..., 0].mean(0), window), label="Distributional TD", color=colours[0], lw=2)

for trial in range(num_trials):
  plt.plot(running_mean(mserrors[0, trial, :, 0], window), label=None, color=colours[1], lw=1, alpha=0.25)
  plt.plot(running_mean(mserrors[1, trial, :, 0], window), label=None, color=colours[0], lw=1, alpha=0.25)

plt.ylim([0., 5.])
plt.title("Task 1", fontsize=18)
plt.xlabel("Training Iterations", fontsize=16)
plt.ylabel("MSE", fontsize=16)
plt.tick_params(top=False, right=False, labelsize=14)
_ = plt.xlim([0, 1000])
plt.ylim([0., 3.])

plt.subplot(1, 2, 2)
plt.plot(running_mean(mserrors[0, ..., 1].mean(0), window), label="TD", color=colours[1], lw=2)
plt.plot(running_mean(mserrors[1, ..., 1].mean(0), window), label="Distributional TD", color=colours[0], lw=2)

for trial in range(num_trials):
  plt.plot(running_mean(mserrors[0, trial, :, 1], window), label=None, color=colours[1], lw=1, alpha=0.25)
  plt.plot(running_mean(mserrors[1, trial, :, 1], window), label=None, color=colours[0], lw=1, alpha=0.25)


plt.title("Task 2", fontsize=18)
plt.xlabel("Training Iterations", fontsize=16)
plt.ylabel("MSE", fontsize=16)
plt.tick_params(top=False, right=False, labelsize=14)
plt.legend(loc='best', fontsize=16)
_ = plt.xlim([0, 1000])
plt.ylim([0., 3.])

plt.savefig("birddog_mse.pdf")
%download_file birddog_mse.pdf


## Analyze representation

In [None]:
ex1_td_init_tsne = TSNE(n_components=2).fit_transform(ex1_td_init_features)
ex1_dtd_init_tsne = TSNE(n_components=2).fit_transform(ex1_dtd_init_features)

ex2_td_init_tsne = TSNE(n_components=2).fit_transform(ex2_td_init_features)
ex2_dtd_init_tsne = TSNE(n_components=2).fit_transform(ex2_dtd_init_features)

ex1_td_init_pca = PCA(n_components=2).fit_transform(ex1_td_init_features)
ex1_dtd_init_pca = PCA(n_components=2).fit_transform(ex1_dtd_init_features)

ex2_td_init_pca = PCA(n_components=2).fit_transform(ex2_td_init_features)
ex2_dtd_init_pca = PCA(n_components=2).fit_transform(ex2_dtd_init_features)

In [None]:
ex1_td_tsne = TSNE(n_components=2).fit_transform(ex1_td_features)
ex1_dtd_tsne = TSNE(n_components=2).fit_transform(ex1_dtd_features)

ex2_td_tsne = TSNE(n_components=2).fit_transform(ex2_td_features)
ex2_dtd_tsne = TSNE(n_components=2).fit_transform(ex2_dtd_features)

ex1_td_pca = PCA(n_components=2).fit_transform(ex1_td_features)
ex1_dtd_pca = PCA(n_components=2).fit_transform(ex1_dtd_features)

ex2_td_pca = PCA(n_components=2).fit_transform(ex2_td_features)
ex2_dtd_pca = PCA(n_components=2).fit_transform(ex2_dtd_features)

In [None]:
tex1_td_tsne = TSNE(n_components=2).fit_transform(tex1_td_features)
tex1_dtd_tsne = TSNE(n_components=2).fit_transform(tex1_dtd_features)

tex2_td_tsne = TSNE(n_components=2).fit_transform(tex2_td_features)
tex2_dtd_tsne = TSNE(n_components=2).fit_transform(tex2_dtd_features)

tex1_td_pca = PCA(n_components=2).fit_transform(tex1_td_features)
tex1_dtd_pca = PCA(n_components=2).fit_transform(tex1_dtd_features)

tex2_td_pca = PCA(n_components=2).fit_transform(tex2_td_features)
tex2_dtd_pca = PCA(n_components=2).fit_transform(tex2_dtd_features)

In [None]:
def color_sort(values):
  order = list(enumerate(np.argsort(values)))
  return np.array(sorted(order, key=lambda k: k[1]))[:, 0] / float(len(values))


In [None]:
N = ex1_td_tsne.shape[0]
order_td1 = color_sort(ex1_td_init_tsne.sum(-1))
order_td2 = color_sort(ex2_td_init_tsne.sum(-1))

order_dtd1 = color_sort(ex1_dtd_init_tsne.sum(-1))
order_dtd2 = color_sort(ex2_dtd_init_tsne.sum(-1))


In [None]:
plt.figure(figsize=(20, 12))

td_colors = [plt.cm.Blues, plt.cm.Reds]
dtd_colors = [plt.cm.Blues, plt.cm.Reds]

kwargs = dict(linewidth=0, alpha=0.75)

plt.subplot(3, 4, 1)
plt.scatter(ex1_td_init_tsne[:, 0], ex1_td_init_tsne[:, 1], c=td_colors[0](order_td1), **kwargs)
plt.scatter(ex2_td_init_tsne[:, 0], ex2_td_init_tsne[:, 1], c=td_colors[1](order_td2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("TD Task1 (t-SNE)", fontsize=14)

plt.subplot(3, 4, 2)
plt.scatter(ex1_dtd_init_tsne[:, 0], ex1_dtd_init_tsne[:, 1], c=dtd_colors[0](order_dtd1), **kwargs)
plt.scatter(ex2_dtd_init_tsne[:, 0], ex2_dtd_init_tsne[:, 1], c=dtd_colors[1](order_dtd2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("DistTD Task1 (t-SNE)", fontsize=14)

plt.subplot(3, 4, 3)
plt.scatter(ex1_td_init_pca[:, 0], ex1_td_init_pca[:, 1], c=td_colors[0](order_td1), **kwargs)
plt.scatter(ex2_td_init_pca[:, 0], ex2_td_init_pca[:, 1], c=td_colors[1](order_td2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("TD Task1 (PCA)", fontsize=14)

plt.subplot(3, 4, 4)
plt.scatter(ex1_dtd_init_pca[:, 0], ex1_dtd_init_pca[:, 1], c=dtd_colors[0](order_dtd1), **kwargs)
plt.scatter(ex2_dtd_init_pca[:, 0], ex2_dtd_init_pca[:, 1], c=dtd_colors[1](order_dtd2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("DistTD Task1 (PCA)", fontsize=14)

# Task1 trained
plt.subplot(3, 4, 5)
plt.scatter(ex1_td_tsne[:, 0], ex1_td_tsne[:, 1], c=td_colors[0](order_td1), **kwargs)
plt.scatter(ex2_td_tsne[:, 0], ex2_td_tsne[:, 1], c=td_colors[1](order_td2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("TD Task1 (t-SNE)", fontsize=14)

plt.subplot(3, 4, 6)
plt.scatter(ex1_dtd_tsne[:, 0], ex1_dtd_tsne[:, 1], c=dtd_colors[0](order_dtd1), **kwargs)
plt.scatter(ex2_dtd_tsne[:, 0], ex2_dtd_tsne[:, 1], c=dtd_colors[1](order_dtd2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("DistTD Task1 (t-SNE)", fontsize=14)

plt.subplot(3, 4, 7)
plt.scatter(ex1_td_pca[:, 0], ex1_td_tsne[:, 1], c=td_colors[0](order_td1), **kwargs)
plt.scatter(ex2_td_pca[:, 0], ex2_td_tsne[:, 1], c=td_colors[1](order_td2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("TD Task1 (PCA)", fontsize=14)

plt.subplot(3, 4, 8)
plt.scatter(ex1_dtd_pca[:, 0], ex1_dtd_pca[:, 1], c=dtd_colors[0](order_dtd1), **kwargs)
plt.scatter(ex2_dtd_pca[:, 0], ex2_dtd_pca[:, 1], c=dtd_colors[1](order_dtd2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("DistTD Task1 (PCA)", fontsize=14)

# Transfer
plt.subplot(3, 4, 9)
plt.scatter(tex1_td_tsne[:, 0], tex1_td_tsne[:, 1], c=td_colors[0](order_td1), **kwargs)
plt.scatter(tex2_td_tsne[:, 0], tex2_td_tsne[:, 1], c=td_colors[1](order_td2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("TD Task2 (t-SNE)", fontsize=14)

plt.subplot(3, 4, 10)
plt.scatter(tex1_dtd_tsne[:, 0], tex1_dtd_tsne[:, 1], c=dtd_colors[0](order_dtd1), **kwargs)
plt.scatter(tex2_dtd_tsne[:, 0], tex2_dtd_tsne[:, 1], c=dtd_colors[1](order_dtd2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("DistTD Task2 (t-SNE)", fontsize=14)

plt.subplot(3, 4, 11)
plt.scatter(tex1_td_pca[:, 0], tex1_td_pca[:, 1], c=td_colors[0](order_td1), **kwargs)
plt.scatter(tex2_td_pca[:, 0], tex2_td_pca[:, 1], c=td_colors[1](order_td2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("TD Task2 (PCA)", fontsize=14)

plt.subplot(3, 4, 12)
plt.scatter(tex1_dtd_pca[:, 0], tex1_dtd_pca[:, 1], c=dtd_colors[0](order_dtd1), **kwargs)
plt.scatter(tex2_dtd_pca[:, 0], tex2_dtd_pca[:, 1], c=dtd_colors[1](order_dtd2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("DistTD Task2 (PCA)", fontsize=14)

plt.savefig("birddog_analysis2d.pdf")
%download_file birddog_analysis2d.pdf


In [None]:
from mpl_toolkits.mplot3d import Axes3D

In [None]:
ex1_td_init_tsne = TSNE(n_components=3).fit_transform(ex1_td_init_features)
ex1_dtd_init_tsne = TSNE(n_components=3).fit_transform(ex1_dtd_init_features)

ex2_td_init_tsne = TSNE(n_components=3).fit_transform(ex2_td_init_features)
ex2_dtd_init_tsne = TSNE(n_components=3).fit_transform(ex2_dtd_init_features)

ex1_td_init_pca = PCA(n_components=3).fit_transform(ex1_td_init_features)
ex1_dtd_init_pca = PCA(n_components=3).fit_transform(ex1_dtd_init_features)

ex2_td_init_pca = PCA(n_components=3).fit_transform(ex2_td_init_features)
ex2_dtd_init_pca = PCA(n_components=3).fit_transform(ex2_dtd_init_features)

In [None]:
ex1_td_tsne = TSNE(n_components=3).fit_transform(ex1_td_features)
ex1_dtd_tsne = TSNE(n_components=3).fit_transform(ex1_dtd_features)

ex2_td_tsne = TSNE(n_components=3).fit_transform(ex2_td_features)
ex2_dtd_tsne = TSNE(n_components=3).fit_transform(ex2_dtd_features)

ex1_td_pca = PCA(n_components=3).fit_transform(ex1_td_features)
ex1_dtd_pca = PCA(n_components=3).fit_transform(ex1_dtd_features)

ex2_td_pca = PCA(n_components=3).fit_transform(ex2_td_features)
ex2_dtd_pca = PCA(n_components=3).fit_transform(ex2_dtd_features)

In [None]:
tex1_td_tsne = TSNE(n_components=3).fit_transform(tex1_td_features)
tex1_dtd_tsne = TSNE(n_components=3).fit_transform(tex1_dtd_features)

tex2_td_tsne = TSNE(n_components=3).fit_transform(tex2_td_features)
tex2_dtd_tsne = TSNE(n_components=3).fit_transform(tex2_dtd_features)

tex1_td_pca = PCA(n_components=3).fit_transform(tex1_td_features)
tex1_dtd_pca = PCA(n_components=3).fit_transform(tex1_dtd_features)

tex2_td_pca = PCA(n_components=3).fit_transform(tex2_td_features)
tex2_dtd_pca = PCA(n_components=3).fit_transform(tex2_dtd_features)

In [None]:
N = ex1_td_tsne.shape[0]
order_td1 = color_sort(ex1_td_init_tsne.sum(-1))
order_td2 = color_sort(ex2_td_init_tsne.sum(-1))

order_dtd1 = color_sort(ex1_dtd_init_tsne.sum(-1))
order_dtd2 = color_sort(ex2_dtd_init_tsne.sum(-1))

# order_td1 = color_sort(ex1_td_tsne.sum(-1))
# order_td2 = color_sort(ex2_td_tsne.sum(-1))

# order_dtd1 = color_sort(ex1_dtd_tsne.sum(-1))
# order_dtd2 = color_sort(ex2_dtd_tsne.sum(-1))


In [None]:
plt.figure(figsize=(20, 12))

td_colors = [plt.cm.Blues, plt.cm.Reds]
dtd_colors = [plt.cm.Blues, plt.cm.Reds]

kwargs = dict(linewidth=0, alpha=0.75)

ax = plt.subplot(3, 4, 1, projection='3d')
ax.scatter(ex1_td_init_tsne[:, 0], ex1_td_init_tsne[:, 1], ex1_td_init_tsne[:, 2], c=td_colors[0](order_td1), **kwargs)
ax.scatter(ex2_td_init_tsne[:, 0], ex2_td_init_tsne[:, 1], ex2_td_init_tsne[:, 2], c=td_colors[1](order_td2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("TD Task1 (t-SNE)", fontsize=14)

ax = plt.subplot(3, 4, 2, projection='3d')
ax.scatter(ex1_dtd_init_tsne[:, 0], ex1_dtd_init_tsne[:, 1], ex1_dtd_init_tsne[:, 2], c=dtd_colors[0](order_dtd1), **kwargs)
ax.scatter(ex2_dtd_init_tsne[:, 0], ex2_dtd_init_tsne[:, 1], ex2_dtd_init_tsne[:, 2], c=dtd_colors[1](order_dtd2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("DistTD Task1 (t-SNE)", fontsize=14)

ax = plt.subplot(3, 4, 3, projection='3d')
ax.scatter(ex1_td_init_pca[:, 0], ex1_td_init_pca[:, 1], ex1_td_init_pca[:, 2], c=td_colors[0](order_td1), **kwargs)
ax.scatter(ex2_td_init_pca[:, 0], ex2_td_init_pca[:, 1], ex2_td_init_pca[:, 2], c=td_colors[1](order_td2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("TD Task1 (PCA)", fontsize=14)

ax = plt.subplot(3, 4, 4, projection='3d')
ax.scatter(ex1_dtd_init_pca[:, 0], ex1_dtd_init_pca[:, 1], ex1_dtd_init_pca[:, 2], c=dtd_colors[0](order_dtd1), **kwargs)
ax.scatter(ex2_dtd_init_pca[:, 0], ex2_dtd_init_pca[:, 1], ex2_dtd_init_pca[:, 2], c=dtd_colors[1](order_dtd2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("DistTD Task1 (PCA)", fontsize=14)

# Task1 trained
ax = plt.subplot(3, 4, 5, projection='3d')
ax.scatter(ex1_td_tsne[:, 0], ex1_td_tsne[:, 1], ex1_td_tsne[:, 2], c=td_colors[0](order_td1), **kwargs)
ax.scatter(ex2_td_tsne[:, 0], ex2_td_tsne[:, 1], ex2_td_tsne[:, 2], c=td_colors[1](order_td2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("TD Task1 (t-SNE)", fontsize=14)

ax = plt.subplot(3, 4, 6, projection='3d')
ax.scatter(ex1_dtd_tsne[:, 0], ex1_dtd_tsne[:, 1], ex1_dtd_tsne[:, 2], c=dtd_colors[0](order_dtd1), **kwargs)
ax.scatter(ex2_dtd_tsne[:, 0], ex2_dtd_tsne[:, 1], ex2_dtd_tsne[:, 2], c=dtd_colors[1](order_dtd2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("DistTD Task1 (t-SNE)", fontsize=14)

ax = plt.subplot(3, 4, 7, projection='3d')
ax.scatter(ex1_td_pca[:, 0], ex1_td_pca[:, 1], ex1_td_pca[:, 2], c=td_colors[0](order_td1), **kwargs)
ax.scatter(ex2_td_pca[:, 0], ex2_td_pca[:, 1], ex2_td_pca[:, 2], c=td_colors[1](order_td2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("TD Task1 (PCA)", fontsize=14)

ax = plt.subplot(3, 4, 8, projection='3d')
ax.scatter(ex1_dtd_pca[:, 0], ex1_dtd_pca[:, 1], ex1_dtd_pca[:, 2], c=dtd_colors[0](order_dtd1), **kwargs)
ax.scatter(ex2_dtd_pca[:, 0], ex2_dtd_pca[:, 1], ex2_dtd_pca[:, 2], c=dtd_colors[1](order_dtd2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("DistTD Task1 (PCA)", fontsize=14)

# Transfer
ax = plt.subplot(3, 4, 9, projection='3d')
ax.scatter(tex1_td_tsne[:, 0], tex1_td_tsne[:, 1], tex1_td_tsne[:, 2], c=td_colors[0](order_td1), **kwargs)
ax.scatter(tex2_td_tsne[:, 0], tex2_td_tsne[:, 1], tex2_td_tsne[:, 2], c=td_colors[1](order_td2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("TD Task2 (t-SNE)", fontsize=14)

ax = plt.subplot(3, 4, 10, projection='3d')
ax.scatter(tex1_dtd_tsne[:, 0], tex1_dtd_tsne[:, 1], tex1_dtd_tsne[:, 2], c=dtd_colors[0](order_dtd1), **kwargs)
ax.scatter(tex2_dtd_tsne[:, 0], tex2_dtd_tsne[:, 1], tex2_dtd_tsne[:, 2], c=dtd_colors[1](order_dtd2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("DistTD Task2 (t-SNE)", fontsize=14)

ax = plt.subplot(3, 4, 11, projection='3d')
ax.scatter(tex1_td_pca[:, 0], tex1_td_pca[:, 1], tex1_td_pca[:, 2], c=td_colors[0](order_td1), **kwargs)
ax.scatter(tex2_td_pca[:, 0], tex2_td_pca[:, 1], tex2_td_pca[:, 2], c=td_colors[1](order_td2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("TD Task2 (PCA)", fontsize=14)

ax = plt.subplot(3, 4, 12, projection='3d')
ax.scatter(tex1_dtd_pca[:, 0], tex1_dtd_pca[:, 1], tex1_dtd_pca[:, 2], c=dtd_colors[0](order_dtd1), **kwargs)
ax.scatter(tex2_dtd_pca[:, 0], tex2_dtd_pca[:, 1], tex2_dtd_pca[:, 2], c=dtd_colors[1](order_dtd2), **kwargs)
plt.tick_params(top=False, right=False)
plt.title("DistTD Task2 (PCA)", fontsize=14)

plt.savefig("birddog_analysis3d.pdf")
%download_file birddog_analysis3d.pdf


## Negative transfer experiment

In [None]:
def sample_test_images_neg(batch_size):
  permutation1 = np.random.permutation(len(test_class1_neg))
  permutation2 = np.random.permutation(len(test_class2_neg))
  return test_class1_neg[permutation1], test_class2_neg[permutation2]


def sample_trajectory_neg(pi, batch_size):
  images = np.zeros((batch_size, ) + train_class1_neg[0].shape)
  rewards = np.zeros(batch_size)
  
  for b in range(batch_size):
    if np.random.random() < 0.5:
      images[b] = train_class2_neg[np.random.randint(len(train_class2_neg))]
      rewards[b] = 1.
    else:
      images[b] = train_class1_neg[np.random.randint(len(train_class1_neg))]
      rewards[b] = 3. if np.random.random() < pi[0] else -1.

  return images, rewards



In [None]:
tf.reset_default_graph()
tf.Session.reset("local")
print('Setting up graph...')

width, height = 32, 32
batch_size = 256

input_ph = tf.placeholder(tf.float32, shape=[None, width, height, 3], name="images")
reward_ph = tf.placeholder(tf.float32, shape=[None], name="rewards")

pi = 0.5  # Probability of 'happy' dog (outcome 0) for Task 1
mu = 1.0  # Probability of 'happy' dog (outcome 0) for Task 2

# Classic TD
value_net = Cifar10ValueNet(1)
output = value_net(input_ph)
values = tf.reshape(output.values, [-1])
td_features = output.activations

delta = reward_ph - values
td_loss = tf.reduce_mean(tf.square(delta))
mse = tf.reduce_mean(tf.square(reward_ph - values))


# Distributional TD
n_atoms = 65
dist_net = Cifar10ValueNet(n_atoms)
taus_np = np.linspace(0., 1., n_atoms+2)[1:-1]
taus = tf.constant(taus_np, dtype=tf.float32)

output = dist_net(input_ph)
dist_values = output.values
dtd_features = output.activations

delta = reward_ph[:, None] - dist_values
indic = tf.cast(delta <= 0., dtype=tf.float32)
weights = tf.abs(taus[None] - indic)

#dtd_loss = tf.reduce_mean(weights * tf.abs(delta))
dtd_loss = tf.reduce_mean(weights * tf.square(delta))
mse_dtd = tf.reduce_mean(tf.square(reward_ph - tf.reduce_mean(dist_values, -1)))

optimizer = tf.train.AdamOptimizer(0.0005)
update = optimizer.minimize(td_loss + dtd_loss)


In [None]:

num_updates = 1000#2500
num_trials = 10

losses = np.zeros((2, num_trials, num_updates,))
transfer_losses = np.zeros((2, num_trials, num_updates,))
mserrors = np.zeros((2, num_trials, num_updates, 2))

for trial in range(num_trials):
  with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    # Train (on train set) for original task/policy
    for epoch in range(num_updates):
      imgs, rwds = sample_trial(pi, batch_size, train_class1, train_class2)
      _, loss_np, dloss_np, mse_np, dmse_np = sess.run(
          [update, td_loss, dtd_loss, mse, mse_dtd],
          feed_dict={input_ph: imgs, reward_ph: rwds})

      losses[0, trial, epoch] = loss_np
      losses[1, trial, epoch] = dloss_np

      mserrors[0, trial, epoch, 0] = mse_np
      mserrors[1, trial, epoch, 0] = dmse_np

      if epoch % 100 == 0:
        print("%d: Loss %d: %f, %f" % (trial, epoch, loss_np, dloss_np))

    # Train (on train set) for transfer
    for epoch in range(num_updates):
      imgs, rwds = sample_trial(mu, batch_size, train_neg_class1, train_neg_class2)
      _, loss_np, dloss_np, mse_np, dmse_np = sess.run(
          [update, td_loss, dtd_loss, mse, mse_dtd],
          feed_dict={input_ph: imgs, reward_ph: rwds})

      transfer_losses[0, trial, epoch] = loss_np
      transfer_losses[1, trial, epoch] = dloss_np

      mserrors[0, trial, epoch, 1] = mse_np
      mserrors[1, trial, epoch, 1] = dmse_np

      if epoch % 100 == 0:
        print("%d: Loss %d: %f %f" % (trial, epoch, loss_np, dloss_np))


## Visualize results

In [None]:
colours = [
    '#e74c3c',
    '#38495C',
]

In [None]:
plt.figure(figsize=(12, 4))
window = 10

plt.subplot(1, 2, 1)
plt.plot(running_mean(losses[0].mean(0), window), label="TD", color=colours[1], lw=2)
plt.plot(running_mean(losses[1].mean(0), window), label="Distributional TD", color=colours[0], lw=2)

for trial in range(num_trials):
  plt.plot(running_mean(losses[0, trial], window), label=None, color=colours[1], lw=1, alpha=0.25)
  plt.plot(running_mean(losses[1, trial], window), label=None, color=colours[0], lw=1, alpha=0.25)

plt.title("Task 1", fontsize=18)
plt.xlabel("Training Iterations", fontsize=16)
plt.ylabel("Loss", fontsize=16)
plt.tick_params(top=False, right=False, labelsize=14)
plt.ylim([0., 5.])

plt.subplot(1, 2, 2)
plt.plot(running_mean(transfer_losses[0].mean(0), window), label="TD", color=colours[1], lw=2)
plt.plot(running_mean(transfer_losses[1].mean(0), window), label="Distributional TD", color=colours[0], lw=2)

for trial in range(num_trials):
  plt.plot(running_mean(transfer_losses[0, trial], window), label=None, color=colours[1], lw=1, alpha=0.25)
  plt.plot(running_mean(transfer_losses[1, trial], window), label=None, color=colours[0], lw=1, alpha=0.25)


plt.title("Task 2", fontsize=18)
plt.xlabel("Training Iterations", fontsize=16)
plt.ylabel("Loss", fontsize=16)
plt.tick_params(top=False, right=False, labelsize=14)
plt.legend(loc='best', fontsize=16)

plt.savefig("birddog_loss.pdf")
%download_file birddog_loss.pdf

In [None]:
plt.figure(figsize=(12, 4))
window = 10
sns.set_style('whitegrid')

plt.subplot(1, 2, 1)
plt.plot(running_mean(mserrors[0, ..., 0].mean(0), window), label="TD", color=colours[1], lw=2)
plt.plot(running_mean(mserrors[1, ..., 0].mean(0), window), label="Distributional TD", color=colours[0], lw=2)

for trial in range(num_trials):
  plt.plot(running_mean(mserrors[0, trial, :, 0], window), label=None, color=colours[1], lw=1, alpha=0.25)
  plt.plot(running_mean(mserrors[1, trial, :, 0], window), label=None, color=colours[0], lw=1, alpha=0.25)

plt.ylim([0., 5.])
plt.title("Task 1", fontsize=18)
plt.xlabel("Training Iterations", fontsize=16)
plt.ylabel("MSE", fontsize=16)
plt.tick_params(top=False, right=False, labelsize=14)
#_ = plt.xlim([0, 1000])
plt.ylim([1.8, 3.2])

plt.subplot(1, 2, 2)
plt.plot(running_mean(mserrors[0, ..., 1].mean(0), window), label="TD", color=colours[1], lw=2)
plt.plot(running_mean(mserrors[1, ..., 1].mean(0), window), label="Distributional TD", color=colours[0], lw=2)

for trial in range(num_trials):
  plt.plot(running_mean(mserrors[0, trial, :, 1], window), label=None, color=colours[1], lw=1, alpha=0.25)
  plt.plot(running_mean(mserrors[1, trial, :, 1], window), label=None, color=colours[0], lw=1, alpha=0.25)


plt.title("Task 2 (Different Classes)", fontsize=18)
plt.xlabel("Training Iterations", fontsize=16)
plt.ylabel("MSE", fontsize=16)
plt.tick_params(top=False, right=False, labelsize=14)
plt.legend(loc='best', fontsize=16)
#_ = plt.xlim([0, 1000])
plt.ylim([0., 1.6])

plt.savefig("birddog_mse.pdf")
%download_file birddog_mse.pdf
