# RL training on the env

In [2]:
import numpy as np
import pandas as pd
from plotnine import (
    ggplot, aes, geom_density, geom_line, geom_point, 
    geom_violin, facet_grid, labs, theme, facet_wrap,
)

# for rl training
from stable_baselines3 import PPO, TD3
from sb3_contrib import TQC
from stable_baselines3.common.env_util import make_vec_env

# the rl environment
from rl4greencrab import greenCrabSimplifiedEnv as gcse

# helper that paralelizes episode simulations for evaluation purposes (agent -> reward)
from rl4greencrab import evaluate_agent

# helper that creates a single episode simulation keeping track of many variables
# of the internal env state
from rl4greencrab import simulator

In [4]:
# using 'vectorized environments' helps paralelize RL training
# (the RL agent collects data by simultaneously interacting with
# n_envs different environments, rather than doing it one envir.
# at a time.
vec_env = make_vec_env(gcse, n_envs=12)

## Algo 1: PPO

see docs here: https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html

In [5]:
model = PPO("MlpPolicy", vec_env, verbose=0, tensorboard_log="/home/rstudio/logs")
model.learn(
	total_timesteps=250_000, 
	progress_bar=True,
)
model.save("ppo_gcse")

Output()

## Algo 2: TD3

In [None]:
model = TD3("MlpPolicy", gcse, verbose=0, tensorboard_log="/home/rstudio/logs")
model.learn(
	total_timesteps=250_000, 
	progress_bar=True,
)
model.save("td3_gcse")

## Algo 3: TQC

In [7]:
model = TQC("MlpPolicy", vec_env, verbose=0, tensorboard_log="/home/rstudio/logs")
model.learn(
	total_timesteps=250_000, 
	progress_bar=True,
)
model.save("tqc_gcse")

Output()

## Loading and evaluating trained models

In [10]:
# load agents into the CPU (rather than the GPU - the default)
# since the paralelization we use to evaluate agents works with
# CPU

ppoAgent = PPO.load("ppo_gcse", device="cpu")
# td3Agent = TD3.load("td3_gcse", device="cpu")
tqcAgent = TQC.load("tqc_gcse", device="cpu")
evalEnv = gcse()

In [13]:
N_EPS = 30

ppo_rew = evaluate_agent(agent=ppoAgent, env=evalEnv, ray_remote=True).evaluate(n_eval_episodes=N_EPS)
# td3_rew = evaluate_agent(agent=td3Agent, ray_remote=True).evaluate(n_eval_episodes=N_EPS)
tqc_rew = evaluate_agent(agent=tqcAgent, env=evalEnv, ray_remote=True).evaluate(n_eval_episodes=N_EPS)

In [15]:
import ray
ray.shutdown()

In [16]:
print(f"""
PPO mean rew = {ppo_rew}
TQC mean rew = {tqc_rew}
""")


PPO mean rew = -9.576186668916746
TQC mean rew = -6.698807471424942

