# World Models Library

World Models Library is a pure python tool facilitating planning from pixels. The major components are the following:
*   **World Model:** A *world model* is an stateless model of the environment that is used for planning. Any world model should implement the following three methods:
  *   `reset_fn(**kwargs) -> state`. This function is responsible for resetting the state of the world model and is called at the beginning of every episode. The *state* can have any structure.
  *   `observe(last_frames, last_actions, last_rewards, state) -> state`. This function is responsible for updating the state of the world model given the latest observed changes in the environment.
  *   `predict_fn(future_actions, state) -> predictions`. This function predicts the future and will be used by planners to evaluate action proposals. The *predictions* should be compatible with the `objective_fn` that is passed to the planner. This function should not change the input *state*.

  Any state that is required for the operation of the model (e.g. recurrent layer state) should be a part of `state` object that is passed around and returned from `observe_fn`. 
*   **Planner:** A *planner* decides what actions to take next with the help of a *world model*. The planner is responsible to keep track of the world model's *state* and call `reset_fn`, `observe_fn` and `predict_fn` to interact with the *world model*.
*   **Task:** A *task* is a thin wrapper around `gym.Env` that adds a name to the underlying environment and provides convinient factory methods to instantiate environments.

![World Models Library Components](https://i.imgur.com/JmuCcRI.png)

## Colab Setup

In [None]:
#@title Imports
import tensorflow.compat.v1 as tf
from world_models.simulate import simulate
from world_models.agents import planet
from world_models.planners import planners
from world_models.objectives import objectives
from world_models.simulate import simulate
from world_models.utils import npz
from world_models.loops import train_eval
from world_models.tasks import tasks

tf.enable_eager_execution()

%load_ext tensorboard

## Task

As a first step, we should choose a task to solve. There are several task suites already defined in the `tasks/tasks.py` file. We will use DeepMind Control's Cheetah task as example.

In [None]:
task = tasks.DeepMindControl(domain_name='cheetah',
                             task_name='run',
                             action_repeat=4)

## World Model

There are already a few options available defined in the `agents` folder including *PlaNet*, *SV2P*, etc. For this colab we will instantiate a *PlaNet* agent. 

In [None]:
model = planet.RecurrentStateSpaceModel(task=task)
model_dir = '/tmp/experiment/model'
dist_strategy = tf.distribute.MirroredStrategy()

reset_fn = planet.create_planet_reset_fn(model=model)
observe_fn = planet.create_planet_observe_fn(model=model,
                                             model_dir=model_dir,
                                             strategy=dist_strategy)
predict_fn = planet.create_planet_predict_fn(model=model,
                                             strategy=dist_strategy)

In addition to `reset_fn`, `observe_fn` and `predict_fn`, we also need to define a `train_fn` as an extra hook to train the model on the latest collected episodes, with this signature: `train_fn(data_path) -> None`. There are utility functions for fast data processing in the `utils/npz.py` that can be used in a training loop but the library is agnostic to how training/checkpointing/restoring is done.

In [None]:
train_steps = 100  # How many training steps per episode
batch = 50 
duration = 50  # How many timesteps in a single training sequence
learning_rate = 1e-3

train_fn = planet.create_planet_train_fn(model=model,
                                         train_steps=train_steps,
                                         batch=batch,
                                         duration=duration,
                                         learning_rate=learning_rate,
                                         model_dir=model_dir,
                                         strategy=dist_strategy)

## Planner

The *planner* is responsible for decision making. It can use the world model to make predictions about the future and make informed decisions about which actions to take next. We normally need **separate planners** for training and evaluation, since we might need some sort of exploration during training that is not applicable for evaluation. We implemented a few planners in the `planners/planners.py` file including continous and discrete *Cross Entropy Method (CEM)*. Diagram below, shows how CEM iteratively refines itself to choose an optimal action.
![CEM iteration](https://i.imgur.com/2iUmnIK.png)

A planner also needs an `objective_fn` to compute scores from *world model* predictions. If a world model predicts rewards directly, we can use a `DiscountedReward` objective.

In [None]:
objective_fn = objectives.DiscountedReward()

horizon = 12  # CEM planning horizon
iterations = 10  # CEM iterations
proposals = 1000  # Number of proposals to evaluate per iteration
top_fraction = 0.1  # Fraction of proposals with highest scores for fitting

# Base CEM planner to use for evaluation.
base_cem = planners.CEM(predict_fn=predict_fn,
                        observe_fn=observe_fn,
                        reset_fn=reset_fn,
                        task=task,
                        objective_fn=objective_fn,
                        horizon=horizon,
                        iterations=iterations,
                        proposals=proposals,
                        fraction=top_fraction)

# Training CEM planner with initial random cold start and random noise.
# Pure random actions for the first `n` episodes to bootstrap the world model.
random_cold_start_episodes = 5
train_cem = planners.RandomColdStart(task=task,
                                     random_episodes=random_cold_start_episodes,
                                     base_planner=base_cem)
# Add some Gaussian noise for active exploration.
noise_scale = 0.3
train_cem = planners.GaussianRandomNoise(task=task,
                                         stdev=noise_scale,
                                         base_planner=train_cem)

## Simulation

In order to run an agent on the task, we can use the `simulate` function in the `simulate/simulate.py` file. Below is a diagram of the chain of events during an episode. 

![Simulation logic](https://i.imgur.com/JjwNHwj.png)

In [None]:
episode_num = 0
train_data_dir = '/tmp/experiment/data/train'
train_summary_dir = '/tmp/experiment/train'
episodes = list()

for i in range(random_cold_start_episodes):
  episode, predictions, score = simulate.simulate(task=task,
                                                  planner=train_cem,
                                                  num_episodes=1)
  scalar_summaries = {'score': score}
  train_eval.visualize(summary_dir=train_summary_dir,
                       global_step=i,
                       episodes=episode,
                       predictions=predictions,
                       scalars=scalar_summaries)
  episodes.extend(episode)
  episode_num += 1

In [None]:
%tensorboard --logdir=/tmp/experiment/ --port=0

We normally need to update our *world model* periodically on all the collected episodes so far, therefore we need to interleave simulation with model training. Since the size of collected episodes will grow over time, we should persist them to disk and use optimized/cacheable data iterators for training. Utility functions in `utils/npz.py` can be used here.

In [None]:
npz.save_dictionaries(episodes, train_data_dir)
train_fn(train_data_dir)

In [None]:
%tensorboard --logdir=/tmp/experiment/ --port=0

Now we can evaluate our agent by using the `base_planner` that is noise free.

In [None]:
eval_summary_dir = '/tmp/experiment/eval'
episode, predictions, score = simulate.simulate(task=task,
                                                planner=base_cem,
                                                num_episodes=1)
scalar_summaries = {'score': score}
train_eval.visualize(summary_dir=eval_summary_dir,
                     global_step=i,
                     episodes=episode,
                     predictions=predictions,
                     scalars=scalar_summaries)

In [None]:
%tensorboard --logdir=/tmp/experiment/ --port=0

## Off the Shelf Train-Eval Loop

A utility function named `train_eval_loop` in `loops/train_eval.py` encapsulates training, evaluating, data collection and tensorboard summary writing all in the same place. If this off the shelf functionality is sufficient we recommend using it instead of implementing them from lower level functions as depicted above.
![Architecture diagram](https://i.imgur.com/VjHWDhx.png)

In [None]:
train_episodes_per_iter = 1  # How many training episodes to collect per train/eval iteration
eval_every_n_iters = 10  # A single eval episode every n iterations
num_iters = 100  # Total number of train/eval iterations
data_dir = '/tmp/experiment/loop/data/'
model_dir = '/tmp/experiment/loop/model'

train_eval.train_eval_loop(task=task,
                           train_planner=train_cem,
                           eval_planner=base_cem,
                           train_fn=train_fn,
                           num_train_episodes_per_iteration=train_episodes_per_iter,
                           eval_every_n_iterations=eval_every_n_iters,
                           num_iterations=num_iters,
                           episodes_dir=data_dir,
                           model_dir=model_dir
                           )

In [None]:
%tensorboard --logdir=/tmp/experiment/ --port=0

## Gin Configs

Finally, it is important to note that all of the above functionalities described above are gin configurable. However we need to provide a gin config to correctly instantiate the task, model, planner and other parameters to assemble an experiment. There are example configs in the `configs/` folder.

The main binary `bin/train_eval.py` defines two bindings for `model_dir=<output_dir>/model` and `episodes_dir=<output_dir>/episodes` that are populated from commandline argument `output_dir`. Any gin configurable component that might need either `model_dir` or `episodes_dir`, should use bindings like `<component>.<property>=%model_dir` instead of hard-coding it in the gin config. Below is an example gin config to instantiate a PlaNet agent.

```
import google3.learning.brain.research.world_models.api.agents.planet

# Parameters for model:
RecurrentStateSpaceModel.frame_size = (64, 64, 3)
RecurrentStateSpaceModel.reward_stop_gradient = False
RecurrentStateSpaceModel.task = %TASK
# Use singleton to inject the same instance of model later on.
MODEL = @model/singleton()
model/singleton.constructor = @RecurrentStateSpaceModel

# Parameters for predict, observe and reset
STRATEGY = @strategy/singleton()
strategy/singleton.constructor = @tf.distribute.MirroredStrategy
create_planet_predict_fn.model = %MODEL  # Singleton reference
create_planet_predict_fn.strategy = %STRATEGY
create_planet_observe_fn.model = %MODEL
create_planet_observe_fn.model_dir = %model_dir
create_planet_observe_fn.strategy = %STRATEGY
create_planet_reset_fn.model = %MODEL

# Parameters for train_fn:
create_planet_train_fn.train_steps = 100
create_planet_train_fn.batch = 50
create_planet_train_fn.duration = 50
create_planet_train_fn.learning_rate = 1e-3
create_planet_train_fn.model_dir = %model_dir  # Is populated from flags.
create_planet_train_fn.model = %MODEL
create_planet_train_fn.strategy = %STRATEGY
```
