This notebook is based on `https://github.com/higgsfield-ai/higgsfield/rl/rl_adventure_2/3.ppo.ipynb`

In [2]:
import gymnasium as gym
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

In [3]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [4]:
import davinci_code_env

<h2>Use CUDA</h2>

In [None]:
use_cuda = torch.cuda.is_available()
device   = torch.device("cuda" if use_cuda else "cpu")
print(f"Using device: {device}")

<h2>Create Environments</h2>

In [6]:
def make_env():
    env = gym.wrappers.FlattenObservation(
        gym.make(
            "DavinciCode-v0",
            num_players=3,
            initial_player=0,
            max_tile_num=11,
            initial_tiles=4,
        )
    )
    return env


train_env = make_env()
eval_env = make_env()

<h2>Neural Network</h2>

In [7]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        # nn.init.normal_(m.weight, mean=0., std=0.1)
        nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
        nn.init.constant_(m.bias, 0.1)


class ActorCritic(nn.Module):
    def __init__(self, num_inputs, nvec_outputs, shared_sizes, critic_sizes, actor_sizes):
        super(ActorCritic, self).__init__()

        self.nvec_outputs = nvec_outputs

        # Shared network
        shared_sizes.insert(0, num_inputs)
        last_size = shared_sizes[0]
        shared_layers = []

        for size in shared_sizes[1:]:
            shared_layers += [
                nn.Linear(last_size, size),
                nn.ReLU(),
            ]
            last_size = size
        shared_last_size = last_size

        self.shared = nn.Sequential(*shared_layers)

        # Critic network
        critic_sizes.insert(0, shared_last_size)
        last_size = critic_sizes[0]
        critic_layers = []

        for size in critic_sizes[1:]:
            critic_layers += [
                nn.Linear(last_size, size),
                nn.ReLU(),
            ]
            last_size = size
        critic_layers += [
            nn.Linear(last_size, 1),
        ]

        self.critic = nn.Sequential(*critic_layers)

        # Actor network
        actor_sizes.insert(0, shared_last_size)
        last_size = actor_sizes[0]
        actor_layers = []

        for size in actor_sizes[1:]:
            actor_layers += [
                nn.Linear(last_size, size),
                nn.ReLU(),
            ]
            last_size = size
        actor_layers += [
            nn.Linear(last_size, sum(nvec_outputs)),
        ]
        actor_layers.append(nn.Softmax(dim=-1))

        self.actor = nn.Sequential(*actor_layers)

        self.apply(init_weights)

    def forward(self, x):
        intermediate = self.shared(x)
        value = self.critic(intermediate)
        actions_probs = self.actor(intermediate)
                
        if x.ndim == 1:  # Non-batched input
            dist = [
                Categorical(action_probs)
                for action_probs in torch.split(actions_probs, self.nvec_outputs.tolist(), dim=0)
            ]
        else:
            dist = [
                Categorical(action_probs)
                for action_probs in torch.split(actions_probs, self.nvec_outputs.tolist(), dim=1)
            ]

        return dist, value

In [8]:
def plot(frame_idx, rewards):
    clear_output(True)
    plt.figure(figsize=(20,5))
    plt.subplot(131)
    plt.title('frame %s. reward: %s' % (frame_idx, rewards[-1]))
    plt.plot(rewards)
    plt.show()
    
def test_env(model):
    with torch.no_grad():
        state, _ = eval_env.reset()
        done = False
        total_reward = 0
        frame_count = 0
        while not done:
            state = torch.FloatTensor(state).to(device)
            dist, _ = model(state)
            action = [dist_single.sample() for dist_single in dist]
            next_state, reward, terminated, truncated, _ = eval_env.step(np.array(action))
            done = np.logical_or(terminated, truncated)
            state = next_state
            total_reward += reward
            frame_count += 1
    return total_reward

<h2>GAE</h2>

In [9]:
def compute_gae(next_value, rewards, masks, values, gamma=0.99, tau=0.95):
    values = values + [next_value]
    gae = 0
    returns = []
    for step in reversed(range(len(rewards))):
        delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]
        gae = delta + gamma * tau * masks[step] * gae
        returns.insert(0, gae + values[step])
    return returns

<h1> Proximal Policy Optimization Algorithm</h1>
<h2><a href="https://arxiv.org/abs/1707.06347">Arxiv</a></h2>

