Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add BartConfig.force_bos_token_to_be_generated #6526

Merged
merged 4 commits into from
Aug 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/transformers/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@
for SequenceClassification
is_encoder_decoder (:obj:`int`, optional, defaults to True):
True
force_bos_token_to_be_generated (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force BOS token to be generated at step 1 (after ``decoder_start_token_id``), only true for `bart-large-cnn`.

"""

Expand Down Expand Up @@ -137,6 +139,7 @@ def __init__(
normalize_embedding=True,
static_position_embeddings=False,
add_bias_logits=False,
force_bos_token_to_be_generated=False,
**common_kwargs
):
r"""
Expand Down Expand Up @@ -195,6 +198,8 @@ def __init__(
# pos embedding offset
self.extra_pos_embeddings = self.pad_token_id + 1

self.force_bos_token_to_be_generated = force_bos_token_to_be_generated

@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads
Expand Down
18 changes: 5 additions & 13 deletions src/transformers/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,23 +1073,15 @@ def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask,
}

def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == 1:
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
self._force_token_ids_generation(logits, self.config.bos_token_id)
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(logits, self.config.eos_token_id)
return logits

def _force_token_ids_generation(self, scores, token_ids) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0"""
if isinstance(token_ids, int):
token_ids = [token_ids]
all_but_token_ids_mask = torch.tensor(
[x for x in range(self.config.vocab_size) if x not in token_ids],
dtype=torch.long,
device=next(self.parameters()).device,
)
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
scores[:, all_but_token_ids_mask] = -float("inf")
def _force_token_ids_generation(self, scores, token_id) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just cleanup

"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf")

@staticmethod
def _reorder_cache(past, beam_idx):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class MarianMTModel(BartForConditionalGeneration):
"""

def adjust_logits_during_generation(self, logits, cur_len, max_length):
logits[:, self.config.pad_token_id] = float("-inf")
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(logits, self.config.eos_token_id)
return logits
2 changes: 1 addition & 1 deletion tests/test_modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def test_xsum_summarization_same_as_fairseq(self):
self.assertFalse(model.config.is_valid_mbart())
tok = BartTokenizer.from_pretrained("facebook/bart-large")

EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state."
EXPECTED_SUMMARY = "California's largest power company has begun shutting off electricity to thousands of customers in the state."
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

summaries change a bit for xsum, but ROUGE increases.

dct = tok.batch_encode_plus(
[PGE_ARTICLE], max_length=1024, padding="max_length", truncation=True, return_tensors="pt",
).to(torch_device)
Expand Down