In [1]:
import yaml
from datasets import load_dataset
from transformers import pipeline
import pathlib
from typing import Dict, List
import trlx
from trlx.data.default_configs import TRLConfig, default_ppo_config

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
default_config = default_ppo_config().to_dict()
default_config['train']['total_steps'] = 10
default_config['train']['tracker'] = 'tensorboard'
default_config['method']['chunk_size'] = 2

config = TRLConfig.update(default_config, {})

config.model.model_path = 'lvwerra/gpt2-imdb'

In [3]:
def reward_fn(samples: List[str], **kwargs) -> List[float]:
    sentiments = list(map(get_positive_score, sentiment_fn(samples)))
    return sentiments
    
def get_positive_score(scores):
    "Extract value associated with a positive sentiment from pipeline's output"
    return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]

sentiment_fn = pipeline(
    "sentiment-analysis",
    "lvwerra/distilbert-imdb",
    top_k=2,
    truncation=True,
    batch_size=256,
    device=0,
)

imdb = load_dataset("imdb", split="train+test")
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]

Found cached dataset imdb (/home/sshuser/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


In [4]:
from trlx.utils.loading import get_pipeline, get_trainer

In [5]:
trainer = get_trainer(config.train.trainer)(
    config=config,
    reward_fn=reward_fn,
    **config.train.trainer_kwargs,
)

[RANK 0] Initializing model: lvwerra/gpt2-imdb


{
    "method": {
        "name": "PPOConfig",
        "ppo_epochs": 4,
        "num_rollouts": 128,
        "chunk_size": 2,
        "init_kl_coef": 0.05,
        "target": 6,
        "horizon": 10000,
        "gamma": 1,
        "lam": 0.95,
        "cliprange": 0.2,
        "cliprange_value": 0.2,
        "vf_coef": 1,
        "scale_reward": "ignored",
        "ref_mean": null,
        "ref_std": null,
        "cliprange_reward": 10,
        "gen_kwargs": {
            "max_new_tokens": 40,
            "top_k": 0,
            "top_p": 1.0,
            "do_sample": true
        },
        "gen_experience_kwargs": null
    },
    "model": {
        "model_path": "lvwerra/gpt2-imdb",
        "model_arch_type": "causal",
        "num_layers_unfrozen": 2,
        "delta_kwargs": null
    },
    "optimizer": {
        "name": "adamw",
        "kwargs": {
            "lr": 0.0001,
            "betas": [
                0.9,
                0.95
            ],
            "eps": 1e-08,
   

In [6]:
from trlx.utils import set_seed
set_seed(config.train.seed)

In [7]:
# freeze 10 h-layers, and unfreeze the last 2 h-layers.
trainer.model

AutoModelForCausalLMWithHydraValueHead(
  (base_model): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  )
  (v

In [8]:
batch_size = config.train.batch_size
batch_size

32

In [9]:
max_prompt_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"]
max_prompt_length

984

In [10]:
prompts = prompts or [trainer.tokenizer.bos_token] * batch_size
prompts = prompts[:2]
prompts

['I rented I AM', '"I Am Curious: Yellow"']

In [11]:
import pprint
pprint.pprint(trainer.tokenizer(prompts))

{'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1]],
 'input_ids': [[40, 26399, 314, 3001], [1, 40, 1703, 44269, 25, 12550, 1]]}


In [12]:
pipeline = get_pipeline(config.train.pipeline)(prompts, max_prompt_length, trainer.tokenizer)
pprint.pprint(pipeline.prompts)

[{'attention_mask': [1, 1, 1, 1], 'input_ids': [40, 26399, 314, 3001]},
 {'attention_mask': [1, 1, 1, 1, 1, 1, 1],
  'input_ids': [1, 40, 1703, 44269, 25, 12550, 1]}]


In [13]:
trainer.add_prompt_pipeline(pipeline)
ppo_rl_elements = []

In [14]:
from trlx.data.accelerate_base_datatypes import PromptBatch
# random sampling
batch: PromptBatch = next(trainer.prompt_iterator)
pprint.pprint(batch)

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1, 1]], device='cuda:0'),
 'input_ids': tensor([[    1,    40,  1703, 44269,    25, 12550,     1],
        [50256, 50256, 50256,    40, 26399,   314,  3001]], device='cuda:0')}


In [15]:
batch

{'input_ids': tensor([[    1,    40,  1703, 44269,    25, 12550,     1],
        [50256, 50256, 50256,    40, 26399,   314,  3001]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1, 1]], device='cuda:0')}

