# Fine-tune a wav2vec 2 checkpoint for Automatic Speech Recognition (ASR) on IPU

This notebook will demonstrate how to fine-tune a pre-trained wav2vec 2.0 model with PyTorch on the Graphcore IPUs. We will use a "wav2vec2-base" model and fine-tune for a CTC downstream task using LibriSpeech.

We will show how to use a wav2vec 2.0 model written in PyTorch from the 🤗`transformers` library from HuggingFace and paralellize it using the 🤗`optimum-graphcore` library.

### Background

ASR (Automatic Speech Recognition), the task of transcribing audio automatically, has historically required large amounts of labelled data. Additionally, these systems had predominantly used fixed feature extraction methods which do not learn from the raw signal, for example STFT or Mel-Frequency. Research conducted by Facebook AI (now Meta AI) demonstrates a “framework for self-supervised learning for speech representations”. In other words, a pre-training phase and architecture which can learn feature representations, and their relationships, by leveraging large amounts of unlabelled, raw audio data.  

There are two phases to training: pre-training on unlabelled data, and fine-tuning on a down-stream task. In the original literature the model is fine-tuned for CTC (connectionist temporal classification), which is an ASR task. The consistent modules between pre-training and fine-tuning are what you’d expect to see in a CTC system; it has feature extraction, and an encoder. But, unlike many models of the past, the feature extraction is a convolutional neural network, which makes it trainable. Following that there is a BERT-style encoder where a large convolutional block is used before the first layers, rather than using sinusoidal positional encoding.  

### Running on Paperspace

The Paperspace environment lets you run this notebook with no set up. To improve your experience we preload datasets and pre-install packages, this can take a few minutes, if you experience errors immediately after starting a session please try restarting the kernel before contacting support. If a problem persists or you want to give us feedback on the content of this notebook, please reach out to through our community of developers using our [slack channel](https://www.graphcore.ai/join-community) or raise a [GitHub issue](https://github.com/gradient-ai/Graphcore-HuggingFace/issues).

Requirements:
- Python packages installed with `python -m pip install -r requirements.txt`

In [2]:
!df -h

Filesystem      Size  Used Avail Use% Mounted on
overlay         7.3T  6.9T  440G  95% /
tmpfs            64M     0   64M   0% /dev
tmpfs           252G     0  252G   0% /sys/fs/cgroup
/dev/md126      7.3T  6.9T  440G  95% /datasets
tmpfs           252G  1.3M  252G   1% /dev/shm
tmpfs           252G     0  252G   0% /proc/acpi
tmpfs           252G     0  252G   0% /proc/scsi
tmpfs           252G     0  252G   0% /sys/firmware


In [3]:
%%bash
apt update
apt-get install libsndfile1 -y

Get:1 http://archive.ubuntu.com/ubuntu focal InRelease [265 kB]
Get:2 http://archive.ubuntu.com/ubuntu focal-updates InRelease [114 kB]
Get:3 http://archive.ubuntu.com/ubuntu focal-backports InRelease [108 kB]
Get:4 http://archive.ubuntu.com/ubuntu focal/multiverse amd64 Packages [177 kB]
Get:5 http://archive.ubuntu.com/ubuntu focal/restricted amd64 Packages [33.4 kB]
Get:6 http://archive.ubuntu.com/ubuntu focal/main amd64 Packages [1275 kB]
Get:7 http://security.ubuntu.com/ubuntu focal-security InRelease [114 kB]
Get:8 http://archive.ubuntu.com/ubuntu focal/universe amd64 Packages [11.3 MB]
Get:9 http://archive.ubuntu.com/ubuntu focal-updates/universe amd64 Packages [1360 kB]
Get:10 http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages [3299 kB]
Get:11 http://archive.ubuntu.com/ubuntu focal-updates/multiverse amd64 Packages [31.2 kB]
Get:12 http://archive.ubuntu.com/ubuntu focal-updates/restricted amd64 Packages [2554 kB]
Get:13 http://archive.ubuntu.com/ubuntu focal-back



debconf: unable to initialize frontend: Dialog
debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 76, <> line 3.)
debconf: falling back to frontend: Readline
debconf: unable to initialize frontend: Readline
debconf: (This frontend requires a controlling tty.)
debconf: falling back to frontend: Teletype
dpkg-preconfigure: unable to re-open stdin: 


