<a href="https://colab.research.google.com/github/maayanorner/RL_snippets/blob/main/colabs/rl_lib.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install ray version without device bugs
!pip install -U ray[default]==2.0 > /dev/null 2>&1
!pip install -U ray[rllib]==2.0 > /dev/null 2>&1
# Install other dependencies
!pip install pygame > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1
!pip install -U colabgymrender > /dev/null 2>&1
!pip install imageio==2.4.1 > /dev/null 2>&1

# Install gym
!pip install gym[atari] > /dev/null 2>&1
!pip install gym[accept-rom-license] > /dev/null 2>&1
#!pip install gym[accept-rom-license]

# GPU monitoring
!pip install GPUtil > /dev/null 2>&1

In [None]:
import os
os.environ['SDL_VIDEODRIVER']='dummy'
import pygame
pygame.display.set_mode((640,480))

In [None]:
import ray
from ray import air, tune

ray.init()

In [None]:
from gym import envs
import gym
#print(envs.registry.all())

In [None]:
env_name = "MsPacman-v4"
from ray.rllib.env.wrappers.atari_wrappers import WarpFrame, MaxAndSkipEnv, FrameStack
dim = 84

class AtariWapped(gym.Env):
    def __init__(self, env_config={}):
        self.env = gym.make(env_name, **env_config)
        self.env = WarpFrame(self.env, dim=dim)
        self.env = MaxAndSkipEnv(self.env, skip=4)
        self.env = FrameStack(self.env, k=4)
        #env = gym.make(env_name, frameskip=0)
        #self.env = gym.wrappers.AtariPreprocessing(env, noop_max=30, frame_skip=16, screen_size=dim, terminal_on_life_loss=True, grayscale_obs=True, grayscale_newaxis=False, scale_obs=False)
        self.action_space = self.env.action_space
        self.observation_space = self.env.observation_space

    def step(self, action):
        return self.env.step(action)

    def reset(self):
        return self.env.reset()

    def render(self, *args, **kwargs):
        return self.env.render(*args, **kwargs)

e_info =  AtariWapped()

from copy import deepcopy

config = deepcopy({
  'observation_space': e_info.observation_space,
  'action_space': e_info.action_space
})

wrapped_name = AtariWapped
from ray.tune.registry import register_env

def env_creator(env_config):
    return AtariWapped(env_config)
final_env_name = "attari_wrapped"
register_env(final_env_name, env_creator)


In [None]:
AtariWapped().reset().shape

In [None]:
from ray.rllib.algorithms.dqn import DQN
from ray.rllib.algorithms.ppo import PPO
from math import sqrt

In [None]:
from ray.tune.stopper import Stopper, TrialPlateauStopper, CombinedStopper, MaximumIterationStopper


# Stop on degradation
es = CombinedStopper(TrialPlateauStopper(metric='episode_reward_mean', num_results = 4, mode = 'max', grace_period = 4), MaximumIterationStopper(max_iter=2000))

In [None]:
from ray.tune.schedulers import HyperBandScheduler

sc = HyperBandScheduler('time_total_s', metric='episode_reward_mean', max_t=1800, mode="max")

In [None]:
import psutil

eps_n_workes = 0.001

In [None]:
Algo = DQN
save_per_training_iteration = 10
train_batch_size = 512
num_workers = 30#100
num_av_cpus = psutil.cpu_count() - 1

# run on CPU
workers_gpu_frac = 0
num_gpus_per_worker = (1/num_workers)*workers_gpu_frac 

lr = [1e-4] #[1e-5, 1e-4]#[0.01, 1e-3, 1e-4, 1e-5]
config = {
        "env": final_env_name,
        "num_gpus": 1-workers_gpu_frac,
        "num_gpus_per_worker": num_gpus_per_worker,#(1/num_workers)*workers_gpu_frac,
        "num_cpus_per_worker": (1/num_workers)*num_av_cpus, 
        "num_workers": num_workers,
        "lr": tune.grid_search(lr),
        "framework": "torch",
        #'evaluation_num_workers': 0,
        #'evaluation_parallel_to_training': False,
        "train_batch_size": train_batch_size,
        #'num_rollout_workers': 1,
        'soft_horizon': False,
        'horizon': 100000,
  }

# if type(Algo) == DQN:
#   config["optimizer"] = "ADAM"

if type(Algo) == PPO:
  # Like in paper
  config['grad_clip'] = 0.2
  config['horizon'] = 128
  config["optimizer"] = "ADAM"
  #config["train_batch_size"] = 32*num_workers
  lr = [1e-4*sqrt(num_workers)]
  

results = tune.Tuner(
    Algo,
    tune_config=tune.TuneConfig(scheduler=sc),
    run_config=air.RunConfig(
        #stop=es,#{"training_iteration": training_iteration, },
        checkpoint_config=air.CheckpointConfig(checkpoint_frequency=save_per_training_iteration),
        name="example-experiment-atari",
        local_dir="./example-experiment"
        ),
    param_space=config,
).fit()

In [None]:
from ray.tune import ExperimentAnalysis
analysis = ExperimentAnalysis("/content/example-experiment/")

In [None]:
trial_logdir = analysis.get_best_logdir(metric="episode_reward_mean", mode="max")  # Can also just specify trial dir directly

checkpoints = analysis.get_trial_checkpoints_paths(trial_logdir)  # Returns tuples of (logdir, metric)
best_checkpoint = analysis.get_best_checkpoint(trial_logdir, metric="episode_reward_mean", mode="max")

In [None]:
agent = Algo(config={"framework": "torch"}, env=final_env_name)
agent.restore(best_checkpoint)

In [None]:
from colabgymrender.recorder import Recorder

video_every = 1
env = AtariWapped()
env = Recorder(env, "./video", 0.2)

In [None]:
# run until episode ends
episode_reward = 0
done = False
obs = env.reset()
while not done:
    action = agent.compute_action(obs)
    obs, reward, done, info = env.step(action)
    episode_reward += reward

print(episode_reward)