In [16]:
pipeline.tokenizer(prompts, truncation=True, padding=False, max_length=max_prompt_length, add_special_tokens=False)

{'input_ids': [[40, 26399, 314, 3001], [1, 40, 1703, 44269, 25, 12550, 1]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1]]}

In [17]:
samples = trainer.generate(**batch)
pprint.pprint(samples)

tensor([[    1,    40,  1703, 44269,    25, 12550,     1,  3584,  6159,   286,
           262,   584,  7328,   373,  7924,   355,  2092,   355, 11609,     6,
           268,   338, 41727,  6592,    13,  1406,    11, 17826, 24868,   318,
           262,   691,   530,   508,  1107, 12766,   284,  1833,   644,   428,
           318,   477,   546,    13,   383,  3437,  2241],
        [50256, 50256, 50256,    40, 26399,   314,  3001,   257,  1218,  1057,
           290,   373,  6655,   284,  1064, 33728,  1338,  4224,  1107, 16403,
           284,   617,  3354,   286,   262,  1382,   281, 10059,  1204,  2854,
           287,  1103,  1204,    13,   383, 16031,   287,   428,  3807,   547,
          8531,   290, 19989,  1103,  1204,  2694,    13]], device='cuda:0')


In [18]:
batch

{'input_ids': tensor([[    1,    40,  1703, 44269,    25, 12550,     1],
        [50256, 50256, 50256,    40, 26399,   314,  3001]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1, 1]], device='cuda:0')}

In [19]:
prompt_tensors = batch.input_ids

In [20]:
import torch
input_ids_expr = torch.Tensor([[1],[2]])
input_ids_expr = input_ids_expr.to('cuda', dtype=torch.int)
attention_mask_expr = torch.Tensor([[1],[1]])
attention_mask_expr = attention_mask_expr.to('cuda', dtype=torch.int)

In [21]:
trainer.model.generate(input_ids=input_ids_expr, attention_mask=attention_mask_expr)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


tensor([[    1,   290,   366,   464,  1869,  5338,   509,  3605, 14190, 13111,
             1,   389,   262,   691,   734,  7328,   314,  1053,  1775,   326],
        [    2,    13, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256]],
       device='cuda:0')

In [22]:
input_ids_expr

tensor([[1],
        [2]], device='cuda:0', dtype=torch.int32)

In [23]:
a = trainer.model(input_ids_expr)

In [24]:
trainer.model.base_model.generate(input_ids=input_ids_expr, attention_mask=attention_mask_expr)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


tensor([[    1,   290,   366,   464,  1869,  5338,   509,  3605, 14190, 13111,
             1,   389,   262,   691,   734,  7328,   314,  1053,  1775,   326],
        [    2,    13, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256]],
       device='cuda:0')

In [25]:
kwargs = {'max_new_tokens': 40, 'top_k': 1, 'top_p': 1.0, 'do_sample': True, 'eos_token_id': 50256, 'pad_token_id': 50256}
trainer.model.base_model.generate(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], **kwargs)

tensor([[    1,    40,  1703, 44269,    25, 12550,     1,   318,   257,  1049,
          3807,    13,   632,   318,   257,  1049,  3807,    13,   632,   318,
           257,  1049,  3807,    13,   632,   318,   257,  1049,  3807,    13,
           632,   318,   257,  1049,  3807,    13,   632,   318,   257,  1049,
          3807,    13,   632,   318,   257,  1049,  3807],
        [50256, 50256, 50256,    40, 26399,   314,  3001,  3336, 32957,  3963,
          3336, 32957,  3963,  3336, 32957,  3963,  3336, 32957,  3963,  3336,
         32957,  3963,  3336, 32957,  3963,  3336, 32957,  3963,  3336, 32957,
          3963,  3336, 32957,  3963,  3336, 32957,  3963,  3336, 32957,  3963,
          3336, 32957,  3963,  3336, 32957,  3963,  3336]], device='cuda:0')

