# Masked Modeling Duo (M2D) Example (Non Colab version) -- Audio Tagging

We show an example of audio tagging using a fine-tuned M2D model.
[M2D](https://github.com/nttcslab/m2d) is an audio self-supervised learning model pre-trained on [AudioSet](https://research.google.com/audioset/) without using labels.
After the M2D pre-training, the pre-trained model was fine-tuned on AudioSet (with labels).

We use the fine-tuned model and demonstrate how it predicts AudioSet classes for audio segments.

In [1]:
# The code depends on these external modules.
# ! pip install timm einops nnAudio librosa >& /dev/null

import warnings; warnings.simplefilter('ignore')
import logging; logging.basicConfig(level=logging.INFO)
import numpy as np
import pandas as pd
from pathlib import Path
import torch
import zipfile
import librosa

INFO:numexpr.utils:Note: detected 80 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
INFO:numexpr.utils:Note: NumExpr detected 80 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.


In [2]:
# Downloads the AudioSet class definition. It has 527 classes.
! wget http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv >& /dev/null
classes = pd.read_csv('class_labels_indices.csv').sort_values('mid').reset_index()
classes[:3]

Unnamed: 0,level_0,index,mid,display_name
0,433,433,/g/122z_qxw,Firecracker
1,169,169,/m/011k_j,Timpani
2,108,108,/m/01280g,Wild animals


In [3]:
# Also downloads example audio files for demonstration.
! wget https://github.com/nttcslab/msm-mae/releases/download/v0.0.1/AudioSetWav16k_examples.zip >& /dev/null
with zipfile.ZipFile("AudioSetWav16k_examples.zip", "r") as zip_ref:
    zip_ref.extractall(".")
! ls AudioSetWav16k/eval_segments

--U7joUcTCo_0.000.wav	-1nilez17Dg_30.000.wav	3tUlhM80ObM_0.000.wav
-0xzrMun0Rs_30.000.wav	-22tna7KHzI_28.000.wav	5hlsVoxJPNI_30.000.wav


## Download M2D
- portable_m2d.py -- A portable loader, no dependance on other files from M2D repository.
- m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d.zip -- An AudioSet fine-tuned weight file

In [4]:
! wget https://github.com/nttcslab/m2d/releases/download/v0.3.0/m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d.zip
with zipfile.ZipFile("m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d.zip", "r") as zip_ref:
    zip_ref.extractall(".")
! find m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d/ -name *.pth

m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d/weights_ep69it3124-0.47929.pth


## Create model

Two lines of code get a model ready for classification.

In [5]:
from portable_m2d import PortableM2D
model = PortableM2D(weight_file='m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d/weights_ep69it3124-0.47929.pth', num_classes=527)

INFO:root:<All keys matched successfully>
INFO:root:Model input size: [80, 1001]
INFO:root:Using weights: m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d/weights_ep69it3124-0.47929.pth
INFO:root:Feature dimension: 768
INFO:root:Norm stats: -7.099999904632568, 4.199999809265137
INFO:root:Runtime MelSpectrogram(16000, 400, 400, 160, 80, 50, 8000):
INFO:root:MelSpectrogram(
  Mel filter banks size = (80, 201), trainable_mel=False
  (stft): STFT(n_fft=400, Fourier Kernel size=(201, 1, 400), iSTFT=False, trainable=False)
)


 using 150 parameters, while dropped 10 out of 160 parameters from m2d_vit_base-80x1001p16x16-221006-mr7_as_46ab246d/weights_ep69it3124-0.47929.pth
 (dropped: ['module.ar.runtime.to_spec.mel_basis', 'module.ar.runtime.to_spec.stft.wsin', 'module.ar.runtime.to_spec.stft.wcos', 'module.ar.runtime.to_spec.stft.window_mask', 'module.head.norm.running_mean'] ...)
<All keys matched successfully>


### An audio tagging example

In [10]:
from IPython.display import display, Audio

def show_topk(classes, m2d, wav_file, k=5):
    print(wav_file)
    # Loads and shows an audio clip.
    wav, _ = librosa.load(wav_file, mono=True, sr=m2d.cfg.sample_rate)
    display(Audio(wav, rate=m2d.cfg.sample_rate))
    wav = torch.tensor(wav).unsqueeze(0)
    # Predicts class probabilities for the batch segments.
    with torch.no_grad():
        probs = m2d(wav).squeeze(0).softmax(0)
    # Shows the top-k prediction results.
    topk_values, topk_indices = probs.topk(k=k)
    print(', '.join([f'{classes.loc[i].display_name} ({v*100:.1f}%)' for i, v in zip(topk_indices.numpy(), topk_values.numpy())]))
    print()

files = list(Path('AudioSetWav16k/eval_segments').glob('*.wav'))
files = np.random.choice(files, size=3, replace=False)

for fn in files:
    show_topk(classes, model, fn)

AudioSetWav16k/eval_segments/-0xzrMun0Rs_30.000.wav


Music (73.1%), Stomach rumble (10.4%), Music for children (1.8%), Purr (0.9%), Inside, small room (0.6%)

AudioSetWav16k/eval_segments/--U7joUcTCo_0.000.wav


Laughter (29.9%), Snicker (29.5%), Cough (10.9%), Chuckle, chortle (5.3%), Wheeze (1.7%)

AudioSetWav16k/eval_segments/5hlsVoxJPNI_30.000.wav


Music (62.4%), Speech (26.3%), Singing (1.1%), Music for children (1.1%), Lullaby (0.8%)



### Audio tagging with sliding window

The following demonstrates the progress of audio tags over seconds, like a sound event detection.

In [11]:
def repeat_if_short(w, min_duration=48000):
    while w.shape[-1] < min_duration:
        w = np.concatenate([w, w], axis=-1)
    return w[..., :min_duration]

def show_topk_sliding_window(classes, m2d, wav_file, k=5, hop=1, duration=2.):
    print(wav_file)
    # Loads and shows an audio clip.
    wav, sr = librosa.load(wav_file, mono=True, sr=m2d.cfg.sample_rate)
    display(Audio(wav, rate=sr))
    # Makes a batch of short segments of the wav into wavs, cropped by the sliding window of [hop, duration].
    wavs = [wav[int(c * sr) : int((c + duration) * sr)] for c in np.arange(0, wav.shape[-1] / sr, hop)]
    wavs = [repeat_if_short(wav) for wav in wavs]
    wavs = torch.tensor(wavs)
    # Predicts class probabilities for the batch segments.
    with torch.no_grad():
        probs_per_chunk = m2d(wavs).softmax(1)
    # Shows the top-k prediction results.
    for i, probs in enumerate(probs_per_chunk):
        topk_values, topk_indices = probs.topk(k=k)
        sec = f'{i * hop:d}s '
        print(sec, ', '.join([f'{classes.loc[i].display_name} ({v*100:.1f}%)' for i, v in zip(topk_indices.numpy(), topk_values.numpy())]))
    print()

for fn in files:
    show_topk_sliding_window(classes, model, fn)

AudioSetWav16k/eval_segments/-0xzrMun0Rs_30.000.wav


0s  Music (70.6%), Trumpet (7.9%), Jazz (2.7%), Saxophone (2.4%), Brass instrument (1.6%)
1s  Music (71.8%), Trumpet (7.1%), Brass instrument (3.3%), Saxophone (2.3%), Jazz (2.0%)
2s  Music (91.9%), Soundtrack music (1.0%), Video game music (0.7%), Funny music (0.5%), Purr (0.4%)
3s  Music (89.6%), Stomach rumble (0.7%), Purr (0.6%), Music for children (0.5%), Sound effect (0.5%)
4s  Music (89.8%), Music for children (0.8%), Background music (0.6%), Funny music (0.6%), Video game music (0.4%)
5s  Music (87.5%), Musical instrument (0.9%), Happy music (0.7%), Background music (0.6%), Guitar (0.4%)
6s  Music (87.8%), Musical instrument (0.9%), Background music (0.6%), Speech (0.4%), Tender music (0.3%)
7s  Music (69.1%), Single-lens reflex camera (5.3%), Crunch (4.3%), Camera (3.3%), Biting (2.0%)
8s  Music (76.2%), Single-lens reflex camera (8.0%), Crunch (2.9%), Camera (2.1%), Crack (1.2%)
9s  Music (91.2%), Musical instrument (0.8%), Guitar (0.3%), Beatboxing (0.3%), Drum machine (0.2%

0s  Laughter (38.9%), Snicker (19.8%), Chuckle, chortle (5.4%), Cough (4.1%), Speech (1.7%)
1s  Snicker (31.5%), Chuckle, chortle (13.4%), Laughter (9.8%), Cough (4.0%), Speech (2.8%)
2s  Cough (23.5%), Snicker (22.2%), Laughter (9.4%), Chuckle, chortle (8.6%), Wheeze (3.8%)
3s  Wheeze (15.8%), Cough (6.4%), Whimper (6.2%), Speech (4.0%), Snort (3.4%)
4s  Cough (26.1%), Whimper (13.1%), Wheeze (5.9%), Throat clearing (5.3%), Breathing (4.4%)
5s  Wheeze (26.0%), Breathing (18.3%), Sneeze (5.6%), Cough (5.4%), Whimper (3.8%)
6s  Cough (79.5%), Throat clearing (10.0%), Sneeze (3.5%), Wheeze (0.9%), Burping, eructation (0.7%)
7s  Cough (17.6%), Sneeze (12.0%), Throat clearing (12.0%), Burping, eructation (11.1%), Laughter (3.8%)
8s  Snicker (31.9%), Laughter (14.8%), Chuckle, chortle (10.7%), Sound effect (3.1%), Whimper (2.4%)
9s  Silence (96.0%), Music (0.6%), Vehicle (0.2%), Speech (0.2%), Inside, small room (0.1%)

AudioSetWav16k/eval_segments/5hlsVoxJPNI_30.000.wav


0s  Music (53.8%), Lullaby (7.6%), Female singing (6.6%), Singing (2.5%), Music for children (2.3%)
1s  Music (61.1%), Music for children (7.3%), Electronic music (5.0%), A capella (2.7%), Female singing (2.2%)
2s  Music (74.2%), Lullaby (9.3%), Music for children (4.3%), Singing (1.5%), Electronic music (0.4%)
3s  Music (68.1%), Music for children (4.6%), Humming (4.2%), Lullaby (3.6%), Singing (2.4%)
4s  Speech (89.5%), Female speech, woman speaking (1.2%), Narration, monologue (1.0%), Ping (0.9%), Busy signal (0.8%)
5s  Speech (60.2%), Female speech, woman speaking (26.5%), Narration, monologue (7.7%), Speech synthesizer (1.6%), Conversation (0.5%)
6s  Speech (81.9%), Female speech, woman speaking (7.9%), Speech synthesizer (1.7%), Narration, monologue (1.3%), Music (0.4%)
7s  Silence (14.6%), Owl (7.2%), Busy signal (6.8%), Heart sounds, heartbeat (4.3%), Music (4.1%)
8s  Music (74.9%), Silence (3.6%), Tick (1.7%), Tick-tock (1.6%), Violin, fiddle (1.2%)
9s  Music (74.3%), Pizzicat

### Audio tagging for all available frames

A ViT splits inputs into patches for both frequency and time axes. (ex., An 80x1001 10-s spectrogram will be 5x62=310 patches)

Then, ViT encodes all patches into embeddings $X$. (e.g., $X \in R^{5\times62\times768}$, where 768 is a feature dimension)

The forward_frames function in our wrapper class PortableM2D summarizes $X$ into $X' \in R^{62\times768}$, embeddings averaged per each time frame, and predicts $Y \in R^{62\times527}$, logits for each frame. (527 is the number of classes)

In summary, 62 frames are available for a 10-s audio, and each frame has 527 logits that become 527 class probabilities after a softmax operation.

In [12]:
def show_topk_for_all_frames(classes, m2d, wav_file, k=5):
    print(wav_file)
    # Loads and shows an audio clip.
    wav, _ = librosa.load(wav_file, mono=True, sr=m2d.cfg.sample_rate)
    display(Audio(wav, rate=m2d.cfg.sample_rate))
    wav = torch.tensor(wav)
    # Predicts class probabilities for all frames.
    with torch.no_grad():
        logits_per_chunk, timestamps = m2d.forward_frames(wav.unsqueeze(0))  # logits_per_chunk: [1, 62, 527], timestamps: [1, 62]
        probs_per_chunk = logits_per_chunk.squeeze(0).softmax(-1)  # logits [1, 62, 527] -> probabilities [62, 527]
        timestamps = timestamps[0]  # [1, 62] -> [62]
    # Shows the top-k prediction results.
    for i, (probs, ts) in enumerate(zip(probs_per_chunk, timestamps)):
        topk_values, topk_indices = probs.topk(k=k)
        sec = f'{ts/1000:.1f}s '
        print(sec, ', '.join([f'{classes.loc[i].display_name} ({v*100:.1f}%)' for i, v in zip(topk_indices.numpy(), topk_values.numpy())]))
    print()

show_topk_for_all_frames(classes, model, files[2])

AudioSetWav16k/eval_segments/5hlsVoxJPNI_30.000.wav


0.0s  Music (29.3%), Choir (17.8%), Speech (12.6%), Singing (7.3%), Christmas music (6.3%)
0.2s  Music (38.9%), Speech (18.4%), Choir (13.5%), Singing (5.1%), Music for children (4.5%)
0.3s  Music (51.6%), Singing (10.9%), Speech (8.0%), Child speech, kid speaking (7.6%), Choir (6.2%)
0.5s  Music (42.9%), Singing (11.3%), Choir (10.0%), Speech (8.6%), Child speech, kid speaking (8.1%)
0.6s  Music (39.3%), Speech (17.1%), Choir (11.3%), Singing (7.2%), Christmas music (4.4%)
0.8s  Music (42.4%), Speech (17.9%), Choir (6.3%), Singing (4.9%), Music for children (4.1%)
1.0s  Music (78.0%), Speech (9.0%), Christmas music (2.3%), Tender music (2.0%), Music for children (1.2%)
1.1s  Music (50.9%), Speech (14.4%), Music for children (7.4%), Female singing (4.7%), Synthetic singing (3.2%)
1.3s  Music (60.9%), Speech (12.1%), Music for children (4.2%), Mantra (2.9%), Singing (2.0%)
1.5s  Music (52.5%), Speech (24.7%), Female singing (3.2%), Christmas music (2.1%), Singing (2.1%)
1.6s  Music (63.