In [1]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No CUDA device")

True
NVIDIA GeForce RTX 4060 Laptop GPU


In [34]:
import torch
import torch.nn as nn
from torch.distributions import Categorical

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.state_values = []
        self.is_terminals = []

    def clear(self):
        self.actions.clear()
        self.states.clear()
        self.logprobs.clear()
        self.rewards.clear()
        self.state_values.clear()
        self.is_terminals.clear()

class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorCritic, self).__init__()

        self.actor = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.Tanh(),
            nn.Linear(128, 128),
            nn.Tanh(),
            nn.Linear(128, action_dim),
            nn.Softmax(dim=-1)
        )

        self.critic = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.Tanh(),
            nn.Linear(128, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )

    def act(self, state, deterministic=False):
        action_probs = self.actor(state)
        if deterministic:
            action = torch.argmax(action_probs, dim=-1)
            action_logprob = torch.log(action_probs.gather(1, action.unsqueeze(-1)).squeeze(-1))
        else:
            dist = Categorical(action_probs)
            action = dist.sample()
            action_logprob = dist.log_prob(action)
        state_value = self.critic(state)
        return action, action_logprob, state_value

    def evaluate(self, states, actions):
        action_probs = self.actor(states)
        dist = Categorical(action_probs)
        action_logprobs = dist.log_prob(actions)
        dist_entropy = dist.entropy()
        state_values = self.critic(states).squeeze(-1)
        return action_logprobs, state_values, dist_entropy

class PPO:
    def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip):
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs

        self.buffer = RolloutBuffer()

        self.policy = ActorCritic(state_dim, action_dim).to(device)
        self.optimizer = torch.optim.Adam([
            {'params': self.policy.actor.parameters(), 'lr': lr_actor},
            {'params': self.policy.critic.parameters(), 'lr': lr_critic}
        ])

        self.policy_old = ActorCritic(state_dim, action_dim).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.MseLoss = nn.MSELoss()

    def select_action(self, state):
        state = torch.FloatTensor(state).flatten().unsqueeze(0).to(device)
        with torch.no_grad():
            action, action_logprob, state_value = self.policy_old.act(state, deterministic=False)

        self.buffer.states.append(state.squeeze(0))
        self.buffer.actions.append(action)
        self.buffer.logprobs.append(action_logprob)
        self.buffer.state_values.append(state_value.squeeze())

        return action.item()

    def update(self, mini_batch_size=256):
        # ==== Compute GAE advantages and returns ====
        rewards = self.buffer.rewards
        is_terminals = self.buffer.is_terminals
        state_values = self.buffer.state_values

        dummy = torch.tensor(0.0, dtype=state_values[0].dtype, device=device)
        values = torch.stack(state_values + [dummy])
        gamma = self.gamma
        lam = 0.95  # GAE lambda

        advantages = []
        gae = 0
        for t in reversed(range(len(rewards))):
            delta = rewards[t] + gamma * values[t + 1] * (1 - is_terminals[t]) - values[t]
            gae = delta + gamma * lam * (1 - is_terminals[t]) * gae
            advantages.insert(0, gae)

        advantages = torch.tensor(advantages, dtype=torch.float32).to(device)
        returns = advantages + torch.stack(state_values).to(device)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # ==== Convert rollout buffer to tensors ====
        old_states = torch.stack(self.buffer.states).to(device)
        old_actions = torch.stack(self.buffer.actions).to(device)
        old_logprobs = torch.stack(self.buffer.logprobs).to(device)
        old_state_values = torch.stack(state_values).detach().to(device)

        # ==== PPO mini-batch updates ====
        dataset_size = old_states.size(0)
        for _ in range(self.K_epochs):
            indices = torch.randperm(dataset_size)
            for start in range(0, dataset_size, mini_batch_size):
                end = start + mini_batch_size
                mb_idx = indices[start:end]

                mb_states = old_states[mb_idx]
                mb_actions = old_actions[mb_idx]
                mb_logprobs = old_logprobs[mb_idx]
                mb_advantages = advantages[mb_idx]
                mb_returns = returns[mb_idx]

                # Evaluate actions under current policy
                logprobs, state_values, dist_entropy = self.policy.evaluate(mb_states, mb_actions)
                ratios = torch.exp(logprobs - mb_logprobs.detach())

                # Clipped surrogate loss
                surr1 = ratios * mb_advantages
                surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * mb_advantages
                actor_loss = -torch.min(surr1, surr2)

                # Value function loss
                critic_loss = self.MseLoss(state_values, mb_returns)

                # Total loss (actor + critic - entropy)
                loss = actor_loss + 0.5 * critic_loss - 0.01 * dist_entropy

                # Backprop and optimize
                self.optimizer.zero_grad()
                loss.mean().backward()
                self.optimizer.step()

        # ==== Sync old policy and clear buffer ====
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.buffer.clear()

        
    def select_action_eval(self, state):
        state = torch.FloatTensor(state).flatten().unsqueeze(0).to(device)
        with torch.no_grad():
            action, _, _ = self.policy_old.act(state, deterministic=True)
        return action.item()

