# M2D 32kHz Inout Example

This is an example of the model with a 32 kHz input.

We pre-trained a 32 kHz input M2D model and then fine-tuned it on AudioSet 2M at the sampling rate of 32 kHz.

We use both fine-tuned and pre-trained only models and demonstrate how you can load the weight, encode raw audio into features, or just predict AudioSet classes using the loaded model.

In [1]:
# The code depends on these external modules.
# ! pip install timm einops nnAudio librosa >& /dev/null

import warnings; warnings.simplefilter('ignore')
import logging; logging.basicConfig(level=logging.INFO)
import numpy as np
import pandas as pd
from pathlib import Path
import torch
import zipfile
import librosa

INFO:numexpr.utils:Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.


In [2]:
# Downloads the AudioSet class definition. It has 527 classes.
# ! wget http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv >& /dev/null
classes = pd.read_csv('class_labels_indices.csv').sort_values('mid').reset_index()
classes[:3]

Unnamed: 0,level_0,index,mid,display_name
0,433,433,/g/122z_qxw,Firecracker
1,169,169,/m/011k_j,Timpani
2,108,108,/m/01280g,Wild animals


In [3]:
# Also downloads example audio files for demonstration.
# ! wget https://github.com/nttcslab/msm-mae/releases/download/v0.0.1/AudioSetWav16k_examples.zip >& /dev/null
# with zipfile.ZipFile("AudioSetWav16k_examples.zip", "r") as zip_ref:
#     zip_ref.extractall(".")
! ls AudioSetWav16k/eval_segments

-0xzrMun0Rs_30.000.wav	-22tna7KHzI_28.000.wav	5hlsVoxJPNI_30.000.wav
-1nilez17Dg_30.000.wav	3tUlhM80ObM_0.000.wav	--U7joUcTCo_0.000.wav


## Download M2D

We download two weights.

- m2d_as_vit_base-80x608p16x16p32k-240413_enconly.zip -- A pre-trained weight file (not fine-tuned), for general-purpose use
- m2d_as_vit_base-80x1001p16x16p32k-240413_AS-FT_enconly.zip -- An AudioSet fine-tuned weight file, for SED or tagging

We also use `portable_m2d.py` for instantiating models.


In [4]:
# ! wget https://github.com/nttcslab/m2d/releases/download/v0.3.0/m2d_as_vit_base-80x1001p16x16p32k-240413_AS-FT_enconly.zip
# with zipfile.ZipFile("m2d_as_vit_base-80x1001p16x16p32k-240413_AS-FT_enconly.zip", "r") as zip_ref:
#     zip_ref.extractall(".")
! find m2d_as_vit_base-80x1001p16x16p32k-240413_AS-FT_enconly/ -name *.pth

m2d_as_vit_base-80x1001p16x16p32k-240413_AS-FT_enconly/weights_ep69it3124-0.47998.pth


In [5]:
# ! wget https://github.com/nttcslab/m2d/releases/download/v0.1.0/m2d_as_vit_base-80x608p16x16p32k-240413_enconly.zip
# with zipfile.ZipFile("m2d_as_vit_base-80x608p16x16p32k-240413_enconly.zip", "r") as zip_ref:
#     zip_ref.extractall(".")
! find m2d_as_vit_base-80x608p16x16p32k-240413_enconly/ -name *.pth

m2d_as_vit_base-80x608p16x16p32k-240413_enconly/checkpoint-300.pth


## Example 1: Using the fine-tuned model for getting features

We load the model for getting features for the audio.

In [6]:
from portable_m2d import PortableM2D
model = PortableM2D(weight_file='m2d_as_vit_base-80x1001p16x16p32k-240413_AS-FT_enconly/weights_ep69it3124-0.47998.pth')

