## [Whisper-AT: Noise-Robust Automatic Speech Recognizers are Also Strong Audio Event Taggers]()

This colab script contains a step-by-step tutorial on how to use Whisper-AT for joint automatic speech recognitiom (ASR) and audio tagging (AT).

Please cite our paper if you find this repository useful.

```
@inproceedings{gong_whisperat,
  author={Gong, Yuan and Khurana, Sameer and Karlinsky, Leonid and Glass, James},
  title={Whisper-AT: Noise-Robust Automatic Speech Recognizers are Also Strong Audio Event Taggers},
  year=2023,
  booktitle={Proc. Interspeech 2023}
}
```
For more information, please check https://github.com/YuanGongND/whisper-at

### Step 1. Install Whisper-AT Package

We intentionally do not any additional dependencies to the original Whisper. So if your environment can run the original Whisper, it must can also run Whisper-AT. Note that following original Whisper, it also requires the command-line tool [`ffmpeg`](https://ffmpeg.org/) to be installed on your system. Please check OpenAI Whisper repo for details.


Whisper-AT can be installed simply by `pip install whisper-at`

In [7]:
!sudo -H pip install whisper-at

[0m

### Step 2. Use as the Original Whisper

In [8]:
# Create a folder for the demo script
import os
folder_path = "./content"
if not os.path.exists(folder_path):
    os.makedirs(folder_path)

In [9]:
# note this is whisper"_"at not whisper-at
import whisper_at as whisper

# the only new thing in whisper-at
# specify the temporal resolution for audio tagging, 10 means Whisper-AT predict audio event every 10 seconds (hop and window=10s).
audio_tagging_time_resolution = 10

#model_version = "large-v1"
model_version = "large-v2"
#model_version = "large-v3"
model = whisper.load_model(model_version)
# for large, medium, small models, we provide low-dim proj AT models to save compute.
# model = whisper.load_model("large-v1", at_low_compute=Ture)
result = model.transcribe("./content/soccer-game.mp4", at_time_res=audio_tagging_time_resolution)
for segment in result['segments']:
  print(segment['start'], 's-', segment['end'], 's', segment['text'])

# # translation task is also supported
# result = model.transcribe("/content/soccer-game.flac", task='translate', at_time_res=audio_tagging_time_resolution)
# print(result["text"])

0.0 s- 3.68 s me 素晴らしい
3.68 s- 9.4 s さあここは ビートもが運んでいく立て板抜いていってきたがなぁ
9.4 s- 13.22 s きました 完璧でした
13.22 s- 21.68 s 説明で本気発動で開幕節はシストそして2節 この戦で本家発動ですみとかかわる
21.68 s- 24.72 s 先生やブライトーン
24.72 s- 26.72 s ん
26.72 s- 31.119999999999997 s 耳に当てるポーズがないのかわいいですね


`result["text"]` is the ASR output transcripts, it will be identical to that of the original Whisper and is not impacted by `at_time_res`, the ASR function still follows Whisper's 30 second window. `at_time_res` is only related to audio tagging.

Compared to the original Whisper, the only new thing is `at_time_res`, which is the hop and window size for Whisper-AT to predict audio events. For example, for a 60-second audio, setting `at_time_res = 10` means the audio will be segmented to 6 10-second segments, and Whisper-AT will predict audio tags based on each 10-second segment,
a total of 6 audio event predictions will be made. **Note `at_time_res` must be an integer multiple of 0.4, e.g., 0.4, 0.8, ...**, the default value is 10.0, which is the value we use to train the model and should lead to best performance.


### Step 3. Get the Audio Tagging Output

Compared with the original Whisper, `result` contains a new entry called `audio_tag`. `result['audio_tag']` is a torch tensor of shape [⌈`audio_length`/`at_time_res`⌉, 527]. For example, for a 60-second audio and `at_time_res = 10`, `result['audio_tag']` is a tensor of shape [6, 527]. 527 is the size of the [AudioSet label set](), `result['audio_tag'][i,j]` is the (unnormalised) logits of class `j` of the `i`th segment.

If you are familiar with audio tagging and AudioSet, you can take raw `result['audio_tag']` for your usage.

In [10]:
!pip install torchaudio==2.0.2



In [11]:
import torchaudio
audio, sr = torchaudio.load('./content/soccer-game_from_video.flac')
audio_len = audio.shape[1] / sr
print('Audio length is {:.2f}, at time resolution is {:.1f}, Whisper-AT output in shape'.format(audio_len, audio_tagging_time_resolution), result['audio_tag'].shape)

Audio length is 31.32, at time resolution is 10.0, Whisper-AT output in shape torch.Size([4, 527])


But we also provide a tool to make it easier.
You can feed the `result` to `whisepr.parse_at_label` and get readable results.

In [12]:
audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=5, p_threshold=-2, include_class_list=list(range(527)))
for segment in audio_tag_result:
  print(segment)

{'time': {'start': 0, 'end': 10}, 'audio tags': [('スピーチの音', 0.16229504346847534), ('車両の音', -1.377264380455017), ('サウンド・オブ・ミュージック', -1.5451114177703857)]}
{'time': {'start': 10, 'end': 20}, 'audio tags': [('スピーチの音', 0.554402232170105), ('外、都会、人工の音', -1.923429250717163)]}
{'time': {'start': 20, 'end': 30}, 'audio tags': [('スピーチの音', 1.1601895093917847), ('静寂の音', -1.647493839263916)]}
{'time': {'start': 30, 'end': 40}, 'audio tags': [('静寂の音', 0.659826397895813)]}


If you change the audio tagging resolution to 2s, then the output will be more fine-grained.

In [13]:
audio_tagging_time_resolution = 2
result = model.transcribe("./content/soccer-game_from_video.flac", at_time_res=audio_tagging_time_resolution)
print('Audio length is {:.2f}, at time resolution is {:.1f}, Whisper-AT output in shape'.format(audio_len, audio_tagging_time_resolution), result['audio_tag'].shape)
audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=5, p_threshold=-2, include_class_list=list(range(527)))
for segment in audio_tag_result:
  print(segment)

  result = model.transcribe("./content/soccer-game_from_video.flac", at_time_res=audio_tagging_time_resolution)


