# Distil-Whisper

Distil-Whisper is a distilled version of the Whisper model that is 6 times faster, 49% smaller, and performs within 1% WER on out-of-distribution evaluation sets.

Before getting started, let's quickly recap Whisper. Whisper is a general purpose speech recognition model proposed by OpenAI in the paper [*Robust Speech Recognition via Large-Scale Weak Supervision*](https://cdn.openai.com/papers/whisper.pdf). The Whisper architecture is a Transformer based encoder-decoder model. First, the encoder maps the input audio to encoder hidden-states in a single forward pass. The decoder then auto-regressively predicts text tokens, conditional on both the previous tokens and the encoder hidden-states.

OpenAI's best Whisper checkpoint, named [Whisper large-v2](https://huggingface.co/openai/whisper-large-v2), has 32 encoder layers and 32 decoder layers. 32 layers is quite a lot! Let's visualize the model:

![img](https://huggingface.co/datasets/patrickvonplaten/images_distil/resolve/main/whiser_arch_old.png)

Here, $\mathbf{X}_{1:T}$ represents the speech input. It is mapped by the encoder (shown in green) through a single forward pass. The encoder ouputs, i.e. the  encoder hidden-states $\mathbf{H}_{1:M}$, are then used in the cross-attention layers in each decoder block.

Starting with a start-of-sequence token $y_0$, the decoder (shown in orange) auto-regressively generates the text tokens in the transcription. In the visiualization above, there are 5 decoder forward passes, one for each $\mathbf{P}(y_i | \mathbf{y}_{0: i - 1}),  \forall i$.

In practice, the decoder is run up to 128 times (depending on the length of the transcription), which means that there many more forward passes through the decoder then the encoder. The result is that the decoder is responsible for over **90% of the inference time** in Whisper.

This is the motivation behind Distil-Whipser: we make the decoder faster in order to speed-up the inference time of model. With this in mind, let's take a look at the Distil-Whisper architecture:

![img](https://huggingface.co/datasets/patrickvonplaten/images_distil/resolve/main/distil_arch_old.png)

Just two decoder layers! That means to generate a transcription of 128 tokens, Distil-Whisper needs to run only 256 decoder layer forward passes, while Whisper large-v2 has to run 4096 forward passes. Since the encoder is only run once, we copy the entire encoder and *freeze* it during training. This means Distil-Whisper inherits Whisper's robustness to different audio conditions.

## Benchmarking

Great, now that we've understood why Distil-Whisper should be faster in theory, let's see if it holds true in practice.

To begin with, we install `transformers`, `accelerate`, and `datasets`.

In this notebook, we use a A100 GPU that is available through a Colab pro subscription, as this is the device we used for benchmarking in the [Distil-Whisper paper](https://huggingface.co/papers/2311.00430). Other GPUs will most likely lead to different speed-ups, but they should be in the same ballpark range:

In [None]:
!which python3

In [None]:
# %pip uninstall -y torch flash-attn

print(torch.__version__)

In [None]:
%pip install --upgrade --quiet transformers accelerate datasets torch
%pip install flash-attn==2.2.2


In addition, we will make use of [Flash Attention 2](https://github.com/Dao-AILab/flash-attention), as it saves
a lot of memory and speeds up large matmul operations.

In [None]:
%pip install --upgrade --quiet flash-attn --no-build-isolation

In [None]:
# Load environment variables
%pip install --upgrade python-dotenv

# import os
# from dotenv import load_dotenv

# load_dotenv('/home/bigdaddy/Documents/GitHub/UCSD-ML-AI-Projects/hf.env')
# HF_TOKEN_PATH = os.getenv("HF_TOKEN_PATH")
# HF_USERNAME = os.getenv("HF_USERNAME")

# # Read the token if it is available
# try:
#     with open(HF_TOKEN_PATH, 'r') as token_file:
#         HF_TOKEN = token_file.read().strip()
# except FileNotFoundError as e:
#     print(f"Error: {e}")

Collecting python-dotenv
  Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)
Downloading python_dotenv-1.0.1-py3-none-any.whl (19 kB)
Installing collected packages: python-dotenv
Successfully installed python-dotenv-1.0.1
Note: you may need to restart the kernel to use updated packages.


In [None]:
# Authenticate with Hugging Face Hub
from huggingface_hub import login

# Check if the token is set
if HF_TOKEN:
    # Use the token to log in
    login(token=HF_TOKEN)
else:
    login()

To begin with, let's load the dataset that we will use for benchmarking. We'll load a small dataset consisting of 73 samples from the [LibriSpeech ASR](https://huggingface.co/datasets/librispeech_asr) validation-clean dataset. This amounts to ~9MB of data, so it's very lightweight and quick to download on device:

In [None]:
from datasets import load_dataset

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

We start by benchmarking [Whisper large-v2](https://huggingface.co/openai/whisper-large-v2) to get our baseline number. We'll load the model in `float16` precision and make sure that loading time takes as little time as possible by passing `low_cpu_mem_usage=True`. In addition, we want to make sure that the model is loaded in [`safetensors`](https://github.com/huggingface/safetensors) format by passing `use_safetensors=True`:

In [None]:
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
import torch

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

print(f'device: {device}')
print(f'torch_dtype: {torch_dtype}')

model_id = "openai/whisper-large-v2"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="flash_attention_2"
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

Great! For the benchmark, we will only measure the generation time (encoder + decoder), so let's write a short helper function that measures this step:

In [None]:
import time

def generate_with_time(model, inputs):
    start_time = time.time()
    outputs = model.generate(**inputs)
    generation_time = time.time() - start_time
    return outputs, generation_time

This function will return both the decoded tokens as well as the time
it took to run the model.

We now iterate over the audio samples and sum up the generation time.

In [None]:
from tqdm import tqdm

all_time = 0

for sample in tqdm(dataset):
  audio = sample["audio"]
  inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
  inputs = inputs.to(device=device, dtype=torch.float16)

  output, gen_time = generate_with_time(model, inputs)
  all_time += gen_time
  print(processor.batch_decode(output, skip_special_tokens=True))

print(all_time)

Alright! In total it took roughly 63 seconds to transcribe 73 audio samples.

Next, let's see how much time it takes with [Distil-Whisper](https://huggingface.co/distil-whisper/distil-large-v2):

In [None]:
model_id = "distil-whisper/distil-large-v2"

distil_model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=True
)
distil_model = distil_model.to(device)

We run the same benchmarking loop:

In [None]:
all_time = 0

for sample in tqdm(dataset):
  audio = sample["audio"]
  inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
  inputs = inputs.to(device=device, dtype=torch.float16)

  output, gen_time = generate_with_time(distil_model, inputs)
  all_time += gen_time
  print(processor.batch_decode(output, skip_special_tokens=True))

print(all_time)

Only 10 seconds - that amounts to a 6x speed-up!

## Memory

In addition to being significantly faster, Distil-Whisper also has fewer parameters. Let's have a look at how many fewer exactly.

In [None]:
distil_model.num_parameters() / model.num_parameters() * 100

Distil-Whisper is 49% of the size of Whisper. Note that this ratio is much lower if we would just compare the size of the decoder:

In [None]:
distil_model.model.decoder.num_parameters() / model.model.decoder.num_parameters() * 100


As expected the decoder is much smaller. One might have guessed that it should even be less, around 2/32 (or 6%), but we can't forget that the decoder has a very large word embedding that requires a lot of parameters.

## Next steps

Hopefully this notebook shed some light on the motivation behind Distil-Whisper! For now, we've measured Distil-Whisper mainly on GPU, but are now actively looking into collaborating to release code how to effectively accelerate Distil-Whisper on CPU as well. Updates will be posted on the Distil-Whisper [repository](https://github.com/huggingface/distil-whisper).

Another key application of Distil-Whisper is *speculative decoding*. In speculative decoding, we can use Distil-Whisper as an *assitant model* to Whisper-large-v2 to reach a speed-up of 2x without **any** loss in performance. More on that in a follow-up notebook that will come out soon!