# Computation of caption text embeddings

To do
- Iron out bugs in the batch downloader
- Use [dask delayed](https://docs.dask.org/en/stable/delayed.html). Configure scheduling so that GPU computations wait until a few hundred captions are accumilated. Then run computation, take out individual embeddings, truncate, and write.

In [1]:
import json
import torch
import braceexpand
from tqdm import tqdm
from imagen_pytorch.t5 import t5_encode_text
import webdataset as wds

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

cuda


In [2]:
def batch_augment_wds(input, output, batch_size=256):
    
    stats_file = input[:-4] + "_stats.json"
    f = open(stats_file)
    stats = json.load(f)
    f.close()
    count = stats["successes"]
    
    src = wds.DataPipeline(
        wds.SimpleShardList(input),
        wds.tarfile_to_samples(),
        wds.decode("pil"),
        wds.to_tuple("__key__", "jpg;png", "txt"),
        wds.map_tuple(None, None, None)
    )
    
    dst = wds.TarWriter(output) 
    
    keys = []; imgs = []; caps = []; idx = 0
    # for key, img, cap in tqdm(src, total=count, desc=f"Writing {output}"):
    for idx, (key, img, cap) in enumerate(src):
        keys.append(key)
        imgs.append(img)
        caps.append(cap)
        if ((idx+1)%batch_size == 0) or idx==count:
            print(f"\r Step {idx}/{count}", end='')
            batch_embeds = t5_encode_text(caps, name="google/t5-v1_1-xl", return_attn_mask=False)
            batch_embeds = batch_embeds.cpu() # consider removing 
            embs = []
            for tensor in batch_embeds:
                ix, iy = tensor.nonzero(as_tuple=True)
                tensor_nonzero = tensor[0:max(ix), :]
                embs.append(tensor_nonzero)
            for key_, img_, cap_, emb_ in zip(keys, imgs, caps, embs):
                dst.write({"__key__":key_, "png":img_, "txt":cap_, "emb.pyd":emb_})
            keys = []; imgs = []; caps = []

In [3]:
def get_emb_tensor(text):
    text_embeds = t5_encode_text([text], name="google/t5-v1_1-xl", return_attn_mask=False)
    return text_embeds.cpu()


def augment_wds(input, output):
    
    stats_file = input[:-4] + "_stats.json"
    f = open(stats_file)
    stats = json.load(f)
    f.close()
    count = stats["successes"]
    
    src = wds.DataPipeline(
        wds.SimpleShardList(input),
        wds.tarfile_to_samples(),
        wds.decode("pil"),
        wds.to_tuple("__key__", "jpg;png", "txt", "txt"),
        wds.map_tuple(None, None, None, get_emb_tensor)
    )
    
    with wds.TarWriter(output) as dst:
        for key, img, cap, emb in tqdm(src, total=count, desc=f"Writing {output}"):
            dst.write({"__key__":key, "png":img, "txt":cap, "emb.pyd":emb})

In [None]:
input_shards = braceexpand.braceexpand("cc12m/{00234..01242}.tar")
output_shards = braceexpand.braceexpand("file:E:/datasets/cc12m/{00234..01242}.tar")
for input_shard, output_shard in zip(input_shards, output_shards):
    augment_wds(input=input_shard, output=output_shard)

Writing file:E:/datasets/cc12m/00234.tar: 100%|██████████| 8561/8561 [06:52<00:00, 20.76it/s]
Writing file:E:/datasets/cc12m/00235.tar: 100%|██████████| 8504/8504 [06:17<00:00, 22.56it/s]
Writing file:E:/datasets/cc12m/00236.tar: 100%|██████████| 8464/8464 [06:13<00:00, 22.67it/s]
Writing file:E:/datasets/cc12m/00237.tar: 100%|██████████| 8500/8500 [06:13<00:00, 22.77it/s]
Writing file:E:/datasets/cc12m/00238.tar: 100%|██████████| 8453/8453 [06:07<00:00, 23.01it/s]
Writing file:E:/datasets/cc12m/00239.tar: 100%|██████████| 8497/8497 [06:02<00:00, 23.41it/s]
Writing file:E:/datasets/cc12m/00240.tar: 100%|██████████| 8461/8461 [06:03<00:00, 23.27it/s]
Writing file:E:/datasets/cc12m/00241.tar: 100%|██████████| 8504/8504 [06:09<00:00, 23.01it/s]
Writing file:E:/datasets/cc12m/00242.tar: 100%|██████████| 8520/8520 [06:08<00:00, 23.11it/s]
Writing file:E:/datasets/cc12m/00243.tar: 100%|██████████| 8450/8450 [06:08<00:00, 22.93it/s]
Writing file:E:/datasets/cc12m/00244.tar: 100%|██████████| 8

In [None]:
# input_shards = braceexpand.braceexpand("cc12m/{00154..00200}.tar")
# output_shards = braceexpand.braceexpand("cc12m_aug/{00154..00200}.tar")
# for input_shard, output_shard in zip(input_shards, output_shards):
#     batch_augment_wds(input=input_shard, output=output_shard, batch_size=4)

In [None]:
# input_shards = braceexpand.braceexpand("cc12m/{00100..00102}.tar")
# output_shards = braceexpand.braceexpand("file:E:/datasets/cc12m/{00100..00102}.tar")
# results = []
# for input_shard, output_shard in zip(input_shards, output_shards):
#     results.append(dask.delayed(augment_wds)(input_shard, output_shard))
# dask.compute(*results)