## Create a virtual display üîΩ

During the notebook, we'll need to generate a replay video. To do so, with colab, **we need to have a virtual screen to be able to render the environment** (and thus record the frames).

Hence the following cell will install the librairies and create and run a virtual screen üñ•

In [1]:
%%capture
!apt install python-opengl
!apt install ffmpeg
!apt install xvfb
!pip3 install pyvirtualdisplay
!apt install x11-utils
!pip install pyglet

In [2]:
# Virtual display
from pyvirtualdisplay import Display

virtual_display = Display(visible=0, size=(1400, 900))
virtual_display.start()

<pyvirtualdisplay.display.Display at 0x7c00f1eeab40>

In [3]:
# Display video
import glob
import io
import base64
from IPython.display import HTML

def show_video():
    """Embeds the recorded video in the notebook output."""
    mp4list = glob.glob('videos/*.mp4')
    if len(mp4list) > 0:
        video = mp4list[0]
        with io.open(video, 'r+b') as f:
            encoded = base64.b64encode(f.read()).decode()
        # Create an HTML display object for Colab
        return HTML(data=f'<video width="1000" controls><source src="data:video/mp4;base64,{encoded}" type="video/mp4" /></video>')
    else:
        print("No video files found in the 'videos' directory.")

### Install dependencies üîΩ

The first step is to install the dependencies, we‚Äôll install multiple ones:
- `gymnasium`
- `panda-gym`: Contains the robotics arm environments.
- `stable-baselines3`: The SB3 deep reinforcement learning library.
- `huggingface_sb3`: Additional code for Stable-baselines3 to load and upload models from the Hugging Face ü§ó Hub.
- `huggingface_hub`: Library allowing anyone to work with the Hub repositories.

‚è≤ The installation can **take 10 minutes**.

In [4]:
#!pip install stable-baselines3[extra]
!pip install gymnasium



In [5]:
# !pip install huggingface_sb3
# !pip install huggingface_hub
# !pip install panda_gym

## W&B Prerequisites

Install the W&B Python SDK and log in:

In [6]:
!pip install wandb -qU
!pip install -q gym numpy tensorboard
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [7]:
# Log in to your W&B account
import wandb
import random
import math

In [8]:
wandb.login()

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

 2


