In [None]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration

model_name = "openai/whisper-tiny"  # or small, medium, etc.
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name)

In [None]:
import torchaudio


def preprocess_function(batch):
    # Load and resample audio to 16kHz
    speech_array, sampling_rate = torchaudio.load(batch["media_filepath"])
    if sampling_rate != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
        speech_array = resampler(speech_array)

    speech_array = speech_array.squeeze().numpy()

    # Process audio
    inputs = processor(
        speech_array,
        sampling_rate=16000,
        return_tensors="pt",
    )

    # Process labels (text transcription)
    labels = processor.tokenizer(batch["raw_transcript_stripped"]).input_ids

    inputs["labels"] = labels
    return inputs

In [None]:
from datasets import load_dataset
# load parquet file as a Hugging Face dataset
dataset = load_dataset('parquet', data_files="../data/testing/veterans_history_project_sample.parquet")

In [None]:
processed_dataset = dataset.map(
    preprocess_function,
    remove_columns=dataset.column_names["train"],
    batched=False,
)

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./whisper-finetune",
    per_device_train_batch_size=1,
    num_train_epochs=3,
    logging_steps=5,
    save_strategy="epoch",
    eval_strategy="epoch",
    fp16=False,
)

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset["train"],
    eval_dataset=processed_dataset["validation"],
    tokenizer=processor.feature_extractor,  # Feature extractor, not tokenizer
)

In [None]:
trainer.train()