Copyright 2021 DeepMind Technologies Limited.


Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Continual learning with pre-trained encoders and ensembles of classifiers

Murray Shanahan

July 2021

A classifier ensemble memory model that mitigates catastrophic forgetting. The model comprises
*   a pre-trained encoder, trained on a different dataset from the target dataset, and
*   a memory with fixed randomised keys and k-nearest neighbour lookup, where
*   each memory location stores the parameters of a trainable local classifier, and
*   the ensemble's output is the mean output of the k selected classifiers weighted according to the distance of their keys from the encoded input

The model is demonstrated on MNIST, where the encoder is pre-trained on Omniglot. The continual learning setting is
*   Task-free. The models doesn't know about task boundaries
*   Online. The dataset is ony seen once, and there are no epochs
*   Incremental class learning. Evaluation is always on 10-way classification

This Colab accompanies the paper:

Shanahan, M., Kaplanis, C. & Mitrovic, J. (2021). Encoders and Ensembles for Task-Free Continual Learning. ArXiv preprint: https://arxiv.org/abs/2105.13327

# Preliminaries

In [None]:
# Dependencies that may require pip installation

!pip install dm-haiku
!pip install optax
!pip install dm-tree

In [None]:
# Imports

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

import jax
import jax.numpy as jnp
from jax import jit, grad, random
import optax
import haiku as hk
import tree

from tqdm.notebook import trange

from matplotlib import pyplot as plt

import tensorflow_datasets as tfds

In [None]:
# Experiment parameters


# MNIST config (high data) - for comparison with Lee, et al. (2020)

config = {
    'enc_size': 512,  # size of latent encoding
    'mem_size': 1024,  # number of memory locations (classifiers)
    'k': 32,  # k nearest neighbour lookup parameter
    'vub': 250,  # upper bound for activation function - was 100
    'res': 28,  # resolution - 28 for MNIST & Omniglot
    'col_dims': 1,  # 3 for RGB, 1 for greyscale
    'num_classes': 10,  # number of classes
    'pretrain_n_batches': 10000,  # number of batches in pre-training
    'pretrain_dataset': 'omniglot',
    'main_dataset': 'mnist',
    'batch_size': 60,  # batch size for training main model
    'learning_rate': 1e-4,  # learning rate for training main model
    'weight_decay': 1e-4,  # optimiser weight decay
    'init_scale': 0.1,  # baseline classifier initialiser variance scaling
    'log_every': 10,  # interval for logging accuracies
    'report_every': 500,  # interval for reporting accuracies
    'schedule_type': '5way_split',  # training schedule (defining splits, etc)
    'n_runs': 20,  # number of runs on the schedule
}

In [None]:
# Optimisers


class NaiveOptimiser():
  """Optimiser that discards magnitude of gradients and uses only their sign."""

  def __init__(self, learning_rate, weight_decay):
    self.learning_rate = learning_rate
    self.weight_decay = weight_decay
    self.state = None

  def init(self, params):
    return None

  def update(self, grads, _, params=None):
    step_size = self.learning_rate
    weight_decay = self.weight_decay
    updates = jax.tree_map(lambda g: -jnp.sign(g), grads)
    updates = jax.tree_multimap(
        lambda g, p: g + weight_decay * p, updates, params)
    updates = jax.tree_map(lambda g: step_size * g, updates)
    return (updates, self.state)


def make_encoder_optimiser():
  opt = optax.adam(learning_rate=0.001)  # learning rate was 0.001
  return opt


def make_ensemble_optimiser():
  opt = NaiveOptimiser(learning_rate=config['learning_rate'],
                       weight_decay=config['weight_decay'])
  return opt

# Datasets

In [None]:
# Datasets: MNIST and Omniglot


