Skip to content

Commit

Permalink
Generate: add Bloom fixes for contrastive search (huggingface#20213)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored and Magnus Pierrau committed Dec 15, 2022
1 parent 5b7d944 commit db4291c
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 27 deletions.
25 changes: 19 additions & 6 deletions src/transformers/generation/utils.py
Expand Up @@ -672,22 +672,32 @@ def _expand_inputs_for_generation(

return input_ids, model_kwargs

@staticmethod
def _extract_past_from_model_output(outputs: ModelOutput):
def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False):
past = None
if "past_key_values" in outputs:
past = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems
elif "past_buckets_states" in outputs:
past = outputs.past_buckets_states

# Bloom fix: standardizes the cache format when requested
if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"):
batch_size = outputs.logits.shape[0]
past = self._convert_to_standard_cache(past, batch_size=batch_size)
return past

def _update_model_kwargs_for_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
# update past
model_kwargs["past"] = self._extract_past_from_model_output(outputs)
model_kwargs["past"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)

# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
Expand Down Expand Up @@ -1939,7 +1949,10 @@ def contrastive_search(
logit_for_next_step = outputs.logits[:, -1, :]

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True,
)

# Expands model inputs top_k times, for batched forward passes (akin to beam search).
Expand Down Expand Up @@ -2001,7 +2014,7 @@ def contrastive_search(
outputs = self(
**next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
)
next_past_key_values = self._extract_past_from_model_output(outputs)
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)

logits = outputs.logits[:, -1, :]
# name is different for encoder-decoder and decoder-only models
Expand Down
68 changes: 51 additions & 17 deletions src/transformers/models/bloom/modeling_bloom.py
Expand Up @@ -506,6 +506,45 @@ def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
if isinstance(module, BloomModel):
module.gradient_checkpointing = value

@staticmethod
def _convert_to_standard_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
num_heads, ...]))
"""
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
num_heads = batch_size_times_num_heads // batch_size
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)

@staticmethod
def _convert_to_bloom_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
"""
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
batch_size_times_num_heads = batch_size * num_heads
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)


BLOOM_START_DOCSTRING = r"""
Expand Down Expand Up @@ -811,6 +850,10 @@ def prepare_inputs_for_generation(
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)

# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if past[0][0].shape[0] == input_ids.shape[0]:
past = self._convert_to_bloom_cache(past)

return {
"input_ids": input_ids,
"past_key_values": past,
Expand Down Expand Up @@ -896,9 +939,8 @@ def forward(
attentions=transformer_outputs.attentions,
)

@staticmethod
def _reorder_cache(
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
Expand All @@ -907,28 +949,20 @@ def _reorder_cache(
Output shares the same memory storage as `past`.
"""
batch_size_times_num_heads, head_dim, seq_length = past[0][0].shape
batch_size = len(beam_idx)
num_heads = batch_size_times_num_heads // batch_size
standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))

# Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
}
# key: layer_past[0] [batch_size * num_heads, head_dim, seq_length]
# value: layer_past[1] [batch_size * num_heads, seq_length, head_dim]
return tuple(
reordered_past = tuple(
(
layer_past[0]
.view(batch_size, num_heads, head_dim, seq_length)
.index_select(0, device_to_beam_idx[layer_past[0].device])
.view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1]
.view(batch_size, num_heads, seq_length, head_dim)
.index_select(0, device_to_beam_idx[layer_past[0].device])
.view(batch_size_times_num_heads, seq_length, head_dim),
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
)
for layer_past in past
for layer_past in standardized_past
)
return self._convert_to_bloom_cache(reordered_past)


@add_start_docstrings(
Expand Down
6 changes: 2 additions & 4 deletions tests/generation/test_utils.py
Expand Up @@ -1411,9 +1411,8 @@ def test_contrastive_generate(self):
# check `generate()` and `contrastive_search()` are equal
for model_class in self.all_generative_model_classes:

# TODO: Fix Bloom. Bloom fails because `past` has a different shape.
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["bloom", "fsmt", "reformer"]):
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return

config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
Expand All @@ -1434,9 +1433,8 @@ def test_contrastive_generate(self):
def test_contrastive_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:

# TODO: Fix Bloom. Bloom fails because `past` has a different shape.
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["bloom", "fsmt", "reformer"]):
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return

# enable cache
Expand Down

0 comments on commit db4291c

Please sign in to comment.