In [1]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No CUDA device")

True
NVIDIA GeForce RTX 4060 Laptop GPU


In [11]:
import torch
import torch.nn as nn
from torch.distributions import Categorical

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.state_values = []
        self.is_terminals = []

    def clear(self):
        self.actions.clear()
        self.states.clear()
        self.logprobs.clear()
        self.rewards.clear()
        self.state_values.clear()
        self.is_terminals.clear()

class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorCritic, self).__init__()

        self.actor = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1)
        )

        self.critic = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

    def act(self, state, deterministic=False):
        action_probs = self.actor(state)
        if deterministic:
            action = torch.argmax(action_probs, dim=-1)
            action_logprob = torch.log(action_probs.gather(1, action.unsqueeze(-1)).squeeze(-1))
        else:
            dist = Categorical(action_probs)
            action = dist.sample()
            action_logprob = dist.log_prob(action)
        state_value = self.critic(state)
        return action, action_logprob, state_value

    def evaluate(self, states, actions):
        action_probs = self.actor(states)
        dist = Categorical(action_probs)
        action_logprobs = dist.log_prob(actions)
        dist_entropy = dist.entropy()
        state_values = self.critic(states).squeeze(-1)
        return action_logprobs, state_values, dist_entropy

class PPO:
    def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip):
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs

        self.buffer = RolloutBuffer()

        self.policy = ActorCritic(state_dim, action_dim).to(device)
        self.optimizer = torch.optim.Adam([
            {'params': self.policy.actor.parameters(), 'lr': lr_actor},
            {'params': self.policy.critic.parameters(), 'lr': lr_critic}
        ])

        self.policy_old = ActorCritic(state_dim, action_dim).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.MseLoss = nn.MSELoss()

    def select_action(self, state):
        state = torch.FloatTensor(state).flatten().unsqueeze(0).to(device)
        with torch.no_grad():
            action, action_logprob, state_value = self.policy_old.act(state, deterministic=False)

        self.buffer.states.append(state.squeeze(0))
        self.buffer.actions.append(action)
        self.buffer.logprobs.append(action_logprob)
        self.buffer.state_values.append(state_value.squeeze())

        return action.item()

    def update(self, mini_batch_size=256):
        # Compute discounted rewards-to-go
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)):
            discounted_reward = reward + (self.gamma * discounted_reward * (1 - is_terminal))
            rewards.insert(0, discounted_reward)
        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
        # rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)

        # Convert rollout buffer to tensors
        old_states = torch.stack(self.buffer.states).to(device)
        old_actions = torch.stack(self.buffer.actions).to(device)
        old_logprobs = torch.stack(self.buffer.logprobs).to(device)
        old_state_values = torch.stack(self.buffer.state_values).to(device)
        advantages = rewards - old_state_values.detach()
        # advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # Total data size
        dataset_size = old_states.size(0)
        for _ in range(self.K_epochs):
            # Shuffle indices
            indices = torch.randperm(dataset_size)
            for start in range(0, dataset_size, mini_batch_size):
                end = start + mini_batch_size
                mb_idx = indices[start:end]

                mb_states = old_states[mb_idx]
                mb_actions = old_actions[mb_idx]
                mb_logprobs = old_logprobs[mb_idx]
                mb_advantages = advantages[mb_idx]
                mb_returns = rewards[mb_idx]

                # Evaluate with current policy
                logprobs, state_values, dist_entropy = self.policy.evaluate(mb_states, mb_actions)
                ratios = torch.exp(logprobs - mb_logprobs.detach())

                # Clipped objective
                surr1 = ratios * mb_advantages
                surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * mb_advantages
                loss = -torch.min(surr1, surr2) \
                    + 0.5 * self.MseLoss(state_values, mb_returns) \
                    - 0.01 * dist_entropy

                self.optimizer.zero_grad()
                loss.mean().backward()
                self.optimizer.step()

        self.policy_old.load_state_dict(self.policy.state_dict())
        self.buffer.clear()

        
    def select_action_eval(self, state):
        state = torch.FloatTensor(state).flatten().unsqueeze(0).to(device)
        with torch.no_grad():
            action, _, _ = self.policy_old.act(state, deterministic=True)
        return action.item()

Using device: cuda


In [12]:
import gymnasium as gym
import highway_env
import numpy as np
import matplotlib.pyplot as plt


