<a href="https://colab.research.google.com/github/kuds/rl-mujoco-tennis/blob/main/%5BWall%20Ball%5D%20Soft%20Actor-Critic%20(SAC).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Wall Ball: Soft Actor-Critic (SAC)

In [None]:
!pip install mujoco

# Set up GPU rendering.
from google.colab import files
import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

# Check if installation was succesful.
try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

# Other imports and helper functions
import time
import itertools
import numpy as np

# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

from IPython.display import clear_output
clear_output()


In [None]:
!pip install stable-baselines3

In [None]:
import gymnasium
import mujoco
from stable_baselines3 import SAC
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import VecVideoRecorder
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.callbacks import CallbackList
import numpy
import os
import csv
import torch
import platform
from importlib.metadata import version
import matplotlib
import matplotlib.pyplot
from gymnasium import utils
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.spaces import Box

In [None]:
print(f"Python Version: {platform.python_version()}")
print(f"Torch Version: {version('torch')}")
print(f"Is Cuda Available: {torch.cuda.is_available()}")
print(f"Cuda Version: {torch.version.cuda}")
print(f"Gymnasium Version: {version('gymnasium')}")
print(f"Numpy Version: {version('numpy')}")
print(f"Mujoco Version: {version('mujoco')}")
print(f"Stable-Baselines3 Version: {version('stable-baselines3')}")
print(f"Matplotlib Version: {version('matplotlib')}")

In [None]:
rl_type = "SAC"
env_str = "WallBall"
log_dir = "./logs/{}".format(env_str)
name_prefix = "wall_ball"

In [None]:
hyperparams = {
    "env_str": env_str,
    "rl_type": rl_type,
    "eval_freq": 25_000,
    "n_envs": 4,
    "min_force": 100.0,
    "total_timesteps": 2_000_000,
    "log_dir": log_dir
}

In [None]:
class VideoRecordCallback(BaseCallback):
    def __init__(
        self,
        save_path: str,
        video_length: int,
        save_freq: int = 5_000,
        name_prefix: str ="rl_model",
        verbose: int = 0):

        super().__init__(verbose)
        self.save_freq = save_freq
        self.video_length = video_length
        self.save_path = save_path
        self.name_prefix = name_prefix
        # Those variables will be accessible in the callback
        # (they are defined in the base class)
        # The RL model
        # self.model = None  # type: BaseAlgorithm
        # An alias for self.model.get_env(), the environment used for training
        # self.training_env # type: VecEnv
        # Number of time the callback was called
        # self.n_calls = 0  # type: int
        # num_timesteps = n_envs * n times env.step() was called
        # self.num_timesteps = 0  # type: int
        # local and global variables
        # self.locals = {}  # type: Dict[str, Any]
        # self.globals = {}  # type: Dict[str, Any]
        # The logger object, used to report things in the terminal
        # self.logger # type: stable_baselines3.common.logger.Logger
        # Sometimes, for event callback, it is useful
        # to have access to the parent object
        # self.parent = None  # type: Optional[BaseCallback]

    def _on_step(self) -> bool:
        if self.n_calls % self.save_freq == 0:

          name_prefix = f"{self.name_prefix}_{self.num_timesteps}"

          # Record video of the best model playing
          rec_val = make_vec_env(make_env, n_envs=1)
          rec_val = VecVideoRecorder(rec_val,
                                    self.save_path,
                                    video_length=self.video_length,
                                    record_video_trigger=lambda x: x == 0,
                                    name_prefix=name_prefix)

          obs = rec_val.reset()
          session_length = 0
          total_reward = 0.0
          csv_file_name = os.path.join(self.save_path, f"{name_prefix}.csv")
          with open(csv_file_name, 'w') as csvfile:
            csv_writer = csv.writer(csvfile, delimiter=',')
            csv_writer.writerow(["Sensor Data",
                                 "Reward",
                                 "Total Reward",
                                 "Done"])
            for _ in range(self.video_length):
              session_length += 1
              action, _states = self.model.predict(obs)
              obs, rewards, dones, info = rec_val.step(action)
              total_reward += rewards
              #print(info)
              csv_writer.writerow([info[0]["sensor_data"],
                                   rewards[0],
                                   total_reward[0],
                                   dones[0]])
              rec_val.render()

              if dones:
                break

          print(f"Step: {self.num_timesteps} | Session Length: {session_length} |Total Bounces: {int(total_reward[0])}")

          rec_val.close()
        return True