def get_dataset(dataset_name, train_or_test, batch_size, filter_labels=None):
  filter_fn = lambda batch: tf.reduce_any(tf.equal(batch['label'],
                                                   filter_labels))
  dataset = tfds.load(dataset_name, split=train_or_test,
                      as_supervised=False)
  if filter_labels is not None:
    dataset = dataset.filter(filter_fn)
  dataset = dataset.shuffle(buffer_size=10000)
  dataset = dataset.batch(batch_size)
  dataset = dataset.repeat()
  dataset = iter(dataset)
  return dataset


def get_batch(dataset, dataset_name):
  batch = next(dataset)
  batch_size = batch['image'].shape[0]
  images = batch['image']
  if dataset_name == 'omniglot':
    images = tf.image.resize(images, [config['res'], config['res']])
    images = images[:, :, :, 0]
  images = tf.reshape(images, [batch_size, 28, 28, 1]) / 255
  if dataset_name == 'omniglot':
    images = 1 - images  # raw Omniglot characters are white (1) on black (0)
  labels = batch['label']
  one_hots = tf.one_hot(batch['label'], config['num_classes'])
  return (images.numpy(), one_hots.numpy(), labels.numpy())

# Plotting

In [None]:
# Plotting accuracies


def smooth(data, degree=2):
  """Smooth out data for plotting."""
  triangle = jnp.array(list(range(degree)) + [degree] +
                       list(range(degree)[::-1])) + 1
  # Copy last data point 'degree' times
  data = jnp.append(data, jnp.array([data[-1] for _ in range(len(triangle))]))
  smoothed = [data[0]]
  for i in range(1, len(data) - len(triangle)):
    point = data[i:i + len(triangle)] * triangle
    smoothed.append(sum(point)/sum(triangle))
  return jnp.array(smoothed)


def plot_x_accuracies(x_accuracies1, x_accuracies2, x_accuracies3, final=False):
  """Plot experiment mean accuracies with error bounds."""

  # Find means and stds
  x_accuracies1 = jnp.array(x_accuracies1, dtype=jnp.float64)
  x_accuracies2 = jnp.array(x_accuracies2, dtype=jnp.float64)
  x_accuracies3 = jnp.array(x_accuracies3, dtype=jnp.float64)
  c_vanilla_means = jnp.mean(x_accuracies1, axis=(0, 1))
  c_vanilla_stds = jnp.std(jnp.mean(x_accuracies1, axis=1), axis=0)
  c_tanh_means = jnp.mean(x_accuracies2, axis=(0, 1))
  c_tanh_stds = jnp.std(jnp.mean(x_accuracies2, axis=1), axis=0)
  e_means = jnp.mean(x_accuracies3, axis=(0, 1))
  e_stds = jnp.std(jnp.mean(x_accuracies3, axis=1), axis=0)
  final_c_vanilla_mean = c_vanilla_means[-1]
  final_c_vanilla_std = c_vanilla_stds[-1]
  final_c_tanh_mean = c_tanh_means[-1]
  final_c_tanh_std = c_tanh_stds[-1]
  final_e_mean = e_means[-1]
  final_e_std = e_stds[-1]

  # Smooth the data
  c_vanilla_means = smooth(c_vanilla_means)
  c_vanilla_stds = smooth(c_vanilla_stds)
  c_tanh_means = smooth(c_tanh_means)
  c_tanh_stds = smooth(c_tanh_stds)
  e_means = smooth(e_means)
  e_stds = smooth(e_stds)

  plt.figure(figsize=(8, 4))

  # Vanilla classifier accuracies
  ax = plt.plot(range(len(c_vanilla_means)), c_vanilla_means)
  colour = ax[-1].get_color()
  plt.fill_between(range(len(c_vanilla_means)),
                   c_vanilla_means-c_vanilla_stds,
                   c_vanilla_means+c_vanilla_stds,
                   facecolor=colour, alpha=0.2)
  # Tanh classifier accuracies
  ax = plt.plot(range(len(c_tanh_means)), c_tanh_means)
  colour = ax[-1].get_color()
  plt.fill_between(range(len(c_tanh_means)),
                   c_tanh_means-c_tanh_stds,
                   c_tanh_means+c_tanh_stds,
                   facecolor=colour, alpha=0.2)
  # Ensemble accuracies
  ax = plt.plot(range(len(e_means)), e_means)
  colour = ax[-1].get_color()
  plt.fill_between(range(len(e_means)),
                   e_means-e_stds,
                   e_means+e_stds,
                   facecolor=colour, alpha=0.2)

  # Produce plots
  plt.ylim([0.0, 1.0])
  plt.xlabel('Batch x{}'.format(config['log_every']))
  plt.ylabel('Accuracy')
  plt.legend(['Vanilla classifier', 'Tanh classifier', 'Ensemble'])
  plt.show()

  # Report accuracies
  print('Vanilla classifier accuracy: {:.4f} \u00b1 {:.4f}'.format(
      final_c_vanilla_mean, final_c_vanilla_std))
  print('Tanh classifier accuracy: {:.4f} \u00b1 {:.4f}'.format(
      final_c_tanh_mean, final_c_tanh_std))
  print('Ensemble accuracy: {:.4f} \u00b1 {:.4f}'.format(
      final_e_mean, final_e_std))
  print()

