In [None]:
# default_exp losses

# losses

> This module defines losses for a variety of RL agents.

In [None]:
#hide
from nbdev import *
%load_ext autoreload
%autoreload 2

In [None]:
%nbdev_export
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional, Union

In [None]:
%nbdev_export
def actor_critic_value_loss(value_estimates: torch.Tensor, env_returns: torch.Tensor) -> torch.Tensor:
    """
    Loss for an actor-critic value function.
    
    Is just Mean-Squared-Error between the value estimates and the real returns.
    
    Args:
    - value_estimates (torch.Tensor): Estimates of state-value from the critic network.
    - env_returns (torch.Tensor): Real returns from the environment.
    
    Returns:
    - value_loss (torch.Tensor): MSE loss betwen the estimates and real returns.
    """
    loss_fn = nn.MSELoss()
    loss = loss_fn(value_estimates, env_returns)
    return loss

In [None]:
show_doc(actor_critic_value_loss)

<h4 id="actor_critic_value_loss" class="doc_header"><code>actor_critic_value_loss</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>actor_critic_value_loss</code>(**`value_estimates`**:`Tensor`, **`env_returns`**:`Tensor`)

Loss for an actor-critic value function.

Is just Mean-Squared-Error between the value estimates and the real returns.

Args:
- value_estimates (torch.Tensor): Estimates of state-value from the critic network.
- env_returns (torch.Tensor): Real returns from the environment.

Returns:
- value_loss (torch.Tensor): MSE loss betwen the estimates and real returns.

In [None]:
#hide
vest = torch.tensor([0.])
rtrue = torch.tensor([0.])
assert actor_critic_value_loss(vest, rtrue) is not None, "Val loss fails to return proper value"
assert actor_critic_value_loss(vest, rtrue) == torch.Tensor([0.]), "Val loss is calculated incorrectly."

In [None]:
%nbdev_export
def reinforce_policy_loss(logps: torch.Tensor, env_returns: torch.Tensor) -> torch.Tensor:
    r"""
    Reinforce Policy gradient loss. $-(log(\pi(a | s)) * R_t)$

    Args:
    - logps (PyTorch Tensor): Action log probabilities.
    - env_returns (PyTorch Tensor): Returns from the environment.
    
    Returns:
    - reinforce_loss (torch.Tensor): REINFORCE loss term.
    """
    reinforce_loss = -(logps * env_returns).mean()
    return reinforce_loss


In [None]:
show_doc(reinforce_policy_loss)

<h4 id="reinforce_policy_loss" class="doc_header"><code>reinforce_policy_loss</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>reinforce_policy_loss</code>(**`logps`**:`Tensor`, **`env_returns`**:`Tensor`)

Reinforce Policy gradient loss. $-(log(\pi(a | s)) * R_t)$

Args:
- logps (PyTorch Tensor): Action log probabilities.
- env_returns (PyTorch Tensor): Returns from the environment.

Returns:
- reinforce_loss (torch.Tensor): REINFORCE loss term.

In [None]:
#hide
tmp_logp = torch.tensor([-0.3])
tmp_ret = torch.tensor([10.])
assert reinforce_policy_loss(tmp_logp, tmp_ret) is not None

In [None]:
%nbdev_export
def a2c_policy_loss(logps: torch.Tensor, advs: torch.Tensor) -> torch.Tensor:
    """
    Loss function for an A2C policy. $-(logp(\pi(a|s)) * A_t)$
    
    Args:
    - logps (torch.Tensor): Log-probabilities of selected actions.
    - advs (torch.Tensor): Advantage estimates of selected actions.
    
    Returns:
    - a2c_loss (torch.Tensor): A2C loss term.
    """
    a2c_loss = -(logps * advs).mean()
    return a2c_loss

In [None]:
show_doc(a2c_policy_loss)

<h4 id="a2c_policy_loss" class="doc_header"><code>a2c_policy_loss</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>a2c_policy_loss</code>(**`logps`**:`Tensor`, **`advs`**:`Tensor`)

Loss function for an A2C policy. $-(logp(\pi(a|s)) * A_t)$

Args:
- logps (torch.Tensor): Log-probabilities of selected actions.
- advs (torch.Tensor): Advantage estimates of selected actions.

Returns:
- a2c_loss (torch.Tensor): A2C loss term.

In [None]:
#hide
assert a2c_policy_loss(tmp_logp, tmp_ret) is not None

