In [1]:
%load_ext autoreload
%autoreload 2

In [7]:
import torch

from minichatgpt.experiments.imdb import config, sent_kwargs
from minichatgpt import PPOTrainer, Lab

In [3]:
# For the sake of the speed of this demonstration, the batch_size is temporarily decreased from 256 to 4
batch_size = 4
config.batch_size = batch_size
config.forward_batch_size = batch_size//2

In [4]:
lab = Lab(config)
dataset = lab.build_dataset(dataset_name="imdb",input_min_text_length=2,input_max_text_length=8)
new_policy, old_policy, tokenizer = lab.init_policies_tokenizer()
lab.set_generation_config(do_sample=True,output_min_length=4,output_max_length=16,pad_token_id=tokenizer.eos_token_id)
ppo_trainer = lab.init_ppo_trainer(config, new_policy, old_policy, tokenizer, dataset)
reward_model = lab.init_reward_model()

Found cached dataset imdb (/Users/carson/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)
Loading cached processed dataset at /Users/carson/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-2bd6a5d7d39a840d.arrow
Loading cached processed dataset at /Users/carson/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-2ecf25d24c93f132.arrow


In [16]:
for batch_step, batch in enumerate(ppo_trainer.dataloader):
    
    queries = batch['input_ids']
    
    #### Get response from gpt2
    responses = []
    for query in queries:
        gen_len = lab.output_length_sampler()
        lab.generation_kwargs["max_new_tokens"] = gen_len
        response = ppo_trainer.generate(query, **lab.generation_kwargs)
        responses.append(response.squeeze()[-gen_len:])

    batch['response'] = [tokenizer.decode(r.squeeze()) for r in responses]

    #### Compute sentiment score
    texts = [q + r for q,r in zip(batch['query'], batch['response'])]
    pipe_outputs = lab.reward_model(texts, **sent_kwargs)
    rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
    break
    
print(rewards)

[tensor(-0.9046), tensor(1.5498), tensor(0.3456), tensor(2.5015)]


In [17]:
queries, responses, scores = ppo_trainer._step_safety_checker(batch_size, queries, responses, rewards)
logprobs, ref_logprobs, values = ppo_trainer.batched_forward_pass(queries, responses)
rewards, non_score_reward = ppo_trainer.compute_rewards(scores, logprobs, ref_logprobs)
print(rewards)

[tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.9046]), tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        1.5498]), tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        -0.0000, 0.3456]), tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        -0.0000, 2.5015])]


In [22]:
idx = list(range(config.batch_size))

for idx in range(config.batch_size):

    train_stats = ppo_trainer.train_minibatch(
        logprobs[idx].unsqueeze(0),
        values[idx].unsqueeze(0),
        rewards[idx].unsqueeze(0),
        queries[idx].unsqueeze(0),
        responses[idx].unsqueeze(0),
        torch.cat([queries[idx], responses[idx]]).unsqueeze(0),
    )
    
    break
    
train_stats

{'loss/policy': tensor(-2.3842e-08, grad_fn=<MeanBackward0>),
 'loss/value': tensor(2.1323, grad_fn=<MulBackward0>),
 'loss/total': tensor(0.2132, grad_fn=<AddBackward0>),
 'policy/entropy': tensor(4.6713, grad_fn=<MeanBackward0>),
 'policy/approxkl': tensor(0., grad_fn=<MulBackward0>),
 'policy/policykl': tensor(0., grad_fn=<MeanBackward0>),
 'policy/clipfrac': tensor(0., dtype=torch.float64),
 'policy/advantages': tensor([[ 0.5925,  0.4375,  1.0873, -1.1887, -0.9287]]),
 'policy/advantages_mean': tensor(2.3842e-08),
 'policy/ratio': tensor([[1., 1., 1., 1., 1.]], grad_fn=<ExpBackward0>),
 'returns/mean': tensor(-0.8628),
 'returns/var': tensor(0.0058),
 'val/vpred': tensor(-2.5561, grad_fn=<MeanBackward0>),
 'val/error': tensor(3.2110, grad_fn=<MeanBackward0>),
 'val/clipfrac': tensor(0.4000, dtype=torch.float64),
 'val/mean': tensor(-1.3650),
 'val/var': tensor(3.8172),
 'time/ppo/optimizer_step': tensor([1.1034])}

In [18]:
ppo_trainer.config.ppo_epochs

4