In [None]:
from utils.finetuning_utils import create_dataset, load_model, DataCollatorSpeechSeq2SeqWithPadding
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, WhisperProcessor
import os 
import mlflow

train_dataset = create_dataset('train_manifest_no_synthetic_wps_processed.json')
val_dataset = create_dataset('val_manifest_wps_processed.json')

model_pretrained = "openai/whisper-small"


# Create a folder for checkpoints
os.makedirs("finetuning/checkpoints", exist_ok=True)
os.makedirs("finetuning/experiments", exist_ok=True)


chekpoint_name = model_pretrained.split("/")[-1]
checkpoint_folder = os.path.join("finetuning/checkpoints", chekpoint_name)
experiment_folder = os.path.join("finetuning/experiments", chekpoint_name)

model = load_model(model_pretrained)
processor = WhisperProcessor.from_pretrained(model_pretrained, language="pt", task="transcribe")


data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id,
)   


training_args = Seq2SeqTrainingArguments(
    output_dir=checkpoint_folder,  
    gradient_accumulation_steps=8,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=16,
    learning_rate=1e-5,
    warmup_steps=int(0.5*len(train_dataset) * 3 / (32 * 4) / 10),
    gradient_checkpointing=False,
    fp16=True,
    num_train_epochs = 3,
    evaluation_strategy="steps",
    generation_max_length=448,
    predict_with_generate=True,
    save_steps=int(len(train_dataset) * 3 / (32 * 4) / 10), 
    eval_steps=int(len(train_dataset) * 3 / (32 * 4) / 10), 
    logging_steps=10,
    report_to=["mlflow"],
    push_to_hub=False,
)


trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    tokenizer=processor.feature_extractor,   
)


mlflow.set_experiment("finetuning/experiments")
with mlflow.start_run() as run:
    trainer.train()
    trainer.save_model(checkpoint_folder)