Audio length is 31.32, at time resolution is 2.0, Whisper-AT output in shape torch.Size([16, 527])
{'time': {'start': 0, 'end': 2}, 'audio tags': [('スピーチの音', -0.251071035861969), ('サウンド・オブ・ミュージック', -1.718733310699463)]}
{'time': {'start': 2, 'end': 4}, 'audio tags': [('スピーチの音', 0.5234002470970154), ('車両の音', -1.6031737327575684)]}
{'time': {'start': 4, 'end': 6}, 'audio tags': [('スピーチの音', 0.16451811790466309)]}
{'time': {'start': 6, 'end': 8}, 'audio tags': [('スピーチの音', 1.104638934135437), ('車両の音', -1.4187698364257812)]}
{'time': {'start': 8, 'end': 10}, 'audio tags': [('スピーチの音', 1.3844512701034546), ('外、都会、人工の音', -1.9056485891342163)]}
{'time': {'start': 10, 'end': 12}, 'audio tags': [('スピーチの音', 0.6657971143722534)]}
{'time': {'start': 12, 'end': 14}, 'audio tags': [('スピーチの音', 0.4543112516403198)]}
{'time': {'start': 14, 'end': 16}, 'audio tags': [('スピーチの音', -0.06972355395555496)]}
{'time': {'start': 16, 'end': 18}, 'audio tags': [('スピーチの音', 0.3470880091190338), ('外、都会、人工の音', -1.5061168

In [14]:
# Go back to 10s for better readability
audio_tagging_time_resolution = 10
result = model.transcribe("./content/soccer-game_from_video.flac", at_time_res=audio_tagging_time_resolution)

Let's take a closer look at `whisper.parse_at_label`.

First, `top_k` and `p_threshold` controls how many audio tags are output. Specifically, `whisper.parse_at_label` will output up to `k` labels that have unnormalised logits above `p_threshold`.

For example, set `top_k` = 1 allows the model to output at most 1 label.

In [15]:
audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=1, p_threshold=-2, include_class_list=list(range(527)))
for segment in audio_tag_result:
  print(segment)

{'time': {'start': 0, 'end': 10}, 'audio tags': [('スピーチの音', 0.1549655944108963)]}
{'time': {'start': 10, 'end': 20}, 'audio tags': [('スピーチの音', 0.5534265041351318)]}
{'time': {'start': 20, 'end': 30}, 'audio tags': [('スピーチの音', 0.22931687533855438)]}
{'time': {'start': 30, 'end': 40}, 'audio tags': [('静寂の音', 0.8264279961585999)]}


Setting larger `top_k` and smaller `p_threshold` makes the model more verbose, and vise-versa.

In [16]:
audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=10, p_threshold=-5, include_class_list=[47])
for segment in audio_tag_result:
  print(segment)

{'time': {'start': 0, 'end': 10}, 'audio tags': []}
{'time': {'start': 10, 'end': 20}, 'audio tags': []}
{'time': {'start': 20, 'end': 30}, 'audio tags': []}
{'time': {'start': 30, 'end': 40}, 'audio tags': []}


Second, you can also select the classes of interest by inputting a list to `include_class_list`. For the name-to-index mappling, simply let Whisper-AT print it for you.

In [17]:
whisper.print_label_name(language='en')

index: 0 : Speech
index: 1 : Male speech, man speaking
index: 2 : Female speech, woman speaking
index: 3 : Child speech, kid speaking
index: 4 : Conversation
index: 5 : Narration, monologue
index: 6 : Babbling
index: 7 : Speech synthesizer
index: 8 : Shout
index: 9 : Bellow
index: 10 : Whoop
index: 11 : Yell
index: 12 : Battle cry
index: 13 : Children shouting
index: 14 : Screaming
index: 15 : Whispering
index: 16 : Laughter
index: 17 : Baby laughter
index: 18 : Giggle
index: 19 : Snicker
index: 20 : Belly laugh
index: 21 : Chuckle, chortle
index: 22 : Crying, sobbing
index: 23 : Baby cry, infant cry
index: 24 : Whimper
index: 25 : Wail, moan
index: 26 : Sigh
index: 27 : Singing
index: 28 : Choir
index: 29 : Yodeling
index: 30 : Chant
index: 31 : Mantra
index: 32 : Male singing
index: 33 : Female singing
index: 34 : Child singing
index: 35 : Synthetic singing
index: 36 : Rapping
index: 37 : Humming
index: 38 : Groan
index: 39 : Grunt
index: 40 : Whistling
index: 41 : Breathing
index: 4

Assume we only interested in class 0, 1, 2 (Speech). We can let Whisper-AT only output these classes.

In [18]:
audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=10, p_threshold=-5, include_class_list=[0, 1, 2])
for segment in audio_tag_result:
  print(segment)

