Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generation / FIX: Fix multi-device generation #30746

Merged
merged 11 commits into from
May 13, 2024
25 changes: 17 additions & 8 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ def _prepare_attention_mask_for_generation(
)
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
attention_mask_from_padding = inputs.ne(pad_token_id).long()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird that this is changed 馃槃

attention_mask = (
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
)
Expand Down Expand Up @@ -1340,7 +1341,10 @@ def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCa
return self._static_cache

def _prepare_special_tokens(
self, generation_config: GenerationConfig, kwargs_has_attention_mask: Optional[bool] = None
self,
generation_config: GenerationConfig,
kwargs_has_attention_mask: Optional[bool] = None,
device: Optional[Union[torch.device, str]] = None,
):
"""
Prepares the special tokens for generation, overwriting the generation config with their processed versions
Expand All @@ -1352,15 +1356,18 @@ def _prepare_special_tokens(
"""

# Convert special tokens to tensors (if they exist)
def _tensor_or_none(token):
def _tensor_or_none(token, device=None):
if device is None:
device = self.device

if token is None or isinstance(token, torch.Tensor):
return token
return torch.tensor(token, device=self.device, dtype=torch.long)
return torch.tensor(token, device=device, dtype=torch.long)

bos_token_id = _tensor_or_none(generation_config.bos_token_id)
eos_token_id = _tensor_or_none(generation_config.eos_token_id)
pad_token_id = _tensor_or_none(generation_config.pad_token_id)
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id)
bos_token_id = _tensor_or_none(generation_config.bos_token_id, device=device)
eos_token_id = _tensor_or_none(generation_config.eos_token_id, device=device)
pad_token_id = _tensor_or_none(generation_config.pad_token_id, device=device)
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id

# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
Expand Down Expand Up @@ -1511,14 +1518,16 @@ def generate(
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask)

# 3. Define model inputs
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]

device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)

# decoder-only models must use left-padding for batched generation.
if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
Expand Down