# Masked Modeling Duo (M2D) Example -- Audio Tagging

We show an example of audio tagging using a fine-tuned M2D model.
[M2D](https://github.com/nttcslab/m2d) is an audio self-supervised learning model pre-trained on [AudioSet](https://research.google.com/audioset/) without using labels.
After the M2D pre-training, the pre-trained model was fine-tuned on AudioSet (with labels).

We use the fine-tuned model and demonstrate how it predicts AudioSet classes for audio segments.

In [1]:
# The code depends on these external modules.
! pip install timm einops nnAudio librosa >& /dev/null

import warnings; warnings.simplefilter('ignore')
import logging; logging.basicConfig(level=logging.INFO)
import numpy as np
import pandas as pd
from pathlib import Path
import torch
import zipfile
import librosa


In [2]:
# Downloads the AudioSet class definition. It has 527 classes.
! wget http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv >& /dev/null
classes = pd.read_csv('class_labels_indices.csv').sort_values('mid').reset_index()
classes[:3]

Unnamed: 0,level_0,index,mid,display_name
0,433,433,/g/122z_qxw,Firecracker
1,169,169,/m/011k_j,Timpani
2,108,108,/m/01280g,Wild animals


In [3]:
# Also downloads example audio files for demonstration.
! wget https://github.com/nttcslab/msm-mae/releases/download/v0.0.1/AudioSetWav16k_examples.zip >& /dev/null
with zipfile.ZipFile("AudioSetWav16k_examples.zip", "r") as zip_ref:
    zip_ref.extractall(".")
! ls AudioSetWav16k/eval_segments

-0xzrMun0Rs_30.000.wav	-22tna7KHzI_28.000.wav	5hlsVoxJPNI_30.000.wav
-1nilez17Dg_30.000.wav	3tUlhM80ObM_0.000.wav	--U7joUcTCo_0.000.wav


## Download M2D
- portable_m2d.py -- A portable loader, no dependance on other files from M2D repository.
- m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d.zip -- An AudioSet fine-tuned weight file

In [4]:
! wget https://raw.githubusercontent.com/nttcslab/m2d/master/examples/portable_m2d.py
! wget https://github.com/nttcslab/m2d/releases/download/v0.3.0/m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d.zip

with zipfile.ZipFile("m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d.zip", "r") as zip_ref:
    zip_ref.extractall(".")
! find m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d -name *.pth

--2024-03-26 02:55:14--  https://raw.githubusercontent.com/nttcslab/m2d/master/examples/portable_m2d.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 12693 (12K) [text/plain]
Saving to: ‘portable_m2d.py’


2024-03-26 02:55:14 (99.8 MB/s) - ‘portable_m2d.py’ saved [12693/12693]

--2024-03-26 02:55:14--  https://github.com/nttcslab/m2d/releases/download/v0.3.0/m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d.zip
Resolving github.com (github.com)... 140.82.112.3
Connecting to github.com (github.com)|140.82.112.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/589370928/0bdeb8a7-c3f3-44c5-afb9-9b9edaa3e861?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credenti

## Create model

Two lines of code get a model ready for classification.

In [5]:
from portable_m2d import PortableM2D
model = PortableM2D(weight_file='m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d/weights_ep69it3124-0.47929.pth', num_classes=527)

 using 150 parameters, while dropped 10 out of 160 parameters from m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d/weights_ep69it3124-0.47929.pth
 (dropped: ['module.ar.runtime.to_spec.mel_basis', 'module.ar.runtime.to_spec.stft.wsin', 'module.ar.runtime.to_spec.stft.wcos', 'module.ar.runtime.to_spec.stft.window_mask', 'module.head.norm.running_mean'] ...)
<All keys matched successfully>


### An audio tagging example

In [6]:
from IPython.display import display, Audio

def show_topk(classes, m2d, wav_file, k=5):
    print(wav_file)
    # Loads and shows an audio clip.
    wav, _ = librosa.load(wav_file, mono=True, sr=m2d.cfg.sample_rate)
    display(Audio(wav, rate=m2d.cfg.sample_rate))
    wav = torch.tensor(wav).unsqueeze(0)
    # Predicts class probabilities for the batch segments.
    with torch.no_grad():
        probs = m2d(wav).squeeze(0).softmax(0)
    # Shows the top-k prediction results.
    topk_values, topk_indices = probs.topk(k=k)
    print(', '.join([f'{classes.loc[i].display_name} ({v*100:.1f}%)' for i, v in zip(topk_indices.numpy(), topk_values.numpy())]))
    print()

files = list(Path('AudioSetWav16k/eval_segments').glob('*.wav'))
files = np.random.choice(files, size=3, replace=False)

for fn in files:
    show_topk(classes, model, fn)

AudioSetWav16k/eval_segments/-1nilez17Dg_30.000.wav


Speech (45.9%), Heart sounds, heartbeat (24.5%), Heart murmur (14.4%), Music (2.0%), Throbbing (1.8%)

AudioSetWav16k/eval_segments/-22tna7KHzI_28.000.wav


Eruption (12.7%), Sound effect (12.4%), Explosion (12.0%), Whoosh, swoosh, swish (8.0%), White noise (4.7%)

AudioSetWav16k/eval_segments/--U7joUcTCo_0.000.wav


Laughter (29.9%), Snicker (29.5%), Cough (10.9%), Chuckle, chortle (5.3%), Wheeze (1.7%)



### Audio tagging with sliding window

The following demonstrates the progress of audio tags over seconds, like a sound event detection.

In [7]:
def repeat_if_short(w, min_duration=48000):
    while w.shape[-1] < min_duration:
        w = np.concatenate([w, w], axis=-1)
    return w[..., :min_duration]

def show_topk_sliding_window(classes, m2d, wav_file, k=5, hop=1, duration=2.):
    print(wav_file)
    # Loads and shows an audio clip.
    wav, sr = librosa.load(wav_file, mono=True, sr=m2d.cfg.sample_rate)
    display(Audio(wav, rate=sr))
    # Makes a batch of short segments of the wav into wavs, cropped by the sliding window of [hop, duration].
    wavs = [wav[int(c * sr) : int((c + duration) * sr)] for c in np.arange(0, wav.shape[-1] / sr, hop)]
    wavs = [repeat_if_short(wav) for wav in wavs]
    wavs = torch.tensor(wavs)
    # Predicts class probabilities for the batch segments.
    with torch.no_grad():
        probs_per_chunk = m2d(wavs).softmax(1)
    # Shows the top-k prediction results.
    for i, probs in enumerate(probs_per_chunk):
        topk_values, topk_indices = probs.topk(k=k)
        sec = f'{i * hop:d}s '
        print(sec, ', '.join([f'{classes.loc[i].display_name} ({v*100:.1f}%)' for i, v in zip(topk_indices.numpy(), topk_values.numpy())]))
    print()

for fn in files:
    show_topk_sliding_window(classes, model, fn)

AudioSetWav16k/eval_segments/-1nilez17Dg_30.000.wav


0s  Heart sounds, heartbeat (78.8%), Heart murmur (13.3%), Throbbing (3.5%), Hum (1.2%), Pulse (0.5%)
1s  Heart sounds, heartbeat (72.5%), Heart murmur (13.3%), Throbbing (4.1%), Hum (1.9%), Speech (1.3%)
2s  Speech (50.4%), Heart sounds, heartbeat (26.6%), Heart murmur (15.1%), Throbbing (0.8%), Music (0.7%)
3s  Speech (79.2%), Female speech, woman speaking (11.7%), Narration, monologue (2.7%), Speech synthesizer (0.8%), Conversation (0.7%)
4s  Speech (75.0%), Female speech, woman speaking (16.6%), Narration, monologue (3.3%), Speech synthesizer (0.5%), Male speech, man speaking (0.3%)
5s  Speech (72.2%), Female speech, woman speaking (14.2%), Narration, monologue (8.9%), Speech synthesizer (1.2%), Male speech, man speaking (0.5%)
6s  Speech (73.4%), Female speech, woman speaking (16.0%), Narration, monologue (2.9%), Speech synthesizer (0.6%), Tick (0.3%)
7s  Speech (71.1%), Female speech, woman speaking (21.9%), Narration, monologue (2.4%), Speech synthesizer (1.0%), Conversation (0.

0s  Eruption (15.3%), Explosion (14.6%), Sound effect (8.3%), Music (4.7%), White noise (4.4%)
1s  Eruption (13.5%), Explosion (5.8%), White noise (4.0%), Music (3.9%), Rumble (3.5%)
2s  Eruption (19.4%), Explosion (6.8%), Rumble (5.4%), Field recording (3.2%), White noise (3.1%)
3s  Eruption (23.7%), Explosion (11.4%), Field recording (6.9%), Fixed-wing aircraft, airplane (3.6%), Aircraft (3.1%)
4s  Eruption (19.0%), White noise (5.8%), Explosion (5.6%), Field recording (5.0%), Sound effect (4.3%)
5s  Rumble (11.2%), Sound effect (5.6%), White noise (4.9%), Music (4.6%), Static (4.1%)
6s  Eruption (11.4%), Rumble (10.7%), Sound effect (8.4%), Explosion (4.2%), White noise (4.0%)
7s  Explosion (26.7%), Eruption (12.8%), Sound effect (8.8%), Music (4.2%), Burst, pop (3.9%)
8s  Sound effect (24.4%), Explosion (16.0%), Eruption (4.7%), Music (3.0%), Rumble (2.8%)
9s  Music (18.7%), Sound effect (14.2%), Electronic music (5.0%), Vehicle (3.7%), Heart murmur (1.7%)

AudioSetWav16k/eval_segm

0s  Laughter (38.9%), Snicker (19.8%), Chuckle, chortle (5.4%), Cough (4.1%), Speech (1.7%)
1s  Snicker (31.5%), Chuckle, chortle (13.4%), Laughter (9.8%), Cough (4.0%), Speech (2.8%)
2s  Cough (23.5%), Snicker (22.2%), Laughter (9.4%), Chuckle, chortle (8.6%), Wheeze (3.8%)
3s  Wheeze (15.8%), Cough (6.4%), Whimper (6.2%), Speech (4.0%), Snort (3.4%)
4s  Cough (26.1%), Whimper (13.1%), Wheeze (5.9%), Throat clearing (5.3%), Breathing (4.4%)
5s  Wheeze (26.0%), Breathing (18.3%), Sneeze (5.6%), Cough (5.4%), Whimper (3.8%)
6s  Cough (79.5%), Throat clearing (10.0%), Sneeze (3.5%), Wheeze (0.9%), Burping, eructation (0.7%)
7s  Cough (17.6%), Sneeze (12.0%), Throat clearing (12.0%), Burping, eructation (11.1%), Laughter (3.8%)
8s  Snicker (31.9%), Laughter (14.8%), Chuckle, chortle (10.7%), Sound effect (3.1%), Whimper (2.4%)
9s  Silence (96.0%), Music (0.6%), Vehicle (0.2%), Speech (0.2%), Inside, small room (0.1%)



### Audio tagging for all available frames

A ViT splits inputs into patches for both frequency and time axes. (ex., An 80x1001 10-s spectrogram will be 5x62=310 patches)

Then, ViT encodes all patches into embeddings $X$. (e.g., $X \in R^{5\times62\times768}$, where 768 is a feature dimension)

The forward_frames function in our wrapper class PortableM2D summarizes $X$ into $X' \in R^{62\times768}$, embeddings averaged per each time frame, and predicts $Y \in R^{62\times527}$, logits for each frame. (527 is the number of classes)

In summary, 62 frames are available for a 10-s audio, and each frame has 527 logits that become 527 class probabilities after a softmax operation.

In [12]:
def show_topk_for_all_frames(classes, m2d, wav_file, k=5):
    print(wav_file)
    # Loads and shows an audio clip.
    wav, _ = librosa.load(wav_file, mono=True, sr=m2d.cfg.sample_rate)
    display(Audio(wav, rate=m2d.cfg.sample_rate))
    wav = torch.tensor(wav)
    # Predicts class probabilities for all frames.
    with torch.no_grad():
        logits_per_chunk, timestamps = m2d.forward_frames(wav.unsqueeze(0))  # logits_per_chunk: [1, 62, 527], timestamps: [1, 62]
        probs_per_chunk = logits_per_chunk.squeeze(0).softmax(-1)  # logits [1, 62, 527] -> probabilities [62, 527]
        timestamps = timestamps[0]  # [1, 62] -> [62]
    # Shows the top-k prediction results.
    for i, (probs, ts) in enumerate(zip(probs_per_chunk, timestamps)):
        topk_values, topk_indices = probs.topk(k=k)
        sec = f'{ts/1000:.1f}s '
        print(sec, ', '.join([f'{classes.loc[i].display_name} ({v*100:.1f}%)' for i, v in zip(topk_indices.numpy(), topk_values.numpy())]))
    print()

show_topk_for_all_frames(classes, model, files[0])

AudioSetWav16k/eval_segments/-1nilez17Dg_30.000.wav


0.0s  Heart murmur (32.2%), Speech (29.9%), Heart sounds, heartbeat (22.6%), Silence (5.1%), Chirp tone (2.6%)
0.2s  Heart murmur (55.4%), Heart sounds, heartbeat (21.7%), Speech (14.7%), Silence (1.8%), Throbbing (1.4%)
0.3s  Heart murmur (66.1%), Heart sounds, heartbeat (11.9%), Speech (10.8%), Silence (4.3%), Music (1.1%)
0.5s  Heart murmur (77.3%), Heart sounds, heartbeat (12.6%), Speech (5.8%), Throbbing (1.4%), Silence (0.5%)
0.6s  Heart murmur (56.5%), Heart sounds, heartbeat (28.5%), Speech (6.1%), Throbbing (2.6%), Silence (1.5%)
0.8s  Heart murmur (69.9%), Heart sounds, heartbeat (13.9%), Silence (5.2%), Speech (5.1%), Music (1.2%)
1.0s  Heart murmur (72.5%), Heart sounds, heartbeat (10.8%), Speech (10.5%), Throbbing (1.1%), Music (1.1%)
1.1s  Heart murmur (59.4%), Heart sounds, heartbeat (15.3%), Speech (11.1%), Throbbing (3.1%), Hum (2.9%)
1.3s  Heart murmur (60.6%), Heart sounds, heartbeat (15.9%), Speech (10.7%), Throbbing (3.4%), Music (1.9%)
1.5s  Heart murmur (40.4%), 