In [None]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt

from dataset_benchmark.datasets import NpzBenchmarkDataset, HfArrowBenchmarkDataset
from dataset_benchmark.evaluation import evaluate_batched_loading

In [None]:
results_batched_dataloading = evaluate_batched_loading(
    dataset_classes=[NpzBenchmarkDataset],
    n_rows=500,
    n_values_per_row=4_000_000,
    batch_sizes=[4, 8], #, 16, 32],
    shuffle=[True, False],
    n_dataloader_workers=[0],
    dataloder_pin_memory=[True, False],
    max_dataset_size=100_000_000_000,
    max_batch_size=4_000_000_000,
    n_repeats=2,
    dataset_init_kwargs=None,
)

df_batched = pd.DataFrame(results_batched_dataloading)
df_batched["total_bytes"] = df_batched["n_rows"] * df_batched["n_values_per_row"] * 8
df_batched["mb_per_second"] = df_batched["total_bytes"] / df_batched["time"] / 1_000_000
df_batched["rows_per_second"] = df_batched["n_rows"] / df_batched["time"]

In [None]:
# visualize the results for the batched loading

sns.relplot(
    data=df_batched,
    x="batch_size",
    y="mb_per_second",
    hue=df_batched["n_dataloader_workers"].astype(str),
    style=df_batched["shuffle"],
    size=df_batched["dataloder_pin_memory"],
    errorbar=("pi", 100),
    markers=True,
    col="datset_type",
    kind="line",

)
plt.show()

sns.relplot(
    data=df_batched,
    x="batch_size",
    y="rows_per_second",
    hue=df_batched["n_dataloader_workers"].astype(str),
    style=df_batched["shuffle"],
    size=df_batched["dataloder_pin_memory"],
    errorbar=("pi", 100),
    markers=True,
    col="datset_type",
    kind="line",
)
plt.show()