In [None]:
!pip install -qU transformers datasets evaluate jiwer

# Automatic Speech Recognition

**Automatic speech recognition (ASR)** converts a speech signal to text, mapping a sequence of audio inputs to text outputs. ASR can be used for live captioning, note-taking during meetings, and etc.

## Load MINDS-14 dataset

In [1]:
from datasets import load_dataset, Audio

minds = load_dataset(
    'PolyAI/minds14',
    name='en-US',
    split='train[:100]'
)

MInDS-14.zip:   0%|          | 0.00/471M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [2]:
# split into train and test sets
minds = minds.train_test_split(test_size=0.2)

In [3]:
minds

DatasetDict({
    train: Dataset({
        features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'],
        num_rows: 80
    })
    test: Dataset({
        features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'],
        num_rows: 20
    })
})

For ASR task, we focus on the `audio` and `transcription` columns and we can safely remove other columns

In [4]:
minds = minds.remove_columns(['english_transcription', 'intent_class', 'lang_id'])

In [5]:
minds['train'][0]

{'path': '/root/.cache/huggingface/datasets/downloads/extracted/f9018fd3747971e77d59e6c5da3fdf9d5bb914c495e16c23e1fe47c921d76a7a/en-US~JOINT_ACCOUNT/602bae8ebb1e6d0fbce9226f.wav',
 'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/f9018fd3747971e77d59e6c5da3fdf9d5bb914c495e16c23e1fe47c921d76a7a/en-US~JOINT_ACCOUNT/602bae8ebb1e6d0fbce9226f.wav',
  'array': array([ 0.        , -0.00024414,  0.        , ...,  0.00024414,
         -0.00024414,  0.        ]),
  'sampling_rate': 8000},
 'transcription': 'how do I set up a joint account'}

* `audio` is a 1D array of the speech signal that must be called to load and resample the audio file
* `transcription` is the target text

## Preprocess

We will load a Wav2Vec2 processor to process the audio signal:

In [None]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained('facebook/wav2vec2-base')

From the dataset card, the MINDS-14 dataset has a sampling rate of 8kHz, so we need to resample the dataset to 16kHz to use the pretrained Wav2Vec2 model:

In [7]:
minds = minds.cast_column('audio', Audio(sampling_rate=16_000))
minds['train'][0]

{'path': '/root/.cache/huggingface/datasets/downloads/extracted/f9018fd3747971e77d59e6c5da3fdf9d5bb914c495e16c23e1fe47c921d76a7a/en-US~JOINT_ACCOUNT/602bae8ebb1e6d0fbce9226f.wav',
 'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/f9018fd3747971e77d59e6c5da3fdf9d5bb914c495e16c23e1fe47c921d76a7a/en-US~JOINT_ACCOUNT/602bae8ebb1e6d0fbce9226f.wav',
  'array': array([-3.15820362e-05, -1.72566681e-04, -2.11391991e-04, ...,
         -1.99425383e-04, -3.09243887e-06,  7.71999621e-05]),
  'sampling_rate': 16000},
 'transcription': 'how do I set up a joint account'}

Since the `transcription` contains a mix of uppercase and lowercase characters and the Wav2Vec2 tokenizer is only trained on uppercase characters, we need to do some processing work:

In [8]:
def uppercase(example):
    return {'transcription': example['transcription'].upper()}

minds = minds.map(uppercase)
minds['train'][0]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

{'path': '/root/.cache/huggingface/datasets/downloads/extracted/f9018fd3747971e77d59e6c5da3fdf9d5bb914c495e16c23e1fe47c921d76a7a/en-US~JOINT_ACCOUNT/602bae8ebb1e6d0fbce9226f.wav',
 'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/f9018fd3747971e77d59e6c5da3fdf9d5bb914c495e16c23e1fe47c921d76a7a/en-US~JOINT_ACCOUNT/602bae8ebb1e6d0fbce9226f.wav',
  'array': array([-3.15820362e-05, -1.72566681e-04, -2.11391991e-04, ...,
         -1.99425383e-04, -3.09243887e-06,  7.71999621e-05]),
  'sampling_rate': 16000},
 'transcription': 'HOW DO I SET UP A JOINT ACCOUNT'}

Now we can create a preprocessing function that:
* calls the `audio` column to load and resample the audio file
* extracts the `input_values` from the audio file and tokenize the `transcription` column with the processor

