# Fine-tuning Whisper

This notebook fine-tunes Whisper on French with the goal of improving accuracy. This notebook was used as a first step towards generating the `-fr` models included in the GitHub release. After creating a fine-tuned model with this notebook (or a similar, slightly-modified notebook), the `-fr` models were then processed with [whisper_more_efficient_encoding](./whisper_more_efficient_encoding.ipynb) to improve performance on shorter audio segments.

**Note**: See [whisper_more_efficient_encoding](./whisper_more_efficient_encoding.ipynb) for logic that converts models to a Joplin-compatible format.

This notebook roughly follows [this blog post](https://huggingface.co/blog/fine-tune-whisper).

**Goal**: Fine-tune `whisper-base` to have medium to high performance on French-language input *without* timestamps.

In [1]:
!pip install --upgrade pip
# jiwer is used for the word error rate (WER) metric
!pip install --upgrade datasets[audio] transformers evaluate jiwer

Collecting pip
  Downloading pip-25.0.1-py3-none-any.whl.metadata (3.7 kB)
Downloading pip-25.0.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m32.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.0.1
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting jiwer
  Downloading jiwer-3.1.0-py3-none-any.whl.metadata (2.6 kB)
Collecting datasets[audio]
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets[audio])
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets[audio])
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (

In [2]:
!pip install pyspellchecker==0.8.1

Collecting pyspellchecker==0.8.1
  Downloading pyspellchecker-0.8.1-py3-none-any.whl.metadata (9.4 kB)
Downloading pyspellchecker-0.8.1-py3-none-any.whl (6.8 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/6.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m94.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyspellchecker
Successfully installed pyspellchecker-0.8.1


In [3]:
import wandb
# See https://discuss.huggingface.co/t/how-to-turn-wandb-off-in-trainer/6237/10
wandb.init(mode='disabled')

In [4]:
from pathlib import Path

checkpoint_remote_path = Path('./final-checkpoints').resolve()
def connect_to_google_drive():
    """ Connects to Google Drive and configures the notebook to upload final
        checkpoints. """
    from google.colab import drive

    drive.mount('/content/drive')
    return Path('/content/drive/My Drive') / 'whisper' / 'checkpoints'

# Optional:
#checkpoint_remote_path = connect_to_google_drive()

In [5]:
if not checkpoint_remote_path.parent.exists():
    checkpoint_remote_path.parent.mkdir(parents=True)

In [6]:
checkpoint_path = Path('./whisper/checkpoints').resolve()

In [7]:
import shutil


## Load data

The [voxpopuli](https://huggingface.co/datasets/facebook/voxpopuli/viewer/fr/train?f%5Braw_text%5D%5Bmin%5D=236&f%5Braw_text%5D%5Bmax%5D=354&f%5Braw_text%5D%5Btransform%5D=length&row=45) and CommonVoice datasets will be used to fine-tune Whisper.

To speed up processing later on, we download the full `train` part of the dataset at once (`streaming=False`). The initial download may take some time.

In [8]:
from datasets import load_dataset, IterableDatasetDict, interleave_datasets

def load_dataset_from_id(dataset_id: str, language_code: str = 'fr'):
    data_raw = IterableDatasetDict()

    data_raw['train'] = load_dataset(dataset_id, language_code, split='train', streaming=True)
    print('Loaded training data. Loading test data:')
    data_raw['test'] = load_dataset(dataset_id, language_code, split='test', streaming=True)
    return data_raw


In [9]:
print('Loading Voxpopuli')
voxpopuli_data_raw = load_dataset_from_id('facebook/voxpopuli')\
    .rename_column('raw_text', 'text')\
    .select_columns(['audio', 'text'])


Loading Voxpopuli


README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

voxpopuli.py:   0%|          | 0.00/8.84k [00:00<?, ?B/s]

The repository for facebook/voxpopuli contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/facebook/voxpopuli.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y
Loaded training data. Loading test data:


In [10]:
print('Loading CommonVoice...')
common_voice_data_raw = load_dataset_from_id('mozilla-foundation/common_voice_11_0')\
    .rename_column('sentence', 'text')\
    .select_columns(['audio', 'text'])

Loading CommonVoice...


README.md:   0%|          | 0.00/14.4k [00:00<?, ?B/s]

common_voice_11_0.py:   0%|          | 0.00/8.13k [00:00<?, ?B/s]

languages.py:   0%|          | 0.00/3.44k [00:00<?, ?B/s]

release_stats.py:   0%|          | 0.00/60.9k [00:00<?, ?B/s]

The repository for mozilla-foundation/common_voice_11_0 contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/mozilla-foundation/common_voice_11_0.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y
Loaded training data. Loading test data:


We'll also include some data from FLEURS:

In [11]:
fleurs_data = load_dataset_from_id('google/fleurs', 'fr_fr')\
    .rename_column('raw_transcription', 'text')\
    .select_columns(['audio', 'text'])

README.md:   0%|          | 0.00/13.3k [00:00<?, ?B/s]

fleurs.py:   0%|          | 0.00/12.5k [00:00<?, ?B/s]

The repository for google/fleurs contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/google/fleurs.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y
Loaded training data. Loading test data:


Let's start by inspecting a FLEURS sample:

In [12]:
test_data = next(iter(fleurs_data['train']))
print(test_data)

{'audio': {'path': None, 'array': array([ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
       -3.05175781e-05,  0.00000000e+00,  0.00000000e+00]), 'sampling_rate': 16000}, 'text': 'Quand la capsule rentrera dans l’atmosphère terrestre, vers 5\xa0heures du matin (heure de l’Est), elle offrira un spectacle lumineux spectaculaire aux habitants du nord de la Californie, de l’Oregon, du Nevada et de l’Utah.'}


In [13]:
from IPython.display import Audio as AudioDisplay
AudioDisplay(test_data['audio']['array'], rate=test_data['audio']['sampling_rate'])

### Postprocessing


The GGML conversion script has trouble with some characters (e.g. the `\u0301` accute accent character). These characters can be replaced.

The normalization function is something we'll use later:

In [14]:
# Normalize text
import unicodedata, re

punctuation_ending_exp = re.compile(r'^.*[\.\?!\)\]]$')
space_exp = re.compile(r'\s+')

# We'll use this later
def normalize_text(text: str):
    text = space_exp.sub(' ', text)
    text = text.strip()
    replacements = [
        ['’', '\''],
        ['‘', '\''],
        ['́a', 'á'], # Convert from two-character á to one-character á
        ['́u', 'ú'],
        ['́e', 'é'],
        ['̀e', 'è'],
        ['̀a', 'à'],
        ['\xa0', ' '], # Non-breaking spaces -> spaces
        # Some characters don't work with the GGML conversion script:
        ['œ', '[oe]'],
        ['́', '\''],
        ['̂', '\''],
        ['̀', '\''],
        ['—', '--'],
        ['…', '...'],
        ['の', ''],
    ]
    for [orig, replace] in replacements:
        text = text.replace(orig, replace)

    if len(text) > 1:
        text = text[0].upper() + text[1:]

    if len(text) == 0:
        return '[BLANK_AUDIO]'

    if not punctuation_ending_exp.match(text):
        text += '.'

    return text

def normalize_texts(texts):
    return [ normalize_text(text) for text in texts ]

assert normalize_text('  ') == '[BLANK_AUDIO]'
assert normalize_text('.\xa0\xa0.') == '. .'
assert normalize_text(' Test. ') == 'Test.'

We'll also remove bad samples:

In [15]:
def should_keep_sample(text):
    trimmed_text = text.strip()
    if trimmed_text == '':
        return False
    # Exclude items that start with a lowercase letter (our model should operate on full sentences).
    return trimmed_text[0].upper() == trimmed_text[0]
def remove_bad_samples(dataset):
    return dataset.filter(should_keep_sample, input_columns=['text'])

voxpopuli_data = voxpopuli_data_raw
common_voice_data = common_voice_data_raw

common_voice_data = remove_bad_samples(common_voice_data)
voxpopuli_data = remove_bad_samples(voxpopuli_data)
fleurs_data = remove_bad_samples(fleurs_data)


Whisper expects a sampling rate of 16,000 Hz. Adjust the data so that it has this rate:

In [16]:
from datasets import Audio

audioFeature = Audio(sampling_rate=16_000)
def cast_audio(data):
    return data.cast_column('audio', audioFeature)

voxpopuli_data = cast_audio(voxpopuli_data)
fleurs_data = cast_audio(fleurs_data)

We can inspect a single item of the updated dataset using `next(iter(...))`:

In [19]:
# Inspect:
test_data_iterator = iter(voxpopuli_data['train'])
next(test_data_iterator)
next(test_data_iterator)

{'audio': {'path': None,
  'array': array([-0.00088501, -0.00619507, -0.00302124, ..., -0.00537109,
          0.00131226,  0.00106812]),
  'sampling_rate': 16000},
 'text': 'Cette régulation concerne les États et leurs économies, mais évidemment, les principaux acteurs économiques sont les entreprises et, en particulier, les entreprises multinationales, qui sont les principales actrices du commerce mondial.'}

The `voxpopuli` dataset has now been converted to a format that's ready for training. Now let's do something similar for the `commonvoice` dataset.

The processing for CommonVoice is a bit different — by default, all CommonVoice data is a single sentence. Because we want our fine-tuned model to work with multiple sentences, we join neighboring sentences to create multi-sentence input:

In [20]:
from random import randint, Random
import numpy as np
sentence_combine_random = Random(123456)

def combine_sentences(batch):
    audios = batch['audio']
    texts = batch['text']
    # See https://github.com/huggingface/datasets/issues/5361
    if len(audios) > 0:
        count = sentence_combine_random.randint(1, len(audios))
        audios = audios[0:count]
        texts = texts[0:count]
        joinedAudio = audioFeature.encode_example({
            'array': np.concatenate([ audio['array'] for audio in audios ]),
            'sampling_rate': audios[0]['sampling_rate']
        })
        batch['audio'] = [joinedAudio]
        batch['text'] = [ ' '.join(texts) ]
        return batch
    else:
        batch['audio'] = []
        batch['text'] = []
        return batch

def map_subdataset(key: str):
    common_voice_data[key] = common_voice_data[key].map(
        combine_sentences,
        batched=True,
        batch_size=3,
        # Pass features to allow casting audio later. See https://github.com/huggingface/datasets/issues/5828
        features=common_voice_data[key].features
    )
common_voice_data = cast_audio(common_voice_data)
map_subdataset('train')
map_subdataset('test')
common_voice_data

IterableDatasetDict({
    train: IterableDataset({
        features: ['audio', 'text'],
        num_shards: 13
    })
    test: IterableDataset({
        features: ['audio', 'text'],
        num_shards: 1
    })
})

Now that the CommonVoice data is processed, let's generate additional training data. The model should be able to recognize blank audio, so we create a noise dataset:

In [21]:
from datasets import Dataset
import numpy as np
from random import SystemRandom
np.random.seed(2)
noise_system_random = SystemRandom(1234)

def build_noise_data():
    audios = []
    texts = []
    for length_seconds in range(1, 12):
        for j in range(0, 8):
            amplitude = noise_system_random.random() * 0.03
            if noise_system_random.random() < 0.1:
                amplitude = 0.0

            # Adjust the sample rate to change the frequencies in the noise
            sample_rate = 11_000 - j * 500
            audios.append({
                'array': (np.random.rand(length_seconds * sample_rate) - 0.5) * amplitude,
                'sampling_rate': sample_rate,
            })
            texts.append('[BLANK_AUDIO]')
    return Dataset.from_dict({
        'audio': audios,
        'text': texts,
    }).shuffle(seed=124)
noise_data_train = cast_audio(build_noise_data()).to_iterable_dataset()
noise_data_test = cast_audio(build_noise_data()).to_iterable_dataset()

example_noise_data = next(iter(noise_data_train))
AudioDisplay(example_noise_data['audio']['array'], rate=example_noise_data['audio']['sampling_rate'])

The next step is to combine the datasets:

In [22]:

voice_data = IterableDatasetDict()

def interleaving_for(subset: str):
    """ Combines either the test or training sets """
    noise_data = noise_data_train
    if subset == 'test':
        noise_data = noise_data_test
    return interleave_datasets([
        voxpopuli_data[subset], common_voice_data[subset], fleurs_data[subset], noise_data,
    ], probabilities=[0.46, 0.40, 0.13, 0.01], stopping_strategy='all_exhausted')

voice_data['train'] = interleaving_for('train')
voice_data['test'] = interleaving_for('test')

In [23]:
voice_data = voice_data\
    .remove_columns(['gender', 'normalized_text', 'accent', 'is_gold_transcript', 'audio_id', 'language'])

voice_data

IterableDatasetDict({
    train: IterableDataset({
        features: ['audio', 'text'],
        num_shards: 1
    })
    test: IterableDataset({
        features: ['audio', 'text'],
        num_shards: 1
    })
})

In [24]:
# Also remove bad data after postprocessing is finished
voice_data = remove_bad_samples(voice_data)

In [25]:
print(next(iter(voice_data['test'])))

Reading metadata...: 16089it [00:00, 34338.12it/s]


{'audio': {'path': None, 'array': array([ 1.15329749e-06, -1.28702959e-06,  1.41587225e-06, ...,
        2.56873318e-05,  4.29910142e-05, -5.58127649e-05]), 'sampling_rate': 16000}, 'text': "Ce dernier a évolué tout au long de l'histoire romaine. Son actionnaire majoritaire est le Conseil territorial de Saint-Pierre-et-Miquelon."}


This next part might not be necessary. In an attempt to save RAM, clear the now-unused dataset variables:

In [26]:
# Clear unused variables to save memory (important in Google Colab)
del voxpopuli_data_raw
del voxpopuli_data
del fleurs_data
del common_voice_data_raw
del common_voice_data
del noise_data_train
del noise_data_test
del test_data

In [27]:
import gc

gc.collect()

58784

## Inspecting a sample

Let's check that the expected columns are still present in the training data:

In [28]:
sample = next(iter(voice_data['train']))

Reading metadata...: 485034it [00:15, 31260.37it/s]


In [29]:
sample

{'audio': {'path': None,
  'array': array([-5.46788215e-09, -7.59428076e-10,  8.12633516e-09, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00]),
  'sampling_rate': 16000},
 'text': 'Il est dissous à Trèves.'}

## Create the feature extractor and tokenizer

We'll be fine-tuning the `openai/whisper-base` model. Here, the feature extractor and tokenizer for this model are fetched from Huggingface:

In [30]:
from transformers import WhisperFeatureExtractor, WhisperTokenizer

finetune_from_id = 'openai/whisper-base'
feature_extractor = WhisperFeatureExtractor.from_pretrained(finetune_from_id, language='french', task='transcribe')
tokenizer = WhisperTokenizer.from_pretrained(finetune_from_id, language='french', task='transcribe')

preprocessor_config.json:   0%|          | 0.00/185k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/836k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

## Create the processor

Next, load the `WhisperProcessor`, which combines a feature extractor and tokenizer.

In [31]:
from transformers import WhisperProcessor

processor = WhisperProcessor(feature_extractor, tokenizer)

Use the feature extractor to convert the data into a format suitable for the model:

In [32]:
def map_sample(batch):
    audio_data = batch['audio']['array']
    audio_sample_rate = batch['audio']['sampling_rate']
    features = processor.feature_extractor(audio_data, sampling_rate=audio_sample_rate)

    batch['input_features'] = features.input_features[0]
    batch['labels'] = processor.tokenizer(normalize_text(batch['text'])).input_ids
    return batch

# Remove columns no longer used
voice_data_original = voice_data # For debugging
voice_data = voice_data.map(map_sample, remove_columns=['audio', 'text'])
voice_data

IterableDatasetDict({
    train: IterableDataset({
        features: Unknown,
        num_shards: 1
    })
    test: IterableDataset({
        features: Unknown,
        num_shards: 1
    })
})

In [33]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained(finetune_from_id)
model.generation_config.language = 'french'
model.generation_config.task = 'transcribe'
model.generation_config.forced_decoder_ids = None


config.json:   0%|          | 0.00/1.98k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/290M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/3.81k [00:00<?, ?B/s]

In [34]:
from dataclasses import dataclass
from typing import Any
import torch
# See the linked blog post and https://huggingface.co/docs/transformers/main_classes/data_collator

@dataclass
class DataCollatorWithPadding:
    ''' Converts raw data into a batch ready for the model '''
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: list) -> dict[str, torch.Tensor]:
        input_features = [{'input_features': f['input_features']} for f in features]
        label_features = [{'input_ids': f['labels']} for f in features]

        # According to the linked blog post, the input and label features need
        # to be padded separately (due to different final lengths), then
        # recombined:
        batch = self.processor.feature_extractor.pad(input_features, return_tensors='pt')

        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors='pt')

        # transformers uses -100 for masking
        labels = labels_batch['input_ids'].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # Don't double-prepend the beginning of sequence token:
        if (labels[:,0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch['labels'] = labels
        return batch

data_collator = DataCollatorWithPadding(processor=processor, decoder_start_token_id=model.config.decoder_start_token_id)

# Viewing sample data

Let's look at some of the training data:

In [35]:
sample_data = next(iter(voice_data['test']))
sample_labels = sample_data['labels']

Reading metadata...: 16089it [00:00, 37863.36it/s]


In [36]:
processor.decode(sample_labels)

"<|startoftranscript|><|fr|><|transcribe|><|notimestamps|>Ce dernier a évolué tout au long de l'histoire romaine. Son actionnaire majoritaire est le Conseil territorial de Saint-Pierre-et-Miquelon. Ce site contient quatre tombeaux de la dynastie achéménide et sept des Sassanides.<|endoftext|>"

In [37]:
def run_on_sample_audio():
    """ Returns the (text) result of running the model on a single audio sample. """
    sample_audio = next(iter(voice_data_original['test']))['audio']
    inputs = processor(sample_audio['array'], return_tensors='pt')
    try:
        generated_ids = model.generate(inputs=inputs.input_features)
    except:
        generated_ids = model.generate(inputs=inputs.input_features.to('cuda'))
    return processor.batch_decode(generated_ids)

In [38]:
print(run_on_sample_audio())

Reading metadata...: 16089it [00:00, 38185.47it/s]
It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
You have passed task=transcribe, but also have set `forced_decoder_ids` to [[1, 50259], [2, 50359], [3, 50363]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=transcribe.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


[" Ce dernier évolué tout au long de l'histoire romaine."]


## Preparing an evaluation function


In [39]:
import evaluate

wer_metric = evaluate.load('wer')
cer_metric = evaluate.load('cer')

def compute_metrics(data):
    true_labels = data.label_ids
    predictions = data.predictions

    # Convert padding from HF
    true_labels[true_labels == -100] = processor.tokenizer.pad_token_id

    predicted_text = processor.batch_decode(predictions, skip_special_tokens=True)
    label_text = processor.batch_decode(true_labels, skip_special_tokens=True)

    # Avoid empty labels/predictions (which can prevent wer_metric and cer_metric
    # from working).
    def fix_prediction(pred):
        if len(pred) == 0:
            return "[empty]"
        else:
            return normalize_text(pred)

    predicted_text = [ fix_prediction(prediction) for prediction in predicted_text ]
    label_text = [ fix_prediction(label) for label in label_text ]

    try:
        wer = wer_metric.compute(predictions=predicted_text, references=label_text)
        cer = cer_metric.compute(predictions=predicted_text, references=label_text)
    except ValueError as err: # E.g. "One or more references are empty strings"
        print('WARNING: While computing WER and CER:', err, label_text)
        wer = 100 # Use a large value in this case
        cer = 100
    return { 'wer': wer, 'cer': cer }


Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/5.60k [00:00<?, ?B/s]

## Preparing training arguments

In [40]:
from transformers import Seq2SeqTrainingArguments

# TODO: Update this if you're planning to push the custom model to
# huggingface (ignore otherwise):
hub_model_id = 'personalizedrefrigerator/whisper-base-fr'

def make_training_args(max_steps: int, learning_rate: float = 1e-4):
    return Seq2SeqTrainingArguments(
        output_dir = checkpoint_path,
        hub_model_id=hub_model_id,
        learning_rate=learning_rate,
        max_steps=max_steps,
        # gradient_checkpointing=True,
        logging_first_step=True,
        fp16=True,
        eval_strategy='steps',
        per_device_eval_batch_size=8,
        generation_max_length=256,
        predict_with_generate=True,
        #auto_find_batch_size = True,
        save_steps=5000,
        eval_steps=1000,
        logging_steps=25,
        save_total_limit=1,
    )

In [41]:
small_eval_dataset = voice_data['test'].shuffle(seed=11).take(128)
large_eval_dataset = voice_data['test'].shuffle(seed=14).take(512)

In [42]:
shuffled_train_dataset = voice_data['train'].shuffle(seed=15)

In [45]:
from transformers import Seq2SeqTrainer

def make_trainer(max_steps: int = 33_000, learning_rate: float = 1e-5):
    return Seq2SeqTrainer(
        args=make_training_args(max_steps),
        model=model,
        train_dataset=shuffled_train_dataset,
        eval_dataset=small_eval_dataset,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        processing_class=processor.feature_extractor,
    )

trainer = make_trainer()

## Training and evaluation

In [46]:
trainer.evaluate(large_eval_dataset)

Reading metadata...: 16089it [00:00, 37235.05it/s]


{'eval_loss': 1.0241453647613525,
 'eval_model_preparation_time': 0.0031,
 'eval_wer': 0.43377164849262345,
 'eval_cer': 0.2355128760284266,
 'eval_runtime': 229.7313,
 'eval_samples_per_second': 2.229,
 'eval_steps_per_second': 0.279}

In [47]:
trainer.train()

Reading metadata...: 485034it [00:12, 38652.53it/s]


Step,Training Loss,Validation Loss,Model Preparation Time,Wer,Cer
1000,0.7687,1.001273,0.0031,0.439784,0.261968
2000,0.7665,0.965672,0.0031,0.478138,0.320303
3000,0.6699,0.888818,0.0031,0.4,0.203194
4000,0.5928,0.885811,0.0031,0.378223,0.187056
5000,0.5833,0.828014,0.0031,0.351428,0.171162
6000,0.4953,0.82157,0.0031,0.374325,0.188178
7000,0.4421,0.78262,0.0031,0.375914,0.21336
8000,0.4458,0.767467,0.0031,0.407455,0.208896
9000,0.4296,0.782993,0.0031,0.357835,0.177699
10000,0.442,0.774285,0.0031,0.349457,0.175729


'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 2c7d68e8-9fa0-4c4f-b2a0-51562a7cee4d)')' thrown while requesting GET https://huggingface.co/datasets/google/fleurs/resolve/d7c758a6dceecd54a98cac43404d3d576e721f07/data/fr_fr/audio/train.tar.gz
Retrying in 1s [Retry 1/5].
Reading metadata...: 16089it [00:00, 36104.19it/s]
'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: add54613-068f-4365-8e18-51890dca729e)')' thrown while requesting GET https://huggingface.co/datasets/facebook/voxpopuli/resolve/719aaef8225945c0d80b277de6c79aa42ab053d5/data/fr/train/train_part_1.tar.gz
Retrying in 1s [Retry 1/5].
Reading metadata...: 16089it [00:00, 30988.15it/s]
'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 0d4e4066-c436-45f0-a9fd-6409c34c3ec8)')' thrown while requesting GET h

TrainOutput(global_step=33000, training_loss=0.41040474846146324, metrics={'train_runtime': 31176.0953, 'train_samples_per_second': 8.468, 'train_steps_per_second': 1.059, 'total_flos': 1.712303898624e+19, 'train_loss': 0.41040474846146324, 'epoch': 1.0})

In [48]:
if checkpoint_remote_path.exists():
    shutil.rmtree(checkpoint_remote_path)
shutil.copytree(checkpoint_path, checkpoint_remote_path)

PosixPath('/content/final-checkpoints')

In [49]:
trainer.evaluate(large_eval_dataset)

Reading metadata...: 16089it [00:00, 26066.31it/s]
'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 146fdd79-f157-460e-b455-847d4b262dc9)')' thrown while requesting GET https://huggingface.co/datasets/facebook/voxpopuli/resolve/719aaef8225945c0d80b277de6c79aa42ab053d5/data/fr/test/test_part_0.tar.gz
Retrying in 1s [Retry 1/5].


