In [None]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# classic 42
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7acd35031790>

In [None]:
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        """
        state_dim:  4 in CartPole
        action_dim: 2 in CartPole, 1,0
        hidden_dim: just a hidden layer
        """
        super(ActorCritic, self).__init__()

        # 1) shared layers
        self.shared_fc = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU()
        )

        # 2) policy_head without softmax
        self.policy_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

        # 3) value head
        self.value_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        """
        feedford:
          input x: shape [batch_size, state_dim]
          return (logits, value)
          logits: [batch_size, action_dim], actor
          value:  [batch_size, 1], Critic
        """
        shared_features = self.shared_fc(x)         # share layer
        logits = self.policy_head(shared_features)  # actor
        value = self.value_head(shared_features)    # critic
        return logits, value

In [None]:
class RolloutBuffer:
    def __init__(self):
        # store data from game play
        self.states = []
        self.actions = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []  # sign for stop
        self.values = []

    def clear(self):
        """
        clear for next round
        """
        self.states = []
        self.actions = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []
        self.values = []

In [None]:
class PPO:
    def __init__(self, state_dim, action_dim,
                 lr=3e-4, gamma=0.99,
                 K_epochs=4, eps_clip=0.2):
        """
        state_dim
        action_dim
        lr: learnign rate
        gamma: discount factor
        K_epochs: how many rounds to train for each sampling
        eps_clip:clip range in objective function
        """
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs

        # 1) policy to renew
        self.policy = ActorCritic(state_dim, action_dim)

        # 2) optimizer
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)

        # 3) load old plicy and send parameters to new
        self.policy_old = ActorCritic(state_dim, action_dim)
        self.policy_old.load_state_dict(self.policy.state_dict())

        # 4) MSE for values
        self.mse_loss = nn.MSELoss()

        #
        self.action_dim = action_dim

    def act(self, state):
        """
        use policy_old to interact with envirment without grad update.
        reture: (action, logprob, value)
        """
        state = torch.FloatTensor(state).unsqueeze(0)  # shape: [1, state_dim]
        #dont update just get result
        with torch.no_grad():
            # 1)get logit and V from old policy
            logit, value = self.policy_old(state)

            # 2) get action posibility distrubution from old policy
            policy_dist = torch.distributions.Categorical(logits=logits)

            # 3) select a sample from distrubution
            action = policy_dist.sample()

            # 4) log π(a|s)
            action_logprob = policy_dist.log_prob(action)

        # return
        return action.item(), action_logprob.item(), value.item()

    def update(self, buffer: RolloutBuffer):

        # change data in buffer to tensor
        states = torch.FloatTensor(buffer.states)
        actions = torch.LongTensor(buffer.actions)
        old_logprobs = torch.FloatTensor(buffer.logprobs)
        rewards = buffer.rewards
        is_terminals = buffer.is_terminals
        values = torch.FloatTensor(buffer.values)

        # get the discounted_rewards

        discounted_rewards = []
        G = 0
        for reward, done in zip(reversed(rewards), reversed(is_terminals)):
            if done:
                G = 0
            G = reward + self.gamma * G
            discounted_rewards.insert(0, G)

        discounted_rewards = torch.FloatTensor(discounted_rewards)

        # normalise discounted_rewards
        discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / \
                             (discounted_rewards.std() + 1e-8)

        # get advantages (discounted_rewards - values)
        advantages = discounted_rewards - values
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # update with (K_epochs) clip-based
        for _ in range(self.K_epochs):
            # (a)  forward with policy new
            logits, state_values = self.policy(states)
            policy_dist = torch.distributions.Categorical(logits=logits)

            # (b) policy new log_probs
            new_logprobs = policy_dist.log_prob(actions)

            # (c) get ratio between policy old and policy new
            ratios = torch.exp(new_logprobs - old_logprobs)

            # (d)  surr1 as classic RL objective , surr2 as clip of that classic RL objective is change between old and new is too big
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages

            # (e) max[ratio×adv,clamp(ratio)×adv]
            policy_loss = -torch.min(surr1, surr2).mean()

            # (f) value loss (MSE)
            value_loss = self.mse_loss(state_values.squeeze(), discounted_rewards)

            # (g) add policy_loss and value_loss as total loss
            loss = policy_loss + 0.5 * value_loss

            # (h) lets go
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        # update on policy_old
        self.policy_old.load_state_dict(self.policy.state_dict())


In [None]:
def train_ppo_cartpole():
    # 1) get env
    env = gym.make('CartPole-v1')

    # 2) get dmention
    state_dim = env.observation_space.shape[0]   # CartPole: 4
    action_dim = env.action_space.n             # CartPole: 2

    # 3) PPO agnet  & RolloutBuffer
    ppo_agent = PPO(state_dim, action_dim, lr=3e-4, gamma=0.99, K_epochs=4, eps_clip=0.2)
    buffer = RolloutBuffer()

    # 4)
    max_episodes = 3000       # play 3000 games
    max_timesteps = 200       # each game 200 moves
    update_timestep = 2000    # update once 2000 moves
    print_freq = 100           # print

    # record
    timestep = 0

    for episode in range(1, max_episodes+1):#3000 games
        state = env.reset(seed=seed)  # reset env
        total_reward = 0

        # inner loop: 200 times
        for t in range(max_timesteps):
            timestep += 1

            # a) input state, get action, logprob, value
            action, logprob, value = ppo_agent.act(state)

            # b) input action, get next_state, reward, done, info
            next_state, reward, done, info = env.step(action)

            # c) save orbit on buffer
            buffer.states.append(state)
            buffer.actions.append(action)
            buffer.logprobs.append(logprob)
            buffer.values.append(value)
            buffer.rewards.append(reward)
            buffer.is_terminals.append(done)

            # d) move on
            state = next_state
            total_reward += reward

            # e) update ppo_agent every 2000
            if timestep % update_timestep == 0:
                ppo_agent.update(buffer)
                buffer.clear()           # clear old buffer

            if done:
                break

        # print
        if episode % print_freq == 0:
            print(f"Episode: {episode}, Reward: {total_reward}")

    env.close()


if __name__ == "__main__":
    # Let go!!!
    train_ppo_cartpole()


  deprecation(
  deprecation(


Episode: 100, Reward: 15.0
Episode: 200, Reward: 19.0
Episode: 300, Reward: 24.0
Episode: 400, Reward: 15.0
Episode: 500, Reward: 25.0
Episode: 600, Reward: 13.0
Episode: 700, Reward: 78.0
Episode: 800, Reward: 27.0
Episode: 900, Reward: 69.0
Episode: 1000, Reward: 98.0
Episode: 1100, Reward: 71.0
Episode: 1200, Reward: 164.0
Episode: 1300, Reward: 51.0
Episode: 1400, Reward: 141.0
Episode: 1500, Reward: 200.0
Episode: 1600, Reward: 200.0
Episode: 1700, Reward: 200.0
Episode: 1800, Reward: 179.0
Episode: 1900, Reward: 200.0
Episode: 2000, Reward: 200.0
Episode: 2100, Reward: 200.0
Episode: 2200, Reward: 200.0
Episode: 2300, Reward: 200.0
Episode: 2400, Reward: 200.0
Episode: 2500, Reward: 200.0
Episode: 2600, Reward: 200.0
Episode: 2700, Reward: 200.0
Episode: 2800, Reward: 200.0
Episode: 2900, Reward: 200.0
Episode: 3000, Reward: 200.0
