# Grid2x2 using RLlib

In [None]:
import csv
import os
from typing import Optional

import ray
from ray.rllib.env.wrappers.multi_agent_env_compatibility import MultiAgentEnvCompatibility
from ray.tune.registry import register_env

from envs import MultiAgentSumoEnv
from observation import Cologne8ObservationFunction
from reward_functions import combined_reward

In [None]:
import random
import numpy as np
import torch

TEST_NUM = 1
SEED = 23423  # default SUMO seed no.
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

ENV_NAME = "cologne8"
OBS_CLASS = Cologne8ObservationFunction
assert not os.path.exists(os.path.join("ray_checkpoints",ENV_NAME,f"test_{TEST_NUM}"))

In [None]:
def train_env_creator(args):
    env_params = {
        "net_file": os.path.join("nets",ENV_NAME,f"{ENV_NAME}.net.xml"),
        "route_file": os.path.join("nets",ENV_NAME,f"{ENV_NAME}.rou.xml"),
        "num_seconds": 3600,
        "reward_fn": combined_reward,
        "sumo_seed": SEED,
        "observation_class": OBS_CLASS,
        "add_system_info": False,
    }
    congestion_reward = combined_reward.__defaults__[0].__name__
    alpha = combined_reward.__defaults__[1]  # congestion component coefficient
    print(congestion_reward, alpha)

    env = MultiAgentSumoEnv(**env_params)
    return env

In [None]:
from ray.rllib.algorithms.ppo import PPOConfig

# From https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ppo/atari-ppo.yaml

train_env = MultiAgentEnvCompatibility(train_env_creator({}))

config: PPOConfig
config = (
    PPOConfig()
    .environment(env=ENV_NAME)
    .framework(framework="torch")
    .rollouts(
        rollout_fragment_length=100,
        num_rollout_workers=10,
    )
    .training(
        lambda_=0.95,
        kl_coeff=0.5,
        clip_param=0.1,
        vf_clip_param=10.0,
        entropy_coeff=0.01,
        train_batch_size=1000,
        sgd_minibatch_size=100,
        num_sgd_iter=10,
    )
    .evaluation(
        evaluation_duration=1,
        evaluation_num_workers=1,
        evaluation_sample_timeout_s=300,
    )
    .debugging(seed=SEED)
    .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "1")))
    .multi_agent(
        policies=set(train_env.env.ts_ids),
        policy_mapping_fn=(lambda agent_id, *args, **kwargs: agent_id),
    )
    .fault_tolerance(recreate_failed_workers=True)
)

In [None]:
csv_dir = os.path.join("outputs",ENV_NAME,f"test_{TEST_NUM}")
if not os.path.exists(csv_dir):
    os.makedirs(csv_dir)

## Play Untrained Agent

In [None]:
def eval_env_creator(csv_path: Optional[str] = None, tb_log_dir: Optional[str] = None):
    env_params = {
        "net_file": os.path.join("nets",ENV_NAME,f"{ENV_NAME}.net.xml"),
        "route_file": os.path.join("nets",ENV_NAME,f"{ENV_NAME}.rou.xml"),
        "num_seconds": 3600,
        "reward_fn": combined_reward,
        "sumo_seed": SEED,
        "observation_class": OBS_CLASS,
        "add_system_info": False,
    }
    congestion_reward = combined_reward.__defaults__[0].__name__
    alpha = combined_reward.__defaults__[1]  # congestion component coefficient
    print(congestion_reward, alpha)

    env = MultiAgentSumoEnv(eval=True, csv_path=csv_path, tb_log_dir=tb_log_dir, **env_params)
    return env

In [None]:
ray.init()

csv_path = os.path.join(csv_dir, "untrained.csv")
tb_log_dir = os.path.join("logs", ENV_NAME, f"PPO_{TEST_NUM}", "eval_untrained")

with open(csv_path, "a", newline="") as f:
    csv_writer = csv.writer(f)
    csv_writer.writerow(["sim_time", "arrived_num", "sys_tyre_pm", "sys_stopped",
                         "sys_total_wait", "sys_avg_wait", "sys_avg_speed",
                         "agents_tyre_pm", "agents_stopped", "agents_total_wait",
                         "agents_avg_speed", "agents_total_pressure"])

register_env(ENV_NAME, lambda config: MultiAgentEnvCompatibility(eval_env_creator(csv_path, tb_log_dir)))

In [None]:
algo = config.build()

In [None]:
algo.evaluate()

In [None]:
ray.shutdown()

## Train RL Agent

In [None]:
ray.init()

register_env(ENV_NAME, lambda config: MultiAgentEnvCompatibility(train_env_creator(config)))

In [None]:
algo = config.build()

In [None]:
from datetime import datetime

TRAIN_EPS = 1400  # 720 * 1400 == 1_008_000 total timesteps
CHECKPOINT_FREQ = 100
assert TRAIN_EPS % CHECKPOINT_FREQ == 0

tic = datetime.now()

for i in range(TRAIN_EPS):
    results = algo.train()

    if (i+1) % CHECKPOINT_FREQ == 0:
        algo.save(os.path.join("ray_checkpoints",ENV_NAME,f"test_{TEST_NUM}"))

toc = datetime.now()

In [None]:
str(toc - tic)

In [None]:
ray.shutdown()

## Play Trained Agent

In [None]:
csv_path = os.path.join(csv_dir, "trained.csv")
tb_log_dir = os.path.join("logs", ENV_NAME, f"PPO_{TEST_NUM}", "eval_trained")

with open(csv_path, "a", newline="") as f:
    csv_writer = csv.writer(f)
    csv_writer.writerow(["sim_time", "arrived_num", "sys_tyre_pm", "sys_stopped",
                         "sys_total_wait", "sys_avg_wait", "sys_avg_speed",
                         "agents_tyre_pm", "agents_stopped", "agents_total_wait",
                         "agents_avg_speed", "agents_total_pressure"])

In [None]:
register_env(ENV_NAME, lambda config: MultiAgentEnvCompatibility(eval_env_creator(csv_path, tb_log_dir)))

In [None]:
ray.init()

In [None]:
from ray.rllib.algorithms.ppo import PPO

checkpoint_path = os.path.join("ray_checkpoints",ENV_NAME,f"test_{TEST_NUM}",f"checkpoint_{TRAIN_EPS:06}")
checkpoint_path = os.path.abspath(checkpoint_path)
print(checkpoint_path)

ppo_agent = PPO.from_checkpoint(checkpoint_path)

In [None]:
ppo_agent.evaluate()

In [None]:
ray.shutdown()