In [None]:
!pip install jedi stable-baselines3[extra] pyglet

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import gym
import numpy as np
import matplotlib.pyplot as plt

In [None]:
class Actor(nn.Module):
  def __init__(self, state_dim, action_dim, learning_rate):
    super().__init__()
    self.fc1 = nn.Linear(state_dim, 64)
    self.fc2 = nn.Linear(64, 32)
    self.fc3 = nn.Linear(32, 16)
    self.mean = nn.Linear(16, action_dim)
    self.log_std = nn.Linear(16, action_dim)
    self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

  def forward(self, state):
    x = torch.relu(self.fc1(state))
    x = torch.relu(self.fc2(x))
    x = torch.relu(self.fc3(x))
    mean = 2 * torch.tanh(self.mean(x)) # pendulum: action space in [-2, 2]
    std = torch.exp(self.log_std(x))
    return mean, std

In [None]:
class Critic(nn.Module):
  def __init__(self, state_dim, learning_rate):
    super().__init__()
    self.fc1 = nn.Linear(state_dim, 64)
    self.fc2 = nn.Linear(64, 32)
    self.fc3 = nn.Linear(32, 16)
    self.value = nn.Linear(16, 1)
    self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

  def forward(self, state):
    x = torch.relu(self.fc1(state))
    x = torch.relu(self.fc2(x))
    x = torch.relu(self.fc3(x))
    return self.value(x)

In [None]:
class PPOagent(object):
  def __init__(self, env):
    self.env = env
    self.state_dim = env.observation_space.shape[0]
    self.action_dim = env.action_space.shape[0]

    # Hyperparameters
    self.GAMMA = 0.99            # discount factor
    self.LAMBDA = 0.95           # for GAE
    self.EPSILON = 0.2           # for clipping the ratio
    self.LR_ACTOR = 2e-4
    self.LR_CRITIC = 2e-3
    self.NUM_STEPS = 1024        # how long will it follow the trajectory
    self.EPOCHS = 50             # how many times will it use the old policy to learn
    self.NUM_EPISODES = 1000

    self.actor = Actor(self.state_dim, self.action_dim, self.LR_ACTOR)
    self.critic = Critic(self.state_dim, self.LR_CRITIC)

    self.ep_rewards = []

  def GAE(self, rewards, values, next_value, done):
    gaes = torch.zeros_like(torch.tensor(rewards))
    gae_sum = 0
    if done:
      next_value = 0
    for i in reversed(range(0, len(rewards))):
      delta = rewards[i] + self.GAMMA * next_value - values[i]
      gae_sum = delta + self.GAMMA * self.LAMBDA * gae_sum
      gaes[i] = gae_sum
      next_value = values[i]
    return gaes[0]

  def update_actor(self, old_log_probs, states, actions, gaes):
    mean, std = self.actor(torch.tensor(states))
    dist = torch.distributions.Normal(mean, torch.clamp(std, 1e-2, 1.0))
    log_probs = dist.log_prob(actions)
    ratios = torch.exp(log_probs - old_log_probs)
    clipped_ratios = torch.clamp(ratios, 1.0 - self.EPSILON, 1.0 + self.EPSILON)
    actor_loss = -torch.mean(torch.min(ratios * gaes, clipped_ratios * gaes), dim=0)
    self.actor.optimizer.zero_grad()
    actor_loss.backward()
    self.actor.optimizer.step()

  def update_critic(self, gaes):
    critic_loss = torch.mean(torch.square(gaes))
    self.critic.optimizer.zero_grad()
    critic_loss.backward()
    self.critic.optimizer.step()

  def sample_action(self, state):
    mean, std = self.actor(torch.tensor(state))
    dist = torch.distributions.Normal(mean, torch.clamp(std, 1e-2, 1.0))
    action = torch.clamp(dist.sample(), -2, 2) # pendulum: action space in [-2, 2]
    log_prob = dist.log_prob(action)
    return action, log_prob

  def train(self):
    for ep in range(self.NUM_EPISODES):
      states, actions, rewards, log_probs = [], [], [], []
      done = False
      state = self.env.reset()
      for step in range(self.NUM_STEPS):
        action, log_prob = self.sample_action(state)
        next_state, reward, done, _ = self.env.step(action.numpy())
        reward = (reward + 8) / 8 # pendulum: return in [-16.xx, 0]

        states.append(state)
        actions.append(action)
        rewards.append(reward)
        log_probs.append(log_prob)
        state = next_state

        if done: 
          break

      old_log_probs = torch.stack(log_probs).detach()
      actions = torch.stack(actions)

      for _ in range(self.EPOCHS):
        values = self.critic(torch.tensor(states))
        next_value = self.critic(torch.tensor(next_state))
        gaes = self.GAE(rewards, values, next_value, done)

        self.update_actor(old_log_probs, states, actions, gaes.detach())
        self.update_critic(gaes)

      ep_reward = np.mean(rewards)
      print(f"Episode: {ep+1}, Mean Reward: {ep_reward}")
      self.ep_rewards.append(ep_reward)

  def plot_ep_rewards(self):
    plt.plot(self.ep_rewards)
    plt.show()

In [None]:
env = gym.make("Pendulum-v1")
agent = PPOagent(env) 

agent.train()
agent.plot_ep_rewards()