# Lab 1 - Level C: Reward Tuning for Challenging Terrain

## Objective

In this level, you will tune the reward function to train the quadruped robot to navigate more challenging environments, specifically to walk over elevated steps.

## Prerequisites

- Successfully completed Level E (PPO loss function implementation)
- Understanding of reward shaping in reinforcement learning
- Completed setup instructions from main `README.md`

## What You'll Learn

- Reward function design for complex locomotion tasks
- How different reward components affect robot behavior

## Task Overview

Your main tasks for Level C:

1. **Environment Analysis**: Understand the challenging terrain environment
2. **Reward Function Design**: Modify and tune reward components for step climbing
3. **Training**: Train the robot with your improved reward function
4. **Evaluation**: Test the trained policy

## Expected Outcomes

After successful completion, you should see:
- Robot successfully navigating elevated steps
- Stable locomotion on uneven terrain

## Setup

In [None]:
# Unset LD_LIBRARY_PATH
import os
if 'LD_LIBRARY_PATH' in os.environ:
    os.environ['LD_LIBRARY_PATH'] = ''
    print("LD_LIBRARY_PATH has been unset")
else:
    print("LD_LIBRARY_PATH was not set")
# Set EGL vendor directory to include user location
user_vendor_dir = os.path.expanduser('~/.local/share/glvnd/egl_vendor.d')
os.makedirs(user_vendor_dir, exist_ok=True)
# Create ICD config
icd_config = {
    "file_format_version": "1.0.0",
    "ICD": {
        "library_path": "libEGL_nvidia.so.0"
    }
}
import json
with open(f'{user_vendor_dir}/10_nvidia.json', 'w') as f:
    json.dump(icd_config, f, indent=2)
# Set environment variable
current_dirs = os.environ.get('__EGL_VENDOR_LIBRARY_DIRS', '/usr/share/glvnd/egl_vendor.d')
os.environ['__EGL_VENDOR_LIBRARY_DIRS'] = f'{user_vendor_dir}:{current_dirs}'
# Set MuJoCo to use EGL
os.environ['MUJOCO_GL'] = 'egl'
# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl
# Tell XLA to use Triton GEMM
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

import os
import subprocess
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from datetime import datetime
import functools
from brax.training.agents.ppo import networks as ppo_networks
from mujoco_playground import wrapper
from mujoco_playground import registry
from mujoco_playground.config import locomotion_params
from brax.training.agents.ppo import losses as ppo_losses
from IPython.display import HTML, clear_output
import mujoco
import jax
import jax.numpy as jp
import cv2
import custom_ppo_train
from utils import render_video_during_training, evaluate_policy
import mediapy as media
# Configure MuJoCo to use the EGL rendering backend (requires GPU)
#os.environ['MUJOCO_GL'] = 'egl'
scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True   # Show visual geoms
scene_option.geomgroup[3] = False  # Hide collision geoms
scene_option.geomgroup[5] = True   # Show sites (including height scanner visualization)
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True  # Show contact points
scene_option.flags[mujoco.mjtVisFlag.mjVIS_RANGEFINDER] = True

In [None]:
print("JAX devices:", jax.devices())

## Load the Challenging Environment

We'll load a more challenging environment with elevated steps and rough terrain that the robot needs to navigate.

In [None]:
env_name = 'Go1JoystickRoughTerrain'  # Update this to actual challenging environment
env = registry.load(env_name)
key = jax.random.PRNGKey(15)

# Visualize the challenging environment
num_resets = 10
frames_per_reset = 50

jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
rollout = []

for reset_idx in range(num_resets):
    key, reset_key = jax.random.split(key)
    
    # Reset to new random position
    state = jit_reset(reset_key)
    
    # Add multiple frames of the same reset position
    for frame_idx in range(frames_per_reset):
        rollout.append(state)
        # Take a few small steps to show terrain
        if frame_idx < 10:
            action = jp.zeros(env.action_size)
            state = jit_step(state, action)

render_every = 2  # Render every 3rd frame
fps = 1.0 / env.dt / render_every
traj = rollout[::render_every]

frames = env.render(
    traj,
    camera="track",  # Use tracking camera
    scene_option=scene_option,
    width=640,
    height=480,
)
media.show_video(frames, fps=fps)

## Level C Task: Reward Function Analysis and Tuning

### Current Challenge

The flat terrain policy from Level E likely won't work well on challenging terrain with steps and obstacles. You need to:

1. **Analyze the current reward function**
2. **Identify what behaviors need to be encouraged/discouraged**
3. **Design new reward components**
4. **Tune reward weights and parameters**

You can find more details on the reward function definitions here: https://github.com/finnBsch/mujoco_playground/blob/lab1_rl/mujoco_playground/_src/locomotion/go1/joystick.py

### Your Task:

Tune the reward function weights that enables the robot to successfully navigate the challenging terrain. Consider both positive rewards (for desired behaviors) and negative rewards/penalties (for undesired behaviors). You will need to explain your choices for the reward weights.

## Training with Custom Reward Function
We will first load the default reward config, and your task is to tune them.

In [None]:
from ml_collections import config_dict

