In [None]:
! pip install swig
! pip install box2d-py
! pip install gym[box2d]

In [None]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import os
import glob
import io
import base64
import pickle
import matplotlib.pyplot as plt
from IPython.display import HTML, display
from google.colab import drive

drive.mount('/content/drive')

DRIVE_FOLDER = "/content/drive/MyDrive/BipedalWalker_Project"
os.makedirs(DRIVE_FOLDER, exist_ok=True)
print(f"Saving all models to: {DRIVE_FOLDER}")

def show_video(video_folder="videos"):
    mp4list = glob.glob(f'{video_folder}/*.mp4')
    if len(mp4list) > 0:
        mp4 = max(mp4list, key=os.path.getctime)
        video = io.open(mp4, 'r+b').read()
        encoded = base64.b64encode(video)
        display(HTML(data='''<video alt="test" autoplay
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
    else:
        print("No video found yet.")

# Model architecture
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class ActorCritic(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(obs_dim, 512)),
            nn.Tanh(),
            layer_init(nn.Linear(512, 512)),
            nn.Tanh(),
            layer_init(nn.Linear(512, 1), std=1.0),
        )
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(obs_dim, 512)),
            nn.Tanh(),
            layer_init(nn.Linear(512, 512)),
            nn.Tanh(),
            layer_init(nn.Linear(512, act_dim), std=0.01),
        )
        self.actor_logstd = nn.Parameter(torch.ones(1, act_dim) * -0.5)

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        action_mean = self.actor_mean(x)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        dist = Normal(action_mean, action_std)

        if action is None:
            action = dist.sample()

        log_prob = dist.log_prob(action).sum(1)
        entropy = dist.entropy().sum(1)
        value = self.critic(x)
        return action, log_prob, entropy, value.squeeze(1)

class RolloutBuffer:
    def __init__(self, size, obs_dim, act_dim, device):
        self.size = size
        self.device = device
        self.obs = np.zeros((size, obs_dim), dtype=np.float32)
        self.actions = np.zeros((size, act_dim), dtype=np.float32)
        self.log_probs = np.zeros(size, dtype=np.float32)
        self.rewards = np.zeros(size, dtype=np.float32)
        self.dones = np.zeros(size, dtype=np.float32)
        self.values = np.zeros(size, dtype=np.float32)
        self.advantages = np.zeros(size, dtype=np.float32)
        self.returns = np.zeros(size, dtype=np.float32)
        self.ptr = 0
        self.path_start_idx = 0

    def store(self, obs, action, log_prob, reward, done, value):
        assert self.ptr < self.size
        self.obs[self.ptr] = obs
        self.actions[self.ptr] = action
        self.log_probs[self.ptr] = log_prob
        self.rewards[self.ptr] = reward
        self.dones[self.ptr] = done
        self.values[self.ptr] = value
        self.ptr += 1

    def finish_path(self, last_value, gamma, lam):
        path_slice = slice(self.path_start_idx, self.ptr)
        rewards = np.append(self.rewards[path_slice], last_value)
        values = np.append(self.values[path_slice], last_value)

        gae = 0.0
        adv = np.zeros_like(self.rewards[path_slice])

        for t in reversed(range(len(rewards) - 1)):
            delta = rewards[t] + gamma * values[t + 1] * (1 - self.dones[path_slice][t]) - values[t]
            gae = delta + gamma * lam * (1 - self.dones[path_slice][t]) * gae
            adv[t] = gae

        self.advantages[path_slice] = adv
        self.returns[path_slice] = adv + self.values[path_slice]
        self.path_start_idx = self.ptr

    def get(self):
        assert self.ptr == self.size
        self.ptr = 0
        self.path_start_idx = 0

        adv = self.advantages
        adv = (adv - adv.mean()) / (adv.std() + 1e-8)

        return dict(
            obs=torch.tensor(self.obs, dtype=torch.float32, device=self.device),
            actions=torch.tensor(self.actions, dtype=torch.float32, device=self.device),
            log_probs=torch.tensor(self.log_probs, dtype=torch.float32, device=self.device),
            advantages=torch.tensor(adv, dtype=torch.float32, device=self.device),
            returns=torch.tensor(self.returns, dtype=torch.float32, device=self.device),
            values=torch.tensor(self.values, dtype=torch.float32, device=self.device),
        )

class PPOAgent:
    def __init__(
        self,
        env_id="BipedalWalker-v3",
        total_timesteps=2_000_000,
        rollout_steps=4096,
        gamma=0.99,
        lam=0.95,
        clip_eps=0.2,
        learning_rate=2.5e-4,
        train_epochs=10,
        minibatch_size=512,
        vf_coef=0.5,
        ent_coef=0.00,
        render_freq=25,
        device="cpu",
    ):
        self.env_id = env_id
        self.total_timesteps = total_timesteps
        self.rollout_steps = rollout_steps
        self.gamma = gamma
        self.lam = lam
        self.clip_eps = clip_eps
        self.learning_rate = learning_rate
        self.train_epochs = train_epochs
        self.minibatch_size = minibatch_size
        self.vf_coef = vf_coef
        self.ent_coef = ent_coef
        self.render_freq = render_freq
        self.device = device

        self.ep_returns = []
        self.ep_timesteps = []

        self.env = gym.make(env_id)
        self.env = gym.wrappers.RecordEpisodeStatistics(self.env)
        self.env = gym.wrappers.ClipAction(self.env)

        self.env = gym.wrappers.NormalizeObservation(self.env)
        self.env = gym.wrappers.TransformObservation(self.env, lambda obs: np.clip(obs, -10, 10), self.env.observation_space)

        self.obs_dim = self.env.observation_space.shape[0]
        self.act_dim = self.env.action_space.shape[0]

        self.ac = ActorCritic(self.obs_dim, self.act_dim).to(device)
        self.optimizer = optim.Adam(self.ac.parameters(), lr=self.learning_rate, eps=1e-5)

        self.num_updates = total_timesteps // rollout_steps
        self.lr_scheduler = torch.optim.lr_scheduler.LinearLR(
            self.optimizer, start_factor=1.0, end_factor=0.0, total_iters=self.num_updates
        )

        self.buffer = RolloutBuffer(self.rollout_steps, self.obs_dim, self.act_dim, self.device)

    def visualize_agent(self, update_count):
        print(f"\n--- Visualizing Agent at Update {update_count} ---")
        vis_env = gym.make(self.env_id, render_mode="rgb_array")
        vis_env = gym.wrappers.RecordVideo(
            vis_env,
            video_folder="videos",
            name_prefix=f"update_{update_count}",
            disable_logger=True
        )
        vis_env = gym.wrappers.ClipAction(vis_env)
        vis_norm = gym.wrappers.NormalizeObservation(vis_env)
        try:
            vis_norm.obs_rms = self.env.get_wrapper_attr('obs_rms')
        except AttributeError:
            pass
        vis_env = gym.wrappers.TransformObservation(vis_norm, lambda obs: np.clip(obs, -10, 10), vis_env.observation_space)

        obs, _ = vis_env.reset()
        ret = 0
        while True:
            obs_tensor = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
            with torch.no_grad():
                action = self.ac.actor_mean(obs_tensor).squeeze(0).cpu().numpy()
            obs, reward, terminated, truncated, _ = vis_env.step(action)
            ret += reward
            if terminated or truncated:
                break
        vis_env.close()
        print(f"Visualization finished with return: {ret:.2f}")
        show_video("videos")

    def train(self):
        obs, _ = self.env.reset()
        timesteps_collected = 0
        update_count = 0

        while timesteps_collected < self.total_timesteps:
            for _ in range(self.rollout_steps):
                obs_tensor = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
                with torch.no_grad():
                    action, log_prob, _, value = self.ac.get_action_and_value(obs_tensor)

                action = action.cpu().numpy().squeeze(0)
                log_prob = log_prob.item()
                value = value.item()

                next_obs, reward, terminated, truncated, infos = self.env.step(action)
                done = terminated or truncated

                self.buffer.store(obs, action, log_prob, reward, done, value)
                timesteps_collected += 1
                obs = next_obs

                if "episode" in infos:
                    ret = infos['episode']['r']
                    self.ep_returns.append(ret)
                    self.ep_timesteps.append(timesteps_collected)
                    print(f"Update {update_count} | Steps: {timesteps_collected} | Return: {ret:.2f}")

                if done:
                    if truncated:
                        last_val_obs = torch.tensor(next_obs, dtype=torch.float32, device=self.device).unsqueeze(0)
                        with torch.no_grad():
                            last_value = self.ac.get_value(last_val_obs).item()
                    else:
                        last_value = 0
                    self.buffer.finish_path(last_value=last_value, gamma=self.gamma, lam=self.lam)
                    obs, _ = self.env.reset()

                if timesteps_collected >= self.total_timesteps:
                    break

            if not done:
                obs_tensor = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
                with torch.no_grad():
                    last_value = self.ac.get_value(obs_tensor).item()
                self.buffer.finish_path(last_value=last_value, gamma=self.gamma, lam=self.lam)

            if self.buffer.ptr == self.buffer.size:
                data = self.buffer.get()
                self._update(data)
                self.lr_scheduler.step()
                update_count += 1

                if update_count % 50 == 0:
                    self.save(os.path.join(DRIVE_FOLDER, "ppo_bipedal_checkpoint.pt"))
                    print(f"Checkpoint saved to drive at update {update_count}")

            if update_count > 0 and update_count % self.render_freq == 0:
                self.visualize_agent(update_count)

        self.env.close()

    def _update(self, data):
        obs = data["obs"]
        actions = data["actions"]
        old_log_probs = data["log_probs"]
        advantages = data["advantages"]
        returns = data["returns"]

        batch_size = len(obs)
        inds = np.arange(batch_size)

        for _ in range(self.train_epochs):
            np.random.shuffle(inds)
            for start in range(0, batch_size, self.minibatch_size):
                end = start + self.minibatch_size
                mb_inds = inds[start:end]
                mb_obs = obs[mb_inds]
                mb_actions = actions[mb_inds]
                mb_old_log_probs = old_log_probs[mb_inds]
                mb_adv = advantages[mb_inds]
                mb_returns = returns[mb_inds]

                _, new_log_probs, entropy, values = self.ac.get_action_and_value(mb_obs, mb_actions)
                ratio = torch.exp(new_log_probs - mb_old_log_probs)
                surr1 = ratio * mb_adv
                surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * mb_adv

                actor_loss = -torch.min(surr1, surr2).mean()
                critic_loss = 0.5 * ((values - mb_returns) ** 2).mean()
                entropy_loss = entropy.mean()
                loss = actor_loss + self.vf_coef * critic_loss - self.ent_coef * entropy_loss

                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.ac.parameters(), max_norm=0.5)
                self.optimizer.step()

    def plot_results(self, window_size=50):
        if not self.ep_returns:
            print("Can not plot, no rewards")
            return

        # Calculate Running Average
        returns = np.array(self.ep_returns)
        running_avg = np.convolve(returns, np.ones(window_size)/window_size, mode='valid')

        plt.figure(figsize=(10, 6))
        plt.plot(self.ep_timesteps, returns, alpha=0.3, color='blue', label='Episode Return')

        avg_timesteps = self.ep_timesteps[window_size-1:]
        plt.plot(avg_timesteps, running_avg, color='red', linewidth=2, label=f'Running Avg (last {window_size} eps)')

        plt.title(f"PPO Training Progress: {self.env_id}")
        plt.xlabel("Total Timesteps")
        plt.ylabel("Return")
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.6)

        # Save plot to drive
        plot_path = os.path.join(DRIVE_FOLDER, "training_plot.png")
        plt.savefig(plot_path)
        plt.show()
        print(f"Plot saved to: {plot_path}")

    def save(self, path):
        torch.save(self.ac.state_dict(), path)

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    agent = PPOAgent(
        env_id="BipedalWalker-v3",
        total_timesteps=2_000_000,
        rollout_steps=4096,
        render_freq=25,
        device=device,
    )

    agent.train()

    # Final plot and save
    agent.plot_results(window_size=50)

    final_model_path = os.path.join(DRIVE_FOLDER, "ppo_bipedal_final.pt")
    final_stats_path = os.path.join(DRIVE_FOLDER, "obs_stats.pkl")
    agent.save(final_model_path)

    with open(final_stats_path, "wb") as f:
        pickle.dump(agent.env.get_wrapper_attr('obs_rms'), f)

    print(f"Training Complete. Files saved to Drive:\n1. {final_model_path}\n2. {final_stats_path}")

In [None]:
import os, glob, io, base64, pickle, random
from collections import deque

import numpy as np
import gymnasium as gym

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal

import matplotlib.pyplot as plt
from IPython.display import HTML, display
from google.colab import drive


if not os.path.exists("/content/drive"):
    drive.mount("/content/drive")

DRIVE_FOLDER = "/content/drive/MyDrive/BipedalWalker_Project"
os.makedirs(DRIVE_FOLDER, exist_ok=True)

BASE_NORMAL_MODEL = os.path.join(DRIVE_FOLDER, "ppo_bipedal_final.pt")
BASE_NORMAL_STATS = os.path.join(DRIVE_FOLDER, "obs_stats.pkl")

RUN_TAG = "v11"
CKPT_PATH  = os.path.join(DRIVE_FOLDER, f"ppo_hardcore_{RUN_TAG}_checkpoint.pt")
BEST_PATH  = os.path.join(DRIVE_FOLDER, f"ppo_hardcore_{RUN_TAG}_best.pt")
FINAL_PATH = os.path.join(DRIVE_FOLDER, f"ppo_hardcore_{RUN_TAG}_final.pt")
PLOT_PATH  = os.path.join(DRIVE_FOLDER, f"training_plot_{RUN_TAG}.png")

print("Drive folder:", DRIVE_FOLDER)
print("Checkpoint:", CKPT_PATH)
print("Best:", BEST_PATH)
print("Final:", FINAL_PATH)
print("Plot:", PLOT_PATH)


def show_video(folder="videos"):
    mp4list = glob.glob(f"{folder}/*.mp4")
    if len(mp4list) == 0:
        print("No video found.")
        return
    mp4 = max(mp4list, key=os.path.getctime)
    video = io.open(mp4, "r+b").read()
    encoded = base64.b64encode(video)
    display(HTML(data=f"""
    <video autoplay loop controls style="height: 420px;">
      <source src="data:video/mp4;base64,{encoded.decode('ascii')}" type="video/mp4" />
    </video>
    """))


def make_one_env_video(seed=None, video_folder="videos", name_prefix="eval"):
    env = gym.make("BipedalWalkerHardcore-v3", render_mode="rgb_array")
    env = gym.wrappers.RecordEpisodeStatistics(env)
    env = gym.wrappers.NormalizeObservation(env)
    env = gym.wrappers.TransformObservation(env, lambda o: np.clip(o, -10, 10), env.observation_space)

    env = gym.wrappers.RecordVideo(
        env,
        video_folder=video_folder,
        episode_trigger=lambda ep: True,
        name_prefix=name_prefix,
        disable_logger=True
    )

    if seed is not None:
        env.reset(seed=int(seed))
    return env

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    nn.init.orthogonal_(layer.weight, std)
    nn.init.constant_(layer.bias, bias_const)
    return layer


class ActorCritic(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(obs_dim, 512)),
            nn.Tanh(),
            layer_init(nn.Linear(512, 512)),
            nn.Tanh(),
            layer_init(nn.Linear(512, 1), std=1.0),
        )
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(obs_dim, 512)),
            nn.Tanh(),
            layer_init(nn.Linear(512, 512)),
            nn.Tanh(),
            layer_init(nn.Linear(512, act_dim), std=0.01),
        )
        self.actor_logstd = nn.Parameter(torch.ones(1, act_dim) * -0.7)

    def get_value(self, x):
        return self.critic(x).squeeze(1)

    def _dist(self, x):
        mean = self.actor_mean(x)
        logstd = self.actor_logstd.expand_as(mean)
        std = torch.exp(logstd)
        return Normal(mean, std), mean, logstd

    @staticmethod
    def _squash_action(u):
        return torch.tanh(u)

    @staticmethod
    def _squash_logprob(dist, u, eps=1e-6):
        logp_u = dist.log_prob(u).sum(dim=1)
        a = torch.tanh(u)
        corr = torch.log(1.0 - a.pow(2) + eps).sum(dim=1)
        return logp_u - corr

    def get_action_and_value(self, x, action=None, deterministic=False):
        dist, mean, _ = self._dist(x)

        if action is None:
            if deterministic:
                u = mean
            else:
                u = dist.rsample()
            a = self._squash_action(u)
            logp = self._squash_logprob(dist, u)
        else:
            a = action
            a_clamped = torch.clamp(a, -0.999, 0.999)
            u = 0.5 * (torch.log1p(a_clamped) - torch.log1p(-a_clamped))  # atanh
            logp = self._squash_logprob(dist, u)

        ent = dist.entropy().sum(dim=1)
        val = self.get_value(x)
        mean_a = torch.tanh(mean)
        return a, logp, ent, val, mean_a


def make_one_env(seed=None):
    env = gym.make("BipedalWalkerHardcore-v3")
    env = gym.wrappers.RecordEpisodeStatistics(env)
    env = gym.wrappers.NormalizeObservation(env)
    env = gym.wrappers.TransformObservation(env, lambda o: np.clip(o, -10, 10), env.observation_space)
    if seed is not None:
        env.reset(seed=int(seed))
    return env


def make_vec_env(n, seed_list=None):
    def thunk(i):
        def _f():
            s = None if seed_list is None else seed_list[i]
            return make_one_env(seed=s)
        return _f
    return gym.vector.SyncVectorEnv([thunk(i) for i in range(n)])


def find_normalize_obs_wrapper(env):
    env_ptr = env
    while hasattr(env_ptr, "env"):
        if isinstance(env_ptr, gym.wrappers.NormalizeObservation):
            return env_ptr
        env_ptr = env_ptr.env
    return None


class SeedPool:
    def __init__(self, maxlen=4000):
        self.maxlen = maxlen
        self.seeds = deque(maxlen=maxlen)

    def add(self, seed):
        if seed is None:
            return
        self.seeds.append(int(seed))

    def sample(self, k):
        if len(self.seeds) == 0:
            return None
        k = min(k, len(self.seeds))
        return random.sample(list(self.seeds), k)

    def state_dict(self):
        return {"maxlen": self.maxlen, "seeds": list(self.seeds)}

    def load_state_dict(self, d):
        self.maxlen = int(d.get("maxlen", self.maxlen))
        self.seeds = deque(d.get("seeds", []), maxlen=self.maxlen)


class PPOHardcoreVectorV11:
    def __init__(
        self,
        device="cuda",
        num_envs=16,
        rollout_steps=2048,          # per env
        total_env_steps=40_000_000,  # across envs
        gamma=0.99,
        lam=0.95,
        clip_eps=0.15,
        lr=1.5e-4,
        train_epochs=10,
        minibatch_size=4096,
        vf_coef=0.5,
        ent_coef_start=0.02,
        ent_coef_end=0.0,
        max_grad_norm=0.5,
        target_kl=0.02,

        logstd_min=-5.0,
        logstd_max=-0.5,

        jerk_start_p75=240.0,
        jerk_coef_max=0.0010,
        jerk_ramp_updates=200,

        eval_every_updates=10,
        eval_episodes=16,

        checkpoint_every_updates=10,

        reseed_p_random=0.60,
        reseed_p_hard=0.20,
        reseed_p_success=0.20,

        hard_pool_bottom_frac=0.20,
        success_pool_threshold=300.0,
    ):
        self.device = device
        self.num_envs = num_envs
        self.rollout_steps = rollout_steps
        self.total_env_steps = total_env_steps

        self.gamma = gamma
        self.lam = lam
        self.clip_eps = clip_eps
        self.lr = lr
        self.train_epochs = train_epochs
        self.minibatch_size = minibatch_size
        self.vf_coef = vf_coef
        self.ent_coef_start = ent_coef_start
        self.ent_coef_end = ent_coef_end
        self.max_grad_norm = max_grad_norm
        self.target_kl = target_kl

        self.logstd_min = logstd_min
        self.logstd_max = logstd_max

        self.jerk_start_p75 = jerk_start_p75
        self.jerk_coef_max = jerk_coef_max
        self.jerk_ramp_updates = max(1, int(jerk_ramp_updates))

        self.eval_every_updates = eval_every_updates
        self.eval_episodes = eval_episodes
        self.checkpoint_every_updates = checkpoint_every_updates

        s = reseed_p_random + reseed_p_hard + reseed_p_success
        self.reseed_p_random  = reseed_p_random / s
        self.reseed_p_hard    = reseed_p_hard / s
        self.reseed_p_success = reseed_p_success / s

        self.hard_pool_bottom_frac = hard_pool_bottom_frac
        self.success_pool_threshold = success_pool_threshold

        self.hard_pool = SeedPool(maxlen=4000)
        self.success_pool = SeedPool(maxlen=2000)

        self.env_seeds = [random.randint(0, 2_000_000_000) for _ in range(num_envs)]
        self.env = make_vec_env(num_envs, seed_list=self.env_seeds)

        self.obs_dim = self.env.single_observation_space.shape[0]
        self.act_dim = self.env.single_action_space.shape[0]

        self.ac = ActorCritic(self.obs_dim, self.act_dim).to(self.device)
        self.opt = optim.Adam(self.ac.parameters(), lr=self.lr, eps=1e-5)

        self.global_step = 0
        self.update = 0

        self.best_eval_p75 = -1e9

        self.jerk_enabled = False
        self.jerk_enable_update = None
        self.jerk_enable_step = None
        self.jerk_coef_current = 0.0

        T, N = rollout_steps, num_envs
        self.buf_obs  = np.zeros((T, N, self.obs_dim), dtype=np.float32)
        self.buf_act  = np.zeros((T, N, self.act_dim), dtype=np.float32)
        self.buf_logp = np.zeros((T, N), dtype=np.float32)
        self.buf_rew  = np.zeros((T, N), dtype=np.float32)
        self.buf_done = np.zeros((T, N), dtype=np.float32)
        self.buf_val  = np.zeros((T, N), dtype=np.float32)

        self.buf_vold = np.zeros((T, N), dtype=np.float32)

        self.prev_mean = np.zeros((N, self.act_dim), dtype=np.float32)

        self.ep_returns = []
        self.ep_timesteps = []
        self.eval_history = []

        if os.path.exists(CKPT_PATH):
            print(f"Resuming from checkpoint: {CKPT_PATH}")
            self.load(CKPT_PATH)
        else:
            if os.path.exists(BASE_NORMAL_MODEL):
                w = torch.load(BASE_NORMAL_MODEL, map_location=self.device, weights_only=False)
                self.ac.load_state_dict(w)
                with torch.no_grad():
                    self.ac.actor_logstd.fill_(-0.7)

            if os.path.exists(BASE_NORMAL_STATS):
                with open(BASE_NORMAL_STATS, "rb") as f:
                    base_rms = pickle.load(f)
                for e in self.env.envs:
                    wno = find_normalize_obs_wrapper(e)
                    if wno is not None:
                        wno.obs_rms = base_rms

    def clamp_logstd(self):
        with torch.no_grad():
            self.ac.actor_logstd.clamp_(self.logstd_min, self.logstd_max)

    def current_ent_coef(self):
        frac = max(0.0, 1.0 - (self.global_step / self.total_env_steps))
        return self.ent_coef_end + (self.ent_coef_start - self.ent_coef_end) * frac

    def current_jerk_coef(self):
        if not self.jerk_enabled:
            return 0.0
        k = self.update - int(self.jerk_enable_update)
        k = max(0, min(k, self.jerk_ramp_updates))
        return self.jerk_coef_max * (k / self.jerk_ramp_updates)

    def reseed_envs_if_needed(self, done_mask):
        done_idxs = np.where(done_mask)[0]
        if len(done_idxs) == 0:
            return

        hard = self.hard_pool.sample(len(done_idxs))
        succ = self.success_pool.sample(len(done_idxs))

        for j, env_i in enumerate(done_idxs):
            r = random.random()
            if (hard is not None) and (r < self.reseed_p_hard):
                seed = hard[j % len(hard)]
            elif (succ is not None) and (r < self.reseed_p_hard + self.reseed_p_success):
                seed = succ[j % len(succ)]
            else:
                seed = random.randint(0, 2_000_000_000)

            self.env_seeds[env_i] = seed
            self.env.envs[env_i].reset(seed=int(seed))

    @torch.no_grad()
    def evaluate(self, episodes=16, deterministic=True):
        env = make_one_env()

        train_wno = find_normalize_obs_wrapper(self.env.envs[0])
        eval_wno  = find_normalize_obs_wrapper(env)
        if train_wno is not None and eval_wno is not None:
            eval_wno.obs_rms = train_wno.obs_rms

        seed_ret = []
        for _ in range(episodes):
            seed = random.randint(0, 2_000_000_000)
            obs, _ = env.reset(seed=int(seed))
            ep_ret = 0.0
            done = False
            while not done:
                obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
                a, _, _, _, _ = self.ac.get_action_and_value(obs_t, action=None, deterministic=deterministic)
                obs, r, term, trunc, _ = env.step(a.squeeze(0).cpu().numpy())
                ep_ret += float(r)
                done = term or trunc
            seed_ret.append((seed, ep_ret))

        env.close()

        rets = np.array([x[1] for x in seed_ret], dtype=np.float32)
        mean = float(np.mean(rets))
        med  = float(np.median(rets))
        p75  = float(np.percentile(rets, 75))
        best = float(np.max(rets))
        worst= float(np.min(rets))
        return mean, med, p75, best, worst, seed_ret

    def save(self, path):
        wno = find_normalize_obs_wrapper(self.env.envs[0])
        ckpt = {
            "run_tag": RUN_TAG,
            "model": self.ac.state_dict(),
            "opt": self.opt.state_dict(),
            "global_step": self.global_step,
            "update": self.update,
            "best_eval_p75": self.best_eval_p75,

            "jerk_enabled": self.jerk_enabled,
            "jerk_enable_update": self.jerk_enable_update,
            "jerk_enable_step": self.jerk_enable_step,

            "obs_rms": (wno.obs_rms if wno is not None else None),
            "hard_pool": self.hard_pool.state_dict(),
            "success_pool": self.success_pool.state_dict(),
            "env_seeds": self.env_seeds,

            "ep_returns": self.ep_returns,
            "ep_timesteps": self.ep_timesteps,
            "eval_history": self.eval_history,
        }
        torch.save(ckpt, path)

    def load(self, path):
        ckpt = torch.load(path, map_location=self.device, weights_only=False)

        self.ac.load_state_dict(ckpt["model"])
        self.opt.load_state_dict(ckpt["opt"])
        self.global_step = int(ckpt.get("global_step", 0))
        self.update = int(ckpt.get("update", 0))
        self.best_eval_p75 = float(ckpt.get("best_eval_p75", -1e9))

        self.jerk_enabled = bool(ckpt.get("jerk_enabled", False))
        self.jerk_enable_update = ckpt.get("jerk_enable_update", None)
        self.jerk_enable_step = ckpt.get("jerk_enable_step", None)

        hp = ckpt.get("hard_pool", None)
        if hp is not None:
            self.hard_pool.load_state_dict(hp)

        sp = ckpt.get("success_pool", None)
        if sp is not None:
            self.success_pool.load_state_dict(sp)

        env_seeds = ckpt.get("env_seeds", None)
        if env_seeds is not None and len(env_seeds) == self.num_envs:
            self.env_seeds = list(env_seeds)

        self.ep_returns = list(ckpt.get("ep_returns", []))
        self.ep_timesteps = list(ckpt.get("ep_timesteps", []))
        self.eval_history = list(ckpt.get("eval_history", []))

        try:
            self.env.close()
        except Exception:
            pass
        self.env = make_vec_env(self.num_envs, seed_list=self.env_seeds)

        rms = ckpt.get("obs_rms", None)
        if rms is not None:
            for e in self.env.envs:
                wno = find_normalize_obs_wrapper(e)
                if wno is not None:
                    wno.obs_rms = rms

        print(
            f"Loaded ckpt: upd={self.update} step={self.global_step} "
            f"jerk={'on' if self.jerk_enabled else 'off'} "
            f"hard={len(self.hard_pool.seeds)} success={len(self.success_pool.seeds)} "
            f"best_p75={self.best_eval_p75:.2f}"
        )

    def plot_results(self, window_size=50):
        if not self.ep_returns:
            print("No rewards to plot")
            return

        returns = np.array(self.ep_returns, dtype=np.float32)
        timesteps = np.array(self.ep_timesteps, dtype=np.int64)

        plt.figure(figsize=(10, 6))

        plt.plot(
            timesteps,
            returns,
            alpha=0.3,
            color="blue",
            linewidth=1.0,
            label="Raw Episode Return"
        )

        if len(returns) >= window_size:
            running_avg = np.convolve(returns, np.ones(window_size) / window_size, mode="valid")
            avg_timesteps = timesteps[window_size - 1:]
            plt.plot(
                avg_timesteps,
                running_avg,
                color="red",
                linewidth=2.0,
                label=f"Running Avg ({window_size})"
            )

        if self.jerk_enable_step is not None:
            plt.axvline(
                x=self.jerk_enable_step,
                color="green",
                linestyle="--",
                linewidth=2.0,
                label="Jerk penalty on"
            )

        plt.title("BipedalWalker Hardcore: Full Training Progress")
        plt.xlabel("Environment Timesteps")
        plt.ylabel("Return")
        plt.legend(loc="lower right")
        plt.grid(True, linestyle="--", alpha=0.35)

        plt.savefig(PLOT_PATH, dpi=160, bbox_inches="tight")
        plt.show()
        print(f"Plot saved to: {PLOT_PATH}")

    @torch.no_grad()
    def record_eval_video(self, deterministic=True, max_steps=2000):
        video_folder = os.path.join(DRIVE_FOLDER, f"videos_{RUN_TAG}")
        os.makedirs(video_folder, exist_ok=True)

        name_prefix = f"upd{self.update}_step{self.global_step}"
        env = make_one_env_video(
            seed=random.randint(0, 2_000_000_000),
            video_folder=video_folder,
            name_prefix=name_prefix,
        )

        train_wno = find_normalize_obs_wrapper(self.env.envs[0])
        eval_wno  = find_normalize_obs_wrapper(env)
        if train_wno is not None and eval_wno is not None:
            eval_wno.obs_rms = train_wno.obs_rms

        obs, _ = env.reset()
        ep_ret = 0.0
        for _ in range(max_steps):
            obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
            a, _, _, _, _ = self.ac.get_action_and_value(obs_t, action=None, deterministic=deterministic)
            obs, r, term, trunc, _ = env.step(a.squeeze(0).cpu().numpy())
            ep_ret += float(r)
            if term or trunc:
                break

        env.close()
        print(f"Evaluation video | ret={ep_ret:.2f} | folder={video_folder}")
        show_video(video_folder)

    def train(self):
        obs, _ = self.env.reset()
        self.prev_mean[:] = 0.0

        batch_size = self.num_envs * self.rollout_steps

        while self.global_step < self.total_env_steps:
            self.clamp_logstd()
            ent_coef = self.current_ent_coef()
            self.jerk_coef_current = self.current_jerk_coef()

            for t in range(self.rollout_steps):
                self.buf_obs[t] = obs

                obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device)
                with torch.no_grad():
                    action, logp, ent, val, mean_a = self.ac.get_action_and_value(obs_t)

                act_np = action.cpu().numpy().astype(np.float32)
                mean_np = mean_a.cpu().numpy().astype(np.float32)

                next_obs, raw_r, term, trunc, infos = self.env.step(act_np)
                done = np.logical_or(term, trunc)

                shaped = raw_r.astype(np.float32)
                if self.jerk_coef_current > 0.0:
                    da2 = np.mean((mean_np - self.prev_mean) ** 2, axis=1)
                    shaped = shaped - (self.jerk_coef_current * da2).astype(np.float32)

                self.prev_mean = mean_np

                self.buf_act[t]  = act_np
                self.buf_logp[t] = logp.cpu().numpy().astype(np.float32)
                self.buf_val[t]  = val.cpu().numpy().astype(np.float32)
                self.buf_vold[t] = self.buf_val[t]
                self.buf_rew[t]  = shaped
                self.buf_done[t] = done.astype(np.float32)

                obs = next_obs
                self.global_step += self.num_envs

                if isinstance(infos, dict) and "episode" in infos:
                    ep_r = infos["episode"]["r"]
                    ep_l = infos["episode"]["l"]
                    for i in range(self.num_envs):
                        if done[i]:
                            r_i = float(ep_r[i])
                            l_i = int(ep_l[i])
                            print(
                                f"step={self.global_step:9d} | upd={self.update:6d} | "
                                f"env={i:2d} | ret={r_i:7.2f} | len={l_i:4d} | "
                                f"jerk={'on' if self.jerk_coef_current > 0.0 else 'off'} coef={self.jerk_coef_current:.5f}"
                            )
                            self.ep_returns.append(r_i)
                            self.ep_timesteps.append(int(self.global_step))

                self.reseed_envs_if_needed(done)

            with torch.no_grad():
                obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device)
                next_val = self.ac.get_value(obs_t).cpu().numpy().astype(np.float32)

            adv = np.zeros((self.rollout_steps, self.num_envs), dtype=np.float32)
            lastgaelam = np.zeros((self.num_envs,), dtype=np.float32)
            for t in reversed(range(self.rollout_steps)):
                nonterminal = 1.0 - self.buf_done[t]
                nextv = next_val if t == self.rollout_steps - 1 else self.buf_val[t + 1]
                delta = self.buf_rew[t] + self.gamma * nextv * nonterminal - self.buf_val[t]
                lastgaelam = delta + self.gamma * self.lam * nonterminal * lastgaelam
                adv[t] = lastgaelam
            ret = adv + self.buf_val

            b_obs  = torch.tensor(self.buf_obs.reshape(batch_size, self.obs_dim), device=self.device)
            b_act  = torch.tensor(self.buf_act.reshape(batch_size, self.act_dim), device=self.device)
            b_log  = torch.tensor(self.buf_logp.reshape(batch_size), device=self.device)
            b_adv  = torch.tensor(adv.reshape(batch_size), device=self.device)
            b_ret  = torch.tensor(ret.reshape(batch_size), device=self.device)
            b_vold = torch.tensor(self.buf_vold.reshape(batch_size), device=self.device)

            b_adv = (b_adv - b_adv.mean()) / (b_adv.std() + 1e-8)

            inds = np.arange(batch_size)
            approx_kl = 0.0

            for g in self.opt.param_groups:
                g["lr"] = self.lr

            for _ in range(self.train_epochs):
                np.random.shuffle(inds)
                for start in range(0, batch_size, self.minibatch_size):
                    mb = inds[start : start + self.minibatch_size]

                    _, new_logp, ent, new_val, _ = self.ac.get_action_and_value(b_obs[mb], b_act[mb])

                    log_ratio = new_logp - b_log[mb]
                    ratio = torch.exp(log_ratio)

                    surr1 = ratio * b_adv[mb]
                    surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * b_adv[mb]
                    actor_loss = -torch.min(surr1, surr2).mean()

                    v_pred = new_val
                    v_old = b_vold[mb]
                    v_clipped = v_old + torch.clamp(v_pred - v_old, -0.2, 0.2)

                    v_loss_unclipped = (v_pred - b_ret[mb]).pow(2)
                    v_loss_clipped   = (v_clipped - b_ret[mb]).pow(2)
                    v_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()

                    ent_loss = ent.mean()

                    loss = actor_loss + self.vf_coef * v_loss - ent_coef * ent_loss

                    self.opt.zero_grad()
                    loss.backward()
                    nn.utils.clip_grad_norm_(self.ac.parameters(), self.max_grad_norm)
                    self.opt.step()

                    with torch.no_grad():
                        approx_kl = float((ratio - 1.0 - log_ratio).mean().item())

                    if approx_kl > 1.5 * self.target_kl:
                        break
                if approx_kl > 1.5 * self.target_kl:
                    break

            for g in self.opt.param_groups:
                if approx_kl < 0.5 * self.target_kl:
                    g["lr"] = min(g["lr"] * 1.10, 3e-4)
                elif approx_kl > 1.5 * self.target_kl:
                    g["lr"] = max(g["lr"] * 0.70, 1e-5)

            self.update += 1

            do_eval = (self.update % self.eval_every_updates == 0)
            if do_eval:
                mean_raw, med_raw, p75_raw, best_raw, worst_raw, seed_ret = self.evaluate(
                    episodes=self.eval_episodes, deterministic=True
                )

                print(
                    f"Evaluation | upd {self.update:5d} | step {self.global_step:9d} | "
                    f"mean {mean_raw:7.2f} median {med_raw:7.2f} p75 {p75_raw:7.2f} "
                    f"best {best_raw:7.2f} worst {worst_raw:7.2f} | "
                    f"approx_kl {approx_kl:.4f} | ent {ent_coef:.4f} | "
                    f"jerk_coef {self.jerk_coef_current:.5f} | hard {len(self.hard_pool.seeds)} success {len(self.success_pool.seeds)}"
                )

                seed_ret_sorted = sorted(seed_ret, key=lambda x: x[1])  # ascending by return
                k = max(1, int(len(seed_ret_sorted) * self.hard_pool_bottom_frac))
                hard_seeds = [s for (s, r) in seed_ret_sorted[:k]]
                for s in hard_seeds:
                    self.hard_pool.add(s)

                for (s, r) in seed_ret_sorted:
                    if r >= self.success_pool_threshold:
                        self.success_pool.add(s)

                if (not self.jerk_enabled) and (p75_raw >= self.jerk_start_p75):
                    self.jerk_enabled = True
                    self.jerk_enable_update = int(self.update)
                    self.jerk_enable_step = int(self.global_step)
                    print(
                        f"Jerk penalty started (eval p75 >= {self.jerk_start_p75}). "
                        f"Increasing to {self.jerk_coef_max} over {self.jerk_ramp_updates} updates."
                    )

                if p75_raw > self.best_eval_p75:
                    self.best_eval_p75 = p75_raw
                    torch.save(self.ac.state_dict(), BEST_PATH)
                    print(f"New best saved -> {BEST_PATH}")

                self.eval_history.append({
                    "update": int(self.update),
                    "step": int(self.global_step),
                    "mean": float(mean_raw),
                    "median": float(med_raw),
                    "p75": float(p75_raw),
                    "best": float(best_raw),
                    "worst": float(worst_raw),
                    "approx_kl": float(approx_kl),
                    "ent_coef": float(ent_coef),
                    "jerk_coef": float(self.jerk_coef_current),
                    "hard_pool": int(len(self.hard_pool.seeds)),
                    "success_pool": int(len(self.success_pool.seeds)),
                })

                self.save(CKPT_PATH)
                self.plot_results(window_size=50)
                self.record_eval_video(deterministic=True)

            if (not do_eval) and (self.update % self.checkpoint_every_updates == 0):
                self.save(CKPT_PATH)
                print(f"Checkpoint saved -> {CKPT_PATH}")

            if self.best_eval_p75 >= 300.0:
                break

        torch.save(self.ac.state_dict(), FINAL_PATH)
        self.save(CKPT_PATH)
        print("Final model saved:", FINAL_PATH)


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using device:", device)

    trainer = PPOHardcoreVectorV11(
        device=device,
        num_envs=16,
        rollout_steps=2048,
        total_env_steps=40_000_000,

        lr=1.5e-4,
        clip_eps=0.15,
        target_kl=0.02,

        ent_coef_start=0.02,
        ent_coef_end=0.0,
        logstd_min=-5.0,
        logstd_max=-0.5,

        jerk_start_p75=240.0,
        jerk_coef_max=0.0010,
        jerk_ramp_updates=200,

        eval_every_updates=10,
        eval_episodes=16,

        checkpoint_every_updates=10,

        reseed_p_random=0.60,
        reseed_p_hard=0.20,
        reseed_p_success=0.20,
        hard_pool_bottom_frac=0.20,
        success_pool_threshold=300.0,
    )
    trainer.train()