[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmishra39[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## Setup

In [9]:
import argparse
import os
import random
import time
from distutils.util import strtobool

import gymnasium as gym  # Use gymnasium instead of gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

## Arguments

In [10]:
from dataclasses import dataclass
from typing import Optional

@dataclass
class PPOConfig:
    # Experiment settings
    exp_name: str = "ppo_experiment"
    gym_id: str = "CartPole-v1"
    learning_rate: float = 2.5e-4
    seed: int = 1
    total_timesteps: int = 25000
    torch_deterministic: bool = True
    cuda: bool = True
    track: bool = False
    wandb_project_name: str = "ppo-implementation-details"
    wandb_entity: Optional[str] = None
    capture_video: bool = False

    # Algorithm specific arguments
    num_envs: int = 4
    num_steps: int = 128
    anneal_lr: bool = True
    gae: bool = True
    gamma: float = 0.99
    gae_lambda: float = 0.95
    num_minibatches: int = 4
    update_epochs: int = 4
    norm_adv: bool = True
    clip_coef: float = 0.2
    clip_vloss: bool = True
    ent_coef: float = 0.01
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5
    target_kl: Optional[float] = None

    def __post_init__(self):
        # Computed values
        self.batch_size = int(self.num_envs * self.num_steps)
        self.minibatch_size = int(self.batch_size // self.num_minibatches)

# Create instance with default values
# args = PPOConfig()

# Or customize specific values
# args = PPOConfig(learning_rate=1e-3, num_envs=8, total_timesteps=50000)

# print(f"Batch size: {args.batch_size}")
# print(f"Minibatch size: {args.minibatch_size}")

## Gym Envrionment

In [23]:
# def make_env(gym_id, seed, idx, capture_video, run_name):
#     def thunk():
#       env = gym.make(gym_id, render_mode="rgb_array" if capture_video else None)
#       env = gym.wrappers.RecordEpisodeStatistics(env)
#       if capture_video:
#         env = gym.wrappers.RecordVideo(env, f"videos/{run_name}", episode_trigger=lambda episode_id: True,  # we'll recreate env only on selected episodes
#             disable_logger=True,  # silence ffmpeg logs) # added to avoid conflict with wandb logger
#       env.action_space.seed(seed)
#       env.observation_space.seed(seed)
#       return env

#     return thunk

def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        # base env with render_mode for video
        env = gym.make(env_id, render_mode="rgb_array")
        env = gym.wrappers.RecordEpisodeStatistics(env)

        if capture_video and idx == 0:
            video_folder = os.path.join(config.video_dir, run_name)
            env = gym.wrappers.RecordVideo(
                env,
                video_folder=video_folder,
                name_prefix=f"{env_id}__{run_name}",
                episode_trigger=lambda episode_id: True,  # Record every episode for that env
                disable_logger=True,
            )
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env

    return thunk

## Layer Initialization

In [12]:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
  torch.nn.init.orthogonal_(layer.weight, std)
  torch.nn.init.constant_(layer.bias, bias_const)
  return layer

## Main Loop

In [24]:
args = PPOConfig(track=True, capture_video=True)
run_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
    # wandb.init(
    #     project=args.wandb_project_name,
    #     entity=args.wandb_entity,
    #     sync_tensorboard=True,
    #     config=vars(args),
    #     name=run_name,
    #     monitor_gym=False, # Set to False to prevent conflict with RecordVideo
    #     save_code=True,
    # )

    run = wandb.init(
    project=args.wandb_project_name,
    name=run_name,
    config=dict(
        env_id=args.gym_id,
        total_timesteps=args.total_timesteps,
        num_envs=args.num_envs,
        video_dir="videos",
        video_eval_dir="videos_eval",
        train_video_every_n_episodes=50,
        eval_episodes=5,
    ),
)
    config = wandb.config

writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
    "hyperparameters",
    "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic

device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

0,1
charts/SPS,‚ñÅ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
charts/episodic_length,‚ñÅ‚ñÅ‚ñÉ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÅ‚ñÇ‚ñÑ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÑ‚ñÅ‚ñÇ‚ñÉ‚ñÑ‚ñÅ‚ñÉ‚ñÉ‚ñÑ‚ñÜ‚ñÉ‚ñà‚ñÜ‚ñÇ‚ñÑ‚ñÉ
charts/episodic_return,‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÇ‚ñÅ‚ñÉ‚ñÇ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñÑ‚ñÇ‚ñÖ‚ñÅ‚ñÖ‚ñà‚ñÑ‚ñÅ‚ñÇ‚ñÉ‚ñÇ‚ñÇ‚ñà‚ñÇ‚ñÖ‚ñÑ‚ñà‚ñà‚ñÇ
charts/learning_rate,‚ñà‚ñà‚ñà‚ñá‚ñá‚ñÜ‚ñÜ‚ñà‚ñà‚ñà‚ñá‚ñá‚ñá‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÖ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÑ‚ñÑ‚ñÇ‚ñÇ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÅ‚ñÅ‚ñÅ
global_step,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñà
losses/approx_kl,‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñà‚ñá‚ñÑ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÉ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
losses/clipfrac,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
losses/entropy,‚ñà‚ñà‚ñà‚ñá‚ñá‚ñÜ‚ñÜ‚ñÖ‚ñÑ‚ñÖ‚ñÖ‚ñÑ‚ñÑ‚ñÑ‚ñÉ‚ñÖ‚ñÑ‚ñÖ‚ñÑ‚ñÑ‚ñÇ‚ñÉ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÅ‚ñÑ‚ñÉ‚ñÑ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÉ‚ñÑ‚ñÉ
losses/explained_variance,‚ñÅ‚ñÇ‚ñÅ‚ñÇ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÇ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÇ‚ñÅ‚ñÇ‚ñÅ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÅ‚ñÑ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÑ‚ñÇ‚ñÉ‚ñÑ‚ñà‚ñÉ‚ñá‚ñÑ‚ñÑ
losses/old_approx_kl,‚ñÑ‚ñÑ‚ñÇ‚ñÑ‚ñÑ‚ñÜ‚ñá‚ñÜ‚ñÑ‚ñà‚ñÇ‚ñÉ‚ñÉ‚ñÅ‚ñÖ‚ñÑ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÅ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÉ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ

0,1
charts/SPS,1203
charts/episodic_length,98
charts/episodic_return,98
charts/learning_rate,1e-05
global_step,24576
losses/approx_kl,0.0
losses/clipfrac,0
losses/entropy,0.61023
losses/explained_variance,0.08197
losses/old_approx_kl,0.0




In [25]:
# env setup
# envs = gym.vector.SyncVectorEnv(
#     [make_env(args.gym_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
# )
# obs, infos = envs.reset()
run_name = f"{config.env_id}__ppo__{int(time.time())}"
os.makedirs(config.video_dir, exist_ok=True)

envs = gym.vector.SyncVectorEnv(
    [make_env(config.env_id, seed=args.seed + i, idx=i, capture_video=args.capture_video, run_name=run_name)
     for i in range(config.num_envs)]
)
obs, infos = envs.reset()

In [29]:
import glob

def log_latest_training_video(step: int):
    video_glob = os.path.join(config.video_dir, run_name, "*.mp4")
    video_files = sorted(glob.glob(video_glob), key=os.path.getmtime)
    if not video_files:
        return
    latest_video = video_files[-1]
    wandb.log({"train/video": wandb.Video(latest_video, caption=f"Step {step}")}, step=step)

## Agent Setup

In [26]:
class Agent(nn.Module):
  def __init__(self, envs):
    super(Agent, self).__init__()

    '''
    - Estimates the Value Function $V(s)$. This is a scalar prediction of the total expected reward an agent will receive starting from state s.

    - Tanh is often preferred in PPO (and standard implementations like CleanRL) because it produces smoother gradients. Since the Critic is trying to map states to a continuous value, a smooth activation function helps the Advantage calculation stay stable.

    - Notice std=1. In PPO, initializing the last layer of the critic with a standard deviation of 1 is a common practice to ensure the initial value estimates aren't near zero, helping the policy gradients have a meaningful "baseline" to compare against immediately

    - In PPO, the Critic's job is to reduce variance'''

    self.critic = nn.Sequential(
        layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64,64)),
        nn.Tanh(),
        layer_init(nn.Linear(64,1), std=1),
    )

    self.actor = nn.Sequential(
        layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64,64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01), # smaller std dev ensures similar values for all actions -> probability is similar for picking each action at the beginning
    )

  def get_value(self, x):
    return self.critic(x)

  def get_action_and_value(self, x, action=None):
    logits = self.actor(x)
    probs = Categorical(logits=logits)
    if action is None:
      action = probs.sample()
    return action, probs.log_prob(action), probs.entropy(), self.critic(x)

