# Computation of caption text embeddings

In [1]:
import json
import torch
import random
import braceexpand
from time import time
from tqdm import tqdm
import webdataset as wds
from imagen_pytorch.t5 import t5_encode_text
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

cuda


In [3]:
def get_emb_batch(text):
    text_embeds = t5_encode_text(text, name="google/t5-v1_1-xl", return_attn_mask=False)
    text_embeds = text_embeds.cpu()
    emb_batch = []
    for tensor in text_embeds:
        ix, iy = tensor.nonzero(as_tuple=True)
        tensor_nonzero = tensor[None, 0:max(ix)+1:]
        emb_batch.append(tensor_nonzero)
    return emb_batch


def batch_augment_wds(input_shard, output_shard, batch_size):
    start = time()
    count = get_count(input_shard)
    input_shard = "file:"+input_shard
    
    src = wds.DataPipeline(
        wds.SimpleShardList(input_shard),
        wds.tarfile_to_samples(),
        wds.decode("pil"),
        wds.to_tuple("__key__", "jpg;png", "txt")
    )
    
    idx = 1
    batch_idx = 0
    keys=[]; imgs=[]; caps=[]; embs=[]
    for key, img, cap in tqdm(src, total=count, desc=f"Extracting {input_shard}"):
        keys.append(key)
        imgs.append(img)
        caps.append(cap)
        
        if ((idx%batch_size) == 0 and idx != 1) or idx == count:
            emb_batch = get_emb_batch(caps[batch_size*batch_idx:])
            for emb in emb_batch:
                embs.append(emb)
            batch_idx += 1
        idx += 1
                
    dst = wds.TarWriter(output_shard)
    for key, img, cap, emb in tqdm(zip(keys, imgs, caps, embs), total=count, desc=f"Writing {output_shard}"):
        dst.write({
            "__key__":key, 
            "png":img, 
            "txt":cap, 
            "emb.pyd":emb
        })
        
    end = time()
    print(f"Finished - {end-start:.0f}s")

In [2]:
def get_emb_tensor(text):
    text_embeds = t5_encode_text([text], name="google/t5-v1_1-xl", return_attn_mask=False)
    return text_embeds.cpu()


def get_count(input_file):
    stats_file = input_file[:-4] + "_stats.json"
    f = open(stats_file)
    stats = json.load(f)
    f.close()
    count = stats["successes"]
    return count


def shuffle_augment_wds(input, output):
    """Takes ~300s for each .tar file"""
    start = time()
    count = get_count(input)
    input = "file:"+input
    src = wds.DataPipeline(
        wds.SimpleShardList(input),
        wds.tarfile_to_samples(),
        wds.decode("pil"),
        wds.to_tuple("__key__", "jpg;png", "txt", "txt"),
        wds.map_tuple(None, None, None, get_emb_tensor)
    )
    
    samples = []
    for key, img, cap, emb in tqdm(src, total=count, desc=f"Extracting {input}"):
        samples.append([key, img, cap, emb])
    random.shuffle(samples)    
    
    dst = wds.TarWriter(output)
    for sample in tqdm(samples, total=count, desc=f"Writing {output}"):
        dst.write({
            "__key__":sample[0], 
            "png":sample[1], 
            "txt":sample[2], 
            "emb.pyd":sample[3]
        })
    end = time()
    print(f"Finished - {end-start:.0f}s")

In [4]:
# input_shards = braceexpand.braceexpand("cc12m_original/{01095..01242}.tar")
# output_shards = braceexpand.braceexpand("file:E:/datasets/cc12m_w_embeds/{01095..01242}.tar")
# for input_shard, output_shard in zip(input_shards, output_shards):
#     batch_augment_wds(input_shard, output_shard, batch_size=8)

In [5]:
input_shards = braceexpand.braceexpand("cc12m_original/{01229..01242}.tar")
output_shards = braceexpand.braceexpand("file:E:/datasets/cc12m_w_embeds/{01229..01242}.tar")
for input_shard,  output_shard in zip(input_shards, output_shards):
    shuffle_augment_wds(input=input_shard, output=output_shard)

