# Preparing the dataset

We build a dataset using tfrecords and webp format.

In [None]:
# install dependencies
#!pip install img2dataset tensorflow tensorflow_io wandb

## Getting URL list

We use the [Conceptual Captions](https://ai.google.com/research/ConceptualCaptions/).

In [None]:
# Download Conceptual Captions
!wget https://storage.googleapis.com/gcc-data/Validation/GCC-1.1.0-Validation.tsv -O GCC-valid.tsv
!wget https://storage.googleapis.com/gcc-data/Train/GCC-training.tsv -O GCC-train.tsv

We format input files to keep only url's.

In [None]:
import pandas as pd

In [None]:
# clean up the files and keep only url
for f, name in zip(
    ["GCC-train.tsv", "GCC-valid.tsv"],
    ["train", "valid"],
):
    df = pd.read_csv(f, sep="\t", names=["caption", "url"])
    df.to_parquet(f"{name}.parquet")

Datasets may be a bit large so we reduce their size.

In [None]:
for path, max_items in [("train.parquet", 500_000), ("valid.parquet", 10_000)]:
    df = pd.read_parquet(path)
    print(f"{path}: keeping {max_items} / {len(df)}")
    df = df[:max_items]
    df.to_parquet(path)

## Download images

In [None]:
!mkdir cc3m

In [None]:
# parameters for validation set
input_file = "valid.parquet"
output_folder = "cc3m/valid"
input_format = "parquet"
caption_col = "caption"
image_size = 256
processes_count = 80
thread_count = 16
encode_quality = 100
encode_format = "webp"
number_sample_per_shard = 1000
min_image_size = 128

In [None]:
!img2dataset \
  --url_list $input_file \
  --image_size $image_size \
  --output_folder $output_folder \
  --input_format $input_format \
  --caption_col $caption_col \
  --processes_count $processes_count \
  --thread_count $thread_count \
  --resize_mode center_crop \
  --encode_quality $encode_quality \
  --encode_format $encode_format \
  --output_format tfrecord \
  --number_sample_per_shard $number_sample_per_shard \
  --extract_exif false \
  --min_image_size $min_image_size \
  --enable_wandb

In [None]:
# update relevant parameters for train set
input_file = "train.parquet"
output_folder = "cc3m/train"

In [None]:
!img2dataset \
  --url_list $input_file \
  --image_size $image_size \
  --output_folder $output_folder \
  --input_format $input_format \
  --caption_col $caption_col \
  --processes_count $processes_count \
  --thread_count $thread_count \
  --resize_mode center_crop \
  --encode_quality $encode_quality \
  --encode_format $encode_format \
  --output_format tfrecord \
  --number_sample_per_shard $number_sample_per_shard \
  --extract_exif false \
  --min_image_size $min_image_size \
  --enable_wandb

## Dataloader

Files have been saved as tfrecords

In [None]:
from clip_jax.data import Dataset, logits_to_image
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf
from tqdm.notebook import tqdm

## Optional: Calculate mean and std of the dataset

We calculate the statistics on the validation set to use a center crop instead of random crop.

In [None]:
dataset = Dataset(
    valid_folder="cc3m/train",
    valid_batch_size=1000,
    image_size=224,
)

We parallelize the calculation of mean and std because we are efficient people.

In [None]:
def get_mean_std(images, captions):
    images = tf.cast(images, tf.float64)
    mean = tf.reduce_mean(images, axis=(0, 2, 3))
    std = tf.math.reduce_std(images, axis=(0, 2, 3))
    return mean, std

In [None]:
ds_mean_std = dataset.valid.map(
    get_mean_std, num_parallel_calls=tf.data.experimental.AUTOTUNE
)

In [None]:
means = []
stds = []

for batch in tqdm(ds_mean_std):
    mean, std = batch
    means.append(mean)
    stds.append(std)

In [None]:
# get the global mean and std
mean = tf.stack(means, axis=0)
std = tf.stack(stds, axis=0)
mean = tf.reduce_mean(mean, axis=0)
std = tf.math.sqrt(tf.reduce_sum(tf.math.square(std) / len(stds), axis=0))

mean, std

## Dataloader

In [None]:
dataset = Dataset(
    train_folder="cc3m/train",
    valid_folder="cc3m/train",
    train_batch_size=10,
    valid_batch_size=10,
    image_size=224,
    mean=[0.5, 0.5, 0.5],
    std=[0.5, 0.5, 0.5],
)

In [None]:
sample_batch = next(iter(dataset.train.as_numpy_iterator()))
sample_batch

In [None]:
images = sample_batch[0]

In [None]:
# visualize the batch
plt.figure(figsize=(10, 10))
for i in range(9):
    img = logits_to_image(
        images[i], mean=dataset.mean, std=dataset.std, format=dataset.format
    )
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(img)

Let's check captions.

In [None]:
from transformers import CLIPTokenizerFast

In [None]:
CLIP_REPO = "openai/clip-vit-large-patch14"
tokenizer = CLIPTokenizerFast.from_pretrained(CLIP_REPO)

In [None]:
captions = sample_batch[1]
captions[:5]

In [None]:
# tokenize
captions = [caption.decode("utf-8") for caption in captions]
txt_inputs = tokenizer(
    captions, padding="max_length", truncation=True, return_tensors="np"
)
txt_inputs

## Run CLIP on data

In [None]:
from transformers import FlaxCLIPModel

In [None]:
model = FlaxCLIPModel.from_pretrained(CLIP_REPO)

In [None]:
inputs = {"pixel_values": images, **txt_inputs}

In [None]:
outputs = model(**inputs)[0]