Skip to content

Commit

Permalink
feat: Sequential beam search (#26304)
Browse files Browse the repository at this point in the history
  • Loading branch information
Saibo-creator committed Jan 19, 2024
1 parent 268fc1f commit d4fc1eb
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 43 deletions.
3 changes: 2 additions & 1 deletion src/transformers/generation/configuration_utils.py
Expand Up @@ -200,7 +200,8 @@ class GenerationConfig(PushToHubMixin):
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
prompt, usually at the expense of poorer quality.
low_memory (`bool`, *optional*):
Switch to sequential topk for contrastive search to reduce peak memory. Used with contrastive search.
Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory.
Used with beam search and contrastive search.
> Parameters that define the output variables of `generate`
Expand Down
229 changes: 187 additions & 42 deletions src/transformers/generation/utils.py
Expand Up @@ -1558,6 +1558,7 @@ def generate(
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
sequential=generation_config.low_memory,
**model_kwargs,
)

Expand Down Expand Up @@ -1951,8 +1952,7 @@ def contrastive_search(
model_kwargs["past_key_values"] = tuple(new_key_values)

if sequential:
all_outputs = {key: [] for key in outputs} # defined in first loop iteration
all_last_hstates, all_hstates, all_logits = [], [], []
all_outputs = []
for i in range(top_k):
# compute the candidate tokens by the language model and collect their hidden_states
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
Expand All @@ -1963,32 +1963,8 @@ def contrastive_search(
output_hidden_states=True,
output_attentions=output_attentions,
)
for key in all_outputs:
all_outputs[key].append(outputs[key])

if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states

else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states

all_last_hstates.append(torch.squeeze(next_hidden, 0))
all_hstates.append(full_hidden_states)
all_logits.append(outputs.logits[:, -1, :])

# stack hidden states
next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0)
final_full_hstates = [0 for i in range(len(full_hidden_states))]
for layer in range(len(full_hidden_states)):
final_full_hstates[layer] = torch.stack(
[torch.squeeze(all_hstates[i][layer], 0) for i in range(top_k)], dim=0
)
full_hidden_states = tuple(final_full_hstates)

# stack logits
logits = torch.cat(all_logits, dim=0)
all_outputs.append(outputs)
outputs = stack_model_outputs(all_outputs)

else:
# compute the candidate tokens by the language model and collect their hidden_states
Expand All @@ -2001,15 +1977,15 @@ def contrastive_search(
output_hidden_states=True,
output_attentions=output_attentions,
)
# name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states
else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states
# name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states
else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states

logits = outputs.logits[:, -1, :]
logits = outputs.logits[:, -1, :]

context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)