def reward_config() -> config_dict.ConfigDict:
  return config_dict.create(
      ctrl_dt=0.02,
      sim_dt=0.004,
      episode_length=1000,
      Kp=35.0,
      Kd=0.5,
      action_repeat=1,
      action_scale=1.0,
      history_len=1,
      soft_joint_pos_limit_factor=0.95,
      noise_config=config_dict.create(
          level=1.0,
          scales=config_dict.create(
              joint_pos=0.03,
              joint_vel=1.5,
              gyro=0.2,
              gravity=0.05,
              linvel=0.1,
          ),
      ),
      reward_config=config_dict.create(
        ### ----- ADJUST SETTINGS BELOW ----- ###
          scales=config_dict.create(
              torso_height=-0.0,
              # Tracking.
              tracking_lin_vel=1.0,
              tracking_ang_vel=0.5,

              # Base reward.
              lin_vel_z=-0.5,
              ang_vel_xy=-0.05,
              orientation=-5.0,
              # Other.
              dof_pos_limits=-1.0,
              pose=0.5,
              # Other.
              termination=-1.0,
              stand_still=-1.0,
              # Regularization.
              torques=-0.0002,
              action_rate=-0.01,
              energy=-0.001,
              # Feet.
              feet_clearance=-0.2,
              feet_slip=-0.1,
              feet_air_time=0.1,
          ),
          tracking_sigma=0.25,
          max_foot_height=0.11,        
          desired_foot_air_time=0.15,
          desired_torso_height=0.36
        ### ----- ADJUST SETTINGS ABOVE ----- ###
      ),
      pert_config=config_dict.create(
          enable=False,
          velocity_kick=[0.0, 3.0],
          kick_durations=[0.05, 0.2],
          kick_wait_times=[1.0, 3.0],
      ),
      command_config=config_dict.create(
          # Uniform distribution for command amplitude.
          a=[1.5, 0.8, 1.2],
          # Probability of not zeroing out new command.
          b=[0.9, 0.25, 0.5],
      ),
      impl="jax",
      nconmax=4 * 8192,
      njmax=40,
  )

You can reduce the final training time by configuring less frequent evaluations (and thus less frequent video generations), but to get started it's helpful to get a lot of feedback.

You can also reduce the total training time if you feel like you have a good policy after few steps. The default is 200.000.000 steps of training.

In [None]:

env_name = 'Go1JoystickRoughTerrain'
env = registry.load(env_name, config=reward_config())
key = jax.random.PRNGKey(15)

randomizer = registry.get_domain_randomizer(env_name)
print(f"Environment '{env_name}' loaded successfully.")

x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]

current_policy = None

env_cfg = reward_config()

eval_env_for_video = registry.load(env_name, config=env_cfg)
jit_reset = jax.jit(eval_env_for_video.reset)
jit_step = jax.jit(eval_env_for_video.step)


def progress(num_steps, metrics):
    clear_output(wait=True)

    times.append(datetime.now())
    x_data.append(num_steps)
    y_data.append(metrics["eval/episode_reward"])
    y_dataerr.append(metrics["eval/episode_reward_std"])

    plt.xlim([0, ppo_params["num_timesteps"] * 1.25])
    plt.xlabel("# environment steps")
    plt.ylabel("reward per episode")
    plt.title(f"Challenging Terrain Training: reward={y_data[-1]:.3f}")
    plt.errorbar(x_data, y_data, yerr=y_dataerr, color="red")
    
    display(plt.gcf())

    # Render video if we have a current policy
    if current_policy is not None:
        render_video_during_training(current_policy, num_steps, jit_step, jit_reset, env_cfg, eval_env_for_video)

# PPO training parameters - you may want to tune these for challenging terrain
ppo_params = locomotion_params.brax_ppo_config(env_name)
ppo_training_params = dict(ppo_params)
 
ppo_training_params["num_evals"] = 25 # Reduce for final training for less feedback.
ppo_training_params["num_timesteps"] = 200000000  # Total number of training steps

network_factory = ppo_networks.make_ppo_networks

if "network_factory" in ppo_params:
    del ppo_training_params["network_factory"]
    network_factory = functools.partial(
        ppo_networks.make_ppo_networks,
        **ppo_params.network_factory
    )

print("Training parameters:")
print(ppo_training_params)

# Create a policy parameters callback to capture the current policy
def policy_params_callback(_, make_policy_fn, params):
    global current_policy
    current_policy = make_policy_fn(params, deterministic=True)
    
train_fn = functools.partial(
        custom_ppo_train.train,
        **ppo_training_params,
        network_factory=network_factory,
        randomization_fn=randomizer,
        progress_fn=progress,
        policy_params_fn=policy_params_callback,
)

In [None]:
print("Starting training on challenging terrain...")
make_policy, params, _ = train_fn(environment=env,
                                  eval_env=registry.load(env_name, config=env_cfg),
                                  wrap_env_fn=wrapper.wrap_for_brax_training,
                                  compute_custom_ppo_loss_fn=ppo_losses.compute_ppo_loss
                                 )
print("Training completed.")

## Evaluate the Trained Policy

Test your trained policy on the challenging terrain. You should see the robot successfully navigating steps and obstacles.

In [None]:
# Set up evaluation environment with challenging conditions
env_cfg = reward_config()
env_cfg.pert_config.enable = True
env_cfg.pert_config.velocity_kick = [2.0, 4.0]  # Add some perturbations
env_cfg.pert_config.kick_wait_times = [3.0, 10.0]
env_cfg.command_config.a = [1.2, 0.6, 2*jp.pi]  # Adjust command ranges

eval_env = registry.load(env_name, config=env_cfg)
velocity_kick_range = [0.0, 0.0]  # Disable velocity kick for clearer evaluation
kick_duration_range = [0.05, 0.2]

jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)
jit_inference_fn = jax.jit(make_policy(params, deterministic=True))

print("Evaluating policy on challenging terrain...")
evaluate_policy(
    eval_env,
    jit_inference_fn,
    jit_step,
    jit_reset,
    env_cfg,
    eval_env,
    velocity_kick_range,
    kick_duration_range,
)