{'eval_loss': 0.5147430896759033,
 'eval_model_preparation_time': 0.0031,
 'eval_wer': 0.2426566579634465,
 'eval_cer': 0.11742522419781246,
 'eval_runtime': 262.37,
 'eval_samples_per_second': 1.951,
 'eval_steps_per_second': 0.244,
 'epoch': 1.0}

In [50]:
model_output_dir = Path('./final-model').resolve()
trainer.save_model(model_output_dir)
tokenizer.save_pretrained(model_output_dir)

('/content/final-model/tokenizer_config.json',
 '/content/final-model/special_tokens_map.json',
 '/content/final-model/vocab.json',
 '/content/final-model/merges.txt',
 '/content/final-model/normalizer.json',
 '/content/final-model/added_tokens.json')

In [51]:
print(run_on_sample_audio())

Reading metadata...: 16089it [00:00, 28045.54it/s]
It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


["Ce dernier évolait tout au long de l'histoire romaine. Son actionnaire majoritaire est le conseil d'un tel de Saint-Pierre-et-Miculon. Ce site contient quatre tombeaux de la dynastie, Hacheménide et sept et safangnide."]


# Model conversion

Next, we need to convert the model into a format usable by Joplin. This next step converts the model from PyTorch to GGML.

In [52]:
!git clone https://github.com/openai/whisper whisper-github
!git clone https://github.com/ggerganov/whisper.cpp
!cd whisper.cpp && git checkout v1.7.4

