Automatic speech recognition (ASR) converts a speech signal to text, mapping a sequence of audio inputs to text outputs. Virtual assistants like Siri and Alexa use ASR models to help users everyday, and there are many other useful user-facing applications like live captioning and note-taking during meetings.

This guide shows how to:
1. Finetune Wav2Vec2 on the MInDS-14 dataset to transcribe audio to text.
2. Use your finetuned model for inference.

# Libraries

In [1]:
pip install transformers datasets evaluate jiwer

In [2]:
import torch
import evaluate
import numpy as np
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from datasets import load_dataset, Audio
from transformers import AutoProcessor, AutoModelForCTC, TrainingArguments, Trainer, pipeline

mps_device = torch.device("mps")

# Load Data

In [3]:
# Load a smaller subset of the MInDS-14 (to experiment on a small dataset first)
minds = load_dataset("PolyAI/minds14", name="en-US", split="train[:100]")

# train-test split 
minds = minds.train_test_split(test_size=0.2)

In [4]:
# Inspect dataset detail
# NB: focusing on the audio and transcription
# audio: a 1-dimensional array of the speech signal that must be called to load and resample the audio file.
# transcription: the target text.
minds

In [5]:
# Inspect an example
minds["train"][0]

# Preprocessing

In [6]:
# Load a Wav2Vec2 processor to process the audio signal
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base")

In [7]:
# MInDS-14 dataset has a sampling rate of 8000kHz (you can find this information in its dataset card)
# You’ll need to resample the dataset to 16000kHz to use the pretrained Wav2Vec2 model
minds = minds.cast_column("audio", Audio(sampling_rate=16_000))
minds["train"][0]

In [8]:
# The transcription text contains a mix of upper and lowercase characters
# The Wav2Vec2 tokenizer is only trained on uppercase characters
# ...make sure the text matches the tokenizer’s vocabulary
def uppercase(example):
    return {"transcription": example["transcription"].upper()}

minds = minds.map(uppercase)

In [9]:
def prepare_dataset(batch):
    # Call the audio column to load and resample the audio file
    audio = batch["audio"]
    batch = processor(audio["array"], sampling_rate=audio["sampling_rate"], text=batch["transcription"])
    
    # Extracts the input_values from the audio file and tokenize the transcription column with the processor
    batch["input_length"] = len(batch["input_values"][0])
    return batch

encoded_minds = minds.map(prepare_dataset, remove_columns=minds.column_names["train"], num_proc=4)

In [10]:
# Transformers doesn’t have a data collator for speech recognition
# Adapt the DataCollatorWithPadding to create a batch of examples
# Also perform dynamic padding which is more efficient than setting padding=True
@dataclass
class DataCollatorCTCWithPadding:
    processor: AutoProcessor
    padding: Union[bool, str] = "longest"

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"][0]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(input_features, padding=self.padding, return_tensors="pt")

        labels_batch = self.processor.pad(labels=label_features, padding=self.padding, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [11]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding="longest")

# Evaluation

In [12]:
# For speech recognition, load the word error rate (WER) metric
wer = evaluate.load("wer")

In [13]:
# Function that passes predictions + labels to compute to calculate the WER
# Called during training
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

# Training

In [14]:
# Load Wav2Vec2 with AutoModelForCTC
# Specify the reduction to apply with the ctc_loss_reduction parameter...
# NB: it is often better to use the mean instead of the default summation
model = AutoModelForCTC.from_pretrained(
    "facebook/wav2vec2-base",
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
)

model.to(mps_device)

In [15]:
# Define your training hyperparameters in TrainingArguments. 
# The only required parameter is output_dir which specifies where to save your model.
training_args = TrainingArguments(
    output_dir="speech_recognition_model",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=2000,
    gradient_checkpointing=True,
    fp16=True,
    group_by_length=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False
)

# Pass training arguments to Trainer 
# along with the model, dataset, tokenizer, data collator, and compute_metrics function.
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_minds["train"],
    eval_dataset=encoded_minds["test"],
    tokenizer=processor,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Fine-tune model
trainer.train()

# Inference

In [None]:
# Load audio file you’d like to run inference on
# Resample the sampling rate of the audio file to match the sampling rate of the model if required
dataset = load_dataset("PolyAI/minds14", "en-US", split="train")
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
sampling_rate = dataset.features["audio"].sampling_rate
audio_file = dataset[0]["audio"]["path"]

In [None]:
# Inference using a pipeline
transcriber = pipeline("automatic-speech-recognition", model="speech_recognition_model")
transcriber(audio_file)

In [None]:
# Inference using PyTorch

# Load a processor to preprocess the audio file and transcription and return the input as PyTorch tensors
processor = AutoProcessor.from_pretrained("speech_recognition_model")
inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")

# Pass your inputs to the model and return the logits
model = AutoModelForCTC.from_pretrained("speech_recognition_model")
with torch.no_grad():
    logits = model(**inputs).logits
    
# Get the predicted input_ids with the highest probability
# use the processor to decode the predicted input_ids back into text
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
transcription