In [None]:
%nbdev_export
def ppo_clip_policy_loss(
    logps: torch.Tensor, 
    logps_old: torch.Tensor, 
    advs: torch.Tensor, 
    clipratio: Optional[float] = 0.2
    ) -> torch.Tensor:
    """
    Loss function for a PPO-clip policy. 
    See paper for full loss function math: https://arxiv.org/abs/1707.06347
    
    Args:
    - logps (torch.Tensor): Action log-probabilities under the current policy.
    - logps_old (torch.Tensor): Action log-probabilities under the old (pre-update) policy.
    - advs (torch.Tensor): Advantage estimates for the actions taken.
    - clipratio (float): Clipping parameter for PPO-clip loss. In general, is fine with being left as default.
    
    Returns:
    - ppo_loss (torch.Tensor): Loss term for PPO agent.
    - kl (torch.Tensor): KL-divergence estimate between new and old policies.
    """
    policy_ratio = torch.exp(logps - logps_old)
    clipped_adv = torch.clamp(policy_ratio, 1 - clipratio, 1 + clipratio) * advs
    ppo_loss = -(torch.min(policy_ratio * advs, clipped_adv)).mean()

    kl = (logps_old - logps).mean().item()
    return ppo_loss, kl

In [None]:
show_doc(ppo_clip_policy_loss)

<h4 id="ppo_clip_policy_loss" class="doc_header"><code>ppo_clip_policy_loss</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>ppo_clip_policy_loss</code>(**`logps`**:`Tensor`, **`logps_old`**:`Tensor`, **`advs`**:`Tensor`, **`clipratio`**:`Optional`\[`float`\]=*`0.2`*)

Loss function for a PPO-clip policy. 
See paper for full loss function math: https://arxiv.org/abs/1707.06347

Args:
- logps (torch.Tensor): Action log-probabilities under the current policy.
- logps_old (torch.Tensor): Action log-probabilities under the old (pre-update) policy.
- advs (torch.Tensor): Advantage estimates for the actions taken.
- clipratio (float): Clipping parameter for PPO-clip loss. In general, is fine with being left as default.

Returns:
- ppo_loss (torch.Tensor): Loss term for PPO agent.
- kl (torch.Tensor): KL-divergence estimate between new and old policies.

In [None]:
#hide
tmp_logp_old = torch.tensor([-0.2])
assert ppo_clip_policy_loss(tmp_logp, tmp_logp_old, tmp_ret) is not None

In [None]:
%nbdev_export
def ddpg_policy_loss(states: torch.Tensor, qfunc: nn.Module, policy: nn.Module):
    """
    Policy loss function for DDPG agent. See the paper: https://arxiv.org/abs/1509.02971
    
    Args:
    - states (torch.Tensor): States to get Q-policy estimates for.
    - qfunc (nn.Module): Q-function network.
    - policy (nn.Module): Policy network.
    
    Returns:
    - q_policy_loss (torch.Tensor): Loss term for DDPG policy.
    """
    q_pi = qfunc(states, policy(states))
    q_policy_loss = -q_pi.mean()
    return q_policy_loss


In [None]:
show_doc(ddpg_policy_loss)

<h4 id="ddpg_policy_loss" class="doc_header"><code>ddpg_policy_loss</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>ddpg_policy_loss</code>(**`states`**:`Tensor`, **`qfunc`**:`Module`, **`policy`**:`Module`)

Policy loss function for DDPG agent. See the paper: https://arxiv.org/abs/1509.02971

Args:
- states (torch.Tensor): States to get Q-policy estimates for.
- qfunc (nn.Module): Q-function network.
- policy (nn.Module): Policy network.

Returns:
- q_policy_loss (torch.Tensor): Loss term for DDPG policy.

