<a href="https://colab.research.google.com/github/lacykaltgr/hworldmodel/blob/main/coding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torchrl
!pip install tensordict

Collecting torchrl
  Downloading torchrl-0.3.1-cp310-cp310-manylinux1_x86_64.whl (5.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.8/5.8 MB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
Collecting tensordict>=0.3.1 (from torchrl)
  Downloading tensordict-0.3.2-cp310-cp310-manylinux1_x86_64.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m50.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch>=2.1.0 (from torchrl)
  Downloading torch-2.2.2-cp310-cp310-manylinux1_x86_64.whl (755.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m755.5/755.5 MB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=2.1.0->torchrl)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=2.1.0->torchrl)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl 

In [2]:
import os
import uuid
import copy

import torch
from torch import nn
from torchrl.collectors import MultiaSyncDataCollector
from torchrl.data import LazyMemmapStorage, MultiStep, TensorDictReplayBuffer
from torchrl.envs import (
    EnvCreator,
    ExplorationType,
    ParallelEnv,
    RewardScaling,
    StepCounter,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
    FrameSkipTransform,
    CatFrames,
    Compose,
    GrayScale,
    ObservationNorm,
    Resize,
    ToTensorImage,
    TransformedEnv,
)
from torchrl.modules import DuelingCnnDQNet, EGreedyModule, QValueActor

from torchrl.objectives import DQNLoss, SoftUpdate
from torchrl.record.loggers.csv import CSVLogger
from torchrl.trainers import (
    LogReward,
    Recorder,
    ReplayBufferTrainer,
    Trainer,
    UpdateWeights,
)

from torchrl.modules import (
    MLP,
    SafeModule,
    SafeProbabilisticModule,
    SafeProbabilisticTensorDictSequential,
    SafeSequential,
    TensorDictModuleBase
)



def is_notebook() -> bool:
    try:
        shell = get_ipython().__class__.__name__
        if shell == "ZMQInteractiveShell":
            return True  # Jupyter notebook or qtconsole
        elif shell == "TerminalInteractiveShell":
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False  # Probably standard Python interpreter

Let's get started with the various pieces we need for our algorithm:

- An environment;
- A policy (and related modules that we group under the "model" umbrella);
- A data collector, which makes the policy play in the environment and
  delivers training data;
- A replay buffer to store the training data;
- A loss module, which computes the objective function to train our policy
  to maximise the return;
- An optimizer, which performs parameter updates based on our loss.

Additional modules include a logger, a recorder (executes the policy in
"eval" mode) and a target network updater. With all these components into
place, it is easy to see how one could misplace or misuse one component in
the training script. The trainer is there to orchestrate everything for you!

## Building the environment

First let's write a helper function that will output an environment. As usual,
the "raw" environment may be too simple to be used in practice and we'll need
some data transformation to expose its output to the policy.

We will be using five transforms:

- :class:`~torchrl.envs.StepCounter` to count the number of steps in each trajectory;
- :class:`~torchrl.envs.transforms.ToTensorImage` will convert a ``[W, H, C]`` uint8
  tensor in a floating point tensor in the ``[0, 1]`` space with shape
  ``[C, W, H]``;
- :class:`~torchrl.envs.transforms.RewardScaling` to reduce the scale of the return;
- :class:`~torchrl.envs.transforms.GrayScale` will turn our image into grayscale;
- :class:`~torchrl.envs.transforms.Resize` will resize the image in a 64x64 format;
- :class:`~torchrl.envs.transforms.CatFrames` will concatenate an arbitrary number of
  successive frames (``N=4``) in a single tensor along the channel dimension.
  This is useful as a single image does not carry information about the
  motion of the cartpole. Some memory about past observations and actions
  is needed, either via a recurrent neural network or using a stack of
  frames.
- :class:`~torchrl.envs.transforms.ObservationNorm` which will normalize our observations
  given some custom summary statistics.

In practice, our environment builder has two arguments:

- ``parallel``: determines whether multiple environments have to be run in
  parallel. We stack the transforms after the
  :class:`~torchrl.envs.ParallelEnv` to take advantage
  of vectorization of the operations on device, although this would
  technically work with every single environment attached to its own set of
  transforms.
