Skip to content

Commit

Permalink
Adding more tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Aug 31, 2021
1 parent bfc203b commit 5e5db45
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions tests/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,6 +1617,30 @@ def test_beam_search_warning_if_max_length_is_passed(self):
# BeamSearchScorer max_length should not influence "real" max_length
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())

def test_custom_stopping_criteria(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)

stopping_criteria = StoppingCriteriaList()
stopping_criteria.append(MaxLengthCriteria(max_length=10))
# XXX: Used to fail with `stopping_criteria` being defined twice in call arguments
# https://github.com/huggingface/transformers/issues/12118
bart_model.generate(input_ids, stopping_criteria=stopping_criteria)

def test_custom_logits_processor(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)

logits_processor = LogitsProcessorList()
logits_processor.append(MinLengthLogitsProcessor(min_length=10, eos_token_id=0))
# XXX: Used to fail with `logits_processor` being defined twice in call arguments
# https://github.com/huggingface/transformers/issues/12118
bart_model.generate(input_ids, logits_processor=logits_processor)

def test_max_new_tokens(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
Expand Down

0 comments on commit 5e5db45

Please sign in to comment.