# Installing Whisper

The commands below will install the Python packages needed to use Whisper models and evaluate the transcription results.

In [1]:
! pip install transformers

Defaulting to user installation because normal site-packages is not writeable
Collecting transformers
  Obtaining dependency information for transformers from https://files.pythonhosted.org/packages/13/30/54b59e73400df3de506ad8630284e9fd63f4b94f735423d55fc342181037/transformers-4.33.1-py3-none-any.whl.metadata
  Downloading transformers-4.33.1-py3-none-any.whl.metadata (119 kB)
     ---------------------------------------- 0.0/119.9 kB ? eta -:--:--
     ---------------------------------------- 0.0/119.9 kB ? eta -:--:--
     --- ------------------------------------ 10.2/119.9 kB ? eta -:--:--
     --- ------------------------------------ 10.2/119.9 kB ? eta -:--:--
     --- ------------------------------------ 10.2/119.9 kB ? eta -:--:--
     --------- --------------------------- 30.7/119.9 kB 163.8 kB/s eta 0:00:01
     ------------ ------------------------ 41.0/119.9 kB 163.4 kB/s eta 0:00:01
     ------------ ------------------------ 41.0/119.9 kB 163.4 kB/s eta 0:00:01
     ------



In [None]:
from transformers import WhisperFeatureExtractor

In [1]:
! pip install git+https://github.com/openai/whisper.git

Defaulting to user installation because normal site-packages is not writeable
Collecting git+https://github.com/openai/whisper.git
  Cloning https://github.com/openai/whisper.git to c:\users\ckumarsingh\appdata\local\temp\pip-req-build-2lkf91hh
  Resolved https://github.com/openai/whisper.git to commit e8622f9afc4eba139bf796c210f5c01081000472
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'


  Running command git clone --filter=blob:none --quiet https://github.com/openai/whisper.git 'C:\Users\ckumarsingh\AppData\Local\Temp\pip-req-build-2lkf91hh'


In [2]:
import io
import os
import numpy as np

try:
    import tensorflow  # required in Colab to avoid protobuf compatibility issues
except ImportError:
    pass

import torch
import pandas as pd
import urllib
import tarfile
import whisper
import torchaudio

from scipy.io import wavfile
from tqdm.notebook import tqdm


pd.options.display.max_rows = 100
pd.options.display.max_colwidth = 1000
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Loading the Fleurs dataset

Select the language of the Fleur dataset to download. Please note that the transcription and translation performance varies widely depending on the language. Appendix D.2 in the paper contains the performance breakdown by language.

In [3]:
import ipywidgets as widgets

languages = {"af_za": "Afrikaans", "am_et": "Amharic", "ar_eg": "Arabic", "as_in": "Assamese", "az_az": "Azerbaijani", "be_by": "Belarusian", "bg_bg": "Bulgarian", "bn_in": "Bengali", "bs_ba": "Bosnian", "ca_es": "Catalan", "cmn_hans_cn": "Chinese", "cs_cz": "Czech", "cy_gb": "Welsh", "da_dk": "Danish", "de_de": "German", "el_gr": "Greek", "en_us": "English", "es_419": "Spanish", "et_ee": "Estonian", "fa_ir": "Persian", "fi_fi": "Finnish", "fil_ph": "Tagalog", "fr_fr": "French", "gl_es": "Galician", "gu_in": "Gujarati", "ha_ng": "Hausa", "he_il": "Hebrew", "hi_in": "Hindi", "hr_hr": "Croatian", "hu_hu": "Hungarian", "hy_am": "Armenian", "id_id": "Indonesian", "is_is": "Icelandic", "it_it": "Italian", "ja_jp": "Japanese", "jv_id": "Javanese", "ka_ge": "Georgian", "kk_kz": "Kazakh", "km_kh": "Khmer", "kn_in": "Kannada", "ko_kr": "Korean", "lb_lu": "Luxembourgish", "ln_cd": "Lingala", "lo_la": "Lao", "lt_lt": "Lithuanian", "lv_lv": "Latvian", "mi_nz": "Maori", "mk_mk": "Macedonian", "ml_in": "Malayalam", "mn_mn": "Mongolian", "mr_in": "Marathi", "ms_my": "Malay", "mt_mt": "Maltese", "my_mm": "Myanmar", "nb_no": "Norwegian", "ne_np": "Nepali", "nl_nl": "Dutch", "oc_fr": "Occitan", "pa_in": "Punjabi", "pl_pl": "Polish", "ps_af": "Pashto", "pt_br": "Portuguese", "ro_ro": "Romanian", "ru_ru": "Russian", "sd_in": "Sindhi", "sk_sk": "Slovak", "sl_si": "Slovenian", "sn_zw": "Shona", "so_so": "Somali", "sr_rs": "Serbian", "sv_se": "Swedish", "sw_ke": "Swahili", "ta_in": "Tamil", "te_in": "Telugu", "tg_tj": "Tajik", "th_th": "Thai", "tr_tr": "Turkish", "uk_ua": "Ukrainian", "ur_pk": "Urdu", "uz_uz": "Uzbek", "vi_vn": "Vietnamese", "yo_ng": "Yoruba"}
selection = widgets.Dropdown(
    options=[("Select language", None), ("----------", None)] + sorted([(f"{v} ({k})", k) for k, v in languages.items()]),
    value="hi_in",
    description='Language:',
    disabled=False,
)

