Imports

In [None]:
from transformers import AutoConfig
from transformers import (
    AutoProcessor,
    AutoModelForSpeechSeq2Seq,
    Wav2Vec2BertForCTC,
    Wav2Vec2BertProcessor,
    Wav2Vec2CTCTokenizer,
    SeamlessM4TFeatureExtractor,
    AutoProcessor,
)
import json
import os
import random
from torch.utils.data import DataLoader
import torch
from dataclasses import dataclass
from typing import Dict, List, Union
from accelerate import Accelerator
import warnings
from tqdm.auto import tqdm
from torch.optim.lr_scheduler import ExponentialLR
from bitsandbytes.optim import PagedAdamW8bit
import datasets

In [None]:
BATCH_SIZE = 2
BASE_MODEL = "flozi00/distilwhisper-german-canary"
OUT_MODEL = "distilwhisper-german-canary"
DATASET_NAME="flozi00/german-canary-asr-0324"
DATASET_SUBSET="default"
AUDIO_PATH="audio"
ACCUMULATION_STEPS = 4
EPOCHS = 10
SAVE_STEPS = 1000
LR= 1e-4

Data Class for Collator

In [None]:
@dataclass
class ASRDataCollator:
    processor: AutoProcessor
    wav_key: str = os.getenv("AUDIO_PATH", "audio")
    locale_key: str = os.getenv("LOCALE_KEY", "german")
    text_key: str = os.getenv("TEXT_KEY", "transkription")
    max_audio_in_seconds: float = float(os.getenv("MAX_AUDIO_IN_SECONDS", 20.0))

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        input_features = []
        label_features = []

        for i in range(len(features)):
            try:
                feature = features[i]

                myaudio = feature[self.wav_key]["array"]
                mytext = feature[self.text_key]
            except Exception as e:
                print(e)
                continue

            audio_len = int((len(myaudio) / 16000))
            if audio_len > self.max_audio_in_seconds:
                print("skipping audio")
                continue

            # Extract the text from the feature and normalize it
            mylang = self.locale_key

            # Extract the audio features from the audio
            extracted = self.processor.feature_extractor(
                myaudio,
                sampling_rate=16000,
                return_tensors="pt",
            )

            # check if feature extractor return input_features or input_values
            ft = (
                "input_values"
                if hasattr(extracted, "input_values")
                else "input_features"
            )

            # append to input_features
            input_features.append(
                {
                    ft: getattr(
                        extracted,
                        ft,
                    )[0]
                }
            )

            # set prefix tokens if possible
            try:
                self.processor.tokenizer.set_prefix_tokens(mylang)
            except Exception:
                pass

            # append to label_features and tokenize
            label_features.append(
                {"input_ids": self.processor.tokenizer(mytext).input_ids}
            )

        batch = self.processor.feature_extractor.pad(
            input_features,
            padding="longest",
            return_tensors="pt",
        )

        labels_batch = self.processor.tokenizer.pad(
            label_features,
            padding="longest",
            return_tensors="pt",
        )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        batch["labels"] = labels

        return batch

preprocessing

In [None]:
def extract_all_chars(sentences):
    all_text = " ".join(sentences)
    vocab = list(set(all_text))
    return {"vocab": [vocab], "all_text": [all_text]}


def make_ctc_processor(cv_data):
    vocab_train = extract_all_chars(cv_data[os.getenv("TEXT_KEY", "transkription")])

    vocab_list = list(set(vocab_train["vocab"][0]))
    vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
    vocab_dict["|"] = vocab_dict[" "]
    del vocab_dict[" "]
    vocab_dict["[UNK]"] = len(vocab_dict)
    vocab_dict["[PAD]"] = len(vocab_dict)
    with open("vocab.json", "w") as vocab_file:
        json.dump(vocab_dict, vocab_file)

    tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
        "./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|"
    )
    feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained(BASE_MODEL)
    processor = Wav2Vec2BertProcessor(
        feature_extractor=feature_extractor, tokenizer=tokenizer
    )

    return processor

