Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Obtaining an Exception "KeyError: 'labels'" while fine-tuning Whisper #33153

Open
2 of 4 tasks
blademoon opened this issue Aug 27, 2024 · 14 comments
Open
2 of 4 tasks

Obtaining an Exception "KeyError: 'labels'" while fine-tuning Whisper #33153

blademoon opened this issue Aug 27, 2024 · 14 comments
Labels
bug Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!

Comments

@blademoon
Copy link

System Info

WLS 2.0 Ubuntu 22.04
transformers 4.44.2
python3.10

Who can help?

@sanchit-gandhi

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Use example from https://huggingface.co/blog/fine-tune-whisper as is. Wthout modify.

Get exception:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[22], line 1
----> 1 trainer.train()

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:1929](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=1928), in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1926 try:
   1927     # Disable progress bars when uploading models during checkpoints to avoid polluting stdout
   1928     hf_hub_utils.disable_progress_bars()
-> 1929     return inner_training_loop(
   1930         args=args,
   1931         resume_from_checkpoint=resume_from_checkpoint,
   1932         trial=trial,
   1933         ignore_keys_for_eval=ignore_keys_for_eval,
   1934     )
   1935 finally:
   1936     hf_hub_utils.enable_progress_bars()

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2236](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2235), in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2233     rng_to_sync = True
   2235 step = -1
-> 2236 for step, inputs in enumerate(epoch_iterator):
   2237     total_batched_samples += 1
   2239     if self.args.include_num_input_tokens_seen:

File [~/.local/lib/python3.10/site-packages/accelerate/data_loader.py:454](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/accelerate/data_loader.py#line=453), in DataLoaderShard.__iter__(self)
    452 # We iterate one batch ahead to check when we are at the end
    453 try:
--> 454     current_batch = next(dataloader_iter)
    455 except StopIteration:
    456     yield

File [~/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:631](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py#line=630), in _BaseDataLoaderIter.__next__(self)
    628 if self._sampler_iter is None:
    629     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    630     self._reset()  # type: ignore[call-arg]
--> 631 data = self._next_data()
    632 self._num_yielded += 1
    633 if self._dataset_kind == _DatasetKind.Iterable and \
    634         self._IterableDataset_len_called is not None and \
    635         self._num_yielded > self._IterableDataset_len_called:

File ~/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:675, in _SingleProcessDataLoaderIter._next_data(self)
    673 def _next_data(self):
    674     index = self._next_index()  # may raise StopIteration
--> 675     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    676     if self._pin_memory:
    677         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File [~/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:54](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py#line=53), in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     52 else:
     53     data = self.dataset[possibly_batched_index]
---> 54 return self.collate_fn(data)

Cell In[17], line 18, in DataCollatorSpeechSeq2SeqWithPadding.__call__(self, features)
     15 batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
     17 # get the tokenized label sequences
---> 18 label_features = [{"input_ids": feature["labels"]} for feature in features]
     19 # pad the labels to max length
     20 labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

Cell In[17], line 18, in <listcomp>(.0)
     15 batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
     17 # get the tokenized label sequences
---> 18 label_features = [{"input_ids": feature["labels"]} for feature in features]
     19 # pad the labels to max length
     20 labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

KeyError: 'labels'

Expected behavior

Starting model training.

@blademoon blademoon added the bug label Aug 27, 2024
@ArthurZucker
Copy link
Collaborator

cc @ylacombe as well
Also feel free to open a PR for a fix!

@ArthurZucker ArthurZucker added the Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want! label Aug 28, 2024
@blademoon
Copy link
Author

@ArthurZucker I don't know, how to fix it(((

@blademoon
Copy link
Author

blademoon commented Aug 29, 2024

I have carefully studied the code, apparently the differences are only in one function. The code from the publication still runs, but the code for tuning the model into two languages does not.

Here is the only function that is different:

def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # get the language of our text
    tokenizer.set_prefix_tokens(language=batch["language"], task="transcribe") 
   
    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

Honestly, I don't really see how adding token prefixes can break learning....

@blademoon
Copy link
Author

blademoon commented Aug 29, 2024

@ArthurZucker I've isolated the problem. Everything works now. But I still have a question, how to properly instantiate an untrained model?

At first I did it as follows:

model_name = "openai/whisper-tiny"
....
from transformers import AutoConfig, AutoModel

configuration = AutoConfig.from_pretrained(model_name)
model = AutoModel.from_config(configuration)

This is what was causing the problem.

How to do this correctly in the case of Whisper to train the model from zero?

@ylacombe
Copy link
Contributor

ylacombe commented Sep 3, 2024

Hey @blademoon, you can get inspiration from the code snippet provided in the Whisper documentation:

from transformers import WhisperConfig, WhisperModel

# Initializing a Whisper tiny style configuration
configuration = WhisperConfig()

# Initializing a model (with random weights) from the tiny style configuration
model = WhisperModel(configuration)

# Accessing the model configuration
configuration = model.config

In your own case, you can first load the config from the repository id and then instantiate your model from the config:

from transformers import WhisperConfig, WhisperForConditionalGeneration

configuration = WhisperConfig.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration(configuration)

Hope it helps!

cc @eustlb for visibility

@blademoon
Copy link
Author

blademoon commented Sep 3, 2024

@ylacombe Your solution work. But another problem arises.

If we instantiate the model as you suggested:

from transformers import WhisperConfig, WhisperForConditionalGeneration

model_name = "openai/whisper-tiny"

configuration = WhisperConfig.from_pretrained(model_name)
model = WhisperForConditionalGeneration(configuration)

then configure the model like a notebook:

model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None

and train, an exception occurs:

image

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[36], line 1
----> 1 trainer.train()

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:1929](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=1928), in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1926 try:
   1927     # Disable progress bars when uploading models during checkpoints to avoid polluting stdout
   1928     hf_hub_utils.disable_progress_bars()
-> 1929     return inner_training_loop(
   1930         args=args,
   1931         resume_from_checkpoint=resume_from_checkpoint,
   1932         trial=trial,
   1933         ignore_keys_for_eval=ignore_keys_for_eval,
   1934     )
   1935 finally:
   1936     hf_hub_utils.enable_progress_bars()

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2356](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2355), in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2353     self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
   2354     self.control = self.callback_handler.on_step_end(args, self.state, self.control)
-> 2356     self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
   2357 else:
   2358     self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2804](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2803), in Trainer._maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
   2802 metrics = None
   2803 if self.control.should_evaluate:
-> 2804     metrics = self._evaluate(trial, ignore_keys_for_eval)
   2806 if self.control.should_save:
   2807     self._save_checkpoint(model, trial, metrics=metrics)

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2761](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2760), in Trainer._evaluate(self, trial, ignore_keys_for_eval, skip_scheduler)
   2760 def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
-> 2761     metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
   2762     self._report_to_hp_search(trial, self.state.global_step, metrics)
   2764     # Run delayed LR scheduler now that metrics are populated

File [~/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py:180](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py#line=179), in Seq2SeqTrainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix, **gen_kwargs)
    178 self.gather_function = self.accelerator.gather
    179 self._gen_kwargs = gen_kwargs
--> 180 return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:3666](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=3665), in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   3663 start_time = time.time()
   3665 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 3666 output = eval_loop(
   3667     eval_dataloader,
   3668     description="Evaluation",
   3669     # No point gathering the predictions if there are no metrics, otherwise we defer to
   3670     # self.args.prediction_loss_only
   3671     prediction_loss_only=True if self.compute_metrics is None else None,
   3672     ignore_keys=ignore_keys,
   3673     metric_key_prefix=metric_key_prefix,
   3674 )
   3676 total_batch_size = self.args.eval_batch_size * self.args.world_size
   3677 if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:3857](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=3856), in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   3854         batch_size = observed_batch_size
   3856 # Prediction step
-> 3857 losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
   3858 main_input_name = getattr(self.model, "main_input_name", "input_ids")
   3859 inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None

