In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gym
import numpy as np

class Actor(nn.Module):
    def __init__(self, state_size, action_size):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.policy = nn.Linear(hidden_size, action_size)

        self.action_size = action_size

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return F.softmax(self.policy(x), dim=-1)

    def compute_loss(self, old_policy, new_policy, actions, gaes):
        gaes = gaes.detach()
        old_log_p = torch.log(torch.sum(old_policy * actions, dim=1))
        old_log_p = old_log_p.detach()
        log_p = torch.log(torch.sum(new_policy * actions, dim=1))
        ratios = torch.exp(log_p - old_log_p)
        clipped_ratios = torch.clamp(ratios, 1 - clip_ratio, 1 + clip_ratio)
        surrogate = -torch.min(ratios * gaes, clipped_ratios * gaes)
        return surrogate.mean()

    def train(self, optimizer, old_policy, states, actions, gaes):
        actions = F.one_hot(actions.long(), num_classes=self.action_size)
        actions = actions.float()

        optimizer.zero_grad()
        curr_policy = self(states)
        loss = self.compute_loss(old_policy, curr_policy, actions, gaes)
        loss.backward()
        optimizer.step()
        return loss


In [None]:
class Critic(nn.Module):
    def __init__(self, state_size):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.value(x)

    def compute_loss(self, v_pred, td_targets):
        # v_pred의 차원을 [n, 1]에서 [n]으로 변경
        v_pred = v_pred.squeeze(-1)
        # td_targets의 차원이 스칼라인 경우 [1]로 변경
        td_targets = td_targets.unsqueeze(-1) if td_targets.dim() == 0 else td_targets
        return F.mse_loss(v_pred, td_targets)

    def train(self, optimizer, states, td_targets):
        optimizer.zero_grad()
        v_pred = self(states)
        loss = self.compute_loss(v_pred, td_targets.detach())
        loss.backward()
        optimizer.step()
        return loss


In [None]:
from tqdm import tqdm

class PPOAgent:
    def __init__(self, env_name, gamma):
        self.env = gym.make(env_name)

        self.state_size = self.env.observation_space.shape[0]
        self.action_size = self.env.action_space.n

        self.gamma = gamma

        self.actor = Actor(self.state_size, self.action_size)
        self.critic = Critic(self.state_size)

        self.actor_opt = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_opt = optim.Adam(self.critic.parameters(), lr=critic_lr)

    def gae_target(self, rewards, curr_Qs, next_Q, done):
        td_targets = np.zeros_like(rewards)
        gae = np.zeros_like(rewards)
        gae_cumulative = 0
        future_reward = 0

        if not done:
            future_reward = next_Q

        for k in reversed(range(0, len(rewards))):
            delta = rewards[k] + self.gamma * future_reward - curr_Qs[k]
            gae_cumulative = self.gamma * lmbda * gae_cumulative + delta
            gae[k] = gae_cumulative
            future_reward = curr_Qs[k]
            td_targets[k] = gae[k] + curr_Qs[k]
        return gae, td_targets

    def train(self, max_episodes, update_interval):
        progress_bar = tqdm(range(max_episodes), desc="Training Progress")
        for episode in progress_bar:
            episode_reward = 0
            done = False
            state = self.env.reset()

            states = []
            actions = []
            rewards = []
            old_policys = []

            while not done:
                state_tensor = torch.tensor(state, dtype=torch.float32)
                probs = self.actor(state_tensor).detach().numpy()
                action = np.random.choice(self.action_size, p=probs)

                next_state, reward, done, _ = self.env.step(action)

                states.append(state)
                actions.append(action)
                rewards.append(reward)
                old_policys.append(probs)

                state = next_state
                episode_reward += reward

                if len(states) >= update_interval or done:
                    states_tensor = torch.tensor(states, dtype=torch.float32)
                    actions_tensor = torch.tensor(actions, dtype=torch.long)
                    rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
                    old_policys_tensor = torch.tensor(old_policys, dtype=torch.float32)

                    curr_Qs = self.critic(states_tensor).detach().numpy()
                    next_Q = self.critic(torch.tensor(next_state, dtype=torch.float32)).item()

                    gaes, td_targets = self.gae_target(rewards, curr_Qs, next_Q, done)
                    gaes = torch.tensor(gaes, dtype=torch.float32)
                    td_targets = torch.tensor(td_targets, dtype=torch.float32)

                    for epoch in range(epochs):
                        actor_loss = self.actor.train(self.actor_opt, old_policys_tensor, states_tensor, actions_tensor, gaes)
                        critic_loss = self.critic.train(self.critic_opt, states_tensor, td_targets)

                    states = []
                    actions = []
                    rewards = []
                    old_policys = []

                progress_bar.set_postfix({"Episode_Reward": episode_reward})

if __name__ == "__main__":

    env_name = "CartPole-v0"
    actor_lr = 0.0005
    critic_lr = 0.001
    gamma = 0.99
    hidden_size = 128
    update_interval = 50
    clip_ratio = 0.1
    lmbda = 0.95
    epochs = 7
    max_episodes = 500

    agent = PPOAgent(env_name, gamma)
    agent.train(max_episodes, update_interval)


Training Progress: 100%|██████████| 500/500 [03:38<00:00,  2.29it/s, Episode_Reward=200]
