In [4]:
%load_ext autoreload
%autoreload 2


In [2]:
# 🧠 PPO Training for TrafficEnv
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
import torch.nn.functional as F
from simulator import TrafficEnv  # Make sure this points to your .py with TrafficEnv
from model import TrafficNet  # Make sure this points to your GNN-like TrafficNet
def print_state(obs, n):
    traffic_tensor, dijkstra_tensor = obs
    for t in range(n):
        print('-------------------------')
        print('t =', t)
        print(traffic_tensor[:, :, t])

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Instantiate environment and model
env = TrafficEnv("traffic_maps", time_per_step=10, max_steps=10)  # Replace with actual path
obs_space = env.observation_space
action_space = env.action_space

# Create model
model = TrafficNet(
    n_vertices=env.n_vertices,
    n_timesteps=env.n_timesteps,
).to(device)



optimizer = optim.Adam(model.parameters(), lr=1e-3)
clip_epsilon = 0.2
gamma = 0.99
save_every = 50  # Save model every N episodes

# Training loop parameters
num_episodes = 10000
max_steps = env.max_steps

for episode in range(num_episodes):
    obs, _ = env.reset()
    traffic_tensor = torch.tensor(obs[0], dtype=torch.float32, device=device)
    next_paths = torch.tensor(obs[1], dtype=torch.float32, device=device)
    visited = torch.tensor(obs[2], dtype=torch.float32, device=device)
    current_vertex = torch.tensor([obs[3]], dtype=torch.long, device=device)

    current_vertex_onehot = F.one_hot(torch.tensor(current_vertex), num_classes=env.n_vertices).float().squeeze(0)

    # print_state((traffic_tensor, next_paths), traffic_tensor.size(0))

    log_probs = []
    values = []
    rewards = []
    states = []
    actions = []

    total_reward = 0

    for step in range(max_steps):


        action_logits, state_value = model(traffic_tensor, next_paths, visited.float(), current_vertex_onehot)
        dist = Categorical(logits=action_logits)
        action = dist.sample()

        new_obs, reward, done, truncated, _ = env.step(action.item())
        traffic_tensor = torch.tensor(new_obs[0], dtype=torch.float32, device=device)
        next_paths = torch.tensor(new_obs[1], dtype=torch.float32, device=device)
        visited = torch.tensor(new_obs[2], dtype=torch.float32, device=device)
        current_vertex = torch.tensor([new_obs[3]], dtype=torch.long, device=device)
        current_vertex_onehot = F.one_hot(torch.tensor(current_vertex), num_classes=env.n_vertices).float().squeeze(0)
        visited = visited.float()

        total_reward += reward

        states.append((traffic_tensor.clone(), next_paths.clone(), visited.clone(), current_vertex_onehot.clone()))
        actions.append(action)
        log_probs.append(dist.log_prob(action))
        values.append(state_value.squeeze())
        rewards.append(torch.tensor([reward], dtype=torch.float32, device=device))

        if done:
            break

        traffic_tensor = torch.tensor(new_obs[0], dtype=torch.float32, device=device)
        next_paths = torch.tensor(new_obs[1], dtype=torch.float32, device=device)
    print(env.visited_vertices.sum())

    # Compute returns and advantages
    returns = []
    G = 0
    print(actions)
    print(rewards)
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)
    returns = torch.cat(returns).detach()
    values = torch.stack(values).detach()
    log_probs = torch.stack(log_probs)
    actions = torch.stack(actions)

    advantages = (returns - values).detach()

    # PPO loss
    for _ in range(4):  # Multiple epochs
        new_log_probs = []
        new_values = []

        for i, (state_tensor, path_tensor, visited_tensor, one_hot_tensor) in enumerate(states):
            logits, value = model(state_tensor, path_tensor, visited_tensor, one_hot_tensor)
            dist = Categorical(logits=logits)
            new_log_probs.append(dist.log_prob(actions[i]))
            new_values.append(value.squeeze())

        new_log_probs = torch.stack(new_log_probs)
        new_values = torch.stack(new_values)

        ratio = (new_log_probs - log_probs.detach()).exp()
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
        policy_loss = -torch.min(surr1, surr2).mean()
        value_loss = (returns - new_values).pow(2).mean()
        entropy = dist.entropy().mean()
        loss = policy_loss + 0.5 * value_loss - 0.1 * entropy

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Episode {episode} | Total Reward: {total_reward:.2f} | Loss: {loss.item():.4f}")

    # Save model every N episodes
    if episode % save_every == 0 and episode > 0:
        torch.save(model.state_dict(), f"ppo_traffic_model_ep{episode}.pt")
        print(f"✅ Saved model at episode {episode}")


Using device: cpu
tensor(7)
[tensor(8), tensor(3), tensor(3), tensor(3), tensor(3), tensor(2), tensor(5), tensor(3), tensor(6), tensor(0)]
[tensor([17.5183]), tensor([25.9750]), tensor([27.6357]), tensor([-70.]), tensor([-70.]), tensor([36.6212]), tensor([46.1148]), tensor([-50.]), tensor([56.1928]), tensor([66.2889])]


  current_vertex_onehot = F.one_hot(torch.tensor(current_vertex), num_classes=env.n_vertices).float().squeeze(0)
  current_vertex_onehot = F.one_hot(torch.tensor(current_vertex), num_classes=env.n_vertices).float().squeeze(0)


Episode 0 | Total Reward: 86.35 | Loss: 1635.3099
tensor(3)
[tensor(5), tensor(5), tensor(5), tensor(5), tensor(5), tensor(8), tensor(5), tensor(5), tensor(5), tensor(5)]
[tensor([-80.]), tensor([-80.]), tensor([-80.]), tensor([-80.]), tensor([-80.]), tensor([30.]), tensor([-70.]), tensor([-70.]), tensor([-70.]), tensor([-70.])]
Episode 1 | Total Reward: -650.00 | Loss: 74802.7109
tensor(4)
[tensor(2), tensor(2), tensor(2), tensor(2), tensor(6), tensor(2), tensor(2), tensor(7), tensor(6), tensor(8)]
[tensor([20.]), tensor([20.]), tensor([-80.]), tensor([-80.]), tensor([-80.]), tensor([-80.]), tensor([-80.]), tensor([30.]), tensor([-70.]), tensor([40.])]
Episode 2 | Total Reward: -360.00 | Loss: 27376.1953
tensor(4)
[tensor(6), tensor(6), tensor(6), tensor(6), tensor(2), tensor(4), tensor(6), tensor(4), tensor(2), tensor(4)]
[tensor([-80.]), tensor([-80.]), tensor([-80.]), tensor([-80.]), tensor([30.]), tensor([40.]), tensor([-60.]), tensor([-60.]), tensor([-60.]), tensor([-60.])]
Episo

KeyboardInterrupt: 

In [6]:
from trainer_dql import train
early_stop, reward_list = train(num_steps=100, mini_batch_size=16, ppo_epochs=4, threshold=0)

  actions = []


TypeError: linear(): argument 'input' (position 1) must be Tensor, not list