In [None]:
#| default_exp extract_metrics

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import sys
import os
from os.path import expanduser
import itertools
from pathlib import Path

import numpy as np
import torch
import torchaudio
import torch.nn.functional as F
from torch.profiler import profile, record_function, ProfilerActivity

from fastprogress import progress_bar
from fastcore.script import *

from pyannote.audio import Model
from brouhaha.pipeline import RegressiveActivityDetectionPipeline
from WhisperSpeech.whisperspeech import vq_stoks, utils, vad_merge
import webdataset as wds

from WhisperSpeech.whisperspeech.inference import get_compute_device

  torchaudio.set_audio_backend("soundfile")
  torchaudio.set_audio_backend("soundfile")


# Semantic token extraction

We take a webdataset shard and extract acoustic and semantic tokens from it.

We don't use the VAD data since the S2A should work on any random 30 second window.

In [None]:
#| exporti
@call_parse
def prepare_metrics(
    input:str,  # audio file webdataset file path
    output:str, # output shard path
    n_samples:int=None, # process a limited amount of samples
    
):
    device = get_compute_device()

    model = Model.from_pretrained(expanduser('~/.cache/brouhaha.ckpt'), strict=False)
    snr_pipeline = RegressiveActivityDetectionPipeline(segmentation=model).to(torch.device(device))
        
    total = n_samples if n_samples else 'noinfer'

    if total == 'noinfer':
        import math, time
        start = time.time()
        ds = wds.WebDataset([utils.derived_name(input, 'mvad')]).decode()
        total = math.ceil(sum([len(x[f'max.spk_emb.npy']) for x in ds]))
        print(f"Counting {total} batches: {time.time()-start:.2f}")
    
    ds = vad_merge.chunked_audio_dataset([input], 'max').compose(
        wds.to_tuple('__key__', 'rpad', 'gain_shift.npy', 'samples', 'sample_rate'),
    )

    dl = wds.WebLoader(ds, num_workers=1, batch_size=None)
    
    with utils.AtomicTarWriter(output, throwaway=n_samples is not None) as sink:
        for keys, rpad, gain_shift, samples, sr in progress_bar(dl, total=total):
            with torch.no_grad():
                snd = samples
                if rpad > 0: snd = snd[:-rpad]
                snd = (snd - gain_shift[1]) * gain_shift[0]
                snd = snd.unsqueeze(0).to(device)

                res = snr_pipeline({
                    "sample_rate": sr, "waveform": snd
                })

            s = {
                "__key__": keys,
                "snr_c50.npy": np.array([res['snr'].mean(), res['c50'].mean()])
            }
            sink.write(s)
        sys.stdout.write("\n")

In [None]:
%pdb

Automatic pdb calling has been turned ON


In [None]:
prepare_metrics('/data2/mls-polish/audio/mls_polish_train-000000.tar', '/data2/mls-polish/snr-c50/mls_polish_train-000000.tar.gz', n_samples=1024)

Lightning automatically upgraded your loaded checkpoint from v1.6.5 to v2.1.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/brouhaha.ckpt`


Model was trained with pyannote.audio 0.0.1, yours is 3.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.12.1+cu102, yours is 2.2.2+cu121. Bad things might happen unless you revert torch to 1.x.


Using default parameters optimized on Brouhaha



In [None]:
#| hide
import nbdev; nbdev.nbdev_export()