In [None]:
class WallBallEnv(MujocoEnv, utils.EzPickle):
    metadata = {
        "render_modes": [
            "human",
            "rgb_array",
            "depth_array",
        ],
        "render_fps": 100,
    }

    # set default episode_len for truncate episodes
    def __init__(self, episode_len=750, min_force=0.0, **kwargs):
        utils.EzPickle.__init__(self, **kwargs)

        self.min_force = min_force
        self.bounce_count = 0
        self.previous_touch_value = 0
        # change shape of observation to your observation space size
        observation_space = Box(low=-numpy.inf, high=np.inf, shape=(18,), dtype=numpy.float64)
        # load your MJCF model with env and choose frames count between actions
        MujocoEnv.__init__(
            self,
            os.path.abspath(f"/content/{name_prefix}.xml"),
            5,
            observation_space=observation_space,
            **kwargs
        )
        self.step_number = 0
        self.episode_len = episode_len

    # determine the reward depending on observation or other properties of the simulation
    def step(self, a):
        reward = 0.0
        self.do_simulation(a, self.frame_skip)
        self.step_number += 1


        # Bounce Detection (CRITICAL)
        current_touch_value = self.data.sensor("touch_sensor").data[0]
        if current_touch_value >= self.min_force and self.previous_touch_value <= 0:  # Check for transition from no contact to contact
            self.bounce_count += 1
            reward = 1.0
            #print(f"Bounce! Total Bounces: {self.bounce_count}")
        self.previous_touch_value = current_touch_value

        obs = self._get_obs()
        done = bool(not np.isfinite(obs).all() or (obs[2] < 0))
        truncated = self.step_number > self.episode_len
        return obs, reward, done, truncated, {"sensor_data": current_touch_value}

    # define what should happen when the model is reset (at the beginning of each episode)
    def reset_model(self):
        self.step_number = 0
        self.bounce_count = 0

        # for example, noise is added to positions and velocities
        qpos = self.init_qpos + self.np_random.uniform(
            size=self.model.nq, low=-0.01, high=0.01
        )
        qvel = self.init_qvel + self.np_random.uniform(
            size=self.model.nv, low=-0.01, high=0.01
        )
        self.set_state(qpos, qvel)
        return self._get_obs()

    # determine what should be added to the observation
    # for example, the velocities and positions of various joints can be obtained through their names, as stated here
    def _get_obs(self):
        obs = np.concatenate((np.array(self.data.joint("ball_freejoint").qpos[:3]),
                              np.array(self.data.joint("ball_freejoint").qvel[:3]),
                              np.array(self.data.joint("rotate_x").qpos),
                              np.array(self.data.joint("rotate_x").qvel),
                              np.array(self.data.joint("rotate_y").qpos),
                              np.array(self.data.joint("rotate_y").qvel),
                              np.array(self.data.joint("rotate_z").qpos),
                              np.array(self.data.joint("rotate_z").qvel),
                              np.array(self.data.joint("slider_x").qpos),
                              np.array(self.data.joint("slider_x").qvel),
                              np.array(self.data.joint("slider_y").qpos),
                              np.array(self.data.joint("slider_y").qvel),
                              np.array(self.data.joint("slider_z").qpos),
                              np.array(self.data.joint("slider_z").qvel)), axis=0)
        return obs

In [None]:
# Ensure environment XML (MuJoCo model) is available
xml_content = """
<mujoco model="tennis_wall_paddle">
    <compiler angle="degree" coordinate="local" inertiafromgeom="true" />
    <option integrator="RK4" timestep="0.002" />
    <size nconmax="500" njmax="1000" nstack="300000" />
    <default>
        <joint armature="0.1" damping="1" limited="true" />
        <geom condim="3" density="1000" friction="1 0.5 0.5" margin="0.001" rgba="0.8 0.8 0.8 1" />
        <motor ctrlrange="-1 1" ctrllimited="true" />
    </default>
    <asset>
        <texture builtin="flat" height="1278" name="texplane" rgb1="0.9 0.9 0.9" rgb2="0.1 0.1 0.1" type="2d" width="1279" />
        <material name="matplane" reflectance="0.5" shininess="0.1" specular="0.1" texture="texplane" />
    </asset>
    <worldbody>
        <light diffuse=".8 .8 .8" pos="0 0 3" specular="0.1 0.1 0.1" />
        <geom name="floor" pos="0 0 0" size="5 5 0.1" type="plane" material="matplane" />
        <geom name="wall" pos="3 0 1.5" size="0.1 5 1.5" type="box" rgba="0.7 0.7 0.9 1" />
        <body name="paddle_base" pos="-2 0 1">
            <joint name="paddle_slide_y" type="slide" axis="0 1 0" range="-2 2" />
            <joint name="paddle_slide_z" type="slide" axis="0 0 1" range="0.5 2" />
            <geom name="paddle_base_geom" pos="0 0 0" size="0.1 0.1 0.1" type="box"
                rgba="0.2 0.2 0.8 1" />

            <body name="paddle_handle" pos="0 0 0.15">
                <joint name="paddle_rotate_x" type="hinge" axis="1 0 0" range="-90 90" />
                <geom name="paddle_handle_geom" pos="0 0 0" size="0.05 0.05 0.15" type="box" rgba="0.2 0.2 0.8 1" />
                <body name="paddle_head" pos="0 0 0.15">
                    <geom name="paddle_head_geom" pos="0 0 0.05" size="0.2 0.05 0.2" type="box" rgba="0.8 0.2 0.2 1" />
                </body>
            </body>
        </body>
        <body name="ball" pos="-4 0 1.5">
            <joint name="ball_x" type="free" limited="false" />
            <geom name="ball_geom" pos="0 0 0" size="0.1" type="sphere" rgba="1 0 0 1" />
        </body>
    </worldbody>
    <actuator>
        <motor name="paddle_slide_y" joint="paddle_slide_y" gear="100" />
        <motor name="paddle_slide_z" joint="paddle_slide_z" gear="100" />
        <motor name="paddle_rotate_x" joint="paddle_rotate_x" gear="100" />
    </actuator>
</mujoco>
"""

