# Lab 1 - Level E: PPO Loss Function Implementation

## Objective

In this level, you will implement the missing components of the PPO (Proximal Policy Optimization) loss function and test it by training a quadruped robot to follow simple velocity commands.

## What You'll Learn

- Core concepts of the PPO algorithm
- How to implement policy gradient loss functions
- Working with JAX and MuJoCo simulation environments

## Prerequisites

Before starting, make sure you have completed the setup instructions in the main `README.md` file. You should have:
- Set up your Python virtual environment
- Installed all required dependencies
- Familiar with basic reinforcement learning concepts

## Task Overview

Your main task is to complete the `compute_custom_ppo_loss` function. This function implements the core PPO algorithm components:

1. **Policy Loss**: The clipped surrogate objective that prevents large policy updates
2. **Value Function Loss**: Mean squared error between predicted and target values  
3. **Entropy Loss**: Regularization term to encourage exploration

The PPO algorithm uses a clipped objective to ensure stable training by preventing policy updates that are too large. You'll implement the mathematical formulation from the original PPO paper.

## Expected Outcome

After successful implementation, you should see:
- The robot learning to follow velocity commands
- Training progress with increasing reward over time
- Video outputs showing the robot's locomotion behavior

## Code Structure

The notebook contains:
- Environment setup and visualization
- The PPO loss function template (your main task)
- Training loop and progress tracking
- Evaluation and video generation

## Setup

In [None]:
# Unset LD_LIBRARY_PATH
import os
if 'LD_LIBRARY_PATH' in os.environ:
    os.environ['LD_LIBRARY_PATH'] = ''
    print("LD_LIBRARY_PATH has been unset")
else:
    print("LD_LIBRARY_PATH was not set")
# Set EGL vendor directory to include user location
user_vendor_dir = os.path.expanduser('~/.local/share/glvnd/egl_vendor.d')
os.makedirs(user_vendor_dir, exist_ok=True)
# Create ICD config
icd_config = {
    "file_format_version": "1.0.0",
    "ICD": {
        "library_path": "libEGL_nvidia.so.0"
    }
}
import json
with open(f'{user_vendor_dir}/10_nvidia.json', 'w') as f:
    json.dump(icd_config, f, indent=2)
# Set environment variable
current_dirs = os.environ.get('__EGL_VENDOR_LIBRARY_DIRS', '/usr/share/glvnd/egl_vendor.d')
os.environ['__EGL_VENDOR_LIBRARY_DIRS'] = f'{user_vendor_dir}:{current_dirs}'
# Set MuJoCo to use EGL
os.environ['MUJOCO_GL'] = 'egl'
# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl
# Tell XLA to use Triton GEMM
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

import os
import subprocess
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from datetime import datetime
import functools
from brax.training.agents.ppo import networks as ppo_networks
from mujoco_playground import wrapper
from mujoco_playground import registry
from mujoco_playground.config import locomotion_params
from brax.training.agents.ppo import losses as ppo_losses
from IPython.display import HTML, clear_output
import mujoco
import jax
import jax.numpy as jp
import cv2
import custom_ppo_train
from utils import render_video_during_training, evaluate_policy
import mediapy as media
# Configure MuJoCo to use the EGL rendering backend (requires GPU)
#os.environ['MUJOCO_GL'] = 'egl'
scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True   # Show visual geoms
scene_option.geomgroup[3] = False  # Hide collision geoms
scene_option.geomgroup[5] = True   # Show sites (including height scanner visualization)
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True  # Show contact points
scene_option.flags[mujoco.mjtVisFlag.mjVIS_RANGEFINDER] = True

In [None]:
print("JAX devices:", jax.devices())

## Load the Environment
We will load a flat terrain environment, and show the robot in different initial conditions. 

In [None]:
env_name = 'Go1JoystickFlatTerrain'
env = registry.load(env_name)
key = jax.random.PRNGKey(15)

num_resets = 20
frames_per_reset = 30

jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
rollout = []

for reset_idx in range(num_resets):
    key, reset_key = jax.random.split(key)
    
    # Reset to new random position
    state = jit_reset(reset_key)

    robot_pos = state.data.qpos[:3]
    
    # Add multiple frames of the same reset position
    for frame_idx in range(frames_per_reset):
        rollout.append(state)
        # Take a few small steps
        if frame_idx < 5:
            action = jp.zeros(env.action_size)
            state = jit_step(state, action)


render_every = 2  # Render every 2nd frame
fps = 1.0 / env.dt / render_every
traj = rollout[::render_every]

