From 5e5db4589dd0bbdf635e60e6718aeec801a86ad6 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 31 Aug 2021 15:44:15 +0200 Subject: [PATCH] Adding more tests. --- tests/test_generation_utils.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index de986b696d8aa..e5d5d54bf16a8 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -1617,6 +1617,30 @@ def test_beam_search_warning_if_max_length_is_passed(self): # BeamSearchScorer max_length should not influence "real" max_length self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist()) + def test_custom_stopping_criteria(self): + article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" + bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") + bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) + input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + stopping_criteria = StoppingCriteriaList() + stopping_criteria.append(MaxLengthCriteria(max_length=10)) + # XXX: Used to fail with `stopping_criteria` being defined twice in call arguments + # https://github.com/huggingface/transformers/issues/12118 + bart_model.generate(input_ids, stopping_criteria=stopping_criteria) + + def test_custom_logits_processor(self): + article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" + bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") + bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) + input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + logits_processor = LogitsProcessorList() + logits_processor.append(MinLengthLogitsProcessor(min_length=10, eos_token_id=0)) + # XXX: Used to fail with `logits_processor` being defined twice in call arguments + # https://github.com/huggingface/transformers/issues/12118 + bart_model.generate(input_ids, logits_processor=logits_processor) + def test_max_new_tokens(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")