In [None]:
import pickle as pkl
import numpy as np
from env import TransitNetworkEnv

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from agent import GNNPolicy, FeatureExtractor

from matplotlib import pyplot as plt
from stable_baselines3.common.utils import obs_as_tensor

In [None]:
env = TransitNetworkEnv()
def make_env(seed):
    def _init():
        env = TransitNetworkEnv(seed=seed)
        return env
    return _init

n_envs = 6
env = SubprocVecEnv([make_env(seed=i*1000) for i in range(n_envs)])

In [None]:
model = PPO(
    tensorboard_log="./logs/",
    policy=GNNPolicy,
    env=env,
    verbose=0,
    policy_kwargs=dict(
        features_extractor_class=FeatureExtractor,
        features_extractor_kwargs=dict(hidden_dim=512, num_heads=4, out_dim=256, env=env),
    ),
    batch_size=32,
    n_steps=32,
    learning_rate=1e-5,
    gamma=0.999,
)

In [None]:
from stable_baselines3.common.callbacks import BaseCallback

class InfoTensorboardCallback(BaseCallback):
    def __init__(self, verbose=0):
        super().__init__(verbose)

    def _on_step(self) -> bool:
        infos = self.locals.get("infos", [])

        for i, info in enumerate(infos):
            for key in info:
                self.logger.record(f"env/{key}", info[key], self.num_timesteps)
            
        return True

In [None]:
callback = InfoTensorboardCallback()
model.learn(total_timesteps=1080*100, log_interval=1, callback=callback, tb_log_name="first_run", reset_num_timesteps=False, progress_bar=False)
model.save("model.pkl")

In [None]:
model.load("model.pkl")
obs, _ = env.reset()
rewards = []

i = 0
while True:
    action = model.policy.predict(obs_as_tensor(obs, device="cpu"))
    obs, reward, terminated, truncated, info = env.step(action.squeeze())
    env.render()
    rewards.append(reward)
    if terminated or truncated:
        break
    i += 1

    if i > 2000:
        print("force breaking")
        break
    
policy_rewards = [reward for reward in rewards]
plt.plot(policy_rewards)

In [None]:
env = TransitNetworkEnv()
obs, _ = env.reset()
rewards = []
i = 0
while True:
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action.squeeze())
    rewards.append(reward)
    if terminated or truncated:
        break
    i += 1

    if i > 2000:
        print("force breaking")
        break

In [None]:
plt.plot(rewards[:len(policy_rewards)], label="random")
plt.plot(policy_rewards, label="GNN")
plt.legend()
ch = ((np.array(policy_rewards) - np.array(rewards[:len(policy_rewards)]))>0)
print(f"{ch.sum()/ch.shape[0] * 100:.3f}", r"% of the time GNN performs better than the random")