selection

Dropdown(description='Language:', index=29, options=(('Select language', None), ('----------', None), ('Afrika…

In [4]:
lang = selection.value
language = languages[lang]

assert lang is not None, "Please select a language"
print(f"Selected language: {language} ({lang})")

Selected language: Hindi (hi_in)


In [5]:
def download(url: str, target_path: str):
    with urllib.request.urlopen(url) as source, open(target_path, "wb") as output:
        with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))


class Fleurs(torch.utils.data.Dataset):
    """
    A simple class to wrap Fleurs and subsample a portion of the dataset as needed.
    """
    def __init__(self, lang, split="test", subsample_rate=1, device=DEVICE):
        url = f"https://storage.googleapis.com/xtreme_translations/FLEURS102/{lang}.tar.gz"
        tar_path = os.path.expanduser(f"~/.cache/fleurs/{lang}.tgz")
        os.makedirs(os.path.dirname(tar_path), exist_ok=True)

        if not os.path.exists(tar_path):
            download(url, tar_path)

        all_audio = {}
        with tarfile.open(tar_path, "r:gz") as tar:
            for member in tar.getmembers():
                name = member.name
                if name.endswith(f"{split}.tsv"):
                    labels = pd.read_table(tar.extractfile(member), names=("id", "file_name", "raw_transcription", "transcription", "_", "num_samples", "gender"))

                if f"/{split}/" in name and name.endswith(".wav"):
                    audio_bytes = tar.extractfile(member).read()
                    all_audio[os.path.basename(name)] = wavfile.read(io.BytesIO(audio_bytes))[1]                    

        self.labels = labels.to_dict("records")[::subsample_rate]
        self.all_audio = all_audio
        self.device = device

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, item):
        record = self.labels[item]
        audio = torch.from_numpy(self.all_audio[record["file_name"]].copy())
        text = record["transcription"]
        
        return (audio, text)

In [6]:
dataset = Fleurs(lang, subsample_rate=10)  # subsample 10% of the dataset for a quick demo

# Running inference on the dataset using a medium Whisper model

The following will take a few minutes to transcribe and translate utterances in the dataset.

In [7]:
model = whisper.load_model("medium")
print(
    f"Model is {'multilingual' if model.is_multilingual else 'English-only'} "
    f"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters."
)

Model is multilingual and has 762,321,920 parameters.


In [8]:
options = dict(language=language, beam_size=5, best_of=5)
transcribe_options = dict(task="transcribe", **options)
translate_options = dict(task="translate", **options)

In [9]:
references = []
transcriptions = []
translations = []

for audio, text in tqdm(dataset):
    transcription = model.transcribe(audio, **transcribe_options)["text"]
    translation = model.transcribe(audio, **translate_options)["text"]
    
    transcriptions.append(transcription)
    translations.append(translation)
    references.append(text)

  0%|          | 0/42 [00:00<?, ?it/s]



In [10]:
data = pd.DataFrame(dict(reference=references, transcription=transcriptions, translation=translations))
data

