In [2]:
import yaml
import random
import pandas as pd
from PIL import Image
from pathlib import Path
from functools import partial

import torch
from torch import nn
from torch.utils import data
from torchvision import transforms as T, utils
from imagen_pytorch.t5 import t5_encode_text

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

cpu


In [5]:
def get_emb_tensor(cfg, targets, device):
    text_embeds, text_masks = t5_encode_text(targets, name = cfg["model"]["text_encoder_name"])
    text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks))
    return text_embeds, text_masks


def exists(val):
    return val is not None


def cycle(dl):
    while True:
        for data in dl:
            yield data

            
def convert_image_to(img_type, image):
    if image.mode != img_type:
        return image.convert(img_type)
    return image


class Dataset(data.Dataset):
    def __init__(
        self, 
        cfg,
        exts = ['jpg','jpeg','png','tiff'], 
        convert_image_to_type = None
    ):
        super().__init__()
        
        self.captions_per_img = cfg["dataset"]["captions_per_image"]
        self.info_df = pd.read_pickle(cfg["dataset"]["info_file"])
        self.image_size = cfg["dataset"]["image_size"]
        self.cfg = cfg
        
        
        self.image_paths = self.info_df["file_path"].values.tolist()
        self.captions = self.info_df["caption"].values.tolist()
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

        convert_fn = partial(
            convert_image_to, 
            convert_image_to_type
        ) if exists(convert_image_to_type) else nn.Identity()

        self.transform = T.Compose([
            T.Lambda(convert_fn),
            T.Resize(self.image_size),
            T.RandomHorizontalFlip(),
            T.CenterCrop(self.image_size),
            T.ToTensor()
        ])
        
    
    def get_embed_tensor(self, caption):
        text_embeds, text_masks = t5_encode_text(caption, name=self.cfg["model"]["text_encoder_name"])
        text_embeds, text_masks = map(lambda t: t.to(self.device), (text_embeds, text_masks))
        return text_embeds, text_masks
    
    
    def compute_embeddings():
        self.info_df["text_embeds"] = None
        self.info_df["text_masks"] = None
        
        for i in range(0, len(self.captions)):
            print(f"\r computing {i+1}/{len(self.captions)}", end='')
            text_embeds, text_masks = self.get_embed_tensor(self.captions[i])
            self.info_df.loc[i, "text_embeds"] = text_embeds
            self.info_df.loc[i, "text_masks"] = text_masks
            
        self.info_df.to_pickle(self.cfg["dataset"]["info_file"])
        

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        """Modify to return tuple ('images', 'text_embeds', 'text_masks')"""
        path = self.image_paths[index]
        img = self.transform(Image.open(path))
        
        caption = self.captions[index][random.randint(0, self.captions_per_img-1)]
        text_embed, text_mask = self.get_embed_tensor(caption)
        return img, text_embed, text_mask
    
    
    
def get_coco_dataloader(
    config, 
    *, 
    batch_size,
    shuffle = True,
    cycle_dl = False,
    pin_memory = True
):
    ds = Dataset(config)
    dl = data.DataLoader(
        ds, 
        batch_size = batch_size, 
        shuffle = shuffle, 
        pin_memory = pin_memory
    )

    if cycle_dl:
        dl = cycle(dl)
    return dl

In [None]:
cfg = yaml.safe_load(Path("configs/imagen-config.yaml").read_text())

In [1]:
# ds = Dataset(cfg)
# dl = get_coco_dataloader(cfg, batch_size=64)