# Model PEFT training

## Create LoRA model

In [5]:
import os
import sys

sys.path.append(os.path.abspath(os.path.join('..')))

from modules.data import get_quebecois_data
from modules.peft import add_peft_to_model
from modules.model import get_whisper

model = get_whisper()

model.generation_config.language = "french"
model.generation_config.task = "transcribe"

model.requires_grad_(False)
lora_model = add_peft_to_model(model)
lora_model.print_trainable_parameters()


  from .autonotebook import tqdm as notebook_tqdm


trainable params: 3,594,240 || all params: 245,329,152 || trainable%: 1.4651


## Setup trainer

In [None]:
from modules.training import compute_metrics, DataCollatorSpeechSeq2SeqWithPadding
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="French", task="transcribe")

checkpoint_dir = "../checkpoints/"
# checkpoint_dir.mkdir(parents=True, exist_ok=True)
training_args = Seq2SeqTrainingArguments(
    output_dir=checkpoint_dir,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=4000,
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=1000,
    eval_steps=500,
    gradient_checkpointing=False,
    fp16=True,
    per_device_eval_batch_size=1,
    predict_with_generate=True,
    generation_max_length=225,
    logging_steps=25,
    report_to=[],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
    save_only_model=True,
    save_total_limit=2,
)
data = get_quebecois_data()
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor,
                                                     decoder_start_token_id=model.config.decoder_start_token_id)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=lora_model,
    train_dataset=data["train"],
    eval_dataset=data["test"],
    data_collator=data_collator,
    processing_class=processor,
    compute_metrics=compute_metrics,
)

## Training

In [None]:
trainer.train()