# Understanding the Learner module for guiding our neural network updates

According to [Ray RLlib definition](https://docs.ray.io/en/latest/rllib/rllib-learner.html), Learners allows you to abstract the training logic of RLModules. It supports both gradient-based and non-gradient-based updates (e.g. polyak averaging, etc.) The API enables you to distribute the Learner using data- distributed parallel (DDP). The Learner achieves the following:

1. Facilitates gradient-based updates on RLModule.
2. Provides abstractions for non-gradient based updates such as polyak averaging, etc.
3. Reporting training statistics.
4. Checkpoints the modules and optimizer states for durable training.

In this notebook, we are going to focus on the first point.

In [None]:
from ray.rllib.core.learner import Learner
from ray.rllib.core.learner.torch.torch_learner import TorchLearner
from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import PPOTorchLearner

If the RLModule was responsible for implementing the neural networks of our DRL system, the Learner is responsible for implementing the mechanisms to update our neural networks according to the used RL method (usually represented by the loss function).

Here, we are adopting a different approach where instead of creating a custom Learner, we are going to analyze the current Proximal Policy Optimization (PPO) RL learner implementation from Ray RLlib. There is a [Learner base class](https://github.com/ray-project/ray/blob/master/rllib/core/learner/learner.py) from Ray RLlib which implement base methods for calculating the gradients based on a specific loss definition and update the RL module parameters (neural networks) based on that. There are a lot of information on this class, but you can try to focus on some functions such as the presence of abstract functions for `compute_gradients()`, `apply_gradients()`, and `compute_loss_for_module()` which are implemented in the subclasses of Learner base class to deal with the gradients and compute loss values.

Ray RLlib assumes two different frameworks (Pytorch and Tensorflow) to compute gradient operations for neural networks, and therefore there are two base classes called [TorchLearner](https://github.com/ray-project/ray/blob/master/rllib/core/learner/torch/torch_learner.py) and [TfLearner](https://github.com/ray-project/ray/blob/master/rllib/core/learner/tf/tf_learner.py) for implementing gradient operations for Pytorch and Tensorflow frameworks. You can check the `compute_gradients()` and `apply_gradients()` functions of these classes to verify the difference of computing and applying gradient updates using both frameworks. You can also check that the function `compute_loss_for_module()` is not implemented in any of these classes yet. Differently from `compute_gradients()` and `apply_gradients()` functions which are common operations for all the RL algorithms that deal with neural networks, the function `compute_loss_for_module()` is directly related to the adopted RL algorithm method, and therefore we should have a different implementation of this function when considering different algorithms such as PPO, SAC and other RL methods.

All the information until this point are common for different RL methods, but now we will start to analyze the specific PPO Learner implemented in the class [PPOTorchLearner](https://github.com/ray-project/ray/blob/master/rllib/algorithms/ppo/torch/ppo_torch_learner.py) for the Pytorch framework. The Ray RLlib code for `PPOTorchLearner` is presented below. Remember from the function `compute_loss_for_module()` that was not implemented in the `TorchLearner`? It is implemented in the `PPOTorchLearner`. It is important to remember the PPO loss function to the Actor and Critic policies from the [paper](https://arxiv.org/abs/1707.06347).

The Value Function (VF) loss (Critic loss) function is 
$$
L_t^{\mathsf{VF}}(W) = \mathbb{E} \left[(V_W(s[t])-R_t)^2\right],
$$ 
where $V_W(s[i])$ is the the critic output (vf policy output) which an approximation of the true return function, and $R_t$ is the return obtained in the batch of experiences.

The Actor loss function is
$$
L_{t}^{\mathsf{clip}}(\theta) = \mathbb{\hat E}_t \!\! \left[ \min ( r_t(\theta) \hat A_t, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \hat A_t ) \right],
$$
where the probability ratio 
$$
    r_t(\theta)=\frac{\pi_{\theta}(t(t)|s(t))}{\pi_{\theta\mathsf{old}}(a(t)|s(t))}
$$
represents the current policy $\pi_{\theta}$ changes in relation to the old policy $\pi_{\theta\mathsf{old}}$. And, $A_t$ is the generalized advantage estimation which represents how much better a particular action was compared to the average action at a given state $s(t)$. The epsilon $\epsilon$ is a PPO hyperparameter. You can check the original paper for more information about the loss functions.

**Here we are ignoring the KL contribution for simplicity**

Finally, let's look at the PPO Learner implementation for Torch framework (class PPOTorchLearner) and let's try to find the presented loss functions.

In [None]:
import logging
from typing import Any, Dict

import numpy as np

from ray.rllib.algorithms.ppo.ppo import (
    LEARNER_RESULTS_KL_KEY,
    LEARNER_RESULTS_CURR_KL_COEFF_KEY,
    LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY,
    LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY,
    PPOConfig,
)
from ray.rllib.algorithms.ppo.ppo_learner import PPOLearner
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY
from ray.rllib.core.learner.torch.torch_learner import TorchLearner
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import explained_variance
from ray.rllib.utils.typing import ModuleID, TensorType

torch, nn = try_import_torch()

logger = logging.getLogger(__name__)


class PPOTorchLearner(PPOLearner, TorchLearner):
    """Implements torch-specific PPO loss logic on top of PPOLearner.

    This class implements the ppo loss under `self.compute_loss_for_module()`.
    """

    @override(TorchLearner)
    def compute_loss_for_module(  # Here we define the loss function calculation
        self,
        *,
        module_id: ModuleID,
        config: PPOConfig,
        batch: Dict[str, Any],
        fwd_out: Dict[str, TensorType],
    ) -> TensorType:
        module = self.module[
            module_id
        ].unwrapped()  # That's the policy model (RLModule)

        ###### Ignore these detail implementation for now ######
        if Columns.LOSS_MASK in batch:
            mask = batch[Columns.LOSS_MASK]
            num_valid = torch.sum(mask)

            def possibly_masked_mean(data_):
                return torch.sum(data_[mask]) / num_valid

        else:
            possibly_masked_mean = torch.mean
        ################################################

        # Here we are getting the action distribution from the policy output for both the current and previous policy (before update neural network parameters)
        action_dist_class_train = module.get_train_action_dist_cls()
        action_dist_class_exploration = module.get_exploration_action_dist_cls()
        curr_action_dist = action_dist_class_train.from_logits(
            fwd_out[Columns.ACTION_DIST_INPUTS]
        )
        prev_action_dist = action_dist_class_exploration.from_logits(
            batch[Columns.ACTION_DIST_INPUTS]
        )

        # Calculate the log probability ratio between the current and previous policy, that is our r_t(\theta)
        # You have to remember that a value A/B is the same as exp(log(A) - log(B))
        logp_ratio = torch.exp(
            curr_action_dist.logp(batch[Columns.ACTIONS]) - batch[Columns.ACTION_LOGP]
        )

        ##### Remember the loss for PPO can consider the KL contribution too? we are just ignoring it for now (config.use_kl_loss=False)
        # Only calculate kl loss if necessary (kl-coeff > 0.0).
        if config.use_kl_loss:
            action_kl = prev_action_dist.kl(curr_action_dist)
            mean_kl_loss = possibly_masked_mean(action_kl)
        else:
            mean_kl_loss = torch.tensor(0.0, device=logp_ratio.device)
        #################################

        # Usually the entropy is used to encourage exploration. I omitted it in the previous explanation to simplify the understanding.
        curr_entropy = curr_action_dist.entropy()
        mean_entropy = possibly_masked_mean(curr_entropy)

        # Compute the surrogate loss, that's our L^{CLIP}(\theta) in the PPO paper
        surrogate_loss = torch.min(
            batch[Postprocessing.ADVANTAGES] * logp_ratio,
            batch[Postprocessing.ADVANTAGES]
            * torch.clamp(logp_ratio, 1 - config.clip_param, 1 + config.clip_param),
        )  # This is the clipped version of the surrogate loss

        # Compute a value function loss.
        if config.use_critic:
            # Here we calculate our value function loss (L^{VF}(\theta))
            value_fn_out = module.compute_values(
                batch, embeddings=fwd_out.get(Columns.EMBEDDINGS)
            )
            vf_loss = torch.pow(
                value_fn_out - batch[Postprocessing.VALUE_TARGETS], 2.0
            )  # Here we are calculating the squared error between the value function output and the target value
            vf_loss_clipped = torch.clamp(
                vf_loss, 0, config.vf_clip_param
            )  # There is a parameter to also clip the value function loss
            mean_vf_loss = possibly_masked_mean(vf_loss_clipped)
            mean_vf_unclipped_loss = possibly_masked_mean(vf_loss)
        # Ignore the value function -> Set all to 0.0.
        else:
            z = torch.tensor(0.0, device=surrogate_loss.device)
            value_fn_out = mean_vf_unclipped_loss = vf_loss_clipped = mean_vf_loss = z

        # Finally, out total loss (actor and critic loss functions, and entropy contribution together) is calculated here
        total_loss = possibly_masked_mean(
            -surrogate_loss
            + config.vf_loss_coeff * vf_loss_clipped
            - (
                self.entropy_coeff_schedulers_per_module[module_id].get_current_value()
                * curr_entropy
            )
        )

        ###### We are ignoring the code below since it is related to KL contibution that we assume to be ignored by setting the use_kl_loss=False

        # Add mean_kl_loss (already processed through `possibly_masked_mean`),
        # if necessary.
        if config.use_kl_loss:
            total_loss += self.curr_kl_coeffs_per_module[module_id] * mean_kl_loss

        # Log important loss stats.
        self.metrics.log_dict(
            {
                POLICY_LOSS_KEY: -possibly_masked_mean(surrogate_loss),
                VF_LOSS_KEY: mean_vf_loss,
                LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY: mean_vf_unclipped_loss,
                LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY: explained_variance(
                    batch[Postprocessing.VALUE_TARGETS], value_fn_out
                ),
                ENTROPY_KEY: mean_entropy,
                LEARNER_RESULTS_KL_KEY: mean_kl_loss,
            },
            key=module_id,
            window=1,  # <- single items (should not be mean/ema-reduced over time).
        )
        # Return the total loss.
        return total_loss

    @override(PPOLearner)
    def _update_module_kl_coeff(
        self,
        *,
        module_id: ModuleID,
        config: PPOConfig,
        kl_loss: float,
    ) -> None:
        if np.isnan(kl_loss):
            logger.warning(
                f"KL divergence for Module {module_id} is non-finite, this "
                "will likely destabilize your model and the training "
                "process. Action(s) in a specific state have near-zero "
                "probability. This can happen naturally in deterministic "
                "environments where the optimal policy has zero mass for a "
                "specific action. To fix this issue, consider setting "
                "`kl_coeff` to 0.0 or increasing `entropy_coeff` in your "
                "config."
            )

        # Update the KL coefficient.
        curr_var = self.curr_kl_coeffs_per_module[module_id]
        if kl_loss > 2.0 * config.kl_target:
            # TODO (Kourosh) why not 2?
            curr_var.data *= 1.5
        elif kl_loss < 0.5 * config.kl_target:
            curr_var.data *= 0.5

        # Log the updated KL-coeff value.
        self.metrics.log_value(
            (module_id, LEARNER_RESULTS_CURR_KL_COEFF_KEY),
            curr_var.item(),
            window=1,
        )

After explaining the PPO Learner in details and how it implements the paper theory on Ray RLlib, you have now knowledge to implement different Loss functions for PPO in order to improve its performance -- or find out the defined loss function is very good and hard to outperform :) 