# 4. Benchmark `webdataset`

In [8]:
import webdataset as wds
import io
import numpy as np
from pathlib import Path
import soundfile as sf
import json
import IPython.display as ipd
from utils import time_me, SAMPLE_RATE
from benchmarks import run_all_benchmarks, run_cpu_benchmark

In [2]:
def webdataset_parser():
    def parse(datum):
        audio, _ = sf.read(io.BytesIO(datum["flac"]), dtype="float32")
        audio = audio / np.abs(audio).max()  # peak normalize
        datum["audio"] = audio
        del datum["flac"]
        for key, val in json.loads(datum["json"]).items():
            datum[key] = val
        del datum["json"]
        del datum["sentence"]
        del datum["duration"]
        return datum

    return wds.map(parse)

In [3]:
def webdataset_cropper(crop_duration: float):
    crop_samples = int(SAMPLE_RATE * crop_duration)

    def crop(datum):
        num_samples = datum["audio"].shape[0]
        assert num_samples != 0
        if num_samples < crop_samples:
            # pad if input is short
            datum["audio"] = np.pad(datum["audio"], (0, crop_samples - num_samples))
        elif num_samples > crop_samples:
            # crop if it is too long
            rand_start = np.random.randint(0, num_samples - crop_samples)
            datum["audio"] = datum["audio"][rand_start : rand_start + crop_samples]
        else:
            # just right :)
            pass
        assert datum["audio"].shape[0] == crop_samples
        datum["audio"] = datum["audio"][None, :]
        return datum

    return wds.map(crop)

In [4]:
def webdataset_redict():
    def redict(datum):
        return {
            "audio": datum[0],
            "accent": datum[1],
            "age": datum[2],
            "gender": datum[3],
        }

    return wds.map(redict)

In [5]:
@time_me
def build_webdataset_dataloader(
    urls,
    crop_duration=3.0,
    batch_size=32,
    shuffle_buffer=2048,
    num_workers=4,
    pin_memory=True,
):
    data_pipeline = [
        wds.SimpleShardList(urls),
        wds.shuffle(),
        wds.split_by_worker,
        wds.tarfile_to_samples(),
        wds.shuffle(bufsize=shuffle_buffer, initial=shuffle_buffer),
        webdataset_parser(),
        webdataset_cropper(crop_duration=crop_duration),
        wds.to_tuple("audio", "accent", "gender", "age"),
        wds.batched(batchsize=batch_size),
        webdataset_redict(),
    ]
    webdataset = wds.DataPipeline(*data_pipeline)
    dataloader = wds.WebLoader(
        webdataset,
        num_workers=num_workers,
        batch_size=None,
        prefetch_factor=2,
        pin_memory=pin_memory,
    )
    single_worker_dataloader = wds.WebLoader(
        webdataset,
        num_workers=0,
        batch_size=None,
        prefetch_factor=None,
        pin_memory=pin_memory,
    )
    return dataloader, single_worker_dataloader

In [6]:
local_urls = "./data/webdataset/common_voice_{000..835}.tar"

local_urls_small_chunk = "./data/webdataset_small_chunk/common_voice_{00000..26728}.tar"


cloud_urls = (
    "https://storage.googleapis.com/hpml-project/webdataset/common_voice_{000..835}.tar"
)
cloud_urls_small_chunk = "https://storage.googleapis.com/hpml-project/webdataset_small_chunk/common_voice_{00000..26728}.tar"

In [9]:
tag = "webdataset_local"
dataloader, single_worker_dataloader = build_webdataset_dataloader(local_urls)
run_all_benchmarks(dataloader, single_worker_dataloader, tag=tag)

Number of parameters of model: 64,628,259


STAGE:2023-12-17 18:22:10 3417434:3417434 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
100%|███████████████████████████████████████████████████████████████████████████████████| 200/200 [00:27<00:00,  7.27it/s]
STAGE:2023-12-17 18:22:37 3417434:3417434 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-12-17 18:22:37 3417434:3417434 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


Time to 1st batch: 1.87 seconds


In [10]:
tag = "webdataset_local_small_chunk"
dataloader, single_worker_dataloader = build_webdataset_dataloader(
    local_urls_small_chunk
)
run_all_benchmarks(dataloader, single_worker_dataloader, tag=tag)

STAGE:2023-12-17 18:25:49 3417434:3417434 ActivityProfilerController.cpp:311] Completed Stage: Warm Up


Number of parameters of model: 64,628,259


100%|███████████████████████████████████████████████████████████████████████████████████| 200/200 [00:27<00:00,  7.36it/s]
STAGE:2023-12-17 18:26:16 3417434:3417434 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-12-17 18:26:16 3417434:3417434 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


Time to 1st batch: 2.11 seconds


In [11]:
tag = "webdataset_cloud"
dataloader, single_worker_dataloader = build_webdataset_dataloader(cloud_urls)
run_all_benchmarks(dataloader, single_worker_dataloader, tag=tag)

Number of parameters of model: 64,628,259


STAGE:2023-12-17 18:26:39 3417434:3417434 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
100%|███████████████████████████████████████████████████████████████████████████████████| 200/200 [00:58<00:00,  3.39it/s]
STAGE:2023-12-17 18:27:38 3417434:3417434 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-12-17 18:27:38 3417434:3417434 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


Time to 1st batch: 11.52 seconds


In [12]:
tag = "webdataset_cloud_small_chunk"
dataloader, single_worker_dataloader = build_webdataset_dataloader(
    cloud_urls_small_chunk
)
run_all_benchmarks(dataloader, single_worker_dataloader, tag=tag)

STAGE:2023-12-17 18:28:25 3417434:3417434 ActivityProfilerController.cpp:311] Completed Stage: Warm Up


Number of parameters of model: 64,628,259


100%|███████████████████████████████████████████████████████████████████████████████████| 200/200 [01:40<00:00,  1.99it/s]
STAGE:2023-12-17 18:30:05 3417434:3417434 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-12-17 18:30:05 3417434:3417434 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


Time to 1st batch: 36.23 seconds