# Models and losses

In [None]:
# Autoencoder (for pretraining)


class Autoencoder(hk.Module):
  """Autoencoder module."""


  def encode(self, image):
    cnn = hk.Sequential([
        hk.Conv2D(output_channels=16, kernel_shape=4, name='enc1'), jax.nn.relu,
        hk.Conv2D(output_channels=16, kernel_shape=4, name='enc2'), jax.nn.relu,
        hk.Flatten(),
    ])
    mlp1 = hk.Sequential([
        hk.Linear(128, name='enc3'), jax.nn.relu,
        hk.Linear(config['enc_size'], name='enc4'),
    ])
    mlp2 = hk.Sequential([
        hk.Linear(128, name='enc5'), jax.nn.relu,
        hk.Linear(config['enc_size'], name='enc6'),
    ])
    feats = cnn(image.reshape([-1, config['res'], config['res'],
                               config['col_dims']]))
    enc_mean = jnp.tanh(mlp1(feats))
    enc_sd = jax.nn.relu(mlp2(feats))
    return (enc_mean, enc_sd)


  def decode(self, latent):
    dcnn = hk.Sequential(
        [hk.Linear(128, name='dec1'), jax.nn.relu,
        hk.Linear(config['res']*config['res']*16, name='dec2'), jax.nn.relu,
        hk.Reshape((config['res'], config['res'], 16)),
        hk.Conv2DTranspose(output_channels=16, kernel_shape=4, name='dec3'),
        jax.nn.relu,
        hk.Conv2DTranspose(output_channels=config['col_dims'],
                           kernel_shape=4, name='dec4'),
        jax.nn.sigmoid])
    image = dcnn(latent).reshape([-1, config['res'], config['res'],
                                  config['col_dims']])
    return image


  def forward(self, rng, image):
    (enc_mean, enc_sd) = self.encode(image)
    # Sample
    (rng2, rng) = random.split(rng)
    eps = random.normal(rng2, jnp.shape(enc_mean))
    enc = enc_mean + enc_sd * eps
    image_dec = self.decode(enc)
    out = {
      'enc_mean': enc_mean,
      'enc_sd': enc_sd,
      'image_dec': image_dec,
    }
    return out


def encoder(rng, image):
  autoencoder = Autoencoder()
  out = autoencoder.forward(rng, image)
  return out

In [None]:
# Autoencoder loss (for pretraining)


def kl_divergence(mean, sd):
  kl = -0.5 * (1.0 + jnp.log(sd**2) - mean**2 - sd**2)
  return kl