{'time': {'start': 0, 'end': 10}, 'audio tags': [('スピーチの音', 0.1549655944108963)]}
{'time': {'start': 10, 'end': 20}, 'audio tags': [('スピーチの音', 0.5534265041351318), ('男性のスピーチの音、男性のスピーチ', -4.174395561218262)]}
{'time': {'start': 20, 'end': 30}, 'audio tags': [('スピーチの音', 0.22931687533855438)]}
{'time': {'start': 30, 'end': 40}, 'audio tags': [('スピーチの音', -2.1657419204711914)]}


Finally, Whisper-AT support multiple languages. The default behavior is to output audio tag label names in the same language as ASR transcripts (i.e., `follow-asr`). But you can specify any supported language. Check supported language by:

In [19]:
whisper.print_support_language()

language code: en : english
language code: zh : chinese
language code: de : german
language code: es : spanish
language code: ru : russian
language code: ko : korean
language code: fr : french
language code: ja : japanese
language code: pt : portuguese
language code: tr : turkish
language code: pl : polish
language code: ca : catalan
language code: nl : dutch
language code: ar : arabic
language code: sv : swedish
language code: it : italian
language code: id : indonesian
language code: hi : hindi
language code: fi : finnish
language code: vi : vietnamese
language code: he : hebrew
language code: uk : ukrainian
language code: el : greek
language code: ms : malay
language code: cs : czech
language code: ro : romanian
language code: da : danish
language code: hu : hungarian
language code: ta : tamil
language code: no : norwegian
language code: th : thai
language code: ur : urdu
language code: hr : croatian
language code: bg : bulgarian
language code: lt : lithuanian
language code: mi : ma

Let's say we want the output labels in Chinese (zh):

In [20]:
audio_tag_result = whisper.parse_at_label(result, language='zh', top_k=5, p_threshold=-2, include_class_list=list(range(527)))
for segment in audio_tag_result:
  print(segment)

{'time': {'start': 0, 'end': 10}, 'audio tags': [('说话的声音', 0.1549655944108963), ('车辆声音', -1.3442070484161377), ('音乐之声', -1.5613001585006714)]}
{'time': {'start': 10, 'end': 20}, 'audio tags': [('说话的声音', 0.5534265041351318), ('室外、城市或人造的声音', -1.9178415536880493)]}
{'time': {'start': 20, 'end': 30}, 'audio tags': [('说话的声音', 0.22931687533855438), ('寂静的声音', -0.5010502934455872), ('音乐之声', -1.9494613409042358)]}
{'time': {'start': 30, 'end': 40}, 'audio tags': [('寂静的声音', 0.8264279961585999)]}


