In [1]:
# !pip install --upgrade pip
!pip install --upgrade datasets[audio] transformers accelerate evaluate jiwer tensorboard gradio


Collecting transformers
  Downloading transformers-4.52.3-py3-none-any.whl.metadata (40 kB)
Collecting accelerate
  Downloading accelerate-1.7.0-py3-none-any.whl.metadata (19 kB)
Collecting gradio
  Downloading gradio-5.31.0-py3-none-any.whl.metadata (16 kB)
Collecting gradio-client==1.10.1 (from gradio)
  Downloading gradio_client-1.10.1-py3-none-any.whl.metadata (7.1 kB)
Downloading transformers-4.52.3-py3-none-any.whl (10.5 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.5/10.5 MB[0m [31m147.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading accelerate-1.7.0-py3-none-any.whl (362 kB)
Downloading gradio-5.31.0-py3-none-any.whl (54.2 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.2/54.2 MB[0m [31m151.7 MB/s[0m eta [36m0:00:00[0m MB/s[0m eta [36m0:00:01[0m
[?25hDownloading gradio_client-1.10.1-py3-none-any.whl (323 kB)
Installing collected packages: gradio-client, transformers, gradio, accelerate
[2K  Attem

In [1]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration, TrainingArguments, Trainer
from datasets import Dataset, Audio
import torch
import json
import os

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


Using device: cpu


# Dataset creation

Using the guideline: https://huggingface.co/blog/fine-tune-whisper

In [2]:
audio_path = "../creolese-audio-dataset/finetune_eligible"
transcription_path = "../creolese-audio-dataset/finetune_eligible/transcripts.json"

# Load transcripts JSON
with open(transcription_path, 'r') as f:
    transcripts = json.load(f)

# Create a list of dicts pairing audio files and transcripts
data = []
for item in transcripts:
    audio_file = os.path.join(audio_path, item['audio'])
    if os.path.exists(audio_file):
        print(f"Found file: {audio_file}")
        data.append({'audio': audio_file, 'text': item['text']})
    else:
        print(f"Missing file: {audio_file}")

Found file: ../creolese-audio-dataset/finetune_eligible/1a.wav
Found file: ../creolese-audio-dataset/finetune_eligible/1b.wav
Found file: ../creolese-audio-dataset/finetune_eligible/1c.wav
Found file: ../creolese-audio-dataset/finetune_eligible/1d.wav
Found file: ../creolese-audio-dataset/finetune_eligible/1e.wav
Found file: ../creolese-audio-dataset/finetune_eligible/3a.wav
Found file: ../creolese-audio-dataset/finetune_eligible/3b.wav
Found file: ../creolese-audio-dataset/finetune_eligible/2.wav
Found file: ../creolese-audio-dataset/finetune_eligible/4a.wav
Found file: ../creolese-audio-dataset/finetune_eligible/4b.wav
Found file: ../creolese-audio-dataset/finetune_eligible/4c.wav
Found file: ../creolese-audio-dataset/finetune_eligible/4d.wav
Found file: ../creolese-audio-dataset/finetune_eligible/4e.wav
Found file: ../creolese-audio-dataset/finetune_eligible/5a.wav
Found file: ../creolese-audio-dataset/finetune_eligible/5b.wav
Found file: ../creolese-audio-dataset/finetune_eligible/

In [4]:
dataset = Dataset.from_list(data)

# Cast the audio column to automatically load audio
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
print(dataset)



Dataset({
    features: ['audio', 'text'],
    num_rows: 239
})


## Load the Model

In [5]:
from transformers import WhisperFeatureExtractor

model_id = "openai/whisper-large-v3"  
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-large-v3")
processor = WhisperProcessor.from_pretrained(model_id, task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained(model_id)

model.to(device)


WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(128, 1280, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 1280)
      (layers): ModuleList(
        (0-31): 32 x WhisperEncoderLayer(
          (self_attn): WhisperSdpaAttention(
            (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
            (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1280, out_features=5120, bias=True)
          (fc2): Linear(in_features=5120, out_features=1280, bia

In [6]:
import numpy as np

#explicitly pad the data
def pad_input_features(features, target_length=3000):
    padded = []
    for f in features:
        current_length = f.shape[-1]
        if current_length < target_length:
            pad_width = target_length - current_length
            f = np.pad(f, ((0, 0), (0, pad_width)), mode="constant", constant_values=0)
        padded.append(f)
    return padded

def batch_prepare_dataset(examples):
    audio_arrays = [audio["array"] for audio in examples["audio"]]
    sampling_rates = [audio["sampling_rate"] for audio in examples["audio"]]

    # Extract input features (mel spectrograms)
    input_features = processor.feature_extractor(
        audio_arrays,
        sampling_rate=sampling_rates[0],
        return_tensors="np"
    )["input_features"]

    # Pad each feature to length 3000
    input_features = pad_input_features(input_features)

    # Convert to torch tensors
    input_features = torch.tensor(np.array(input_features))

    # Process labels
    labels = processor.tokenizer(
        examples["text"], return_tensors="pt", padding=True
    ).input_ids

    return {
        "input_features": input_features,
        "labels": labels
    }


# Process in batches
prepared_dataset = dataset.map(
    batch_prepare_dataset,
    batched=True,
    batch_size=1,  # Adjust based on memory
    remove_columns=dataset.column_names,
    num_proc=1
)

Map:   0%|          | 0/239 [00:00<?, ? examples/s]

In [9]:
model.generation_config.task = "transcribe"
model.generation_config.language = None 
model.generation_config.forced_decoder_ids = None

#This is unnecessary

In [7]:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [ f["input_features"] for f in features]
        batch = self.processor.feature_extractor.pad({"input_features": input_features}, return_tensors="pt", padding=True)
    
        label_features = [{"input_ids": f["labels"]} for f in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt", padding=True)
    
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
    
        # Remove BOS if present
        if (labels[:, 0] == self.decoder_start_token_id).all():
            labels = labels[:, 1:]
    
        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id
)

In [8]:
# First, update your batch_prepare_dataset function


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Handle input features separately with proper padding
        input_features = [{"input_features": feature["input_features"].unsqueeze(0)} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Handle labels
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding with -100 for loss calculation
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # Remove BOS token if present
        if (labels[:, 0] == self.decoder_start_token_id).all():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch

In [19]:
import jiwer

transform = jiwer.Compose([
        jiwer.ToLowerCase(),
        jiwer.RemovePunctuation(),
        jiwer.Strip(),
        jiwer.RemoveMultipleSpaces(),
])
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = torch.argmax(torch.tensor(pred_logits), dim=-1)
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = jiwer.wer(label_str, pred_str)
    mer = jiwer.mer(label_str, pred_str)
    cer = jiwer.cer(label_str, pred_str)
    return {"wer": wer, "mer": mer, "cer": cer}



In [20]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-large-v3-creolese-finetuned",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    learning_rate=1e-5,
    max_steps=50,  
    gradient_checkpointing=True,
    fp16=torch.cuda.is_available(),
    do_eval= True,
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=2,
    logging_steps=5,
    report_to=["tensorboard"],  # or ["none"]
    push_to_hub=False,
    eval_steps=5,            # Every 50 steps
    eval_strategy="steps",
)


In [21]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=prepared_dataset,
    eval_dataset=prepared_dataset,  # or add eval split if available
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=processor.feature_extractor
)


In [23]:
trainer.train(resume_from_checkpoint=True)
# trainer.train()


There were missing keys in the checkpoint model loaded: ['proj_out.weight'].


Step,Training Loss,Validation Loss,Wer,Mer,Cer
5,2.5153,2.106143,0.999943,0.999943,0.999401
10,2.0832,1.820553,0.999943,0.999943,0.99936
15,1.6062,1.636009,0.999887,0.999887,0.999442
20,1.5324,1.517083,0.999943,0.999943,0.999513


KeyboardInterrupt: 