with open(f"{name_prefix}.xml", "w") as f:
    f.write(xml_content)

In [None]:
env = BallBounceEnv(render_mode="rgb_array")
print("Observation Space Size: ", env.observation_space.shape)
print('Actions Space: ', env.action_space)
env.close()

In [None]:
def make_env():
  env = WallBallEnv(render_mode="rgb_array",
                      min_force=hyperparams["min_force"])
  check_env(env)
  return env

# Create Training environment
env = make_vec_env(make_env,
                   n_envs=hyperparams["n_envs"],
                   monitor_dir=os.path.join(log_dir, "monitor"))

# Create Evaluation environment
env_val = make_vec_env(make_env, n_envs=1)

eval_callback = EvalCallback(env_val,
                             best_model_save_path=log_dir,
                             log_path=log_dir,
                             render=False,
                             deterministic=True,
                             n_eval_episodes=20,
                             eval_freq=hyperparams["eval_freq"])

video_record_callback = VideoRecordCallback(
    save_path=os.path.join(log_dir, "videos"),
    video_length=10_000,
    save_freq=hyperparams["eval_freq"],
    name_prefix=name_prefix)

# Create the callback list
callbackList = CallbackList([video_record_callback,
                             eval_callback])

# learning with tensorboard logging and saving model
model = SAC("MlpPolicy",
            env,
            verbose=0,
            tensorboard_log=os.path.join(log_dir, "tensorboard"))

model.learn(total_timesteps=hyperparams["total_timesteps"],
            callback=callbackList,
            progress_bar=False)

# Save the model
model.save(os.path.join(log_dir, "final_model"))

mean_reward, std_reward = evaluate_policy(model, env)
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")

env.close()
env_val.close()

In [None]:
# Create Evaluation environment
env_val = make_vec_env(make_env, n_envs=1)

# Load the best model
best_model_path = os.path.join(log_dir, "best_model")
best_model = SAC.load(best_model_path, env=env)

mean_reward, std_reward = evaluate_policy(best_model,
                                          env_val,
                                          n_eval_episodes=20)

print(f"Best Model - Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")

# Record video of the best model
env = VecVideoRecorder(env_val, os.path.join(log_dir, "videos"),
                       video_length=10_000,
                       record_video_trigger=lambda x: x == 0,
                       name_prefix="best_model_{}".format(name_prefix))

total_reward = 0
obs = env.reset()
for _ in range(10_000):
    action, _states = best_model.predict(obs, deterministic=True)
    obs, rewards, dones, info = env.step(action)
    total_reward += rewards
    env.render()
    if dones:
        break

env.close()
print(f"Total reward: {total_reward[0]}")

In [None]:
# Load the evaluations.npz file
data = numpy.load(os.path.join(log_dir, "evaluations.npz"))

# Extract the relevant data
timesteps = data['timesteps']
results = data['results']

# Calculate the mean and standard deviation of the results
mean_results = numpy.mean(results, axis=1)
std_results = numpy.std(results, axis=1)

# Plot the results
matplotlib.pyplot.figure()
matplotlib.pyplot.plot(timesteps, mean_results)
matplotlib.pyplot.fill_between(timesteps,
                               mean_results - std_results,
                               mean_results + std_results,
                               alpha=0.3)

matplotlib.pyplot.xlabel('Timesteps')
matplotlib.pyplot.ylabel('Mean Reward')
matplotlib.pyplot.title(f"{rl_type} Performance on {env_str}")
matplotlib.pyplot.show()