def autoencoder_losses(enc_params, rng, images):
  encoder_net = hk.transform(encoder)
  (rng2, rng) = random.split(rng)
  autoencoder_out = encoder_net.apply(enc_params, rng, rng2, images)
  enc_means = autoencoder_out['enc_mean']
  enc_sds = autoencoder_out['enc_sd']
  image_decs = autoencoder_out['image_dec']
  # Decoder reconstruction loss
  decoder_loss = jnp.mean((images-image_decs)**2)
  # Decoder KL loss
  kld = kl_divergence(enc_means, enc_sds + 1e-10)  # add epsilon to avoid sd=0
  kl_loss = jnp.mean(kld)
  # Total loss
  beta = 0.001  # weighting of KL term
  tot_loss = decoder_loss + beta * kl_loss
  losses = {
    'tot_loss': tot_loss,
    'decoder_loss': decoder_loss,
    'kl_loss': kl_loss,
  }
  return losses


def autoencoder_loss(enc_params, rng, images):
  losses = autoencoder_losses(enc_params, rng, images)
  return losses['tot_loss']

In [None]:
# Update autoencoder parameters

@jit
def update_autoencoder(enc_params, rng, opt_state, images):
  opt = make_encoder_optimiser()
  grads = grad(autoencoder_loss)(enc_params, rng, images)
  updates, opt_state = opt.update(grads, opt_state)
  new_params = optax.apply_updates(enc_params, updates)
  return new_params, opt_state

In [None]:
# Ensemble memory (main model)


def activation(values):
  """Activation function for ensemble (scaled tanh)."""
  out = jnp.tanh(values / config['vub']) * config['vub']
  return out


def l2_normalize(x, axis=None, epsilon=1e-12):
  """l2 normalize a tensor on an axis with numerical stability."""
  square_sum = jnp.sum(jnp.square(x), axis=axis, keepdims=True)
  x_inv_norm = jax.lax.rsqrt(jnp.maximum(square_sum, epsilon))
  return x * x_inv_norm


class Memory(hk.Module):
  """Memory module."""

  def __init__(self, name=None):
    super().__init__(name)
    self.keys = hk.get_parameter('mem_keys', [config['mem_size'],
                                              config['enc_size']],
                                 init=hk.initializers.Constant(0))
    self.weights = hk.get_parameter('mem_weights', [config['mem_size'],
                                                    config['enc_size'],
                                                    config['num_classes']],
                                    init=hk.initializers.VarianceScaling())
    self.biases = hk.get_parameter('mem_biases', [config['mem_size'], 1,
                                                  config['num_classes']],
                                   init=hk.initializers.Constant(0))


  def lookup(self, enc):
    """k-nearest neighbour lookup in ensemble memory."""
    enc = l2_normalize(enc, axis=1)
    keys = l2_normalize(self.keys, axis=1)
    sims = jnp.matmul(enc, jnp.transpose(keys))   # cosine similarities
    (k_sims, idx) = jax.lax.top_k(sims, config['k'])  # k nearest neighbours
    # Keys
    k_keys = jnp.take(self.keys, idx, axis=0)
    mean_key = jnp.mean(k_keys, axis=1)
    # Values
    k_encs = jnp.expand_dims(enc, axis=(1, 2))
    k_encs = jnp.tile(k_encs, (1, config['k'], 1, 1))
    k_weights = jnp.take(self.weights, idx, axis=0)
    k_biases = jnp.take(self.biases, idx, axis=0)
    k_values = jnp.matmul(k_encs, k_weights) + k_biases
    k_values = jnp.squeeze(k_values)
    k_values = activation(k_values)
    # Mean of values weighted by key similarity
    k_sims2 = jax.lax.stop_gradient(jnp.expand_dims(k_sims, axis=2))
    mean_value = jnp.sum(k_values * k_sims2, axis=1) / jnp.sum(k_sims2, axis=1)
    return (mean_key, mean_value, k_sims, k_keys, k_values, idx)