## Training

In [27]:
agent = Agent(envs).to(device)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

# ALGO Logic: Storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)

# Size: N (num_envs) * T (timesteps)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros_like(logprobs).to(device)
dones = torch.zeros_like(logprobs).to(device)
values = torch.zeros_like(logprobs).to(device)

# TRY NOT TO MODIFY: start the game
global_step = 0
next_obs, _ = envs.reset()
next_obs = torch.Tensor(next_obs).to(device)
next_done = torch.zeros(args.num_envs).to(device)
num_updates = args.total_timesteps // args.batch_size
print(f"total_timesteps: {args.total_timesteps}")
print(f"batch_size: {args.batch_size}")
print(f"num_updates: {num_updates}")

total_timesteps: 25000
batch_size: 512
num_updates: 48


## Training Loop

In [30]:
start_time = time.time()
for update in range(1, num_updates + 1): # ALGO 1: Line #1
  # lr annealing
  if args.anneal_lr:
    frac = 1.0 - (update - 1.0) / num_updates
    lrnow = frac * args.learning_rate
    optimizer.param_groups[0]["lr"] = lrnow
    writer.add_scalar("charts/learning_rate", lrnow, global_step)
    if args.track:
            wandb.log({
                "charts/learning_rate": lrnow,
            }, step=global_step)

            if args.capture_video and (update % 10 == 0):
              log_latest_training_video(global_step)


  # policy rollout
  for step in range(0, args.num_steps): # ALGO1: Line #3
    global_step += 1 * args.num_envs
    obs[step] = next_obs
    dones[step] = next_done

    # Algo Logic: Action logic
    with torch.no_grad():
      action, log_prob, _, value = agent.get_action_and_value(next_obs)
      values[step] = value.squeeze(-1) # dimension: (4,1)
    actions[step] = action # dim: (4, action_size)
    logprobs[step] = log_prob

    # TRY NOT TO MODIFY: execute the game and log data.
    next_obs, reward, terminated, truncated, info = envs.step(action.cpu().numpy())
    done = np.logical_or(terminated, truncated)  # Combine terminated and truncated into done

    rewards[step] = torch.tensor(reward).to(device).view(-1)
    next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)

    # Log episodic returns when episodes finish
    if "episode" in info and info["_episode"].any():
      finished_indices = np.where(info["_episode"])[0]

      for idx in finished_indices:
        episodic_return = info["episode"]["r"][idx]
        episodic_length = info["episode"]["l"][idx]

        print(f"global_step={global_step}, episodic_return={episodic_return}")
        writer.add_scalar("charts/episodic_return", episodic_return, global_step)
        writer.add_scalar("charts/episodic_length", episodic_length, global_step)

        # Optionally log to wandb
        if args.track:
            wandb.log({
                "charts/episodic_return": episodic_return,
                "charts/episodic_length": episodic_length,
            }, step=global_step)
        break  # Log only the first finished episode per step

  # bootstrap value if not done
  with torch.no_grad():
    next_value = agent.get_value(next_obs).reshape(1,-1)
    if args.gae:
      advantages = torch.zeros_like(rewards).to(device) # dim: (num_env, num_timesteps) = (N,T)
      lastgaelam = 0
      for t in reversed(range(args.num_steps)): # ALGO1: Line #4
          if t == args.num_steps - 1:
              nextnonterminal = 1.0 - next_done
              nextvalues = next_value
          else:
              nextnonterminal = 1.0 - dones[t + 1]
              nextvalues = values[t + 1]
          delta = rewards[t] + nextvalues * nextnonterminal * args.gamma - values[t] # Eq. 12 (Temporal Difference Error)
          advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam # Eq. 11 (Recursive Form)
      returns = advantages + values
    else:
      returns = torch.zeros_like(rewards).to(device)
      for t in reversed(range(args.num_steps)):
          if t == args.num_steps - 1:
              nextnonterminal = 1.0 - next_done
              next_return = next_value
          else:
              nextnonterminal = 1.0 - dones[t + 1]
              next_return = returns[t + 1]
          returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
      advantages = returns - values
  # flatten the batch
  b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
  b_logprobs = logprobs.reshape(-1)
  b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
  b_advantages = advantages.reshape(-1)
  b_returns = returns.reshape(-1)
  b_values = values.reshape(-1)

  # Optimizing the policy and value network
  b_inds = np.arange(args.batch_size) # 512
  clipfracs = []
  for epoch in range(args.update_epochs):
    np.random.shuffle(b_inds)
    for start in range(0, args.batch_size, args.minibatch_size):
      end = start + args.minibatch_size
      mb_inds = b_inds[start:end]
      _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
      logratio = newlogprob - b_logprobs[mb_inds]
      ratio = logratio.exp()

      with torch.no_grad():
        # calculate approx_kl http://joschu.net/blog/kl-approx.html
        old_approx_kl = (-logratio).mean()
        approx_kl = ((ratio - 1) - logratio).mean()
        clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]

      mb_advantages = b_advantages[mb_inds]
      if args.norm_adv:
        mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

      # Policy loss
      pg_loss1 = -mb_advantages * ratio # Eq.6 (L_CPI)
      pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) # Eq.7 (L_CLIP)
      pg_loss = torch.max(pg_loss1, pg_loss2).mean()

      # Value loss
      newvalue = newvalue.view(-1)
      if args.clip_vloss:
        v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
        v_clipped = b_values[mb_inds] + torch.clamp(
                        newvalue - b_values[mb_inds],
                        -args.clip_coef,
                        args.clip_coef,
                    )
        v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
        v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
        v_loss = 0.5 * v_loss_max.mean()
      else:
        v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

      # Entropy loss
      entropy_loss = entropy.mean()
      # Overall loss
      loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef

      # Backprop
      optimizer.zero_grad()
      loss.backward()
      # Clip gradient
      nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
      optimizer.step()

    if args.target_kl is not None:
      if approx_kl > args.target_kl:
        break
  y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
  var_y = np.var(y_true)
  explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

  # TRY NOT TO MODIFY: record rewards for plotting purposes
  writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
  writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
  writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
  writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
  writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
  writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
  writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
  writer.add_scalar("losses/explained_variance", explained_var, global_step)
  print("SPS:", int(global_step / (time.time() - start_time)))
  writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

