Copyright 2021 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of the
License at

[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)

Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

# RL Unplugged: Offline DQN - Bsuite
## Guide to  training an Acme DQN agent on Bsuite data.
# <a href="https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/rl_unplugged/atari_dqn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>





In [None]:
# @title Installation
!pip install dm-acme
!pip install dm-acme[reverb]
!pip install dm-acme[tf]
!pip install dm-sonnet
!pip install dopamine-rl==3.1.2
!pip install atari-py
!pip install dm_env
!git clone https://github.com/deepmind/deepmind-research.git
%cd deepmind-research

!git clone https://github.com/deepmind/bsuite.git
!pip install -q bsuite/

In [None]:
# @title Imports
import copy
import functools
from typing import Dict, Tuple


import acme
from acme.agents.tf import actors
from acme.agents.tf.dqn import learning as dqn
from acme.tf import utils as acme_utils
from acme.utils import loggers

import sonnet as snt
import tensorflow as tf

import numpy as np
import tree
import dm_env
import reverb
from acme.wrappers import base as wrapper_base
from acme.wrappers import single_precision

import bsuite

In [None]:
# @title Data Loading Utilities
def _parse_seq_tf_example(example, shapes, dtypes):
    """Parse tf.Example containing one or two episode steps."""

    def to_feature(shape, dtype):
        if np.issubdtype(dtype, np.floating):
            return tf.io.FixedLenSequenceFeature(
                shape=shape, dtype=tf.float32, allow_missing=True)
        elif dtype == np.bool or np.issubdtype(dtype, np.integer):
            return tf.io.FixedLenSequenceFeature(
                shape=shape, dtype=tf.int64, allow_missing=True)
        else:
            raise ValueError(f'Unsupported type {dtype} to '
            f'convert from TF Example.')

    feature_map = {}
    for k, v in shapes.items():
        feature_map[k] = to_feature(v, dtypes[k])

    parsed = tf.io.parse_single_example(example, features=feature_map)

    restructured = {}
    for k, v in parsed.items():
        dtype = tf.as_dtype(dtypes[k])
        if v.dtype == dtype:
            restructured[k] = parsed[k]
        else:
            restructured[k] = tf.cast(parsed[k], dtype)

    return restructured


def _build_sars_example(sequences):
    """Convert raw sequences into a Reverb SARS' sample."""

    o_tm1 = tree.map_structure(lambda t: t[0], sequences['observation'])
    o_t = tree.map_structure(lambda t: t[1], sequences['observation'])
    a_tm1 = tree.map_structure(lambda t: t[0], sequences['action'])
    r_t = tree.map_structure(lambda t: t[0], sequences['reward'])
    p_t = tree.map_structure(
        lambda d, st: d[0] * tf.cast(st[1] != dm_env.StepType.LAST, d.dtype),
        sequences['discount'], sequences['step_type'])

    info = reverb.SampleInfo(key=tf.constant(0, tf.uint64),
                             probability=tf.constant(1.0, tf.float64),
                             table_size=tf.constant(0, tf.int64),
                             priority=tf.constant(1.0, tf.float64))
    return reverb.ReplaySample(info=info, data=(
        o_tm1, a_tm1, r_t, p_t, o_t))


def bsuite_dataset_params(env):
    """Return shapes and dtypes parameters for bsuite offline dataset."""
    shapes = {
        'observation': env.observation_spec().shape,
        'action': env.action_spec().shape,
        'discount': env.discount_spec().shape,
        'reward': env.reward_spec().shape,
        'episodic_reward': env.reward_spec().shape,
        'step_type': (),
    }

    dtypes = {
        'observation': env.observation_spec().dtype,
        'action': env.action_spec().dtype,
        'discount': env.discount_spec().dtype,
        'reward': env.reward_spec().dtype,
        'episodic_reward': env.reward_spec().dtype,
        'step_type': np.int64,
    }

    return {'shapes': shapes, 'dtypes': dtypes}


