In [1]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No CUDA device")

True
1
NVIDIA GeForce RTX 4060 Laptop GPU


In [2]:
import torch
import torch.nn as nn
from torch.distributions import Categorical

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.state_values = []
        self.is_terminals = []

    def clear(self):
        self.actions.clear()
        self.states.clear()
        self.logprobs.clear()
        self.rewards.clear()
        self.state_values.clear()
        self.is_terminals.clear()

class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorCritic, self).__init__()
        
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1)
        )
        
        self.critic = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

    def act(self, state):
        action_probs = self.actor(state)  # [1, action_dim]
        dist = Categorical(action_probs)
        action = dist.sample()            # [1]
        action_logprob = dist.log_prob(action)  # [1]
        state_value = self.critic(state)        # [1, 1]

        return action.squeeze(0), action_logprob.squeeze(0), state_value.squeeze(0)


    def evaluate(self, states, actions):
        action_probs = self.actor(states)
        dist = Categorical(action_probs)
        action_logprobs = dist.log_prob(actions)
        dist_entropy = dist.entropy()
        state_values = self.critic(states).squeeze(-1)
        return action_logprobs, state_values, dist_entropy

class PPO:
    def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip):
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        
        self.buffer = RolloutBuffer()

        self.policy = ActorCritic(state_dim, action_dim).to(device)
        self.optimizer = torch.optim.Adam([
            {'params': self.policy.actor.parameters(), 'lr': lr_actor},
            {'params': self.policy.critic.parameters(), 'lr': lr_critic}
        ])

        self.policy_old = ActorCritic(state_dim, action_dim).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.MseLoss = nn.MSELoss()

    def select_action(self, state):
        state = torch.FloatTensor(state).flatten().unsqueeze(0).to(device)

        with torch.no_grad():
            action, action_logprob, state_value = self.policy_old.act(state)

        self.buffer.states.append(state.squeeze(0))
        self.buffer.actions.append(action)
        self.buffer.logprobs.append(action_logprob)
        self.buffer.state_values.append(state_value)

        return action.item()


    def update(self):
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)):
            discounted_reward = reward + (self.gamma * discounted_reward * (1 - is_terminal))
            rewards.insert(0, discounted_reward)

        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)

        old_states = torch.stack(self.buffer.states).to(device)
        old_actions = torch.stack(self.buffer.actions).to(device)
        old_logprobs = torch.stack(self.buffer.logprobs).to(device)
        old_state_values = torch.stack(self.buffer.state_values).to(device)

        advantages = rewards - old_state_values.detach()

        for _ in range(self.K_epochs):
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)
            ratios = torch.exp(logprobs - old_logprobs.detach())

            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages

            loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy

            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        self.policy_old.load_state_dict(self.policy.state_dict())
        self.buffer.clear()


Using device: cuda


In [3]:
import gymnasium as gym
import highway_env
import numpy as np

def train():
    env_name = "highway-v0"
    env = gym.make(env_name, render_mode=None).unwrapped
    env.configure({"controlled_vehicles": 1, "observation": {"type": "Kinematics"}})
    print(env.observation_space)

    state_dim = np.prod(env.observation_space.shape)
    action_dim = env.action_space.n

    max_training_timesteps = 300_000
    max_ep_len = 1000
    update_timestep = 4000
    K_epochs = 80
    eps_clip = 0.2
    gamma = 0.99
    lr_actor = 0.0003
    lr_critic = 0.001

    ppo_agent = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip)

    time_step = 0
    i_episode = 0

    while time_step <= max_training_timesteps:
        obs, _ = env.reset()
        current_ep_reward = 0

        for _ in range(max_ep_len):
            action = ppo_agent.select_action(obs)
            next_obs, reward, done, truncated, info = env.step(action)

            ppo_agent.buffer.rewards.append(reward)
            ppo_agent.buffer.is_terminals.append(done or truncated)

            time_step += 1
            current_ep_reward += reward

            if time_step % update_timestep == 0:
                ppo_agent.update()

            obs = next_obs

            if done or truncated:
                break

        i_episode += 1
        print(f"Episode {i_episode}, Reward: {current_ep_reward}")

    env.close()

if __name__ == "__main__":
    train()


Box(-inf, inf, (5, 5), float32)
Episode 1, Reward: 5.530725363273095
Episode 2, Reward: 10.371813406591151
Episode 3, Reward: 9.15646275004832
Episode 4, Reward: 2.664512183073266
Episode 5, Reward: 3.491384621727922
Episode 6, Reward: 10.576416203900411
Episode 7, Reward: 16.925857732386092
Episode 8, Reward: 2.552307953561
Episode 9, Reward: 13.804558413144099
Episode 10, Reward: 15.081003827471994
Episode 11, Reward: 10.013383288979737
Episode 12, Reward: 7.217647369614492
Episode 13, Reward: 7.267683459711742
Episode 14, Reward: 13.108262311215332
Episode 15, Reward: 7.669160042374724
Episode 16, Reward: 7.865316725156708
Episode 17, Reward: 6.0461508124950365
Episode 18, Reward: 4.868382791906869
Episode 19, Reward: 7.969786817694332
Episode 20, Reward: 7.336505705048753
Episode 21, Reward: 7.750823355234628
Episode 22, Reward: 7.7116790135021835
Episode 23, Reward: 6.782824178254533
Episode 24, Reward: 12.711372724885639
Episode 25, Reward: 6.088814053131497
Episode 26, Reward: 9

KeyboardInterrupt: 

In [24]:
import sys
print(sys.executable)

c:\Python311\python.exe