envs.close()
writer.close()



global_step=24592, episodic_return=109.0
global_step=24760, episodic_return=97.0
global_step=24952, episodic_return=47.0
global_step=24964, episodic_return=120.0
global_step=25072, episodic_return=119.0
SPS: 53630
global_step=25216, episodic_return=62.0
global_step=25248, episodic_return=43.0
global_step=25284, episodic_return=82.0
global_step=25436, episodic_return=120.0
global_step=25472, episodic_return=55.0
SPS: 14458
global_step=25736, episodic_return=129.0
global_step=25788, episodic_return=87.0
global_step=25812, episodic_return=84.0
global_step=25884, episodic_return=149.0
global_step=25952, episodic_return=53.0
global_step=26064, episodic_return=68.0
global_step=26088, episodic_return=68.0
SPS: 6348
global_step=26160, episodic_return=51.0
global_step=26268, episodic_return=95.0
global_step=26288, episodic_return=31.0
global_step=26592, episodic_return=125.0
SPS: 4550
global_step=26628, episodic_return=140.0
global_step=26756, episodic_return=116.0
global_step=26796, episodic_r



SPS: 2120
global_step=29464, episodic_return=246.0
global_step=29608, episodic_return=141.0
SPS: 2006
global_step=30004, episodic_return=285.0
SPS: 1886
global_step=30328, episodic_return=215.0
global_step=30420, episodic_return=399.0
global_step=30488, episodic_return=120.0
SPS: 1453
global_step=30820, episodic_return=302.0
global_step=30872, episodic_return=112.0
SPS: 1355
global_step=31424, episodic_return=233.0
global_step=31716, episodic_return=72.0
SPS: 1322
global_step=31748, episodic_return=354.0
global_step=31764, episodic_return=222.0
global_step=31852, episodic_return=257.0
global_step=32092, episodic_return=93.0
SPS: 1242
global_step=32536, episodic_return=192.0
global_step=32608, episodic_return=214.0
global_step=32624, episodic_return=192.0
SPS: 1185
global_step=32836, episodic_return=185.0
global_step=33088, episodic_return=137.0
SPS: 1139
global_step=33404, episodic_return=141.0
SPS: 1132
global_step=33868, episodic_return=194.0
global_step=33976, episodic_return=142.0




