In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.utils.tensorboard import SummaryWriter
import gym
import math

In [2]:
EPISODES = 500
EPOCHS = 5
ROLLOUTS = 350
device = "cuda" if torch.cuda.is_available() else "cpu"

# Utils

In [3]:
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [4]:
class PytorchWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)

    def step(self, action):
        obs, reward, done, _ = self.env.step(action)
        if done:
            reward = -10  ## bigger negative reward if epsiode terminates
        obs = torch.tensor(obs, dtype=torch.float)
        reward = torch.tensor(reward, dtype=torch.float)
        return obs, reward, done

    def reset(self):
        obs = self.env.reset()
        obs = torch.tensor(obs, dtype=torch.float)
        return obs

In [5]:
def make_env(env):
    env = PytorchWrapper(env)  ## Convert obs to pytorch tensors
    return env

# Policy Network

In [6]:
class PolicyNetwork(nn.Module):
    def __init__(self, in_features=4, out_features=2):
        super().__init__()
        ## Predict action prob distribution for a state
        self.action_head = nn.Sequential(
            nn.Linear(in_features, 32),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(16, out_features),
            nn.Softmax(dim=-1),
        )

        ## Predict value function for a state
        self.value_head = nn.Sequential(
            nn.Linear(in_features, 32),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(16, 1),
        )

    def forward(self, x):
        action_prob = self.action_head(x)
        state_value = self.value_head(x)
        return action_prob, state_value

In [7]:
class RolloutBuffer:

    """
    Rollout buffer to store data during a policy rollout
    """

    def __init__(self, rollout_steps, gamma, device):
        self.rollout_steps = rollout_steps
        self.gamma = gamma
        self.device = device
        self.states = None
        self.rewards = None
        self.actions = None
        self.log_probs = None
        self.count = None
        self.reset()

    def reset(self):
        self.states = [None] * self.rollout_steps
        self.rewards = [None] * self.rollout_steps
        self.actions = [None] * self.rollout_steps
        self.log_probs = [None] * self.rollout_steps
        self.count = 0

    def store(self, state, reward, action, log_prob):
        self.states[self.count] = state
        self.rewards[self.count] = (self.gamma ** self.count) * reward
        self.actions[self.count] = action
        self.log_probs[self.count] = log_prob
        self.count += 1

    def compute_returns(self):
        returns = []
        for i in range(self.count):
            returns.append(sum(self.rewards[i : self.count]) / self.gamma ** i)
        return returns

    def get_values(self):
        ## Computes returns for each step in rollout
        states = torch.stack(self.states[: self.count]).to(self.device)
        actions = torch.tensor(self.actions[: self.count]).to(self.device).long()
        log_probs = torch.stack(self.log_probs[: self.count]).to(self.device)
        returns = self.compute_returns()
        returns = torch.stack(returns).to(self.device)
        self.reset()  ## reset rollout buffer
        return states, returns, actions, log_probs

# Proximal Policy Optimization (PPO)