Unnamed: 0,reference,transcription,translation
0,स्कीइंग मार्ग को एक हाईकिंग लंबी पैदल यात्रा मार्ग जैसा ही सोचें।,अस्कीन्मार्क को एक हाईकिन् लंबी पैदल यात्रा मार्ग जैसा ही सोचें।,Think of skiing as a long hiking trail.
1,अधिकांश छोटे द्वीप स्वतंत्र राष्ट्र हैं या फ़्रांस से संबंधित हैं और लग्ज़री बीच रिसॉर्ट के रूप में जाने जाते हैं,अधिकाश चोटे दे पस्वतंत्र राश्टर हैं या फ्रांस से सम्बंदित हैं और लक्स्री बीट्स रिसॉर्ट के रूप में जाने जाते हैं।,Most of them are related to the small and independent state or France and are known as Luxury Beach Resort.
2,तूफान और बवंडर की तरह आंधी ओले भारी बारिश और जंगल की आग तीव्र मौसम का हिस्सा और असर हैं,"तूफान और भवंदर की तरा आंधी, आूले, भारी, बारिश और जंगल की आक तीवर मौसम का हिस्सा और असर है।","Like storms and storms, winds, storms, heavy rains and forest fires are a part and effect of severe weather."
3,महिलाएं यह अनुशंसा की जाती है कि कोई भी महिला यात्री वास्तविक वैवाहिक स्थिति के बावजूद कहती है कि वह विवाहित है,महिलाये या अनुश्रणशा की जाती है कि कोई भी महिलायाइादरी वास्तविक विवाई क्षिति के बावजुत कहती है कि वो विवाईत है।,Women are said to be married despite the fact that they are married.
4,वाइल्डलाइफ़ हैबिटेंट्स के रूप में काम करने वाली रेती और तटों को बनाने के लिए गाद ज़रूरी थी,Wildlife habitants के रूप में काम करने वाली रेती और तटों को बनाने के लिए गाद जरूडी थी।,It was very important to make the soil and soil to work as wildlife habitants.
5,बाली में इस एजेंडे के अन्य विषयों में दुनिया के बचे हुए जंगलों को बचाने और ऐसी तकनीकों का आदान प्रदान करना का विषय शामिल है जिससे कि विकासशील देशों को कम प्रदूषणकारी तरीकों से आगे बढ़ने में मदद मिले,"बाली में इस अजेंडे के अन्य विष्यों में दुन्या के बचे हुए जंगलों को बचाने और ऐसी तेक्नीकों का आदान पर्दान करने का विष्य शामिल हैं, जिससे कि विकास जीर देशों को कम परदुषन कारी तरीकों से आगे बरने में मदद मेले.","In this agenda, Bali is involved in protecting the remaining forests of the world and providing solutions to such techniques, which will help developing countries to move forward in less polluting ways."
6,1889 में यह बंदरगाह कुख्यात नौसैनिक गतिरोध का ठिकाना था उस समय जर्मनी अमेरिका और ब्रिटेन के सात जहाजों ने इस बंदरगाह से जाने से इनकार कर दिया था,"1889 में यह बंदर्गाख युक्यात नोसेनिक गतिरोत का थिकाना था। उस समय जर्मनी, अमेरिका और बिटन के साथ जहाजों ने इस बंदर्गाख से जाने से इनकार कर दिया था।","In 1889, this was the place of the 9th World War, during which Germany, America and Britain refused to leave this place."
7,सन 1976 तक माचू पिचू के तीस प्रतिशत हिस्से का जीर्णोद्धार कर दिया गया था और जीर्णोद्धार का कार्य आज तक जारी है,1976 तक माचु पिचु के 30% हिस्से का जिर्नोधार कर दिया गया था और जिर्नोधार का कारे आज तक जारी है।,"In 1976, 30% of Machu Picchu's share of Jirnodhar was done and the work of Jirnodhar is still going on."
8,ms बीमारी केंद्रीय तंत्रिका तंत्र पर असर करती है जिसमें दिमाग स्पाइनल कॉर्ड और ऑप्टिक नर्व शामिल हैं,"MS बिमारी केंद्रिय तंत्र का तंत्र पर असर करती है, जिसमें दिमाग, स्पाइनल कौड और आप्टिक नर्व शामिल है।","The MS disease centre affects the immune system, which includes the brain, spinal cord and optic nerve."
9,समाजीकरण के महत्व को स्पष्ट करने के लिए इस्तेमाल किए जाने वाले सबसे आम तरीकों में से कुछ बच्चों के दुर्भाग्यपूर्ण मामलों को आकर्षित करना है जो बड़े होने के दौरान वयस्कों द्वारा उपेक्षित नहीं बल्कि उपेक्षा दुर्भाग्य या दुर्व्यवहार के माध्यम से होते थे,"समाजी करन के महत्तो को इसपष्ट करने के लिए इस्तमाल किया जाने वाले सबसे आम तरीकों में से कुछ बच्चों के दृर्भाग्य पुर्न मामलों को आकर्षित करना हैं, जो बड़े होने के दौरान वैसको दौरा उपेक्छित नहीं बल्कि उपेक्छा दृर्भाग्य या दृर्वेवार के माध्यम से होते थे।","To clarify the importance of socialization, the most common method used is to attract the unfortunate events of some children, which were not avoided by adults, but were avoided through misfortune or misbehavior."


