Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MMS] Scaling Speech Technology to 1,000+ Languages | Add attention adapter to Wav2Vec2 #23813

Merged
merged 29 commits into from
Jun 2, 2023

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented May 27, 2023

What does this PR do?

This PR adds the MMS models fine-tuned on speech recognition.
See official announcement here: https://about.fb.com/news/2023/05/ai-massively-multilingual-speech-technology/
See more details here: https://github.com/facebookresearch/fairseq/blob/main/examples/mms/README.md#asr

Fixes #23811 and #23665

For now checkpoints are uploaded here:

Pretrained-only

ASR fine-tuned

The fine-tuned checkpoints are based on Adapter layers as can be seen in this PR. The ASR fine-tuned weights consist of two parts:

  • The non-adapter weights which are exactly the same as the base model weights
  • Language specific fine-tuned adapter layer weights. This means we have 1000+ adapter weights for mms-1b-all

If one wants to use a specific language, specific adapter weights need to be loaded into mms-1b-all.
By default mms-1b-all et. al load the English adapter layer weights as is currently done in https://huggingface.co/patrickvonplaten/mms-1b-all

The following works with this PR:

from transformers import Wav2Vec2ForCTC, AutoProcessor
import soundfile as sf
import torch

ckpt = "./mms-1b-fl102/"
ckpt = "./mms-1b-l1107"
ckpt = "./mms-1b-all/"

processor = AutoProcessor.from_pretrained(ckpt)
model = Wav2Vec2ForCTC.from_pretrained(ckpt)

# get audio.flac from https://huggingface.co/datasets/patrickvonplaten/audios/blob/main/audio.flac
audio, sr = sf.read("./audio.flac")

inputs = processor(audio, sampling_rate=sr, return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

transcription = processor.batch_decode(logits.argmax(-1))[0]

print(f"Transcription: {transcription}")

Now, the question what API to we want to build for allow the user to easily switch between languages for the fine-tuned weights.

Note:

  • To switch from one language to another, both the tokenizer's vocab and the model's adapter layers need to be switched out
  • The tokenizer can always easily hold all langs dicts in RAM because each lang has around 150 entries so we have 150,000 entries which is not too much for RAM
  • However, things are a bit more tricky for the model. The base model requires 3.1 GB in FP32 RAM and each adapter weights are around 9MB in size. This means loading all adapter layers into RAM would cost ~9GB which is quite a bit.

How should we design this model? We need to have some kind of switching between languages function anyways. I see the following APIs that could work.

1.) By default, we download all adapter layers and load all in RAM, but we provide a functionality to remove all language but one from RAM:

from transformers import Wav2Vec2ForCTC, AutoProcessor

ckpt = "./mms-1b-all/"

processor = AutoProcessor.from_pretrained(ckpt)
model = Wav2Vec2ForCTC.from_pretrained(ckpt)  # requires at least 10GB of CPU RAM

target_lang = "esp"

processor.set_lang("esp")
adapter_id = processor.lang_to_id["esp"]
model.set_adapter_weights(adapter_id) # throw away all but one weights => 3.1GB of CPU RAM

model.to("cuda")

A problem with this is though also that it's not trivial to switch between languages because one needs to load the whole model again and then set the language again. Also we would have to add a set_adapter_weights function to Wav2Vec2 which is not ideal

2.) By default we only the adapter weights one of language (e.g. English) and the load upon request more adapter layers

