Imports

In [1]:
from transformers import (
    AutoProcessor,
    AutoModelForSpeechSeq2Seq,
    Wav2Vec2BertForCTC,
    Wav2Vec2BertProcessor,
    Wav2Vec2CTCTokenizer,
    SeamlessM4TFeatureExtractor,
    AutoProcessor,
    Trainer,
    Seq2SeqTrainingArguments,
    AutoConfig,
)
import json
import os
import random
import torch
from dataclasses import dataclass
from typing import Dict, List, Union
from accelerate import Accelerator
import warnings
from tqdm.auto import tqdm
import huggingface_hub
import datasets

In [2]:
BATCH_SIZE = 64
BASE_MODEL = "flozi00/distilwhisper-german-canary"
OUT_MODEL = "distilwhisper-german-canary-v2"
DATASET_NAME="flozi00/german-canary-asr-0324"
DATASET_SUBSET="default"
AUDIO_PATH="audio"
EPOCHS = 1
SAVE_STEPS = 1000
LR= 1e-5

Data Class for Collator

In [3]:
@dataclass
class ASRDataCollator:
    processor: AutoProcessor
    wav_key: str = os.getenv("AUDIO_PATH", "audio")
    locale_key: str = os.getenv("LOCALE_KEY", "german")
    text_key: str = os.getenv("TEXT_KEY", "transkription")
    max_audio_in_seconds: float = float(os.getenv("MAX_AUDIO_IN_SECONDS", 20.0))

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        input_features = []
        label_features = []

        for i in range(len(features)):
            try:
                feature = features[i]

                myaudio = feature[self.wav_key]["array"]
                mytext = feature[self.text_key]
            except Exception as e:
                print(e)
                continue

            audio_len = int((len(myaudio) / 16000))
            if audio_len > self.max_audio_in_seconds:
                print("skipping audio")
                continue

            # Extract the text from the feature and normalize it
            mylang = self.locale_key

            # Extract the audio features from the audio
            extracted = self.processor.feature_extractor(
                myaudio,
                sampling_rate=16000,
                return_tensors="pt",
            )

            # check if feature extractor return input_features or input_values
            ft = (
                "input_values"
                if hasattr(extracted, "input_values")
                else "input_features"
            )

            # append to input_features
            input_features.append(
                {
                    ft: getattr(
                        extracted,
                        ft,
                    )[0]
                }
            )

            # set prefix tokens if possible
            try:
                self.processor.tokenizer.set_prefix_tokens(mylang)
            except Exception:
                pass

            # append to label_features and tokenize
            label_features.append(
                {"input_ids": self.processor.tokenizer(mytext).input_ids}
            )

        batch = self.processor.feature_extractor.pad(
            input_features,
            padding="longest",
            return_tensors="pt",
        )

        labels_batch = self.processor.tokenizer.pad(
            label_features,
            padding="longest",
            return_tensors="pt",
        )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        batch["labels"] = labels

        return batch

preprocessing

In [4]:
def get_model(
    model_name: str
):
    kwargs = {}
    # get the config of the base model and extract the model type from it
    conf = AutoConfig.from_pretrained(
        pretrained_model_name_or_path=model_name,
        trust_remote_code=True,
    )


    processor = AutoProcessor.from_pretrained(
        model_name,
        legacy=False,
        trust_remote_code=True,
    )

    kwargs["attn_implementation"] = "sdpa"

    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_name,
        config=conf,
        low_cpu_mem_usage=True,
        torch_dtype=torch.bfloat16,
        **kwargs,
    )
    
    model = model.cuda()

    return model, processor

In [5]:
cv_data = datasets.load_dataset(
    DATASET_NAME,
    DATASET_SUBSET,
    split="train",
    streaming=True,
).cast_column(
    AUDIO_PATH,
    datasets.Audio(sampling_rate=16000, decode=True),
).with_format(
    "torch"
)

Resolving data files:   0%|          | 0/84 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/84 [00:00<?, ?it/s]

Main function

In [6]:
model, processor = get_model(
    model_name=BASE_MODEL
)

try:
    model.config.forced_decoder_ids = None
    model.config.suppress_tokens = []
except Exception:
    pass

cv_data = cv_data.shuffle(seed=random.randint(0, 1000))

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
training_args = Seq2SeqTrainingArguments(output_dir=OUT_MODEL, per_device_train_batch_size=BATCH_SIZE, push_to_hub=True, gradient_checkpointing=True, 
                                        dataloader_pin_memory=False, hub_private_repo=True, bf16=True,
                                        learning_rate=LR, optim="paged_lion_8bit", max_steps=int(985000/BATCH_SIZE),
                                        gradient_accumulation_steps=1, logging_steps=10, save_strategy="steps",
                                        save_steps=SAVE_STEPS, optim_target_modules=["attn", "mlp"], remove_unused_columns=False,
                                        gradient_checkpointing_kwargs={"use_reentrant": False})

In [8]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=cv_data,
    eval_dataset=None,
    tokenizer=processor.feature_extractor,
    data_collator=ASRDataCollator(processor=processor),
    compute_metrics=None,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [9]:
!nvidia-smi

Fri Apr  5 07:20:49 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.29.06              Driver Version: 545.29.06    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA L40S                    Off | 00000000:00:10.0 Off |                    0 |
| N/A   36C    P0              70W / 350W |   1997MiB / 46068MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
train_result = trainer.train()
trainer.save_model()  # Saves the feature extractor too for easy upload

[34m[1mwandb[0m: Currently logged in as: [33mflozi00[0m. Use [1m`wandb login --relogin`[0m to force relogin


`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Step,Training Loss
100,0.0537