In [None]:
%nbdev_export
def ddpg_qfunc_loss(
    data: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], 
    qfunc: nn.Module, 
    qfunc_target: nn.Module, 
    policy_target: nn.Module,
    gamma: Optional[float] = 0.99
    ):
    """
    Loss for a DDPG Q-function. See the paper: https://arxiv.org/abs/1509.02971
    
    Args:
    - data (tuple of torch.Tensor): input data batch. Contains 5 PyTorch Tensors. The tensors contain the
    following: (states, next_states, actions, rewards, dones).
    - qfunc (nn.Module): Q-function network being trained.
    - qfunc_target (nn.Module): Q-function target network.
    - policy_target (nn.Module): Policy target network.
    - gamma (float): Discount factor.
    
    Returns:
    - loss_q (torch.Tensor): DDPG loss for the Q-function.
    - loss_info (dict): Dictionary containing useful loss info for logging.
    """
    o, o2, a, r, d = data 

    q = qfunc(o, a)

    # Bellman backup for Q function
    with torch.no_grad():
        q_pi_targ = qfunc_target(o2, policy_target(o2))
        backup = r + gamma * (1 - d) * q_pi_targ

    # MSE loss against Bellman backup
    loss_q = ((q - backup) ** 2).mean()

    # Useful info for logging
    loss_info = dict(MeanQValues=q.mean().detach().numpy())

    return loss_q, loss_info

In [None]:
show_doc(ddpg_qfunc_loss)

<h4 id="ddpg_qfunc_loss" class="doc_header"><code>ddpg_qfunc_loss</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>ddpg_qfunc_loss</code>(**`data`**:`Tuple`\[`Tensor`, `Tensor`, `Tensor`, `Tensor`, `Tensor`\], **`qfunc`**:`Module`, **`qfunc_target`**:`Module`, **`policy_target`**:`Module`, **`gamma`**:`Optional`\[`float`\]=*`0.99`*)

Loss for a DDPG Q-function. See the paper: https://arxiv.org/abs/1509.02971

Args:
- data (tuple of torch.Tensor): input data batch. Contains 5 PyTorch Tensors. The tensors contain the
following: (states, next_states, actions, rewards, dones).
- qfunc (nn.Module): Q-function network being trained.
- qfunc_target (nn.Module): Q-function target network.
- policy_target (nn.Module): Policy target network.
- gamma (float): Discount factor.

Returns:
- loss_q (torch.Tensor): DDPG loss for the Q-function.
- loss_info (dict): Dictionary containing useful loss info for logging.

In [None]:
%nbdev_export
def td3_policy_loss(states: torch.Tensor, qfunc: nn.Module, policy: nn.Module):
    """
    Calculate policy loss for TD3 agent. See paper here: https://arxiv.org/abs/1802.09477
    
    Args:
    - states (torch.Tensor): Input states to get policy loss for.
    - qfunc (torch.Tensor): TD3 q-function network.
    - policy (torch.Tensor): Policy network.
    
    Returns:
    - q_policy_loss (torch.Tensor): The TD3 policy loss term.
    """
    q1_pi = qfunc1(states, policy(states))
    q_policy_loss = -q1_pi.mean()
    return q_policy_loss

In [None]:
show_doc(td3_policy_loss)

<h4 id="td3_policy_loss" class="doc_header"><code>td3_policy_loss</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>td3_policy_loss</code>(**`states`**:`Tensor`, **`qfunc`**:`Module`, **`policy`**:`Module`)

Calculate policy loss for TD3 agent. See paper here: https://arxiv.org/abs/1802.09477

Args:
- states (torch.Tensor): Input states to get policy loss for.
- qfunc (torch.Tensor): TD3 q-function network.
- policy (torch.Tensor): Policy network.

Returns:
- q_policy_loss (torch.Tensor): The TD3 policy loss term.

In [None]:
%nbdev_export
def td3_qfunc_loss(
    data: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
    qfunc1: nn.Module,
    qfunc2: nn.Module,
    qfunc1_target: nn.Module,
    qfunc2_target: nn.Module,
    policy: nn.Module,
    act_limit: Union[float, int],
    target_noise: Optional[float] = 0.2,
    noise_clip: Optional[float] = 0.5,
    gamma: Optional[float] = 0.99,
    ):
    """
    Calculate Q-function loss for TD3 agent. See paper here: https://arxiv.org/abs/1802.09477
    
    Args:
    - data (tuple of torch.Tensor): input data batch. Contains 5 PyTorch Tensors. The tensors contain the
    following: (states, next_states, actions, rewards, dones).
    - qfunc1 (nn.Module): First Q-function network being trained.
    - qfunc2 (nn.Module): Other Q-function network being trained.
    - qfunc1_target (nn.Module): First Q-function target network.
    - qfunc2_target (nn.Module): Other Q-function target network.
    - policy (nn.Module): Policy network.
    - act_limit (float or int): Action limit from the environment.
    - target_noise (float): Noise to apply to policy target network.
    - noise_clip (float): Clip the noise within + and - this range.
    - gamma (float): Gamma discount factor.
    
    Returns:
    - loss_q (torch.Tensor): TD3 loss for the Q-function.
    - loss_info (dict): Dictionary containing useful loss info for logging.
    """
    o, a, r, o2, d = data

    q1 = qfunc1(o, a)
    q2 = qfunc2(o, a)

    # Bellman backup for Q functions
    with torch.no_grad():
        pi_targ = policy(o2)

        # Target policy smoothing
        epsilon = torch.randn_like(pi_targ) * target_noise
        epsilon = torch.clamp(epsilon, -noise_clip, noise_clip)
        a2 = pi_targ + epsilon
        a2 = torch.clamp(a2, -act_limit, act_limit)

        # Target Q-values
        q1_pi_targ = qfunc1_target(o2, a2)
        q2_pi_targ = qfunc2_target(o2, a2)
        q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
        backup = r + gamma * (1 - d) * q_pi_targ

    # MSE loss against Bellman backup
    loss_q1 = ((q1 - backup) ** 2).mean()
    loss_q2 = ((q2 - backup) ** 2).mean()
    loss_q = loss_q1 + loss_q2

    # Useful info for logging
    loss_info = dict(Q1Values=q1.detach().numpy(), Q2Values=q2.detach().numpy())

    return loss_q, loss_info

