In [None]:
import itertools
import json
import math
from typing import List, Union
import torchvision.transforms.functional as TF
import webdataset as wds
from braceexpand import braceexpand
from torch.utils.data import default_collate
from torchvision import transforms
from webdataset.tariterators import (
    base_plus_ext,
    tar_file_expander,
    url_opener,
    valid_sample,
)

In [None]:
def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
    # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
    streams = url_opener(src, handler=handler)
    files = tar_file_expander(streams, handler=handler)
    samples = group_by_keys_nothrow(files, handler=handler)
    return samples


def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
    """Return function over iterator that groups key, value pairs into samples.

    :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to
    lower case (Default value = True)
    """
    current_sample = None
    for filesample in data:
        assert isinstance(filesample, dict)
        fname, value = filesample["fname"], filesample["data"]
        prefix, suffix = keys(fname)
        if prefix is None:
            continue
        if lcase:
            suffix = suffix.lower()
        # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
        #  this happening in the current LAION400m dataset if a tar ends with same prefix as the next
        #  begins, rare, but can happen since prefix aren't unique across tar files in that dataset
        if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
            if valid_sample(current_sample):
                yield current_sample
            current_sample = {"__key__": prefix,
                              "__url__": filesample["__url__"]}
        if suffixes is None or suffix in suffixes:
            current_sample[suffix] = value
    if valid_sample(current_sample):
        yield current_sample


def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
    # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
    streams = url_opener(src, handler=handler)
    files = tar_file_expander(streams, handler=handler)
    samples = group_by_keys_nothrow(files, handler=handler)
    return samples


def filter_keys(key_set):
    def _f(dictionary):
        return {k: v for k, v in dictionary.items() if k in key_set}

    return _f


class WebdatasetFilter:
    def __init__(self, min_size=1024, max_pwatermark=0.5):
        self.min_size = min_size
        self.max_pwatermark = max_pwatermark

    def __call__(self, x):
        try:
            if "json" in x:
                x_json = json.loads(x["json"])
                filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
                    "original_height", 0
                ) >= self.min_size
                filter_watermark = (x_json.get(
                    "pwatermark", 1.0) or 1.0) <= self.max_pwatermark
                return filter_size and filter_watermark
            else:
                return False
        except Exception:
            return False


class Text2ImageDataset:
    def __init__(
        self,
        train_shards_path_or_url: Union[str, List[str]],
        num_train_examples: int,
        per_gpu_batch_size: int,
        global_batch_size: int,
        num_workers: int,
        resolution: int = 1024,
        shuffle_buffer_size: int = 1000,
        pin_memory: bool = False,
        persistent_workers: bool = False,
        use_fix_crop_and_size: bool = False,
    ):
        if not isinstance(train_shards_path_or_url, str):
            train_shards_path_or_url = [
                list(braceexpand(urls)) for urls in train_shards_path_or_url]
            # flatten list using itertools
            train_shards_path_or_url = list(
                itertools.chain.from_iterable(train_shards_path_or_url))

        def get_orig_size(json, resolution=1024, use_fix_crop_and_size=False):
            if use_fix_crop_and_size:
                return (resolution, resolution, str(json.get("caption", "")))
            else:
                return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)),
                        str(json.get("caption", "")))

        def transform(example):
            # resize image
            image = example["image"]
            image = TF.resize(
                image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)

            # get crop coordinates and crop image
            c_top, c_left, _, _ = transforms.RandomCrop.get_params(
                image, output_size=(resolution, resolution))
            image = TF.crop(image, c_top, c_left, resolution, resolution)
            image = TF.to_tensor(image)
            image = TF.normalize(image, [0.5], [0.5])

            example["image"] = image
            example["crop_coords"] = (
                c_top, c_left) if not use_fix_crop_and_size else (0, 0)
            return example

        processing_pipeline = [
            wds.decode("pil", handler=wds.ignore_and_continue),
            wds.rename(image="jpg;png;jpeg;webp", orig_size="json",
                       handler=wds.warn_and_continue),
            wds.map(filter_keys({"image", "orig_size"})),
            wds.map_dict(orig_size=get_orig_size),
            wds.map(transform),
            wds.to_tuple("image", "orig_size", "crop_coords"),
        ]

        # Create train dataset and loader
        pipeline = [
            wds.ResampledShards(train_shards_path_or_url),
            tarfile_to_samples_nothrow,
            # wds.select(WebdatasetFilter(min_size=960)),
            wds.shuffle(shuffle_buffer_size),
            *processing_pipeline,
            wds.batched(per_gpu_batch_size, partial=False,
                        collation_fn=default_collate),
        ]

        num_worker_batches = math.ceil(
            num_train_examples / (global_batch_size * num_workers))  # per dataloader worker
        num_batches = num_worker_batches * num_workers
        num_samples = num_batches * global_batch_size

        # each worker is iterating over this
        self._train_dataset = wds.DataPipeline(
            *pipeline).with_epoch(num_worker_batches)
        self._train_dataloader = wds.WebLoader(
            self._train_dataset,
            batch_size=None,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=pin_memory,
            persistent_workers=persistent_workers,
        )
        # add meta-data to dataloader instance for convenience
        self._train_dataloader.num_batches = num_batches
        self._train_dataloader.num_samples = num_samples

    @property
    def train_dataset(self):
        return self._train_dataset

    @property
    def train_dataloader(self):
        return self._train_dataloader

In [None]:
train_shards_path_or_url = "/home/dataset/songyun/songyun_small.tar"
max_train_samples = 10
train_batch_size = 4
dataloader_num_workers = 1
resolution = 1024
use_fix_crop_and_size = False

dataset = Text2ImageDataset(
    train_shards_path_or_url=train_shards_path_or_url,
    num_train_examples=max_train_samples,
    per_gpu_batch_size=train_batch_size,
    global_batch_size=train_batch_size,
    num_workers=dataloader_num_workers,
    resolution=resolution,
    shuffle_buffer_size=1000,
    pin_memory=True,
    persistent_workers=True,
    use_fix_crop_and_size=use_fix_crop_and_size,
)

train_dataloader = dataset.train_dataloader

for step, batch in enumerate(train_dataloader):
    image, text, orig_size, crop_coords = batch[0], batch[1][2], [
        batch[1][0], batch[1][1]], batch[2]
    print(image.shape, text, orig_size, crop_coords)