In [None]:
!pip3 install --upgrade transformers peft bitsandbytes datasets accelerate loralib huggingface_hub jiwer evaluate wandb pythainlp

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [7]:
from datasets import load_dataset, DatasetDict, load_metric, Audio
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model, prepare_model_for_kbit_training

from dataclasses import dataclass
from typing import Any, Dict, List, Union
import os
import torch
import numpy as np
import pandas as pd
import evaluate



## Prepare Data

In [None]:
language_abbr = "th"
dataset_name = "mozilla-foundation/common_voice_13_0"

common_voice = DatasetDict()

common_voice["train"] = load_dataset(dataset_name, language_abbr, split="train+validation", token=True)
common_voice["test"] = load_dataset(dataset_name, language_abbr, split="test", token=True)

In [None]:
common_voice = common_voice.filter(lambda x: x == 0, input_columns=["down_votes"])

In [10]:
common_voice = common_voice.remove_columns(
    ["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes", "variant"]
)

In [11]:
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))

In [None]:
feature_extractor = WhisperFeatureExtractor.from_pretrained("distil-whisper/distil-large-v2")

In [None]:
tokenizer = WhisperTokenizer.from_pretrained("distil-whisper/distil-large-v2", language="Thai", task="transcribe")

In [14]:
common_voice

DatasetDict({
    train: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 39140
    })
    test: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 9332
    })
})

In [15]:
input_str = common_voice["train"][0]["sentence"]
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)

print(f"Input:                 {input_str}")
print(f"Decoded w/ special:    {decoded_with_special}")
print(f"Decoded w/out special: {decoded_str}")
print(f"Are equal:             {input_str == decoded_str}")

Input:                 ถ้าทำแบบนี้ ถูกไล่ออก ครอบครัวจะไปอยู่ไหน
Decoded w/ special:    <|startoftranscript|><|th|><|transcribe|><|notimestamps|>ถ้าทำแบบนี้ ถูกไล่ออก ครอบครัวจะไปอยู่ไหน<|endoftext|>
Decoded w/out special: ถ้าทำแบบนี้ ถูกไล่ออก ครอบครัวจะไปอยู่ไหน
Are equal:             True


In [16]:
processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v2", language="Thai", task="transcribe")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [17]:
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

In [None]:
common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=4)

In [19]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [20]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [21]:
cer_metric = load_metric("cer")

  cer_metric = load_metric("cer")


Downloading builder script:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

In [22]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer}

## Prepare Model

In [51]:
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v2", load_in_4bit=True, device_map="auto")

In [52]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

In [53]:
model = prepare_model_for_kbit_training(model)

In [None]:
config = LoraConfig(
    r=32, 
    lora_alpha=64, 
    target_modules=".*decoder.*(self_attn|encoder_attn).*(q_proj|v_proj)$",#["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none"
)

model = get_peft_model(model, config)
model.print_trainable_parameters()

## Training

In [58]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./distill-whisper-large-v2-thai-qlora",
    per_device_train_batch_size=16,
    gradient_accumulation_steps=2,
    learning_rate=1e-3,
    warmup_steps=100,
    max_steps=1000,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=500,
    eval_steps=500,
    logging_steps=500,
    num_train_epochs=3,
    report_to=["wandb"],
    load_best_model_at_end=True,
    metric_for_best_model="cer",
    greater_is_better=False,
    push_to_hub=False,
    save_total_limit=3,
    remove_unused_columns=False,  # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
    label_names=["labels"],  # same reason as above
    lr_scheduler_type="cosine",
)

In [59]:
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=common_voice["train"],
    eval_dataset=common_voice["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

In [None]:
trainer.train()

In [None]:
peft_model_id = "juierror/distill-whisper-large-v2-thai-qlora"
model.push_to_hub(peft_model_id)

In [None]:
torch.cuda.empty_cache()

## Evaluate

In [62]:
cer_metric = load_metric("cer")
wer_metric = load_metric("wer")

Downloading builder script:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

In [None]:
peft_model_id = "juierror/distill-whisper-large-v2-thai-qlora"
peft_config = PeftConfig.from_pretrained(peft_model_id)
model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, load_in_4bit=True, device_map="auto")
model = PeftModel.from_pretrained(model, peft_model_id)

In [68]:
from torch.utils.data import DataLoader
from pythainlp.tokenize import word_tokenize
from tqdm import tqdm

eval_dataloader = DataLoader(common_voice["test"], batch_size=16, collate_fn=data_collator)
all_labels = []
all_transcription = []
all_labels_token = []
all_transcription_token = []

model.eval()
for step, batch in enumerate(tqdm(eval_dataloader)):
    with torch.cuda.amp.autocast():
        with torch.no_grad():
            generated_tokens = (
                model.generate(
                    input_features=batch["input_features"].to("cuda"),
                    max_new_tokens=255,
                    language="Thai",
                    task="transcribe"
                )
                .cpu()
                .numpy()
            )
            labels = batch["labels"].cpu().numpy()
            transcriptions = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
            sentences = tokenizer.batch_decode(labels, skip_special_tokens=True)
            all_labels.extend(sentences)
            all_transcription.extend(transcriptions)
            
            cer_metric.add_batch(predictions=[pred_str.replace(" ", "") for pred_str in transcriptions], references=[label_str.replace(" ", "") for label_str in sentences])

            pred_str_newmm = [word_tokenize(text=e, engine='newmm', keep_whitespace=False) for e in transcriptions]
            label_str_newmm = [word_tokenize(text=e, engine='newmm', keep_whitespace=False) for e in sentences]
            all_labels_token.extend(label_str_newmm)
            all_transcription_token.extend(pred_str_newmm)
            wer_metric.add_batch(predictions=pred_str_newmm, references=label_str_newmm)

wer = 100 * wer_metric.compute()
cer = 100 * cer_metric.compute()
print(f"wer: {wer}")
print(f"cer: {cer}")

100%|██████████| 584/584 [2:00:02<00:00, 12.33s/it]  


wer: 62.49801320675654
cer: 19.122046340140493


In [71]:
report_df = pd.DataFrame(data={
    "labels": all_labels,
    "transcribe": all_transcription,
    "labels_tokenize": all_labels_token,
    "transcribe_tokenizer": all_transcription_token
})

In [72]:
report_df.to_csv("report.csv", index=False)