class EnsembleModel(hk.Module):
  """Ensemble memory model."""


  def __init__(self, name=None):
    super().__init__(name)
    self.memory = Memory()


  def enc_to_class_vanilla(self, enc_image):
    mlp = hk.Sequential([
        hk.Linear(config['num_classes'],
                  w_init=hk.initializers.VarianceScaling(
                      scale=config['init_scale']),
                  name='classifier1'),
        jax.nn.log_softmax,
    ])
    pred = mlp(enc_image)  # predicted class
    return pred


  def enc_to_class_tanh(self, enc_image):
    mlp = hk.Sequential([
        hk.Linear(config['num_classes'],
                  w_init=hk.initializers.VarianceScaling(
                      scale=config['init_scale']),
                  name='classifier2'),
    ])
    pred = activation(mlp(enc_image))  # predicted class
    return pred


  def forward(self, enc_image):
    """Memory lookup and classifier."""
    (mean_key, mean_value,
     k_sims, k_keys, k_values, idx) = self.memory.lookup(enc_image)
    classifier_out_vanilla = self.enc_to_class_vanilla(enc_image)
    classifier_out_tanh = self.enc_to_class_tanh(enc_image)
    out = {
      'classifier_out_vanilla': classifier_out_vanilla,
      'classifier_out_tanh': classifier_out_tanh,
      'k_sims': k_sims,
      'k_keys': k_keys,
      'k_values': k_values,
      'mean_key': mean_key,
      'mean_value': mean_value,
      'idx': idx,
    }
    return out


def model(enc_image):
  ensemble = EnsembleModel()
  out = ensemble.forward(enc_image)
  return out

In [None]:
# Loss and accuracy for ensemble memory (main model)


def model_loss(params, rng, enc_images, one_hots, labels):
  model_net = hk.transform(model)
  model_out = model_net.apply(params, rng, enc_images)
  preds_vanilla = model_out['classifier_out_vanilla']
  preds_tanh = model_out['classifier_out_tanh']
  mean_values = model_out['mean_value']
  # Classifier losses
  classifier_loss_vanilla = -jnp.mean(jnp.sum(preds_vanilla * one_hots, axis=1))
  classifier_loss_tanh = -jnp.mean(jnp.sum(preds_tanh * one_hots, axis=1))
  # Memory loss
  memory_loss = -jnp.mean(jnp.sum(mean_values * one_hots, axis=1))
  # Total loss
  loss = classifier_loss_vanilla + classifier_loss_tanh + memory_loss
  return loss


@jit
def accuracy(params, rng, enc_images, labels):
  """Accuracies for each type of model."""
  model_net = hk.transform(model)
  model_out = model_net.apply(params, rng, enc_images)
  classifier_classes_van = jnp.argmax(model_out['classifier_out_vanilla'], axis=1)
  classifier_acc_van = jnp.mean(classifier_classes_van == labels)
  classifier_classes_tanh = jnp.argmax(model_out['classifier_out_tanh'], axis=1)
  classifier_acc_tanh = jnp.mean(classifier_classes_tanh == labels)
  ensemble_classes = jnp.argmax(model_out['mean_value'], axis=1)
  ensemble_acc = jnp.mean(ensemble_classes == labels)
  return (classifier_acc_van, classifier_acc_tanh, ensemble_acc)

In [None]:
# Update ensemble memory parameters

@jit
def update_model(params, rng, opt_state,
                 enc_images, one_hots, labels):
  opt = make_ensemble_optimiser()
  grads = grad(model_loss)(params, rng, enc_images, one_hots, labels)
  updates, opt_state = opt.update(grads, opt_state, params)
  new_params = optax.apply_updates(params, updates)
  return new_params, opt_state

# Training

In [None]:
def get_encoder():
  encoder_net = hk.transform(encoder)
  return encoder_net


