In [59]:
import json 
from collections import OrderedDict
import glob
import webdataset as wds
from torch.utils.data import DataLoader
from pprint import pprint
import os

In [97]:
def read_json(file):
    with open(file, 'r') as f:
        # data = json.load(f)
        data = json.load(f, object_pairs_hook=OrderedDict)
        
    return data 

def load_sample_simple(sample):
    data = {}
    json_data = json.loads(sample["json"])
    for k, v in json_data.items():
        if v is not None:
            data[k] = v 
            
    image_keys = ['jpg', 'png', 'jpeg', 'bmp', 'tif', 'tiff']
    for k in image_keys:
        if k in sample:
            data["image_key"] = k 
            break 
    
    # # specical for instruction_tune
    # data.update(
    #     {
    #         "question": json_data.get("question", ""),
    #         "response": json_data.get("response", ""),
    #         "instruction": json_data.get("instruction", ""),
    #     }
    # )
    
    # if "response" not in json_data and "text" in json_data:
    #     data["response"] = data["text"]
                
    return data 

def test_wds_iterator(wd_path, num_samples=1):
    search_path = os.path.join(wd_path, "**", "*.tar")
    train_urls = glob.glob(search_path, recursive=True)
    
    # files = glob.glob(wd_path + '/**/*', recursive=True)
    
    dataset = wds.DataPipeline(
        wds.SimpleShardList(train_urls),
        wds.tariterators.tarfile_to_samples(),
        wds.map(load_sample_simple)
    )
    
    dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
    
    for i, sample in enumerate(dataloader):
        if i == num_samples:
            break 
        
        pprint(sample)

In [98]:
config_json = "/fsx_0/user/tranx/rsync/llm_mm_aligner/experiments/aws/mm10/stage2/MH21_70B_224px_0916_exp32a_20240927.json"
config = read_json(config_json)

wd_data_path = config['trainer_args']['wd_data_path']
print(wd_data_path) 

verification_results = {}

for path in wd_data_path:
    path = path.split(":")[0]
    print(f"Testing: {path}")
    try:
        test_wds_iterator(path)
        verification_results[path] = True
    except Exception as e:
        print(e)
        verification_results[path] = False