- ``obs_norm_sd`` will contain the normalizing constants for
  the :class:`~torchrl.envs.ObservationNorm` transform.




In [3]:
def make_env(
    parallel=False,
    obs_norm_sd=None,
):
    if obs_norm_sd is None:
        obs_norm_sd = {"standard_normal": True}
    if parallel:
        base_env = ParallelEnv(
            num_workers,
            EnvCreator(
                lambda: GymEnv(
                    "MountainCarContinuous-v0",
                    from_pixels=True,
                    pixels_only=True,
                    device=device,
                )
            ),
        )
    else:
        base_env = GymEnv(
            "MountainCarContinuous-v0",
            from_pixels=True,
            pixels_only=True,
            device=device,
        )

    env = TransformedEnv(
        base_env,
        Compose(
            StepCounter(),  # to count the steps of each trajectory
            ToTensorImage(),
            RewardScaling(loc=0.0, scale=0.1),  # TODO ???
            GrayScale(),
            Resize(64, 64),
            FrameSkipTransform(2),
            #CatFrames(4, in_keys=["pixels"], dim=-3),
            ObservationNorm(in_keys=["pixels"], **obs_norm_sd),
        ),
    )
    return env

### Compute normalizing constants

To normalize images, we don't want to normalize each pixel independently
with a full ``[C, W, H]`` normalizing mask, but with simpler ``[C, 1, 1]``
shaped set of normalizing constants (loc and scale parameters).
We will be using the ``reduce_dim`` argument
of :meth:`~torchrl.envs.ObservationNorm.init_stats` to instruct which
dimensions must be reduced, and the ``keep_dims`` parameter to ensure that
not all dimensions disappear in the process:




In [4]:
def get_norm_stats():
    test_env = make_env()
    test_env.transform[-1].init_stats(
        num_iter=1000, cat_dim=0, reduce_dim=[-1, -2, -4], keep_dims=(-1, -2)
    )
    obs_norm_sd = test_env.transform[-1].state_dict()
    # let's check that normalizing constants have a size of ``[C, 1, 1]`` where
    # ``C=4`` (because of :class:`~torchrl.envs.CatFrames`).
    print("state dict of the observation norm:", obs_norm_sd)
    test_env.close()
    return obs_norm_sd

## Building the model (MPC Dreamer)





In [None]:
class StatefulObsEncoder(nn.Module):
  """
  ObsEncoder class of torchrl extended with GRU memory cell
  """
  def __init__(self, ):
      self.encoder = ObsEncoder(**obs_encoder_kwargs)
      self.memory_cell = nn.GRUCell(**memory_kwargs)

  def forward(self, observation, state):
      encoded = self.encoder(observations)
      new_state = self.memory_cell(encoded, state)
      return new_state, encoded

class ObsEncoderWithTarget(nn.Module):
  """
  Observation encoder class wrapper
  Adds a target encoder, which is updated slowly
  """

  def __init__(self, encoder_kwargs, memory_kwargs):
    self.encoder = StatefulObsEncoder(encoder_kwargs)
    self.target_encoder = copy.deepcopy(self.encoder)

    for param in self.target_encoder.parameters():
      param.requires_grad = False

  def forward(self, observation, state):
    return self.encoder(*args, **kwargs)

  def forward_target(self, observation, state):
    with torch.no_grad():
      return self.target_encoder(*args, **kwargs)

  def update_target(self, momentum: float = 0.99):
      with torch.no_grad():
          # use momentum to update the EMA encoder
          for param_q, param_k in zip(
              self.encoder.parameters(), self.target_encoder.parameters()
          ):
              param_k.data.mul_(momentum).add_((1.-momentum) * param_q.detach().data)

          for param_q, param_k in zip(
              self.memory_cell.parameters(), self.target_memory_cell.parameters()
          ):
              param_k.data.mul_(momentum).add_((1.-momentum) * param_q.detach().data)