In [10]:
def ppo_iter(mini_batch_size, states, actions, log_probs, returns, advantage):
    batch_size = states.size(0)
    for _ in range(batch_size // mini_batch_size):
        rand_ids = np.random.randint(0, batch_size, mini_batch_size)
        yield states[rand_ids, :], actions[rand_ids], log_probs[rand_ids, :], returns[
            rand_ids
        ], advantage[rand_ids]


def ppo_update(
    model, optimizer, ppo_epochs, mini_batch_size, states, actions, log_probs, returns, advantages, clip_param=0.2
):
    torch.autograd.set_detect_anomaly(True)
    for _ in range(ppo_epochs):
        for state, action, old_log_probs, return_, advantage in ppo_iter(
            mini_batch_size, states, actions, log_probs, returns, advantages
        ):
            dist, value = model(state)
            entropy = torch.mean(torch.stack([dist_single.entropy() for dist_single in dist]))
            new_log_probs = torch.stack(
                [
                    dist.log_prob(action_single)
                    for action_single, dist in zip(action.transpose(0, 1), dist)
                ]
            ).transpose(0, 1)

            ratios = (new_log_probs - old_log_probs).exp()
            ratio = torch.prod(ratios, dim=1)
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantage

            actor_loss = -torch.min(surr1, surr2).mean()
            critic_loss = (return_ - value).pow(2).mean()

            loss = 0.5 * critic_loss + actor_loss - 0.001 * entropy

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

In [11]:
num_inputs = train_env.observation_space.shape[0]
nvec_outputs = train_env.action_space.nvec

# Hyper params:
shared_sizes = [1024, 512, 256, 128, 96, 64]
critic_sizes = [48, 48, 32, 16]
actor_sizes = [48, 48, 48]
hidden_num = 7
hidden_size = 48
lr = 1e-5
num_steps = 20
mini_batch_size = 5
ppo_epochs = 3
threshold_reward = 200

model = ActorCritic(num_inputs, nvec_outputs, shared_sizes, critic_sizes, actor_sizes).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

In [12]:
max_frames = 100000
frame_count  = 0
test_rewards = []

In [None]:
state, _ = train_env.reset()
early_stop = False

while frame_count < max_frames and not early_stop:

    log_probs = []
    values = []
    states = []
    actions = []
    rewards = []
    masks = []
    entropy = 0

    for step_num in range(num_steps):
        state = torch.FloatTensor(state).to(device)
        with torch.no_grad():
            dist, value = model(state)

        action = [dist_single.sample() for dist_single in dist]
        next_state, reward, terminated, truncated, _ = train_env.step(np.array(action))
        done = np.logical_or(terminated, truncated)

        log_prob = torch.tensor(
            [
                dist_single.log_prob(action_single)
                for action_single, dist_single in zip(action, dist)
            ]
        )
        entropy += torch.mean(torch.stack([dist_single.entropy() for dist_single in dist]))

        log_probs.append(log_prob)
        values.append(value)
        rewards.append(torch.tensor(reward, dtype=torch.float32).to(device))
        masks.append(torch.tensor(1 - done).to(device))

        states.append(state)
        actions.append(action)

        state = next_state
        frame_count += 1

        if done:
            print(
                "Training episode ended at: ",
                frame_count,
                "With total reward: ",
                sum(rewards),
                flush=True,
            )
            state, _ = train_env.reset()

        if frame_count % 1000 == 0:
            test_reward = np.mean([test_env(model) for _ in range(5)])
            test_rewards.append(test_reward)
            plot(frame_count, test_rewards)

            torch.save(model, "ppo_model")

            if test_reward > threshold_reward:
                early_stop = True

    next_state = torch.FloatTensor(next_state).to(device)
    _, next_value = model(next_state)
    returns = compute_gae(next_value, rewards, masks, values)

    returns = torch.cat(returns).detach()
    log_probs = torch.stack(log_probs).detach()
    values = torch.cat(values).detach()
    states = torch.stack(states)
    actions = torch.tensor(actions)

    advantage = returns - values

    ppo_update(
        model,
        optimizer,
        ppo_epochs,
        mini_batch_size,
        states,
        actions,
        log_probs,
        returns,
        advantage,
    )

In [None]:
train_env.close()
eval_env.close()