### Step 4. Transcribe a video

Let's check the result! The above audio track is actually from a video. You can of course generate .srt. But in this example, we directly put text and audio transcriptions to the video.

**Step 4 is independent from the above, replace the URL and play with your own video!**

In [21]:
from IPython.display import HTML
from base64 import b64encode
# Replace this URL to play with your own video
#wget.download('https://www.dropbox.com/s/pzc72c59xtluuc0/case_closed.mp4?dl=1', './content/soccer-game.mp4')
# mp4 = open('/content/soccer-game.mp4','rb').read()
# data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
# HTML("""
# <video width=800 controls>
#       <source src="%s" type="video/mp4">
# </video>
# """ % data_url)

In [22]:
# install packages for video processing
!pip install -q ffmpeg-python
!pip install opencv-python
!pip install pillow



In [25]:
import os,ffmpeg,cv2
import numpy as np
from PIL import ImageFont, ImageDraw, Image

def dubbing_video(video_path, out_video_path, text_info, font_size=0.5, font_v_pos=0.90, font_color=(0, 0, 255)):
    extract_audio(video_path, './temp_audio.wav')

    video = cv2.VideoCapture(video_path)
    # Get video properties
    frame_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = video.get(cv2.CAP_PROP_FPS)

    # Create output video writer
    output_video = cv2.VideoWriter('./temp_video.mp4', cv2.VideoWriter_fourcc(*"mp4v"), fps,
                                   (frame_width, frame_height))

    # Process each frame of the video
    current_frame = 0
    while video.isOpened():
        ret, frame = video.read()
        if not ret:
            break

        # Calculate current time in seconds
        current_time = current_frame / fps

        # Iterate through text information and add text if within the time interval
        for text_start, text_end, text in text_info:
            if text_start <= current_time <= text_end:
                text_position = (int(frame_width * 0.0), int(frame_height * font_v_pos))

                # Create font
                font = ImageFont.truetype('/usr/share/fonts/opentype/noto/NotoSansCJK-Bold.ttc', font_size*75)

                # Convert to PIL Image
                img_pil = Image.fromarray(frame)

                # Draw the text on the image
                draw = ImageDraw.Draw(img_pil)

                # Draw the text
                draw.text(text_position, text, font=font, fill=font_color)

                # Save the image back to frame
                frame = np.array(img_pil)

        # Write the frame to the output video
        output_video.write(frame)
        current_frame += 1

    # Release video resources
    video.release()
    output_video.release()

    combine_audio_video('./temp_video.mp4', './temp_audio.wav', out_video_path)
    os.remove('./temp_video.mp4')
    os.remove('./temp_audio.wav')

def combine_audio_video(video_path, audio_path, output_path):
    video = ffmpeg.input(video_path)
    audio = ffmpeg.input(audio_path)
    output_file = ffmpeg.output(video, audio, output_path)
    output_file.overwrite_output().run()

def extract_audio(video_path, output_path):
    video = ffmpeg.input(video_path)
    audio = video.audio
    output_file = ffmpeg.output(audio, output_path)
    output_file.overwrite_output().run()

extract_audio('./content/soccer-game.mp4', './content/soccer-game_from_video.wav')

#lang = 'ja'
lang = 'auto'

if lang == 'ja': # Japanese
    result = model.transcribe("./content/soccer-game_from_video.wav", at_time_res=audio_tagging_time_resolution, language=lang)
else: # Other languages
    result = model.transcribe("./content/soccer-game_from_video.wav", at_time_res=audio_tagging_time_resolution)

# ASR Output
text_segments = result['segments']
text_annotation = [(x['start'], x['end'], x['text']) for x in text_segments]

# Audio Tagging Output
audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=5, p_threshold=-2, include_class_list=list(range(527)))

all_seg = []
for segment in audio_tag_result:
    cur_start = segment['time']['start']
    cur_end = segment['time']['end']
    cur_tags = segment['audio tags']
    cur_tags = [x[0] for x in cur_tags]
    cur_tags = '; '.join(cur_tags)
    all_seg.append((cur_start, cur_end, cur_tags))

dubbing_video('./content/soccer-game.mp4', './content/soccer-game_at.mp4', all_seg)
dubbing_video('./content/soccer-game_at.mp4', f'./content/soccer-game_at_text_{model_version}_{lang}.mp4', text_annotation, font_color=(0,255,0), font_v_pos=0.80)


ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)
  configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enab

That's all. If you like the project, give us a star at https://github.com/YuanGongND/whisper-at.