# Feature Pipeline for Swedish ASR Fine Tuning

## Introduction

<figure>
<img src="https://raw.githubusercontent.com/sanchit-gandhi/notebooks/main/whisper_architecture.svg" alt="Trulli" style="width:100%">
<figcaption align = "center"><b>Figure 1:</b> Whisper model. The architecture
follows the standard Transformer-based encoder-decoder model. A
log-Mel spectrogram is input to the encoder. The last encoder
hidden states are input to the decoder via cross-attention mechanisms. The
decoder autoregressively predicts text tokens, jointly conditional on the
encoder hidden states and previously predicted tokens. Figure source:
<a href="https://openai.com/blog/whisper/">OpenAI Whisper Blog</a>.</figcaption>
</figure>

The Whisper checkpoints come in five configurations of varying model sizes.
The smallest four are trained on either English-only or multilingual data.
The largest checkpoint is multilingual only. All nine of the pre-trained checkpoints
are available on the [Hugging Face Hub](https://huggingface.co/models?search=openai/whisper). The
checkpoints are summarised in the following table with links to the models on the Hub:

| Size   | Layers | Width | Heads | Parameters | English-only                                         | Multilingual                                      |
|--------|--------|-------|-------|------------|------------------------------------------------------|---------------------------------------------------|
| tiny   | 4      | 384   | 6     | 39 M       | [✓](https://huggingface.co/openai/whisper-tiny.en)   | [✓](https://huggingface.co/openai/whisper-tiny.)  |
| base   | 6      | 512   | 8     | 74 M       | [✓](https://huggingface.co/openai/whisper-base.en)   | [✓](https://huggingface.co/openai/whisper-base)   |
| small  | 12     | 768   | 12    | 244 M      | [✓](https://huggingface.co/openai/whisper-small.en)  | [✓](https://huggingface.co/openai/whisper-small)  |
| medium | 24     | 1024  | 16    | 769 M      | [✓](https://huggingface.co/openai/whisper-medium.en) | [✓](https://huggingface.co/openai/whisper-medium) |
| large  | 32     | 1280  | 20    | 1550 M     | x                                                    | [✓](https://huggingface.co/openai/whisper-large)  |

For demonstration purposes, we'll fine-tune the multilingual version of the
[`"small"`](https://huggingface.co/openai/whisper-small) checkpoint with 244M params (~= 1GB).
As for our data, we'll train and evaluate our system on a low-resource language
taken from the [Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0)
dataset. We'll show that with as little as 8 hours of fine-tuning data, we can achieve
strong performance in this language.

## Prepare Environment

We need to login to HuggingFace to download the dataset.

In [5]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Some global variable for config.

In [1]:
MODEL_SIZE = "small" # tiny, base, small, ...
LANG_CODE = "sv-SE"
LANG_NAME = "Swedish"
MODEL_VERSION = "v7"

We need to install a few dependencies.

In [36]:
!add-apt-repository -y ppa:jonathonf/ffmpeg-4
!apt update
!apt install -y ffmpeg
!pip install datasets>=2.6.1 git+https://github.com/huggingface/transformers
!pip install audiomentations

Repository: 'deb https://ppa.launchpadcontent.net/jonathonf/ffmpeg-4/ubuntu/ jammy main'
Description:
Backport of FFmpeg 4 and associated libraries. Now includes AOM/AV1 support!

FDK AAC is not compatible with GPL and FFmpeg can't be redistributed with it included. Please don't ask for it to be added to this public PPA.

---

PPA supporters:

BigBlueButton (https://bigbluebutton.org)

---

Donate to FFMPEG: https://ffmpeg.org/donations.html
Donate to Debian: https://www.debian.org/donations
Donate to this PPA: https://ko-fi.com/jonathonf
More info: https://launchpad.net/~jonathonf/+archive/ubuntu/ffmpeg-4
Adding repository.
Found existing deb entry in /etc/apt/sources.list.d/jonathonf-ubuntu-ffmpeg-4-jammy.list
Adding deb entry to /etc/apt/sources.list.d/jonathonf-ubuntu-ffmpeg-4-jammy.list
Found existing deb-src entry in /etc/apt/sources.list.d/jonathonf-ubuntu-ffmpeg-4-jammy.list
Adding disabled deb-src entry to /etc/apt/sources.list.d/jonathonf-ubuntu-ffmpeg-4-jammy.list
Adding key

We need to mount Google Drive to store data to it.

In [3]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


## Load Dataset

We use the [mozilla-foundation/common_voice_11_0](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) dataset.


### Download Dataset

In [41]:
from datasets import load_dataset, DatasetDict, DownloadConfig

common_voice = DatasetDict()

raw_data_path = "./raw_data/"

download_conf = DownloadConfig(
    token=True,
    cache_dir=raw_data_path,
)
common_voice["train"] = load_dataset("mozilla-foundation/common_voice_11_0", LANG_CODE, split="train+validation", download_config=download_conf)
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0", LANG_CODE, split="test", download_config=download_conf)

print(common_voice)

DatasetDict({
    train: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
        num_rows: 12360
    })
    test: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
        num_rows: 5069
    })
})


In [42]:
from datasets import load_dataset, DatasetDict, DownloadConfig

fleurs = DatasetDict()

raw_data_path = "./raw_data/"

download_conf = DownloadConfig(
    token=True,
    cache_dir=raw_data_path,
)
fleurs["train"] = load_dataset("google/fleurs", "sv_se", split="train+validation", download_config=download_conf)
fleurs["test"] = load_dataset("google/fleurs", "sv_se", split="test", download_config=download_conf)

print(fleurs)

DatasetDict({
    train: Dataset({
        features: ['id', 'num_samples', 'path', 'audio', 'transcription', 'raw_transcription', 'gender', 'lang_id', 'language', 'lang_group_id'],
        num_rows: 2715
    })
    test: Dataset({
        features: ['id', 'num_samples', 'path', 'audio', 'transcription', 'raw_transcription', 'gender', 'lang_id', 'language', 'lang_group_id'],
        num_rows: 759
    })
})


## Prepare Feature Extractor, Tokenizer and Data

### Load WhisperFeatureExtractor

The Whisper feature extractor performs two operations:
1. Pads / truncates the audio inputs to 30s: any audio inputs shorter than 30s are padded to 30s with silence (zeros), and those longer that 30s are truncated to 30s
2. Converts the audio inputs to _log-Mel spectrogram_ input features, a visual representation of the audio and the form of the input expected by the Whisper model

<figure>
<img src="https://raw.githubusercontent.com/sanchit-gandhi/notebooks/main/spectrogram.jpg" alt="Trulli" style="width:100%">
<figcaption align = "center"><b>Figure 2:</b> Conversion of sampled audio array to log-Mel spectrogram.
Left: sampled 1-dimensional audio signal. Right: corresponding log-Mel spectrogram. Figure source:
<a href="https://ai.googleblog.com/2019/04/specaugment-new-data-augmentation.html">Google SpecAugment Blog</a>.
</figcaption>

We'll load the feature extractor from the pre-trained checkpoint with the default values:

In [43]:
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained(f"openai/whisper-{MODEL_SIZE}")

### Load WhisperTokenizer

https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperTokenizer

The Whisper model outputs a sequence of _token ids_. The tokenizer maps each of these token ids to their corresponding text string.

In [44]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained(f"openai/whisper-{MODEL_SIZE}", language=LANG_NAME, task="transcribe")

## Prepare Data

Let's print the first example of the Common Voice dataset to see
what form the data is in:

In [45]:
print(common_voice["train"][0])
print(fleurs["train"][0])

{'client_id': '782ec7b338418a4966cc49ae09265d258705091874fb4d3a7fc76c9541738a997af0f51e9ef6045dc01874a060b482c7adfbfff2a51b50fa8d03764248956d48', 'path': './raw_data/extracted/40784a27e162d09ad00f11f09f5a86a0cd56ee87ffaa2341f563f63a5cc19a5d/sv-SE_train_0/common_voice_sv-SE_20466896.mp3', 'audio': {'path': './raw_data/extracted/40784a27e162d09ad00f11f09f5a86a0cd56ee87ffaa2341f563f63a5cc19a5d/sv-SE_train_0/common_voice_sv-SE_20466896.mp3', 'array': array([0., 0., 0., ..., 0., 0., 0.]), 'sampling_rate': 48000}, 'sentence': 'Du ser ut att ha gjort det här hela livet.', 'up_votes': 2, 'down_votes': 0, 'age': 'twenties', 'gender': 'female', 'accent': '', 'locale': 'sv-SE', 'segment': ''}
{'id': 927, 'num_samples': 243840, 'path': './raw_data/extracted/529b5fb5c1abd2ad090d99933ea3ab4662dc02dfa27741904a0c368c34921a68/10005022996767235104.wav', 'audio': {'path': 'train/10005022996767235104.wav', 'array': array([ 0.        ,  0.        ,  0.        , ..., -0.00037664,
       -0.00037396, -0.0001

Since
our input audio is sampled at 48kHz, we need to _downsample_ it to
16kHz prior to passing it to the Whisper feature extractor, 16kHz being the sampling rate expected by the Whisper model.

In [46]:
from datasets import Audio, concatenate_datasets

AUDIO_COLUMN_NAME = "audio"
TEXT_COLUMN_NAME = "sentence"

def normalize_dataset(ds, audio_column_name=None, text_column_name=None):
    if audio_column_name is not None and audio_column_name != AUDIO_COLUMN_NAME:
        ds = ds.rename_column(audio_column_name, AUDIO_COLUMN_NAME)
    if text_column_name is not None and text_column_name != TEXT_COLUMN_NAME:
        ds = ds.rename_column(text_column_name, TEXT_COLUMN_NAME)
    # resample to the same sampling rate
    ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
    # normalise columns to ["audio", "sentence"]
    ds = ds.remove_columns(set(ds.features.keys()) - set([AUDIO_COLUMN_NAME, TEXT_COLUMN_NAME]))
    return ds

common_voice["train"] = normalize_dataset(common_voice["train"])
fleurs["train"] = normalize_dataset(fleurs["train"], text_column_name="raw_transcription")

common_voice["test"] = normalize_dataset(common_voice["test"])
fleurs["test"] = normalize_dataset(fleurs["test"], text_column_name="raw_transcription")

data = DatasetDict()
data["train"] = concatenate_datasets([common_voice["train"], fleurs["train"]])
# NB: shuffle concatenated dataset
data["train"] = data["train"].shuffle(seed=10)


data["test"] = concatenate_datasets([common_voice["test"], fleurs["test"]])
# NB: shuffle concatenated dataset
data["test"] = data["test"].shuffle(seed=10)

Re-loading the first audio sample in the Common Voice dataset will resample
it to the desired sampling rate:

### Data Augmentation

In [48]:
from audiomentations import (
    AddBackgroundNoise,
    Compose,
    Gain,
    PitchShift,
    PolarityInversion,
    TimeStretch,
)

# define augmentation
augmentation = Compose(
    [
        TimeStretch(min_rate=0.9, max_rate=1.1, p=0.2, leave_length_unchanged=False),
        Gain(min_gain_in_db=-4, max_gain_in_db=4, p=0.1),
        PitchShift(min_semitones=-2, max_semitones=2, p=0.2),
    ]
)


def augment_dataset(batch):
    # load and (possibly) resample audio data to 16kHz
    sample = batch["audio"]

    # apply augmentation
    augmented_waveform = augmentation(sample["array"], sample_rate=sample["sampling_rate"])
    batch["audio"]["array"] = augmented_waveform
    return batch

# augment training data
augmented_raw_training_dataset = data["train"].map(
    augment_dataset, num_proc=2, desc="augment train dataset"
)

# combine
data["train"] = concatenate_datasets([data["train"], augmented_raw_training_dataset])
data["train"] = data["train"].shuffle(seed=10)

augment train dataset (num_proc=2):   0%|          | 0/15075 [00:00<?, ? examples/s]



In [49]:
import random

print(data["train"][random.randint(0, len(data["train"]))])
print(data)

{'audio': {'path': None, 'array': array([ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
        0.00000000e+00,  0.00000000e+00, -3.05175781e-05]), 'sampling_rate': 16000}, 'sentence': 'Han har lovat att vara lat.'}
DatasetDict({
    train: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 30150
    })
    test: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 5828
    })
})


### Transform Data
This is the main data processing step that creates the features.
Unlike in the tutorial, we actually do it in batches, which is a bit faster.

In [50]:
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # compute input length of audio sample in seconds
    batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]

    # encode target text to label ids
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

#def prepare_dataset(batch):
    # compute log-Mel input features from input audio array
    #batch["input_features"] = [feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0] for audio in batch["audio"]]

    # encode target text to label ids
    #batch["labels"] = [tokenizer(sentence).input_ids for sentence in batch["sentence"]]
    #return batch

#common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=2, batched=True, batch_size=128)
prepared_data = data.map(prepare_dataset, remove_columns=data.column_names["train"], num_proc=2)

Map (num_proc=2):   0%|          | 0/30150 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/5828 [00:00<?, ? examples/s]

In [51]:
import random

print(prepared_data["train"][random.randint(0, len(prepared_data["train"]))])

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [52]:
print(prepared_data["test"][random.randint(0, len(prepared_data["test"]))])
print(prepared_data)

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



### Remove long audio

Whisper can only train with audio chunks up to 30s

In [53]:
max_input_length = 30

def is_audio_in_length_range(length):
    return length < max_input_length

filtered_data = prepared_data.filter(
    is_audio_in_length_range, num_proc=2, input_columns=["input_length"]
)
filtered_data.remove_columns("input_length")

Filter (num_proc=2):   0%|          | 0/30150 [00:00<?, ? examples/s]

Filter (num_proc=2):   0%|          | 0/5828 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 30125
    })
    test: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 5827
    })
})

