Copyright 2020 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of the
License at

[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)

Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

# RL Unplugged: CRR agent with GPU/TPU support - DM control

## Guide to training an Acme CRR agent on DM control data.
# <a href="https://colab.research.google.com/github/deepmind/deepmind-research/blob/master/rl_unplugged/dm_control_suite_crr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


## Installation

In [None]:
!pip install git+https://github.com/deepmind/acme.git#egg=dm-acme
!pip install dm-acme[reverb]
!pip install dm-acme[tf]
!pip install dm-sonnet
!git clone https://github.com/deepmind/deepmind-research.git
%cd deepmind-research

### dm_control

More detailed instructions in [this tutorial](https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb#scrollTo=YvyGCsgSCxHQ).

#### Institutional MuJoCo license.

In [None]:
#@title Edit and run
mjkey = """

REPLACE THIS LINE WITH YOUR MUJOCO LICENSE KEY

""".strip()

mujoco_dir = "$HOME/.mujoco"

# Install OpenGL deps
!apt-get update && apt-get install -y --no-install-recommends \
  libgl1-mesa-glx libosmesa6 libglew2.0

# Fetch MuJoCo binaries from Roboti
!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip
!unzip -o -q mujoco.zip -d "$mujoco_dir"

# Copy over MuJoCo license
!echo "$mjkey" > "$mujoco_dir/mjkey.txt"


# Configure dm_control to use the OSMesa rendering backend
%env MUJOCO_GL=osmesa

# Install dm_control
!pip install dm_control

#### Machine-locked MuJoCo license.

In [None]:
#@title Add your MuJoCo License and run
mjkey = """
""".strip()

mujoco_dir = "$HOME/.mujoco"

# Install OpenGL dependencies
!apt-get update && apt-get install -y --no-install-recommends \
  libgl1-mesa-glx libosmesa6 libglew2.0

# Get MuJoCo binaries
!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip
!unzip -o -q mujoco.zip -d "$mujoco_dir"

# Copy over MuJoCo license
!echo "$mjkey" > "$mujoco_dir/mjkey.txt"

# Install dm_control
!pip install dm_control[locomotion_mazes]

# Configure dm_control to use the OSMesa rendering backend
%env MUJOCO_GL=osmesa

## Imports

In [None]:
import copy
from typing import Sequence
import acme
from acme import specs
from acme.agents.tf import actors
from acme.agents.tf import crr
from acme.tf import networks as acme_networks
from acme.tf import utils as tf2_utils
from acme.utils import loggers
import numpy as np
from rl_unplugged import dm_control_suite
from rl_unplugged import networks
import sonnet as snt
import tensorflow as tf

## Data

In [None]:
task_name = 'cartpole_swingup' #@param
gs_path = 'gs://rl_unplugged/dm_control_suite'

num_shards_str, = !gsutil ls {gs_path}/{task_name}/* | wc -l
num_shards = int(num_shards_str)

## Dataset and environment

In [None]:
batch_size = 256  #@param

task = dm_control_suite.ControlSuite(task_name)

environment = task.environment
environment_spec = specs.make_environment_spec(environment)

## Networks

In [None]:
def make_networks(
    action_spec: specs.BoundedArray,
    policy_lstm_sizes: Sequence[int] = None,
    critic_lstm_sizes: Sequence[int] = None,
    num_components: int = 5,
    vmin: float = 0.,
    vmax: float = 100.,
    num_atoms: int = 21,
):
  """Creates recurrent networks with GMM head used by the agents."""

  action_size = np.prod(action_spec.shape, dtype=int)
  actor_head = acme_networks.MultivariateGaussianMixture(
      num_components=num_components, num_dimensions=action_size)

  if policy_lstm_sizes is None:
    policy_lstm_sizes = [1024, 1024]
  if critic_lstm_sizes is None:
    critic_lstm_sizes = [1024, 1024]

  actor_neck = acme_networks.LayerNormAndResidualMLP(hidden_size=1024,
                                                     num_blocks=4)
  actor_encoder = networks.ControlNetwork(
      proprio_encoder_size=300,
      activation=tf.nn.relu)

  policy_lstms = [snt.LSTM(s) for s in policy_lstm_sizes]

  policy_network = snt.DeepRNN([actor_encoder, actor_neck] + policy_lstms +
                               [actor_head])

  critic_encoder = networks.ControlNetwork(
      proprio_encoder_size=400,
      activation=tf.nn.relu)
  critic_neck = acme_networks.LayerNormAndResidualMLP(
      hidden_size=1024, num_blocks=4)
  distributional_head = acme_networks.DiscreteValuedHead(
      vmin=vmin, vmax=vmax, num_atoms=num_atoms)
  critic_lstms = [snt.LSTM(s) for s in critic_lstm_sizes]
  critic_network = acme_networks.CriticDeepRNN([critic_encoder, critic_neck] +
                                                critic_lstms + [
                                                    distributional_head,
                                                ])

  return {
      'policy': policy_network,
      'critic': critic_network,
  }

## Set up TPU if present

In [None]:
try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  tf.config.experimental_connect_to_cluster(tpu)
  tf.tpu.experimental.initialize_tpu_system(tpu)
  accelerator_strategy = snt.distribute.TpuReplicator()
  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
  print('Running on CPU or GPU (no TPUs available)')
  accelerator_strategy = snt.distribute.Replicator()


## CRR learner

In [None]:
action_spec = environment_spec.actions
action_size = np.prod(action_spec.shape, dtype=int)

with accelerator_strategy.scope():
  dataset = dm_control_suite.dataset(
    'gs://rl_unplugged/',
    data_path=task.data_path,
    shapes=task.shapes,
    uint8_features=task.uint8_features,
    num_threads=1,
    batch_size=batch_size,
    num_shards=num_shards,
    sarsa=False)
  # CRR learner assumes that the dataset samples don't have metadata,
  # so let's remove it here.
  dataset = dataset.map(lambda sample: sample.data)
  nets = make_networks(action_spec)
  policy_network, critic_network = nets['policy'], nets['critic']

  # Create the target networks
  target_policy_network = copy.deepcopy(policy_network)
  target_critic_network = copy.deepcopy(critic_network)

  # Create variables.
  tf2_utils.create_variables(network=policy_network,
                            input_spec=[environment_spec.observations])
  tf2_utils.create_variables(network=critic_network,
                            input_spec=[environment_spec.observations,
                                        environment_spec.actions])
  tf2_utils.create_variables(network=target_policy_network,
                            input_spec=[environment_spec.observations])
  tf2_utils.create_variables(network=target_critic_network,
                            input_spec=[environment_spec.observations,
                                        environment_spec.actions])

# The learner updates the parameters (and initializes them).
learner = crr.RCRRLearner(
    policy_network=policy_network,
    critic_network=critic_network,
    accelerator_strategy=accelerator_strategy,
    target_policy_network=target_policy_network,
    target_critic_network=target_critic_network,
    dataset=dataset,
    discount=0.99,
    target_update_period=100)

## Training loop

In [None]:
# Run
#   tf.config.run_functions_eagerly(True)
# if you want to debug the code in eager mode.

for _ in range(100):
  learner.step()

## Evaluation

In [None]:
# Create a logger.
logger = loggers.TerminalLogger(label='evaluation', time_delta=1.)

# Create an environment loop.
loop = acme.EnvironmentLoop(
    environment=environment,
    actor=actors.RecurrentActor(policy_network),
    logger=logger)

loop.run(5)