In [None]:
!pip install transformers
!pip install torch
!pip install librosa
!pip install soundfile

## This notebook contains everything necessary to perform batch transcriptions using the imported checkpoints. The current imports correspond to the fine-tuned versions, which can be found on my HF account.

In [None]:
import torch
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import librosa
import soundfile as sf

In [None]:
import os

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Import fine-tuned checkpoints
model_names = {
    "base": "mariatepei/whisper-base-synthetic_model",
    "small": "mariatepei/whisper-small-synthetic_model",
    "medium": "mariatepei/whisper-medium-synthetic_model",
    "large": "mariatepei/whisper-large-synthetic_model"
}

models = {}
processors = {}

for size, model_name in model_names.items():
  processor = WhisperProcessor.from_pretrained(model_name)
  model = WhisperForConditionalGeneration.from_pretrained(model_name)
  models[size] = model
  processors[size] = processor


In [None]:
test_folder_path = "/content/drive/My Drive/test"

In [None]:
for audio_file in os.listdir(test_folder_path):
  if audio_file.endswith(".wav"):
    audio_file_path = os.path.join(test_folder_path, audio_file)
    audio_input, sample_rate = librosa.load(audio_file_path, sr = 16000)

    for size in model_names.keys():
      processor = processors[size]
      model = models[size]

      input_features = processor(audio_input, sampling_rate = sample_rate, return_tensors = "pt").input_features
      # Specify the language and task, otherwise whisper will predict it (often wrongly)
      forced_decoder_ids = processor.get_decoder_prompt_ids(language="dutch", task="transcribe")
      predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)

      transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
      transcription_file_path = os.path.join(test_folder_path, f"{audio_file}_transcription_n_{size}.txt")

      with open(transcription_file_path, "w", encoding = "utf8") as f:
        f.write(transcription)

      print(f"Transcription for {audio_file} using {size} model saved to {transcription_file_path}")