### Remove text that is too long

Whisper can only deal with labels up to a max length

In [54]:
max_label_length = 448 #model.config.max_target_positions

def is_labels_in_length_range(labels):
    return len(labels) < max_label_length

filtered_data = filtered_data.filter(
    is_labels_in_length_range, num_proc=2, input_columns=["labels"]
)

Filter (num_proc=2):   0%|          | 0/30125 [00:00<?, ? examples/s]

Filter (num_proc=2):   0%|          | 0/5827 [00:00<?, ? examples/s]

In [55]:
print(filtered_data)

DatasetDict({
    train: Dataset({
        features: ['input_features', 'input_length', 'labels'],
        num_rows: 30125
    })
    test: Dataset({
        features: ['input_features', 'input_length', 'labels'],
        num_rows: 5827
    })
})


## Save Features in Google Drive

In [56]:
# save features to drive
drive_features_path = f"/content/drive/MyDrive/ID2223/lab2/{MODEL_VERSION}/{LANG_NAME}/features/{MODEL_SIZE}/"
filtered_data.save_to_disk(drive_features_path, max_shard_size="1GB")

Saving the dataset (0/29 shards):   0%|          | 0/30125 [00:00<?, ? examples/s]

Saving the dataset (0/6 shards):   0%|          | 0/5827 [00:00<?, ? examples/s]