Expand Down Expand Up @@ -2747,6 +2723,7 @@ def beam_search(
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
sequential: Optional[bool] = None,
**model_kwargs,
) -> Union[GenerateBeamOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -2792,6 +2769,10 @@ def beam_search(
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
sequential (`bool`, defaults to `False`):
By default, beam search has `batch_size * num_beams` as effective batch size (see `beam_search()` for
more details). This flag will avoid parallelizing the beam search and will instead run beam search
sequentially.
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand Down Expand Up @@ -2858,6 +2839,7 @@ def beam_search(
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
sequential = sequential if sequential is not None else self.generation_config.low_memory
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use"
Expand Down Expand Up @@ -2932,12 +2914,39 @@ def beam_search(

model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# if sequential is True, split the input to batches of batch_size and run sequentially
if sequential:
if any(
model_name in self.__class__.__name__.lower()
for model_name in ["fsmt", "reformer", "bloom", "ctrl", "gpt_bigcode", "transo_xl", "xlnet", "cpm"]
):
raise RuntimeError(
f"Currently generation for {self.__class__.__name__} is not supported "
f"for `low_memory beam_search`. Please open an issue on GitHub if you need this feature."
)

inputs_per_sub_batches = _split_model_inputs(
model_inputs, split_size=batch_size, full_batch_size=batch_beam_size
)
outputs_per_sub_batch = [
self(
**inputs_per_sub_batch,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
for inputs_per_sub_batch in inputs_per_sub_batches
]

outputs = stack_model_outputs(outputs_per_sub_batch)

else: # Unchanged original behavior
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)

if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
Expand Down Expand Up @@ -4656,3 +4665,139 @@ def _ranking_fast(
contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K]
_, selected_idx = contrastive_score.max(dim=-1) # [B]
return selected_idx


def _split(data, full_batch_size: int, split_size: int = None):
"""
Takes care of three cases:
1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim
2. data is a tuple: e.g. hidden_states, attentions etc. Keep the tuple as it is and split each tensor in it and
return a list of tuples
3. data is a tuple of tuples, e.g. past_key_values. Keep the tuple as it is and split each tuple in it and
return a list of tuples of tuples
(see documentation of ModelOutput)
"""
if data is None:
return [None] * (full_batch_size // split_size)
if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
elif isinstance(data, tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0], tuple):
return [
tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data)
for i in range(0, full_batch_size, split_size)
]

else:
return [
tuple(sub_tensor[i : i + split_size] for sub_tensor in data)
for i in range(0, full_batch_size, split_size)
]
else:
raise ValueError(f"Unexpected attribute type: {type(data)}")


def _split_model_inputs(
model_input: Union[ModelOutput, Dict], split_size: int, full_batch_size: int
) -> List[Union[ModelOutput, Dict]]:
"""
Split a ModelOutput object (or its subclasses) or Dict into a list of same-class objects based on a specified split
size. The input object is dict when it was prepared for forward pass and ModelOutput when it was returned from
previous forward pass.
"""
# Edge case: if model_input is None, return a list of Nones
# this happens with Whisper where encoder_outputs is None
if model_input is None:
return [model_input] * (full_batch_size // split_size)
# Infer the class from the object
model_output_cls = type(model_input)
if (full_batch_size % split_size) != 0:
raise ValueError("`full_batch_size` must be divisible by `split_size`")

if split_size > full_batch_size:
raise ValueError("`split_size` must be smaller or equal to `full_batch_size`")

# Helper function to split tensors or tuples of tensors

# Find all the dataclass fields (e.g., last_hidden_state, pooler_output etc.) and split them
keys = (
model_input.__dataclass_fields__.keys() if hasattr(model_input, "__dataclass_fields__") else model_input.keys()
)
# We only keep keys that are in the model_input
keys = [k for k in keys if k in model_input]
# Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a
# ModelOutput object.
# bool should not be split but replicated for each split
bool_keys = [k for k in keys if isinstance(model_input[k], bool)]
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and not k == "encoder_outputs"]

# we split the tensors and tuples of tensors
data_split_list = [
{k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys}
for i in range(full_batch_size // split_size)
]
# bool values are the same and replicated for each split
bool_data = {k: model_input[k] for k in bool_keys}
# encoder_outputs is a ModelOutput object and should be split by its own
if "encoder_outputs" in model_input:
encoder_outputs_split = _split_model_inputs(model_input["encoder_outputs"], split_size, full_batch_size)
data_split_list = [
{**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list)
]

# Convert each dictionary in the list to an object of the inferred class
split_model_inputs: List[Union[ModelOutput, Dict]] = [
model_output_cls(**data_split, **bool_data) for data_split in data_split_list
]

return split_model_inputs


def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
"""
Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
specific ModelOutput subclass from the list provided.
"""
if not model_outputs:
raise ValueError("Input list is empty.")

# Infer the class from the first object in the list
model_output_cls = type(model_outputs[0])

# Ensure all objects are of the same type
if not all(isinstance(obj, model_output_cls) for obj in model_outputs):
raise ValueError("All elements in the list should be of the same type.")

# Helper function to concat tensors or tuples of tensors
def _concat(data):
"""
Reverse of `_split` function above.
"""
if any(data is None for data in data):
return None
if isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
elif isinstance(data[0], tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0][0], tuple):
return tuple(
tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0])))
for i in range(len(data[0]))
)
else:
return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0])))
elif isinstance(data[0], (int, float)):
# If the elements are integers or floats, return a tensor
return torch.tensor(data)
else:
raise ValueError(f"Unexpected attribute type: {type(data[0])}")

# Use a dictionary comprehension to gather attributes from all objects and concatenate them
concatenated_data = {
k: _concat([getattr(model_output, k) for model_output in model_outputs])
for k in model_output_cls.__dataclass_fields__.keys()
}

# Return a new object of the inferred class with the concatenated attributes
return model_output_cls(**concatenated_data)
46 changes: 46 additions & 0 deletions tests/generation/test_utils.py
Expand Up @@ -1539,6 +1539,39 @@ def test_contrastive_generate_low_memory(self):
)
self.assertListEqual(low_output.tolist(), high_output.tolist())

def test_beam_search_low_memory(self):
# Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes:
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in [
"bloom",
"ctrl",
"gptbigcode",
"transo_xl",
"xlnet",
"cpm",
]
):
self.skipTest("May fix in the future: need model-specific fixes")
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=2)
# batch_size=1 is ok, but batch_size>1 will cause non-identical output

config.use_cache = True
config.is_decoder = True

# test output equality of low versus high memory
model = model_class(config).to(torch_device).eval()

low_output = model.generate(input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True)

high_output = model.generate(
input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=False
)
self.assertListEqual(low_output.tolist(), high_output.tolist())

@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
def test_assisted_decoding_matches_greedy_search(self):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
Expand Down Expand Up @@ -2766,6 +2799,19 @@ def test_transition_scores_group_beam_search_encoder_decoder(self):

self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))

def test_beam_search_low_memory(self):
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
model_inputs = tokenizer("I", return_tensors="pt")["input_ids"]

low_output = model.generate(model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=True)

high_output = model.generate(
model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=False
)
self.assertListEqual(low_output.tolist(), high_output.tolist())

@slow
def test_beam_search_example_integration(self):
# PT-only test: TF doesn't have a BeamSearchScorer
Expand Down

0 comments on commit d4fc1eb

Please sign in to comment.