In [1]:
import torch
from transformers import (
    AutoModelForSpeechSeq2Seq,
    AutoProcessor,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    pipeline,
)
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from datasets import Dataset, DatasetDict, load_from_disk

import numpy as np
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import re
import jiwer

In [2]:
device = "cuda:0"
torch_dtype = torch.float32
model_id = "openai/whisper-large-v3"

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
)
model.to(device)

# Configure generation
model.generation_config.language = "english"
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None

In [3]:
myst = load_from_disk("/home/jovyan/active-projects/tla-asr-finetune/data/myst_dataset.ds")
myst

Loading dataset from disk:   0%|          | 0/25 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['audio', 'transcription', 'speaker_id', 'date', 'time', 'session_type', 'version', 'split'],
        num_rows: 76924
    })
    validation: Dataset({
        features: ['audio', 'transcription', 'speaker_id', 'date', 'time', 'session_type', 'version', 'split'],
        num_rows: 12238
    })
    test: Dataset({
        features: ['audio', 'transcription', 'speaker_id', 'date', 'time', 'session_type', 'version', 'split'],
        num_rows: 13169
    })
})

In [4]:
normalizer = BasicTextNormalizer()

def normalize_transcript(text):
    # The original transcript has annotations, for example a pause is <pau>
    # Remove tags in angle brackets
    text = re.sub(r'<[^>]*>', '', text)
    
    # These are "false starts" in the original transcript, for example th*
    # These are ignored by ASR
    # Remove words that end with asterisks (e.g., th*)
    text = re.sub(r'\S*\*', '', text)

    # Remove all punctuation
    normalized_text = normalizer(text)

    return normalized_text

In [10]:
def filter_short_long_samples(batch):
    audio_len = batch["audio"]["array"].shape[0] / batch["audio"]["sampling_rate"]
    return 5 <= audio_len <= 30

def filter_empty_transcripts(batch):
    return len(batch["transcription"].strip()) > 0

myst_filtered = myst.filter(filter_short_long_samples).filter(filter_empty_transcripts)

print(len(myst_filtered["train"]))

Filter:   0%|          | 0/36544 [00:00<?, ? examples/s]

Filter:   0%|          | 0/6013 [00:00<?, ? examples/s]

Filter:   0%|          | 0/6321 [00:00<?, ? examples/s]

36544


In [11]:
def prepare_dataset(batch):
    # Load audio
    audio = batch["audio"]
    sampling_rate = batch["audio"]["sampling_rate"]

    inputs = processor.feature_extractor(
        audio["array"],
        sampling_rate=sampling_rate,
        return_tensors="pt",
        padding="max_length",  # This ensures padding to max length in batch
        max_length=30 * sampling_rate,  # 30 seconds at 16kHz
        truncation=True,  # Truncate if longer than max_length
    )

    # Reprocess the filtered audio
    batch["input_features"] = inputs.input_features[0]

    # Normalize and encode target text
    normalized_text = normalize_transcript(batch["transcription"])
    batch["labels"] = processor.tokenizer(normalized_text, padding=True).input_ids

    return batch

myst_processed = myst_filtered.map(
    prepare_dataset,
    remove_columns=myst_filtered.column_names["train"],
    num_proc=1,
)

myst_processed["train"]


Map:   0%|          | 0/36544 [00:00<?, ? examples/s]

Map:   0%|          | 0/6013 [00:00<?, ? examples/s]

Map:   0%|          | 0/6321 [00:00<?, ? examples/s]

Dataset({
    features: ['input_features', 'labels'],
    num_rows: 36544
})