File [~/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py:310](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py#line=309), in Seq2SeqTrainer.prediction_step(self, model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs)
    302 if (
    303     "labels" in generation_inputs
    304     and "decoder_input_ids" in generation_inputs
    305     and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
    306 ):
    307     generation_inputs = {
    308         k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
    309     }
--> 310 generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
    312 # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
    313 # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
    314 # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
    315 if self.model.generation_config._from_model_config:

File ~/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:542, in WhisperGenerationMixin.generate(self, input_features, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_timestamps, task, language, is_multilingual, prompt_ids, prompt_condition_type, condition_on_prev_tokens, temperature, compression_ratio_threshold, logprob_threshold, no_speech_threshold, num_segment_frames, attention_mask, time_precision, return_token_timestamps, return_segments, return_dict_in_generate, **kwargs)
    536 self._set_prompt_condition_type(
    537     generation_config=generation_config,
    538     prompt_condition_type=prompt_condition_type,
    539 )
    541 # pass self.config for backward compatibility
--> 542 init_tokens = self._retrieve_init_tokens(
    543     input_features,
    544     batch_size=batch_size,
    545     generation_config=generation_config,
    546     config=self.config,
    547     num_segment_frames=num_segment_frames,
    548     kwargs=kwargs,
    549 )
    550 # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
    551 # where the input ids are handled explicitly by the generate method
    552 self._check_decoder_input_ids(kwargs=kwargs)

File [~/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:1357](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py#line=1356), in WhisperGenerationMixin._retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs)
   1355 if task is not None:
   1356     if task in TASK_IDS:
-> 1357         init_tokens[i].append(generation_config.task_to_id[generation_config.task])
   1358         task_id = generation_config.task_to_id[generation_config.task]
   1360         # if task is defined it'll overwrite task ids that might have already been defined via the generation_config

AttributeError: 'GenerationConfig' object has no attribute 'task_to_id'

@blademoon
Copy link
Author

@ylacombe If I comment model.generation_config.forced_decoder_ids = None:

from transformers import WhisperConfig, WhisperForConditionalGeneration

configuration = WhisperConfig.from_pretrained(model_name)
model = WhisperForConditionalGeneration(configuration)

model.generation_config.task = "transcribe"
# model.generation_config.forced_decoder_ids = None

I get another exception while training:
image

You have passed task=transcribe, but also have set `forced_decoder_ids` to [[1, 50259], [2, 50359], [3, 50363]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=transcribe.
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[16], line 1
----> 1 trainer.train()

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:1929](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=1928), in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1926 try:
   1927     # Disable progress bars when uploading models during checkpoints to avoid polluting stdout
   1928     hf_hub_utils.disable_progress_bars()
-> 1929     return inner_training_loop(
   1930         args=args,
   1931         resume_from_checkpoint=resume_from_checkpoint,
   1932         trial=trial,
   1933         ignore_keys_for_eval=ignore_keys_for_eval,
   1934     )
   1935 finally:
   1936     hf_hub_utils.enable_progress_bars()

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2356](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2355), in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2353     self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
   2354     self.control = self.callback_handler.on_step_end(args, self.state, self.control)
-> 2356     self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
   2357 else:
   2358     self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2804](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2803), in Trainer._maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
   2802 metrics = None
   2803 if self.control.should_evaluate:
-> 2804     metrics = self._evaluate(trial, ignore_keys_for_eval)
   2806 if self.control.should_save:
   2807     self._save_checkpoint(model, trial, metrics=metrics)

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2761](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2760), in Trainer._evaluate(self, trial, ignore_keys_for_eval, skip_scheduler)
   2760 def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
-> 2761     metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
   2762     self._report_to_hp_search(trial, self.state.global_step, metrics)
   2764     # Run delayed LR scheduler now that metrics are populated

File [~/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py:180](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py#line=179), in Seq2SeqTrainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix, **gen_kwargs)
    178 self.gather_function = self.accelerator.gather
    179 self._gen_kwargs = gen_kwargs
--> 180 return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:3666](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=3665), in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   3663 start_time = time.time()
   3665 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 3666 output = eval_loop(
   3667     eval_dataloader,
   3668     description="Evaluation",
   3669     # No point gathering the predictions if there are no metrics, otherwise we defer to
   3670     # self.args.prediction_loss_only
   3671     prediction_loss_only=True if self.compute_metrics is None else None,
   3672     ignore_keys=ignore_keys,
   3673     metric_key_prefix=metric_key_prefix,
   3674 )
   3676 total_batch_size = self.args.eval_batch_size * self.args.world_size
   3677 if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:3857](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=3856), in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   3854         batch_size = observed_batch_size
   3856 # Prediction step