In [26]:
trainer.model.base_model.transformer.h[0].mlp.c_fc.weight

Parameter containing:
tensor([[ 0.0946,  0.0982, -0.0275,  ..., -0.1776,  0.1447,  0.0717],
        [-0.1265, -0.0656,  0.0332,  ...,  0.1970, -0.1247, -0.0649],
        [ 0.0536, -0.0389, -0.0499,  ...,  0.0678, -0.0733,  0.0841],
        ...,
        [ 0.0514,  0.1575,  0.0029,  ..., -0.3976,  0.0898,  0.0218],
        [ 0.0326,  0.1229, -0.0413,  ..., -0.1904,  0.1263, -0.0408],
        [-0.0344,  0.0023, -0.0510,  ..., -0.0417,  0.0563,  0.1917]],
       device='cuda:0')

In [27]:
import torch
device = samples.device
prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device)

In [28]:
prompt_tensors[1]

tensor([50256, 50256, 50256,    40, 26399,   314,  3001], device='cuda:0')

In [29]:
prompt_sizes

tensor([7, 7], device='cuda:0')

In [30]:
padded_samples = trainer.accelerator.pad_across_processes(
    samples, dim=1, pad_index=trainer.tokenizer.eos_token_id, pad_first=False
)

In [31]:
prompt_tensors

tensor([[    1,    40,  1703, 44269,    25, 12550,     1],
        [50256, 50256, 50256,    40, 26399,   314,  3001]], device='cuda:0')

In [32]:
samples

tensor([[    1,    40,  1703, 44269,    25, 12550,     1,  3584,  6159,   286,
           262,   584,  7328,   373,  7924,   355,  2092,   355, 11609,     6,
           268,   338, 41727,  6592,    13,  1406,    11, 17826, 24868,   318,
           262,   691,   530,   508,  1107, 12766,   284,  1833,   644,   428,
           318,   477,   546,    13,   383,  3437,  2241],
        [50256, 50256, 50256,    40, 26399,   314,  3001,   257,  1218,  1057,
           290,   373,  6655,   284,  1064, 33728,  1338,  4224,  1107, 16403,
           284,   617,  3354,   286,   262,  1382,   281, 10059,  1204,  2854,
           287,  1103,  1204,    13,   383, 16031,   287,   428,  3807,   547,
          8531,   290, 19989,  1103,  1204,  2694,    13]], device='cuda:0')

In [33]:
padded_samples

tensor([[    1,    40,  1703, 44269,    25, 12550,     1,  3584,  6159,   286,
           262,   584,  7328,   373,  7924,   355,  2092,   355, 11609,     6,
           268,   338, 41727,  6592,    13,  1406,    11, 17826, 24868,   318,
           262,   691,   530,   508,  1107, 12766,   284,  1833,   644,   428,
           318,   477,   546,    13,   383,  3437,  2241],
        [50256, 50256, 50256,    40, 26399,   314,  3001,   257,  1218,  1057,
           290,   373,  6655,   284,  1064, 33728,  1338,  4224,  1107, 16403,
           284,   617,  3354,   286,   262,  1382,   281, 10059,  1204,  2854,
           287,  1103,  1204,    13,   383, 16031,   287,   428,  3807,   547,
          8531,   290, 19989,  1103,  1204,  2694,    13]], device='cuda:0')

In [34]:
padded_prompts = trainer.accelerator.pad_across_processes(
    prompt_tensors, dim=1, pad_index=trainer.tokenizer.eos_token_id, pad_first=False
)
padded_prompts

tensor([[    1,    40,  1703, 44269,    25, 12550,     1],
        [50256, 50256, 50256,    40, 26399,   314,  3001]], device='cuda:0')

