In [1]:
# !apt-get install -y \
#     libgl1-mesa-dev \
#     libgl1-mesa-glx \
#     libglew-dev \
#     libosmesa6-dev \
#     software-properties-common

# !apt-get install -y patchelf

# !apt-get update --fix-missing
# !pip install stable-baselines3
# !pip install mujoco
# !pip install  --upgrade gymnasium==0.29
# !pip install free-mujoco-py

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
software-properties-common is already the newest version (0.99.22.7).
The following additional packages will be installed:
  libegl-dev libgl-dev libgles-dev libgles1 libglu1-mesa libglu1-mesa-dev
  libglvnd-core-dev libglvnd-dev libglx-dev libopengl-dev libosmesa6
The following NEW packages will be installed:
  libegl-dev libgl-dev libgl1-mesa-dev libgl1-mesa-glx libgles-dev libgles1
  libglew-dev libglu1-mesa libglu1-mesa-dev libglvnd-core-dev libglvnd-dev
  libglx-dev libopengl-dev libosmesa6 libosmesa6-dev
0 upgraded, 15 newly installed, 0 to remove and 16 not upgraded.
Need to get 3,952 kB of archives.
After this operation, 18.7 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/main amd64 libglx-dev amd64 1.4.0-1 [14.1 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/main amd64 libgl-dev amd64 1.4.0-1 [101 kB]
Get:3 http://archive.ubuntu.com/ubuntu 

In [2]:
import os
import time
import wandb
import random

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from stable_baselines3.common.buffers import ReplayBuffer

In [3]:
def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        else:
            env = gym.make(env_id)

        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env

    return thunk

  and should_run_async(code)


In [15]:
class SoftQNetwork(nn.Module):

    def __init__(self, env):
        super().__init__()
        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)

    def forward(self, x, a):
        x = torch.cat([x,a], 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

  and should_run_async(code)


In [13]:
class Actor(nn.Module):

    def __init__(self, env):
        super().__init__()
        self.fc1 = nn.Linear(np.array(env.observation_space.shape).prod(), 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mean = nn.Linear(256, np.prod(env.single_action_space.shape))
        self.fc_logstd = nn.Linear(256, np.prod(env.single_action_space.shape))
        self.log_std_min = -5
        self.log_std_max = 2
        self.register_buffer(
            "action_scale", torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32)
        )

        self.register_buffer(
            "action_bias", torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32)
        )

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mean = self.fc_mean(x)
        log_std = self.fc_logstd(x)
        log_std = torch.tanh(log_std)
        log_std = self.log_std_min + 0.5 *  (self.log_std_max - self.log_std_min) * ( log_std + 1)
        return mean, log_std

    def get_actions(self, x):
        mean, log_std = self.forward(x)
        std = log_std.exp()
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log( self.action_scale * (1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean


  and should_run_async(code)


In [1]:
def train(env_id,
          seed,
          total_timesteps,
          buffer_size,
          gamma,
          tau,
          batch_size,
          learning_starts,
          policy_lr,
          q_lr,
          policy_frequency,
          target_network_frequency,
          noise_clip,
          alpha):

    run_name = f"{env_id}__{seed}__{int(time.time())}"
    wandb.init(
        project="sac-mujoco-benchmark",
        config={
            "env":env_id,
            "seed":seed,
            "timesteps":total_timesteps,
            "buffer_size":buffer_size,
            "gamma":gamma,
            "tau":tau,
            "batch_size":batch_size,
            "learning_starts":learning_starts,
            "policy_lr":policy_lr,
            "q_lr",q_lr,
            "policy_frequency":policy_frequency,
            "target_network_frequency":target_network_frequency,
            "noise_clip":noise_clip,
            "alpha":alpha,
        },
        sync_tensorboard=True,
        monitor_gym=True,
        name=run_name
    )
    writer = SummaryWriter(f"runs/{run_name}")

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    envs = gym.vector.SyncVectorEnv([make_env(env_id, seed, 0, True, run_name)])

    max_action = float(envs.single_action_space.high[0])

    actor = Actor(envs).to(device)
    qf1 = SoftQNetwork(envs).to(device)
    qf2 = SoftQNetwork(envs).to(device)
    qf1_target = SoftQNetwork(envs).to(device)
    qf2_target = SoftQNetwork(envs).to(device)
    qf1_target.load_state_dict(qf1.state_dict())
    qf2_target.load_state_dict(qf2.state_dict())
    q_optimiser = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=q_lr)
    actor_optimizer = optim.Adam(actor.parameters(), lr=policy_lr)

    envs.single_observation_space.dtype = np.float32
    rb = ReplayBuffer(
        buffer_size,
        envs.single_observation_space,
        envs.single_action_space,
        device,
        handle_timeout_termination=False,
    )


    start_time = time.time()

    obs,_ = envs.reset(seed=seed)
    for global_step in range(total_timesteps):
        if global_step < learning_starts:
            actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
        else:
            actions, _, _ = actor.get_actions(torch.Tensor(obs).to(device))
            actions = actions.detach().cpu().numpy()

        next_obs, rewards, terminated, truncated, infos = envs.step(actions)

        if "final_info" in infos:
            for info in infos["final_info"]:
                print(f"global_step={global_step} episodic_return={info['episode']['r']}")
                writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
                writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
                break

        real_next_obs = next_obs.copy()
        for idx,d in enumerate(truncated):
            if d:
                real_next_obs[idx] = infos["final_observation"][idx]
        rb.add(obs, real_next_obs, actions, rewards, terminated, infos)

        obs = next_obs

        if global_step > learning_starts:
            data = rb.sample(batch_size)
            with torch.no_grad():
                next_state_actions, next_state_log_pi, _ = actor.get_actions(data.next_observations)
                qf1_next_target = qf1_target(data.next_observations, next_state_actions)
                qf2_next_target = qf2_target(data.next_observations, next_state_actions)
                min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
                next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * gamma * (min_qf_next_target).view(-1)

            qf1_a_values = qf1(data.observations, data.actions).view(-1)
            qf2_a_values = qf2(data.observations, data.actions).view(-1)
            qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
            qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
            qf_loss = qf1_loss + qf2_loss

            q_optimiser.zero_grad()
            qf_loss.backward()
            q_optimiser.step()

            if global_step % policy_frequency == 0:
                pi, log_pi, _ = actor.get_actions(data.observations)
                qf1_pi = qf1(data.observations, pi)
                qf2_pi = qf2(data.observations, pi)
                min_qf_pi = torch.min(qf1_pi, qf2_pi).view(-1)
                actor_loss = ((alpha * log_pi) - min_qf_pi).mean()

                actor_optimizer.zero_grad()
                actor_loss.backward()
                actor_optimizer.step()

            if global_step % target_network_frequency == 0:
                for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
                    target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
                for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
                    target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

            if global_step % 100 == 0:
                writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
                writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step)
                writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
                writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
                writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
                writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
                writer.add_scalar("losses/alpha", alpha, global_step)
                writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

    envs.close()
    writer.close()
    wandb.finish()


