# Training Whisper on huggingface

In my previous article, we learned about the new end-to-end speech recognition model developed by OpenAI: Whisper.

Today, we will go through the steps required to fine-tune a Whisper model using several Huggingface libraries. Furthermore, we will understand in detail how the Huggingface libraries take care of all the different steps under the hood, and how the model learns from the training examples.

But first let me clarify an important point: Whisper models are already trained on downstream tasks, which means that they can be used out-of-the-box to perform several tasks like language-to-language transcription, language-to-english translation, and language identification. But you will achieve better performance on specific distributions (language, domain, specific background noise, ...) if the model is fine-tuned on a a specific dataset.

If you want to learn more about the approach used, training data, model architecture and the extensive evaluation performed by the OpenAI team, [this](https://marinone94.github.io/Whisper-paper/) is the place to start! You will benefit much more from this post afterwards.

Since the scope of the article is to learn about the training process, we will fine-tune the smallest model available - [Whisper Tiny](https://huggingface.co/openai/whisper-tiny) - on the English subset of the [Fleurs](https://huggingface.co/datasets/google/fleurs) dataset. Whisper models have been trained largely on English data, so I don't expect to further improve the performance; still, it will be a good starting point to understand how the training process works, and it will allow everyone to verify the model's predictions. 

## Environment setup

To reproduce the following examples, I recommend you to setup a virtual environment. The code has been tested with the packages listed in the `requirements.txt` file, so I cannot guarantee that different packages or Python versions will run smoothly (although they will mostly will).

In [None]:
!python3.8 -m venv venv
!source venv/bin/activate
!pip install -r requirements.txt

## Training dataset

Fleurs is a dataset open-sourced by Google which contains approximately 2000 examples for each language. Each training set has around **10 hours** of supervision, and speakers of the training sets are different from the speakers of the dev and test sets.

This dataset has also been used to evaluate the translation capabilities of Whisper models, since all sentences are translated in all languages and can be matched using their ids.

Before proceeding with the training, let's take a look at the data.

In [6]:
from datasets import load_dataset

dataset = load_dataset("google/fleurs", "en_us", streaming=True)

In [7]:
dataset

{'train': <datasets.iterable_dataset.IterableDataset at 0x7f8ff41ec310>,
 'validation': <datasets.iterable_dataset.IterableDataset at 0x7f8ff41ecc10>,
 'test': <datasets.iterable_dataset.IterableDataset at 0x7f8ff41e83d0>}

As you can see, the dataset contains three splits. Each split is an IterableDataset, since we have loaded it in streaming mode. This means that the dataset is not downloaded, but it is loaded on the fly when needed. This is useful when the dataset occupies too much space on the disk, or if you want to avoid waiting for the whole dataset to be downloaded. Huggingface [docs](https://huggingface.co/docs/datasets/stream) are excellent to learn more about the datasets library and the streaming mode.

But we can still explore the dataset features without downloading it. So let's have a look.

In [9]:
from pprint import pprint
features = dataset['train'].features
pprint(features)

{'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None),
 'gender': ClassLabel(names=['male', 'female', 'other'], id=None),
 'id': Value(dtype='int32', id=None),
 'lang_group_id': ClassLabel(names=['western_european_we', 'eastern_european_ee', 'central_asia_middle_north_african_cmn', 'sub_saharan_african_ssa', 'south_asian_sa', 'south_east_asian_sea', 'chinese_japanase_korean_cjk'], id=None),
 'lang_id': ClassLabel(names=['af_za', 'am_et', 'ar_eg', 'as_in', 'ast_es', 'az_az', 'be_by', 'bg_bg', 'bn_in', 'bs_ba', 'ca_es', 'ceb_ph', 'ckb_iq', 'cmn_hans_cn', 'cs_cz', 'cy_gb', 'da_dk', 'de_de', 'el_gr', 'en_us', 'es_419', 'et_ee', 'fa_ir', 'ff_sn', 'fi_fi', 'fil_ph', 'fr_fr', 'ga_ie', 'gl_es', 'gu_in', 'ha_ng', 'he_il', 'hi_in', 'hr_hr', 'hu_hu', 'hy_am', 'id_id', 'ig_ng', 'is_is', 'it_it', 'ja_jp', 'jv_id', 'ka_ge', 'kam_ke', 'kea_cv', 'kk_kz', 'km_kh', 'kn_in', 'ko_kr', 'ky_kg', 'lb_lu', 'lg_ug', 'ln_cd', 'lo_la', 'lt_lt', 'luo_ke', 'lv_lv', 'mi_nz', 'mk_mk', 'ml_in', 'mn_mn