In [None]:
import random
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.distributions import Categorical
from torch.nn.utils import clip_grad_norm_
from einops import rearrange
from functools import partial
from transformers import GPT2Model

# env
import gym
import slimevolleygym
from slimevolleygym import FrameStack
from stable_baselines3.common.atari_wrappers import ClipRewardEnv, NoopResetEnv, MaxAndSkipEnv, WarpFrame, ClipRewardEnv

# logging and vis
import matplotlib.pyplot as plt
import wandb
from tqdm.notebook import tqdm
import cv2
from time import sleep
from gym.envs.classic_control import rendering as rendering
from array2gif import write_gif
from pathlib import Path

## GPT

In [None]:
class FPT(nn.Module):

    def __init__(
            self,
            input_dim,
            output_dim,
            model_name='gpt2',
            pretrained=False,
            return_last_only=True,
            use_embeddings_for_in=False,
            in_layer_sizes=None,
            out_layer_sizes=None,
            freeze_trans=True,
            freeze_in=False,
            freeze_pos=False,
            freeze_ln=False,
            freeze_attn=True,
            freeze_ff=True,
            freeze_out=False,
            dropout=0.1,
            orth_gain=1.41,
    ):
        super().__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.model_name = model_name
        self.return_last_only = return_last_only
        self.use_embeddings_for_in = use_embeddings_for_in

        self.in_layer_sizes = [] if in_layer_sizes is None else in_layer_sizes
        self.out_layer_sizes = [] if out_layer_sizes is None else out_layer_sizes
        self.dropout = dropout

        if 'gpt' in model_name:
            assert model_name in ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']

            from transformers import GPT2Model

            pretrained_transformer = GPT2Model.from_pretrained(model_name)
            if pretrained:
                self.transformer = pretrained_transformer
            else:
                self.transformer = GPT2Model(pretrained_transformer.config)

            if model_name == 'gpt2':
                embedding_size = 768
            elif model_name == 'gpt2-medium':
                embedding_size = 1024
            elif model_name == 'gpt2-large':
                embedding_size = 1280
            elif model_name == 'gpt2-xl':
                embedding_size = 1600

        else:
            raise NotImplementedError('model_name not implemented')

        if use_embeddings_for_in:
            self.in_net = nn.Embedding(input_dim, embedding_size)
        else:
            in_layers = []
            last_output_size = input_dim
            for size in self.in_layer_sizes:
                layer = nn.Linear(last_output_size, size)
                if orth_gain is not None:
                    torch.nn.init.orthogonal_(layer.weight, gain=orth_gain)
                layer.bias.data.zero_()

                in_layers.append(layer)
                in_layers.append(nn.ReLU())
                in_layers.append(nn.Dropout(dropout))
                last_output_size = size

            final_linear = nn.Linear(last_output_size, embedding_size)
            if orth_gain is not None:
                torch.nn.init.orthogonal_(final_linear.weight, gain=orth_gain)
            final_linear.bias.data.zero_()

            in_layers.append(final_linear)
            in_layers.append(nn.Dropout(dropout))

            self.in_net = nn.Sequential(*in_layers)

        out_layers = []
        last_output_size = embedding_size
        for size in self.out_layer_sizes:
            out_layers.append(nn.Linear(last_output_size, size))
            out_layers.append(nn.ReLU())
            out_layers.append(nn.Dropout(dropout))
            last_output_size = size
        out_layers.append(nn.Linear(last_output_size, output_dim))
        self.out_net = nn.Sequential(*out_layers)

        if freeze_trans:
            for name, p in self.transformer.named_parameters():
                name = name.lower()
                if 'ln' in name:
                    p.requires_grad = not freeze_ln
                elif 'wpe' in name:
                    p.requires_grad = not freeze_pos
                elif 'mlp' in name:
                    p.requires_grad = not freeze_ff
                elif 'attn' in name:
                    p.requires_grad = not freeze_attn
                else:
                    p.requires_grad = False
        if freeze_in:
            for p in self.in_net.parameters():
                p.requires_grad = False
        if freeze_out:
            for p in self.out_net.parameters():
                p.requires_grad = False

    def forward(self, x, output_attentions=False):

        orig_dim = x.shape[-1]
        if orig_dim != self.input_dim and not self.use_embeddings_for_in:
            if orig_dim % self.input_dim != 0:
                raise ValueError('dimension of x must be divisible by patch size')
            ratio = orig_dim // self.input_dim
            x = x.reshape(x.shape[0], x.shape[1] * ratio, self.input_dim)
        else:
            ratio = 1

        x = self.in_net(x)

        transformer_outputs = self.transformer(
            inputs_embeds=x,
            return_dict=True,
            output_attentions=output_attentions,
        )
        x = transformer_outputs.last_hidden_state

        if self.return_last_only:
            x = x[:,-ratio:]

        x = self.out_net(x)
        if self.return_last_only and ratio > 1:
            x = x.reshape(x.shape[0], x.shape[1] // ratio, ratio * self.output_dim)

        if output_attentions:
            return x, transformer_outputs.attentions
        else:
            return x

## GPT Actor-Critic

In [None]:
def ortho_init(module, gain):
    if isinstance(module, (nn.Linear, nn.Conv2d)):
        nn.init.orthogonal_(module.weight, gain=gain)
        if module.bias is not None:
            module.bias.data.fill_(0.0) 


class GPTActorCritic(nn.Module):
    
    def __init__(self,
                 input_dim,
                 out_dim,
                 n_actions,
                 patch_size,
                 device='cuda',
                 model_name='gpt2',
                 pretrained=False,
                 return_last_only=True,
                 use_embeddings_for_in=False,
                 in_layer_sizes=None,
                 out_layer_sizes=None,
                 freeze_trans=True,
                 freeze_in=False,
                 freeze_pos=False,
                 freeze_ln=False,
                 freeze_attn=True,
                 freeze_ff=True,
                 freeze_out=False,
                 dropout=0.1,
                 orth_gain=1.41):
        
        super().__init__()
        
        self.patch_size = patch_size
        self.device = device
        
        self.fpt = FPT(input_dim,
                       out_dim,
                       model_name,
                       pretrained,
                       return_last_only,
                       use_embeddings_for_in,
                       in_layer_sizes,
                       out_layer_sizes,
                       freeze_trans,
                       freeze_in,
                       freeze_pos,
                       freeze_ln,
                       freeze_attn,
                       freeze_ff,
                       freeze_out,
                       dropout,
                       orth_gain)
        
        self.policy_head = nn.Linear(out_dim, n_actions) #TODO output dim
        self.value_head  = nn.Linear(out_dim, 1)
        
        self.init_head_weights()
    
    def forward(self, obs):
        obs = self.prepare_obs(obs) # (bs, n_patches, features_dim)
    
        features = self.fpt(obs).squeeze(1) # (bs, features)
        policy = Categorical(logits=self.policy_head(features))
        value  = self.value_head(features)
        return policy, value
        
        
    def prepare_obs(self, obs):
        """
        prepares numpy observations for GPT feature extractor
        input: np.array (bs, h, w, c)
        returns: torch.tensor (bs, n_patches, features_dim)
        """
        if not torch.is_tensor(obs):
            # if True, obs comes from environment. Else, from Rollout Buffer 
            obs = torch.as_tensor(obs, dtype=torch.float32)
            obs = obs / 255.
        
        # channel first
        obs = rearrange(obs, 'b h w c -> b c h w')
        # tokenise
        obs = rearrange(obs, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)
        return obs.to(self.device)
    
    def init_head_weights(self):
        """
        Orthogonal initialization for policy and value head with OpenAI baseline gains
        """
        module_gains = {
                self.policy_head: 0.01,
                self.value_head: 1,
        }
        for module, gain in module_gains.items():
            module.apply(partial(ortho_init, gain=gain))

## Slimevolley Environment

In [None]:
def make_env(seed):
    # almost the same as typical Atari processing for CNN agent.
    env = gym.make("SlimeVolleySurvivalNoFrameskip-v0")
    env = NoopResetEnv(env, noop_max=30)
    env = MaxAndSkipEnv(env, skip=4)
    env = WarpFrame(env)
    #env = ClipRewardEnv(env)
    env = FrameStack(env, 4)
    env.seed(seed)
    return env


def make_eval_env(seed):
    env = gym.make("SlimeVolleyNoFrameskip-v0")
    env = NoopResetEnv(env, noop_max=30)
    env = MaxAndSkipEnv(env, skip=4)
    env = WarpFrame(env)
    #env = ClipRewardEnv(env)
    env = FrameStack(env, 4)
    env.seed(seed)
    return env

## PPO

### Rolloutbuffer

In [None]:
from collections import namedtuple

RolloutBufferSamples = namedtuple('RolloutBufferSamples', ('observations', 'actions', 'old_values', 'old_log_probs', 'advantages', 'returns'))

class RolloutBuffer:

    def __init__(
        self,
        buffer_size,
        observation_size,
        device = "cpu",
        gae_lambda = 1,
        gamma = 0.99
    ):  
        
        self.buffer_size = buffer_size
        self.observation_size = observation_size
        
        self.gae_lambda = gae_lambda
        self.gamma = gamma
        
        self.device = device
        
        self.reset()

    def reset(self):
        self.pos = 0
        self.full = False
        self.generator_ready = False
        
        zeros = lambda shape: np.zeros(shape, dtype=np.float32)
        self.observations = zeros((self.buffer_size,) + self.observation_size)
        self.actions      = zeros((self.buffer_size, 1))
        self.rewards      = zeros((self.buffer_size, 1))
        self.returns      = zeros((self.buffer_size, 1))
        self.dones        = zeros((self.buffer_size, 1))
        self.values       = zeros((self.buffer_size, 1))
        self.log_probs    = zeros((self.buffer_size, 1))
        self.advantages   = zeros((self.buffer_size, 1))

    def compute_returns_and_advantage(self, last_values, dones):
        
        last_values = last_values.clone().cpu().numpy().flatten()
        last_gae_lam = 0
        
        for step in reversed(range(self.buffer_size)):
            if step == self.buffer_size - 1:
                next_non_terminal = 1.0 - dones
                next_values = last_values
            else:
                next_non_terminal = 1.0 - self.dones[step + 1]
                next_values = self.values[step + 1]
            delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
            last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
            self.advantages[step] = last_gae_lam
        self.returns = self.advantages + self.values

    def add(self, obs, action, reward, done, value, log_prob):
    
        if len(log_prob.shape) == 0:
            log_prob = log_prob.reshape(-1, 1)
        
        self.observations[self.pos] = np.array(obs).copy()
        self.actions[self.pos]      = np.array(action).copy()
        self.rewards[self.pos]      = np.array(reward).copy()
        self.dones[self.pos]        = np.array(done).copy()
        self.values[self.pos]       = value.clone().cpu().numpy().flatten()
        self.log_probs[self.pos]    = log_prob.clone().cpu().numpy()
        
        self.pos += 1
        
        if self.pos == self.buffer_size:
            self.full = True

    def get(self, batch_size):
        
        indices = np.random.permutation(self.buffer_size)
        start_idx = 0
        while start_idx < self.buffer_size:
            yield self._get_samples(indices[start_idx : start_idx + batch_size])
            start_idx += batch_size

    def _get_samples(self, batch_inds):
        data = (
            self.observations[batch_inds],
            self.actions[batch_inds],
            self.values[batch_inds].flatten(),
            self.log_probs[batch_inds].flatten(),
            self.advantages[batch_inds].flatten(),
            self.returns[batch_inds].flatten(),
        )
        
        return RolloutBufferSamples(*tuple(map(self.to_torch, data)))
    
    def to_torch(self, array):
        return torch.tensor(array).to(self.device)

In [None]:
def explained_variance(y_pred, y_true):
    assert y_true.ndim == 1 and y_pred.ndim == 1
    var_y = np.var(y_true)
    return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

### Evaluation and Visualisation

In [None]:
def render_atari(obs):
    """
    Helper function that takes in a processed obs (84,84,4)
    Useful for visualizing what an Atari agent actually *sees*
    Outputs in Atari visual format (Top: resized to orig dimensions, buttom: 4 frames)
    """
    tempObs = []
    obs = np.copy(obs)
    for i in range(4):
        if i == 3:
            latest = np.copy(obs[:, :, i])
        if i > 0: # insert vertical lines
            obs[:, 0, i] = 141
        tempObs.append(obs[:, :, i])
    latest = np.expand_dims(latest, axis=2)
    latest = np.concatenate([latest*255.0] * 3, axis=2).astype(np.uint8)
    latest = cv2.resize(latest, (84 * 8, 84 * 4), interpolation=cv2.INTER_NEAREST)
    tempObs = np.concatenate(tempObs, axis=1)
    tempObs = np.expand_dims(tempObs, axis=2)
    tempObs = np.concatenate([tempObs*255.0] * 3, axis=2).astype(np.uint8)
    tempObs = cv2.resize(tempObs, (84 * 8, 84 * 2), interpolation=cv2.INTER_NEAREST)
    return np.concatenate([latest, tempObs], axis=0)

In [None]:
@torch.no_grad()
def evaluate_agent(agent, env, n_episodes=100, render=False, gif_path=None):
    echo(f'Evaluating Agent...')
    
    total_rewards = []
    all_frames = []
    
    if render:
        viewer = rendering.SimpleImageViewer(maxwidth=2160)
    
    for i in range(n_episodes):
        
        rollout_frames = []
        
        obs  = env.reset()
        done = False
        total_reward = 0.
        
        while not done:
            
            policy, value = agent(obs[None])
            #action = policy.logits.argmax(dim=1) # exploit at test time
            action = policy.sample()
            
            action = action.cpu().numpy()
            obs, reward, done, _ = env.step(action.item())
            
            if render:
                viewer.imshow(render_atari(obs))
                sleep(0.08)
                
            if gif_path is not None:
                frame = np.einsum('hwc->chw', render_atari(obs))
                rollout_frames.append(frame)
            
            total_reward += reward
            
        total_rewards.append(total_reward)
        all_frames.append(rollout_frames)
        echo(f'Episode {i+1} done.')
        
    if render:
        viewer.close()
        
    if gif_path is not None:
        print(f'Writing gifs to {gif_path}...')
        gif_path = Path(gif_path)
        if not gif_path.exists(): 
            gif_path.mkdir()
        for i, frames in enumerate(all_frames, start=1):
            write_gif(frames, gif_path/f'rollout_{i}.gif', fps=5)
            echo(f'Rollout {i}/{n_episodes} done.')
    
    return np.mean(total_rewards)

## Training

In [None]:
# reproducibility
def set_random_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    # for cudnn
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
def echo(message):
    print(message + ' '*1000, end='\r')

In [None]:
# Hyperparameters

## General
SEED = 101
#DEVICE = torch.device('cuda:0')
DEVICE = torch.device('cuda:0')
MAX_STEPS = 1_000_000
EVAL_FREQ = 10_000
EVAL_EPISODES = 100 # make sure to rerun on 1000 episodes afterwards to measure that performance, don't use only 100.

## PPO
LEARNING_RATE = 3e-4
MAX_GRAD_NORM = 0.5
VALUE_COEF = 0.5
ENTROPY_COEF = 0.0
TARGET_KL = None
BATCH_SIZE = 12
PPO_CLIP_RANGE = 0.2
BUFFER_SIZE = 4096 # 2048
PPO_EPOCHS = 4 #10
DISCOUNT = 0.99
TRACE_DECAY = 0.95

## GPT Actor Critic
PATCH_SIZE = 4
INPUT_SIZE = 4 * PATCH_SIZE**2
OUT_SIZE = 64

## LOGGING
LOGGING_MODE = 'online'

In [None]:
# init env
set_random_seeds(SEED)
env = make_env(SEED)
eval_env = make_eval_env(SEED)

# init model
agent = GPTActorCritic(input_dim=INPUT_SIZE, 
                       out_dim=OUT_SIZE,
                       n_actions=env.action_space.n,
                       patch_size=PATCH_SIZE,
                       device=DEVICE,
                       pretrained=True,
                       freeze_trans=True,
                       freeze_in=False,
                       freeze_pos=False,
                       freeze_ln=False,
                       freeze_attn=True,
                       freeze_ff=True,
                       freeze_out=False) 

agent = agent.to(DEVICE)
optimizer = optim.Adam(agent.parameters(), lr=LEARNING_RATE)

# init rollout buffer
rollout_buffer = RolloutBuffer(BUFFER_SIZE, env.observation_space.shape, gamma=DISCOUNT, gae_lambda=TRACE_DECAY, device=DEVICE)

In [None]:
with wandb.init(project="rl-gpt", mode=LOGGING_MODE):

    last_obs  = env.reset()
    last_done = False
    total_reward = 0.

    pbar = tqdm(range(MAX_STEPS), unit_scale=1, smoothing=0)
    for step in pbar:

        with torch.no_grad():
            #obs_tensor = torch.as_tensor(last_obs).to(DEVICE)
            policy, value = agent(last_obs[None])
            action = policy.sample()
            log_prob = policy.log_prob(action)

        #import pdb; pdb.set_trace()
        action = action.cpu().numpy()
        obs, reward, done, _ = env.step(action.item())

        rollout_buffer.add(last_obs, action, reward, last_done, value, log_prob)

        last_obs = obs
        last_done = done
        total_reward += reward

        if done:
            last_obs = env.reset()
            pbar.set_description('Step: %i | Reward: %f' % (step, total_reward))
            wandb.log({'total_reward': total_reward})
            total_reward = 0.
            
        # evaluate policy
        if step % EVAL_FREQ == 0:
            agent.eval()
            mean_eval_reward = evaluate_agent(agent, eval_env, EVAL_EPISODES)
            wandb.log({'mean_eval_reward': mean_eval_reward})
            
            # early stopping
            if mean_eval_reward >= 0.0:
                echo(f'Mean Eval Reward: {mean_eval_reward}. Early stopping...')
                break
            
        
        # train policy
        if rollout_buffer.full:
            agent.train()
            
            with torch.no_grad():
                # Compute value for the last timestep
                #obs_tensor = torch.as_tensor(obs).to(DEVICE)
                _, value = agent(obs[None])

            # compute returns and advantages
            rollout_buffer.compute_returns_and_advantage(value, done)

            clip_range = PPO_CLIP_RANGE

            # containers for collecting training stats
            entropy_losses = [] 
            all_kl_divs    = []
            policy_losses  = []
            value_losses   = []
            clip_fractions = []
            approx_kl_divs = []

            # Start Training
            for epoch in range(PPO_EPOCHS):
                echo(f'Step: {step}; Epoch: {epoch}; Starting PPO Training...')
                for rollout_data in rollout_buffer.get(BATCH_SIZE):
                    actions = rollout_data.actions.long().flatten()

                    # evaluate actions
                    policy, value = agent(rollout_data.observations)
                    log_prob = policy.log_prob(actions)
                    entropy = policy.entropy()
                    value = value.flatten()

                    # Extra step: normalize advantage
                    advantages = rollout_data.advantages
                    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

                    # ratio between old and new policy, should be one at the first iteration
                    ratio = torch.exp(log_prob - rollout_data.old_log_probs)

                    # clipped surrogate loss
                    policy_loss_1 = advantages * ratio
                    policy_loss_2 = advantages * torch.clamp(ratio, 1 - clip_range, 1 + clip_range)
                    policy_loss   = -torch.min(policy_loss_1, policy_loss_2).mean()

                    # Value loss using the TD(gae_lambda) target
                    value_loss = F.mse_loss(rollout_data.returns, value)

                    entropy_loss = -torch.mean(entropy)

                    loss = policy_loss + ENTROPY_COEF * entropy_loss + VALUE_COEF * value_loss

                    # Optimization step
                    optimizer.zero_grad()
                    loss.backward()
                    clip_grad_norm_(agent.parameters(), MAX_GRAD_NORM)
                    optimizer.step()

                    # Logging
                    approx_kl_divs.append(torch.mean(rollout_data.old_log_probs - log_prob).detach().cpu().numpy())
                    policy_losses.append(policy_loss.item())
                    clip_fraction = torch.mean((torch.abs(ratio - 1) > clip_range).float()).item()
                    clip_fractions.append(clip_fraction)
                    value_losses.append(value_loss.item())
                    entropy_losses.append(entropy_loss.item())

                echo(f'PPO Epoch {epoch} done.')
                
                all_kl_divs.append(np.mean(approx_kl_divs))

                if TARGET_KL is not None and np.mean(approx_kl_divs) > 1.5 * TARGET_KL:
                    print(f'Early stopping at step {step} due to reaching max kl: {np.mean(approx_kl_divs):.2f}')
                    break

            explained_var = explained_variance(rollout_buffer.values.flatten(), rollout_buffer.returns.flatten())

            # Empty the rollout buffer
            rollout_buffer.reset()

            # Log Training Progress
            wandb.log({"entropy_loss": np.mean(entropy_losses)}, step=step)
            wandb.log({"policy_gradient_loss": np.mean(policy_losses)}, step=step)
            wandb.log({"value_loss": np.mean(value_losses)}, step=step)
            wandb.log({"approx_kl": np.mean(approx_kl_divs)}, step=step)
            wandb.log({"clip_fraction": np.mean(clip_fractions)}, step=step)
            wandb.log({"loss": loss.item()}, step=step)
            wandb.log({"explained_variance": explained_var}, step=step)

## Visualise Trained Agent

In [None]:
agent.eval()
evaluate_agent(agent, eval_env, n_episodes=1, render=True, gif_path='./gifs')