['/fsx_0/user/yetian12/datasets/pdfa-eng-wds-converted', '/fsx_0/user/yetian12/datasets/idl-wds_v2', '/fsx_1/datasets_30days/sg_mmllm_stage_2_exp31a/20240905', '/fsx_0/datasets_30days/flywheel/sg_mmllm_sft_sg_data_curation_flywheel_6p8k_0916:300', '/fsx_1/datasets_30days/sg_mmllm_stage_2_exp32a_subtable_ocr', '/fsx_1/datasets_30days/sg_mmllm_stage_2_exp32a_subtable_sg_qa', '/fsx_1/datasets_30days/datarecipe_source/sg_mmllm_stage2_compliant_cap_qa_exp28_kosher_v2/20240827v4/sg_mmllm_coco_captions_blur', '/fsx_1/datasets_30days/datarecipe_source/sg_mmllm_stage2_compliant_cap_qa_exp28_kosher_v2/20240827v4/sg_mmllm_coco_object_count_rewritten_cap_46k:25', '/fsx_1/datasets_30days/mmllm_m2c2_mitigated_id_inc_filtered_clip_0_36_unique_id']
Testing: /fsx_0/user/yetian12/datasets/pdfa-eng-wds-converted
{'__key__': ['0004531-0'],
 'image_key': ['png'],
 'text': ['Pliocene Ridge Community Services District Policies & Procedures: '
          'OPERATIONS Policy 3015 NON-DISCRIMINATION/AFFIRMATIVE A

In [99]:
verification_results

{'/fsx_0/user/yetian12/datasets/pdfa-eng-wds-converted': True,
 '/fsx_0/user/yetian12/datasets/idl-wds_v2': True,
 '/fsx_1/datasets_30days/sg_mmllm_stage_2_exp31a/20240905': True,
 '/fsx_0/datasets_30days/flywheel/sg_mmllm_sft_sg_data_curation_flywheel_6p8k_0916': True,
 '/fsx_1/datasets_30days/sg_mmllm_stage_2_exp32a_subtable_ocr': True,
 '/fsx_1/datasets_30days/sg_mmllm_stage_2_exp32a_subtable_sg_qa': True,
 '/fsx_1/datasets_30days/datarecipe_source/sg_mmllm_stage2_compliant_cap_qa_exp28_kosher_v2/20240827v4/sg_mmllm_coco_captions_blur': True,
 '/fsx_1/datasets_30days/datarecipe_source/sg_mmllm_stage2_compliant_cap_qa_exp28_kosher_v2/20240827v4/sg_mmllm_coco_object_count_rewritten_cap_46k': True,
 '/fsx_1/datasets_30days/mmllm_m2c2_mitigated_id_inc_filtered_clip_0_36_unique_id': True}

# Test wds

In [110]:
import re 
import random 

def _get_sampling_multiplier(path: str) -> tuple[str, float]:
    """Split path and get sampling multiplier, if it exists.
    :param path: path with or without a sampling multiplier
    :return: path without the multiplier, sampling multiplier or None
    Examples:
        _get_sampling_multiplier("foo") -> "foo", 1.0
        _get_sampling_multiplier("foo:0.5") -> "foo", 0.5
        _get_sampling_multiplier("foo:1.5") -> "foo", 1.5
        _get_sampling_multiplier("foo:2") -> "foo", 2.0
    """
    pattern = r"^[^\s:]+\:(\d*\.)?\d+$"
    match = re.search(pattern, path)
    if match:
        path, multiplier = path.split(":")
        return (path, float(multiplier))
    path = path.split(" ")[0].split(":")[0].strip()
    return (path, 1.0)

def _update_data_filters(
    wd_train_urls: list[str],
    data_filters: dict[str, dict],
    dataset_config: dict[str, dict],
) -> None:
    """Update data filters for each tar file url"""
    for url in wd_train_urls:
        data_filters[url] = {
            "target_keys": dataset_config.get("target_keys", None),
            "rename_map": dataset_config.get("rename_map", None),
        }
        
def _get_train_urls_from_datarecipe(
    wd_datarecipe: dict[str, dict],
    data_seed: int | None,
) -> tuple[list[str], dict[str, dict]]:
    """Get train urls from wd_datarecipe.
    Args:
        wd_datarecipe: a mixture of data sources, w/o a sampling multiplier
        data_seed: random seed to use for sampling or None
    Returns:
        train_urls: list of paths
        data_filters: dict of data filters for each url
    Examples:
        "wd_datarecipe": {
            "dataset1": {
                "path": 'path3',                        # required, str or list
                "chunk_size": 1000,                     # required, int
                "multiplier": 2,                        # optional, int, float, or list, default=1,
                "target_keys": ["caption", "response"], # optional, default to load all
                "rename_map": {"caption": "text"}       # optional, default with no rename
            },
            "dataset2": {
                "path": ['/path4', '/path5']            # required, str or list
                "chunk_size": 1000,                     # required, int
                "multiplier": [0.5, 2],                 # optional, int, float, or list, default=1,
                "target_keys": ["caption", "response"], # optional, default to load all
                "rename_map": {"caption": "text"}       # optional, default with no rename
            },
        }
    """
    train_urls = []
    data_filters = {}

    for dataset_name, dataset_config in wd_datarecipe.items():
        path = dataset_config.get("paths", None)
        if not path:
            raise ValueError(f"Missing path for dataset {dataset_name}")

        # datasets with single path
        if isinstance(path, str):
            multiplier = dataset_config.get("multiplier", 1.0)
            wd_train_urls = _sample_webds_dataset(
                path,
                multiplier,
                data_seed,
            )
            train_urls.extend(wd_train_urls)
            _update_data_filters(list(set(wd_train_urls)), data_filters, dataset_config)

        # datasets with multiple paths
        elif isinstance(path, list):
            multipliers = dataset_config.get("multiplier", [1.0] * len(path))

            # make multipliers the same length as path
            if len(multipliers) != len(path):
                raise ValueError(
                    "number of multipliers must be less than or equal to the number of paths"
                )

            for single_path, multiplier in zip(path, multipliers):
                wd_train_urls = _sample_webds_dataset(
                    single_path,
                    multiplier,
                    data_seed,
                )
                train_urls.extend(wd_train_urls)
                # each url has unique data filters
                _update_data_filters(
                    list(set(wd_train_urls)), data_filters, dataset_config
                )
        else:
            raise ValueError(
                f"Invalid type for path in dataset {dataset_name}. Expected str or list, got {type(path)}."
            )
    return (train_urls, data_filters)


def _sample_webds_dataset(
    path: str, multiplier: float, data_seed: int | None
) -> list[str]:
    """Sample the dataset based on the multiplier.

    :param path: tar file url
    :param multiplier: sampling multiplier
        when float between 0.0 and 1.0, subsample
        when integer >= 1, oversample
        when float > 1.0, oversample + random sample
    :param data_seed: random seed to use for sampling or None
        when not None, initialize random sampler with the specified seed
        when None, initialize random sampler with seed=0
    """
    if multiplier <= 0.0:
        raise ValueError("wd_data_sampling_multiplier must be greater than 0.0")

    search_path = os.path.join(path, "**", "*.tar")
    train_urls = glob.glob(search_path, recursive=True)

    rand = random.Random(data_seed or 0)
    if multiplier < 1.0:
        # subsampling the data
        train_urls = rand.sample(train_urls, int(len(train_urls) * multiplier))
    elif multiplier.is_integer():
        # oversampling the data, when multplier is 1, 2, 3...
        train_urls = train_urls * int(multiplier)
    else:
        # oversampling the data + random sampling when multiplier is 1.5, 2.7, 3.1...
        q, r = divmod(multiplier, 1)
        train_urls = train_urls * int(q) + rand.sample(
            train_urls, int(len(train_urls) * r)
        )
    return train_urls

def _get_train_urls(
    wd_data_path: str | list[str] | None,
    wd_datarecipe: dict[str, dict] | None,
    data_seed: int | None,
) -> tuple[list[str], dict[str, dict]]:
    train_urls = []
    data_filters = {}
    if not wd_data_path and not wd_datarecipe:
        return train_urls, data_filters

    if wd_data_path:
        if isinstance(wd_data_path, list):
            path_list = wd_data_path
        else:
            path_list = wd_data_path.split(",")
        for path in path_list:
            path, multiplier = _get_sampling_multiplier(path.strip())
            wd_train_urls = _sample_webds_dataset(
                path,
                multiplier,
                data_seed,
            )
            train_urls.extend(wd_train_urls)

    if wd_datarecipe:
        wd_train_urls, data_filters = _get_train_urls_from_datarecipe(
            wd_datarecipe, data_seed
        )
        train_urls.extend(wd_train_urls)
    return train_urls, data_filters

In [111]:
config_json = "/fsx_0/user/tranx/rsync/llm_mm_aligner/experiments/aws/mm10/stage2/MH21_70B_336px_exp32a_20240930.json"
config = read_json(config_json)

wd_data_path = config['trainer_args']['wd_data_path']

train_urls, data_filters = _get_train_urls(wd_data_path, None, None)

In [112]:
train_urls

['/fsx_0/user/yetian12/datasets/pdfa-eng-wds-converted/converted_pdfa-eng-train-0897.tar',
 '/fsx_0/user/yetian12/datasets/pdfa-eng-wds-converted/converted_pdfa-eng-train-0342.tar',
 '/fsx_0/user/yetian12/datasets/pdfa-eng-wds-converted/converted_pdfa-eng-train-1502.tar',
 '/fsx_0/user/yetian12/datasets/pdfa-eng-wds-converted/converted_pdfa-eng-train-1163.tar',
 '/fsx_0/user/yetian12/datasets/pdfa-eng-wds-converted/converted_pdfa-eng-train-0084.tar',
 '/fsx_0/user/yetian12/datasets/pdfa-eng-wds-converted/converted_pdfa-eng-train-0723.tar',
 '/fsx_0/user/yetian12/datasets/pdfa-eng-wds-converted/converted_pdfa-eng-train-0055.tar',
 '/fsx_0/user/yetian12/datasets/pdfa-eng-wds-converted/converted_pdfa-eng-train-1615.tar',
 '/fsx_0/user/yetian12/datasets/pdfa-eng-wds-converted/converted_pdfa-eng-train-1274.tar',
 '/fsx_0/user/yetian12/datasets/pdfa-eng-wds-converted/converted_pdfa-eng-train-0846.tar',
 '/fsx_0/user/yetian12/datasets/pdfa-eng-wds-converted/converted_pdfa-eng-train-0434.tar',

In [108]:
"wd_datarecipe": [
	"dataset1": {
		"path": 'path1', #required, str or list
		"chunk_size": 1000, #required, str or list
		"multiplier": 2 # optional, default=1, must be same as length as path if provided
		"target_keys": ["caption", "response"], # optional, default load all
		"rename_keys": {"caption": "text"} # optional, default no rename
       },
]