# Random Network Distillation - colab version

In [None]:
# altered the original notebook for colab
# The only difference is the first 3 cells.
! git clone https://github.com/jnskkmhr/DRL_RND.git
%cd DRL_RND

!pip install -r requirements.txt

In [None]:
import os
import gymnasium as gym
from gym.wrappers import RecordVideo
from IPython.display import Video, display, clear_output
from tqdm import tqdm
import torch 
from torch.utils.tensorboard import SummaryWriter
# torch default device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
torch.set_default_device(device)

from rnd_rl.runner.policy_runner import PPOConfig, PolicyRunner
os.environ["MUJOCO_GL"] = "egl"

In [None]:
# @title Visualization code. Used later.

def visualize(agent):

    video_dir = "./videos"  # Directory to save videos
    os.makedirs(video_dir, exist_ok=True)

    # Create environment with proper render_mode
    env = gym.make("InvertedPendulum-v5", render_mode="rgb_array", reset_noise_scale=0.2)

    # Apply video recording wrapper
    env = RecordVideo(env, video_folder=video_dir, episode_trigger=lambda x: True)

    obs, _ = env.reset()


    for t in range(4096):
        actions, _ = agent.get_action(torch.Tensor(obs)[None, :].to(device))
        obs, _, done, _ = env.step(actions.squeeze(0).cpu().numpy())

        if done:
            # self.writer.add_scalar("Duration", t, i)
            break

    env.close()

    # Display the latest video
    video_path = os.path.join(video_dir, sorted(os.listdir(video_dir))[-1])  # Get the latest video


    clear_output(wait=True)
    display(Video(video_path, embed=True))

In [None]:
# Launch TensorBoard
%load_ext tensorboard
%tensorboard --logdir runs

In [None]:
n_envs = 64
envs = gym.vector.SyncVectorEnv(
    [lambda: gym.make("InvertedPendulum-v5", reset_noise_scale=0.2) for _ in range(n_envs)]
    )

### PPO baseline

In [None]:
ppo_cfg = PPOConfig(
    use_rnd=False, 
    clip_params=0.2,
    init_noise_std=1.0, 
)

In [None]:
num_epochs = 250
policy_runner = PolicyRunner(envs=envs, policy_cfg=ppo_cfg, num_mini_epochs=10, device=device)
for epoch in tqdm(range(num_epochs)):
    policy_runner.rollout(epoch)
    policy_runner.update()

In [None]:
visualize(policy_runner.alg)
print("PPO trained agent")

### PPO with RND

In [None]:
ppo_rnd_cfg = PPOConfig(
    use_rnd=True, 
    clip_params=0.2,
    init_noise_std=1.0, 
)

In [None]:
num_epochs = 250 
rnd_policy_runner = PolicyRunner(envs=envs, policy_cfg=ppo_rnd_cfg, num_mini_epochs=10,device=device)
for epoch in tqdm(range(num_epochs)):
    rnd_policy_runner.rollout(epoch)
    rnd_policy_runner.update()

In [None]:
visualize(rnd_policy_runner.alg)
print("RND PPO trained agent")

### Reward normalization only

In [None]:

ppo_rnd_reward_normalization_cfg = PPOConfig(
    use_rnd=True, 
    clip_params=0.2,
    init_noise_std=1.0, 
    reward_normalization = True
)


In [None]:
num_epochs = 250 
rnd_reward_norm_policy_runner = PolicyRunner(envs=envs, policy_cfg=ppo_rnd_reward_normalization_cfg, num_mini_epochs=10,device=device)
rnd_reward_norm_policy_runner.writer = SummaryWriter(log_dir=f'runs/{"RND_reward_normalization"}') 
for epoch in tqdm(range(num_epochs)):
    rnd_reward_norm_policy_runner.rollout(epoch)
    rnd_reward_norm_policy_runner.update()

In [None]:
visualize(rnd_reward_norm_policy_runner.alg)
print("RND PPO trained agent with reward normalization")

### Reward and observation normalization

In [None]:
ppo_rnd_all_normalization_cfg = PPOConfig(
    use_rnd=True, 
    clip_params=0.2,
    init_noise_std=1.0, 
    reward_normalization = True,
    obs_normalization = True
)


In [None]:
num_epochs = 250 
rnd_all_norm_policy_runner = PolicyRunner(envs=envs, policy_cfg=ppo_rnd_all_normalization_cfg, num_mini_epochs=10,device=device)
rnd_all_norm_policy_runner.writer = SummaryWriter(log_dir=f'runs/{"RND_all_normalization"}') 
for epoch in tqdm(range(num_epochs)):
    rnd_all_norm_policy_runner.rollout(epoch)
    rnd_all_norm_policy_runner.update()

In [None]:
visualize(rnd_all_norm_policy_runner.alg)
print("RND PPO trained agent with observation normalization")