Using device: cuda


In [35]:
import gymnasium as gym
import highway_env
import numpy as np
import matplotlib.pyplot as plt


BASE_CONFIG = {
    "observation":       {"type": "Kinematics"},
    "action":            {"type": "DiscreteMetaAction"},
    "lanes_count":       4,
    "controlled_vehicles": 1,
    "duration":          40,
    "ego_spacing":       2,
    "vehicles_density":  1,
    "collision_reward":  -20,
    "right_lane_reward": 0.1,
    "high_speed_reward": 1,
    "lane_change_reward": 0,
    "normalize_reward":  True,
    "offroad_terminal":  False,
}

In [40]:

SCENARIOS = {
    "Slow":   {"reward_speed_range": [10, 20]},
    "Normal": {"reward_speed_range": [20, 30]},
    "Fast":   {"reward_speed_range": [30, 40]},
}

def train(cfg, label=""):
    env = gym.make("highway-v0", render_mode="rgb_array").unwrapped
    env.configure(cfg)
    print(f"\nTraining scenario: {label}")
    print(env.observation_space)
    
    state_dim = np.prod(env.observation_space.shape)
    action_dim = env.action_space.n

    # PPO hyperparameters
    max_episodes = 1200
    max_ep_len = 1000
    update_timestep = 1500
    K_epochs = 10
    eps_clip = 0.2
    gamma = 0.99
    lr_actor = 0.0005
    lr_critic = 0.0005

    episode_reward_history = []
    running_rewards = []
    last_n_reward = 50

    ppo_agent = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip)

    time_step = 0
    i_episode = 0

    while i_episode < max_episodes:
        obs, _ = env.reset()
        current_ep_reward = 0

        for _ in range(max_ep_len):
            action = ppo_agent.select_action(obs)
            next_obs, reward, done, truncated, info = env.step(action)

            ppo_agent.buffer.rewards.append(reward)
            ppo_agent.buffer.is_terminals.append(done or truncated)

            time_step += 1
            current_ep_reward += reward

            if time_step % update_timestep == 0:
                print("update with length "+ str(len(ppo_agent.buffer.rewards)))
                ppo_agent.update()

            obs = next_obs

            if done or truncated:
                break

        i_episode += 1
        episode_reward_history.append(current_ep_reward)

        running_reward = (
            np.mean(episode_reward_history[-last_n_reward:])
            if len(episode_reward_history) > last_n_reward
            else np.mean(episode_reward_history)
        )
        running_rewards.append(running_reward)
        if i_episode % 10 == 0:
            print(f"Episode {i_episode}, Reward: {current_ep_reward:.2f}, Running Avg ({last_n_reward}): {running_reward:.2f}")

    plt.figure(figsize=(10, 5))
    plt.plot(episode_reward_history, label="Episode Reward")
    plt.plot(running_rewards, label=f"Running Avg (last {last_n_reward})", linewidth=2)
    plt.xlabel("Episode")
    plt.ylabel("Reward")
    plt.title(f"PPO Training Rewards - {label}")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    env.close()
    return ppo_agent

if __name__ == "__main__":
    trained_agents = {}
    for label, scenario_cfg in SCENARIOS.items():
        cfg = BASE_CONFIG.copy()
        cfg.update(scenario_cfg)
        agent = train(cfg, label)
        torch.cuda.empty_cache()
        trained_agents[label] = agent


Training scenario: Slow
Box(-inf, inf, (5, 5), float32)
Episode 10, Reward: 3.05, Running Avg (50): 6.61
Episode 20, Reward: 6.03, Running Avg (50): 8.86
Episode 30, Reward: 3.04, Running Avg (50): 9.60
Episode 40, Reward: 2.02, Running Avg (50): 9.88
Episode 50, Reward: 38.00, Running Avg (50): 10.96
Episode 60, Reward: 4.01, Running Avg (50): 12.16
Episode 70, Reward: 39.96, Running Avg (50): 12.38
Episode 80, Reward: 11.98, Running Avg (50): 12.94
Episode 90, Reward: 2.02, Running Avg (50): 12.71
Episode 100, Reward: 32.95, Running Avg (50): 11.70
Episode 110, Reward: 18.93, Running Avg (50): 11.10
Episode 120, Reward: 5.02, Running Avg (50): 11.13
update with length 1500
Episode 130, Reward: 24.90, Running Avg (50): 10.69
Episode 140, Reward: 6.02, Running Avg (50): 12.64
Episode 150, Reward: 5.01, Running Avg (50): 14.06
Episode 160, Reward: 17.96, Running Avg (50): 14.72
Episode 170, Reward: 26.91, Running Avg (50): 14.96
Episode 180, Reward: 8.99, Running Avg (50): 14.66
Episod

KeyboardInterrupt: 