In order to improve usability and support for future users, Graphcore would like to collect information about the
applications and code being run in this notebook. The following information will be anonymised before being sent to Graphcore:

- User progression through the notebook
- Notebook details: number of cells, code being run and the output of the cells
- Environment details

You can disable logging at any time by running `%unload_ext gc_logger` from any cell.

In [4]:
%pip install -r requirements.txt
%load_ext graphcore_cloud_tools.notebook_logging.gc_logger

Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting examples-utils[common]@ git+https://github.com/graphcore/examples-utils@latest_stable (from -r requirements.txt (line 7))
  Cloning https://github.com/graphcore/examples-utils (to revision latest_stable) to /tmp/pip-install-i82n65wg/examples-utils_2611845749454411bbefb71a05438f9e
  Running command git clone --filter=blob:none --quiet https://github.com/graphcore/examples-utils /tmp/pip-install-i82n65wg/examples-utils_2611845749454411bbefb71a05438f9e
  Running command git checkout -q 40c62e6646db8f9d60d1707a61204c95a15c7ccb
  Resolved https://github.com/graphcore/examples-utils to commit 40c62e6646db8f9d60d1707a61204c95a15c7ccb
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting optimum-graphcore==0.6.1 (from -r requirements.txt (line 1))
  Downloading optimum_graphcore-0.6.1-py3-none-any.whl (212 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.9/212.9 kB[0m [31m12.4 MB/s

Collecting audioread>=2.1.9 (from librosa->-r requirements.txt (line 4))
  Downloading audioread-3.0.0.tar.gz (377 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m377.0/377.0 kB[0m [31m70.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting scikit-learn>=0.20.0 (from librosa->-r requirements.txt (line 4))
  Downloading scikit_learn-1.2.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.8/9.8 MB[0m [31m63.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting joblib>=0.14 (from librosa->-r requirements.txt (line 4))
  Downloading joblib-1.2.0-py3-none-any.whl (297 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m298.0/298.0 kB[0m [31m70.5 MB/s[0m eta [36m0:00:00[0m
Collecting numba>=0.51.0 (from librosa->-r requirements.txt (line 4))
  Downloading numba-0.57.1-cp38-cp38-manylinux2014_x86_64.manylinux_2

[?25h  Downloading boto3-1.26.158-py3-none-any.whl (135 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m135.9/135.9 kB[0m [31m37.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Downloading boto3-1.26.157-py3-none-any.whl (135 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m135.9/135.9 kB[0m [31m39.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Downloading boto3-1.26.156-py3-none-any.whl (135 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m135.6/135.6 kB[0m [31m41.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Downloading boto3-1.26.155-py3-none-any.whl (135 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m135.6/135.6 kB[0m [31m41.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Downloading boto3-1.26.154-py3-none-any.whl (135 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m135.6/135.6 kB[0m [31m38.9 MB/s[0m eta [36m0:00:00[0m
[?25hINFO: pip is looking at multiple versions of boto3 to determine which version is com

  Downloading aiohttp-3.8.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m74.8 MB/s[0m eta [36m0:00:00[0m
Collecting multidict<7.0,>=4.5 (from aiohttp->datasets->optimum-graphcore==0.6.1->-r requirements.txt (line 1))
  Downloading multidict-6.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (121 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.3/121.3 kB[0m [31m39.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting async-timeout<5.0,>=4.0.0a3 (from aiohttp->datasets->optimum-graphcore==0.6.1->-r requirements.txt (line 1))
  Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)
Collecting yarl<2.0,>=1.0 (from aiohttp->datasets->optimum-graphcore==0.6.1->-r requirements.txt (line 1))
  Downloading yarl-1.9.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (266 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.9/266.9 

Collecting tzdata>=2022.1 (from pandas->datasets->optimum-graphcore==0.6.1->-r requirements.txt (line 1))
  Downloading tzdata-2023.3-py2.py3-none-any.whl (341 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m341.8/341.8 kB[0m [31m73.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting mpmath>=0.19 (from sympy->optimum==1.6.1->optimum-graphcore==0.6.1->-r requirements.txt (line 1))
  Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m536.2/536.2 kB[0m [31m27.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting mypy-extensions>=0.3.0 (from typing-inspect->simple-parsing==0.0.19.post1->examples-utils[common]@ git+https://github.com/graphcore/examples-utils@latest_stable->-r requirements.txt (line 7))
  Downloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)
Building wheels for collected packages: audioread, cppimport, examples-utils
  Building wheel for audioread (setup.py) ... [?25ldone
[?25h  Created wheel

### Graphcore Hugging Face models
Hugging Face provides convenient access to pre-trained transformer models. The partnership between Hugging Face and Graphcore allows us to run these models on the IPU.

Hugging Face models ported to the IPU can be found on the Graphcore organisation page on Hugging Face. 

### Utility imports
We start by importing the utilities that will be used later in the tutorial: 

In [5]:
import functools
import json
import logging
import os
import re
import sys
import warnings
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union

import datasets
import numpy as np
import torch
from datasets import DatasetDict, load_dataset, load_metric
from pathlib import Path
import transformers
from optimum.graphcore import IPUConfig, IPUTrainer
from optimum.graphcore import IPUTrainingArguments
from transformers import (
    AutoConfig,
    AutoFeatureExtractor,
    AutoModelForCTC,
    AutoProcessor,
    AutoTokenizer,
    HfArgumentParser,
    Wav2Vec2Processor,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

In [6]:
set_seed(0)

Values for machine size and cache directories can be configured through environment variables or directly in the notebook:

In [7]:
import os

pod_type = os.getenv("GRAPHCORE_POD_TYPE", "pod4")
executable_cache_dir = os.getenv("POPLAR_EXECUTABLE_CACHE_DIR", "/tmp/exe_cache/") + "/wav2vec2_fine_tuning"
checkpoint_directory = Path(os.getenv("PERSISTENT_CHECKPOINT_DIR", "/tmp")) / "demo"

## Preparing the LibriSpeech dataset

The 🤗`datasets` library from HuggingFace can be used to load the LibriSpeech dataset, as well as provide a tool to process the data.

First we are going to create a `DatasetDict` dictionary to handle our data, and then load the LibriSpeech splits for training and validation. For this notebook we will use `train.100` which is 100 hours of clean training data. Section C of the appendix in the [paper](https://arxiv.org/abs/2006.11477) suggests that fine-tuning a `Base` model can yield 6.1% WER without an additional language model.

In [8]:
raw_datasets = DatasetDict()
raw_datasets["train"] = load_dataset("librispeech_asr", "clean", split="train.100")
raw_datasets["eval"] = load_dataset("librispeech_asr", "clean", split="validation")

Downloading builder script:   0%|          | 0.00/11.5k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/10.1k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/10.2k [00:00<?, ?B/s]

Downloading and preparing dataset librispeech_asr/clean to /root/.cache/huggingface/datasets/librispeech_asr/clean/2.1.0/cff5df6e7955c80a67f80e27e7e655de71c689e2d2364bece785b972acb37fe7...


Downloading data files:   0%|          | 0/4 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/338M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/347M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.39G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/23.0G [00:00<?, ?B/s]

In [9]:
raw_datasets

### Text normalisation

Using the package `map` function, any special characters are removed from the transcription. The resulting transcript is then lower-cased. These two processes mean that the model will not have to learn punctuation and capitalisation, although it may have the ability to do so. This is much easier for the model.

There are other situations where text normalisation may be used like converting digits into their text counterpart. This is not performed in this script as LibriSpeech already has the text counterpart.

In [10]:
chars_to_ignore_regex = "".join([",", "?", ".", "!", "-", "\;", "\:", "\"", "“", "%", "‘", "”", "�"])
text_column_name = "text"


def remove_special_characters(batch):
    if chars_to_ignore_regex is not None:
        batch["target_text"] = re.sub(chars_to_ignore_regex, "", batch[text_column_name]).lower() + " "
    else:
        batch["target_text"] = batch[text_column_name].lower() + " "
    return batch


raw_datasets = raw_datasets.map(
    remove_special_characters,
    remove_columns=[text_column_name],
    desc="remove special characters from datasets",
)

### Create vocabulary and tokenizer

We now create a vocabulary from the dataset. This will find all the unique characters from all the normalised text in the datasets.

In [11]:
def create_vocabulary_from_data(
        datasets: DatasetDict,
        word_delimiter_token=None,
        unk_token=None,
        pad_token=None,
):
    # Given training and test labels create vocabulary
    def extract_all_chars(batch):
        all_text = " ".join(batch["target_text"])
        vocab = list(set(all_text))
        return {"vocab": [vocab], "all_text": [all_text]}

    vocabs = datasets.map(
        extract_all_chars,
        batched=True,
        batch_size=-1,
        keep_in_memory=True,
        remove_columns=datasets["train"].column_names,
    )

    # take union of all unique characters in each dataset
    vocab_set = functools.reduce(
        lambda vocab_1, vocab_2: set(vocab_1["vocab"][0]) | set(vocab_2["vocab"][0]), vocabs.values()
    )

    vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))}

    # replace white space with delimiter token
    if word_delimiter_token is not None:
        vocab_dict[word_delimiter_token] = vocab_dict[" "]
        del vocab_dict[" "]

    # add unk and pad token
    if unk_token is not None:
        vocab_dict[unk_token] = len(vocab_dict)

    if pad_token is not None:
        vocab_dict[pad_token] = len(vocab_dict)

    return vocab_dict


word_delimiter_token = "|"
unk_token = "[UNK]"
pad_token = "[PAD]"

vocab_dict = create_vocabulary_from_data(raw_datasets,
                                         word_delimiter_token=word_delimiter_token,
                                         unk_token=unk_token,
                                         pad_token=pad_token)

In [12]:
vocab_dict

With the vocabulary generated from the normalised trascripts we create a `tokenizer` which is included in the 🤗`transformers` library. This will later be used to encode text into indexes, and decode indexes into text.

In [13]:
tokenizer_name_or_path = "/tmp/wav2vec2-notebook"

vocab_file = os.path.join(tokenizer_name_or_path, "vocab.json")

if os.path.isfile(vocab_file):
    os.remove(vocab_file)

os.makedirs(tokenizer_name_or_path, exist_ok=True)

with open(vocab_file, "w") as file:
    json.dump(vocab_dict, file)

tokenizer_kwargs = {
    "config": None,
    "tokenizer_type": "wav2vec2",
    "unk_token": unk_token,
    "pad_token": pad_token,
    "word_delimiter_token": word_delimiter_token,
}

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_auth_token=False, **tokenizer_kwargs)

Let's look at an example for using the tokenizer. The vocabulary does not contain any digits, so these will be set to `[UNK]`. Remember, any special characters (such as commas) have already been removed from the dataset.

In [14]:
tokenizer("wav2vec2 finetuning on ipu")

In [15]:
tokenizer.decode(tokenizer("wav2vec2 finetuning on ipu").input_ids)

### Feature extraction

Now we generate the feature extraction method for the model and map it across the datasets onto the audio data. In this model we are learning from raw audio signal, so the feature extraction is just used to resample the audio to the rate which the model expects. 

Afterwards we set the minimum and maximum input lengths in samples. These are set to 2.0 and 15.6 seconds, converted to 32000 and 249600. 

In [16]:
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")

dataset_sampling_rate = next(iter(raw_datasets.values())).features["audio"].sampling_rate
if dataset_sampling_rate != 16000:
    raw_datasets = raw_datasets.cast_column("audio", datasets.features.Audio(sampling_rate=16000))

max_input_length = int(15.6 * feature_extractor.sampling_rate)
min_input_length = int(2.0 * feature_extractor.sampling_rate)

### Prepare dataset

In this step, both the feature extraction and tokenization are applied to the audio and transcript, respectively. The feature extractor resamples the audio, and the tokenizer will convert the normalised text into indexes.

After the map function, the dataset will be filtered by the audio length. If the length of the raw audio is not between 2.0 and 15.6 seconds then it will be removed from the data. The result of filtering is cached.

In [17]:
def prepare_dataset(batch):
    sample = batch["audio"]

    inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
    batch["input_values"] = inputs.input_values[0]
    batch["input_length"] = len(inputs.input_values[0])

    batch["labels"] = tokenizer(batch["target_text"]).input_ids

    return batch


def is_audio_in_length_range(length):
    try:
        return length > min_input_length and length < max_input_length
    except:
        return False


vectorized_datasets = raw_datasets.map(prepare_dataset,
                                       remove_columns=raw_datasets["train"].column_names,
                                       num_proc=8,
                                       desc="preprocess datasets")

vectorized_datasets = vectorized_datasets.filter(is_audio_in_length_range,
                                                 input_columns=["input_length"],
                                                 num_proc=8)

## Data loading

With the dataset prepared, the majority of the processing is complete nearly fit to be sent to the model. The role of the collator is to pad the resampled audio and encoded text to a static size. The padding values for audio will be set to `0.0` but for the indexes they will be `-100` so it's not confused with an index in the vocabulary.

In [18]:
@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.AutoProcessor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
        max_length_labels (:obj:`int`, `optional`):
            Maximum length of the ``labels`` returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
    """

    processor: AutoProcessor
    padding: Union[bool, str] = "longest"
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        batch["labels"] = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["input_values"] = batch["input_values"].half()

        return batch.data


processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

data_collator = DataCollatorCTCWithPadding(processor=processor, pad_to_multiple_of=int(max_input_length),
                                           pad_to_multiple_of_labels=1000)

## Preparing the model


For the model we are using `wav2vec2-base` from the HuggingFace model-hub. This model has been pretrained only.
Some of the defaults options for the model will need to be changed for training:
* CTC loss will be normalised by the lengths
* There is no masking of the features to be applied so both masks are set to 0.0, the current masking strategy isn't supported on IPU.
* The [PAD] index and vocabulary size are later used in the model for the final output layer and CTC-loss.
* Epsilon adjusted for FP16 training.


The IPU config describes how to parallelise the model across several IPUs. It also includes additional options such as gradient accumulation, device iterations, and memory proportion. 

In [None]:
config = AutoConfig.from_pretrained("facebook/wav2vec2-base")
config.update(
    {
        "ctc_loss_reduction": "mean",
        "mask_time_prob": 0.0,
        "mask_feature_prob": 0.0,
        "layerdrop": 0.0,
        "pad_token_id": tokenizer.pad_token_id,
        "vocab_size": len(tokenizer),
        "layer_norm_eps": 0.0001,
    }
)

model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base", config=config)

ipu_config = IPUConfig.from_pretrained("Graphcore/wav2vec2-base-ipu", executable_cache_dir=executable_cache_dir)

In [None]:
ipu_config.layers_per_ipu = [5, 5, 5, 6]

Let's set our training hyperparameters using `IPUTrainingArguments`. This subclasses the Hugging Face `TrainingArguments` class, adding parameters specific to the IPU and its execution characteristics.

In [None]:
training_args = IPUTrainingArguments(output_dir= checkpoint_directory,
                                     overwrite_output_dir=True,
                                     do_train=True,
                                     do_eval=True,
                                     evaluation_strategy="epoch",
                                     learning_rate=3e-4,
                                     num_train_epochs=5.0,
                                     adam_epsilon=0.0001,
                                     warmup_steps=400,
                                     dataloader_drop_last=True,
                                     dataloader_num_workers=16,
                                     )

In [None]:
feature_extractor.save_pretrained(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)
processor.save_pretrained(training_args.output_dir)

The performance of the model is measured using the WER. This metric takes a predicted string and the correct string and computes an edit distance normalised by the length. For many sentences the sum of the edit distances is normalised by the sum of the lengths. 

To add this metric to our evaluation we define a `compute_metrics` function and load the metric from the `datasets` package. This is performed once after all the evaluation outputs have been computed.

In [None]:
eval_metrics = {"wer": load_metric("wer")}


def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id

    pred_str = tokenizer.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)

    metrics = {k: v.compute(predictions=pred_str, references=label_str) for k, v in eval_metrics.items()}

    return metrics

To train the model, we define a trainer using the `IPUTrainer` class which takes care of compiling the model to run on IPUs, and of performing training and evaluation. The `IPUTrainer` class works just like the HuggingFace `Trainer` class, but takes the additional `ipu_config` argument.

In [None]:
# Initialize Trainer
trainer = IPUTrainer(
    model=model,
    ipu_config=ipu_config,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=vectorized_datasets["train"],
    eval_dataset=vectorized_datasets["eval"],
    tokenizer=feature_extractor,
)

## Run the training

In [None]:
trainer.train()
trainer.save_model()

In [None]:
!df -h