In [1]:
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import io
import os
import urllib

import PIL.Image

from datasets import load_dataset
from datasets.utils.file_utils import get_datasets_user_agent


USER_AGENT = get_datasets_user_agent()

num_threads = os.cpu_count() * 5


def fetch_single_image(image_url, timeout=None, retries=0):
    for _ in range(retries + 1):
        try:
            request = urllib.request.Request(
                image_url,
                data=None,
                headers={"user-agent": USER_AGENT},
            )
            with urllib.request.urlopen(request, timeout=timeout) as req:
                image = PIL.Image.open(io.BytesIO(req.read()))
            break
        except Exception:
            image = None
    return image


def fetch_images(batch, num_threads, timeout=None, retries=0):
    fetch_single_image_with_args = partial(fetch_single_image, timeout=timeout, retries=retries)
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        batch["image"] = list(executor.map(fetch_single_image_with_args, batch["image_url"]))
    return batch


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
load_dataset('conceptual_captions', cache_dir='E:/Datasets/conceptual').map(fetch_images, batched=True, batch_size=100, fn_kwargs={"num_threads": num_threads})

In [2]:
load_dataset(
    'conceptual_captions', split='validation', cache_dir='E:/Datasets/conceptual'
    ).map(
        fetch_images, batched=True, load_from_cache_file='E:/Datasets/conceptual', batch_size=1024, fn_kwargs={"num_threads": num_threads}
        ).save_to_disk(
            'E:/Datasets/conceptual_captions_validation'
            )

No config specified, defaulting to: conceptual_captions/unlabeled
Found cached dataset conceptual_captions (E:/Datasets/conceptual/conceptual_captions/unlabeled/1.0.0/05266784888422e36944016874c44639bccb39069c2227435168ad8b02d600d8)
100%|██████████| 16/16 [45:58<00:00, 172.40s/ba]


In [109]:
from datasets import load_from_disk
cc_val = load_from_disk('E:/Datasets/conceptual_captions_validation/')
len(cc_val)

15840

In [110]:
cc_val_filterd = cc_val.filter(lambda x: x['image'] is not None and x['image'].mode == 'RGB')
len(cc_val_filterd)

100%|██████████| 16/16 [00:39<00:00,  2.49s/ba]


12913

In [29]:
import requests
from PIL import Image
from torchvision import transforms

In [30]:
request = urllib.request.Request(
    cc_val[0]['image_url'],
    data=None,
    headers={"user-agent": USER_AGENT},
)
with urllib.request.urlopen(request, timeout=None) as req:
    image = PIL.Image.open(io.BytesIO(req.read()))
    image = transforms.ToTensor()(image)
    print(type(image))

<class 'torch.Tensor'>


In [32]:
img = Image.open(requests.get(cc_val[0]['image_url'], stream=True).raw)
img = transforms.ToTensor()(img)
print(type(img))

<class 'torch.Tensor'>


In [34]:
from PIL import Image, JpegImagePlugin
print(type(cc_val[0]['image']))
image_loaded = transforms.ToTensor()(cc_val[0]['image'])
print(type(image_loaded))

<class 'PIL.JpegImagePlugin.JpegImageFile'>
<class 'torch.Tensor'>


In [132]:
import torch
from transformers import PerceiverTokenizer

tokenizer = PerceiverTokenizer()

def collate_fn(batch):
    return dict(
        text=tokenizer([item['caption'] for item in batch], padding=True, return_tensors='pt')['input_ids'],
        image=torch.cat([item['image'] for item in batch]),
    )

In [130]:
from torch.utils.data import Dataset
from torchvision import transforms
from transformers import PerceiverFeatureExtractor, PerceiverTokenizer

class ConceptualCaptionsDataset(Dataset):
    def __init__(self, hf_dataset) -> None:
        super().__init__()
        self.dataset = hf_dataset
        self.feature_extractor = PerceiverFeatureExtractor()
        self.tokenizer = PerceiverTokenizer()
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        x = self.dataset[index]
        x['image'] = self.feature_extractor(x['image'], return_tensors='pt')['pixel_values']
        return x

In [134]:
from torch.utils.data import DataLoader

ds = ConceptualCaptionsDataset(cc_val_filterd)
dl = DataLoader(ds, batch_size=32, num_workers=0, pin_memory=True, collate_fn=collate_fn)

batch = next(iter(dl))
print(batch['text'].size())
print(batch['image'].size())

torch.Size([32, 174])
torch.Size([32, 3, 224, 224])