In [9]:
def prepare_dataset(batch):
    audio = batch['audio']
    batch = processor(
        audio['array'],
        sampling_rate=audio['sampling_rate'],
        text=batch['transcription']
    )
    batch['input_length'] = len(batch['input_values'][0])

    return batch

In [10]:
encoded_minds = minds.map(
    prepare_dataset,
    remove_columns=minds.column_names['train'], # remove all column names
    num_proc=4
)
encoded_minds

Map (num_proc=4):   0%|          | 0/80 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/20 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_values', 'labels', 'input_length'],
        num_rows: 80
    })
    test: Dataset({
        features: ['input_values', 'labels', 'input_length'],
        num_rows: 20
    })
})

HuggingFace Transformers does not have a data colator for ASR, so we will need to adapt the `DataCollatorWithPadding` to create a batch of examples. It will also dynamically pad our text and labels to the length of the longest element in its batch (instead of the entire dataset) so they are a uniform length.

Unlike other data collators, this specific data collator needs to apply a different padding method to `input_values` and `labels`:

In [11]:
import torch
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union


@dataclass
class DataCollatorCTCWithPadding:
    processor: AutoProcessor
    padding: Union[bool, str] = 'longest'

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels since they have to be of different lengths and
        # need different padding methods
        input_features = [
            {'input_values': feature['input_values'][0]} for feature in features
        ]
        label_features = [
            {'input_ids': feature['labels']} for feature in features
        ]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors='pt'
        )

        labels_batch = self.processor.pad(
            labels=label_features,
            padding=self.padding,
            return_tensors='pt'
        )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch['input_ids'].masked_fill(
            labels_batch.attention_mask.ne(1),
            -100
        )
        batch['labels'] = labels

        return batch

Now we can instantiate our `data_collator`:

In [12]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding='longest')

## Evaluate

For ASR task, we need to load the **word error rate (WER)** metric.

In [None]:
import evaluate

wer = evaluate.load('wer')

Then we can create a function that passes our predictions and labels to compute the WER:

In [13]:
import numpy as np

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer_score = wer.compute(
        predictions=pred_str,
        references=label_str
    )

    return {'wer': wer_score}

## Train

We start training our model by loading Wav2Vec2 with `AutoModelForCTC` and specifying the reduction to apply with the `ctc_loss_reduction` parameter:

In [14]:
from transformers import AutoModelForCTC, TrainingArguments, Trainer

model = AutoModelForCTC.from_pretrained(
    'facebook/wav2vec2-base',
    ctc_loss_reduction='mean',
    pad_token_id=processor.tokenizer.pad_token_id
)



pytorch_model.bin:   0%|          | 0.00/380M [00:00<?, ?B/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Now we need to set up the training arguments:

In [15]:
training_args = TrainingArguments(
    output_dir='my_awesome_asr_mind_model',
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=2000,
    gradient_checkpointing=True,
    fp16=True,
    group_by_length=True,
    eval_strategy='steps',
    per_device_eval_batch_size=8,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    load_best_model_at_end=True,
    metric_for_best_model='wer',
    greater_is_better=False,
    push_to_hub=False
)

In [16]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_minds['train'],
    eval_dataset=encoded_minds['test'],
    processing_class=processor,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

## Inference

In [None]:
from datasets import load_dataset, Audio

dataset = load_dataset("PolyAI/minds14", "en-US", split="train")
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
sampling_rate = dataset.features["audio"].sampling_rate
audio_file = dataset[0]["audio"]["path"]

We can load our fine-tuned model for inference by using a `pipeline()`:

In [None]:
from transformers import pipeline

transcriber = pipeline(
    'automatic-speech-recognition',
    model='stevhliu/my_awesome_asr_mind_model'
)

In [None]:
transcriber(audio_file)

We can also manually load models

In [None]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained('stevhliu/my_awesome_asr_mind_model')
inputs = processor(
    dataset[0]['audio']['array'],
    sampling_rate=sampling_rate,
    return_tenosrs='pt'
)

In [None]:
from transformers import AutoModelForCTC

model = AutoModelForCTC.from_pretrained('stevhliu/my_awesome_asr_mind_model')

In [None]:
import torch

with torch.no_grad():
    logits = model(**inputs).logits

predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
transcription