def bsuite_dataset(path: str,
                   shapes: Dict[str, Tuple[int]],
                   dtypes: Dict[str, type],  # pylint:disable=g-bare-generic
                   num_threads: int,
                   batch_size: int,
                   num_shards: int,
                   shuffle_buffer_size: int = 100000,
                   shuffle: bool = True) -> tf.data.Dataset:
    """Create tf dataset for training."""

    filenames = [f'{path}-{i:05d}-of-{num_shards:05d}' for i in range(
        num_shards)]
    file_ds = tf.data.Dataset.from_tensor_slices(filenames)
    if shuffle:
      file_ds = file_ds.repeat().shuffle(num_shards)

    example_ds = file_ds.interleave(
        functools.partial(tf.data.TFRecordDataset, compression_type='GZIP'),
        cycle_length=tf.data.experimental.AUTOTUNE,
        block_length=5)
    if shuffle:
      example_ds = example_ds.shuffle(shuffle_buffer_size)

    def map_func(example):
        example = _parse_seq_tf_example(example, shapes, dtypes)
        return example

    example_ds = example_ds.map(map_func, num_parallel_calls=num_threads)
    if shuffle:
      example_ds = example_ds.repeat().shuffle(batch_size * 10)

    example_ds = example_ds.map(
        _build_sars_example,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    example_ds = example_ds.batch(batch_size, drop_remainder=True)

    example_ds = example_ds.prefetch(tf.data.experimental.AUTOTUNE)

    return example_ds


def load_offline_bsuite_dataset(
        bsuite_id: str,
        path: str,
        batch_size: int,
        num_shards: int = 1,
        num_threads: int = 1,
        single_precision_wrapper: bool = True,
        shuffle: bool = True) -> Tuple[tf.data.Dataset,
                                                        dm_env.Environment]:
    """Load bsuite offline dataset."""
    # Data file path format: {path}-?????-of-{num_shards:05d}
    # The dataset is not deterministic and not repeated if shuffle = False.
    environment = bsuite.load_from_id(bsuite_id)
    if single_precision_wrapper:
        environment = single_precision.SinglePrecisionWrapper(environment)
    params = bsuite_dataset_params(environment)
    dataset = bsuite_dataset(path=path,
                      num_threads=num_threads,
                      batch_size=batch_size,
                      num_shards=num_shards,
                      shuffle_buffer_size=2,
                      shuffle=shuffle,
                      **params)
    return dataset, environment

## Dataset and environment

In [None]:
tmp_path = 'gs://rl_unplugged/bsuite'
level = 'catch'
dir = '0_0.0'
filename = '0_full'
path = f'{tmp_path}/{level}/{dir}/{filename}'

In [None]:
batch_size = 2  #@param
bsuite_id = level + '/0'
dataset, environment = load_offline_bsuite_dataset(bsuite_id=bsuite_id,
                                                   path=path,
                                                   batch_size=batch_size)
dataset = dataset.prefetch(1)

## DQN learner

In [None]:
# Get total number of actions.
num_actions = environment.action_spec().num_values
obs_spec = environment.observation_spec()
print(environment.observation_spec())
# Create the Q network.
network = snt.Sequential([
      snt.flatten,
      snt.nets.MLP([56, 56]),
      snt.nets.MLP([num_actions])
  ])
acme_utils.create_variables(network, [environment.observation_spec()])


In [None]:
# Create a logger.
logger = loggers.TerminalLogger(label='learner', time_delta=1.)

# Create the DQN learner.
learner = dqn.DQNLearner(
    network=network,
    target_network=copy.deepcopy(network),
    discount=0.99,
    learning_rate=3e-4,
    importance_sampling_exponent=0.2,
    target_update_period=2500,
    dataset=dataset,
    logger=logger)

## Training loop

In [None]:
for _ in range(10000):
  learner.step()

## Evaluation

In [None]:
# Create a logger.
logger = loggers.TerminalLogger(label='evaluation', time_delta=1.)

# Create an environment loop.
policy_network = snt.Sequential([
    network,
    lambda q: tf.argmax(q, axis=-1),
])
loop = acme.EnvironmentLoop(
    environment=environment,
    actor=actors.FeedForwardActor(policy_network=policy_network),
    logger=logger)

loop.run(400)