# vqgan-jax-encoding-yfcc100m

Encoding notebook for YFCC100M.

This dataset was prepared by @borisdayma in Json lines format. We'll load it in streaming mode.

In [1]:
import io

import requests
from PIL import Image
import numpy as np
from tqdm import tqdm

import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode
import os

import jax
from jax import pmap

## Dataset and Parameters

This is a local test on a small subset stored in disk. It can also be used in streaming mode, so it is useful to validate how everything works.

In [2]:
from pathlib import Path

In [3]:
yfcc100m = Path('/home/pedro/data/YFCC100M_OpenAI_subset')
# Images are 'sharded' from the following directory
yfcc100m_images = yfcc100m/'data'/'images'
yfcc100m_metadata = yfcc100m/'metadata_existing.jsonl'
yfcc100m_output = yfcc100m/'metadata_encoded.tsv'

In [4]:
batch_size = 128     # Per device

In [5]:
import datasets
from datasets import Dataset, load_dataset

We load the dataset in streaming mode. This allows us to process the whole OpenAI subset of YFCC100M without having to download a local copy in the TPU.

Datasets loaded in streaming mode are iterables and have no `len`.

In [6]:
dataset = load_dataset('json', data_files=str(yfcc100m_metadata), streaming=True)
dataset = dataset["train"]

Using custom data configuration default-d6917ade73efa578


In [7]:
next(iter(dataset))

{'accuracy': 12.0,
 'capturedevice': '',
 'datetaken': '2004-09-01 15:21:46.0',
 'dateuploaded': '1094077306',
 'description': 'The+door+we+climbed+through+to+the+cafeteria',
 'description_clean': 'The door we climbed through to the cafeteria',
 'downloadurl': 'http://farm1.staticflickr.com/1/317823_1b42b71779.jpg',
 'ext': 'jpg',
 'farmid': 1,
 'key': '10752f5dcc9b9ca5309542708b1bacf',
 'latitude': 51.897893,
 'licensename': 'Attribution-NonCommercial-NoDerivs License',
 'licenseurl': 'http://creativecommons.org/licenses/by-nc-nd/2.0/',
 'longitude': -8.506336,
 'machinetags': '',
 'marker': 0,
 'pageurl': 'http://www.flickr.com/photos/51035594319@N01/317823/',
 'photoid': 317823,
 'secret': '1b42b71779',
 'secretoriginal': '1b42b71779',
 'serverid': 1,
 'title': 'st+annes-18',
 'title_clean': 'st annes-18',
 'uid': '51035594319@N01',
 'unickname': 'twelves',
 'usertags': 'abandoned,asylum,cork,door,ireland,urban+decay,urban+exploration'}

### Data preparation

* Images

We retrieve them based on their `key`, then we transform them so they are center-cropped and square, all of the same size so we can build batches for TPU/GPU processing.

* Captions: we extract a single `caption` column from the source data, by concatenating the cleaned title and description.

These transformations are done using the Datasets `map` function. In the case of streaming datasets, transformations will run as needed instead of pre-processing the dataset at once.

The following function retrieves an image based on its `key`. In my tests I'll read it from the filesystem. In the final dataset, it will be downloaded remotely from the appropriate zip file.

In [8]:
def get_image_data(path, key, ext):
    image_path = (path/key[0:3]/key[3:6]/key).with_suffix("." + ext)
    return Image.open(image_path).convert('RGB')

This function does the center-cropping.

In [9]:
def center_crop(image, max_size=256):
    s = min(image.size)

    # Note: we allow upscaling too. We should exclude small images.
    r = max_size / s
    s = (round(r * image.size[1]), round(r * image.size[0]))
    image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)
    image = TF.center_crop(image, output_size=2 * [max_size])
    image = torch.from_numpy(np.array(image, copy=True))
    image = torch.unsqueeze(image, 0)
    return image.numpy()

And this is the basic transformation function to use in `map`.

In [10]:
def prepare_data(examples):
    result = {'key': [], 'caption': [], 'image': []}

    for key, ext, title, description in zip(examples['key'], examples['ext'], examples['title_clean'], examples['description_clean']):
        image = get_image_data(yfcc100m_images, key, ext)
        image = center_crop(image)
        
        caption = f"{title}. {description}"
        
        result['key'].append(key)
        result['caption'].append(caption)
        result['image'].append(image)
    return result

Unlike when using non-streaming datasets, the following operation completes immediately in streaming mode. Samples will be processed as needed. We use a `batch_size` of the same size as the processing size to read as many items we'll consume in a dataloader step. We could use a different size, this way we do a retrieval per batch.

In [11]:
prepared_dataset = dataset.map(prepare_data, batched=True, batch_size=batch_size * jax.device_count())

In [12]:
%%time
_ = next(iter(prepared_dataset))

