####Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.

####Licensed under the Apache License, Version 2.0 (the "License");

####Full license text

In [0]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Epistemic neural network demo

Demo colab for ENN library.


## Install ENN

In [0]:
#@title Installing enn
!git clone https://github.com/deepmind/enn.git
!pip install -q enn/

## Imports

In [0]:
#@title General imports

import warnings

warnings.filterwarnings('ignore')

In [None]:
#@title Development imports
from acme.utils.loggers.terminal import TerminalLogger
import dataclasses
import optax

In [0]:
#@title ENN imports
from enn import losses
from enn import networks
from enn import supervised
from enn import data_noise
from enn.supervised import classification_data, regression_data

## Parameters set using dataclass

In [0]:
@dataclasses.dataclass
class Config:
  num_batch: int = 1_000
  index_dim: int = 10
  num_index_samples: int = 10
  seed: int = 0
  prior_scale: float = 5.
  learning_rate: float = 1e-3
  noise_std: float = 0.1


FLAGS = Config()

## Regression problem with 1D output



In [0]:
#@title Create the regression experiment

# Generate dataset
dataset = regression_data.make_dataset()

# Logger
logger = TerminalLogger('supervised_regression')

# Create Ensemble ENN with a prior network
enn = networks.MLPEnsembleMatchedPrior(
    output_sizes=[50, 50, 1],
    dummy_input=next(dataset).x,
    num_ensemble=FLAGS.index_dim,
    prior_scale=FLAGS.prior_scale,
    seed=FLAGS.seed,
)

# L2 loss on perturbed outputs
noise_fn = data_noise.GaussianTargetNoise(enn, FLAGS.noise_std, FLAGS.seed)
single_loss = losses.add_data_noise(losses.L2Loss(), noise_fn)
loss_fn = losses.average_single_index_loss(single_loss, FLAGS.num_index_samples)

# Optimizer
optimizer = optax.adam(FLAGS.learning_rate)

# Aggregating different components of the experiment
experiment = supervised.Experiment(
    enn, loss_fn, optimizer, dataset, FLAGS.seed, logger=logger)

In [0]:
# Train the experiment
experiment.train(FLAGS.num_batch)

In [0]:
#@title Plot the output
p = regression_data.make_plot(experiment)
_ = p.draw()

## Classification problem with 2 labels



In [0]:
#@title Create the classification experiment

# Generate dataset
dataset = classification_data.make_dataset()
num_classes = 2

# Logger
logger = TerminalLogger('supervised_regression')

# Create Ensemble ENN with a prior network
enn = networks.MLPEnsembleMatchedPrior(
    output_sizes=[50, 50, num_classes],
    dummy_input=next(dataset).x,
    num_ensemble=FLAGS.index_dim,
    prior_scale=FLAGS.prior_scale,
    seed=FLAGS.seed,
)

# Cross entropy loss
loss_fn = losses.average_single_index_loss(
    single_loss=losses.XentLoss(num_classes),
    num_index_samples=FLAGS.num_index_samples
)

# Optimizer
optimizer = optax.adam(FLAGS.learning_rate)

# Aggregating different components of the experiment
experiment = supervised.Experiment(
    enn, loss_fn, optimizer, dataset, FLAGS.seed, logger=logger)

In [0]:
# Train the experiment
experiment.train(FLAGS.num_batch)

In [0]:
#@title Plot the output
classification_data.colab_plots(experiment)