Cloning into 'whisper-github'...
remote: Enumerating objects: 828, done.[K
remote: Counting objects: 100% (370/370), done.[K
remote: Compressing objects: 100% (69/69), done.[K
remote: Total 828 (delta 333), reused 301 (delta 301), pack-reused 458 (from 2)[K
Receiving objects: 100% (828/828), 8.26 MiB | 15.83 MiB/s, done.
Resolving deltas: 100% (496/496), done.
Cloning into 'whisper.cpp'...
remote: Enumerating objects: 16251, done.[K
remote: Counting objects: 100% (2938/2938), done.[K
remote: Compressing objects: 100% (619/619), done.[K
remote: Total 16251 (delta 2490), reused 2319 (delta 2319), pack-reused 13313 (from 5)[K
Receiving objects: 100% (16251/16251), 20.09 MiB | 12.20 MiB/s, done.
Resolving deltas: 100% (11166/11166), done.
Note: switching to 'v1.7.4'.

You are in 'detached HEAD' state. You can look around, make experimental
changes and commit them, and you can discard any commits you make in this
state without impacting any branches by switching back to a branch.

I

In [53]:
# Patch convert-h5-to-ggml to work with more recent model versions
conversion_script_path = Path('whisper.cpp/models/convert-h5-to-ggml.py')
conversion_script_content = conversion_script_path.read_text()
with open(conversion_script_path, 'w') as conversion_script:
    bad_if_statement = 'if "max_length" not in hparams:'
    replaced_if_statement = 'if "max_length" not in hparams or hparams["max_length"] == None:'
    conversion_script.write(conversion_script_content.replace(bad_if_statement, replaced_if_statement))

In [54]:
!mkdir ./ggml
!python whisper.cpp/models/convert-h5-to-ggml.py ./final-model ./whisper-github ./ggml

2025-03-21 02:25:28.959765: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742523928.981887  137249 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742523928.991911  137249 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
model.encoder.conv1.weight  ->  encoder.conv1.weight
encoder.conv1.weight 3 (512, 80, 3)
model.encoder.conv1.bias  ->  encoder.conv1.bias
  Reshaped variable:  encoder.conv1.bias  to shape:  (512, 1)
encoder.conv1.bias 2 (512, 1)
  Converting to float32
model.encoder.conv2.weight  ->  encoder.conv2.weight
encoder.conv2.weight 3 (512, 512, 3)
model.encoder.conv2.bias  ->  encoder.conv2.bias
  Reshaped variable:  encoder.conv2.bias  to

For smaller size and better performance, we can also quantize the GGML model:

In [55]:
!cd whisper.cpp && cmake -B build && cmake --build build --config Release
!./whisper.cpp/build/bin/quantize ./ggml/ggml-model.bin ./ggml/ggml-model-q8_0.bin q8_0
!./whisper.cpp/build/bin/quantize ./ggml/ggml-model.bin ./ggml/ggml-model-q5_0.bin q5_0

  Compatibility with CMake < 3.10 will be removed from a future version of
  CMake.

  Update the VERSION argument <min> value.  Or, use the <min>...<max> syntax
  to tell CMake that the project requires at least <min> but has been updated
  to work with policies introduced by <max> or earlier.

[0m
-- The C compiler identification is GNU 11.4.0
-- The CXX compiler identification is GNU 11.4.0
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working C compiler: /usr/bin/cc - skipped
-- Detecting C compile features
-- Detecting C compile features - done
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Found Git: /usr/bin/git (found version "2.34.1")
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Success
-- Found Threads: TRUE
-- CMAKE_SYSTEM_PROCES

Now, let's make sure that the `.ggml` model works. Start by downloading some test audio:

In [56]:
!mkdir ./test-audio
# Download the first chapter of Alice in Wonderland (in French)
!wget -P ./test-audio/ https://www.archive.org/download/alice_au_pays_des_merveilles_1811_librivox/aliceaupays_04_carroll_128kb.mp3
# Convert it to a format that's understandable by whisper.cpp:
# -t 30                 Take the first 30s
# -i ...                Input path
# -ar 16000             Sample rate of 16000 HZ
# -ac 1                 1 audio channel
# -codec:a pcm_s16le    Audio codec
!ffmpeg -t 30 -i ./test-audio/aliceaupays_04_carroll_128kb.mp3 -ar 16000 -ac 1 -codec:a pcm_s16le ./test-audio/recording-fr-4.wav

--2025-03-21 02:27:14--  https://www.archive.org/download/alice_au_pays_des_merveilles_1811_librivox/aliceaupays_04_carroll_128kb.mp3
Resolving www.archive.org (www.archive.org)... 207.241.224.2
Connecting to www.archive.org (www.archive.org)|207.241.224.2|:443... connected.
HTTP request sent, awaiting response... 302 Moved Temporarily
Location: https://archive.org/download/alice_au_pays_des_merveilles_1811_librivox/aliceaupays_04_carroll_128kb.mp3 [following]
--2025-03-21 02:27:14--  https://archive.org/download/alice_au_pays_des_merveilles_1811_librivox/aliceaupays_04_carroll_128kb.mp3
Resolving archive.org (archive.org)... 207.241.224.2
Connecting to archive.org (archive.org)|207.241.224.2|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://ia803201.us.archive.org/25/items/alice_au_pays_des_merveilles_1811_librivox/aliceaupays_04_carroll_128kb.mp3 [following]
--2025-03-21 02:27:15--  https://ia803201.us.archive.org/25/items/alice_au_pays_des_mervei

Next, use the `whisper-cli` command to transcribe the audio using our GGML model:

In [57]:
# Test converting the WAV file to text using the GGML file that we built
!./whisper.cpp/build/bin/whisper-cli --language fr -np --no-timestamps -m ./ggml/ggml-model.bin ./test-audio/recording-fr-4.wav


Capétre quatre de aventures d'Alice au Pays des Marseille par Louis Scarras. C'est un registrement libre-voxe fait partie du domaine public, enregistré par Linda Olsen-Fightac, Los Angeles. L'habitation du lapin blanc. C'était le lapin blanc qui revenait en trotinant et qui cherchait de tout côté d'un air inquiet, comme s'il avait perdu quelque chose.


In [63]:
# Compare with the upstream model
!mkdir ./ggml-upstream/
!sh ./whisper.cpp/models/download-ggml-model.sh base ./ggml-upstream/
!./whisper.cpp/build/bin/whisper-cli --language fr --no-timestamps -np -m ./ggml-upstream/ggml-base.bin ./test-audio/recording-fr-4.wav

mkdir: cannot create directory ‘./ggml-upstream/’: File exists
Downloading ggml model base from 'https://huggingface.co/ggerganov/whisper.cpp' ...
Model base already exists. Skipping download.

 Chapétre quatre de aventures d'alice au pays des merveilleux par Louis Carroll. Cet enregistrement librevox fait partie du domaine public, enregistré par Linda Olsen-Vitac Los Angeles. L'habitation du lapin blanc. C'était le lapin blanc qui revenait en trottinant et qui cherchait de tout côté d'un air inquiét comme s'il avait perdu quelque chose.


In [59]:
from huggingface_hub import notebook_login, HfApi

# (Optional) Publish to Huggingface (does not currently include the ggml model)
def push_to_hub():
    notebook_login()

    revision = 'train-on-voxpopuli-and-commonvoice--v4'
    # Publish the GGML files
    api = HfApi()
    # Commit to base the new branch on (replace this):
    #base_on = '9dc99c95056795aaa8fbed87c976965c7ff0a129'
    #api.create_branch(repo_id = hub_model_id, branch=revision)

    # Publish the model, processor
    trainer.push_to_hub(
        # dataset_tags=['facebook/voxpopuli', 'mozilla-foundation/common_voice_11_0'],
        # language='fr',
        # model_name='Whisper Tiny (Finetuned on French)',
        # finetuned_from=finetune_from_id,
        # tasks='automatic-speech-recognition',
        #revision=revision
    )
    # Note: If this creates a new repo, it will be public
    tokenizer.push_to_hub(hub_model_id)#, revision=revision)
    api.upload_folder(
        folder_path='./ggml',
        repo_id=hub_model_id,
        path_in_repo='ggml/',
        #revision=revision,
        allow_patterns='ggml-*.bin',
        delete_patterns='ggml-*.bin'
    )

In [60]:
# Google colab only: Save the files to the local machine
from google.colab import files
files.download('./ggml/ggml-model-q8_0.bin')
files.download('./ggml/ggml-model.bin')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# Uncomment to publish
#push_to_hub()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

model.safetensors:   0%|          | 0.00/290M [00:00<?, ?B/s]

events.out.tfevents.1742523885.4170f5d27dac.1634.1:   0%|          | 0.00/527 [00:00<?, ?B/s]

Upload 4 LFS files:   0%|          | 0/4 [00:00<?, ?it/s]

training_args.bin:   0%|          | 0.00/5.56k [00:00<?, ?B/s]

events.out.tfevents.1742492389.4170f5d27dac.1634.0:   0%|          | 0.00/303k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/4.72k [00:00<?, ?B/s]

No files have been modified since last commit. Skipping to prevent empty commit.


ggml-model-q5_0.bin:   0%|          | 0.00/55.3M [00:00<?, ?B/s]

ggml-model-q8_0.bin:   0%|          | 0.00/81.8M [00:00<?, ?B/s]

Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

ggml-model.bin:   0%|          | 0.00/148M [00:00<?, ?B/s]

The `ggml-model.bin` file still needs to be placed in a ZIP file with a `config.json`. For the expected format, see [the vocab cleanup notebook](./whisper_vocab_cleanup.ipynb).