In [35]:
gathered_samples = trainer.accelerator.gather(padded_samples)
gathered_samples

tensor([[    1,    40,  1703, 44269,    25, 12550,     1,  3584,  6159,   286,
           262,   584,  7328,   373,  7924,   355,  2092,   355, 11609,     6,
           268,   338, 41727,  6592,    13,  1406,    11, 17826, 24868,   318,
           262,   691,   530,   508,  1107, 12766,   284,  1833,   644,   428,
           318,   477,   546,    13,   383,  3437,  2241],
        [50256, 50256, 50256,    40, 26399,   314,  3001,   257,  1218,  1057,
           290,   373,  6655,   284,  1064, 33728,  1338,  4224,  1107, 16403,
           284,   617,  3354,   286,   262,  1382,   281, 10059,  1204,  2854,
           287,  1103,  1204,    13,   383, 16031,   287,   428,  3807,   547,
          8531,   290, 19989,  1103,  1204,  2694,    13]], device='cuda:0')

In [36]:
gathered_prompts = trainer.accelerator.gather(padded_prompts)
gathered_prompts

tensor([[    1,    40,  1703, 44269,    25, 12550,     1],
        [50256, 50256, 50256,    40, 26399,   314,  3001]], device='cuda:0')

In [37]:
gathered_prompt_sizes = trainer.accelerator.gather(prompt_sizes)
gathered_prompt_sizes

tensor([7, 7], device='cuda:0')

In [38]:
# - samples: full sentences.
# - prompts: take few words off of movies reviews.
# - outputs: outputs generated using prompts
all_str_samples, all_str_prompts, all_str_outputs = trainer.decode(
    gathered_prompts, gathered_samples, gathered_prompt_sizes
)

In [39]:
all_str_samples

['"I Am Curious: Yellow" although neither of the other films was directed as similar as Ye\'en\'s koja. So, Jamie Lynn is the only one who really struggles to understand what this is all about. The director himself',
 'I rented I AM a second run and was surprised to find Andreas Spitz really appealing to some parts of the build an unexpected life performance in real life. The saves in this movie were stupid and lacked real life ability.']

In [40]:
all_str_prompts

['"I Am Curious: Yellow"', 'I rented I AM']

In [41]:
all_str_outputs

[" although neither of the other films was directed as similar as Ye'en's koja. So, Jamie Lynn is the only one who really struggles to understand what this is all about. The director himself",
 ' a second run and was surprised to find Andreas Spitz really appealing to some parts of the build an unexpected life performance in real life. The saves in this movie were stupid and lacked real life ability.']

In [42]:
all_scores = torch.tensor(
    trainer.reward_fn(
        samples=all_str_samples,
        prompts=all_str_prompts,
        outputs=all_str_outputs,
    ),
    dtype=torch.float,
    device=device,
)

In [43]:
all_scores

tensor([0.4048, 0.0483], device='cuda:0')

In [44]:
all_scores = list(all_scores.reshape(trainer.accelerator.num_processes, -1).unbind())
all_scores

[tensor([0.4048, 0.0483], device='cuda:0')]

In [45]:
scores = torch.tensor(all_scores[0])
scores

  scores = torch.tensor(all_scores[0])


tensor([0.4048, 0.0483], device='cuda:0')

In [46]:
str_samples, str_prompts, str_outputs = trainer.decode(prompt_tensors, samples)

In [47]:
str_samples

['"I Am Curious: Yellow" although neither of the other films was directed as similar as Ye\'en\'s koja. So, Jamie Lynn is the only one who really struggles to understand what this is all about. The director himself',
 'I rented I AM a second run and was surprised to find Andreas Spitz really appealing to some parts of the build an unexpected life performance in real life. The saves in this movie were stupid and lacked real life ability.']

In [48]:
all_str_samples

['"I Am Curious: Yellow" although neither of the other films was directed as similar as Ye\'en\'s koja. So, Jamie Lynn is the only one who really struggles to understand what this is all about. The director himself',
 'I rented I AM a second run and was surprised to find Andreas Spitz really appealing to some parts of the build an unexpected life performance in real life. The saves in this movie were stupid and lacked real life ability.']