INFO:root:<All keys matched successfully>
INFO:root:Model input size: [80, 1001]
INFO:root:Using weights: m2d_as_vit_base-80x1001p16x16p32k-240413_AS-FT_enconly/weights_ep69it3124-0.47998.pth
INFO:root:Feature dimension: 3840
INFO:root:Norm stats: -6.957574844360352, 4.923122406005859
INFO:root:Runtime MelSpectrogram(32000, 800, 800, 320, 80, 50, 16000):
INFO:root:MelSpectrogram(
  Mel filter banks size = (80, 401), trainable_mel=False
  (stft): STFT(n_fft=800, Fourier Kernel size=(401, 1, 800), iSTFT=False, trainable=False)
)


 using 151 parameters, while dropped 9 out of 160 parameters from m2d_as_vit_base-80x1001p16x16p32k-240413_AS-FT_enconly/weights_ep69it3124-0.47998.pth
 (dropped: ['module.ar.runtime.to_spec.mel_basis', 'module.ar.runtime.to_spec.stft.wsin', 'module.ar.runtime.to_spec.stft.wcos', 'module.ar.runtime.to_spec.stft.window_mask', 'module.head.norm.running_mean'] ...)
<All keys matched successfully>


### Getting feature for raw audio

In [7]:
from IPython.display import display, Audio

files = list(Path('AudioSetWav16k/eval_segments').glob('*.wav'))

wav, _ = librosa.load(files[0], mono=True, sr=model.cfg.sample_rate)
display(Audio(wav, rate=model.cfg.sample_rate))
wav = torch.tensor(wav).unsqueeze(0)

with torch.no_grad():
    features = model(wav).squeeze(0)

print('Duration (sec)', wav.shape[-1] / model.cfg.sample_rate)
print('Number of frames', int((wav.shape[-1] + model.cfg.hop_size * model.cfg.patch_size[1] - 1) / (model.cfg.hop_size) / model.cfg.patch_size[1]))
print('Feature dimensions (frame, dimension)', features.shape)
features

Duration (sec) 7.848375
Number of frames 50
Feature dimensions (frame, dimension) torch.Size([50, 3840])


tensor([[-0.7707, -3.3738,  0.9126,  ...,  0.8008, -0.1172,  0.1381],
        [-0.6634, -2.6753,  1.1871,  ...,  0.8921, -0.0850,  0.1616],
        [-0.5289,  0.2070,  1.4464,  ...,  0.6510, -0.0574,  0.2789],
        ...,
        [-0.7452,  0.8443,  1.6440,  ...,  1.0685, -0.0808,  0.2771],
        [-0.3094, -1.3019,  1.6771,  ...,  0.9246, -0.1958,  0.3928],
        [-0.8025, -1.9669,  0.8107,  ...,  1.8111, -1.3075,  0.3630]])

As you may have found, the output of the M2D pre-trained model consists of 3840-d features that stack five 768-d features per frequency axis for each time frame.

You can average features along the frequency axis for each time frame to get 768-d features that might show better SED results.

In [8]:
model = PortableM2D(weight_file='m2d_as_vit_base-80x1001p16x16p32k-240413_AS-FT_enconly/weights_ep69it3124-0.47998.pth', flat_features=True)

with torch.no_grad():
    features = model(wav, average_per_time_frame=True).squeeze(0)

print('Feature dimensions (frame, dimension)', features.shape)
features

INFO:root:<All keys matched successfully>
INFO:root:Model input size: [80, 1001]
INFO:root:Using weights: m2d_as_vit_base-80x1001p16x16p32k-240413_AS-FT_enconly/weights_ep69it3124-0.47998.pth
INFO:root:Feature dimension: 768
INFO:root:Norm stats: -6.957574844360352, 4.923122406005859
INFO:root:Runtime MelSpectrogram(32000, 800, 800, 320, 80, 50, 16000):
INFO:root:MelSpectrogram(
  Mel filter banks size = (80, 401), trainable_mel=False
  (stft): STFT(n_fft=800, Fourier Kernel size=(401, 1, 800), iSTFT=False, trainable=False)
)


 using 151 parameters, while dropped 9 out of 160 parameters from m2d_as_vit_base-80x1001p16x16p32k-240413_AS-FT_enconly/weights_ep69it3124-0.47998.pth
 (dropped: ['module.ar.runtime.to_spec.mel_basis', 'module.ar.runtime.to_spec.stft.wsin', 'module.ar.runtime.to_spec.stft.wcos', 'module.ar.runtime.to_spec.stft.window_mask', 'module.head.norm.running_mean'] ...)
<All keys matched successfully>
Feature dimensions (frame, dimension) torch.Size([50, 768])