In [None]:
show_doc(td3_qfunc_loss)

<h4 id="td3_qfunc_loss" class="doc_header"><code>td3_qfunc_loss</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>td3_qfunc_loss</code>(**`data`**:`Tuple`\[`Tensor`, `Tensor`, `Tensor`, `Tensor`, `Tensor`\], **`qfunc1`**:`Module`, **`qfunc2`**:`Module`, **`qfunc1_target`**:`Module`, **`qfunc2_target`**:`Module`, **`policy`**:`Module`, **`act_limit`**:`Union`\[`float`, `int`\], **`target_noise`**:`Optional`\[`float`\]=*`0.2`*, **`noise_clip`**:`Optional`\[`float`\]=*`0.5`*, **`gamma`**:`Optional`\[`float`\]=*`0.99`*)

Calculate Q-function loss for TD3 agent. See paper here: https://arxiv.org/abs/1802.09477

Args:
- data (tuple of torch.Tensor): input data batch. Contains 5 PyTorch Tensors. The tensors contain the
following: (states, next_states, actions, rewards, dones).
- qfunc1 (nn.Module): First Q-function network being trained.
- qfunc2 (nn.Module): Other Q-function network being trained.
- qfunc1_target (nn.Module): First Q-function target network.
- qfunc2_target (nn.Module): Other Q-function target network.
- policy (nn.Module): Policy network.
- act_limit (float or int): Action limit from the environment.
- target_noise (float): Noise to apply to policy target network.
- noise_clip (float): Clip the noise within + and - this range.
- gamma (float): Gamma discount factor.

Returns:
- loss_q (torch.Tensor): TD3 loss for the Q-function.
- loss_info (dict): Dictionary containing useful loss info for logging.

In [None]:
%nbdev_export
def sac_policy_loss(
    states: torch.Tensor, 
    qfunc1: nn.Module, 
    qfunc2: nn.Module, 
    policy: nn.Module,
    alpha: Optional[float] = 0.2
    ):
    """
    Calculate policy loss for Soft-Actor Critic agent. See paper here: https://arxiv.org/abs/1801.01290
    
    Args:
    - states (torch.Tensor): Input states for the policy.
    - qfunc1 (nn.Module): First Q-function in SAC agent.
    - qfunc2 (nn.Module): Second Q-function in SAC agent.
    - policy (nn.Module): Policy network.
    - alpha (float): alpha factor for entropy-regularized policy loss.
    
    Returns:
    - loss_policy (torch.Tensor): The policy loss term.
    - policy_info (dict): Useful logging info for the policy.
    """
    o = states
    pi, logp_pi = policy(o)
    q1_pi = qfunc1(o, pi)
    q2_pi = qfunc2(o, pi)
    q_pi = torch.min(q1_pi, q2_pi)

    # Entropy-regularized policy loss
    loss_policy = (alpha * logp_pi - q_pi).mean()

    # Useful info for logging
    policy_info = dict(PolicyLogP=logp_pi.detach().numpy())

    return loss_policy, policy_info

In [None]:
show_doc(sac_policy_loss)