In [13]:
def weighted_wer(ref: list[str], pred: list[str]):
    # Normalize both predictions and references
    pred_normalized = [normalize_transcript(text) for text in pred]
    label_normalized = [normalize_transcript(text) for text in ref]
    
    total_errors = 0
    total_words = 0
    
    for pred_text, ref_text in zip(pred_normalized, label_normalized):
        ref_words = ref_text.split()
            
        # Compute WER for this sample
        if len(ref_words) > 0:
            sample_wer = jiwer.wer(ref_text, pred_text)
        else:
            sample_wer = 0
        
        # Accumulate weighted errors
        sample_errors = sample_wer * len(ref_words)
        total_errors += sample_errors
        total_words += len(ref_words)
    
    weighted_wer = total_errors / total_words if total_words > 0 else 0.0
    
    return {"wer": weighted_wer}
    

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    
    # Replace -100 with pad token
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    
    # Decode predictions and labels
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    return weighted_wer(label_str, pred_str)

In [14]:
myst

DatasetDict({
    train: Dataset({
        features: ['audio', 'transcription', 'speaker_id', 'date', 'time', 'session_type', 'version', 'split'],
        num_rows: 76924
    })
    validation: Dataset({
        features: ['audio', 'transcription', 'speaker_id', 'date', 'time', 'session_type', 'version', 'split'],
        num_rows: 12238
    })
    test: Dataset({
        features: ['audio', 'transcription', 'speaker_id', 'date', 'time', 'session_type', 'version', 'split'],
        num_rows: 13169
    })
})

In [16]:
def get_baseline(ds):
    pipe = pipeline(
        "automatic-speech-recognition",
        model=model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        batch_size=16,
        device=device,
        chunk_length_s=30,
    )
    results = pipe(ds["test"]["audio"])
    return results

def get_wer(y_true, preds):
    y_pred = [d["text"] for d in preds]
    wer_score = weighted_wer(y_true, y_pred)
    return(wer_score)

preds = get_baseline(myst)
get_wer(myst["test"]["transcription"], preds)
# 0.19560642190320807

{'wer': 0.19560642190320807}

In [17]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

In [18]:
training_args = Seq2SeqTrainingArguments(
    output_dir="../bin/whisper-myst",
    per_device_train_batch_size=8, # About 30 GiB of VRAM with 2x gradient accumulation
    gradient_accumulation_steps=2,
    learning_rate=1e-6,
    warmup_steps=100,
    max_steps=1_000,
    gradient_checkpointing=True,
    fp16=True,
    eval_strategy="steps",
    per_device_eval_batch_size=16, # About 30GiB of VRAM
    predict_with_generate=True,
    save_steps=200,
    eval_steps=200,
    logging_steps=50,
    report_to=[],  # Disable logging for simplicity
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=myst_processed["train"],
    eval_dataset=myst_processed["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=processor.feature_extractor,
)

In [19]:
# Start training
print("Starting training...")
trainer.train()

# 0.136113 with learning_rate 1e-6 on MyST

Starting training...


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Step,Training Loss,Validation Loss,Wer
200,0.5871,0.549029,0.178696
400,0.3455,0.414291,0.150277
600,0.2982,0.401005,0.137361
800,0.2975,0.3952,0.137282
1000,0.299,0.391985,0.136113


There were missing keys in the checkpoint model loaded: ['proj_out.weight'].


TrainOutput(global_step=1000, training_loss=0.47948250007629395, metrics={'train_runtime': 36098.4593, 'train_samples_per_second': 0.443, 'train_steps_per_second': 0.028, 'total_flos': 5.435997290496e+19, 'train_loss': 0.47948250007629395, 'epoch': 0.43782837127845886})

In [21]:
trainer.save_model("../bin/myst1")

In [22]:
# Evaluate final model
print("Evaluating final model...")
final_results = trainer.evaluate()
print(f"Final WER: {final_results['eval_wer']:.4f}")

Evaluating final model...


Final WER: 0.1361


In [None]:
# Test on a few samples
print("\nTesting on sample predictions:")
test_samples = cslu_processed["test"].select(range(3))
predictions = trainer.predict(test_samples)

pred_ids = predictions.predictions
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)

for i, (pred, label) in enumerate(zip(pred_str, test_samples["labels"])):
    label_str = processor.tokenizer.decode(label, skip_special_tokens=True)
    print(f"\nSample {i+1}:")
    print(f"Prediction: {normalize_transcript(pred)}")
    print(f"Reference:  {normalize_transcript(label_str)}")