-
Notifications
You must be signed in to change notification settings - Fork 26.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Obtaining an Exception "KeyError: 'labels'" while fine-tuning Whisper #33153
Comments
cc @ylacombe as well |
@ArthurZucker I don't know, how to fix it((( |
I have carefully studied the code, apparently the differences are only in one function. The code from the publication still runs, but the code for tuning the model into two languages does not. Here is the only function that is different: def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch["audio"]
# compute log-Mel input features from input audio array
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# get the language of our text
tokenizer.set_prefix_tokens(language=batch["language"], task="transcribe")
# encode target text to label ids
batch["labels"] = tokenizer(batch["sentence"]).input_ids
return batch Honestly, I don't really see how adding token prefixes can break learning.... |
@ArthurZucker I've isolated the problem. Everything works now. But I still have a question, how to properly instantiate an untrained model? At first I did it as follows: model_name = "openai/whisper-tiny"
....
from transformers import AutoConfig, AutoModel
configuration = AutoConfig.from_pretrained(model_name)
model = AutoModel.from_config(configuration) This is what was causing the problem. How to do this correctly in the case of Whisper to train the model from zero? |
Hey @blademoon, you can get inspiration from the code snippet provided in the Whisper documentation: from transformers import WhisperConfig, WhisperModel
# Initializing a Whisper tiny style configuration
configuration = WhisperConfig()
# Initializing a model (with random weights) from the tiny style configuration
model = WhisperModel(configuration)
# Accessing the model configuration
configuration = model.config In your own case, you can first load the config from the repository id and then instantiate your model from the config: from transformers import WhisperConfig, WhisperForConditionalGeneration
configuration = WhisperConfig.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration(configuration) Hope it helps! cc @eustlb for visibility |
@ylacombe Your solution work. But another problem arises. If we instantiate the model as you suggested: from transformers import WhisperConfig, WhisperForConditionalGeneration
model_name = "openai/whisper-tiny"
configuration = WhisperConfig.from_pretrained(model_name)
model = WhisperForConditionalGeneration(configuration) then configure the model like a notebook: model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None and train, an exception occurs: ---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[36], line 1
----> 1 trainer.train()
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:1929](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=1928), in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1926 try:
1927 # Disable progress bars when uploading models during checkpoints to avoid polluting stdout
1928 hf_hub_utils.disable_progress_bars()
-> 1929 return inner_training_loop(
1930 args=args,
1931 resume_from_checkpoint=resume_from_checkpoint,
1932 trial=trial,
1933 ignore_keys_for_eval=ignore_keys_for_eval,
1934 )
1935 finally:
1936 hf_hub_utils.enable_progress_bars()
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2356](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2355), in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2353 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
2354 self.control = self.callback_handler.on_step_end(args, self.state, self.control)
-> 2356 self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
2357 else:
2358 self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2804](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2803), in Trainer._maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
2802 metrics = None
2803 if self.control.should_evaluate:
-> 2804 metrics = self._evaluate(trial, ignore_keys_for_eval)
2806 if self.control.should_save:
2807 self._save_checkpoint(model, trial, metrics=metrics)
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2761](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2760), in Trainer._evaluate(self, trial, ignore_keys_for_eval, skip_scheduler)
2760 def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
-> 2761 metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
2762 self._report_to_hp_search(trial, self.state.global_step, metrics)
2764 # Run delayed LR scheduler now that metrics are populated
File [~/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py:180](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py#line=179), in Seq2SeqTrainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix, **gen_kwargs)
178 self.gather_function = self.accelerator.gather
179 self._gen_kwargs = gen_kwargs
--> 180 return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:3666](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=3665), in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
3663 start_time = time.time()
3665 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 3666 output = eval_loop(
3667 eval_dataloader,
3668 description="Evaluation",
3669 # No point gathering the predictions if there are no metrics, otherwise we defer to
3670 # self.args.prediction_loss_only
3671 prediction_loss_only=True if self.compute_metrics is None else None,
3672 ignore_keys=ignore_keys,
3673 metric_key_prefix=metric_key_prefix,
3674 )
3676 total_batch_size = self.args.eval_batch_size * self.args.world_size
3677 if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:3857](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=3856), in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
3854 batch_size = observed_batch_size
3856 # Prediction step
-> 3857 losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3858 main_input_name = getattr(self.model, "main_input_name", "input_ids")
3859 inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None
File [~/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py:310](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py#line=309), in Seq2SeqTrainer.prediction_step(self, model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs)
302 if (
303 "labels" in generation_inputs
304 and "decoder_input_ids" in generation_inputs
305 and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
306 ):
307 generation_inputs = {
308 k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
309 }
--> 310 generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
312 # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
313 # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
314 # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
315 if self.model.generation_config._from_model_config:
File ~/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:542, in WhisperGenerationMixin.generate(self, input_features, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_timestamps, task, language, is_multilingual, prompt_ids, prompt_condition_type, condition_on_prev_tokens, temperature, compression_ratio_threshold, logprob_threshold, no_speech_threshold, num_segment_frames, attention_mask, time_precision, return_token_timestamps, return_segments, return_dict_in_generate, **kwargs)
536 self._set_prompt_condition_type(
537 generation_config=generation_config,
538 prompt_condition_type=prompt_condition_type,
539 )
541 # pass self.config for backward compatibility
--> 542 init_tokens = self._retrieve_init_tokens(
543 input_features,
544 batch_size=batch_size,
545 generation_config=generation_config,
546 config=self.config,
547 num_segment_frames=num_segment_frames,
548 kwargs=kwargs,
549 )
550 # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
551 # where the input ids are handled explicitly by the generate method
552 self._check_decoder_input_ids(kwargs=kwargs)
File [~/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:1357](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py#line=1356), in WhisperGenerationMixin._retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs)
1355 if task is not None:
1356 if task in TASK_IDS:
-> 1357 init_tokens[i].append(generation_config.task_to_id[generation_config.task])
1358 task_id = generation_config.task_to_id[generation_config.task]
1360 # if task is defined it'll overwrite task ids that might have already been defined via the generation_config
AttributeError: 'GenerationConfig' object has no attribute 'task_to_id'
|
@ylacombe If I comment from transformers import WhisperConfig, WhisperForConditionalGeneration
configuration = WhisperConfig.from_pretrained(model_name)
model = WhisperForConditionalGeneration(configuration)
model.generation_config.task = "transcribe"
# model.generation_config.forced_decoder_ids = None I get another exception while training: You have passed task=transcribe, but also have set `forced_decoder_ids` to [[1, 50259], [2, 50359], [3, 50363]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=transcribe.
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[16], line 1
----> 1 trainer.train()
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:1929](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=1928), in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1926 try:
1927 # Disable progress bars when uploading models during checkpoints to avoid polluting stdout
1928 hf_hub_utils.disable_progress_bars()
-> 1929 return inner_training_loop(
1930 args=args,
1931 resume_from_checkpoint=resume_from_checkpoint,
1932 trial=trial,
1933 ignore_keys_for_eval=ignore_keys_for_eval,
1934 )
1935 finally:
1936 hf_hub_utils.enable_progress_bars()
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2356](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2355), in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2353 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
2354 self.control = self.callback_handler.on_step_end(args, self.state, self.control)
-> 2356 self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
2357 else:
2358 self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2804](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2803), in Trainer._maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
2802 metrics = None
2803 if self.control.should_evaluate:
-> 2804 metrics = self._evaluate(trial, ignore_keys_for_eval)
2806 if self.control.should_save:
2807 self._save_checkpoint(model, trial, metrics=metrics)
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2761](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2760), in Trainer._evaluate(self, trial, ignore_keys_for_eval, skip_scheduler)
2760 def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
-> 2761 metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
2762 self._report_to_hp_search(trial, self.state.global_step, metrics)
2764 # Run delayed LR scheduler now that metrics are populated
File [~/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py:180](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py#line=179), in Seq2SeqTrainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix, **gen_kwargs)
178 self.gather_function = self.accelerator.gather
179 self._gen_kwargs = gen_kwargs
--> 180 return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:3666](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=3665), in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
3663 start_time = time.time()
3665 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 3666 output = eval_loop(
3667 eval_dataloader,
3668 description="Evaluation",
3669 # No point gathering the predictions if there are no metrics, otherwise we defer to
3670 # self.args.prediction_loss_only
3671 prediction_loss_only=True if self.compute_metrics is None else None,
3672 ignore_keys=ignore_keys,
3673 metric_key_prefix=metric_key_prefix,
3674 )
3676 total_batch_size = self.args.eval_batch_size * self.args.world_size
3677 if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:3857](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=3856), in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
3854 batch_size = observed_batch_size
3856 # Prediction step
-> 3857 losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3858 main_input_name = getattr(self.model, "main_input_name", "input_ids")
3859 inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None
File [~/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py:310](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py#line=309), in Seq2SeqTrainer.prediction_step(self, model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs)
302 if (
303 "labels" in generation_inputs
304 and "decoder_input_ids" in generation_inputs
305 and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
306 ):
307 generation_inputs = {
308 k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
309 }
--> 310 generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
312 # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
313 # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
314 # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
315 if self.model.generation_config._from_model_config:
File ~/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:542, in WhisperGenerationMixin.generate(self, input_features, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_timestamps, task, language, is_multilingual, prompt_ids, prompt_condition_type, condition_on_prev_tokens, temperature, compression_ratio_threshold, logprob_threshold, no_speech_threshold, num_segment_frames, attention_mask, time_precision, return_token_timestamps, return_segments, return_dict_in_generate, **kwargs)
536 self._set_prompt_condition_type(
537 generation_config=generation_config,
538 prompt_condition_type=prompt_condition_type,
539 )
541 # pass self.config for backward compatibility
--> 542 init_tokens = self._retrieve_init_tokens(
543 input_features,
544 batch_size=batch_size,
545 generation_config=generation_config,
546 config=self.config,
547 num_segment_frames=num_segment_frames,
548 kwargs=kwargs,
549 )
550 # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
551 # where the input ids are handled explicitly by the generate method
552 self._check_decoder_input_ids(kwargs=kwargs)
File [~/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:1357](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py#line=1356), in WhisperGenerationMixin._retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs)
1355 if task is not None:
1356 if task in TASK_IDS:
-> 1357 init_tokens[i].append(generation_config.task_to_id[generation_config.task])
1358 task_id = generation_config.task_to_id[generation_config.task]
1360 # if task is defined it'll overwrite task ids that might have already been defined via the generation_config
AttributeError: 'GenerationConfig' object has no attribute 'task_to_id' |
Hey @blademoon, from transformers import WhisperConfig, WhisperForConditionalGeneration
from transformers.generation.configuration_utils import GenerationConfig
model_name = "openai/whisper-tiny"
configuration = WhisperConfig.from_pretrained(model_name)
generation_config = GenerationConfig.from_pretrained(model_name)
model = WhisperForConditionalGeneration(configuration)
model.generation_config = generation_config What you might want to do as well is to look at the tiny model generation config to see if the config fits your need, and to modify some parameters in the code above if necessary ! |
@ylacombe Hello. I'm testing the variant you suggested. A small question, if the model with random weights is trained in two languages at once, then the proposed variant will work correctly too? |
Most probably, but you might want to guide the model a bit more by using language tokens |
@ylacombe I use this guide from @sanchit-gandhi . Now i have same warning but training work: |
@ylacombe It may be easier to understand if you can see the big picture. |
If the training does what you want (i.e transcribing Russian and English I guess ?) then you can ignore the warning! |
@ylacombe OK. I'll check it out and come back with feedback. Thank you. |
System Info
WLS 2.0 Ubuntu 22.04
transformers 4.44.2
python3.10
Who can help?
@sanchit-gandhi
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Get exception:
Expected behavior
Starting model training.
The text was updated successfully, but these errors were encountered: