In [1]:
import torch
from transformers import (
    AutoModelForSpeechSeq2Seq,
    AutoProcessor,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
import evaluate
from datasets import Dataset, DatasetDict, load_from_disk

import numpy as np
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import re
import string

from text2digits import text2digits

In [2]:
device = "cuda:0"
torch_dtype = torch.float32
model_id = "openai/whisper-large-v3"

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
)
model.to(device)

# Configure generation
model.generation_config.language = "english"
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None

In [3]:
cslu = load_from_disk("../data/cslu_kids_segmented.ds").train_test_split(test_size=0.2, seed=42)

In [4]:
t2d = text2digits.Text2Digits()
punctuation_remover = str.maketrans('', '', string.punctuation)

def normalize_transcript(text):
    # The original transcript has annotations, for example a pause is <pau>
    # Remove tags in angle brackets
    text = re.sub(r'<[^>]*>', '', text)
    
    # These are "false starts" in the original transcript, for example th*
    # These are ignored by ASR
    # Remove words that end with asterisks (e.g., th*)
    text = re.sub(r'\S*\*', '', text)

    # Remove all punctuation
    text = text.translate(punctuation_remover)

    # Clean up excess spaces in the original transcript or resulting from above operations
    text = re.sub(r'\s+', ' ', text)

    # Convert number representations, e.g., "thirteen" to "13"
    # This is imperfect (e.g., does not know when "one" is a pronoun vs. a number)
    # But we apply the same normalization to the predicted text as well as the gold text,
    # so works fine for Word Error Rate
    try:
        normalized_text = t2d.convert(text)
    except:
        normalized_text = text

    return normalized_text.strip().lower()

In [5]:
def prepare_dataset(batch):
    # Load audio
    audio = batch["audio"]
    sampling_rate = batch["audio"]["sampling_rate"]    

    inputs = processor.feature_extractor(
        audio["array"], 
        sampling_rate=sampling_rate, 
        return_tensors="pt",
        padding="max_length",  # This ensures padding to max length in batch
        max_length=30*sampling_rate,  # 30 seconds at 16kHz
        truncation=True  # Truncate if longer than max_length
    )
    
    # Reprocess the filtered audio
    batch["input_features"] = inputs.input_features[0]
    
    # Normalize and encode target text
    normalized_text = normalize_transcript(batch["text"])
    batch["labels"] = processor.tokenizer(normalized_text, padding=True).input_ids
    
    return batch

cslu_processed = cslu.map(
    prepare_dataset, 
    remove_columns=cslu.column_names["train"], 
    num_proc=1,
)

cslu_processed

DatasetDict({
    train: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 8749
    })
    test: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 2188
    })
})

In [6]:
wer = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    
    # Replace -100 with pad token
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    
    # Decode predictions and labels
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    
    # Normalize both predictions and references
    pred_normalized = [normalize_transcript(text) for text in pred_str]
    label_normalized = [normalize_transcript(text) for text in label_str]
    
    # Compute WER
    wer_score = wer.compute(predictions=pred_normalized, references=label_normalized)
    
    return {"wer": wer_score}

In [7]:
def get_baseline():
    from transformers import pipeline
    pipe = pipeline(
        "automatic-speech-recognition",
        model=model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        batch_size=16,
        device=device,
    )
    results = pipe(cslu["test"]["audio"])
    predictions = [normalize_transcript(d["text"]) for d in results]
    references = [normalize_transcript(text) for text in cslu["test"]["text"]]
    wer_score = wer.compute(predictions=predictions, references=references)
    print(wer_score)

# get_baseline()
# 0.7680394631180452

In [8]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

In [9]:
training_args = Seq2SeqTrainingArguments(
    output_dir="../bin/whisper-cslu-kids",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=5e-6,
    warmup_steps=100,
    max_steps=500,  # Fewer steps for proof of concept
    gradient_checkpointing=True,
    fp16=True,
    eval_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=250,
    eval_steps=250,
    logging_steps=50,
    report_to=[],  # Disable logging for simplicity
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=cslu_processed["train"],
    eval_dataset=cslu_processed["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=processor.feature_extractor,
)

In [None]:
# Start training
print("Starting training...")
trainer.train()

# 0.73 with learning_rate=1e-5

Starting training...


Step,Training Loss,Validation Loss,Wer
250,1.5592,1.542537,0.747448


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [None]:
# Evaluate final model
print("Evaluating final model...")
final_results = trainer.evaluate()
print(f"Final WER: {final_results['eval_wer']:.4f}")

In [None]:
# Test on a few samples
print("\nTesting on sample predictions:")
test_samples = cslu_processed["test"].select(range(3))
predictions = trainer.predict(test_samples)

pred_ids = predictions.predictions
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)

for i, (pred, label) in enumerate(zip(pred_str, test_samples["labels"])):
    label_str = processor.tokenizer.decode(label, skip_special_tokens=True)
    print(f"\nSample {i+1}:")
    print(f"Prediction: {normalize_transcript(pred)}")
    print(f"Reference:  {normalize_transcript(label_str)}")