BASE_CONFIG = {
    "observation":       {"type": "Kinematics"},
    "action":            {"type": "DiscreteMetaAction"},
    "lanes_count":       4,
    "controlled_vehicles": 1,
    "duration":          40,
    "ego_spacing":       2,
    "vehicles_density":  1,
    "collision_reward":  -1,
    "right_lane_reward": 0.1,
    "high_speed_reward": 0.4,
    "lane_change_reward": 0,
    "normalize_reward":  True,
    "offroad_terminal":  False,
}

In [None]:

SCENARIOS = {
    "Slow":   {"reward_speed_range": [10, 20]},
    "Normal": {"reward_speed_range": [20, 30]},
    "Fast":   {"reward_speed_range": [30, 40]},
}

def train(cfg, label=""):
    env = gym.make("highway-v0", render_mode="rgb_array").unwrapped
    env.configure(cfg)
    print(f"\nTraining scenario: {label}")
    print(env.observation_space)
    
    state_dim = np.prod(env.observation_space.shape)
    action_dim = env.action_space.n

    # PPO hyperparameters
    max_episodes = 1500
    max_ep_len = 1000
    update_timestep = 1000
    K_epochs = 20
    eps_clip = 0.1
    gamma = 0.99
    lr_actor = 0.003
    lr_critic = 0.003

    episode_reward_history = []
    running_rewards = []
    last_n_reward = 50

    ppo_agent = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip)

    time_step = 0
    i_episode = 0

    while i_episode < max_episodes:
        obs, _ = env.reset()
        current_ep_reward = 0

        for _ in range(max_ep_len):
            action = ppo_agent.select_action(obs)
            next_obs, reward, done, truncated, info = env.step(action)

            ppo_agent.buffer.rewards.append(reward)
            ppo_agent.buffer.is_terminals.append(done or truncated)

            time_step += 1
            current_ep_reward += reward

            if time_step % update_timestep == 0:
                print("update with length "+ str(len(ppo_agent.buffer.rewards)))
                ppo_agent.update()

            obs = next_obs

            if done or truncated:
                break

        i_episode += 1
        episode_reward_history.append(current_ep_reward)

        running_reward = (
            np.mean(episode_reward_history[-last_n_reward:])
            if len(episode_reward_history) > last_n_reward
            else np.mean(episode_reward_history)
        )
        running_rewards.append(running_reward)
        if i_episode % 10 == 0:
            print(f"Episode {i_episode}, Reward: {current_ep_reward:.2f}, Running Avg ({last_n_reward}): {running_reward:.2f}")

    plt.figure(figsize=(10, 5))
    plt.plot(episode_reward_history, label="Episode Reward")
    plt.plot(running_rewards, label=f"Running Avg (last {last_n_reward})", linewidth=2)
    plt.xlabel("Episode")
    plt.ylabel("Reward")
    plt.title(f"PPO Training Rewards - {label}")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    env.close()
    return ppo_agent

if __name__ == "__main__":
    trained_agents = {}
    for label, scenario_cfg in SCENARIOS.items():
        cfg = BASE_CONFIG.copy()
        cfg.update(scenario_cfg)
        agent = train(cfg, label)
        torch.cuda.empty_cache()
        trained_agents[label] = agent


Training scenario: Slow
Box(-inf, inf, (5, 5), float32)
Episode 10, Reward: 3.19, Running Avg (50): 10.94
Episode 20, Reward: 24.93, Running Avg (50): 12.61
Episode 30, Reward: 3.15, Running Avg (50): 13.17
Episode 40, Reward: 7.09, Running Avg (50): 12.22
Episode 50, Reward: 17.68, Running Avg (50): 11.90
Episode 60, Reward: 7.01, Running Avg (50): 12.38
Episode 70, Reward: 1.22, Running Avg (50): 10.89
Episode 80, Reward: 5.79, Running Avg (50): 9.68
update with length 1000
Episode 90, Reward: 9.28, Running Avg (50): 9.87
Episode 100, Reward: 18.71, Running Avg (50): 10.88
Episode 110, Reward: 11.56, Running Avg (50): 11.40
Episode 120, Reward: 3.09, Running Avg (50): 12.48
Episode 130, Reward: 4.13, Running Avg (50): 12.66
Episode 140, Reward: 39.57, Running Avg (50): 12.94
Episode 150, Reward: 7.76, Running Avg (50): 11.88
update with length 1000
Episode 160, Reward: 12.58, Running Avg (50): 12.08
Episode 170, Reward: 5.99, Running Avg (50): 12.67
Episode 180, Reward: 5.22, Runnin