# Experiment: PPO

* Let's repeat the run of the previous notebook, but with a smaller network and more time to run.
* We are going to compare PPO to another algorithm.
* Decide the size of network you want to train and run the cells below.
* We will then move on to the next notebook while this model trains.

In [None]:
from environments.env_utils import env_from_env_config
from environments.observations import TreeObsForRailEnv
from environments.preprocessor import TreeObsPreprocessor
from models.dense_model import DenseModel
from ray import tune
from ray.rllib.models import ModelCatalog
from ray.tune.registry import register_env

In [None]:
# We want to see many more iterations of training for PPO
u =  # TODO: set this to something very small. Suggested: in the order of 10 units per layer.

In [None]:
# Set up a dense model with "u" hidden_layer units per layer
model_params = {
    "embedding": {"hidden_layers": [u, u], "activation_fn": "relu"},
    "actor": {"hidden_layers": [u], "activation_fn": "relu"},
    "critic": {"hidden_layers": [u], "activation_fn": "relu"},
}
custom_model = "dense_model"

# Set up the environment
env_config = {
    "obs_config": {"max_depth": 2},
    "rail_generator": "complex_rail_generator",
    "rail_config": {"nr_start_goal": 12, "nr_extra": 0, "min_dist": 8, "seed": 10},
    "width": 8,
    "height": 8,
    "number_of_agents": 5,
    "schedule_generator": "complex_schedule_generator",
    "schedule_config": {},
    "frozen": False,
    "remove_agents_at_target": True,
    "wait_for_all_done": False
}
env = env_from_env_config(env_config)
action_space = env.action_space
observation_space = env.observation_space

# Define 1 policy per agent
num_policies = env_config["number_of_agents"]
policies = {f"policy_{i}": (None, observation_space, action_space, {})
            for i in range(num_policies)}

# Register custom setup with RLlib
register_env("train_env", env_from_env_config)
ModelCatalog.register_custom_model(custom_model, DenseModel)
ModelCatalog.register_custom_preprocessor("tree_obs_preprocessor", TreeObsPreprocessor)

# Full experiment config
config = {
    # Run parameters
    "num_cpus_per_worker": 1,
    "num_cpus_for_driver": 1,
    "num_workers": 7,
    "num_gpus": 0,  # TODO: change this to 4
    
    # Environment parameters
    "env": "train_env",
    "env_config": env_config,
    "log_level": "ERROR",
    
    # Training parameters
    "horizon": 60,
    "num_sgd_iter": 15,
    "lr": 1e-4,
    
    # Policy parameters
    "vf_loss_coeff": 1e-6,    
    "multiagent": {
        "policies": policies,
        "policy_mapping_fn": lambda agent_id: "policy_0",
    },
    
    # Model parameters
    "model": {
        "custom_preprocessor": "tree_obs_preprocessor",
        "custom_model": custom_model,
        "custom_options": {
            "tree_depth": env_config["obs_config"]["max_depth"],
            "observation_radius": 0,
            **model_params,
        },
    },
}

n_GPUS = config["num_gpus"]

tune.run(
    "PPO",
    name=f"PPO_multi_agent-MODEL={custom_model}_{u}-GPUS={n_GPUS}",
    stop={"training_iteration": 80},
    config=config,
    checkpoint_freq=1,
    checkpoint_at_end=True,
    loggers=tune.logger.DEFAULT_LOGGERS,
    ray_auto_init=True,
)