-> 3857 losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
   3858 main_input_name = getattr(self.model, "main_input_name", "input_ids")
   3859 inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None

File [~/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py:310](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py#line=309), in Seq2SeqTrainer.prediction_step(self, model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs)
    302 if (
    303     "labels" in generation_inputs
    304     and "decoder_input_ids" in generation_inputs
    305     and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
    306 ):
    307     generation_inputs = {
    308         k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
    309     }
--> 310 generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
    312 # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
    313 # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
    314 # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
    315 if self.model.generation_config._from_model_config:

File ~/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:542, in WhisperGenerationMixin.generate(self, input_features, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_timestamps, task, language, is_multilingual, prompt_ids, prompt_condition_type, condition_on_prev_tokens, temperature, compression_ratio_threshold, logprob_threshold, no_speech_threshold, num_segment_frames, attention_mask, time_precision, return_token_timestamps, return_segments, return_dict_in_generate, **kwargs)
    536 self._set_prompt_condition_type(
    537     generation_config=generation_config,
    538     prompt_condition_type=prompt_condition_type,
    539 )
    541 # pass self.config for backward compatibility
--> 542 init_tokens = self._retrieve_init_tokens(
    543     input_features,
    544     batch_size=batch_size,
    545     generation_config=generation_config,
    546     config=self.config,
    547     num_segment_frames=num_segment_frames,
    548     kwargs=kwargs,
    549 )
    550 # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
    551 # where the input ids are handled explicitly by the generate method
    552 self._check_decoder_input_ids(kwargs=kwargs)

File [~/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:1357](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py#line=1356), in WhisperGenerationMixin._retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs)
   1355 if task is not None:
   1356     if task in TASK_IDS:
-> 1357         init_tokens[i].append(generation_config.task_to_id[generation_config.task])
   1358         task_id = generation_config.task_to_id[generation_config.task]
   1360         # if task is defined it'll overwrite task ids that might have already been defined via the generation_config

AttributeError: 'GenerationConfig' object has no attribute 'task_to_id'

@ylacombe
Copy link
Contributor

ylacombe commented Sep 4, 2024

Hey @blademoon,
It's true that, by default, the generation config of a transformers model doesn't fit the Whisper generation config. I think you might want to start from the Whisper generation config:

from transformers import WhisperConfig, WhisperForConditionalGeneration
from transformers.generation.configuration_utils import GenerationConfig


model_name = "openai/whisper-tiny"

configuration = WhisperConfig.from_pretrained(model_name)
generation_config = GenerationConfig.from_pretrained(model_name)
model = WhisperForConditionalGeneration(configuration)
model.generation_config = generation_config

What you might want to do as well is to look at the tiny model generation config to see if the config fits your need, and to modify some parameters in the code above if necessary !

@blademoon
Copy link
Author

blademoon commented Sep 4, 2024

@ylacombe Hello. I'm testing the variant you suggested.

A small question, if the model with random weights is trained in two languages at once, then the proposed variant will work correctly too?

@ylacombe
Copy link
Contributor

ylacombe commented Sep 4, 2024

Most probably, but you might want to guide the model a bit more by using language tokens

@blademoon
Copy link
Author

@ylacombe I use this guide from @sanchit-gandhi .

Now i have same warning but training work:

image

@blademoon
Copy link
Author

@ylacombe It may be easier to understand if you can see the big picture.

@ylacombe
Copy link
Contributor

ylacombe commented Sep 4, 2024

If the training does what you want (i.e transcribing Russian and English I guess ?) then you can ignore the warning!

@blademoon
Copy link
Author

@ylacombe OK. I'll check it out and come back with feedback. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!
Projects
None yet
Development

No branches or pull requests

3 participants