In [1]:
import torch
from transformers import (
    AutoModelForSpeechSeq2Seq,
    AutoProcessor,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
import evaluate
from datasets import Dataset, DatasetDict, load_from_disk

import numpy as np
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import re
import string

from text2digits import text2digits

In [2]:
device = "cuda:0"
torch_dtype = torch.float32

model_id = "openai/whisper-large-v3"

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
)
model.to(device)

# Configure generation
model.generation_config.language = "english"
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None

In [3]:
cslu = load_from_disk("../data/cslu_kids.ds").train_test_split(test_size=0.2, seed=42)

In [4]:
cslu

DatasetDict({
    train: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 880
    })
    test: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 221
    })
})

In [38]:
for sample in cslu["train"].select(range(10)):
    print(sample["sentence"])

b c d e <pau> h i l m n o p<ln> q r s<ln> t<ln> u v w x<ln> y z <br> i went to my brother's friend and then we went to<ln> burger king and then we went to the park<ln> and we uhm came to a house uhm i went down the<ln> slide mom like <bs> my mom's name's caitlin <br> and my dad's name's miguel and my <br> big brother's<ln> name's freddie<ln> and my little brother's eric<ln> uhm my little brother's four years old and my big brother's<ln> sixteen years old and <br> i uhm i don't fight with them i always play<ln> with them and share<ln> stuff with them cada<nitl> hand and i mean <br> cada<nitl> face hay<nitl> un<nitl> neve<nitl> se<nitl> perrito<nitl> que<nitl> se<nitl> hay<nitl> muerto<nitl> there was there was a dog that <br> uhm he died <laugh>
<br> a b c d e f g h i j k<ln> l m n o p<ln> q r s t u <br> v w x y z <br> okay <br> uhm i have a room <br> that i share with my brother <br> and both of us have two of he exactly same stereos but mine is a little bit taller<ln> <br> and then in

In [35]:
t2d = text2digits.Text2Digits()
punctuation_remover = str.maketrans('', '', string.punctuation)

def normalize_transcript(text):
    # The original transcript has annotations, for example a pause is <pau>
    # Remove tags in angle brackets
    text = re.sub(r'<[^>]*>', '', text)
    
    # These are "false starts" in the original transcript, for example th*
    # These are ignored by ASR
    # Remove words that end with asterisks (e.g., th*)
    text = re.sub(r'\S*\*', '', text)

    # Remove all punctuation
    text = text.translate(punctuation_remover)

    # Clean up excess spaces in the original transcript or resulting from above operations
    text = re.sub(r'\s+', ' ', text)

    # Convert number representations, e.g., "thirteen" to "13"
    # This is imperfect (does not know when "one" is a pronoun vs. a number)
    # But we apply the same normalization to both samples, so works fine for Word Error Rate
    try:
        normalized_text = t2d.convert(text)
    except:
        print(text)
    return normalized_text.strip().lower()

def prepare_dataset(batch):
    # Load audio
    audio = batch["audio"]
    
    # Check if audio is longer than 30 seconds
    max_audio_length = 30 * 16000  # 30 seconds at 16kHz
    
    if len(audio["array"]) > max_audio_length:
        # Skip samples longer than 30 seconds to avoid mismatch
        # Alternative: you could return None and filter these out
        return None
    
    # Compute log-Mel input features (no truncation, just padding)
    batch["input_features"] = processor.feature_extractor(
        audio["array"], 
        sampling_rate=audio["sampling_rate"],
        padding=True
    ).input_features[0]
    
    # Normalize and encode target text (no truncation)
    normalized_text = normalize_transcript(batch["sentence"])
    
    # Check if text is too long (safety check)
    tokens = processor.tokenizer(normalized_text).input_ids
    if len(tokens) > 448:
        # Skip very long transcripts
        return None
    
    batch["labels"] = tokens
    return batch


def filter_long_samples(batch):
    # Remove samples longer than 30 seconds
    return len(batch["audio"]["array"]) <= 30 * 16000


cslu_processed = cslu.filter(filter_long_samples).map(
    prepare_dataset, 
    remove_columns=cslu.column_names["train"], 
    num_proc=1
)

Map:   0%|          | 0/6 [00:00<?, ? examples/s]

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

In [36]:
cslu_processed

DatasetDict({
    train: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 6
    })
    test: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 3
    })
})

In [31]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        
        # Replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        
        batch["labels"] = labels
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

In [32]:
wer = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    
    # Replace -100 with pad token
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    
    # Decode predictions and labels
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    
    # Normalize both predictions and references
    pred_normalized = [normalize_transcript(text) for text in pred_str]
    label_normalized = [normalize_transcript(text) for text in label_str]
    
    # Compute WER
    wer_score = wer.compute(predictions=pred_normalized, references=label_normalized)
    
    return {"wer": wer_score}

In [33]:
training_args = Seq2SeqTrainingArguments(
    output_dir="../bin/whisper-cslu-kids",
    per_device_train_batch_size=8,  # Smaller batch size for kids dataset
    gradient_accumulation_steps=2,
    learning_rate=1e-5,
    warmup_steps=100,
    max_steps=1000,  # Fewer steps for proof of concept
    gradient_checkpointing=True,
    fp16=True,
    eval_strategy="steps",
    per_device_eval_batch_size=4,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=250,
    eval_steps=250,
    logging_steps=50,
    report_to=[],  # Disable logging for simplicity
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,  # Set to True if you want to upload
)

# 7. Create trainer
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=cslu_processed["train"],
    eval_dataset=cslu_processed["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=processor.feature_extractor,
)

In [34]:
# 8. Start training
print("Starting training...")
trainer.train()

Starting training...


Step,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
# 9. Evaluate final model
print("Evaluating final model...")
final_results = trainer.evaluate()
print(f"Final WER: {final_results['eval_wer']:.4f}")

In [None]:
# 10. Test on a few samples
print("\nTesting on sample predictions:")
test_samples = cslu_processed["test"].select(range(3))
predictions = trainer.predict(test_samples)

pred_ids = predictions.predictions
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)

for i, (pred, label) in enumerate(zip(pred_str, test_samples["labels"])):
    label_str = processor.tokenizer.decode(label, skip_special_tokens=True)
    print(f"\nSample {i+1}:")
    print(f"Prediction: {normalize_transcript(pred)}")
    print(f"Reference:  {normalize_transcript(label_str)}")