<h4 id="sac_policy_loss" class="doc_header"><code>sac_policy_loss</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>sac_policy_loss</code>(**`states`**:`Tensor`, **`qfunc1`**:`Module`, **`qfunc2`**:`Module`, **`policy`**:`Module`, **`alpha`**:`Optional`\[`float`\]=*`0.2`*)

Calculate policy loss for Soft-Actor Critic agent. See paper here: https://arxiv.org/abs/1801.01290

Args:
- states (torch.Tensor): Input states for the policy.
- qfunc1 (nn.Module): First Q-function in SAC agent.
- qfunc2 (nn.Module): Second Q-function in SAC agent.
- policy (nn.Module): Policy network.
- alpha (float): alpha factor for entropy-regularized policy loss.

Returns:
- loss_policy (torch.Tensor): The policy loss term.
- policy_info (dict): Useful logging info for the policy.

In [None]:
%nbdev_export
def sac_qfunc_loss(
    data,
    qfunc1: nn.Module,
    qfunc2: nn.Module,
    qfunc1_target: nn.Module,
    qfunc2_target: nn.Module,
    policy: nn.Module,
    gamma: Optional[float] = 0.99,
    alpha: Optional[float] = 0.2
    ):
    """
    Q-function loss for Soft-Actor Critic agent.
    
    Args:
    - data (tuple of torch.Tensor): input data batch. Contains 5 PyTorch Tensors. The tensors contain the
    following: (states, next_states, actions, rewards, dones).
    - qfunc1 (nn.Module): First Q-function network being trained.
    - qfunc2 (nn.Module): Other Q-function network being trained.
    - qfunc1_target (nn.Module): First Q-function target network.
    - qfunc2_target (nn.Module): Other Q-function target network.
    - policy (nn.Module): Policy network.
    - gamma (float): Gamma discount factor.
    - alpha (float): Loss term alpha factor.
    
    Returns:
    - loss_q (torch.Tensor): SAC loss for the Q-function.
    - loss_info (dict): Dictionary containing useful loss info for logging.
    """
    o, a, r, o2, d = data

    q1 = qfunc1(o, a)
    q2 = qfunc2(o, a)

    # Bellman backup for Q functions
    with torch.no_grad():
        # Target actions come from *current* policy
        a2, logp_a2 = policy(o2)

        # Target Q-values
        q1_pi_targ = qfunc1_target(o2, a2)
        q2_pi_targ = qfunc2_target(o2, a2)
        q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
        backup = r + gamma * (1 - d) * (q_pi_targ - alpha * logp_a2)

    # MSE loss against Bellman backup
    loss_q1 = ((q1 - backup) ** 2).mean()
    loss_q2 = ((q2 - backup) ** 2).mean()
    loss_q = loss_q1 + loss_q2

    # Useful info for logging
    q_info = dict(Q1Values=q1.detach().numpy(), Q2Values=q2.detach().numpy())

    return loss_q, q_info

In [None]:
show_doc(sac_qfunc_loss)

<h4 id="sac_qfunc_loss" class="doc_header"><code>sac_qfunc_loss</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>sac_qfunc_loss</code>(**`data`**, **`qfunc1`**:`Module`, **`qfunc2`**:`Module`, **`qfunc1_target`**:`Module`, **`qfunc2_target`**:`Module`, **`policy`**:`Module`, **`gamma`**:`Optional`\[`float`\]=*`0.99`*, **`alpha`**:`Optional`\[`float`\]=*`0.2`*)

Q-function loss for Soft-Actor Critic agent.

Args:
- data (tuple of torch.Tensor): input data batch. Contains 5 PyTorch Tensors. The tensors contain the
following: (states, next_states, actions, rewards, dones).
- qfunc1 (nn.Module): First Q-function network being trained.
- qfunc2 (nn.Module): Other Q-function network being trained.
- qfunc1_target (nn.Module): First Q-function target network.
- qfunc2_target (nn.Module): Other Q-function target network.
- policy (nn.Module): Policy network.
- gamma (float): Gamma discount factor.
- alpha (float): Loss term alpha factor.

Returns:
- loss_q (torch.Tensor): SAC loss for the Q-function.
- loss_info (dict): Dictionary containing useful loss info for logging.

In [None]:
#hide
notebook2script()

Converted 00_utils.ipynb.
Converted 01_datasets.ipynb.
Converted 02_buffers.ipynb.
Converted 03_neuralnets.ipynb.
Converted 04_losses.ipynb.
Converted index.ipynb.
