|
33 | 33 | GPT2ForSequenceClassification, |
34 | 34 | GPT2LMHeadModel, |
35 | 35 | GPT2Model, |
| 36 | + GPT2Tokenizer, |
36 | 37 | ) |
37 | 38 |
|
38 | 39 |
|
@@ -425,6 +426,50 @@ def test_gpt2_gradient_checkpointing(self): |
425 | 426 | config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) |
426 | 427 | self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) |
427 | 428 |
|
| 429 | + @slow |
| 430 | + def test_batch_generation(self): |
| 431 | + model = GPT2LMHeadModel.from_pretrained("gpt2") |
| 432 | + model.to(torch_device) |
| 433 | + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
| 434 | + |
| 435 | + tokenizer.padding_side = "left" |
| 436 | + |
| 437 | + # Define PAD Token = EOS Token = 50256 |
| 438 | + tokenizer.pad_token = tokenizer.eos_token |
| 439 | + model.config.pad_token_id = model.config.eos_token_id |
| 440 | + |
| 441 | + # use different length sentences to test batching |
| 442 | + sentences = [ |
| 443 | + "Hello, my dog is a little", |
| 444 | + "Today, I", |
| 445 | + ] |
| 446 | + |
| 447 | + inputs = tokenizer(sentences, return_tensors="pt", padding=True) |
| 448 | + |
| 449 | + torch.manual_seed(0) |
| 450 | + outputs = model.generate( |
| 451 | + input_ids=inputs["input_ids"].to(torch_device), |
| 452 | + attention_mask=inputs["attention_mask"].to(torch_device), |
| 453 | + ) |
| 454 | + |
| 455 | + inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device) |
| 456 | + output_non_padded = model.generate(input_ids=inputs_non_padded) |
| 457 | + |
| 458 | + num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item() |
| 459 | + inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device) |
| 460 | + output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings) |
| 461 | + |
| 462 | + batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| 463 | + non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True) |
| 464 | + padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True) |
| 465 | + |
| 466 | + expected_output_sentence = [ |
| 467 | + "Hello, my dog is a little bit of a mess. I'm not sure if he's going", |
| 468 | + "Today, I'm going to be doing a lot of research on this. I", |
| 469 | + ] |
| 470 | + self.assertListEqual(expected_output_sentence, batch_out_sentence) |
| 471 | + self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence]) |
| 472 | + |
428 | 473 | @slow |
429 | 474 | def test_model_from_pretrained(self): |
430 | 475 | for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: |
|
0 commit comments