In [None]:
test_v1 = "9d72e2e08c8a3c47"
train_v1 = "7564f9602fee7770"
test_v2 = "f01f52aef4fa5fba"
train_v2 = "0a0def596547b5a1"

# load features from drive
from datasets import load_from_disk

drive_features_path = f"/content/drive/MyDrive/ID2223/lab2/v1/swedish/features/{MODEL_SIZE}/"
common_voice_v1 = load_from_disk(drive_features_path)



IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [None]:
print(common_voice["train"][12000]["input_features"][16])
print(common_voice_v1["train"][12000]["input_features"][16])
print(common_voice["train"][12000]["labels"])
print(common_voice_v1["train"][12000]["labels"])

[-0.7454781532287598, -0.7454781532287598, -0.7454781532287598, -0.47161686420440674, -0.4697835445404053, -0.5579724311828613, -0.5971349477767944, -0.5251418352127075, -0.4510633945465088, -0.418179988861084, -0.5701056718826294, -0.442257285118103, -0.493937611579895, -0.3336390256881714, -0.5032496452331543, -0.48250865936279297, -0.47455036640167236, -0.6313267946243286, -0.4681462049484253, -0.48624396324157715, -0.5473730564117432, -0.6280766725540161, -0.5531842708587646, -0.5454846620559692, -0.48922228813171387, -0.5870174169540405, -0.49881649017333984, -0.5354394912719727, -0.42693662643432617, -0.554484486579895, -0.5857061147689819, -0.4437748193740845, -0.35414111614227295, -0.45352840423583984, -0.6511633396148682, -0.36377596855163574, -0.7454781532287598, -0.5811154842376709, -0.4593040943145752, -0.5245180130004883, -0.3889038562774658, -0.3174154758453369, -0.22594690322875977, -0.3608362674713135, -0.48156237602233887, -0.3279153108596802, -0.23989784717559814, -0.