# CleanRL's baseline setup for envpool using PufferLib install

# Install

In [None]:
# Don't restart runtime until both finish (it'll ask after the first line)
!pip install git+https://github.com/pufferai/pufferlib.git@1.0
!pip install gymnasium[atari,accept-rom-license]==0.29.1 tensorboard==2.11.2 stable_baselines3==2.1.0 torch wandb envpool kron-torch

# wandb

In [None]:
import os
from getpass import getpass

wandb_key = getpass("Enter your Wandb API key: ")
os.environ['WANDB_API_KEY'] = wandb_key

print("Wandb API key has been set as an environment variable.")

# Train

In [None]:
import os
import math
import random
import time
from collections import deque
from IPython.display import display, Image
import matplotlib.pyplot as plt
import zipfile

import envpool
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from kron_torch import Kron
from kron_torch.kron import precond_update_prob_schedule


class RecordEpisodeStatistics(gym.Wrapper):
    def __init__(self, env, deque_size=100):
        super().__init__(env)
        self.num_envs = getattr(env, "num_envs", 1)
        self.episode_returns = None
        self.episode_lengths = None

    def reset(self, **kwargs):
        observations = super().reset(**kwargs)
        self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
        self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
        self.lives = np.zeros(self.num_envs, dtype=np.int32)
        self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
        self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
        return observations

    def step(self, action):
        observations, rewards, dones, infos = super().step(action)
        self.episode_returns += infos["reward"]
        self.episode_lengths += 1
        self.returned_episode_returns[:] = self.episode_returns
        self.returned_episode_lengths[:] = self.episode_lengths
        self.episode_returns *= 1 - infos["terminated"]
        self.episode_lengths *= 1 - infos["terminated"]
        infos["r"] = self.returned_episode_returns
        infos["l"] = self.returned_episode_lengths
        return (observations, rewards, dones, infos)


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class Agent(nn.Module):
    def __init__(self, envs):
        super().__init__()
        self.network = nn.Sequential(
            layer_init(nn.Conv2d(4, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(64 * 7 * 7, 512)),
            nn.ReLU(),
        )
        self.actor = layer_init(nn.Linear(512, envs.single_action_space.n), std=0.01)
        self.critic = layer_init(nn.Linear(512, 1), std=1)

    def get_value(self, x):
        return self.critic(self.network(x / 255.0))

    def get_action_and_value(self, x, action=None):
        hidden = self.network(x / 255.0)
        logits = self.actor(hidden)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)


def print_model_summary(model):
    def print_layer_info(module, name=''):
        if hasattr(module, 'weight'):
            print(f"{name}: {module.__class__.__name__}, "
                  f"Weight shape: {tuple(module.weight.shape)}")
            if hasattr(module, 'bias') and module.bias is not None:
                print(f"  Bias shape: {tuple(module.bias.shape)}")
        elif isinstance(module, (nn.Flatten, nn.ReLU)):
            print(f"{name}: {module.__class__.__name__}")

    print("Model Summary:")
    for name, layer in model.named_children():
        if isinstance(layer, nn.Sequential):
            for i, sublayer in enumerate(layer):
                print_layer_info(sublayer, f"{name}.{i}")
        else:
            print_layer_info(layer, name)

    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\nTotal trainable parameters: {total_params}")


