# Dataloader experiments

In [None]:
import yaml
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from time import time

import torch
from torchvision import transforms as T, utils
from torch.utils.data import DataLoader
from imagen_pytorch.t5 import t5_encode_text
import webdataset as wds

%config InlineBackend.figure_format = 'retina'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

In [None]:
cfg = yaml.safe_load(Path("configs\\imagen-medium-config.yaml").read_text())

preproc = T.Compose([
    T.Resize(cfg["dataset"]["image_size"]),
    T.RandomHorizontalFlip(),
    T.CenterCrop(cfg["dataset"]["image_size"]),
    T.ToTensor()
])

In [None]:
def get_emb_tensor(cfg, targets, device):
    text_embeds = t5_encode_text(targets, name="google/t5-v1_1-xl", return_attn_mask=False)
    return text_embeds.to(device)


def padding_tensor(sequences):
    """
    https://discuss.pytorch.org/t/how-to-do-padding-based-on-lengths/24442/2
    :param sequences: list of tensors
    :return:
    """
    num = len(sequences)
    max_len = max([s.size(1) for s in sequences])
    out_dims = (num, max_len, 2048)
    out_tensor = sequences[0].data.new(*out_dims).fill_(0)
    for i, tensor in enumerate(sequences):
        length = tensor.size(1)
        out_tensor[i, :length, :] = tensor[:,:,0].permute(1,0)
    return out_tensor


def my_collate(batch):
    imgs = [item[0] for item in batch]
    embeds = [item[1] for item in batch]
    embeds = padding_tensor(embeds)
    return [imgs, embeds]



def benchmark_regular(iters=50):
    cc_dataset = (
        wds.WebDataset("cc12m/{00000..00030}.tar")
        .shuffle(240)
        .decode("pilrgb")
        .rename(image="jpg;png", caption="txt")
        .map_dict(image=preproc)
        .to_tuple("image", "caption")
    )
    
    cc_dataloader = DataLoader(
        dataset = cc_dataset, 
        batch_size = cfg["train"]["batch_size"], 
        drop_last = True,
        # num_workers = 3,
        # pin_memory = False
    )
    
    step_times = []
    start = time()
    for step, batch in enumerate(cc_dataloader):
        print(f"\r Step {step}", end='')
        
        if step > 0:
            step_start = time()
        images, texts = batch
        text_embeds = get_emb_tensor(cfg, texts, device)
        if step > 0:
            step_end = time()
            step_times.append(step_end-step_start)
        if step == iters:
            break
    end = time()
    step_time = np.mean(step_times)
    step_std = np.std(step_times)
    print(f"\n Finished in {end-start:.1f}s at {step_time:.4f}s/it +\- {step_std:.4f}")
    
    
def benchmark_aug(iters=50):
    cc_dataset = (
        wds.WebDataset("file:E:/datasets/cc12m/{00000..00030}.tar") 
        .shuffle(240)
        .decode("pilrgb")
        .rename(image="png", embedding="emb.pyd")
        .map_dict(image=preproc)
        .to_tuple("image", "embedding")
    )
    
    cc_dataloader = DataLoader(
        dataset = cc_dataset, 
        batch_size = cfg["train"]["batch_size"], 
        drop_last = True,
        # num_workers = 3,
        # pin_memory = False,
        collate_fn=my_collate
    )
    
    step_times = []
    start = time()
    for step, batch in enumerate(cc_dataloader):
        print(f"\r Step {step}", end='')
        if step > 0:
            step_start = time()
        images, texts = batch
        images = torch.stack(images, dim=0)
        print(images.size())
        text = texts.to(device)
        if step > 0:
            step_end = time()
            step_times.append(step_end-step_start)
        if step == iters:
            break
    end = time()
    step_time = np.mean(step_times)
    step_std = np.std(step_times)
    print(f"\n Finished in {end-start:.1f}s at {step_time:.4f}s/it +\- {step_std:.4f}")

In [None]:
benchmark_regular(50)

In [None]:
benchmark_aug(50)

| from embed | batch size | time | step time | step std |
| --- | --- | --- | --- | --- |
| F | 240 | 180.8s | 2.2411s | 0.7672s |
| F | 240 | 128.5s | 2.0428s | 0.4996s |
| T | 240 | 51.7s | 0.0960s | 0.0146s |
| T | 240 | 80.3s | 0.0939s | 0.0212s |