In [8]:
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# 🧠 PPO Training for TrafficEnv
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
import torch.nn.functional as F
from simulator import TrafficEnv  # Make sure this points to your .py with TrafficEnv
from model import TrafficNet  # Make sure this points to your GNN-like TrafficNet
def print_state(obs, n):
    traffic_tensor, dijkstra_tensor = obs
    for t in range(n):
        print('-------------------------')
        print('t =', t)
        print(traffic_tensor[:, :, t])

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Instantiate environment and model
env = TrafficEnv("traffic_maps", time_per_step=10, max_steps=10)  # Replace with actual path
obs_space = env.observation_space
action_space = env.action_space

# Create model
model = TrafficNet(
    n_vertices=env.n_vertices,
    n_timesteps=env.n_timesteps,
).to(device)



optimizer = optim.Adam(model.parameters(), lr=1e-3)
clip_epsilon = 0.2
gamma = 0.99
save_every = 50  # Save model every N episodes

# Training loop parameters
num_episodes = 10000
max_steps = env.max_steps

for episode in range(num_episodes):
    obs, _ = env.reset()
    traffic_tensor = torch.tensor(obs[0], dtype=torch.float32, device=device)
    next_paths = torch.tensor(obs[1], dtype=torch.float32, device=device)
    visited = torch.tensor(obs[2], dtype=torch.float32, device=device)
    current_vertex = torch.tensor([obs[3]], dtype=torch.long, device=device)

    current_vertex_onehot = F.one_hot(torch.tensor(current_vertex), num_classes=env.n_vertices).float().squeeze(0)

    # print_state((traffic_tensor, next_paths), traffic_tensor.size(0))

    log_probs = []
    values = []
    rewards = []
    states = []
    actions = []

    total_reward = 0

    for step in range(max_steps):


        action_logits, state_value = model(traffic_tensor, next_paths, visited.float(), current_vertex_onehot)
        dist = Categorical(logits=action_logits)
        action = dist.sample()

        new_obs, reward, done, truncated, _ = env.step(action.item())
        traffic_tensor = torch.tensor(new_obs[0], dtype=torch.float32, device=device)
        next_paths = torch.tensor(new_obs[1], dtype=torch.float32, device=device)
        visited = torch.tensor(new_obs[2], dtype=torch.float32, device=device)
        current_vertex = torch.tensor([new_obs[3]], dtype=torch.long, device=device)
        current_vertex_onehot = F.one_hot(torch.tensor(current_vertex), num_classes=env.n_vertices).float().squeeze(0)
        visited = visited.float()

        total_reward += reward

        states.append((traffic_tensor.clone(), next_paths.clone(), visited.clone(), current_vertex_onehot.clone()))
        actions.append(action)
        log_probs.append(dist.log_prob(action))
        values.append(state_value.squeeze())
        rewards.append(torch.tensor([reward], dtype=torch.float32, device=device))

        if done:
            break

        traffic_tensor = torch.tensor(new_obs[0], dtype=torch.float32, device=device)
        next_paths = torch.tensor(new_obs[1], dtype=torch.float32, device=device)
    print(env.visited_vertices.sum())

    # Compute returns and advantages
    returns = []
    G = 0
    print(actions)
    print(rewards)
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)
    returns = torch.cat(returns).detach()
    values = torch.stack(values).detach()
    log_probs = torch.stack(log_probs)
    actions = torch.stack(actions)

    advantages = (returns - values).detach()

    # PPO loss
    for _ in range(4):  # Multiple epochs
        new_log_probs = []
        new_values = []

        for i, (state_tensor, path_tensor, visited_tensor, one_hot_tensor) in enumerate(states):
            logits, value = model(state_tensor, path_tensor, visited_tensor, one_hot_tensor)
            dist = Categorical(logits=logits)
            new_log_probs.append(dist.log_prob(actions[i]))
            new_values.append(value.squeeze())

        new_log_probs = torch.stack(new_log_probs)
        new_values = torch.stack(new_values)

        ratio = (new_log_probs - log_probs.detach()).exp()
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
        policy_loss = -torch.min(surr1, surr2).mean()
        value_loss = (returns - new_values).pow(2).mean()
        entropy = dist.entropy().mean()
        loss = policy_loss + 0.5 * value_loss - 0.1 * entropy

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            list(model.parameters()), 
    max_norm=0.5)
        optimizer.step()

    print(f"Episode {episode} | Total Reward: {total_reward:.2f} | Loss: {loss.item():.4f}")

    # Save model every N episodes
    if episode % save_every == 0 and episode > 0:
        torch.save(model.state_dict(), f"ppo_traffic_model_ep{episode}.pt")
        print(f"✅ Saved model at episode {episode}")


