Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] support for large forward_batch_size in seq2seq models #100

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/scripts/ppo-sentiment-t5-small.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
model_name="lvwerra/t5-imdb",
learning_rate=5e-5,
batch_size=256,
forward_batch_size=1
forward_batch_size=16,
)
# We then define the arguments to pass to the sentiment analysis pipeline.
# We set `return_all_scores` to True to get the sentiment score for each token.
Expand Down
3 changes: 3 additions & 0 deletions trl/models/modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ def forward(
attention_mask=None,
**kwargs,
):
if attention_mask is None:
attention_mask = input_ids.ne(self.pretrained_model.config.pad_token_id).float()

base_model_output = self.pretrained_model(
input_ids=input_ids,
past_key_values=past_key_values,
Expand Down
132 changes: 117 additions & 15 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import random
import time
import warnings
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

import datasets
import torch
Expand Down Expand Up @@ -464,6 +464,7 @@ def batched_forward_pass(
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
}
model_input = decoder_input_ids
else:
input_ids = self.data_collator([torch.cat([q, r]) for q, r in zip(query_batch, response_batch)])[
"input_ids"
Expand All @@ -472,17 +473,17 @@ def batched_forward_pass(
input_kwargs = {
"input_ids": input_ids,
}
model_input = input_ids

with torch.no_grad():
logits, _, v = self.model(**input_kwargs)
ref_logits, _, _ = self.ref_model(**input_kwargs)

if self.is_encoder_decoder:
logprobs = logprobs_from_logits(logits[:, :-1, :], decoder_input_ids[:, 1:])
ref_logprobs = logprobs_from_logits(ref_logits[:, :-1, :], decoder_input_ids[:, 1:])
else:
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
ref_logprobs = logprobs_from_logits(ref_logits[:, :-1, :], input_ids[:, 1:])
# mask out pad tokens
logits, ref_logits, v = self.remove_padding(logits, ref_logits, v, input_kwargs)

logprobs = logprobs_from_logits(logits[:, :-1, :], model_input[:, 1:])
ref_logprobs = logprobs_from_logits(ref_logits[:, :-1, :], model_input[:, 1:])

for j in range(fbs):
if self.is_encoder_decoder:
Expand All @@ -493,6 +494,15 @@ def batched_forward_pass(
start = len(query_batch[j]) - 1
end = len(query_batch[j]) + len(response_batch[j]) - 1

if self.tokenizer.padding_side == "left":
offset = len(torch.where(model_input[j] == self.tokenizer.pad_token_id)[0])
start += offset
end += offset
end = min(end, model_input[j].shape[-1] - 1)
# start, end = self._return_slice_index(model_input[j], query_batch[j])
# indexes = torch.where(model_input[j] != self.tokenizer.eos_token_id)[0]
# start, end = indexes[0], indexes[-1]

if len(logprobs[j, start:end]) < 2:
raise ValueError("Responses are too short. Make sure they are at least 4 tokens long.")

Expand All @@ -502,6 +512,40 @@ def batched_forward_pass(

return all_logprobs, all_ref_logprobs, all_values

def remove_padding(
self,
logits: torch.FloatTensor,
ref_logits: torch.FloatTensor,
values: torch.FloatTensor,
input_kwargs: Dict[str, torch.LongTensor],
):
"""
Remove padding from logits and values.

Args:
logits (`torch.FloatTensor`):
Logits from the model, shape (`batch_size`, `response_length`, `vocab_size`)
ref_logits (`torch.FloatTensor`):
Logits from the reference model, shape (`batch_size`, `response_length`, `vocab_size`)
values (`torch.FloatTensor`):
Values from the value head, shape (`batch_size`, `response_length`)
input_kwargs (`Dict[str, torch.LongTensor]`):
Input kwargs for the model
"""
if "decoder_input_ids" in input_kwargs:
input_ids = input_kwargs["decoder_input_ids"]
else:
input_ids = input_kwargs["input_ids"]

if hasattr(self.tokenizer, "pad_token_id") and self.tokenizer.pad_token_id is not None:
mask = (input_ids != self.tokenizer.pad_token_id).float()

logits = logits * mask.unsqueeze(-1)
ref_logits = ref_logits * mask.unsqueeze(-1)
values = values * mask

return logits, ref_logits, values

def train_minibatch(
self,
logprobs: torch.FloatTensor,
Expand Down Expand Up @@ -600,10 +644,6 @@ def loss(
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)

returns = advantages + values
advantages = whiten(advantages)
advantages = advantages.detach()

input_kwargs = {
"input_ids": model_input,
}
Expand All @@ -613,8 +653,32 @@ def loss(
input_kwargs["decoder_input_ids"] = response
model_input = response

if hasattr(self.tokenizer, "pad_token_id") and self.tokenizer.pad_token_id is not None:
attention_mask = model_input.ne(self.tokenizer.pad_token_id).float()
else:
attention_mask = torch.ones_like(model_input)

logits, _, vpred = self.model(**input_kwargs)

# advantages = advantages * attention_mask
# values = values * attention_mask

logits = logits * attention_mask.unsqueeze(-1)
old_logprobs = old_logprobs * attention_mask.unsqueeze(-1)
values = values * attention_mask.unsqueeze(-1)

returns = advantages + values
advantages = whiten(advantages)
advantages = advantages.detach()

# logprob = logprobs_from_logits(logits[:, :-1, :], model_input[:, 1:])
# start, end = self._return_slice_index(model_input, query)
# indexes = torch.where(model_input != self.tokenizer.pad_token_id)[0]
# start, end = indexes[0], indexes[-1]

# logprob = logprob[:, start:end]
# vpred = vpred[:, start - 1 : end - 1]

if self.is_encoder_decoder:
logprob = logprobs_from_logits(logits[:, :-1, :], model_input[:, 1:])
start, end = 1, response.shape[-1] - 1
Expand All @@ -628,15 +692,15 @@ def loss(

vf_losses1 = (vpred - returns) ** 2
vf_losses2 = (vpredclipped - returns) ** 2
vf_loss = 0.5 * torch.mean(torch.max(vf_losses1, vf_losses2))
vf_clipfrac = torch.mean(torch.gt(vf_losses2, vf_losses1).double())
vf_loss = 0.5 * torch.mean(torch.max(vf_losses1, vf_losses2) * attention_mask.unsqueeze(-1))
vf_clipfrac = torch.mean(torch.gt(vf_losses2, vf_losses1).double() * attention_mask.unsqueeze(-1))

ratio = torch.exp(logprob - old_logprobs)
pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange)

pg_loss = torch.mean(torch.max(pg_losses, pg_losses2))
pg_clipfrac = torch.mean(torch.gt(pg_losses2, pg_losses).double())
pg_loss = torch.mean(torch.max(pg_losses, pg_losses2) * attention_mask.unsqueeze(-1))
pg_clipfrac = torch.mean(torch.gt(pg_losses2, pg_losses).double() * attention_mask.unsqueeze(-1))

loss = pg_loss + self.config.vf_coef * vf_loss

Expand Down Expand Up @@ -668,6 +732,44 @@ def loss(
)
return pg_loss, self.config.vf_coef * vf_loss, flatten_dict(stats)

def _return_slice_index(
self,
model_input,
query=None,
):
"""
Return slice index for the query and response. This is used to slice the
logits and values to match the length of the response. For encoder-decoder


Args:
model_input (`torch.LongTensor`):
Model input tensor - either `input_ids` or `decoder_input_ids`
depending on the model type
query (`torch.LongTensor`, optional):
Query input tensor. Only used for decoder models. Defaults to None.
"""
if self.is_encoder_decoder:
# Enc-Dec models have a special token at the beginning of the response
# that we need to manually remove
indexes = torch.where(model_input != self.tokenizer.pad_token_id)[0]
start = 1
if len(indexes) > 1:
end = indexes[1].item() - 1
else:
end = model_input.shape[-1] - 1
else:
indexes = torch.where(model_input == self.tokenizer.pad_token_id)[0]
if hasattr(self.tokenizer, "padding_side") and self.tokenizer.padding_side == "left":
start = indexes[-1].item() + 1 if len(indexes) > 0 else 0
end = model_input.shape[-1] - 1
else:
# For decoder based models, always start with the first token of the query
start = max(query.shape) - 1
# .. and finish in the first occurence of `pad` token.
end = indexes[0].item() - 1 if len(indexes) > 0 else model_input.shape[-1] - 1
return start, end

def record_step_stats(self, kl_coef: float, **data):
"""
Record training step statistics.
Expand Down