In [76]:
import torch
from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import mmcv
import matplotlib.pyplot as plt
import os.path as osp
import webdataset as wds
from datasets.formatting import ToDataContainer
from datasets.builder import WordAugTokenizeWrapper, Tokenize
from datasets.tokenizer import SimpleTokenizer
from mmcv.parallel import collate
from functools import partial



In [77]:
transform = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
                ToDataContainer()
            ])


text_transform = WordAugTokenizeWrapper(
                Tokenize(SimpleTokenizer(), max_seq_len=77),
                max_word=3,
                max_key=0,
                word_type=['noun'])

In [78]:
path = "/workspace/Dataset/local_data/gcc3m_shards/"
prefix = "gcc-train-000000.tar"
tar_file = osp.join(path, prefix)

dataset = (
    wds.WebDataset(tar_file, repeat=True)
    .decode('pil')
    .rename(image1='jpg;png;jpeg', image2='jpg;png;jpeg', text='text;txt', keep=False,)
    .map_dict(image1=transform, image2=transform, text=text_transform)
)

dc_collate = partial(collate, samples_per_gpu=16)

data_loader = wds.WebLoader(
    dataset.batched(16, dc_collate),
    batch_size = None,
    shuffle=False,
)


In [79]:
data =next(iter(data_loader))

In [85]:
data['image2']

DataContainer([tensor([[[[ 2.0605,  2.0605,  2.0777,  ..., -1.0904, -1.0562, -1.1760],
          [ 2.0605,  2.0605,  2.0777,  ..., -0.6623, -0.8507, -1.0390],
          [ 2.0605,  2.0605,  2.0777,  ...,  0.9646,  0.6392,  0.1768],
          ...,
          [-1.8610, -1.8782, -1.8610,  ...,  2.1975,  2.2318,  2.2318],
          [-1.9638, -1.9467, -1.3987,  ...,  2.1975,  2.1975,  2.1975],
          [-1.5185, -0.0801,  0.9132,  ...,  2.2318,  2.2489,  2.2318]],

         [[ 2.1660,  2.1660,  2.1835,  ..., -1.5280, -1.4580, -1.5455],
          [ 2.1660,  2.1660,  2.1835,  ..., -1.3880, -1.5105, -1.6155],
          [ 2.1660,  2.1660,  2.1835,  ..., -0.0224, -0.2325, -0.5651],
          ...,
          [-1.7731, -1.7381, -1.6856,  ...,  2.1134,  2.1485,  2.1485],
          [-1.7906, -1.6155, -0.9153,  ...,  2.1310,  2.1310,  2.1310],
          [-1.1429,  0.5028,  1.7283,  ...,  2.1485,  2.1835,  2.1660]],

         [[ 2.1868,  2.1868,  2.2043,  ..., -1.6302, -1.5604, -1.5953],
          [ 2.1