In [1]:
import json
import os
import sys
from typing import List

import torch
from datasets import load_dataset
from transformers import pipeline

import trlx
from trlx.data.default_configs import TRLConfig, default_ppo_config
from trlx.utils.loading import get_pipeline, get_trainer
from trlx.utils import set_seed

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
default_config = default_ppo_config().to_dict()
default_config['train']['total_steps'] = 1
default_config['train']['tracker'] = 'tensorboard'
default_config['method']['num_rollouts'] = 2
default_config['method']['chunk_size'] = 2

config = TRLConfig.update(default_config, {})

In [3]:
# Take few words off of movies reviews as prompts
imdb = load_dataset("imdb", split="train+test")
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]

sentiment_fn = pipeline(
    "sentiment-analysis",
    "lvwerra/distilbert-imdb",
    top_k=2,
    truncation=True,
    batch_size=256,
    device=0 if int(os.environ.get("LOCAL_RANK", 0)) == 0 else -1,
)


def get_positive_score(scores):
    # Extract value associated with a positive sentiment from pipeline's output
    return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]


def reward_fn(samples: List[str], **kwargs) -> List[float]:
    sentiments = list(map(get_positive_score, sentiment_fn(samples)))
    return sentiments

Found cached dataset imdb (/home/sshuser/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


In [4]:
set_seed(config.train.seed)

trainer = get_trainer(config.train.trainer)(
    config=config,
    reward_fn=reward_fn,
    **config.train.trainer_kwargs,
)

[RANK 0] Initializing model: lvwerra/gpt2-imdb


{
    "method": {
        "name": "PPOConfig",
        "ppo_epochs": 4,
        "num_rollouts": 2,
        "chunk_size": 2,
        "init_kl_coef": 0.05,
        "target": 6,
        "horizon": 10000,
        "gamma": 1,
        "lam": 0.95,
        "cliprange": 0.2,
        "cliprange_value": 0.2,
        "vf_coef": 1,
        "scale_reward": "ignored",
        "ref_mean": null,
        "ref_std": null,
        "cliprange_reward": 10,
        "gen_kwargs": {
            "max_new_tokens": 40,
            "top_k": 0,
            "top_p": 1.0,
            "do_sample": true
        },
        "gen_experience_kwargs": null
    },
    "model": {
        "model_path": "lvwerra/gpt2-imdb",
        "model_arch_type": "causal",
        "num_layers_unfrozen": 2,
        "delta_kwargs": null
    },
    "optimizer": {
        "name": "adamw",
        "kwargs": {
            "lr": 0.0001,
            "betas": [
                0.9,
                0.95
            ],
            "eps": 1e-08,
     

In [5]:
batch_size = config.train.batch_size * int(os.environ.get("WORLD_SIZE", 1))
max_prompt_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"]
batch_size, max_prompt_length

(32, 984)

In [6]:
prompts = prompts or [trainer.tokenizer.bos_token] * batch_size
len(prompts)

50000

In [7]:
pipeline = get_pipeline(config.train.pipeline)(prompts, max_prompt_length, trainer.tokenizer)
trainer.add_prompt_pipeline(pipeline)

In [8]:
trainer.make_experience(config.method.num_rollouts)

[RANK 0] Collecting rollouts
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [9]:
eval_prompts=["It's a bad movie but "] * 64
eval_pipeline = get_pipeline(config.train.pipeline)(eval_prompts, max_prompt_length, trainer.tokenizer)
trainer.add_eval_pipeline(eval_pipeline)

In [10]:
from trlx.pipeline import _DATAPIPELINE
_DATAPIPELINE

{'basepipeline': trlx.pipeline.BasePipeline,
 'promptpipeline': trlx.pipeline.offline_pipeline.PromptPipeline}

In [11]:
trainer.eval_pipeline

<trlx.pipeline.offline_pipeline.PromptPipeline at 0x7f72d41b9630>

In [12]:
trainer.prepare_learning()
trainer.iter_count = 0
trainer.nth_evaluation = 0

results = trainer.evaluate()

[RANK 0] Evaluating model
[generation sweep 1/1 | eval batch 2/2]: 100%|██████████████████████| 2/2 [00:00<00:00,  2.74it/s]
[RANK 0] Computing rewards
[RANK 0] Summarizing evaluation


In [13]:
best_reward = -float("inf")

In [14]:
for batch in trainer.train_dataloader:
    batch = batch
batch

PPORLBatch(query_tensors=tensor([[20670,   257,  1256,   286],
        [18357,   257,  1200,   286]]), response_tensors=tensor([[  670,  1816,   656,   262,  7110,    11,   475,   356,   991,  1392,
           257,  4735,  2126,   644,   284,  1607,   422,   617, 28201,   582,
           492, 11246,  3307,  1234,   287,   257,   264, 16406, 26951,   422,
           597,   640,   329, 42402,   492, 11246,  6877,  4190,  8188,   492],
        [  257,  1218,  1363,   290,   262,  3173,   286,   262,  2975,  1422,
           470,  1107,  4174,   284,   502,  7620,    11,   314,   750,   281,
          6672, 16835,   371,  2470,   284,   257,  1295,   884,   355,   428,
            11, 19167,  4152,   287,  9131,   357,    40,   423,   281,  6565]]), logprobs=tensor([[-4.6365e+00, -1.3491e+00, -1.0836e-02, -2.0695e+00, -4.1903e+00,
         -1.0977e+00, -1.4188e+00, -5.0114e+00, -4.0047e+00, -1.7658e+00,
         -1.4360e+00, -4.2761e+00, -2.4873e+00, -3.3030e+00, -3.3897e+00,
         -1.7

In [15]:
loss, stats = trainer.loss(batch)
loss

tensor(0.1329, device='cuda:0', grad_fn=<AddBackward0>)

In [16]:
batch.rewards

tensor([[-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
         -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
         -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
         -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
         -0.0000, -0.0000, -0.0000, 0.8653],
        [-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
         -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
         -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
         -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
         -0.0000, -0.0000, -0.0000, 0.0793]])

In [17]:
stats

{'losses/total_loss': 0.13293874263763428,
 'losses/policy_loss': -2.3841858265427618e-08,
 'losses/value_loss': 0.13293877243995667,
 'values/mean': tensor(-0.0566, device='cuda:0', grad_fn=<DivBackward0>),
 'values/min': tensor(-3.2174, device='cuda:0', grad_fn=<MinBackward1>),
 'values/max': tensor(0.4171, device='cuda:0', grad_fn=<MaxBackward1>),
 'values/std': tensor(0.4370, device='cuda:0', grad_fn=<SqrtBackward0>),
 'values/values_error': tensor(0.2659, device='cuda:0', grad_fn=<DivBackward0>),
 'values/clipfrac': tensor(0., device='cuda:0'),
 'old_values/mean': tensor(-0.0566, device='cuda:0'),
 'old_values/min': tensor(-3.2174, device='cuda:0'),
 'old_values/max': tensor(0.4171, device='cuda:0'),
 'old_values/std': tensor(0.4370, device='cuda:0'),
 'returns/mean': tensor(0.1989, device='cuda:0'),
 'returns/min': tensor(-0.2411, device='cuda:0'),
 'returns/max': tensor(0.8653, device='cuda:0'),
 'returns/std': tensor(0.2925, device='cuda:0'),
 'policy/approx_kl': 0.0,
 'policy/

In [18]:
query_tensors = batch.query_tensors
response_tensors = batch.response_tensors
old_logprobs = batch.logprobs
old_values = batch.values
old_rewards = batch.rewards
response_length = old_rewards.shape[1]  # 40

In [19]:
trainer.config.method

PPOConfig(name='PPOConfig', ppo_epochs=4, num_rollouts=2, chunk_size=2, init_kl_coef=0.05, target=6, horizon=10000, gamma=1, lam=0.95, cliprange=0.2, cliprange_value=0.2, vf_coef=1, scale_reward='ignored', ref_mean=None, ref_std=None, cliprange_reward=10, gen_kwargs={'max_new_tokens': 40, 'top_k': 0, 'top_p': 1.0, 'do_sample': True}, gen_experience_kwargs=None)

In [20]:
# get_advantages_and_returns(old_values, old_rewards, response_length)
lastgaelam = 0
advantages_reversed = []
for t in reversed(range(response_length)):
    # Get next token's value
    nextvalues = old_values[:, t + 1] if t < response_length - 1 else 0.0
    delta = old_rewards[:, t] + trainer.config.method.gamma * \
        nextvalues - old_values[:, t]
    lastgaelam = delta + trainer.config.method.gamma * \
        trainer.config.method.lam * lastgaelam
    advantages_reversed.append(lastgaelam)

In [21]:
advantages = torch.stack(advantages_reversed[::-1], dim=1)

In [22]:
advantages

tensor([[ 0.5445, -0.3393, -0.2265,  0.3362,  0.3114,  0.3457,  0.2286,  0.6270,
          0.0851,  0.3294,  0.6319,  0.3858,  0.5135,  0.5743,  0.4898,  0.4111,
          0.5383,  0.8656,  0.6726,  0.3542,  0.3168,  0.2019,  0.3104,  0.3112,
          0.3779,  0.2985,  0.3322,  0.2030,  0.5828,  0.3228,  0.5661,  0.2726,
          0.5033,  0.5753,  0.4507,  0.3857,  0.7315,  0.5187,  0.4311,  0.8034],
        [ 0.0823, -0.0236, -0.3469, -0.3913, -0.5028, -0.4381, -0.3602,  0.0734,
          0.2247, -0.4628,  3.1330,  0.1851,  0.1343,  0.1699,  0.3894,  0.1319,
          0.2259,  0.0815,  0.0689,  0.1625, -0.0895,  0.0614,  0.0512,  0.4468,
          0.2123,  0.0596,  0.2504,  0.1480,  0.7243,  0.3388,  0.1566, -0.1626,
          0.0956,  0.0620,  0.1907,  0.1005, -0.1866, -0.2902, -0.1056, -0.3378]])

In [23]:
old_values

tensor([[-0.4607,  0.4061,  0.2820, -0.2639, -0.2235, -0.2405, -0.1120, -0.4790,
          0.0671, -0.1608, -0.4317, -0.1662, -0.2683, -0.3003, -0.1914, -0.0921,
         -0.1924, -0.4764, -0.2498,  0.0863,  0.1396,  0.2646,  0.1715,  0.1863,
          0.1385,  0.2328,  0.2157,  0.3551,  0.0044,  0.2806,  0.0656,  0.3727,
          0.1672,  0.1240,  0.2711,  0.3553,  0.0461,  0.2849,  0.3941,  0.0620],
        [-0.2120, -0.1072,  0.1986,  0.2235,  0.3098,  0.2233,  0.1274, -0.3026,
         -0.4426,  0.2217, -3.2174, -0.2603, -0.2027, -0.2299, -0.4299, -0.1658,
         -0.2485, -0.1000, -0.0840, -0.1694,  0.0781, -0.0698, -0.0570, -0.4302,
         -0.1851, -0.0295, -0.2077, -0.0979, -0.6380, -0.2356, -0.0456,  0.2655,
          0.0121,  0.0488, -0.0704,  0.0248,  0.3027,  0.3917,  0.2019,  0.4171]])

In [24]:
returns = advantages + old_values
returns

tensor([[ 0.0838,  0.0668,  0.0555,  0.0723,  0.0879,  0.1052,  0.1166,  0.1479,
          0.1522,  0.1687,  0.2003,  0.2195,  0.2452,  0.2739,  0.2984,  0.3190,
          0.3459,  0.3892,  0.4228,  0.4405,  0.4563,  0.4664,  0.4820,  0.4975,
          0.5164,  0.5313,  0.5480,  0.5581,  0.5873,  0.6034,  0.6317,  0.6453,
          0.6705,  0.6993,  0.7218,  0.7411,  0.7777,  0.8036,  0.8251,  0.8653],
        [-0.1297, -0.1309, -0.1482, -0.1678, -0.1929, -0.2148, -0.2328, -0.2292,
         -0.2179, -0.2411, -0.0844, -0.0752, -0.0685, -0.0600, -0.0405, -0.0339,
         -0.0226, -0.0185, -0.0151, -0.0070, -0.0114, -0.0084, -0.0058,  0.0165,
          0.0272,  0.0301,  0.0427,  0.0501,  0.0863,  0.1032,  0.1110,  0.1029,
          0.1077,  0.1108,  0.1203,  0.1253,  0.1160,  0.1015,  0.0962,  0.0793]])

In [25]:
from trlx.utils.modeling import whiten
advantages = whiten(advantages)

In [26]:
advantages

tensor([[ 0.6413, -1.3197, -1.0693,  0.1791,  0.1240,  0.2002, -0.0596,  0.8243,
         -0.3780,  0.1641,  0.8353,  0.2891,  0.5724,  0.7073,  0.5198,  0.3453,
          0.6275,  1.3536,  0.9255,  0.2190,  0.1360, -0.1189,  0.1219,  0.1236,
          0.2717,  0.0955,  0.1702, -0.1163,  0.7263,  0.1494,  0.6891,  0.0380,
          0.5498,  0.7095,  0.4331,  0.2890,  1.0562,  0.5840,  0.3896,  1.2156],
        [-0.3842, -0.6192, -1.3364, -1.4350, -1.6823, -1.5389, -1.3661, -0.4039,
         -0.0683, -1.5936,  6.3842, -0.1561, -0.2689, -0.1899,  0.2971, -0.2742,
         -0.0655, -0.3859, -0.4138, -0.2064, -0.7655, -0.4305, -0.4533,  0.4244,
         -0.0958, -0.4345, -0.0113, -0.2385,  1.0401,  0.1848, -0.2193, -0.9275,
         -0.3547, -0.4293, -0.1437, -0.3438, -0.9809, -1.2107, -0.8012, -1.3162]])

In [27]:
tokens = torch.cat((query_tensors, response_tensors), dim=1).to('cuda:0')
tokens

tensor([[20670,   257,  1256,   286,   670,  1816,   656,   262,  7110,    11,
           475,   356,   991,  1392,   257,  4735,  2126,   644,   284,  1607,
           422,   617, 28201,   582,   492, 11246,  3307,  1234,   287,   257,
           264, 16406, 26951,   422,   597,   640,   329, 42402,   492, 11246,
          6877,  4190,  8188,   492],
        [18357,   257,  1200,   286,   257,  1218,  1363,   290,   262,  3173,
           286,   262,  2975,  1422,   470,  1107,  4174,   284,   502,  7620,
            11,   314,   750,   281,  6672, 16835,   371,  2470,   284,   257,
          1295,   884,   355,   428,    11, 19167,  4152,   287,  9131,   357,
            40,   423,   281,  6565]], device='cuda:0')

In [28]:
attention_mask = tokens.not_equal(trainer.tokenizer.pad_token_id).long().to('cuda:0')
attention_mask

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:0')

In [29]:
outputs = trainer.model(tokens, attention_mask, return_dict=True)
logits = outputs.logits

In [30]:
logits.shape

torch.Size([2, 44, 50257])

In [31]:
values_pred = outputs.value
values_pred = values_pred[:, :-1]
values_pred.shape

torch.Size([2, 43])

In [32]:
from trlx.utils.modeling import logprobs_of_labels

# Softmax probabilities
logprobs = logprobs_of_labels(logits[:, :-1, :], tokens[:, 1:])
logprobs.shape

torch.Size([2, 43])

In [33]:
start = query_tensors.shape[1] - 1
end = start + response_length
logprobs, values_pred, mask = (
    logprobs[:, start:end],
    values_pred[:, start:end],
    attention_mask[:, start:end],
)

In [34]:
values_pred.shape

torch.Size([2, 40])

In [35]:
logprobs=logprobs.to('cuda:0')
values=values_pred.to('cuda:0')
old_logprobs=old_logprobs.to('cuda:0')
old_values=old_values.to('cuda:0')
advantages=advantages.to('cuda:0')
returns=returns.to('cuda:0')
mask=mask.to('cuda:0')

loss, stats = trainer.config.method.loss(
    logprobs=logprobs,
    values=values_pred,
    old_logprobs=old_logprobs,
    old_values=old_values,
    advantages=advantages,
    returns=returns,
    mask=mask,
)

In [36]:
values_clipped = torch.clamp(
    values_pred,
    old_values - trainer.config.method.cliprange_value,
    old_values + trainer.config.method.cliprange_value,
)
n = mask.sum()

In [37]:
vf_loss1 = (values - returns) ** 2
vf_loss2 = (values_clipped - returns) ** 2
vf_loss = 0.5 * torch.sum(torch.max(vf_loss1, vf_loss2) * mask) / n
vf_clipfrac = torch.sum((vf_loss2 > vf_loss1).float() * mask) / n

log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio)

In [38]:
# Unbiased KL-div estimates (`k3`). Ref: http://joschu.net/blog/kl-approx.html
with torch.no_grad():
    approx_kl = torch.mean((ratio - 1) - log_ratio)

pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(
    ratio,
    1.0 - trainer.config.method.cliprange,
    1.0 + trainer.config.method.cliprange,
)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / n
pg_clipfrac = torch.sum((pg_loss2 > pg_loss1).float() * mask) / n

In [39]:
loss = pg_loss + trainer.config.method.vf_coef * vf_loss

In [40]:
trainer.config.method.vf_coef

1