In [49]:
outputs = trainer.tokenizer(str_outputs).input_ids
outputs = list(map(torch.LongTensor, outputs))
outputs

[tensor([ 3584,  6159,   286,   262,   584,  7328,   373,  7924,   355,  2092,
           355, 11609,     6,   268,   338, 41727,  6592,    13,  1406,    11,
         17826, 24868,   318,   262,   691,   530,   508,  1107, 12766,   284,
          1833,   644,   428,   318,   477,   546,    13,   383,  3437,  2241]),
 tensor([  257,  1218,  1057,   290,   373,  6655,   284,  1064, 33728,  1338,
          4224,  1107, 16403,   284,   617,  3354,   286,   262,  1382,   281,
         10059,  1204,  2854,   287,  1103,  1204,    13,   383, 16031,   287,
           428,  3807,   547,  8531,   290, 19989,  1103,  1204,  2694,    13])]

In [50]:
maxsize = max(map(len, outputs))
maxsize

40

In [51]:
import torch.nn.functional as F
outputs = [
    F.pad(
        output,
        (0, maxsize - len(output)),
        value=trainer.tokenizer.pad_token_id,
    )
    for output in outputs
]
outputs

[tensor([ 3584,  6159,   286,   262,   584,  7328,   373,  7924,   355,  2092,
           355, 11609,     6,   268,   338, 41727,  6592,    13,  1406,    11,
         17826, 24868,   318,   262,   691,   530,   508,  1107, 12766,   284,
          1833,   644,   428,   318,   477,   546,    13,   383,  3437,  2241]),
 tensor([  257,  1218,  1057,   290,   373,  6655,   284,  1064, 33728,  1338,
          4224,  1107, 16403,   284,   617,  3354,   286,   262,  1382,   281,
         10059,  1204,  2854,   287,  1103,  1204,    13,   383, 16031,   287,
           428,  3807,   547,  8531,   290, 19989,  1103,  1204,  2694,    13])]

In [52]:
sample_outputs = torch.vstack(outputs).to(device)
sample_outputs

tensor([[ 3584,  6159,   286,   262,   584,  7328,   373,  7924,   355,  2092,
           355, 11609,     6,   268,   338, 41727,  6592,    13,  1406,    11,
         17826, 24868,   318,   262,   691,   530,   508,  1107, 12766,   284,
          1833,   644,   428,   318,   477,   546,    13,   383,  3437,  2241],
        [  257,  1218,  1057,   290,   373,  6655,   284,  1064, 33728,  1338,
          4224,  1107, 16403,   284,   617,  3354,   286,   262,  1382,   281,
         10059,  1204,  2854,   287,  1103,  1204,    13,   383, 16031,   287,
           428,  3807,   547,  8531,   290, 19989,  1103,  1204,  2694,    13]],
       device='cuda:0')

In [53]:
trainer.ref_mean, trainer.ref_std = scores.mean(), scores.std()

In [54]:
all_scores_mean, all_scores_std = trainer.running_moments.update(scores)

In [55]:
trainer.ref_mean, all_scores_mean, trainer.ref_std, all_scores_std

(tensor(0.2266, device='cuda:0'),
 tensor(0.2266, device='cuda:0'),
 tensor(0.2521, device='cuda:0'),
 tensor(0.2521, device='cuda:0'))

In [56]:
clip_reward = trainer.config.method.cliprange_reward
clip_reward

10

In [57]:
scores = torch.clip(scores, -clip_reward, clip_reward)
scores

tensor([0.4048, 0.0483], device='cuda:0')

In [58]:
all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1)
all_tokens