def initialise_encoder(rng):
  """Initialise internal encoder for pre-training."""
  encoder_net = get_encoder()
  # Get dummy batch
  batch_size =  24
  train_set = get_dataset(config['pretrain_dataset'], 'train', batch_size)
  batch = get_batch(train_set, config['pretrain_dataset'])
  (images, _, _) = batch
  (rng2, rng) = random.split(rng)
  (rng3, rng) = random.split(rng)
  enc_params = encoder_net.init(rng2, rng3, images)
  return enc_params


@jit
def apply_encoder(enc_params, rng, images):
  encoder_net = get_encoder()
  (rng2, rng) = random.split(rng)
  encoder_out = encoder_net.apply(enc_params, rng, rng2, images)
  enc = encoder_out['enc_mean']
  return enc

In [None]:
# Autoencoder training


def pretrain_encoder(rng, n_batches, filter_labels=None):

  print('Encoder pre-training on {}'.format(config['pretrain_dataset']))
  print()
  # Get train and test data
  train_batch_size = 48
  test_batch_size = 256
  train_set = get_dataset(config['pretrain_dataset'], 'train',
                          train_batch_size, filter_labels=filter_labels)
  test_set = get_dataset(config['pretrain_dataset'], 'test',
                         test_batch_size, filter_labels=filter_labels)
  test_batch = get_batch(test_set, config['pretrain_dataset'])
  # Train encoders until a good enough one is found
  success = False
  loss_threshold = 0.025
  while not success:
    # Initialise parameters
    (rng2, rng) = random.split(rng)
    enc_params = initialise_encoder(rng2)
    # Initialise optimiser
    opt = make_encoder_optimiser()
    opt_state = opt.init(enc_params)
    # Training
    for i in trange(n_batches, desc='Training'):
      batch = get_batch(train_set, config['pretrain_dataset'])
      (images, _, _) = batch
      (rng2, rng) = random.split(rng)
      (enc_params, opt_state) = update_autoencoder(enc_params, rng2,
                                                   opt_state, images)
    (images, _, _) = test_batch
    (rng2, rng) = random.split(rng)
    losses = autoencoder_losses(enc_params, rng2, images)
    print('Batch {}'.format(i+1))
    print('Reconstruction loss {:.8f}'.format(losses['decoder_loss']))
    print('KL loss {:.4f}'.format(losses['kl_loss']))
    print()
    success = losses['decoder_loss'] < loss_threshold
    if not success:
      print('Reconstruction loss too high - retraining')
      print()
  (rng2, rng) = random.split(rng)

  return enc_params

In [None]:
# Model testing with recording of accuracies


def test_model(model_params, enc_params, test_labels, test_batches,
               accuracies, batch_number):
  rng = random.PRNGKey(42)
  for i in range(len(test_labels)):
    (images, _, labels) = test_batches[i]
    # Encode images
    (rng2, rng) = random.split(rng)
    enc_images = apply_encoder(enc_params, rng2, images)
    # Get accuracies
    (rng2, rng) = random.split(rng)
    (classifier_acc_vanilla,
     classifier_acc_tanh, ensemble_acc) = accuracy(model_params, rng2,
                                                   enc_images, labels)
    accuracies['accuracies_vanilla'][i].append(classifier_acc_vanilla)
    accuracies['accuracies_tanh'][i].append(classifier_acc_tanh)
    accuracies['accuracies_ensemble'][i].append(ensemble_acc)
  return accuracies

In [None]:
# Model training with pre-trained encoder


def initialise_model(rng, enc_params):
  model_net = hk.transform(model)
  # Get dummy batch
  train_batch_size = 24
  train_set = get_dataset(config['main_dataset'], 'train', train_batch_size)
  batch = get_batch(train_set, config['main_dataset'])
  (images, _, _) = batch
  # Encode images
  (rng2, rng) = random.split(rng)
  enc_images = apply_encoder(enc_params, rng2, images)
  # Initialise the model
  (rng2, rng) = random.split(rng)
  model_params = model_net.init(rng2, enc_images)
  # Initialise memory keys according to encoding stats
  (rng2, rng) = random.split(rng)
  keys = jax.random.normal(rng2, [config['mem_size'], config['enc_size']])
  new_key_params = {'ensemble_model/~/memory': {'mem_keys': keys}}
  (old_key_params, rest_params) = hk.data_structures.partition(
      lambda m, n, p: (m == 'ensemble_model/~/memory' and
                       n == 'mem_keys'), model_params)
  model_params = hk.data_structures.merge(new_key_params, rest_params)
  return model_params


