# 4. Benchmark `webdataset`

In [1]:
import webdataset as wds
import io
import numpy as np
import torchaudio
import torch
from m5 import M5
from torch import nn
import torch.nn.functional as F
from pathlib import Path
from tqdm import tqdm
import soundfile as sf
import json
import IPython.display as ipd
from utils import time_me, SAMPLE_RATE, count_parameters

In [2]:
def parse(datum):
    audio, _ = sf.read(io.BytesIO(datum["flac"]), dtype="float32")
    datum["audio"] = audio
    del datum["flac"]
    for key, val in json.loads(datum["json"]).items():
        datum[key] = val
    del datum["json"]
    del datum["sentence"]
    del datum["duration"]
    return datum

In [14]:
# webdataset_directory = Path("./data/webdataset_small_chunk/")
# urls = list(map(str, webdataset_directory.glob("*.tar")))

urls = "https://storage.googleapis.com/hpml-project/webdataset/common_voice_{000..835}.tar"



In [15]:
def cropper(crop_duration: float):
    safe_keys = ["__key__", "__url__", "gender", "client", "age", "accent"]
    crop_samples = int(SAMPLE_RATE * crop_duration)

    def crop(src):
        for datum in src:
            num_samples = datum["audio"].shape[0]
            assert num_samples != 0
            if num_samples < crop_samples:
                # pad
                datum["audio"] = np.pad(datum["audio"], (0, crop_samples - num_samples))
                assert datum["audio"].shape[0] == crop_samples
                yield datum

            elif num_samples > crop_samples:
                full_crops = num_samples // crop_samples
                for i in range(full_crops):
                    new_datum = {key: datum[key] for key in safe_keys}
                    new_datum["audio"] = datum["audio"][
                        i * crop_samples : (i + 1) * crop_samples
                    ]
                    assert new_datum["audio"].shape[0] == crop_samples
                    yield new_datum

                leftover = datum["audio"][full_crops * crop_samples :]
                leftover_samples = leftover.shape[0]
                if leftover_samples > crop_samples / 2:
                    datum["audio"] = np.pad(
                        leftover, (0, crop_samples - leftover_samples)
                    )
                    assert datum["audio"].shape[0] == crop_samples

                    yield datum

    return crop

In [16]:
data_pipeline = [
    wds.SimpleShardList(urls),
    wds.split_by_worker,
    wds.shuffle(),
    wds.tarfile_to_samples(),
    wds.map(parse),
    cropper(crop_duration=3.0),
    wds.shuffle(32),
    wds.to_tuple("audio", "accent", "gender", "age"),
    wds.batched(batchsize=32),
]

In [17]:
webdataset = wds.DataPipeline(*data_pipeline)

In [18]:
loader = wds.WebLoader(webdataset, num_workers=8, batch_size=None, prefetch_factor=12, pin_memory=True)

In [19]:
m5 = M5().cuda()
print(f"Number of parameters of M5: {count_parameters(m5):,}")

Number of parameters of M5: 26,074,659


In [20]:
for audio, accent, gender, age in tqdm(loader):
    # wav2vec2_features = feature_extractor(audio, sampling_rate=16_000, return_tensors="pt").input_values[0].cuda()
    # with torch.no_grad():
    #     wav2vec2_embeddings = wav2vec2(features).last_hidden_state
    with torch.no_grad():
        m5(audio[:, None, :].cuda())

5650it [04:17, 21.90it/s]Exception ignored from cffi callback <function SoundFile._init_virtual_io.<locals>.vio_tell at 0x7f6dcef2a840>:
Traceback (most recent call last):
  File "/home/kaandonbekci/miniconda3/envs/ccu/lib/python3.11/site-packages/soundfile.py", line 1264, in vio_tell
    @_ffi.callback("sf_vio_tell")

KeyboardInterrupt


KeyboardInterrupt: 

: 
