In [410]:
import tqdm

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import Trainer, Callback, seed_everything
from pytorch_lightning.core.lightning import LightningModule

from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3 import PPO, DQN
from stable_baselines3.common.env_util import make_vec_env

from src.environments import Scratch_Pad_Environment, AddGymEnv
from src.FasterMCTS import FasterMCTS
from src.RL_trainer import AlphaZero_Trainer

import gym
from transformers import BertConfig, BertModel
from transformers.modeling_outputs import CausalLMOutput
from transformers.generation_utils import GenerationMixin

%load_ext autoreload
%autoreload 2
%load_ext line_profiler
%load_ext memory_profiler

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler
The memory_profiler extension is already loaded. To reload it, use:
  %reload_ext memory_profiler


In [414]:
env = make_vec_env(AddGymEnv, 20, vec_env_cls=SubprocVecEnv, env_kwargs=dict(max_token_length=13, max_val=1))

In [415]:
config = BertConfig()
config.num_attention_heads = 8
config.hidden_size = 256
config.num_hidden_layers = 12
config.intermediate_size = 1024
config.n_ctx = 512
config.n_positions = 512
config.is_decoder = True
config.vocab_size = env.action_space.n
config.temp = 1
config.position_embedding_type = 'relative_key_query'

In [416]:
class CausalBERTFeatureExtractor(BaseFeaturesExtractor):
    """
    :param observation_space: (gym.Space)
    :param features_dim: (int) Number of features extracted.
        This corresponds to the number of unit for the last layer.
    """

    def __init__(self, observation_space: gym.spaces.MultiDiscrete, config=None, pad_id=0):
        super(CausalBERTFeatureExtractor, self).__init__(observation_space, config.hidden_size)
        self.transformer = BertModel(config)
        self.config = config
        self.pad_id = pad_id

    def forward(self, observations: torch.Tensor):
        batch_size = observations.shape[0]
        observations = observations.reshape(batch_size, -1, self.config.vocab_size)
        input_ids = observations.argmax(dim=-1)
        
        attention_mask = (input_ids != self.pad_id)
        sequence_lengths = torch.sum(attention_mask, dim=1)
        last_token_positions = sequence_lengths-1
        
        transformer_outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask)
        hidden_states = transformer_outputs[0]
        
        last_token_hidden_states = hidden_states[torch.arange(batch_size),last_token_positions,:]
                
        return last_token_hidden_states

In [417]:
class CausalBertLMPolicyWrapper(LightningModule, GenerationMixin):
    def __init__(self, policy, **kwargs):
        super().__init__(**kwargs)
        
        self.policy = policy
        self.action_head = policy.action_net
        self.value_head = policy.value_net
        self.vocab_size = policy.action_space.n
        self.config = policy.features_extractor.config
        self.to(learner.policy.device)
    
    def forward(self, input_ids, **kwargs):
        "ids: int tensor([batch_size, sequence_length])"
        batch_size, seq_len = input_ids.shape
        one_hot_ids = torch.nn.functional.one_hot(input_ids, self.vocab_size)
        
        tarnsformer_outputs = self.policy.features_extractor.transformer(input_ids)
        hidden_states = tarnsformer_outputs[0]
        lm_logits = self.action_head(hidden_states)
        return CausalLMOutput(logits=lm_logits)

In [418]:
lm_model = CausalBertLMPolicyWrapper(learner.policy)

In [420]:
indices = torch.randint(0,100, size=(4,7), device='cuda')
lm_model(indices).logits.shape

torch.Size([4, 7, 110])

In [421]:
lm_model.generate(indices)

tensor([[ 77,  15,  71,  35,  91,  74,  25,  58,  23,  23,  24,  50,  51,  51,
          23,  77,  75,  11,  40,   6],
        [ 95,  65,  52,  72,  95,  22,  27,  32,  24,  60,  93,  90,  24,  50,
         103,  23, 103,  38,  52,  90],
        [ 60,  27,  16,  76,  89,  76,  71,  50,  38,  94,  50,  95,  38,   9,
          26,  18,  38,  38,  77,  86],
        [ 70,  47,  44,  59,  69,  99,  78,  90, 103, 105,  50,  79,  11,  67,
          92,  50,  50,   6,  78,  91]], device='cuda:0')

In [422]:
pad_id=0
policy_kwargs = dict(
    features_extractor_class=CausalBERTFeatureExtractor,
    features_extractor_kwargs=dict(config=config, pad_id=pad_id),
)

learner = PPO('MlpPolicy', env, policy_kwargs=policy_kwargs, n_steps=256, verbose=10, batch_size=256)

Using cuda device


In [278]:
learner.learn(100000)

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 4.95     |
|    ep_rew_mean     | -1       |
| time/              |          |
|    fps             | 1220     |
|    iterations      | 1        |
|    time_elapsed    | 4        |
|    total_timesteps | 5120     |
| train/             |          |
|    learning_rate   | 0.0003   |
---------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 4.9        |
|    ep_rew_mean          | -1         |
| time/                   |            |
|    fps                  | 433        |
|    iterations           | 2          |
|    time_elapsed         | 23         |
|    total_timesteps      | 10240      |
| train/                  |            |
|    approx_kl            | 0.01352766 |
|    clip_fraction        | 0.319      |
|    clip_range           | 0.2        |
|    entropy_loss         | -5.59      |
|    explained_var

<stable_baselines3.ppo.ppo.PPO at 0x7f3c9267aba8>

In [430]:
env = AddGymEnv(max_token_length=13, max_val=1)
obs = env.reset()
for i in range(10):
    action, _states = learner.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()
    if dones:
        env.reset()

[BOS]What is 0+0?)
[BOS]What is 0+0?)`
[BOS]What is 0+0?)`[VALUE]
[BOS]What is 0+0?)`[VALUE]5
[BOS]What is 0+0?)`[VALUE]5'
[BOS]What is 0+0?j
[BOS]What is 0+0?jK
[BOS]What is 0+0?jKa
[BOS]What is 0+0?jKa;
[BOS]What is 0+0?jKa;n