def train_model(model_params, enc_params, rng,
                label_set, test_labels, test_batch, accuracies, run,
                n_batches=0, tot_batches=0):

  # Get train and test data
  train_batch_size = config['batch_size']
  train_set = get_dataset(config['main_dataset'], 'train', train_batch_size,
                          filter_labels=label_set)
  # Training
  for i in range(n_batches):
    # Re-initialise optimiser
    opt = make_ensemble_optimiser()
    opt_state = opt.init(model_params)
    tot_batches += 1
    batch = get_batch(train_set, config['main_dataset'])
    (images, one_hots, labels) = batch
    # Encode images
    (rng2, rng) = random.split(rng)
    enc_images = apply_encoder(enc_params, rng2, images)
    (rng2, rng) = random.split(rng)
    (model_params, opt_state) = update_model(model_params, rng2,
                                             opt_state, enc_images,
                                             one_hots, labels)
    # Log accuracies
    if tot_batches % config['log_every'] == 0:
      accuracies = test_model(model_params, enc_params, test_labels, test_batch,
                              accuracies, tot_batches)
  return (model_params, accuracies, tot_batches)

In [None]:
def train_with_schedule(model_params, enc_params, rng,
                        schedule, test_labels, test_batch, run):
  """Train the model with a given schedule of label sets (tasks)."""

  # Initial accuracy
  tot_batches = 0
  accuracies = {
      'accuracies_vanilla': [[] for _ in range(len(test_labels))],
      'accuracies_tanh': [[] for _ in range(len(test_labels))],
      'accuracies_ensemble': [[] for _ in range(len(test_labels))]}
  accuracies = test_model(model_params, enc_params, test_labels, test_batch,
                          accuracies, tot_batches)
  # Go through the schedule
  for ep_no in trange(len(schedule), desc='Training Schedule'):
    episode = schedule[ep_no]
    label_set = episode['label_set']
    n_batches = episode['n_batches']
    (rng2, rng) = random.split(rng)
    (model_params, accuracies, tot_batches) = train_model(
        model_params, enc_params, rng2,
        label_set, test_labels, test_batch, accuracies, run,
        n_batches=n_batches, tot_batches=tot_batches)

  return (model_params, accuracies)

# Schedules

In [None]:
def gaussian(peak, width, x):
  out = jnp.exp(- ((x - peak)**2 / (2 * width**2)))
  return out


