Skip to content

Commit

Permalink
Fix kwargs handling in generate_with_fallback (#29225)
Browse files Browse the repository at this point in the history
* Fix generate_with_fallback **kwargs

* Change pop to get

* Delete keys from kwargs to prevent overriding generation_config

* Revert to passing kwargs by reference, but make a (shallow) copy

* dict -> copy.copy

* Add test_whisper_longform_multi_batch_beam
  • Loading branch information
cifkao committed Apr 3, 2024
1 parent 851f253 commit bcd42c4
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,8 @@ def generate_with_fallback(
do_condition_on_prev_tokens,
kwargs,
):
kwargs = copy.copy(kwargs)

# 6.6 Batch generate current chunk
seek_sequence_list = [None for _ in range(cur_bsz)]
seek_outputs_list = [None for _ in range(cur_bsz)]
Expand All @@ -769,8 +771,12 @@ def generate_with_fallback(
generation_config.do_sample = temperature is not None and temperature > 0.0

generation_config.temperature = temperature if generation_config.do_sample else 1.0
generation_config.num_beams = kwargs.pop("num_beams", 1) if not generation_config.do_sample else 1
generation_config.num_beams = kwargs.get("num_beams", 1) if not generation_config.do_sample else 1

generate_kwargs = copy.copy(kwargs)
for key in ["do_sample", "temperature", "num_beams"]:
if key in generate_kwargs:
del generate_kwargs[key]
seek_outputs = super().generate(
segment_input,
generation_config,
Expand All @@ -779,7 +785,7 @@ def generate_with_fallback(
prefix_allowed_tokens_fn,
synced_gpus,
decoder_input_ids=decoder_input_ids,
**kwargs,
**generate_kwargs,
)

# post-process sequence tokens and outputs to be in list form
Expand Down
55 changes: 55 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,12 @@ def test_longform_generate_multi_batch_cond_prev(self):
@require_torch
@require_torchaudio
class WhisperModelIntegrationTests(unittest.TestCase):
def setUp(self):
self._unpatched_generation_mixin_generate = transformers.GenerationMixin.generate

def tearDown(self):
transformers.GenerationMixin.generate = self._unpatched_generation_mixin_generate

@cached_property
def default_processor(self):
return WhisperProcessor.from_pretrained("openai/whisper-base")
Expand All @@ -1544,6 +1550,16 @@ def _load_datasamples(self, num_samples):

return [x["array"] for x in speech_samples]

def _patch_generation_mixin_generate(self, check_args_fn=None):
test = self

def generate(self, *args, **kwargs):
if check_args_fn is not None:
check_args_fn(*args, **kwargs)
return test._unpatched_generation_mixin_generate(self, *args, **kwargs)

transformers.GenerationMixin.generate = generate

@slow
def test_tiny_logits_librispeech(self):
torch_device = "cpu"
Expand Down Expand Up @@ -2426,6 +2442,45 @@ def test_whisper_longform_single_batch_prev_cond(self):

assert decoded == EXPECTED_TEXT

@slow
def test_whisper_longform_multi_batch_beam(self):
# fmt: off
EXPECTED_TEXT = [' A man said to the universe, Sir, I exist. Sweat-covered Brienne\'s body trickling into the titling cloth that was the only german he wore. The cut on his chest was still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators, rich trivialities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were, triggered his muscles into complete relaxation. Oily his heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, nights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied. The thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I\'m here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The 20s, he must have drawn his gun because the intruder said quickly, but that away, you\'re being a fool. Out, there was silence then, and still wondering, Breon was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. A red-haired mountain of a man, with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing. Just thrust and parry, and victory to the stronger. Every man who entered the 20s had his own training tricks. There appeared to be an immediate association with the death trauma, as if the two were andextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported, except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the 20s, and death during the last round was, in some ways, easier than defeat. Breeding deeply, Breon\'s softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. Our role looked amazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent\'s face when the man finally recognized his error. A wave of despair rolled out from our rogue. Breon sensed it and knew the fifth point was his. Then the powerful twist that\'s rested aside, in and under the guard, Mr. Quilter is the apostle of the middle classes, and we\'re glad to welcome his gospel. Nor is Mr. Quilter\'s manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and Rose beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton\'s work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell\'s pictures are a sort of up-gards and atom paintings, and Mason\'s exquisite idles are as national as a jingo poem. Mr. Burkett Foster\'s landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap in the back, before he says, like a shampooer and a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate in expression. From the general principles of art, Mr. Quilter writes with equal lucidity. Painting he tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. The customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing upholsterer, near the fire, and the ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man. In remarks was pleasing courtesy and fellas of this grace that many faces are feeling. Only unfortunately his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tupper of painting. By Harry Quilter M.A. Because you are sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, accoing dove. He has gone, and gone for good," answered Polychrom, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stout chains as easily as if they had been threads. A little girl had been asleep, but she heard the wraps and opened the door. The king has fled and disgraced, and your friends are asking for you. I begged Ruggido long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn\'t work too hard, since Shaggy. He doesn\'t work at all. In fact, there is nothing he can do in these dominions, as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, we\'ve turned Calico, whereas my brother now, inquired Shaggy. In the metal forest. Where is that? The metal forest is in the great domed cavern, the largest in all our dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I\'m quite sure he didn\'t. That\'s funny, remarked Betsy thoughtfully. I don\'t believe and knew any magic, or she\'d have worked it before. I do not know, confessed Shaggy. True, a great Calico. Calico went to the big gong and pounded on it, just as we\'re good to be used to do, but no one answered the summons. Having returned to the Royal Cavern, Calico first pounded the gong and then sat in the throne, wearing Regido\'s discarded ruby crown, and holding in his hand to scepter which Regido had so often thrown at his head.']
# fmt: on

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model = model.to(torch_device)

ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean")
one_audio = np.concatenate([x["array"] for x in ds["validation"]["audio"]], dtype=np.float32)

input_features = processor(one_audio, return_tensors="pt", truncation=False, padding="longest")[
"input_features"
]
input_features = input_features.to(device=torch_device)

gen_kwargs = {
"return_timestamps": True,
"no_speech_threshold": 0.6,
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
"num_beams": 2,
"compression_ratio_threshold": 1.35,
"condition_on_prev_tokens": True,
"logprob_threshold": -1.0,
}

def check_gen_kwargs(inputs, generation_config, *args, **kwargs):
assert generation_config.num_beams == gen_kwargs["num_beams"]

self._patch_generation_mixin_generate(check_args_fn=check_gen_kwargs)

torch.manual_seed(0)
result = model.generate(input_features, **gen_kwargs)
decoded = processor.batch_decode(result, skip_special_tokens=True)

assert decoded == EXPECTED_TEXT

@slow
def test_whisper_longform_multi_batch(self):
# fmt: off
Expand Down

0 comments on commit bcd42c4

Please sign in to comment.