In [None]:
class Predictor(nn.Module):
  """
  Predictor class
  Takes encoded observation (state) as input and tries to predict the next

  cell   default='gru'
  action_dim
  state_dim
  """

  def __init__(self, cell, action_dim, state_dim):
    if cell == 'gru':
      self.cell = nn.GRUCell(action_dim, state_dim)
    else:
      raise NotImplementedError()


  def forward(self, state, action):
    predicted_state = self.cell(action, state)
    return predicted_state

In [None]:
class GRURollout(TensorDictModuleBase):
    """Rollout the RSSM network.

    Given a set of encoded observations and actions, this module will rollout the RSSM network to compute all the intermediate
    states and beliefs.
    The previous posterior is used as the prior for the next time step.
    The forward method returns a stack of all intermediate states and beliefs.

    Reference: https://arxiv.org/abs/1811.04551

    Args:
        rssm_prior (TensorDictModule): Prior network.
        rssm_posterior (TensorDictModule): Posterior network.


    """

    def __init__(self, encoder: TensorDictModule, transition_model: TensorDictModule):
        super().__init__()
        _module = TensorDictSequential(encoder, transition_model)
        self.in_keys = _module.in_keys
        self.out_keys = _module.out_keys
        self.encoder = encoder
        self.transition_model = transition_model

    def forward(self, tensordict):
        """Runs a rollout of simulated transitions in the latent space given a sequence of actions and environment observations.

        The rollout requires a belief and posterior state primer.

        At each step, two probability distributions are built and sampled from:
        - A prior distribution p(s_{t+1} | s_t, a_t, b_t) where b_t is a
            deterministic transform of the form b_t(s_{t-1}, a_{t-1}). The
            previous state s_t is sampled according to the posterior
            distribution (see below), creating a chain of posterior-to-priors
            that accumulates evidence to compute a prior distribution over
            the current event distribution:
            p(s_{t+1} s_t | o_t, a_t, s_{t-1}, a_{t-1}) = p(s_{t+1} | s_t, a_t, b_t) q(s_t | b_t, o_t)

        - A posterior distribution of the form q(s_{t+1} | b_{t+1}, o_{t+1})
            which amends to q(s_{t+1} | s_t, a_t, o_{t+1})

        """
        tensordict_out = []
        *batch, time_steps = tensordict.shape

        update_values = tensordict.exclude(*self.out_keys).unbind(-1)
        _tensordict = update_values[0]
        for t in range(time_steps):
            # samples according to p(s_{t+1} | s_t, a_t, b_t)
            # ["state", "belief", "action"] -> [("next", "prior_mean"), ("next", "prior_std"), "_", ("next", "belief")]
            with timeit("rollout/time-encoder"):
                self.encoder(_tensordict)

            # samples according to p(s_{t+1} | s_t, a_t, o_{t+1}) = p(s_t | b_t, o_t)
            # [("next", "belief"), ("next", "encoded_latents")] -> [("next", "posterior_mean"), ("next", "posterior_std"), ("next", "state")]
            with timeit("rollout/time-transition-model"):
                self.transition_model(_tensordict)

            tensordict_out.append(_tensordict)
            if t < time_steps - 1:
                _tensordict = step_mdp(
                    _tensordict.select(*self.out_keys, strict=False), keep_other=False
                )
                _tensordict = update_values[t + 1].update(_tensordict)

        return torch.stack(tensordict_out, tensordict.ndim - 1)


In [None]:
class WorldModel(nn.Module):
  """
  World Model class

  encoder             - CNN + GRU
  transition model    - GRU
  reward model        - MLP
  """

  def __init__(self, encoder, transition_model, reward_model):

    self.world_model = GRURollout(
        encoder = SafeModule(
            encoder,
            in_keys=["observation", "state"],
            out_keys=["state", "encoded"],
        ),
        transition_model = SafeModule(
            transition_model,
            in_keys=["state", "action"],
            out_keys=[("next", "state")],
        ),
    )


    self.target_encoder = SafeModule(
            encoder.transition_model,
            in_keys=[("next", "observation"), "state"],
            out_keys=[("next", "state_target")],
    )

    self.reward_model = SafeProbabilisticTensorDictSequential(
        SafeModule(
            reward_model,
            in_keys=[("next", "state"), ("next", "belief")],
            out_keys=[("next", "loc")],
        ),
        SafeProbabilisticModule(
            in_keys=[("next", "loc")],
            out_keys=[("next", "reward")],
            distribution_class=IndependentNormal,
            distribution_kwargs={"scale": 1.0, "event_dim": 1},
        ),
    )