def get_model(
    model_name: str,
    processor_name: str = None,
    vocab_size=None,
    processor=None,
):
    kwargs = {}
    use_flash_v2 = True
    # get the config of the base model and extract the model type from it
    conf = AutoConfig.from_pretrained(
        pretrained_model_name_or_path=model_name,
        trust_remote_code=True,
    )

    ctc_model = False
    keys = ["wav2vec", "w2v"]
    for key in keys:
        if key in model_name:
            ctc_model = True
            conf.attention_dropout=0.01
            conf.hidden_dropout=0.01
            conf.feat_proj_dropout=0.01
            conf.mask_time_prob=0.05
            conf.layerdrop=0.01
            conf.ctc_loss_reduction="mean"
            conf.add_adapter = False
            conf.num_adapter_layers = 1
            use_flash_v2 = False
            break
    model_class = AutoModelForSpeechSeq2Seq if ctc_model is False else Wav2Vec2BertForCTC
    tok_class = AutoProcessor


    if processor is None:
        processor = tok_class.from_pretrained(
            model_name if processor_name is None else processor_name,
            legacy=False,
            trust_remote_code=True,
        )

    if use_flash_v2:
        kwargs["attn_implementation"] = "sdpa"
    
    if vocab_size is not None:
        conf.vocab_size = vocab_size
        kwargs["ignore_mismatched_sizes"]=True

    model = model_class.from_pretrained(
        model_name,
        config=conf,
        **kwargs,
    )

    model = model.train()

    return model, processor

Trainingsloop

In [None]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]


def start_training(
    model,
    processor,
    dloader,
    OUT_MODEL,
    callback=None,
):
    accelerator = Accelerator(
        log_with="wandb",
        gradient_accumulation_steps=ACCUMULATION_STEPS,
        mixed_precision="fp16",
    )
    accelerator.init_trackers("huggingface")

    # print the number of total parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Number of parameters: {total_params}")

    # print the number of trainable parameters
    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Number of trainable parameters: {total_trainable_params}")


    optim = PagedAdamW8bit(model.parameters(), lr=LR)

    scheduler = ExponentialLR(optim, gamma=0.9995)
    model, optim, dloader, scheduler = accelerator.prepare(
        model, optim, dloader, scheduler
    )

    if callback is not None:
        eval_ = callback()
        if eval_ is not None:
            accelerator.log({"eval_metric": eval_}, step=0)

    def do_save_stuff():
        if callback is not None:
            eval_ = callback()
            if eval_ is not None:
                accelerator.log({"eval_metric": eval_}, step=index - 1)
        model.save_pretrained(
            OUT_MODEL,
            save_function=accelerator.save,
            state_dict=accelerator.get_state_dict(model),
            safe_serialization=False,
        )
        processor.save_pretrained(OUT_MODEL)

        try:
            model.push_to_hub(
                OUT_MODEL,
                save_function=accelerator.save,
                state_dict=accelerator.get_state_dict(model),
                safe_serialization=False,
            )
            processor.push_to_hub(OUT_MODEL)
        except Exception as e:
            warnings.warn(f"Could not push to hub: {e}")

    index = 1
    for epoch in range(EPOCHS):
        for data in (pbar := tqdm(dloader)):
            if index / ACCUMULATION_STEPS % SAVE_STEPS == 0 and index != 0:
                do_save_stuff()

            optim.zero_grad()
            with accelerator.accumulate(model), accelerator.autocast():
                data = {k: v.to(model.device) for k, v in data.items()}
                output = model(return_dict=True, **data)
                loss = output.loss
                accelerator.backward(loss)

                pbar.set_description(
                    f"Loss: {loss} LR: {get_lr(optim.optimizer)}",
                    refresh=True,
                )
                accelerator.log(
                    values={
                        "training_loss": loss,
                        "learning_rate": get_lr(optim.optimizer),
                    },
                    step=int(index / ACCUMULATION_STEPS),
                )
                if accelerator.sync_gradients:
                    accelerator.clip_grad_value_(model.parameters(), 0.9)
                optim.step()
                scheduler.step()

            index += 1
        do_save_stuff()
    do_save_stuff()


Main function

In [None]:
def main():
    cv_data = datasets.load_dataset(
        DATASET_NAME,
        DATASET_SUBSET,
        split="train",
    ).cast_column(
        AUDIO_PATH,
        datasets.Audio(sampling_rate=16000, decode=True),
    ).with_format(
        "torch"
    )

    if "w2v" in BASE_MODEL or "wav2vec" in BASE_MODEL:
        ctc_processor = make_ctc_processor(cv_data)
        vocab_size = len(ctc_processor.tokenizer)
    else:
        ctc_processor = None
        vocab_size = None

    model, processor = get_model(
        model_name=BASE_MODEL, vocab_size=vocab_size, processor=ctc_processor
    )

    try:
        model.config.forced_decoder_ids = None
        model.config.suppress_tokens = []
    except Exception:
        pass

    dataloader = ASRDataCollator(processor=processor)

    cv_data = cv_data.shuffle(seed=random.randint(0, 1000))
    dloader = DataLoader(
        cv_data,
        collate_fn=dataloader,
        batch_size=BATCH_SIZE,
        pin_memory=True,
        num_workers=0,
    )

    # start the training
    start_training(
        model=model,
        processor=processor,
        dloader=dloader,
        OUT_MODEL=OUT_MODEL,
    )

In [None]:
main()