def main(
    exp_name: str = "ppo_atari",
    seed: int = 1,
    torch_deterministic: bool = True,
    cuda: bool = True,
    track: bool = True,
    wandb_project_name: str = "cleanRL",
    wandb_entity: str = None,
    env_id: str = "Breakout-v5",
    total_timesteps: int = 10000000,
    optimizer: str = "adam",
    learning_rate: float = 2.5e-4,
    weight_decay: float = 0.0,
    num_envs: int = 8,
    num_steps: int = 128,
    anneal_lr: bool = True,
    gamma: float = 0.99,
    gae_lambda: float = 0.95,
    num_minibatches: int = 4,
    update_epochs: int = 4,
    norm_adv: bool = True,
    clip_coef: float = 0.1,
    clip_vloss: bool = True,
    ent_coef: float = 0.01,
    vf_coef: float = 0.5,
    max_grad_norm: float = 0.5,
    target_kl: float = None,
):
    batch_size: int = int(num_envs * num_steps)
    minibatch_size: int = int(batch_size // num_minibatches)
    num_iterations: int = total_timesteps // batch_size
    run_name = f"{env_id}__{exp_name}__{seed}__{int(time.time())}"
    if track:
        import wandb

        wandb.init(
            project=wandb_project_name,
            entity=wandb_entity,
            sync_tensorboard=True,
            config={
                "exp_name": exp_name,
                "seed": seed,
                "torch_deterministic": torch_deterministic,
                "cuda": cuda,
                "track": track,
                "wandb_project_name": wandb_project_name,
                "wandb_entity": wandb_entity,
                "env_id": env_id,
                "total_timesteps": total_timesteps,
                "optimizer": optimizer,
                "learning_rate": learning_rate,
                "weight_decay": weight_decay,
                "num_envs": num_envs,
                "num_steps": num_steps,
                "anneal_lr": anneal_lr,
                "gamma": gamma,
                "gae_lambda": gae_lambda,
                "num_minibatches": num_minibatches,
                "update_epochs": update_epochs,
                "norm_adv": norm_adv,
                "clip_coef": clip_coef,
                "clip_vloss": clip_vloss,
                "ent_coef": ent_coef,
                "vf_coef": vf_coef,
                "max_grad_norm": max_grad_norm,
                "target_kl": target_kl,
            },
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s"
        % ("\n".join([f"|{key}|{value}|" for key, value in locals().items()])),
    )

    # TRY NOT TO MODIFY: seeding
    def set_seed_everywhere(seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    set_seed_everywhere(seed)

    torch.backends.cudnn.deterministic = torch_deterministic
    torch.backends.cudnn.benchmark = False

    device = torch.device("cuda" if torch.cuda.is_available() and cuda else "cpu")

    # env setup
    envs = envpool.make(
        env_id,
        env_type="gym",
        num_envs=num_envs,
        episodic_life=True,
        reward_clip=True,
        seed=seed,
    )
    envs.num_envs = num_envs
    envs.single_action_space = envs.action_space
    envs.single_observation_space = envs.observation_space
    envs = RecordEpisodeStatistics(envs)
    assert isinstance(
        envs.action_space, gym.spaces.Discrete
    ), "only discrete action space is supported"

    agent = Agent(envs).to(device)
    agent.train()

    # Print model summary
    print_model_summary(agent)

    agent = torch.compile(agent)

    ########## optimizers ##########
    if optimizer == "adam":
        optimizer = optim.AdamW(
            agent.parameters(), lr=learning_rate, weight_decay=weight_decay, eps=1e-5
        )  # cleanrl baseline
    elif optimizer == "kron":
        optimizer = Kron(
            agent.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay,
        )
    else:
        raise ValueError(f"Invalid optimizer: {optimizer}")
    ################################

    # ALGO Logic: Storage setup
    obs = torch.zeros((num_steps, num_envs) + envs.single_observation_space.shape).to(
        device
    )
    actions = torch.zeros((num_steps, num_envs) + envs.single_action_space.shape).to(
        device
    )
    logprobs = torch.zeros((num_steps, num_envs)).to(device)
    rewards = torch.zeros((num_steps, num_envs)).to(device)
    dones = torch.zeros((num_steps, num_envs)).to(device)
    values = torch.zeros((num_steps, num_envs)).to(device)

    # Print shapes of batch tensors
    print("Batch tensor shapes:")
    print(f"obs: {obs.shape}")
    print(f"actions: {actions.shape}")
    print(f"logprobs: {logprobs.shape}")
    print(f"rewards: {rewards.shape}")
    print(f"dones: {dones.shape}")
    print(f"values: {values.shape}")

    avg_returns = deque(maxlen=20)
    collected_rewards = []

    # TRY NOT TO MODIFY: start the game
    global_step = 0
    start_time = time.time()
    next_obs = torch.Tensor(envs.reset()).to(device)
    next_done = torch.zeros(num_envs).to(device)

    for iteration in range(1, num_iterations + 1):
        # Annealing the rate if instructed to do so.
        if anneal_lr:
            frac = 1.0 - (iteration - 1) / num_iterations
            lrnow = frac * learning_rate
            optimizer.param_groups[0]["lr"] = lrnow

        agent.eval()
        for step in range(0, num_steps):
            global_step += num_envs
            obs[step] = next_obs
            dones[step] = next_done

            # ALGO LOGIC: action logic
            with torch.no_grad():
                action, logprob, _, value = agent.get_action_and_value(next_obs)
                values[step] = value.flatten()
            actions[step] = action
            logprobs[step] = logprob

            # TRY NOT TO MODIFY: execute the game and log data.
            next_obs, reward, next_done, info = envs.step(action.cpu().numpy())
            rewards[step] = torch.tensor(reward).to(device).view(-1)
            next_obs, next_done = (
                torch.from_numpy(next_obs).to(device),
                torch.from_numpy(next_done).to(device).float(),
            )

            for idx, d in enumerate(next_done):
                if d and info["lives"][idx] == 0:
                    if iteration % 25 == 0:
                        print(
                            f"global_step={global_step}, episodic_return={info['r'][idx]}"
                        )
                    avg_returns.append(info["r"][idx])
                    writer.add_scalar(
                        "charts/avg_episodic_return",
                        np.average(avg_returns),
                        global_step,
                    )
                    collected_rewards.append((global_step, np.average(avg_returns)))
                    writer.add_scalar(
                        "charts/episodic_return", info["r"][idx], global_step
                    )
                    writer.add_scalar(
                        "charts/episodic_length", info["l"][idx], global_step
                    )

        agent.train()

        # bootstrap value if not done
        with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = 0
            for t in reversed(range(num_steps)):
                if t == num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = (
                    delta + gamma * gae_lambda * nextnonterminal * lastgaelam
                )
            returns = advantages + values

        # flatten the batch
        b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
        b_logprobs = logprobs.reshape(-1)
        b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)

        # Optimizing the policy and value network
        b_inds = np.arange(batch_size)
        clipfracs = []
        for epoch in range(update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, batch_size, minibatch_size):
                end = start + minibatch_size
                mb_inds = b_inds[start:end]

                _, newlogprob, entropy, newvalue = agent.get_action_and_value(
                    b_obs[mb_inds], b_actions.long()[mb_inds]
                )
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [
                        ((ratio - 1.0).abs() > clip_coef).float().mean().item()
                    ]

                mb_advantages = b_advantages[mb_inds]
                if norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (
                        mb_advantages.std() + 1e-8
                    )

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(
                    ratio, 1 - clip_coef, 1 + clip_coef
                )
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                if clip_vloss:
                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                    v_clipped = b_values[mb_inds] + torch.clamp(
                        newvalue - b_values[mb_inds], -clip_coef, clip_coef
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                entropy_loss = entropy.mean()
                loss = pg_loss - ent_coef * entropy_loss + v_loss * vf_coef

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)
                optimizer.step()

            if target_kl is not None and approx_kl > target_kl:
                break

        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

        # TRY NOT TO MODIFY: record rewards for plotting purposes
        writer.add_scalar(
            "charts/learning_rate", optimizer.param_groups[0]["lr"], global_step
        )
        writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
        writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
        writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
        writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
        writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
        writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
        writer.add_scalar("losses/explained_variance", explained_var, global_step)
        writer.add_scalar(
            "charts/SPS", int(global_step / (time.time() - start_time)), global_step
        )

    envs.close()
    writer.close()
    if track:
        wandb.finish()

    return collected_rewards


if __name__ == "__main__":
    games = [
        "Breakout-v5",
        "MsPacman-v5",
        "Qbert-v5",
        "KungFuMaster-v5",
        "Amidar-v5",
        "BankHeist-v5",
        "Kangaroo-v5",
        "Klax-v5",
    ]
    shooter_games = [
        "YarsRevenge-v5",
        "Hero-v5",
        "Krull-v5",
        "Tutankham-v5",
        "Assault-v5",
        "Asteroids-v5",
        "BeamRider-v5",
        "Berzerk-v5",
        "DemonAttack-v5",
        "Phoenix-v5",
        "SpaceInvaders-v5",
        "StarGunner-v5"
    ]
    optimizers = ["adam", "kron"]

    for game in games:
        for optimizer in optimizers:
            print(f"Running with optimizer: {optimizer}")
            main(
                exp_name=f"ppo_atari_{optimizer}",
                env_id=game,
                optimizer=optimizer,
                learning_rate=0.00025,
                weight_decay=0.0001,  # cleanrl default was 0.0
                total_timesteps=10000000,
                anneal_lr=True,
                num_envs=16,  # cleanrl default was 8
            )