In [None]:
from torchrl.modules.models.model_based import ObsEncoder, ObsDecoder
from torchrl.modules import MLP

state_dim = 200
action_dim = 2
mlp_depth = 3
mlp_dims = 200
activation = torch.nn.SiLU

def make_model(dummy_env):
  encoder = ObsEncoderWithTarget()
  transition = Predictor('gru', action_dim, state_dim)
  reward_model = MLP(out_features=1, depth=mlp_depth, num_cells=mlp_dims, activation_class=activation)

  return WorldModel(encoder, transition, reward_model)



## Collecting and storing data

### Replay buffers

Replay buffers play a central role in off-policy RL sota-implementations such as DQN.
They constitute the dataset we will be sampling from during training.

Here, we will use a regular sampling strategy, although a prioritized RB
could improve the performance significantly.

We place the storage on disk using
:class:`~torchrl.data.replay_buffers.storages.LazyMemmapStorage` class. This
storage is created in a lazy manner: it will only be instantiated once the
first batch of data is passed to it.

The only requirement of this storage is that the data passed to it at write
time must always have the same shape.



In [None]:
def get_replay_buffer(buffer_size, n_optim, batch_size):
    replay_buffer = TensorDictReplayBuffer(
        batch_size=batch_size,
        storage=LazyMemmapStorage(buffer_size),
        prefetch=n_optim,
    )
    return replay_buffer

### Data collector

