In [3]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np

In [4]:
LEARNING_RATE = 0.0003
GAMMA = 0.99
EPSILON_CLIP = 0.2
ENTROPY_COEFF = 0.01
EPOCHS = 10
BATCH_SIZE = 64 # Replay Buffer에 있는 거 64개씩 가져오겠다
TIMESTEPS = 2048

In [5]:
class PPOActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PPOActorCritic, self).__init__()

        #Actor
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1)
        )

        # Critic
        self.critic = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

    def forward(self):
        raise NotImplementedError

    def get_action_and_value(self, state):
        action_probs = self.actor(state) # pi(a|s) => left|right => [0.7, 0.3]
        state_values = self.critic(state) # V(s), R+rV(s')-V(s) -> V(s) 때문에 있음

        dist = Categorical(action_probs)
        action = dist.sample() # Action Prob의 분포로부터 샘플링한다.=> [0.7, 0.3] => 0번 Action
        action_logprobs = dist.log_prob(action) #log(0.75)
        entropy = dist.entropy() #entropy([0.7, 0.3])

        return action, action_logprobs, state_values, entropy

In [6]:
class RolloutBuffer():
    def __init__(self):
        self.actions = []
        self.states = []
        self.log_probs = []
        self.rewards = []
        self.state_values = []
        self.dones = []

    def clear(self):
        self.actions = []
        self.states = []
        self.log_probs = []
        self.rewards = []
        self.state_values = []
        self.dones = []

In [23]:
def train_ppo(buffer, old_model, new_model, optimizer):

    state = buffer.states[-1]
    done = buffer.dones[-1]
    with torch.no_grad():
        discounted_rewards = 0 if done else old_model.get_action_and_value(torch.FloatTensor(state))[2].item()

    returns = []
    for reward in reversed(buffer.rewards):
        discounted_rewards = reward + GAMMA * discounted_rewards
        returns.insert(0, discounted_rewards)

    advantages = torch.FloatTensor(returns) - torch.FloatTensor(buffer.state_values)

    for _ in range(EPOCHS):
        for idx in range(0, len(buffer.states), BATCH_SIZE):
            batch_states = torch.FloatTensor(buffer.states[idx:idx+BATCH_SIZE])
            batch_actions = torch.LongTensor(buffer.actions[idx:idx+BATCH_SIZE])

            batch_returns = torch.FloatTensor(returns[idx:idx+BATCH_SIZE])
            batch_advantages = torch.FloatTensor(advantages[idx:idx+BATCH_SIZE])

            # 동일한 state에 대해 New model이라면 이렇게 했을 것이다
            new_policy_logits = new_model.actor(batch_states)
            values = new_model.critic(batch_states)
            new_policy_dist = Categorical(logits=new_policy_logits)
            new_log_probs = new_policy_dist.log_prob(batch_actions)
            entropy = new_policy_dist.entropy()

            with torch.no_grad():
                old_policy_logits = old_model.actor(batch_states)
                old_policy_dist = Categorical(logits=old_policy_logits)
                old_log_probs = old_policy_dist.log_prob(batch_actions)

            ratios = torch.exp(new_log_probs - old_log_probs)

            surrogate1 = ratios * batch_advantages
            surrogate2 = torch.clamp(ratios, 1-EPSILON_CLIP, 1+EPSILON_CLIP) * batch_advantages
            policy_loss = -torch.min(surrogate1, surrogate2).mean()

            value_loss = nn.MSELoss()(values.squeeze(), batch_returns)
            entropy_loss = -ENTROPY_COEFF * entropy.mean()

            loss = policy_loss + value_loss + entropy_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

In [27]:
env = gym.make("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

old_model = PPOActorCritic(state_dim, action_dim)
new_model = PPOActorCritic(state_dim, action_dim)
new_model.load_state_dict(old_model.state_dict())

optimizer = optim.Adam([
    {'params': new_model.actor.parameters(), 'lr': LEARNING_RATE},
    {'params': new_model.critic.parameters(), 'lr': LEARNING_RATE},
])

buffer = RolloutBuffer()

n_episodes = 100
for episode in range(n_episodes):
    state, _ = env.reset()
    state = torch.FloatTensor(state)
    episode_reward = 0

    buffer.clear()

    for t in range(TIMESTEPS):
        with torch.no_grad():
            action, log_prob, value, _ = old_model.get_action_and_value(state)
            next_state, reward, done, _, _ = env.step(action.item())

        # Store data
        buffer.states.append(state.numpy())
        buffer.actions.append(action.item())
        buffer.rewards.append(reward)
        buffer.dones.append(done)
        buffer.log_probs.append(log_prob.item())
        buffer.state_values.append(value.item())

        state = torch.FloatTensor(next_state)
        episode_reward += reward

        if done:
            state, _ = env.reset()
            state = torch.FloatTensor(state)
            break

        # Train PPO
        train_ppo(buffer, old_model, new_model, optimizer)
        old_model.load_state_dict(new_model.state_dict())

        print(f"\rEpisode {episode + 1}, episode reward {episode_reward}", end="")
    print()

Episode 1, episode reward 8.0
Episode 2, episode reward 9.0
Episode 3, episode reward 9.0
Episode 4, episode reward 9.0
Episode 5, episode reward 8.0
Episode 6, episode reward 8.0
Episode 7, episode reward 8.0
Episode 8, episode reward 8.0
Episode 9, episode reward 8.0
Episode 10, episode reward 9.0
Episode 11, episode reward 8.0
Episode 12, episode reward 9.0
Episode 13, episode reward 7.0
Episode 14, episode reward 9.0
Episode 15, episode reward 9.0
Episode 16, episode reward 9.0
Episode 17, episode reward 8.0
Episode 18, episode reward 8.0
Episode 19, episode reward 8.0
Episode 20, episode reward 8.0
Episode 21, episode reward 8.0
Episode 22, episode reward 9.0
Episode 23, episode reward 8.0
Episode 24, episode reward 7.0
Episode 25, episode reward 7.0
Episode 26, episode reward 9.0
Episode 27, episode reward 7.0
Episode 28, episode reward 9.0
Episode 29, episode reward 9.0
Episode 30, episode reward 9.0
Episode 31, episode reward 9.0
Episode 32, episode reward 8.0
Episode 33, episo

In [29]:
import time
max_ep_len = 300

total_test_episodes = 2
test_running_reward = 0

env = gym.make("CartPole-v1", render_mode='human')

for ep in range(1, total_test_episodes+1):
    ep_reward = 0
    state, info = env.reset()
    
    for i in range(max_ep_len):
        action_probs = new_model.actor(torch.FloatTensor(state))
        dist = Categorical(action_probs)
        action = dist.sample()
        state, reward, done, trancated, _ = env.step(action.numpy())
        ep_reward += reward
        env.render()
        time.sleep(0.01)
        
        if done:
            state, info = env.reset()
            
    test_running_reward += ep_reward
    print(f"Episode: {ep} \t\t Reward: {ep_reward:.2f}")
    ep_reward = 0
    
env.close()

Episode: 1 		 Reward: 300.00
Episode: 2 		 Reward: 300.00