tensor([[    1,    40,  1703, 44269,    25, 12550,     1,  3584,  6159,   286,
           262,   584,  7328,   373,  7924,   355,  2092,   355, 11609,     6,
           268,   338, 41727,  6592,    13,  1406,    11, 17826, 24868,   318,
           262,   691,   530,   508,  1107, 12766,   284,  1833,   644,   428,
           318,   477,   546,    13,   383,  3437,  2241],
        [50256, 50256, 50256,    40, 26399,   314,  3001,   257,  1218,  1057,
           290,   373,  6655,   284,  1064, 33728,  1338,  4224,  1107, 16403,
           284,   617,  3354,   286,   262,  1382,   281, 10059,  1204,  2854,
           287,  1103,  1204,    13,   383, 16031,   287,   428,  3807,   547,
          8531,   290, 19989,  1103,  1204,  2694,    13]], device='cuda:0')

In [59]:
samples

tensor([[    1,    40,  1703, 44269,    25, 12550,     1,  3584,  6159,   286,
           262,   584,  7328,   373,  7924,   355,  2092,   355, 11609,     6,
           268,   338, 41727,  6592,    13,  1406,    11, 17826, 24868,   318,
           262,   691,   530,   508,  1107, 12766,   284,  1833,   644,   428,
           318,   477,   546,    13,   383,  3437,  2241],
        [50256, 50256, 50256,    40, 26399,   314,  3001,   257,  1218,  1057,
           290,   373,  6655,   284,  1064, 33728,  1338,  4224,  1107, 16403,
           284,   617,  3354,   286,   262,  1382,   281, 10059,  1204,  2854,
           287,  1103,  1204,    13,   383, 16031,   287,   428,  3807,   547,
          8531,   290, 19989,  1103,  1204,  2694,    13]], device='cuda:0')

In [60]:
attention_mask = all_tokens.not_equal(trainer.tokenizer.pad_token_id).long().to(device)
attention_mask

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:0')

In [61]:
# values : Uses base_model's output, and goes through v_head(2 lyrs) one more time.
logits, *_, values = trainer.model(
    all_tokens,
    attention_mask=attention_mask,
)

In [62]:
values