In [11]:
from trainer_ppo_class import train
import os
with open('output.csv', 'w') as f:
    f.write('step, reward\n')
early_stop, reward_list = train(num_steps=10000, mini_batch_size=16, ppo_epochs=4, threshold=0)

  current_vertex_onehot = F.one_hot(torch.tensor(current_vertex), num_classes=env.n_vertices).float().to(device).squeeze(0)
  state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)


[4, 6, 8, 2, 3, 5, 6, 9]
[6, 0, 5, 8, 1, 2, 5, 6]
[6, 3, 9, 2, 1, 3, 2, 2]
[8, 1, 3, 4, 6, 2, 4, 3]
[4, 4, 3, 5, 4, 6, 7, 4]
[2, 9, 3, 0, 3, 4, 2, 4]
[2, 7, 1, 6, 5, 1, 1, 1]
[5, 5, 8, 1, 3, 8, 5, 7]
[3, 1, 4, 3, 2, 7, 3, 3]
[7, 0, 1, 5, 0, 3, 8, 2]
Step: 0	Reward: 32.0
[7, 6, 6, 5, 8, 9, 6, 5]
[8, 5, 9, 9, 9, 9, 9, 9]
[4, 8, 4, 3, 9, 1, 2, 9]
[9, 5, 0, 5, 7, 4, 8, 3]
[3, 3, 2, 7, 9, 6, 5, 2]
[6, 2, 3, 3, 4, 9, 2, 2]
[3, 7, 4, 0, 8, 3, 5, 6]
[3, 5, 9, 7, 0, 2, 4, 0]
[0, 2, 6, 9, 3, 2, 0, 2]
[1, 8, 1, 2, 5, 7, 3, 6]
Step: 5	Reward: 30.0
[7, 4, 8, 4, 5, 9, 2, 0]
[2, 2, 9, 1, 8, 2, 3, 9]
[8, 4, 1, 0, 9, 5, 2, 8]
[2, 9, 9, 1, 3, 8, 8, 1]
[1, 0, 3, 2, 0, 3, 1, 0]
[2, 9, 9, 5, 3, 6, 6, 0]
[4, 0, 2, 1, 6, 0, 7, 2]
[1, 4, 3, 4, 0, 6, 1, 8]
[2, 5, 2, 0, 3, 8, 5, 1]
[4, 5, 2, 9, 0, 2, 2, 7]
Step: 10	Reward: 30.0
[9, 3, 4, 5, 7, 6, 5, 3]
[9, 5, 4, 9, 7, 3, 6, 4]
[3, 8, 8, 6, 3, 6, 1, 9]
[1, 4, 9, 4, 7, 5, 4, 5]
[5, 5, 3, 2, 0, 2, 9, 3]
[1, 4, 3, 5, 0, 4, 5, 5]
[0, 9, 2, 1, 4, 3, 3, 1]
[5, 4, 3, 3

In [None]:
from trainer_ppo_class import evaluate_policy
from simulator import TrafficEnv

def nearest_neighbor(next_paths):
    i = torch.argmax(next_paths)
    return i

def test_nn(n):
    # env = TrafficEnv(trafficmap_dir='traffic_maps', time_per_step=20, max_steps=10)
    for i in range(n):
        total_reward += evaluate_policy(nearest_neighbor)
    return total_reward/n
    