frames = env.render(
    traj,
    camera="track",  # Use tracking camera
    scene_option=scene_option,
    width=640,
    height=480,
)
env_cfg = registry.get_default_config(env_name)
eval_env_for_video = registry.load(env_name, config=env_cfg)
jit_reset = jax.jit(eval_env_for_video.reset)
jit_step = jax.jit(eval_env_for_video.step)
media.show_video(frames, fps=fps)

## Level E Implementation Task:

**Your main task**: Complete the missing components in the `compute_custom_ppo_loss` function above.

The function template is provided with most of the implementation, but you need to understand and verify the key PPO components:

### Key Components to Understand:

1. **Policy Loss (Clipped Surrogate Objective)**:
   - Compute probability ratio: `rho_s = exp(new_log_prob - old_log_prob)`
   - Apply clipping to prevent large policy updates
   - Use minimum of clipped and unclipped objectives

2. **Value Function Loss**:
   - Mean squared error between predicted values and GAE targets
   - Helps the critic learn to estimate state values accurately

3. **Entropy Loss**:
   - Regularization term to encourage exploration
   - Prevents the policy from becoming too deterministic too early

### Implementation Hints:

- The function already includes the correct mathematical formulation
- Pay attention to the clipping mechanism in the policy loss
- Make sure you understand how advantages are normalized
- The total loss combines all three components with appropriate weighting

### Questions to Consider:

1. Why does PPO use clipping instead of other policy gradient methods?
2. What role does the advantage normalization play?
3. How do the different loss components balance exploration vs exploitation?

In [None]:
# Imports
from typing import Any, Tuple
from brax.training import types
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.types import Params
import flax
import jax
import jax.numpy as jnp
from brax.training.agents.ppo.losses import compute_gae

# Helper Struct for policy and value params
@flax.struct.dataclass
class PPONetworkParams:
  """Contains training state for the learner."""

  policy: Params
  value: Params

# PPO loss function to be implemented
def compute_custom_ppo_loss(
    params: PPONetworkParams,
    normalizer_params: Any,
    data: types.Transition,
    rng: jnp.ndarray,
    ppo_network: ppo_networks.PPONetworks,
    entropy_cost: float = 1e-4,
    discounting: float = 0.9,
    reward_scaling: float = 1.0,
    gae_lambda: float = 0.95,
    clipping_epsilon: float = 0.3,
    normalize_advantage: bool = True,
) -> Tuple[jnp.ndarray, types.Metrics]:
  """Computes PPO loss.

  Args:
    params: Network parameters,
    normalizer_params: Parameters of the normalizer.
    data: Transition that with leading dimension [B, T]. extra fields required
      are ['state_extras']['truncation'] ['policy_extras']['raw_action']
      ['policy_extras']['log_prob']
    rng: Random key
    ppo_network: PPO networks.
    entropy_cost: entropy cost.
    discounting: discounting,
    reward_scaling: reward multiplier.
    gae_lambda: General advantage estimation lambda.
    clipping_epsilon: Policy loss clipping epsilon
    normalize_advantage: whether to normalize advantage estimate

  Returns:
    A tuple (loss, metrics)
  """
  parametric_action_distribution = ppo_network.parametric_action_distribution
  policy_apply = ppo_network.policy_network.apply
  value_apply = ppo_network.value_network.apply

  # Put the time dimension first.
  data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), data)
  policy_logits = policy_apply(
      normalizer_params, params.policy, data.observation
  )

  baseline = value_apply(normalizer_params, params.value, data.observation)
  terminal_obs = jax.tree_util.tree_map(lambda x: x[-1], data.next_observation)
  bootstrap_value = value_apply(normalizer_params, params.value, terminal_obs)

  rewards = data.reward * reward_scaling
  truncation = data.extras['state_extras']['truncation']
  termination = (1 - data.discount) * (1 - truncation)

  target_action_log_probs = parametric_action_distribution.log_prob(
      policy_logits, data.extras['policy_extras']['raw_action']
  )
  behaviour_action_log_probs = data.extras['policy_extras']['log_prob']

  vs, advantages = compute_gae(
      truncation=truncation,
      termination=termination,
      rewards=rewards,
      values=baseline,
      bootstrap_value=bootstrap_value,
      lambda_=gae_lambda,
      discount=discounting,
  )
  if normalize_advantage:
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
  
  # TODO: Compute probability ratio between new and old policy
  # Hint: Use jnp.exp() and compute the difference between log probabilities
  rho_s = # YOUR CODE HERE
  
  # TODO: Implement the PPO clipped surrogate objective
  # Hint: Compare unclipped (rho_s * advantages) vs clipped version
  # Use jnp.clip() with clipping_epsilon parameter
  surrogate_loss1 = # YOUR CODE HERE
  surrogate_loss2 = # YOUR CODE HERE

  # TODO: PPO policy loss is the negative mean of the minimum of both surrogates
  policy_loss = # YOUR CODE HERE

  # TODO: Implement value function loss
  # Hint: Mean squared error between vs (targets) and baseline (predictions)
  # Scale by 0.5 * 0.5 as in the original implementation
  v_error = # YOUR CODE HERE
  v_loss = # YOUR CODE HERE

  # TODO: Implement entropy loss for exploration
  # Hint: Use parametric_action_distribution.entropy() and multiply by entropy_cost
  # Make it negative since we want to maximize entropy (minimize negative entropy)
  entropy = # YOUR CODE HERE
  entropy_loss = # YOUR CODE HERE

  # TODO: Combine all loss components
  total_loss = # YOUR CODE HERE
  
  return total_loss, {
      'total_loss': total_loss,
      'policy_loss': policy_loss,
      'v_loss': v_loss,
      'entropy_loss': entropy_loss,
  }