In [8]:
class PPO:
    def __init__(
        self,
        action_size,
        rollout_steps,
        device="cpu",
        eps_clip=0.2,
        gamma=0.99,
        lr=0.001,
    ):
        self.buffer = RolloutBuffer(rollout_steps, gamma, device)

        self.policy = PolicyNetwork().to(device)  ## Current policy
        self.policy_old = PolicyNetwork().to(device)  ## Old policy
        self.update_old()

        self.policy_optim = torch.optim.Adam(self.policy.parameters(), lr=lr)

        self.rollout_steps = rollout_steps
        self.device = device
        self.eps_clip = eps_clip  ## Epsilon to clip ratios

    def update_old(self):
        ## Update old policy
        self.eval_models()
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.train_models()

    def eval_models(self):
        ## Set models to eval
        self.policy.eval()
        self.policy_old.eval()

    def train_models(self):
        ## Set model to train
        self.policy.train()

    def save(self):
        ## Save policy and old policy
        self.eval_models()
        torch.save(
            {
                "policy": self.policy.state_dict(),
                "policy_old": self.policy_old.state_dict(),
                "optimizer": self.policy_optim.state_dict(),
            },
            "PPO.pth",
        )
        self.train_models()

    def load(self, path="PPO.pth"):
        ## Load policy from path
        ckpt = torch.load(path)
        self.policy.load_state_dict(ckpt["policy"])
        self.policy_old.load_state_dict(ckpt["policy_old"])
        self.policy_optim.load_state_dict(ckpt["optimizer"])

    def evaluate(self, env):
        obs = env.reset()
        done = False
        step = 0
        while not done:
            obs = obs.unsqueeze(0)
            obs = obs.to(self.device)
            env.render()
            action_prob, _ = self.forward(self.policy_old, obs)
            action = Categorical(action_prob).sample().item()
            obs, reward, done = env.step(action)
            step += 1
            print(f"Step: {step}, Reward: {reward}")

    def forward(self, model, obs, grad=False):
        action_prob, state_value = model(obs)
        if not grad:
            action_prob = action_prob.detach()
            state_value = state_value.detach()
        return action_prob, state_value

    def learn(self, env, episodes, epochs):
        count = 0
        writer = SummaryWriter()
        for eps in range(episodes):
            obs = env.reset()
            reward_tracker = AverageMeter()
            for _ in range(self.rollout_steps):
                obs = obs.to(self.device)
                action_prob, _ = self.forward(self.policy_old, obs.unsqueeze(0))
                dist = Categorical(action_prob)
                action = (
                    dist.sample()
                )  ## Action is sampled from action prob distribution
                log_prob = dist.log_prob(action)
                action = action.item()
                next_obs, reward, done = env.step(action)
                reward_tracker.update(reward.item())
                self.buffer.store(obs.cpu(), reward.squeeze(), action, log_prob)
                obs = next_obs

                if done:
                    break

            (
                old_states,
                old_returns,
                old_actions,
                old_log_probs,
            ) = self.buffer.get_values()

            for _ in range(epochs):
                ## Optimize on same trajectory for k epochs
                action_prob, values = self.forward(self.policy, old_states, grad=True)
                action_cat = Categorical(action_prob)
                log_probs = action_cat.log_prob(old_actions)
                entropy = action_cat.entropy()

                ratios = torch.exp(
                    log_probs - old_log_probs.squeeze()
                )  ## ratio of log probs
                returns = (old_returns - old_returns.mean()) / (
                    old_returns.std() + 1e-6
                )  ## Normalize returns
                advantages = (
                    returns - values.detach().squeeze()
                )  ## Compute advantages using value function network and returns

                self.policy_optim.zero_grad()
                surr1 = ratios * advantages
                surr2 = (
                    torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip)
                    * advantages
                )  ## Clipped surrogate objective

                ## Minimum of two surrogate loss, update value network using smooth_l1_loss
                loss = (
                    -torch.min(surr1, surr2)
                    + 0.5 * F.smooth_l1_loss(values.squeeze(), returns.squeeze())
                    - 0.01 * entropy
                ).mean()
                loss.backward()
                self.policy_optim.step()

            self.update_old()  ## Update old policy after k epochs

            writer.add_scalar("Loss", loss.item(), count)
            writer.add_scalar("Reward", reward_tracker.sum, count)
            count += 1

            if eps % 20 == 0:
                print(
                    f"Episode: {eps+1}/{episodes}, loss: {loss.item()}, reward: {reward_tracker.sum}"
                )
                self.save()

        writer.close()

In [9]:
env = gym.make("CartPole-v1")
env = make_env(env)
action_size = env.action_space.n

In [10]:
agent = PPO(action_size=action_size, rollout_steps=ROLLOUTS, device=device)

In [12]:
## Load tensorboard for visualization of loss
# %load_ext tensorboard
# %tensorboard --logdir runs

In [11]:
agent.learn(env, EPISODES, EPOCHS)

Episode: 1/500, loss: -0.08507605642080307, reward: 27.0
Episode: 21/500, loss: 0.12620748579502106, reward: 2.0
Episode: 41/500, loss: 0.11203133314847946, reward: 13.0
Episode: 61/500, loss: -0.01715271547436714, reward: 2.0
Episode: 81/500, loss: -0.13787035644054413, reward: 4.0
Episode: 101/500, loss: 0.4552017152309418, reward: 28.0
Episode: 121/500, loss: -0.3362264633178711, reward: -2.0
Episode: 141/500, loss: 0.29309141635894775, reward: 23.0
Episode: 161/500, loss: -0.43723297119140625, reward: 85.0
Episode: 181/500, loss: -0.1673368662595749, reward: 138.0
Episode: 201/500, loss: 0.5344097018241882, reward: 45.0
Episode: 221/500, loss: -0.1828581690788269, reward: 159.0
Episode: 241/500, loss: -0.17927084863185883, reward: 137.0
Episode: 261/500, loss: -0.023509884253144264, reward: 272.0
Episode: 281/500, loss: 0.35203856229782104, reward: 350.0
Episode: 301/500, loss: 0.5279468894004822, reward: 350.0
Episode: 321/500, loss: 0.696567714214325, reward: 350.0
Episode: 341/5

In [13]:
# agent.load()
# agent.evaluate(env)