# Word-level timestamps using attention weights

Below, we use the cross-attention weights to determine more granular, word-level timestamps. It uses a set of heuristics and dynamic time warping (DTW) to find the alignment between the audio and the transcript.

In [11]:
! pip install dtw-python

Defaulting to user installation because normal site-packages is not writeable
Collecting dtw-python
  Downloading dtw_python-1.3.0-cp311-cp311-win_amd64.whl (302 kB)
     ---------------------------------------- 0.0/302.8 kB ? eta -:--:--
     - -------------------------------------- 10.2/302.8 kB ? eta -:--:--
     --- --------------------------------- 30.7/302.8 kB 262.6 kB/s eta 0:00:02
     ----------- ------------------------- 92.2/302.8 kB 585.1 kB/s eta 0:00:01
     -------------------------------------- 302.8/302.8 kB 1.6 MB/s eta 0:00:00
Installing collected packages: dtw-python
Successfully installed dtw-python-1.3.0




In [12]:
import string
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import matplotlib.ticker as ticker

from IPython.display import display, HTML
from whisper.tokenizer import get_tokenizer
from dtw import dtw
from scipy.ndimage import median_filter

%matplotlib inline
%config InlineBackend.figure_format = "retina"

Importing the dtw module. When using in academic works please cite:
  T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.
  J. Stat. Soft., doi:10.18637/jss.v031.i07.



In [13]:
AUDIO_SAMPLES_PER_TOKEN = whisper.audio.HOP_LENGTH * 2
AUDIO_TIME_PER_TOKEN = AUDIO_SAMPLES_PER_TOKEN / whisper.audio.SAMPLE_RATE

medfilt_width = 7
qk_scale = 1.0

tokenizer = get_tokenizer(model.is_multilingual, language=languages[lang])

In [14]:
# This part downloads a repackaged version of the Noto Sans font (either CJK or non-CJK)
# to render various languages in Matplotlib figures.

if languages[lang] in {"Chinese", "Japanese", "Korean"}:
    font = "GoNotoCJKCore.ttf"
else:
    font = "GoNotoCurrent.ttf"

font_release = "https://github.com/satbyy/go-noto-universal/releases/download/v5.2"
if not os.path.exists(font):
    download(f"{font_release}/{font}", font)

prop = fm.FontProperties(fname=font)
props = {'fontproperties': prop}

  0%|                                              | 0.00/14.2M [00:00<?, ?iB/s]

In [15]:
def split_tokens_on_unicode(tokens: torch.Tensor):
    words = []
    word_tokens = []
    current_tokens = []
    
    for token in tokens.tolist():
        current_tokens.append(token)
        decoded = tokenizer.decode_with_timestamps(current_tokens)
        if "\ufffd" not in decoded:
            words.append(decoded)
            word_tokens.append(current_tokens)
            current_tokens = []
    
    return words, word_tokens

In [16]:
def split_tokens_on_spaces(tokens: torch.Tensor):
    subwords, subword_tokens_list = split_tokens_on_unicode(tokens)
    words = []
    word_tokens = []
    
    for subword, subword_tokens in zip(subwords, subword_tokens_list):
        special = subword_tokens[0] >= tokenizer.eot
        with_space = subword.startswith(" ")
        punctuation = subword.strip() in string.punctuation
        if special or with_space or punctuation:
            words.append(subword)
            word_tokens.append(subword_tokens)
        else:
            words[-1] = words[-1] + subword
            word_tokens[-1].extend(subword_tokens)
    
    return words, word_tokens