tensor([[-0.1144, -1.6477, -0.7635, -0.7435, -0.4699, -0.5392, -0.7150, -1.2081,
         -1.1719, -0.7929, -1.2340, -1.2667, -1.4014, -1.4717, -0.7298, -0.7724,
         -0.4831, -0.8247, -0.5923, -0.5602, -0.3411, -0.7314, -0.4289, -0.5899,
         -0.2241, -0.8277, -1.2634,  0.2456, -0.2927, -0.9984, -0.8142, -1.0425,
         -1.0788, -1.0885, -1.2277, -0.7496, -0.9655, -0.7919, -1.1532, -1.0226,
         -0.8372, -0.6618, -0.3051, -0.2710, -0.8394, -0.8853, -0.8451],
        [-0.1301, -0.1811, -0.1861, -0.4533, -0.9991, -0.7957,  0.2317, -0.1800,
         -0.4699, -0.7473, -0.7988, -0.6913, -0.4419, -0.4495, -0.2182, -0.1093,
         -0.1637, -0.1946,  0.0114, -0.0950, -0.1235, -0.4455, -0.3586, -0.3516,
         -0.5608, -0.2130, -0.1212, -0.4105, -0.3970, -0.2640, -0.2569, -0.2457,
         -0.1830,  0.0729, -0.1008, -0.3603, -0.2137,  0.0231,  0.1038, -0.3651,
         -0.2889, -0.2839, -0.1293, -0.2133, -0.1907, -0.3309, -0.0599]],
       device='cuda:0', grad_fn=<SqueezeBac

In [63]:
logits.shape

torch.Size([2, 47, 50257])

In [64]:
values.shape

torch.Size([2, 47])

In [65]:
all_tokens.shape

torch.Size([2, 47])

In [66]:
samples.shape

torch.Size([2, 47])

In [67]:
import trlx.utils.modeling as modeling_utils
hidden_layers = modeling_utils.hf_get_decoder_blocks(trainer.model.base_model)
hidden_layers

ModuleList(
  (0-11): 12 x GPT2Block(
    (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (attn): GPT2Attention(
      (c_attn): Conv1D()
      (c_proj): Conv1D()
      (attn_dropout): Dropout(p=0.1, inplace=False)
      (resid_dropout): Dropout(p=0.1, inplace=False)
    )
    (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): GPT2MLP(
      (c_fc): Conv1D()
      (c_proj): Conv1D()
      (act): NewGELUActivation()
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
)

In [68]:
ref_logits = trainer.model.forward_hydra(
    all_tokens,
    attention_mask=attention_mask,
    return_dict=True,
).logits
ref_logits.shape

torch.Size([2, 47, 50257])

In [69]:
trainer.model

AutoModelForCausalLMWithHydraValueHead(
  (base_model): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  )
  (v

In [70]:
logits, *_, values = trainer.model.forward(
    all_tokens,
    attention_mask=attention_mask,
)

In [71]:
values

tensor([[-0.1144, -1.6477, -0.7635, -0.7435, -0.4699, -0.5392, -0.7150, -1.2081,
         -1.1719, -0.7929, -1.2340, -1.2667, -1.4014, -1.4717, -0.7298, -0.7724,
         -0.4831, -0.8247, -0.5923, -0.5602, -0.3411, -0.7314, -0.4289, -0.5899,
         -0.2241, -0.8277, -1.2634,  0.2456, -0.2927, -0.9984, -0.8142, -1.0425,
         -1.0788, -1.0885, -1.2277, -0.7496, -0.9655, -0.7919, -1.1532, -1.0226,
         -0.8372, -0.6618, -0.3051, -0.2710, -0.8394, -0.8853, -0.8451],
        [-0.1301, -0.1811, -0.1861, -0.4533, -0.9991, -0.7957,  0.2317, -0.1800,
         -0.4699, -0.7473, -0.7988, -0.6913, -0.4419, -0.4495, -0.2182, -0.1093,
         -0.1637, -0.1946,  0.0114, -0.0950, -0.1235, -0.4455, -0.3586, -0.3516,
         -0.5608, -0.2130, -0.1212, -0.4105, -0.3970, -0.2640, -0.2569, -0.2457,
         -0.1830,  0.0729, -0.1008, -0.3603, -0.2137,  0.0231,  0.1038, -0.3651,
         -0.2889, -0.2839, -0.1293, -0.2133, -0.1907, -0.3309, -0.0599]],
       device='cuda:0', grad_fn=<SqueezeBac

In [72]:
logits.shape

torch.Size([2, 47, 50257])

In [73]:
from trlx.utils.modeling import RunningMoments, logprobs_of_labels
logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:])
logprobs.shape

torch.Size([2, 46])

In [74]:
# Original value
ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:])
ref_logprobs.shape

torch.Size([2, 46])

In [75]:
n_samples: int = samples.shape[0]
n_samples

2

In [76]:
# Estimate the KL divergence between the model and reference model
start = prompt_tensors.shape[1] - 1
start

6

In [77]:
log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1]
log_ratio

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
       device='cuda:0', grad_fn=<MulBackward0>)

In [78]:
trainer.mean_kl = (log_ratio.exp() - 1 - log_ratio).mean().to(device)
trainer.mean_kl

tensor(0., device='cuda:0', grad_fn=<MeanBackward0>)

In [79]:
ends = start + attention_mask[:, start:].sum(1)
ends

tensor([47, 47], device='cuda:0')

In [80]:
values.shape

torch.Size([2, 47])

In [81]:
all_values = [values[ix, start: ends[ix]] for ix in range(n_samples)]
all_values