tensor([[-0.7146, -3.0938,  0.4316,  ...,  1.1008, -0.4324,  0.2317],
        [-0.6089, -2.2139,  0.6941,  ...,  1.1408, -0.3069,  0.2180],
        [-0.5863,  0.2654,  0.9498,  ...,  0.7830, -0.3708,  0.3043],
        ...,
        [-0.5535,  0.8131,  1.1728,  ...,  1.3421, -0.2334,  0.3127],
        [-0.2945, -1.1928,  1.2050,  ...,  1.1016, -0.3221,  0.4574],
        [-0.9051, -1.9163,  0.6895,  ...,  2.0121, -1.0794,  0.2841]])

### Getting clip-level feature

Average features along the time axis.

In [9]:
clip_level = features.mean(-2)  # Or simply 0. Please make sure to average at the time frame axis.
clip_level.shape, clip_level

(torch.Size([768]),
 tensor([-8.2208e-01, -2.3113e-01,  5.0966e-01, -8.6173e-01,  2.4128e-01,
          8.3806e-02, -2.8640e-01,  1.5146e+00,  8.1342e-01, -1.5676e-01,
          4.6441e-01, -6.8438e-01, -1.3593e-01,  2.8585e-01,  1.7892e-01,
          1.1848e-02, -6.0438e-01,  7.8305e-01, -2.5908e-01,  3.2183e-01,
         -6.5071e-01,  2.8048e-01,  1.0449e+00, -3.3696e-01,  1.1092e-01,
          2.2780e-01, -4.7895e-01,  1.4319e-01,  2.6760e-01, -1.0467e+00,
         -1.0409e-01, -4.0325e-01,  1.1485e+00, -6.9287e-01, -9.4560e-01,
         -3.4892e-01,  7.5113e-01, -1.0444e+00,  8.3146e-02, -8.5039e-01,
          1.1663e+00, -1.9453e-01, -3.6808e-01, -1.6599e-01,  3.7980e-01,
          3.4345e-01,  4.6431e-01, -2.5698e-01, -8.1584e-01, -8.1261e-01,
         -3.2231e-01, -8.7979e-01, -2.7383e-01,  4.9184e-01,  8.1145e-02,
         -3.3567e-01,  2.6732e-01,  1.4526e-01,  4.2770e-02, -8.1908e-01,
         -1.1213e+00,  1.0881e-01,  8.8395e-01, -8.9552e-01,  9.4076e-01,
          2.5043e-

## Example 2: Fine-tuned model for SED or tagging

The fine-tuned model can tag the events.

In [10]:
from portable_m2d import PortableM2D
model = PortableM2D(weight_file='m2d_as_vit_base-80x1001p16x16p32k-240413_AS-FT_enconly/weights_ep69it3124-0.47998.pth', num_classes=527)


INFO:root:<All keys matched successfully>
INFO:root:Model input size: [80, 1001]
INFO:root:Using weights: m2d_as_vit_base-80x1001p16x16p32k-240413_AS-FT_enconly/weights_ep69it3124-0.47998.pth
INFO:root:Feature dimension: 768
INFO:root:Norm stats: -6.957574844360352, 4.923122406005859
INFO:root:Runtime MelSpectrogram(32000, 800, 800, 320, 80, 50, 16000):
INFO:root:MelSpectrogram(
  Mel filter banks size = (80, 401), trainable_mel=False
  (stft): STFT(n_fft=800, Fourier Kernel size=(401, 1, 800), iSTFT=False, trainable=False)
)


 using 151 parameters, while dropped 9 out of 160 parameters from m2d_as_vit_base-80x1001p16x16p32k-240413_AS-FT_enconly/weights_ep69it3124-0.47998.pth
 (dropped: ['module.ar.runtime.to_spec.mel_basis', 'module.ar.runtime.to_spec.stft.wsin', 'module.ar.runtime.to_spec.stft.wcos', 'module.ar.runtime.to_spec.stft.window_mask', 'module.head.norm.running_mean'] ...)
<All keys matched successfully>


### An audio tagging example

In [11]:
from IPython.display import display, Audio

def show_topk(classes, m2d, wav_file, k=5):
    print(wav_file)
    # Loads and shows an audio clip.
    wav, _ = librosa.load(wav_file, mono=True, sr=m2d.cfg.sample_rate)
    display(Audio(wav, rate=m2d.cfg.sample_rate))
    wav = torch.tensor(wav).unsqueeze(0)
    # Predicts class probabilities for the batch segments.
    with torch.no_grad():
        probs = m2d(wav).squeeze(0).softmax(0)
    # Shows the top-k prediction results.
    topk_values, topk_indices = probs.topk(k=k)
    print(', '.join([f'{classes.loc[i].display_name} ({v*100:.1f}%)' for i, v in zip(topk_indices.numpy(), topk_values.numpy())]))
    print()

files = list(Path('AudioSetWav16k/eval_segments').glob('*.wav'))
files = np.random.choice(files, size=3, replace=False)

for fn in files:
    show_topk(classes, model, fn)

AudioSetWav16k/eval_segments/3tUlhM80ObM_0.000.wav


Knock (95.9%), Sound effect (0.9%), Music (0.8%), Silence (0.4%), Rattle (0.2%)

AudioSetWav16k/eval_segments/-22tna7KHzI_28.000.wav


Sound effect (9.4%), Explosion (8.5%), Whoosh, swoosh, swish (8.4%), Eruption (7.0%), White noise (4.7%)

AudioSetWav16k/eval_segments/-1nilez17Dg_30.000.wav


Heart sounds, heartbeat (83.1%), Heart murmur (11.6%), Speech (2.7%), Throbbing (0.9%), Hum (0.6%)



### Audio tagging with sliding window

The following demonstrates the progress of audio tags over seconds, like a sound event detection.

In [12]:
def repeat_if_short(w, min_duration=48000):
    while w.shape[-1] < min_duration:
        w = np.concatenate([w, w], axis=-1)
    return w[..., :min_duration]

def show_topk_sliding_window(classes, m2d, wav_file, k=5, hop=1, duration=2.):
    print(wav_file)
    # Loads and shows an audio clip.
    wav, sr = librosa.load(wav_file, mono=True, sr=m2d.cfg.sample_rate)
    display(Audio(wav, rate=sr))
    # Makes a batch of short segments of the wav into wavs, cropped by the sliding window of [hop, duration].
    wavs = [wav[int(c * sr) : int((c + duration) * sr)] for c in np.arange(0, wav.shape[-1] / sr, hop)]
    wavs = [repeat_if_short(wav) for wav in wavs]
    wavs = torch.tensor(wavs)
    # Predicts class probabilities for the batch segments.
    with torch.no_grad():
        probs_per_chunk = m2d(wavs).softmax(1)
    # Shows the top-k prediction results.
    for i, probs in enumerate(probs_per_chunk):
        topk_values, topk_indices = probs.topk(k=k)
        sec = f'{i * hop:d}s '
        print(sec, ', '.join([f'{classes.loc[i].display_name} ({v*100:.1f}%)' for i, v in zip(topk_indices.numpy(), topk_values.numpy())]))
    print()

for fn in files:
    show_topk_sliding_window(classes, model, fn)

AudioSetWav16k/eval_segments/3tUlhM80ObM_0.000.wav


0s  Knock (99.4%), Music (0.2%), Sound effect (0.1%), Chopping (food) (0.1%), Silence (0.0%)
1s  Knock (86.9%), Music (3.5%), Silence (3.1%), Chopping (food) (2.7%), Sound effect (0.3%)
2s  Silence (80.4%), Music (3.3%), Speech (1.3%), Inside, small room (0.6%), Musical instrument (0.4%)
3s  Silence (80.4%), Music (3.3%), Speech (1.3%), Inside, small room (0.6%), Musical instrument (0.4%)
4s  Silence (80.4%), Music (3.3%), Speech (1.3%), Inside, small room (0.6%), Musical instrument (0.4%)
5s  Knock (60.7%), Music (13.7%), Silence (11.3%), Musical instrument (1.4%), Wood block (1.1%)
6s  Knock (99.6%), Sound effect (0.1%), Music (0.1%), Chopping (food) (0.0%), Silence (0.0%)
7s  Music (33.0%), Heart murmur (21.3%), Synthesizer (7.8%), Knock (3.0%), Silence (2.1%)

AudioSetWav16k/eval_segments/-22tna7KHzI_28.000.wav


0s  Explosion (44.8%), Eruption (13.0%), White noise (10.2%), Whoosh, swoosh, swish (2.6%), Burst, pop (2.2%)
1s  Explosion (30.3%), Eruption (13.6%), White noise (12.5%), Rumble (6.6%), Whoosh, swoosh, swish (3.0%)
2s  White noise (37.8%), Rumble (11.9%), Eruption (7.4%), Explosion (6.3%), Field recording (4.4%)
3s  White noise (46.1%), Rumble (11.7%), Eruption (8.5%), Explosion (7.9%), Field recording (5.0%)
4s  Explosion (52.0%), Eruption (13.4%), White noise (8.5%), Whoosh, swoosh, swish (5.5%), Burst, pop (1.7%)
5s  White noise (28.5%), Whoosh, swoosh, swish (12.5%), Explosion (11.8%), Rumble (5.0%), Eruption (3.9%)
6s  White noise (36.5%), Rumble (18.3%), Eruption (6.8%), Explosion (4.1%), Field recording (2.7%)
7s  Explosion (50.7%), Eruption (12.0%), Whoosh, swoosh, swish (7.7%), Burst, pop (3.7%), Sound effect (3.4%)
8s  Whoosh, swoosh, swish (60.8%), Sound effect (15.2%), Explosion (13.0%), Burst, pop (1.4%), Eruption (0.7%)
9s  Sound effect (48.0%), Bouncing (6.0%), Echo (4.

0s  Heart sounds, heartbeat (88.2%), Heart murmur (9.5%), Throbbing (1.5%), Hum (0.4%), Pulse (0.0%)
1s  Heart sounds, heartbeat (91.8%), Heart murmur (4.5%), Throbbing (2.4%), Hum (0.4%), Pulse (0.1%)
2s  Heart sounds, heartbeat (42.2%), Speech (34.5%), Heart murmur (5.6%), Hum (4.0%), Inside, small room (3.4%)
3s  Speech (84.3%), Female speech, woman speaking (7.7%), Narration, monologue (6.2%), Inside, small room (0.1%), Conversation (0.1%)
4s  Speech (85.1%), Female speech, woman speaking (12.3%), Narration, monologue (1.1%), Conversation (0.2%), Inside, small room (0.2%)
5s  Speech (71.0%), Female speech, woman speaking (25.7%), Narration, monologue (2.5%), Conversation (0.1%), Male speech, man speaking (0.1%)
6s  Speech (96.6%), Female speech, woman speaking (1.9%), Narration, monologue (0.5%), Inside, small room (0.1%), Zipper (clothing) (0.1%)
7s  Speech (73.5%), Female speech, woman speaking (22.9%), Narration, monologue (2.3%), Inside, small room (0.2%), Conversation (0.1%)
8

### Audio tagging for all available frames

A ViT splits inputs into patches for both frequency and time axes. (ex., An 80x1001 10-s spectrogram will be 5x62=310 patches)

Then, ViT encodes all patches into embeddings $X$. (e.g., $X \in R^{5\times62\times768}$, where 768 is a feature dimension)

The forward_frames function in our wrapper class PortableM2D summarizes $X$ into $X' \in R^{62\times768}$, embeddings averaged per each time frame, and predicts $Y \in R^{62\times527}$, logits for each frame. (527 is the number of classes)

In summary, 62 frames are available for a 10-s audio, and each frame has 527 logits that become 527 class probabilities after a softmax operation.

In [13]:
def show_topk_for_all_frames(classes, m2d, wav_file, k=5):
    print(wav_file)
    # Loads and shows an audio clip.
    wav, _ = librosa.load(wav_file, mono=True, sr=m2d.cfg.sample_rate)
    display(Audio(wav, rate=m2d.cfg.sample_rate))
    wav = torch.tensor(wav)
    # Predicts class probabilities for all frames.
    with torch.no_grad():
        logits_per_chunk, timestamps = m2d.forward_frames(wav.unsqueeze(0))  # logits_per_chunk: [1, 62, 527], timestamps: [1, 62]
        probs_per_chunk = logits_per_chunk.squeeze(0).softmax(-1)  # logits [1, 62, 527] -> probabilities [62, 527]
        timestamps = timestamps[0]  # [1, 62] -> [62]
    # Shows the top-k prediction results.
    for i, (probs, ts) in enumerate(zip(probs_per_chunk, timestamps)):
        topk_values, topk_indices = probs.topk(k=k)
        sec = f'{ts/1000:.1f}s '
        print(sec, ', '.join([f'{classes.loc[i].display_name} ({v*100:.1f}%)' for i, v in zip(topk_indices.numpy(), topk_values.numpy())]))
    print()

show_topk_for_all_frames(classes, model, files[0])

AudioSetWav16k/eval_segments/3tUlhM80ObM_0.000.wav


0.0s  Knock (93.6%), Sound effect (2.3%), Silence (2.1%), Music (0.4%), Chopping (food) (0.2%)
0.2s  Knock (87.4%), Sound effect (4.4%), Silence (3.1%), Music (1.0%), Chopping (food) (0.7%)
0.3s  Knock (80.7%), Silence (4.4%), Sound effect (3.6%), Music (2.4%), Rattle (0.4%)
0.5s  Knock (93.6%), Music (3.8%), Sound effect (0.7%), Rattle (0.2%), Sampler (0.2%)
0.6s  Knock (91.5%), Music (2.3%), Sound effect (1.6%), Rodents, rats, mice (0.9%), Chopping (food) (0.4%)
0.8s  Knock (97.0%), Music (0.8%), Sound effect (0.5%), Oink (0.4%), Chopping (food) (0.3%)
0.9s  Knock (93.5%), Chopping (food) (1.8%), Sound effect (1.1%), Music (0.7%), Dishes, pots, and pans (0.5%)
1.1s  Knock (99.0%), Sound effect (0.2%), Chopping (food) (0.1%), Speech (0.1%), Music (0.1%)
1.3s  Knock (95.7%), Whack, thwack (1.3%), Chopping (food) (0.7%), Music (0.7%), Bouncing (0.3%)
1.4s  Knock (87.6%), Sound effect (5.1%), Music (1.3%), Wood block (0.9%), Silence (0.7%)
1.6s  Knock (85.3%), Sound effect (6.1%), Silenc

## Example 3: Use a pre-trained only model (not fine-tuned) model

You can use the m2d_as_vit_base-80x1001p16x16p32k-240413_enconly weight.

In [14]:
from portable_m2d import PortableM2D
model = PortableM2D(weight_file='m2d_as_vit_base-80x608p16x16p32k-240413_enconly/checkpoint-300.pth')

INFO:root:<All keys matched successfully>
INFO:root:Model input size: [80, 608]
INFO:root:Using weights: m2d_as_vit_base-80x608p16x16p32k-240413_enconly/checkpoint-300.pth
INFO:root:Feature dimension: 3840
INFO:root:Norm stats: -6.957574844360352, 4.923122406005859
INFO:root:Runtime MelSpectrogram(32000, 800, 800, 320, 80, 50, 16000):
INFO:root:MelSpectrogram(
  Mel filter banks size = (80, 401), trainable_mel=False
  (stft): STFT(n_fft=800, Fourier Kernel size=(401, 1, 800), iSTFT=False, trainable=False)
)


 using 151 parameters from m2d_as_vit_base-80x608p16x16p32k-240413_enconly/checkpoint-300.pth
 (dropped: [] )
<All keys matched successfully>


### Getting feature for raw audio

In [15]:
from IPython.display import display, Audio

files = list(Path('AudioSetWav16k/eval_segments').glob('*.wav'))

wav, _ = librosa.load(files[0], mono=True, sr=model.cfg.sample_rate)
wav = torch.tensor(wav).unsqueeze(0)

with torch.no_grad():
    features = model(wav).squeeze(0)

print('Duration (sec)', wav.shape[-1] / model.cfg.sample_rate)
print('Number of frames', int((wav.shape[-1] + model.cfg.hop_size * model.cfg.patch_size[1] - 1) / (model.cfg.hop_size) / model.cfg.patch_size[1]))
print('Feature dimensions (frame, dimension)', features.shape)
features

Duration (sec) 7.848375
Number of frames 50
Feature dimensions (frame, dimension) torch.Size([50, 3840])


tensor([[-0.3146, -3.9031,  1.0022,  ...,  0.1946,  0.5961, -0.2602],
        [-0.2664,  0.3402,  1.3693,  ...,  0.5310,  0.5292,  0.0199],
        [-0.5151,  4.6126,  1.6979,  ...,  0.4123,  0.4939,  0.1638],
        ...,
        [-0.0172,  4.1656,  2.3489,  ...,  0.0362,  0.7520, -0.1036],
        [-0.3306,  0.3951,  1.9232,  ..., -0.4530,  0.6817, -0.0958],
        [ 0.2220, -4.2947, -0.2630,  ...,  1.0635, -2.0272,  0.1666]])

As you may have found, the output of the M2D pre-trained model consists of 3840-d features that stack five 768-d features per frequency axis for each time frame.

You can average features along the frequency axis for each time frame to get 768-d features that might show better SED results.

In [16]:
model = PortableM2D(weight_file='m2d_as_vit_base-80x608p16x16p32k-240413_enconly/checkpoint-300.pth', flat_features=True)

with torch.no_grad():
    features = model(wav, average_per_time_frame=True).squeeze(0)

print('Feature dimensions (frame, dimension)', features.shape)
features

INFO:root:<All keys matched successfully>
INFO:root:Model input size: [80, 608]
INFO:root:Using weights: m2d_as_vit_base-80x608p16x16p32k-240413_enconly/checkpoint-300.pth
INFO:root:Feature dimension: 768
INFO:root:Norm stats: -6.957574844360352, 4.923122406005859
INFO:root:Runtime MelSpectrogram(32000, 800, 800, 320, 80, 50, 16000):
INFO:root:MelSpectrogram(
  Mel filter banks size = (80, 401), trainable_mel=False
  (stft): STFT(n_fft=800, Fourier Kernel size=(401, 1, 800), iSTFT=False, trainable=False)
)


 using 151 parameters from m2d_as_vit_base-80x608p16x16p32k-240413_enconly/checkpoint-300.pth
 (dropped: [] )
<All keys matched successfully>
Feature dimensions (frame, dimension) torch.Size([50, 768])


tensor([[-0.3426, -3.6046,  0.2380,  ...,  0.6638,  0.0142, -0.1915],
        [-0.2013,  0.5296,  0.6978,  ...,  1.1017, -0.0171,  0.0930],
        [-0.4502,  4.7145,  1.0644,  ...,  0.7539, -0.3713,  0.1416],
        ...,
        [ 0.1390,  4.2695,  1.6372,  ...,  0.7442,  0.3882,  0.0471],
        [-0.3462,  0.6596,  1.3660,  ...,  0.2081,  0.3408,  0.0663],
        [ 0.5234, -4.3592, -0.2124,  ...,  1.4339, -1.0066, -0.1325]])

### Getting clip-level feature

Average features along the time axis.

In [17]:
clip_level = features.mean(-2)  # Or simply 0. Please make sure to average at the time frame axis.
clip_level.shape, clip_level

(torch.Size([768]),
 tensor([-1.7680e-01, -2.1681e-01,  2.8168e-01,  5.3847e-01,  1.6550e-01,
          1.3680e-01,  2.7224e-01, -4.3612e-01,  1.1943e-01, -6.3486e-01,
         -1.0796e+00,  1.2167e-01, -3.5008e-01,  1.8700e-01,  1.2298e-01,
         -3.4330e-01,  1.7071e-01, -3.5502e-01,  2.5027e-01, -3.9269e-01,
         -2.9977e-01, -4.2100e-01,  4.4574e-01, -3.8586e-01,  3.7667e-01,
          2.3137e-01, -7.0861e-01,  3.2073e-01,  4.9340e-03, -1.7379e-01,
         -5.0967e-01, -1.3372e-01, -7.2402e-02, -9.9767e-01, -2.0125e-01,
         -1.0943e-01,  1.0334e-01,  2.9785e-01, -3.1583e-01, -2.6530e-01,
          1.8449e-01, -4.5851e-02,  2.4763e-01, -1.5177e-01, -4.6638e-01,
         -5.7079e-01,  1.4816e-01, -1.9007e-01, -1.6070e-01, -1.6379e-01,
          6.0893e-01, -1.4595e-01, -5.1740e-01, -4.3270e-01,  3.5188e-01,
          6.5143e-01,  2.2737e-01,  3.5120e-01,  2.5439e-01, -6.9190e-01,
         -9.9554e-02, -8.7255e-02, -2.3036e-01, -3.3670e-01, -2.2351e-01,
          1.2041e-