Extracting file:cc12m_original/01229.tar: 100%|██████████| 8461/8461 [03:46<00:00, 37.37it/s]
Writing file:E:/datasets/cc12m_w_embeds/01229.tar: 100%|██████████| 8461/8461 [01:20<00:00, 105.54it/s]


Finished - 307s


Extracting file:cc12m_original/01230.tar: 100%|██████████| 8536/8536 [03:21<00:00, 42.32it/s]
Writing file:E:/datasets/cc12m_w_embeds/01230.tar: 100%|██████████| 8536/8536 [01:24<00:00, 100.73it/s]


Finished - 286s


Extracting file:cc12m_original/01231.tar: 100%|██████████| 8542/8542 [03:21<00:00, 42.38it/s]
Writing file:E:/datasets/cc12m_w_embeds/01231.tar: 100%|██████████| 8542/8542 [01:27<00:00, 97.23it/s] 


Finished - 289s


Extracting file:cc12m_original/01232.tar: 100%|██████████| 8437/8437 [03:19<00:00, 42.25it/s]
Writing file:E:/datasets/cc12m_w_embeds/01232.tar: 100%|██████████| 8437/8437 [01:25<00:00, 98.62it/s] 


Finished - 285s


Extracting file:cc12m_original/01233.tar: 100%|██████████| 8476/8476 [03:19<00:00, 42.45it/s]
Writing file:E:/datasets/cc12m_w_embeds/01233.tar: 100%|██████████| 8476/8476 [01:27<00:00, 97.34it/s] 


Finished - 287s


Extracting file:cc12m_original/01234.tar: 100%|██████████| 8471/8471 [03:20<00:00, 42.20it/s]
Writing file:E:/datasets/cc12m_w_embeds/01234.tar: 100%|██████████| 8471/8471 [01:26<00:00, 98.10it/s] 


Finished - 287s


Extracting file:cc12m_original/01235.tar: 100%|██████████| 8607/8607 [03:23<00:00, 42.33it/s]
Writing file:E:/datasets/cc12m_w_embeds/01235.tar: 100%|██████████| 8607/8607 [01:28<00:00, 97.11it/s] 


Finished - 292s


Extracting file:cc12m_original/01236.tar: 100%|██████████| 8495/8495 [03:19<00:00, 42.52it/s]
Writing file:E:/datasets/cc12m_w_embeds/01236.tar: 100%|██████████| 8495/8495 [01:27<00:00, 97.24it/s] 


Finished - 287s


Extracting file:cc12m_original/01237.tar: 100%|██████████| 8489/8489 [03:20<00:00, 42.27it/s]
Writing file:E:/datasets/cc12m_w_embeds/01237.tar: 100%|██████████| 8489/8489 [01:27<00:00, 96.96it/s] 


Finished - 288s


Extracting file:cc12m_original/01238.tar: 100%|██████████| 8543/8543 [03:21<00:00, 42.48it/s]
Writing file:E:/datasets/cc12m_w_embeds/01238.tar: 100%|██████████| 8543/8543 [01:28<00:00, 96.05it/s] 


Finished - 290s


Extracting file:cc12m_original/01239.tar: 100%|██████████| 8541/8541 [03:22<00:00, 42.25it/s]
Writing file:E:/datasets/cc12m_w_embeds/01239.tar: 100%|██████████| 8541/8541 [01:29<00:00, 95.47it/s] 


Finished - 292s


Extracting file:cc12m_original/01240.tar: 100%|██████████| 8479/8479 [03:19<00:00, 42.54it/s]
Writing file:E:/datasets/cc12m_w_embeds/01240.tar: 100%|██████████| 8479/8479 [01:28<00:00, 95.66it/s] 


Finished - 288s


Extracting file:cc12m_original/01241.tar: 100%|██████████| 8485/8485 [03:20<00:00, 42.24it/s]
Writing file:E:/datasets/cc12m_w_embeds/01241.tar: 100%|██████████| 8485/8485 [01:28<00:00, 96.22it/s] 


Finished - 289s


Extracting file:cc12m_original/01242.tar: 100%|██████████| 2869/2869 [01:07<00:00, 42.58it/s]
Writing file:E:/datasets/cc12m_w_embeds/01242.tar: 100%|██████████| 2869/2869 [00:27<00:00, 105.28it/s]


Finished - 95s