def gaussian_schedule(rng):
  """Returns a schedule where one task blends smoothly into the next."""

  schedule_length = 1000  # schedule length in batches
  episode_length = 5  # episode length in batches

  # Each class label appears according to a Gaussian probability distribution
  # with peaks spread evenly over the schedule
  peak_every = schedule_length // config['num_classes']
  width = 50  # width of Gaussian
  peaks = range(peak_every // 2, schedule_length, peak_every)

  schedule = []
  labels = jnp.array(list(range(config['num_classes'])))
  labels = random.permutation(rng, labels)  # labels in random order

  for ep_no in range(0, schedule_length // episode_length):

    lbls = []
    while lbls == []:  # make sure lbls isn't empty
      for j in range(len(peaks)):
        peak = peaks[j]
        # Sample from a Gaussian with peak in the right place
        p = gaussian(peak, width, ep_no * episode_length)
        (rng2, rng) = jax.random.split(rng)
        add = jax.random.bernoulli(rng2, p=p)
        if add:
          lbls.append(int(labels[j]))

    episode = {'label_set': lbls, 'n_batches': episode_length}
    schedule.append(episode)

  return schedule

In [None]:
def get_schedule(name, rng=None):

  # Full set of labels
  test_labels = [[x] for x in range(10)]

  if name == '1way_split':
    # 1-way split schedule (multi-task setting) (MNIST or CIFAR-10)
    schedule = [{'label_set': list(range(10)),
                'n_batches': 1000}]

  elif name == '2way_split':
    # Random 2-way split schedule (MNIST or CIFAR-10)
    lbls = jnp.array(list(range(10)))
    lbls = random.permutation(rng, lbls)
    lbls = jnp.reshape(lbls, (2, 5))
    schedule = [{'label_set': lbl.tolist(), 'n_batches': 500} for lbl in lbls]

  elif name == '5way_split':
    # Random 5-way split schedule (MNIST or CIFAR-10)
    lbls = jnp.array(list(range(10)))
    lbls = random.permutation(rng, lbls)
    lbls = jnp.reshape(lbls, (5, 2))
    schedule = [{'label_set': lbl.tolist(), 'n_batches': 200} for lbl in lbls]

  elif name == '10way_split':
    # Random 10-way split schedule (MNIST or CIFAR-10)
    lbls = jnp.array(list(range(10)))
    lbls = random.permutation(rng, lbls)
    lbls = jnp.reshape(lbls, (10, 1))
    schedule = [{'label_set': lbl.tolist(), 'n_batches': 100} for lbl in lbls]

  elif name == 'gaussian_schedule':
    # Gaussian schedule
    schedule = gaussian_schedule(rng)

  else:
    print('Error: no such schedule')
    print()

  return (schedule, test_labels)

# Scripts

In [None]:
# Training scripts


def get_test_batches(test_labels):
  """Split test batch up into label-wise batches according to 'test_labels'."""
  batch_size = 10000 // len(test_labels)
  test_batches = []
  for labels in test_labels:
    dataset = get_dataset(config['main_dataset'], 'test', batch_size,
                          filter_labels=labels)
    test_batch = get_batch(dataset, config['main_dataset'])
    test_batches.append(test_batch)
  return test_batches


def main():
  """Main script - multiple runs."""

  print('STARTING EXPERIMENT')
  print()
  schedule_name = config['schedule_type']
  rng = random.PRNGKey(78)
  n_runs = config['n_runs']
  # Ensemble training
  x_accuracies1 = []  # vanilla classifier accuracies
  x_accuracies2 = []  # tanh classifier accuracies
  x_accuracies3 = []  # ensemble accuracies
  for run in range(n_runs):
    print('RUN {} of {}'.format(run+1, n_runs))
    print()
    # Encoder pretraining - new encoder every run
    (rng2, rng) = random.split(rng)
    enc_params = pretrain_encoder(rng2, n_batches=config['pretrain_n_batches'])
    # Get a schedule
    (rng2, rng) = random.split(rng)
    (schedule, test_labels) = get_schedule(schedule_name, rng2)
    # Get batches for testing
    test_batches = get_test_batches(test_labels)
    # Train the model
    (rng2, rng) = random.split(rng)
    model_params = initialise_model(rng2, enc_params)
    # Carry out schedule
    print('Ensemble training on {}'.format(config['main_dataset']))
    print()
    (rng2, rng) = random.split(rng)
    (model_params, accuracies) = train_with_schedule(model_params, enc_params,
                                                     rng2, schedule,
                                                     test_labels, test_batches,
                                                     run)
    # Record the results
    x_accuracies1.append(accuracies['accuracies_vanilla'])
    x_accuracies2.append(accuracies['accuracies_tanh'])
    x_accuracies3.append(accuracies['accuracies_ensemble'])
    plot_x_accuracies(x_accuracies1, x_accuracies2, x_accuracies3)
  print('FINAL PLOT')
  print()
  plot_x_accuracies(x_accuracies1, x_accuracies2, x_accuracies3, final=True)

In [None]:
# Main script

main()