In [8]:
env = {"hopper":"Hopper-v2","humanoid":"Humanoid-v2","halfCheetah":"HalfCheetah-v2","ant":"Ant-v2"}
seed = 1
total_timesteps = 500000
policy_learning_rate = 0.0003
q_learning_rate = 0.0001
buffer_size = 100000
gamma = 0.99
tau = 0.005
batch_size = 256
exploration_noise = 0.1
learning_starts = 50000
policy_frequency = 2
target_network_frequency = 1
alpha = 0.2
noise_clip = 0.5

In [25]:
train(env["halfCheetah"],
      seed,
      total_timesteps,
      buffer_size,
      gamma,
      tau,
      batch_size,
      learning_starts,
      policy_learning_rate,
      q_learning_rate,
      policy_frequency,
      target_network_frequency,
      noise_clip,
      alpha)

  logger.deprecation(
  logger.deprecation(
  logger.warn(
  logger.warn(
  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")
  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")


Moviepy - Building video /content/videos/HalfCheetah-v2__1__1691937044/rl-video-episode-0.mp4.
Moviepy - Writing video /content/videos/HalfCheetah-v2__1__1691937044/rl-video-episode-0.mp4





Moviepy - Done !
Moviepy - video ready /content/videos/HalfCheetah-v2__1__1691937044/rl-video-episode-0.mp4
global_step=999 episodic_return=[-290.36285]
Moviepy - Building video /content/videos/HalfCheetah-v2__1__1691937044/rl-video-episode-1.mp4.
Moviepy - Writing video /content/videos/HalfCheetah-v2__1__1691937044/rl-video-episode-1.mp4





Moviepy - Done !
Moviepy - video ready /content/videos/HalfCheetah-v2__1__1691937044/rl-video-episode-1.mp4
global_step=1999 episodic_return=[-265.64658]
global_step=2999 episodic_return=[-277.56036]
global_step=3999 episodic_return=[-177.74356]
global_step=4999 episodic_return=[-304.93332]
global_step=5999 episodic_return=[-218.62822]
global_step=6999 episodic_return=[-216.49706]
global_step=7999 episodic_return=[-349.96634]
Moviepy - Building video /content/videos/HalfCheetah-v2__1__1691937044/rl-video-episode-8.mp4.
Moviepy - Writing video /content/videos/HalfCheetah-v2__1__1691937044/rl-video-episode-8.mp4





Moviepy - Done !
Moviepy - video ready /content/videos/HalfCheetah-v2__1__1691937044/rl-video-episode-8.mp4
global_step=8999 episodic_return=[-348.958]
global_step=9999 episodic_return=[-166.24995]
global_step=10999 episodic_return=[-198.58427]
global_step=11999 episodic_return=[-193.47505]
global_step=12999 episodic_return=[-117.61518]
global_step=13999 episodic_return=[-364.74878]
global_step=14999 episodic_return=[-171.0557]
global_step=15999 episodic_return=[-285.139]
global_step=16999 episodic_return=[-218.25229]
global_step=17999 episodic_return=[-205.81152]
global_step=18999 episodic_return=[-326.8549]
global_step=19999 episodic_return=[-387.1138]
global_step=20999 episodic_return=[-302.94595]
global_step=21999 episodic_return=[-160.01483]
global_step=22999 episodic_return=[-145.06184]
global_step=23999 episodic_return=[-34.949318]
global_step=24999 episodic_return=[-326.93106]
global_step=25999 episodic_return=[-280.37436]
global_step=26999 episodic_return=[-198.99377]
Moviepy 



Moviepy - Done !
Moviepy - video ready /content/videos/HalfCheetah-v2__1__1691937044/rl-video-episode-27.mp4
global_step=27999 episodic_return=[-269.1459]
global_step=28999 episodic_return=[-211.94794]
global_step=29999 episodic_return=[-324.1697]
global_step=30999 episodic_return=[-298.2644]
global_step=31999 episodic_return=[-222.77611]
global_step=32999 episodic_return=[-229.46997]
global_step=33999 episodic_return=[-290.52527]
global_step=34999 episodic_return=[-243.12852]
global_step=35999 episodic_return=[-337.54956]
global_step=36999 episodic_return=[-328.91647]
global_step=37999 episodic_return=[-256.67102]
global_step=38999 episodic_return=[-268.15323]
global_step=39999 episodic_return=[-367.25827]
global_step=40999 episodic_return=[-282.81506]
global_step=41999 episodic_return=[-291.27274]
global_step=42999 episodic_return=[-416.1082]
global_step=43999 episodic_return=[-195.68433]
global_step=44999 episodic_return=[-273.47717]
global_step=45999 episodic_return=[-317.45792]
gl



Moviepy - Done !
Moviepy - video ready /content/videos/HalfCheetah-v2__1__1691937044/rl-video-episode-64.mp4
global_step=64999 episodic_return=[-527.8817]
global_step=65999 episodic_return=[-517.5039]
global_step=66999 episodic_return=[-517.91736]
global_step=67999 episodic_return=[-487.32184]
global_step=68999 episodic_return=[-489.32846]
global_step=69999 episodic_return=[-485.72675]
global_step=70999 episodic_return=[-491.18024]
global_step=71999 episodic_return=[-499.59402]
global_step=72999 episodic_return=[-422.36078]
global_step=73999 episodic_return=[-461.14386]
global_step=74999 episodic_return=[-473.17523]
global_step=75999 episodic_return=[-433.60135]
global_step=76999 episodic_return=[-448.52777]
global_step=77999 episodic_return=[-444.9023]
global_step=78999 episodic_return=[-428.50986]
global_step=79999 episodic_return=[-453.74527]
global_step=80999 episodic_return=[-434.34912]
global_step=81999 episodic_return=[-437.01117]
global_step=82999 episodic_return=[-467.45078]
g



Moviepy - Done !
Moviepy - video ready /content/videos/HalfCheetah-v2__1__1691937044/rl-video-episode-125.mp4
global_step=125999 episodic_return=[-376.3235]
global_step=126999 episodic_return=[-376.86282]
global_step=127999 episodic_return=[-366.40778]
global_step=128999 episodic_return=[-381.72992]
global_step=129999 episodic_return=[-382.21994]
global_step=130999 episodic_return=[-386.13715]
global_step=131999 episodic_return=[-375.02954]
global_step=132999 episodic_return=[-382.4894]
global_step=133999 episodic_return=[-382.91943]
global_step=134999 episodic_return=[-382.70465]
global_step=135999 episodic_return=[-375.74857]
global_step=136999 episodic_return=[-374.56204]
global_step=137999 episodic_return=[-345.268]
global_step=138999 episodic_return=[-362.09613]
global_step=139999 episodic_return=[-379.59378]
global_step=140999 episodic_return=[-361.62286]
global_step=141999 episodic_return=[-355.4788]
global_step=142999 episodic_return=[-372.67386]
global_step=143999 episodic_ret



Moviepy - Done !
Moviepy - video ready /content/videos/HalfCheetah-v2__1__1691937044/rl-video-episode-216.mp4
global_step=216999 episodic_return=[-316.0373]
global_step=217999 episodic_return=[-321.3237]
global_step=218999 episodic_return=[-303.06613]
global_step=219999 episodic_return=[-315.2012]
global_step=220999 episodic_return=[-328.48187]
global_step=221999 episodic_return=[-331.32974]
global_step=222999 episodic_return=[-305.25485]
global_step=223999 episodic_return=[-283.40903]
global_step=224999 episodic_return=[-332.05267]
global_step=225999 episodic_return=[-295.77548]
global_step=226999 episodic_return=[-333.07455]
global_step=227999 episodic_return=[-293.27646]
global_step=228999 episodic_return=[-286.24738]
global_step=229999 episodic_return=[-322.3173]
global_step=230999 episodic_return=[-300.1896]
global_step=231999 episodic_return=[-291.33475]
global_step=232999 episodic_return=[-308.9093]
global_step=233999 episodic_return=[-282.49658]
global_step=234999 episodic_retu



Moviepy - Done !
Moviepy - video ready /content/videos/HalfCheetah-v2__1__1691937044/rl-video-episode-343.mp4
global_step=343999 episodic_return=[679.85645]
global_step=344999 episodic_return=[514.1515]
global_step=345999 episodic_return=[681.74585]
global_step=346999 episodic_return=[771.13873]
global_step=347999 episodic_return=[779.39874]
global_step=348999 episodic_return=[795.208]
global_step=349999 episodic_return=[396.26105]
global_step=350999 episodic_return=[769.7005]
global_step=351999 episodic_return=[887.06036]
global_step=352999 episodic_return=[625.3501]
global_step=353999 episodic_return=[933.1343]
global_step=354999 episodic_return=[441.64984]
global_step=355999 episodic_return=[1071.3202]
global_step=356999 episodic_return=[1030.1238]
global_step=357999 episodic_return=[1083.3474]
global_step=358999 episodic_return=[955.9901]
global_step=359999 episodic_return=[1111.9424]
global_step=360999 episodic_return=[1115.705]
global_step=361999 episodic_return=[1046.1641]
globa



Moviepy - Done !
Moviepy - video ready /content/videos/HalfCheetah-v2__1__1691937044/rl-video-episode-512.mp4
global_step=512999 episodic_return=[2204.3438]
global_step=513999 episodic_return=[2224.696]
global_step=514999 episodic_return=[2187.8833]
global_step=515999 episodic_return=[2258.0798]
global_step=516999 episodic_return=[2227.1199]
global_step=517999 episodic_return=[2248.6084]
global_step=518999 episodic_return=[2259.7312]
global_step=519999 episodic_return=[2234.4626]
global_step=520999 episodic_return=[2268.6958]
global_step=521999 episodic_return=[2291.1252]
global_step=522999 episodic_return=[2238.9736]
global_step=523999 episodic_return=[2371.8618]
global_step=524999 episodic_return=[2253.3445]
global_step=525999 episodic_return=[2272.8533]
global_step=526999 episodic_return=[2302.851]
global_step=527999 episodic_return=[2235.0808]
global_step=528999 episodic_return=[2289.522]
global_step=529999 episodic_return=[2180.5186]
global_step=530999 episodic_return=[2275.1746]




Moviepy - Done !
Moviepy - video ready /content/videos/HalfCheetah-v2__1__1691937044/rl-video-episode-729.mp4
global_step=729999 episodic_return=[2265.598]
global_step=730999 episodic_return=[2237.3757]
global_step=731999 episodic_return=[2258.6562]
global_step=732999 episodic_return=[2355.3076]
global_step=733999 episodic_return=[2202.4995]
global_step=734999 episodic_return=[2349.037]
global_step=735999 episodic_return=[2308.035]
global_step=736999 episodic_return=[2356.153]
global_step=737999 episodic_return=[2338.9646]
global_step=738999 episodic_return=[2189.5933]
global_step=739999 episodic_return=[2256.6716]
global_step=740999 episodic_return=[2173.1855]
global_step=741999 episodic_return=[2234.1045]
global_step=742999 episodic_return=[2227.5483]
global_step=743999 episodic_return=[2342.8767]
global_step=744999 episodic_return=[2330.8066]
global_step=745999 episodic_return=[2340.1436]
global_step=746999 episodic_return=[2335.048]
global_step=747999 episodic_return=[2387.3328]
gl

KeyboardInterrupt: ignored