As in [PPO](https://pytorch.org/rl/tutorials/coding_ppo.html) and
[DDPG](https://pytorch.org/rl/tutorials/coding_ddpg.html), we will be using
a data collector as a dataloader in the outer loop.

We choose the following configuration: we will be running a series of
parallel environments synchronously in parallel in different collectors,
themselves running in parallel but asynchronously.

<div class="alert alert-info"><h4>Note</h4><p>This feature is only available when running the code within the "spawn"
  start method of python multiprocessing library. If this tutorial is run
  directly as a script (thereby using the "fork" method) we will be using
  a regular :class:`~torchrl.collectors.SyncDataCollector`.</p></div>

The advantage of this configuration is that we can balance the amount of
compute that is executed in batch with what we want to be executed
asynchronously. We encourage the reader to experiment how the collection
speed is impacted by modifying the number of collectors (ie the number of
environment constructors passed to the collector) and the number of
environment executed in parallel in each collector (controlled by the
``num_workers`` hyperparameter).

Collector's devices are fully parametrizable through the ``device`` (general),
``policy_device``, ``env_device`` and ``storing_device`` arguments.
The ``storing_device`` argument will modify the
location of the data being collected: if the batches that we are gathering
have a considerable size, we may want to store them on a different location
than the device where the computation is happening. For asynchronous data
collectors such as ours, different storing devices mean that the data that
we collect won't sit on the same device each time, which is something that
out training loop must account for. For simplicity, we set the devices to
the same value for all sub-collectors.



In [None]:
def get_collector(
    stats,
    num_collectors,
    actor_explore,
    frames_per_batch,
    total_frames,
    device,
):
    cls = MultiaSyncDataCollector
    env_arg = [make_env(parallel=True, obs_norm_sd=stats)] * num_collectors
    data_collector = cls(
        env_arg,
        policy=actor_explore,
        frames_per_batch=frames_per_batch,
        total_frames=total_frames,
        # this is the default behaviour: the collector runs in ``"random"`` (or explorative) mode
        exploration_type=ExplorationType.RANDOM,
        # We set the all the devices to be identical. Below is an example of
        # heterogeneous devices
        device=device,
        storing_device=device,
        split_trajs=False,
        postproc=MultiStep(gamma=gamma, n_steps=5),
    )
    return data_collector

## Loss function

Building our loss function is straightforward: we only need to provide
the model and a bunch of hyperparameters to the DQNLoss class.

### Target parameters

Many off-policy RL sota-implementations use the concept of "target parameters" when it
comes to estimate the value of the next state or state-action pair.
The target parameters are lagged copies of the model parameters. Because
their predictions mismatch those of the current model configuration, they
help learning by putting a pessimistic bound on the value being estimated.
This is a powerful trick (known as "Double Q-Learning") that is ubiquitous
in similar sota-implementations.




In [None]:
class ModelLoss(LossModule):
    """Dreamer Model Loss.

    Computes the loss of the dreamer world model. The loss is composed of the
    kl divergence between the prior and posterior of the RSSM,
    the reconstruction loss over the reconstructed observation and the reward
    loss over the predicted reward.

    Reference: https://arxiv.org/abs/1912.01603.

    Args:
        world_model (TensorDictModule): the world model.
        lambda_kl (float, optional): the weight of the kl divergence loss. Default: 1.0.
        lambda_reco (float, optional): the weight of the reconstruction loss. Default: 1.0.
        lambda_reward (float, optional): the weight of the reward loss. Default: 1.0.
        reco_loss (str, optional): the reconstruction loss. Default: "l2".
        reward_loss (str, optional): the reward loss. Default: "l2".
        free_nats (int, optional): the free nats. Default: 3.
        delayed_clamp (bool, optional): if ``True``, the KL clamping occurs after
            averaging. If False (default), the kl divergence is clamped to the
            free nats value first and then averaged.
        global_average (bool, optional): if ``True``, the losses will be averaged
            over all dimensions. Otherwise, a sum will be performed over all
            non-batch/time dimensions and an average over batch and time.
            Default: False.
    """

    @dataclass
    class _AcceptedKeys:
        """Maintains default values for all configurable tensordict keys.

        This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
        default values

        Attributes:
            reward (NestedKey): The reward is expected to be in the tensordict
                key ("next", reward). Defaults to ``"reward"``.
            true_reward (NestedKey): The `true_reward` will be stored in the
                tensordict key ("next", true_reward). Defaults to ``"true_reward"``.
            prior_mean (NestedKey): The prior mean is expected to be in the
                tensordict key ("next", prior_mean). Defaults to ``"prior_mean"``.
            prior_std (NestedKey): The prior mean is expected to be in the
                tensordict key ("next", prior_mean). Defaults to ``"prior_mean"``.
            posterior_mean (NestedKey): The posterior mean is expected to be in
                the tensordict key ("next", prior_mean). Defaults to ``"posterior_mean"``.
            posterior_std (NestedKey): The posterior std is expected to be in
                the tensordict key ("next", prior_mean). Defaults to ``"posterior_std"``.
            pixels (NestedKey): The pixels is expected to be in the tensordict key ("next", pixels).
                Defaults to ``"pixels"``.
            reco_pixels (NestedKey): The reconstruction pixels is expected to be
                in the tensordict key ("next", reco_pixels). Defaults to ``"reco_pixels"``.
        """

        reward: NestedKey = "reward"
        true_reward: NestedKey = "true_reward"
        prior: NestedKey = "prior_logits"
        posterior: NestedKey = "posterior_logits"
        pixels: NestedKey = "pixels"
        reco_pixels: NestedKey = "reco_pixels"

    default_keys = _AcceptedKeys()

    def __init__(
        self,
        world_model: TensorDictModule,
        *,
        lambda_kl: float = 1.0,
        lambda_reco: float = 1.0,
        lambda_reward: float = 1.0,
        kl_balance: float = 0.8,
        reco_loss: Optional[str] = None,
        reward_loss: Optional[str] = None,
        delayed_clamp: bool = False,
        global_average: bool = False,
    ):
        super().__init__()
        self.world_model = world_model
        self.lambda_kl = lambda_kl
        self.lambda_reward = lambda_reward

    def _forward_value_estimator_keys(self, **kwargs) -> None:
        pass

    def forward(self, tensordict: TensorDict) -> torch.Tensor:
        tensordict = tensordict.copy()
        tensordict.rename_key_(
            ("next", self.tensor_keys.reward),
            ("next", self.tensor_keys.true_reward),
        )
        tensordict = self.world_model(tensordict)

        # compute model loss
        latents = tensordict.get(("next", "state")).detach()
        targets = tensordict.get(("next", "state_target"))

        latent_distance_loss = torch.nn.MSELoss()(targets, latents).mean()

        reward_model = self.world_model.reward_model
        dist = reward_model.get_dist(tensordict)
        reward_loss = -dist.log_prob(
            tensordict.get(("next", self.tensor_keys.true_reward))
        ).mean()

        return (
            TensorDict(
                {
                    "loss_model_distance": self.lambda_kl * kl_loss,
                    "loss_model_reward": self.lambda_reward * reward_loss,
                },
                [],
            ),
            tensordict.detach(),
        )

## Hyperparameters

Let's start with our hyperparameters. The following setting should work well
in practice, and the performance of the algorithm should hopefully not be
too sensitive to slight variations of these.



In [None]:
is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)

### Optimizer



In [None]:
# the learning rate of the optimizer
lr = 2e-3
# weight decay
wd = 1e-5
# the beta parameters of Adam
betas = (0.9, 0.999)
# Optimization steps per batch collected (aka UPD or updates per data)
n_optim = 8

### DQN parameters
gamma decay factor



In [None]:
gamma = 0.99

Smooth target network update decay parameter.
This loosely corresponds to a 1/tau interval with hard target network
update



In [None]:
tau = 0.02

### Data collection and replay buffer

<div class="alert alert-info"><h4>Note</h4><p>Values to be used for proper training have been commented.</p></div>

Total frames collected in the environment. In other implementations, the
user defines a maximum number of episodes.
This is harder to do with our data collectors since they return batches
of N collected frames, where N is a constant.
However, one can easily get the same restriction on number of episodes by
breaking the training loop when a certain number
episodes has been collected.



In [None]:
total_frames = 5_000  # 500000

Random frames used to initialize the replay buffer.



In [None]:
init_random_frames = 100  # 1000

Frames in each batch collected.



In [None]:
frames_per_batch = 32  # 128

Frames sampled from the replay buffer at each optimization step



In [None]:
batch_size = 32  # 256

Size of the replay buffer in terms of frames



In [None]:
buffer_size = min(total_frames, 100000)

Number of environments run in parallel in each data collector



In [None]:
num_workers = 2  # 8
num_collectors = 2  # 4

### Environment and exploration

We set the initial and final value of the epsilon factor in Epsilon-greedy
exploration.
Since our policy is deterministic, exploration is crucial: without it, the
only source of randomness would be the environment reset.



In [None]:
eps_greedy_val = 0.1
eps_greedy_val_env = 0.005

To speed up learning, we set the bias of the last layer of our value network
to a predefined value (this is not mandatory)



In [None]:
init_bias = 2.0

<div class="alert alert-info"><h4>Note</h4><p>For fast rendering of the tutorial ``total_frames`` hyperparameter
  was set to a very low number. To get a reasonable performance, use a greater
  value e.g. 500000</p></div>




## Building a Trainer

TorchRL's :class:`~torchrl.trainers.Trainer` class constructor takes the
following keyword-only arguments:

- ``collector``
- ``loss_module``
- ``optimizer``
- ``logger``: A logger can be
- ``total_frames``: this parameter defines the lifespan of the trainer.
- ``frame_skip``: when a frame-skip is used, the collector must be made
  aware of it in order to accurately count the number of frames
  collected etc. Making the trainer aware of this parameter is not
  mandatory but helps to have a fairer comparison between settings where
  the total number of frames (budget) is fixed but the frame-skip is
  variable.



In [None]:
stats = get_norm_stats()
test_env = make_env(parallel=False, obs_norm_sd=stats)
# Get model
actor, actor_explore = make_model(test_env)
loss_module, target_net_updater = get_loss_module(actor, gamma)

collector = get_collector(
    stats=stats,
    num_collectors=num_collectors,
    actor_explore=actor_explore,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    device=device,
)
optimizer = torch.optim.Adam(
    loss_module.parameters(), lr=lr, weight_decay=wd, betas=betas
)
exp_name = f"dqn_exp_{uuid.uuid1()}"
tmpdir = tempfile.TemporaryDirectory()
logger = CSVLogger(exp_name=exp_name, log_dir=tmpdir.name)
warnings.warn(f"log dir: {logger.experiment.log_dir}")

We can control how often the scalars should be logged. Here we set this
to a low value as our training loop is short:



In [None]:
log_interval = 500

trainer = Trainer(
    collector=collector,
    total_frames=total_frames,
    frame_skip=1,
    loss_module=loss_module,
    optimizer=optimizer,
    logger=logger,
    optim_steps_per_batch=n_optim,
    log_interval=log_interval,
)

### Registering hooks

Registering hooks can be achieved in two separate ways:

- If the hook has it, the :meth:`~torchrl.trainers.TrainerHookBase.register`
  method is the first choice. One just needs to provide the trainer as input
  and the hook will be registered with a default name at a default location.
  For some hooks, the registration can be quite complex: :class:`~torchrl.trainers.ReplayBufferTrainer`
  requires 3 hooks (``extend``, ``sample`` and ``update_priority``) which
  can be cumbersome to implement.



In [None]:
buffer_hook = ReplayBufferTrainer(
    get_replay_buffer(buffer_size, n_optim, batch_size=batch_size),
    flatten_tensordicts=True,
)
buffer_hook.register(trainer)
weight_updater = UpdateWeights(collector, update_weights_interval=1)
weight_updater.register(trainer)
recorder = Recorder(
    record_interval=100,  # log every 100 optimization steps
    record_frames=1000,  # maximum number of frames in the record
    frame_skip=1,
    policy_exploration=actor_explore,
    environment=test_env,
    exploration_type=ExplorationType.MODE,
    log_keys=[("next", "reward")],
    out_keys={("next", "reward"): "rewards"},
    log_pbar=True,
)
recorder.register(trainer)

The exploration module epsilon factor is also annealed:




In [None]:
trainer.register_op("post_steps", actor_explore[1].step, frames=frames_per_batch)

- Any callable (including :class:`~torchrl.trainers.TrainerHookBase`
  subclasses) can be registered using :meth:`~torchrl.trainers.Trainer.register_op`.
  In this case, a location must be explicitly passed (). This method gives
  more control over the location of the hook but it also requires more
  understanding of the Trainer mechanism.
  Check the [trainer documentation](https://pytorch.org/rl/reference/trainers.html)
  for a detailed description of the trainer hooks.




In [None]:
trainer.register_op("post_optim", target_net_updater.step)

We can log the training rewards too. Note that this is of limited interest
with CartPole, as rewards are always 1. The discounted sum of rewards is
maximised not by getting higher rewards but by keeping the cart-pole alive
for longer.
This will be reflected by the `total_rewards` value displayed in the
progress bar.




In [None]:
log_reward = LogReward(log_pbar=True)
log_reward.register(trainer)

<div class="alert alert-info"><h4>Note</h4><p>It is possible to link multiple optimizers to the trainer if needed.
  In this case, each optimizer will be tied to a field in the loss
  dictionary.
  Check the :class:`~torchrl.trainers.OptimizerHook` to learn more.</p></div>

Here we are, ready to train our algorithm! A simple call to
``trainer.train()`` and we'll be getting our results logged in.




In [None]:
trainer.train()

We can now quickly check the CSVs with the results.



In [None]:
def print_csv_files_in_folder(folder_path):
    """
    Find all CSV files in a folder and prints the first 10 lines of each file.

    Args:
        folder_path (str): The relative path to the folder.

    """
    csv_files = []
    output_str = ""
    for dirpath, _, filenames in os.walk(folder_path):
        for file in filenames:
            if file.endswith(".csv"):
                csv_files.append(os.path.join(dirpath, file))
    for csv_file in csv_files:
        output_str += f"File: {csv_file}\n"
        with open(csv_file, "r") as f:
            for i, line in enumerate(f):
                if i == 10:
                    break
                output_str += line.strip() + "\n"
        output_str += "\n"
    print(output_str)


print_csv_files_in_folder(logger.experiment.log_dir)

## Conclusion and possible improvements

In this tutorial we have learned:

- How to write a Trainer, including building its components and registering
  them in the trainer;
- How to code a DQN algorithm, including how to create a policy that picks
  up the action with the highest value with
  :class:`~torchrl.modules.QValueNetwork`;
- How to build a multiprocessed data collector;

Possible improvements to this tutorial could include:

- A prioritized replay buffer could also be used. This will give a
  higher priority to samples that have the worst value accuracy.
  Learn more on the
  [replay buffer section](https://pytorch.org/rl/reference/data.html#composable-replay-buffers)
  of the documentation.
- A distributional loss (see :class:`~torchrl.objectives.DistributionalDQNLoss`
  for more information).
- More fancy exploration techniques, such as :class:`~torchrl.modules.NoisyLinear` layers and such.



In [None]:
class ActorCritic(nn.Module):
  """
  Actor critic module

  policy_module
  value_module
  """

  def __init__(self, proof_environment)
    # Define input/output distributions
    input_shape = proof_environment.observation_spec["observation"].shape
    num_outputs = proof_environment.action_spec.shape[-1]
    distribution_class = TanhNormal
    distribution_kwargs = {
        "min": proof_environment.action_spec.space.low,
        "max": proof_environment.action_spec.space.high,
        "tanh_loc": False,
    }

    # Define policy architecture
    policy_mlp = MLP(
        in_features=input_shape[-1],
        activation_class=torch.nn.Tanh,
        out_features=num_outputs,  # predict only loc
        num_cells=[64, 64],
    )
    for layer in policy_mlp.modules():
        if isinstance(layer, torch.nn.Linear):
            torch.nn.init.orthogonal_(layer.weight, 1.0)
            layer.bias.data.zero_()

    # Add state-independent normal scale
    policy_mlp = torch.nn.Sequential(
        policy_mlp,
        AddStateIndependentNormalScale(proof_environment.action_spec.shape[-1]),
    )

    # Add probabilistic sampling of the actions
    self.policy_module = ProbabilisticActor(
        TensorDictModule(
            module=policy_mlp,
            in_keys=["observation"],
            out_keys=["loc", "scale"],
        ),
        in_keys=["loc", "scale"],
        spec=CompositeSpec(action=proof_environment.action_spec),
        distribution_class=distribution_class,
        distribution_kwargs=distribution_kwargs,
        return_log_prob=True,
        default_interaction_type=ExplorationType.RANDOM,
    )

    # Define value architecture
    value_mlp = MLP(
        in_features=input_shape[-1],
        activation_class=torch.nn.Tanh,
        out_features=1,
        num_cells=[64, 64],
    )
    for layer in value_mlp.modules():
        if isinstance(layer, torch.nn.Linear):
            torch.nn.init.orthogonal_(layer.weight, 0.01)
            layer.bias.data.zero_()

    # Define value module
    self.value_module = ValueOperator(
        value_mlp,
        in_keys=["observation"],
    )

In [None]:
class Plan2ExploreModule(nn.Module):

  def __init__(self, reward_scale, use_log):
    self._networks = [MLP(size, **config.expl_head)
                      for _ in range(config.disag_models)]
    self.use_log = use_log
    self.reward_scale = reward_scale

  def forward(self, state, action):
    state_action = torch.cat([state, action], dim=-1)
    preds = [head(state_action).mode() for head in self._networks]
    disagreement = torch.tensor(preds).std(0).mean(-1)

    if self.use_log:
      disagreement = torch.log(disagreement)

    reward = self.reward_scale * self.intr_rewnorm(disagreement)[0]
    #if self.config.expl_extr_scale:
    #  reward += self.config.expl_extr_scale * self.extr_rewnorm(
    #      self.reward(seq))[0]
    return reward


    class Plan2ExploreLoss():

  def __init__(self, config, act_space, wm, tfstep, reward):



  def _train_ensemble(self, inputs, targets):
    if self.config.disag_offset:
      targets = targets[:, self.config.disag_offset:]
      inputs = inputs[:, :-self.config.disag_offset]
    targets = tf.stop_gradient(targets)
    inputs = tf.stop_gradient(inputs)
    with tf.GradientTape() as tape:
      preds = [head(inputs) for head in self._networks]
      loss = -sum([pred.log_prob(targets).mean() for pred in preds])
    metrics = self.opt(tape, loss, self._networks)
    return metrics