```py
from transformers import Wav2Vec2ForCTC, AutoProcessor

ckpt = "./mms-1b-all/"

processor = AutoProcessor.from_pretrained(ckpt)
model = Wav2Vec2ForCTC.from_pretrained(ckpt)  # requires only 3GB of CPU RAM

target_lang = "esp"

processor.set_lang("esp")
model.load_adapter("esp") # This will load a file called "adapter.esp.bin" from: https://huggingface.co/patrickvonplaten/mms-1b-all , cache it and replace the adapter

model.to("cuda")

Think this is quite user-friendly, intuitive and this way we also never require more than 3.1 GB of RAM. It however requires to add a pretty specific load_adapter function to Wav2Vec2 (think it's fine though).

3.) We just upload 1000+ repos one for each language. This way we don't need any "set" or "load" function and we just tread each adapter weights as their own model:

from transformers import Wav2Vec2ForCTC, AutoProcessor

ckpt = "./mms-1b-all-esp/" # repo names then become lang specific

processor = AutoProcessor.from_pretrained(ckpt)
model = Wav2Vec2ForCTC.from_pretrained(ckpt)  # requires only 3GB of CPU RAM
model.to("cuda")

Big disadvantage is that it's pretty wasteful since an adapter layer is just 0.3% of all the models weights.

=> Overall, I'm tending to API 2.) because it's the most user-friendly and intuitive. It'd just require to add a somewhat specific "load_adapter" function to Wav2Vec2, but think that's totally fine.

Thoughts @sanchit-gandhi @Vaibhavs10 @sgugger @LysandreJik @amyeroberts ?

@patrickvonplaten patrickvonplaten changed the title add fine-tuned with adapter layer [RFC] Add fine-tuned with adapter layer May 27, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 27, 2023

The documentation is not available anymore as the PR was closed or merged.

@Vaibhavs10
Copy link
Member

Hey @patrickvonplaten - Thanks for working on this and I reviewed the options provided. I believe the second one would work best from a developer standpoint. IMO it ensures that all the adapter weights are in one repository and it all works the way it should, should someone want to use a different language with the base model.

I am not a big fan of option 1 because it would make it difficult for a model to run in a resource-constrained environment.

I am a bit conflicted with option 3, primarily because it involves the end-user having the same experience with Wav2Vec2 without worrying about the specific language adapter layers and so on. Although having 1000+ repos for the same sounds a bit wasteful IMO.

Question: How would this work for fine-tuning, I am assuming if someone fine-tunes the Wav2Vec2-MMS on a language "X" then they'll push their adapter weights to a new repo and pull from that. So that'd mean that purely from a UX perspective, we should allow for the load_adapter function to be able to pull from a separate repository too right?

@sgugger
Copy link
Collaborator

sgugger commented May 30, 2023

I think 2 is probably the better solution, and I would also make it possible to set the lang in the from_pretrained call:

from transformers import Wav2Vec2ForCTC, AutoProcessor

ckpt = "./mms-1b-all/"

processor = AutoProcessor.from_pretrained(ckpt)
model = Wav2Vec2ForCTC.from_pretrained(ckpt, target_lang="esp")

processor.set_lang("esp")

model.to("cuda")

### Stuff
# want to change the language:
model.load_adapter("fra") 

@sanchit-gandhi
Copy link
Contributor

+1 on the composite solution proposed by @sgugger. Regarding fine-tuning @Vaibhavs10, users will save both the fine-tuned base weights and adapter layer weights to the same repo (this is different to PEFT where we only save the adapter weights, since here the base weights are also trainable. The way to view the adapter layer is as a extra small feed-forward network on top of the transformer block, so a regular layer of weights rather than a parameter efficient one), so probably we can assume we're loading the base weights and adapter weights from the same repo.

@amyeroberts
Copy link
Collaborator

Agreed, with all above - 2 would be my choice:

  • 1 doesn't feel very user friendly. I'd expect most people would only use a consistent subset so downloading everything is slow and wasteful.
  • 2 feels the most intuitive with the current API and flexible. Seconding @Vaibhavs10's questions about finetuning, pushing to the hub and loading finetuned weights. If we load model weights from mms-1b-fl102 and want our own finetuned adapter weights, how do I specify when loading and how is this information saved? How would we differentiate weights such that when I call model.push_to_hub the adapter weights are uploaded separately from the rest of the model (pattern matching?) Should the adapter weights be tied to a specific version of the 'base model' weights?
  • 3 Probably simplest to do - but seems like a waste with many repeated weights.

@patrickvonplaten
Copy link
Contributor Author

I'll leave more in-detail functionality for fine-tuning adapter weights for a future PR, but in short we can already do the following:

from transformers import Wav2Vec2ForCTC

ckpt = "patrickvonplaten/mms-1b"
model = Wav2Vec2ForCTC.from_pretrained(ckpt, num_attn_adapters=1, vocab_size=277)

adapter_keys = set(model._adapters.keys())
for name, param in model.named_parameters():
    if name not in adapter_keys:
        param.requires_grad = False

So once we add adapter fine-tuning to the wav2vec2 fine-tuning script, we could also add a simple "freeze_all_but_adapter()" function or something.

@patrickvonplaten patrickvonplaten changed the title [RFC] Add fine-tuned with adapter layer Add fine-tuned with adapter layer May 31, 2023
@patrickvonplaten patrickvonplaten changed the title Add fine-tuned with adapter layer [MMS] Scaling Speech Technology to 1,000+ Languages | Add attention adapter to Wav2Vec2 May 31, 2023
@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented May 31, 2023

The code is now finished. I still need to upload the adapters for the smaller checkpoints, transfer them to Facebook and write some nice docs.

All modeling files except Wav2Vec2 are changed due to the #Copied from mechanism. I think this is better than removing the copy-from mechanism, but happy to change.

patrickvonplaten and others added 4 commits June 1, 2023 17:55
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice PR and new model abilities. Thanks for adding! Tests especially were clear and helped show the expected behaviour 🤗

Mostly just nits - main comment is some missing asserts in the tests.

I'm not sure sure about the added logic and layers to the models copying from wav2vec2. Classic inheritance problem but it's a bit counter intuitive to have calls to a method -- load_adapter -- which the model doesn't have in the modeling code. Not a big issue -- if we find it confuses users, then we can handle e.g. def load_adapter which raises NotImplementedError. Noting just as a limitation of adapting models with copying logic.

patrickvonplaten and others added 3 commits June 1, 2023 19:17
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@patrickvonplaten patrickvonplaten merged commit 5dfd407 into main Jun 2, 2023
@patrickvonplaten patrickvonplaten deleted the add_wav2vec2_mms branch June 2, 2023 09:30
gojiteji pushed a commit to gojiteji/transformers that referenced this pull request Jun 5, 2023
…dapter to Wav2Vec2 (huggingface#23813)

* add fine-tuned with adapter layer

* Add set_target_lang to tokenizer

* Implement load adapter

* add tests

* make style

* Apply suggestions from code review

* Update src/transformers/models/wav2vec2/tokenization_wav2vec2.py

* make fix-copies

* Apply suggestions from code review

* make fix-copies

* make style again

* mkae style again

* fix doc string

* Update tests/models/wav2vec2/test_tokenization_wav2vec2.py

* Apply suggestions from code review

* fix

* Correct wav2vec2 adapter

* mkae style

* Update src/transformers/models/wav2vec2/modeling_wav2vec2.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* add more nice docs

* finish

* finish

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Apply suggestions from code review

* all finish

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@ydshieh ydshieh mentioned this pull request Jun 12, 2023
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
…dapter to Wav2Vec2 (huggingface#23813)

* add fine-tuned with adapter layer

* Add set_target_lang to tokenizer

* Implement load adapter

* add tests

* make style

* Apply suggestions from code review

* Update src/transformers/models/wav2vec2/tokenization_wav2vec2.py

* make fix-copies

* Apply suggestions from code review

* make fix-copies

* make style again

* mkae style again

* fix doc string

* Update tests/models/wav2vec2/test_tokenization_wav2vec2.py

* Apply suggestions from code review

* fix

* Correct wav2vec2 adapter

* mkae style

* Update src/transformers/models/wav2vec2/modeling_wav2vec2.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* add more nice docs

* finish

* finish

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Apply suggestions from code review

* all finish

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Metas MMS speech recognition
7 participants