CPU times: user 8.76 s, sys: 0 ns, total: 8.76 s
Wall time: 8.55 s


We have a problem here. Our `prepare_data` function receives a batch, but it loads and transforms images sequentially. We'll now try to retrieve images in parallel, using a simple `Pool`.

In [13]:
from multiprocessing import Pool

In [14]:
def get_image(key):
    image = get_image_data(yfcc100m_images, key, 'jpg')
    image = center_crop(image)
    return image

# Create a single pool that will be reused
pool = Pool(16)

def parallel_prepare_data(examples):
    # Retrieve images in parallel using the global pool
    keys = examples['key']
    images = pool.map(get_image, keys)
    captions = [f"{title}. {description}" for (title, description) in zip(examples['title_clean'], examples['description_clean'])]

    result = {'key': keys, 'caption': captions, 'image': images}
    return result

In [15]:
prepared_dataset = dataset.map(parallel_prepare_data, batched=True, batch_size=batch_size * jax.device_count())

In [16]:
%%time
_ = next(iter(prepared_dataset))

CPU times: user 404 ms, sys: 704 ms, total: 1.11 s
Wall time: 1.03 s


We'll use this method for encoding.

### Torch DataLoader

We'll create a PyTorch DataLoader for convenience. This allows us to easily take batches of our desired size.

We won't be using parallel processing of the DataLoader for now, as the items will be retrieved on the fly. We could attempt to do it using these recommendations: https://pytorch.org/docs/stable/data.html#multi-process-data-loading. For now, we'll just leverage our parallel image loading method and retrieve batches sequentially.

In [17]:
import torch
from torch.utils.data import DataLoader

In [18]:
torch_dataset = prepared_dataset.with_format("torch")

In [19]:
type(torch_dataset)

datasets.iterable_dataset.iterable_dataset.<locals>.TorchIterableDataset

In [20]:
torch_loader = DataLoader(torch_dataset, batch_size=batch_size * jax.device_count())

In [21]:
batch = next(iter(torch_loader))

In [22]:
batch['image'].shape

torch.Size([1024, 1, 256, 256, 3])

## VQGAN-JAX model

In [23]:
from vqgan_jax.modeling_flax_vqgan import VQModel

We'll use a VQGAN trained with Taming Transformers and converted to a JAX model.

In [24]:
model = VQModel.from_pretrained("flax-community/vqgan_f16_16384")

Working with z of shape (1, 256, 16, 16) = 65536 dimensions.


## Encoding

In [25]:
from flax.training.common_utils import shard
from functools import partial

In [26]:
@partial(jax.pmap, axis_name="batch")
def encode(batch):
    # Not sure if we should `replicate` params, does not seem to have any effect
    _, indices = model.encode(batch)
    return indices

### Putting it all together in a encoding function

In [27]:
import os
import pandas as pd

def encode_captioned_dataset(dataset, prepare_function, output_tsv, batch_size=32):
    """
    :param dataset: Streaming Dataset with source data.
    :param prepare_function: Data preparation function, to be applied to the dataset using `map`.
    :param output_tsv: Destination file. Must not exist.
    :batch_size: Per-device batch size
    """
    if os.path.isfile(output_tsv):
        print(f"Destination file {output_tsv} already exists, please move away.")
        return
    
    num_tpus = jax.device_count()
    
    prepared_dataset = dataset.map(prepare_function, batched=True, batch_size=batch_size * num_tpus)
    torch_dataset = prepared_dataset.with_format("torch")
    dataloader = DataLoader(torch_dataset, batch_size=batch_size * num_tpus)

    # We save each batch to avoid reallocation of buffers as we process them.
    # We keep the file open to prevent excessive file seeks.
    # TODO: save to .jsonl, create a new file every half million images or so
    with open(output_tsv, "w") as file:
        iter_loader = iter(dataloader)
        for batch in tqdm(iter_loader):
            try:
                images = batch["image"].numpy()
                images = shard(images.squeeze())
                encoded = encode(images)
                encoded = encoded.reshape(-1, encoded.shape[-1])

                # Extract captions from the dataset
                keys = batch["key"]
                captions = batch["caption"]
                encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))
                batch_df = pd.DataFrame.from_dict({"key": keys, "caption": captions, "encoding": encoded_as_string})
                batch_df.to_csv(file, sep='\t', header=False, index=None)
            except ValueError:
                # Ignore incomplete last batch, which cannot be sharded
                pass

In [29]:
encode_captioned_dataset(dataset, parallel_prepare_data, yfcc100m_output, batch_size=batch_size)

57it [01:50,  1.93s/it]


This is not as efficient as a regular dataset loaded with multiple workers, but it's not bad. Non-streaming parallel processing took ~1:30 in a similar test.

----