In [17]:
if languages[lang] in {"Chinese", "Japanese", "Thai", "Lao", "Myanmar"}:
    # These languages don't typically use spaces, so it is difficult to split words
    # without morpheme analysis. Here, we instead split words at any
    # position where the tokens are decoded as valid unicode points
    split_tokens = split_tokens_on_unicode
else:
    split_tokens = split_tokens_on_spaces

In [18]:
# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer

for i, block in enumerate(model.decoder.blocks):
    block.cross_attn.register_forward_hook(
        lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1])
    )

In [22]:
# for the first 10 examples in the dataset
for (audio, label), transcription in zip(dataset, transcriptions[:10]):
    print(transcription)
  
    duration = len(audio)
 #   mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(audio)).cuda()
    mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(audio))
    tokens = torch.tensor(
        [
            *tokenizer.sot_sequence,
            tokenizer.timestamp_begin,
        ] + tokenizer.encode(transcription) + [
            tokenizer.timestamp_begin + duration // AUDIO_SAMPLES_PER_TOKEN,
            tokenizer.eot,
        ]
    ).cpu()
#    ).cuda()
    with torch.no_grad():
        logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))

    weights = torch.cat(QKs)  # layers * heads * tokens * frames    
    weights = weights[:, :, :, : duration // AUDIO_SAMPLES_PER_TOKEN].cpu()
    weights = median_filter(weights, (1, 1, 1, medfilt_width))
    weights = torch.tensor(weights * qk_scale).softmax(dim=-1)
    
    w = weights / weights.norm(dim=-2, keepdim=True)
    matrix = w[-6:].mean(axis=(0, 1))

    alignment = dtw(-matrix.double().numpy())

    jumps = np.pad(np.diff(alignment.index1s), (1, 0), constant_values=1).astype(bool)
    jump_times = alignment.index2s[jumps] * AUDIO_TIME_PER_TOKEN
    words, word_tokens = split_tokens(tokens)

    # display the normalized attention weights and the alignment
    plt.figure(figsize=(8, 8))
    plt.imshow(matrix, aspect="auto")
    plt.plot(alignment.index2s, alignment.index1s, color="red")

    xticks = np.arange(0, matrix.shape[1], 1 / AUDIO_TIME_PER_TOKEN)
    xticklabels = (xticks * AUDIO_TIME_PER_TOKEN).round().astype(np.int32) 
    plt.xticks(xticks, xticklabels)
    plt.xlabel("Time (s)")
    
    # display tokens and words as tick labels
    ylims = plt.gca().get_ylim()

    ax = plt.gca()
    ax.tick_params('both', length=0, width=0, which='minor', pad=6)

    ax.yaxis.set_ticks_position("left")
    ax.yaxis.set_label_position("left")
    ax.invert_yaxis()
    ax.set_ylim(ylims)

    major_ticks = [-0.5]
    minor_ticks = []
    current_y = 0
    
    for word, word_token in zip(words, word_tokens):
        minor_ticks.append(current_y + len(word_token) / 2 - 0.5)
        current_y += len(word_token)
        major_ticks.append(current_y - 0.5)
        
    ax.yaxis.set_minor_locator(ticker.FixedLocator(minor_ticks))
    ax.yaxis.set_minor_formatter(ticker.FixedFormatter(words))
    ax.set_yticks(major_ticks)
    ax.yaxis.set_major_formatter(ticker.NullFormatter())
    
    for label in ax.get_yminorticklabels():
        label.set_fontproperties(prop)

    plt.ylabel("Words")
    plt.show()

    # display the word-level timestamps in a table
    word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
    begin_times = jump_times[word_boundaries[:-1]]
    end_times = jump_times[word_boundaries[1:]]

    data = [
        dict(word=word, begin=begin, end=end)
        for word, begin, end in zip(words[:-1], begin_times, end_times)
        if not word.startswith("<|") and word.strip() not in ".,!?、。"
    ]

    display(pd.DataFrame(data))
    display(HTML("<hr>"))

 अस्कीन्मार्क को एक हाईकिन् लंबी पैदल यात्रा मार्ग जैसा ही सोचें।


AssertionError: Torch not compiled with CUDA enabled