In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" huggingface_hub hf_transfer
    !pip install --no-deps unsloth
!pip install librosa soundfile evaluate jiwer

In [None]:
%%capture
!pip install transformers==4.51.3

In [None]:
from unsloth import FastModel
from transformers import WhisperForConditionalGeneration
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import pipeline
import torch
import evaluate

from datasets import load_dataset, Audio

import numpy as np
from dataclasses import dataclass
from typing import Any
import tqdm

In [None]:
model, tokenizer = FastModel.from_pretrained(
    model_name="unsloth/whisper-large-v3",
    dtype=None,
    load_in_4bit=False,
    auto_model=WhisperForConditionalGeneration,
    whisper_language="English",
    whisper_task="transcribe",
)

In [None]:
model = FastModel.get_peft_model(
    model,
    r=64,
    target_modules=["q_proj", "v_proj"],
    lora_alpha=64,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    use_rslora=False,
    loftq_config=None,
    task_type=None,
)

In [None]:
model.generation_config.language = "<|en|>"
model.generation_config.task = "transcribe"
model.config.suppress_tokens = []
model.generation_config.forced_decoder_ids = None

def formatting_prompts_func(example):
    audio_arrays = example['audio']['array']
    sampling_rate = example["audio"]["sampling_rate"]
    features = tokenizer.feature_extractor(
        audio_arrays, sampling_rate=sampling_rate
    )
    tokenized_text = tokenizer.tokenizer(example["text"])
    return {
        "input_features": features.input_features[0],
        "labels": tokenized_text.input_ids,
    }

dataset = load_dataset("MrDragonFox/Elise", split="train")

dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
dataset = dataset.train_test_split(test_size=0.06)
train_dataset = [formatting_prompts_func(example) for example in tqdm.tqdm(dataset['train'], desc='Train split')]
test_dataset = [formatting_prompts_func(example) for example in tqdm.tqdm(dataset['test'], desc='Test split')]

In [None]:
metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    label_ids[label_ids == -100] = tokenizer.pad_token_id

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: list[dict[str, list[int] | torch.Tensor]]) -> dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    train_dataset=train_dataset,
    data_collator=DataCollatorSpeechSeq2SeqWithPadding(processor=tokenizer),
    eval_dataset=test_dataset,
    tokenizer=tokenizer.feature_extractor,
    args=Seq2SeqTrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        warmup_steps=5,
        max_steps=60,
        learning_rate=1e-4,
        logging_steps=1,
        optim="adamw_8bit",
        fp16=True,
        bf16=False,
        weight_decay=0.01,
        remove_unused_columns=False,  
        lr_scheduler_type="linear",
        label_names=['labels'],
        eval_steps=5 ,
        eval_strategy="steps",
        seed=3407,
        output_dir="outputs",
        report_to="none",
    ),
)

In [None]:
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)

print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

In [None]:
trainer_stats = trainer.train()

In [None]:
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)

print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

In [None]:
!wget https://upload.wikimedia.org/wikipedia/commons/5/5b/Speech_12dB_s16.flac

In [None]:
from IPython.display import Audio, display
display(Audio("Speech_12dB_s16.flac", rate = 24000))

In [None]:
FastModel.for_inference(model)
model.eval()

whisper = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=tokenizer.tokenizer,
    feature_extractor=tokenizer.feature_extractor,
    processor=tokenizer,
    return_language=True,
    torch_dtype=torch.float16
)

audio_file = "Speech_12dB_s16.flac"
transcribed_text = whisper(audio_file)
print(transcribed_text["text"])

In [None]:
model.save_pretrained("lora_model")
tokenizer.save_pretrained("lora_model")

In [None]:
# Merge to 16bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = None,)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")

# Merge to 4bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")

# Just LoRA adapters
if False:
    model.save_pretrained("model")
    tokenizer.save_pretrained("model")
if False:
    model.push_to_hub("hf/model", token = "")
    tokenizer.push_to_hub("hf/model", token = "")