Imports

In [None]:
!pip install flash-attn transformers galore-torch --no-build-isolation --upgrade

In [None]:
from transformers import (
    AutoProcessor,
    AutoModelForSpeechSeq2Seq,
    Wav2Vec2BertForCTC,
    Wav2Vec2BertProcessor,
    Wav2Vec2CTCTokenizer,
    SeamlessM4TFeatureExtractor,
    AutoProcessor,
    Trainer,
    Seq2SeqTrainingArguments,
    AutoConfig,
)
import json
import os
import random
import torch
from dataclasses import dataclass
from typing import Dict, List, Union
from accelerate import Accelerator
import warnings
from tqdm.auto import tqdm
import huggingface_hub
import datasets

In [None]:
BATCH_SIZE = 2
BASE_MODEL = "facebook/w2v-bert-2.0"
OUT_MODEL = "w2v-bert-german-canary"
DATASET_NAME="flozi00/german-canary-asr-0324"
DATASET_SUBSET="default"
AUDIO_PATH="audio"
EPOCHS = 10
SAVE_STEPS = 1000
LR= 1e-5

Data Class for Collator

In [None]:
@dataclass
class ASRDataCollator:
    processor: AutoProcessor
    wav_key: str = os.getenv("AUDIO_PATH", "audio")
    locale_key: str = os.getenv("LOCALE_KEY", "german")
    text_key: str = os.getenv("TEXT_KEY", "transkription")
    max_audio_in_seconds: float = float(os.getenv("MAX_AUDIO_IN_SECONDS", 20.0))

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        input_features = []
        label_features = []

        for i in range(len(features)):
            try:
                feature = features[i]

                myaudio = feature[self.wav_key]["array"]
                mytext = feature[self.text_key]
            except Exception as e:
                print(e)
                continue

            audio_len = int((len(myaudio) / 16000))
            if audio_len > self.max_audio_in_seconds:
                print("skipping audio")
                continue

            # Extract the text from the feature and normalize it
            mylang = self.locale_key

            # Extract the audio features from the audio
            extracted = self.processor.feature_extractor(
                myaudio,
                sampling_rate=16000,
                return_tensors="pt",
            )

            # check if feature extractor return input_features or input_values
            ft = (
                "input_values"
                if hasattr(extracted, "input_values")
                else "input_features"
            )

            # append to input_features
            input_features.append(
                {
                    ft: getattr(
                        extracted,
                        ft,
                    )[0]
                }
            )

            # set prefix tokens if possible
            try:
                self.processor.tokenizer.set_prefix_tokens(mylang)
            except Exception:
                pass

            # append to label_features and tokenize
            label_features.append(
                {"input_ids": self.processor.tokenizer(mytext).input_ids}
            )

        batch = self.processor.feature_extractor.pad(
            input_features,
            padding="longest",
            return_tensors="pt",
        )

        labels_batch = self.processor.tokenizer.pad(
            label_features,
            padding="longest",
            return_tensors="pt",
        )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        batch["labels"] = labels

        return batch

preprocessing

In [None]:
def extract_all_chars(sentences):
    all_text = " ".join(sentences)
    vocab = list(set(all_text))
    return {"vocab": [vocab], "all_text": [all_text]}


def make_ctc_processor(cv_data):
    vocab_train = extract_all_chars(cv_data[os.getenv("TEXT_KEY", "transkription")])

    vocab_list = list(set(vocab_train["vocab"][0]))
    vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
    vocab_dict["|"] = vocab_dict[" "]
    del vocab_dict[" "]
    vocab_dict["[UNK]"] = len(vocab_dict)
    vocab_dict["[PAD]"] = len(vocab_dict)
    with open("vocab.json", "w") as vocab_file:
        json.dump(vocab_dict, vocab_file)

    tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
        "./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|"
    )
    feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained(BASE_MODEL)
    processor = Wav2Vec2BertProcessor(
        feature_extractor=feature_extractor, tokenizer=tokenizer
    )

    return processor

def get_model(
    model_name: str,
    processor_name: str = None,
    vocab_size=None,
    processor=None,
):
    kwargs = {}
    use_flash_v2 = True
    # get the config of the base model and extract the model type from it
    conf = AutoConfig.from_pretrained(
        pretrained_model_name_or_path=model_name,
        trust_remote_code=True,
    )

    ctc_model = False
    keys = ["wav2vec", "w2v"]
    for key in keys:
        if key in model_name:
            ctc_model = True
            conf.attention_dropout=0.01
            conf.hidden_dropout=0.01
            conf.feat_proj_dropout=0.01
            conf.mask_time_prob=0.05
            conf.layerdrop=0.01
            conf.ctc_loss_reduction="mean"
            conf.add_adapter = False
            conf.num_adapter_layers = 1
            use_flash_v2 = False
            break
    model_class = AutoModelForSpeechSeq2Seq if ctc_model is False else Wav2Vec2BertForCTC
    tok_class = AutoProcessor


    if processor is None:
        processor = tok_class.from_pretrained(
            model_name if processor_name is None else processor_name,
            legacy=False,
            trust_remote_code=True,
        )

    if use_flash_v2:
        kwargs["attn_implementation"] = "sdpa"
    
    if vocab_size is not None:
        conf.vocab_size = vocab_size
        kwargs["ignore_mismatched_sizes"]=True

    model = model_class.from_pretrained(
        model_name,
        config=conf,
        low_cpu_mem_usage=True,
        **kwargs,
    )
    
    model = model.cuda()

    return model, processor

In [None]:
cv_data = datasets.load_dataset(
    DATASET_NAME,
    DATASET_SUBSET,
    split="train",
    streaming=False,
).cast_column(
    AUDIO_PATH,
    datasets.Audio(sampling_rate=16000, decode=True),
).with_format(
    "torch"
)

Main function

In [None]:
if "w2v" in BASE_MODEL or "wav2vec" in BASE_MODEL:
    ctc_processor = make_ctc_processor(cv_data)
    vocab_size = len(ctc_processor.tokenizer)
else:
    ctc_processor = None
    vocab_size = None

model, processor = get_model(
    model_name=BASE_MODEL, vocab_size=vocab_size, processor=ctc_processor
)

try:
    model.config.forced_decoder_ids = None
    model.config.suppress_tokens = []
except Exception:
    pass

cv_data = cv_data.shuffle(seed=random.randint(0, 1000))

In [None]:
training_args = Seq2SeqTrainingArguments(output_dir=OUT_MODEL, per_device_train_batch_size=BATCH_SIZE, push_to_hub=True, gradient_checkpointing=True, 
                                        dataloader_pin_memory=False, hub_private_repo=True,
                                        learning_rate=LR, optim="galore_adamw_layerwise", num_train_epochs=EPOCHS,
                                        gradient_accumulation_steps=1, logging_steps=100, save_strategy="steps",
                                        save_steps=SAVE_STEPS, optim_target_modules=["attn", "mlp"], remove_unused_columns=False)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=cv_data,
    eval_dataset=None,
    tokenizer=processor.feature_extractor,
    data_collator=ASRDataCollator(processor=processor),
    compute_metrics=None,
)

In [None]:
!nvidia-smi

In [None]:
train_result = trainer.train()
trainer.save_model()  # Saves the feature extractor too for easy upload