SPS: 1088
global_step=34616, episodic_return=159.0
global_step=34728, episodic_return=214.0
SPS: 1062
global_step=34944, episodic_return=203.0
SPS: 1038
global_step=35332, episodic_return=365.0
SPS: 925
global_step=36068, episodic_return=362.0
global_step=36076, episodic_return=185.0
SPS: 878
global_step=36376, episodic_return=411.0
global_step=36500, episodic_return=388.0
global_step=36564, episodic_return=123.0
global_step=36780, episodic_return=175.0
SPS: 839
global_step=37308, episodic_return=232.0
SPS: 835
global_step=37768, episodic_return=300.0
global_step=37808, episodic_return=256.0
SPS: 812
SPS: 812
global_step=38420, episodic_return=162.0
global_step=38504, episodic_return=500.0
global_step=38696, episodic_return=221.0
global_step=38812, episodic_return=375.0
SPS: 788
global_step=39276, episodic_return=213.0




SPS: 782
global_step=39492, episodic_return=198.0
global_step=39712, episodic_return=301.0
SPS: 762
global_step=40192, episodic_return=174.0
global_step=40212, episodic_return=124.0
SPS: 748
global_step=40452, episodic_return=293.0
global_step=40816, episodic_return=500.0
SPS: 749
global_step=40964, episodic_return=192.0
global_step=41160, episodic_return=236.0
global_step=41300, episodic_return=211.0
SPS: 736
global_step=41516, episodic_return=137.0
global_step=41568, episodic_return=187.0
SPS: 718
global_step=42056, episodic_return=188.0
global_step=42100, episodic_return=132.0
global_step=42308, episodic_return=197.0
global_step=42344, episodic_return=295.0
SPS: 706
global_step=42728, episodic_return=167.0
SPS: 707
global_step=43060, episodic_return=187.0
global_step=43096, episodic_return=248.0
global_step=43344, episodic_return=249.0
SPS: 691
global_step=43780, episodic_return=170.0
global_step=43900, episodic_return=209.0
SPS: 673
global_step=44076, episodic_return=336.0
global_s