You can reduce the final training time by configuring less frequent evaluations (and thus less frequent video generations), but to get started it's helpful to get a lot of feedback.

You can also reduce the total training time if you feel like you have a good policy after few steps. The default is 200.000.000 steps of training.

In [None]:
env_cfg = registry.get_default_config(env_name)
randomizer = registry.get_domain_randomizer(env_name)
print(f"Environment '{env_name}' loaded successfully.")


x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]

# Store the current policy for video rendering
current_policy = None

def progress(num_steps, metrics):
    clear_output(wait=True)

    times.append(datetime.now())
    x_data.append(num_steps)
    y_data.append(metrics["eval/episode_reward"])
    y_dataerr.append(metrics["eval/episode_reward_std"])

    plt.xlim([0, ppo_params["num_timesteps"] * 1.25])
    plt.xlabel("# environment steps")
    plt.ylabel("reward per episode")
    plt.title(f"y={y_data[-1]:.3f}")
    plt.errorbar(x_data, y_data, yerr=y_dataerr, color="blue")
    
    # Save the plot to a file instead of displaying it
    display(plt.gcf())

    # Render video if we have a current policy
    if current_policy is not None:
        render_video_during_training(current_policy, num_steps, jit_step, jit_reset, env_cfg, eval_env_for_video)


ppo_params = locomotion_params.brax_ppo_config(env_name)
ppo_training_params = dict(ppo_params)

ppo_training_params["num_evals"] = 25 # Reduce for final training for less feedback.
ppo_training_params["num_timesteps"] = 200000000  # Total number of training steps
 
network_factory = ppo_networks.make_ppo_networks

if "network_factory" in ppo_params:
    del ppo_training_params["network_factory"]
    network_factory = functools.partial(
        ppo_networks.make_ppo_networks,
        **ppo_params.network_factory
    )
print(ppo_training_params)

# Create a policy parameters callback to capture the current policy
def policy_params_callback(_, make_policy_fn, params):
    # Update the current policy for video rendering
    global current_policy
    current_policy = make_policy_fn(params, deterministic=True)
    
train_fn = functools.partial(
        custom_ppo_train.train,
        **ppo_training_params,
        network_factory=network_factory,
        randomization_fn=randomizer,
        progress_fn=progress,
        policy_params_fn=policy_params_callback,
)


Run the training. The first steps should take a while, since we needto wait for jax to compile the training functions.

In [None]:
print("Starting training...")
make_policy, params, _ = train_fn(environment=env,
                                  eval_env=registry.load(env_name, config=env_cfg),
                                  wrap_env_fn=wrapper.wrap_for_brax_training,
                                  compute_custom_ppo_loss_fn=compute_custom_ppo_loss
                                 )
print("Training completed.")

## Evaluate the trained model
You should see the robot following the command sets.

In [None]:
env_cfg = registry.get_default_config(env_name)
env_cfg.pert_config.enable = True
env_cfg.pert_config.velocity_kick = [3.0, 6.0]
env_cfg.pert_config.kick_wait_times = [5.0, 15.0]
env_cfg.command_config.a = [1.5, 0.8, 2*jp.pi]
eval_env = registry.load(env_name, config=env_cfg)
velocity_kick_range = [0.0, 0.0]  # Disable velocity kick.
kick_duration_range = [0.05, 0.2]

jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)
jit_inference_fn = jax.jit(make_policy(params, deterministic=True))

evaluate_policy(
    eval_env,
    jit_inference_fn,
    jit_step,
    jit_reset,
    env_cfg,
    eval_env,
    velocity_kick_range,
    kick_duration_range,
)