In [1]:
import os

import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.vec_env import VecMonitor, is_vecenv_wrapped
from sumo_rl import parallel_env
import supersuit as ss

from evaluate import evaluate

In [2]:
from stable_baselines3.common.monitor import Monitor
from sumo_rl import SumoEnvironment

from reward_functions import diff_wait_time, tyre_pm
    
env_params = {
    "net_file": os.path.join("nets","simple_nets","cross1ltl","net.net.xml"),
    "route_file": os.path.join("nets","simple_nets","cross1ltl","input_routes.rou.xml"),
    "num_seconds": 1200,
    "single_agent": True,
    "reward_fn": diff_wait_time,
    "sumo_seed": 42,
}
env = SumoEnvironment(**env_params)
check_env(env)
env = Monitor(env)  # wrap env to know episode reward, length, time

In [3]:
from helper_functions import linear_schedule

# Using hyperparams for Atari (except for n_steps) from
# https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/ppo.yml

model = PPO(
    "MlpPolicy",
    env,
    learning_rate=2.5e-4,
    n_steps=1024,
    batch_size=256,
    n_epochs=4,
    clip_range=0.1,
    ent_coef=1e-3,
    verbose=1
)

Using cuda device
Wrapping the env in a DummyVecEnv.


In [4]:
evaluate(model, env, n_eval_episodes=1)

AttributeError: 'DummyVecEnv' object has no attribute 'traffic_signals'

In [None]:
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv

In [None]:
from stable_baselines3.common.monitor import Monitor

if not isinstance(env, VecEnv):
    env = DummyVecEnv([lambda: env])

is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
is_monitor_wrapped

In [None]:
n_eval_episodes = 1

n_envs = env.num_envs
episode_rewards = []
episode_lengths = []

n_envs

In [None]:
episode_counts = np.zeros(n_envs, dtype="int")
# Divides episodes among different sub environments in the vector as evenly as possible
episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int")

print(episode_counts)
print(episode_count_targets)

In [None]:
current_rewards = np.zeros(n_envs)
current_lengths = np.zeros(n_envs, dtype="int")
observations = env.reset()
states = None
episode_starts = np.ones((env.num_envs,), dtype=bool)

print(observations.shape)
print(episode_starts)

In [None]:
actions, states = model.predict(
    observations,  # type: ignore[arg-type]
    state=states,
    episode_start=episode_starts,
    deterministic=True,
)
print(actions)
print(len(actions))
print(states)

In [None]:
new_observations, rewards, dones, infos = env.step(actions)

print(rewards)
print(dones)

In [None]:
current_lengths += 1
current_lengths

In [None]:
for i in range(n_envs):
    if episode_counts[i] < episode_count_targets[i]:
        # unpack values so that the callback can access the local variables
        reward = rewards[i]
        done = dones[i]
        info = infos[i]
        episode_starts[i] = done

print(dones)
print(episode_starts)

observations = new_observations

In [None]:
import traci
num_steps = (env_params["num_seconds"]-5 - traci.simulation.getTime())/5
int(num_steps)

In [None]:
for _ in range(int(num_steps)):
    actions, states = model.predict(
        observations,  # type: ignore[arg-type]
        state=states,
        episode_start=episode_starts,
        deterministic=True,
    )
    new_observations, rewards, dones, infos = env.step(actions)
    current_rewards += rewards
    current_lengths += 1

    observations = new_observations

print(dones)
print(current_rewards)
print(current_lengths)

In [None]:
traci.simulation.getTime()

In [None]:
actions, states = model.predict(
    observations,  # type: ignore[arg-type]
    state=states,
    episode_start=episode_starts,
    deterministic=True,
)
new_observations, rewards, dones, infos = env.step(actions)
current_rewards += rewards
current_lengths += 1

print(dones)
print(infos)
print(current_rewards)
print(current_lengths)

In [None]:
observations = new_observations
traci.simulation.getTime()