SPS: 674
global_step=44764, episodic_return=215.0
global_step=44808, episodic_return=182.0
global_step=44920, episodic_return=123.0
SPS: 664
global_step=45244, episodic_return=119.0
global_step=45476, episodic_return=138.0
global_step=45552, episodic_return=260.0
global_step=45564, episodic_return=188.0
SPS: 657
global_step=45860, episodic_return=153.0
SPS: 650
global_step=46176, episodic_return=155.0
global_step=46320, episodic_return=188.0
global_step=46376, episodic_return=224.0
SPS: 651
global_step=46716, episodic_return=134.0
global_step=46976, episodic_return=278.0
SPS: 641
global_step=47124, episodic_return=186.0
global_step=47184, episodic_return=215.0
global_step=47344, episodic_return=156.0
SPS: 642
global_step=48116, episodic_return=192.0
SPS: 636
global_step=48168, episodic_return=297.0
global_step=48368, episodic_return=310.0
SPS: 619
global_step=48768, episodic_return=149.0
global_step=48980, episodic_return=448.0
global_step=49076, episodic_return=239.0
global_step=49088

## Evaluation

In [31]:
def make_eval_env():
    env = gym.make(config.env_id, render_mode="rgb_array")
    eval_folder = config.video_eval_dir
    os.makedirs(eval_folder, exist_ok=True)
    env = gym.wrappers.RecordVideo(
        env,
        video_folder=eval_folder,
        name_prefix="eval",
        episode_trigger=lambda episode_id: True,
        disable_logger=True,
    )
    return env

def evaluate_and_log_videos(model: Agent, episodes: int = None):
    if episodes is None:
        episodes = config.eval_episodes

    env = make_eval_env()
    for ep in range(episodes):
        obs, info = env.reset()
        done = False
        truncated = False
        total_r = 0.0

        while not (done or truncated):
            obs_t = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
            with torch.no_grad():
                action, _, _, _ = model.get_action_and_value(obs_t)
            obs, reward, done, truncated, info = env.step(action.cpu().numpy()[0])
            total_r += reward

        wandb.log({"eval/return": total_r}, step=global_step)

    env.close()

    # upload all eval videos
    video_files = sorted(glob.glob(os.path.join(config.video_eval_dir, "*.mp4")), key=os.path.getmtime)
    for vf in video_files[-episodes:]:
        wandb.log({"eval/video": wandb.Video(vf)}, step=global_step)

In [32]:
# Run the evaluation
evaluate_and_log_videos(agent, episodes=config.eval_episodes)

  logger.warn(