[tensor([-0.7150, -1.2081, -1.1719, -0.7929, -1.2340, -1.2667, -1.4014, -1.4717,
         -0.7298, -0.7724, -0.4831, -0.8247, -0.5923, -0.5602, -0.3411, -0.7314,
         -0.4289, -0.5899, -0.2241, -0.8277, -1.2634,  0.2456, -0.2927, -0.9984,
         -0.8142, -1.0425, -1.0788, -1.0885, -1.2277, -0.7496, -0.9655, -0.7919,
         -1.1532, -1.0226, -0.8372, -0.6618, -0.3051, -0.2710, -0.8394, -0.8853,
         -0.8451], device='cuda:0', grad_fn=<SliceBackward0>),
 tensor([ 0.2317, -0.1800, -0.4699, -0.7473, -0.7988, -0.6913, -0.4419, -0.4495,
         -0.2182, -0.1093, -0.1637, -0.1946,  0.0114, -0.0950, -0.1235, -0.4455,
         -0.3586, -0.3516, -0.5608, -0.2130, -0.1212, -0.4105, -0.3970, -0.2640,
         -0.2569, -0.2457, -0.1830,  0.0729, -0.1008, -0.3603, -0.2137,  0.0231,
          0.1038, -0.3651, -0.2889, -0.2839, -0.1293, -0.2133, -0.1907, -0.3309,
         -0.0599], device='cuda:0', grad_fn=<SliceBackward0>)]

In [82]:
all_logprobs = [logprobs[ix, start: ends[ix]] for ix in range(n_samples)]
all_logprobs

[tensor([-8.0056e+00, -7.1409e+00, -1.8413e+00, -1.4369e+00, -2.7020e+00,
         -1.2303e+00, -3.0204e+00, -2.9217e+00, -5.7553e+00, -7.2738e+00,
         -1.4152e+00, -9.8927e+00, -5.8801e+00, -5.2949e+00, -1.7356e+00,
         -1.4803e+01, -5.6885e+00, -1.8524e+00, -4.5328e+00, -1.5505e+00,
         -1.1950e+01, -4.3547e+00, -1.8808e+00, -2.2701e+00, -2.0085e+00,
         -9.3521e-01, -1.0281e+00, -3.1574e+00, -6.5488e+00, -1.0721e+00,
         -3.2327e+00, -2.0036e+00, -3.6637e+00, -2.1087e+00, -7.9226e-01,
         -1.2594e-02, -7.9068e-01, -2.4174e+00, -4.6679e+00, -4.7881e+00],
        device='cuda:0', grad_fn=<SliceBackward0>),
 tensor([-12.2430,  -7.2060,  -9.7343,  -5.3320,  -8.3455, -10.2839,  -5.4186,
          -4.7302, -13.8492, -16.6296,  -6.6775, -13.8959, -15.1789,  -5.2657,
          -8.3728,  -3.2701,  -4.7287,  -5.7416, -14.0443,  -6.4452,  -8.1692,
          -9.0305, -11.1253,  -5.2776,  -8.1605,  -0.9804,  -4.5875,  -3.3391,
         -12.2452,  -6.0990,  -6.0489, 

In [83]:
kl_penalty = trainer.kl_ctl.value * -log_ratio.cpu()
kl_penalty

tensor([[-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,
         -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
        [-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,
         -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.]],
       grad_fn=<MulBackward0>)

In [84]:
kl_penalty = [xs[start: ends[ix]] for ix, xs in enumerate(kl_penalty)]
kl_penalty

[tensor([-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,
         -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
        grad_fn=<SliceBackward0>),
 tensor([-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,
         -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
        grad_fn=<SliceBackward0>)]

In [85]:
rewards = kl_penalty[0]
rewards

tensor([-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,
        -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
       grad_fn=<SliceBackward0>)

In [86]:
trainer.kl_ctl.value

0.05

In [87]:
kl_penalty

[tensor([-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,
         -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
        grad_fn=<SliceBackward0>),
 tensor([-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,
         -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
        grad_fn=<SliceBackward0>)]

In [88]:
rewards

tensor([-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,
        -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
       grad_fn=<SliceBackward0>)

In [89]:
scores[0].cpu()

tensor(0.4048)

In [90]:
# rewards[-1] += scores[0].cpu()

In [91]:
rewards.shape, logprobs[0].shape, values[0].shape

(torch.Size([40]), torch.Size([46]), torch.Size([47]))

In [92]:
response_length = rewards.shape[0]
response_length

40

In [93]:
values.shape, rewards.shape

(torch.Size([2, 47]), torch.Size([40]))