# Datasets utils

In [None]:
#|default_exp hf.datasets.utils

In [None]:
#|hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#|export

from datasets import concatenate_datasets, DatasetDict, load_dataset, Dataset

from bellem.logging import get_logger
from bellem.utils import chunk_random

log = get_logger(__name__)

In [None]:
#|export

def load_datasets(dataset_kwargs_list: list[dict]) -> Dataset:
    datasets = []
    for dataset_kwargs in dataset_kwargs_list:
        dataset = load_dataset(**dataset_kwargs)
        datasets.append(dataset)
    return concatenate_datasets(datasets)

In [None]:
#|export

def concatenate_dataset_dicts(dataset_dicts: list[DatasetDict]):
    """
    Concatenate multiple `DatasetDict` objects into a single `DatasetDict`.
    """
    splits = {key for dd in dataset_dicts for key in dd.keys()}
    output_dsd = {}
    for split in splits:
        dataset_list = []
        for dd in dataset_dicts:
            if split in dd.keys():
                dataset_list.append(dd[split])
        output_dsd[split] = concatenate_datasets(dataset_list)
    return DatasetDict(output_dsd)


In [None]:
dsds = [
    load_dataset("bdsaglam/webnlg-jerx-sft-st-ms-openai"),
    load_dataset("bdsaglam/musique-jerx-sft-st-ms-openai"),
]
dsd = concatenate_dataset_dicts(dsds)
print(len(dsd['train']))

Downloading readme:   0%|          | 0.00/545 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.17M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/523k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/807k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/17733 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/2235 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3661 [00:00<?, ? examples/s]

Downloading readme:   0%|          | 0.00/341 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/43.8k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/56 [00:00<?, ? examples/s]

17789


In [None]:
#|export

def chunk_random_dataset(ds, min_chunk=1, max_chunk=3):
    for indices in chunk_random(range(len(ds)), min_chunk, max_chunk):
        yield ds.select(indices).to_list()

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()