In [72]:
import numpy as np
import glob
import os
import importlib
import yaml
import albumentations
import glob
import json
import torch
from PIL import Image
from torch.utils.data import random_split, DataLoader, Dataset
from omegaconf import OmegaConf
from tqdm.notebook import tqdm
from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex

def get_obj_from_str(string, reload=False):
    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)

def instantiate_from_config(config):
    if not "target" in config:
        raise KeyError("Expected key `target` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()))

# dataset_root_path = ""
# source_dataset = ""
# dataset_name = ""

# load config for transformer
config_path = "configs/owt_transformer.yaml"
config = OmegaConf.load(config_path)

vggan_config = config.model.params.first_stage_config
cond_vggan_config = config.model.params.cond_stage_config

vggan_ckpt_path = vggan_config.params.ckpt_path
cond_vggan_ckpt_path = cond_vggan_config.params.ckpt_path

print(f"Read checkpoint path {vggan_ckpt_path} for dqgan")
print(f"Read checkpoint path {cond_vggan_ckpt_path} for cond dqgan")

Read checkpoint path logs/2022-07-28T00-02-08_usc_pretrained_vggan/checkpoints/last.ckpt for dqgan
Read checkpoint path logs/2022-08-01T01-37-16_use_2048_cond_stage/checkpoints/last.ckpt for cond dqgan


In [76]:
# load models
device = torch.device('cuda')

vggan = instantiate_from_config(vggan_config).to(device)
vggan_cond = instantiate_from_config(cond_vggan_config).to(device)

Working with z of shape (1, 256, 16, 16) = 65536 dimensions.
Restored from logs/2022-07-28T00-02-08_usc_pretrained_vggan/checkpoints/last.ckpt
Working with z of shape (1, 256, 16, 16) = 65536 dimensions.
Restored from logs/2022-08-01T01-37-16_use_2048_cond_stage/checkpoints/last.ckpt


In [82]:
config

{'model': {'base_learning_rate': 4.5e-06, 'target': 'taming.models.cond_transformer.Net2NetTransformer', 'params': {'cond_stage_key': 'segmentation', 'transformer_config': {'target': 'taming.modules.transformer.mingpt.CondGPT', 'params': {'vocab_size': 8192, 'vocab_size_cond': 1024, 'block_size': 512, 'n_layer': 40, 'n_head': 16, 'n_embd': 1408, 'embd_pdrop': 0.1, 'resid_pdrop': 0.1, 'attn_pdrop': 0.1}}, 'first_stage_config': {'target': 'taming.models.vqgan.VQModel', 'params': {'ckpt_path': 'logs/2022-07-28T00-02-08_usc_pretrained_vggan/checkpoints/last.ckpt', 'embed_dim': 256, 'n_embed': 8192, 'ddconfig': {'double_z': False, 'z_channels': 256, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 1, 2, 2, 4], 'num_res_blocks': 2, 'attn_resolutions': [16], 'dropout': 0.0}, 'lossconfig': {'target': 'taming.modules.losses.DummyLoss'}}}, 'cond_stage_config': {'target': 'taming.models.vqgan.VQSegmentationModel', 'params': {'ckpt_path': 'logs/2022-08-01T01-37-16_use_20

In [83]:
# dataloader
size = config.data.params.train.params.size
dataroot = config.data.params.train.params.dataroot
datasetname = f"tokenized_{size}"
os.makedirs(os.path.join(dataroot, datasetname),exist_ok=True)
rescaler = albumentations.SmallestMaxSize(max_size=size)
# cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
# preprocessor = albumentations.Compose([self.rescaler, self.cropper],additional_targets={"segmentation": "image"})
# rescaler = albumentations.Resize(height=self.size, width=self.size)
preprocessor = albumentations.Compose([rescaler], additional_targets={"segmentation": "image"})

def preprocess_image(image_path, segmentation_path, onehot=True):
    image = Image.open(image_path)
    if not image.mode == "RGB":
        image = image.convert("RGB")
    image = np.array(image).astype(np.uint8)

    segmentation = np.load(segmentation_path)
    segmentation = segmentation * (segmentation>-1)
    segmentation = segmentation.astype(np.uint8)
    processed = preprocessor(image=image, segmentation=segmentation)
    image, segmentation = processed["image"], processed["segmentation"]
    image = (image / 127.5 - 1.0).astype(np.float32)

    if onehot:
        assert segmentation.dtype == np.uint8
        # make it one hot
        n_labels = 3
        flatseg = np.ravel(segmentation)
        onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
        onehot[np.arange(flatseg.size), flatseg] = True
        onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
        segmentation = onehot
    else:
        # normalizing to (-1, 1)
        segmentation = (segmentation / 1.0 - 1.0).astype(np.float32)

    return image, segmentation

def patch_encode(x, model):
    h,w = x.shape[:2]
    index_grid = np.zeros([h//16, w//16],dtype=np.int64)
    grid_size = crop_size // 16
    for i in range(0,h,crop_size):
        for j in range(0,w,crop_size):
            if i + crop_size > h or j + crop_size > w:
                continue
            patch_np = x[i:i+crop_size, j:j+crop_size]
            patch = torch.from_numpy(patch_np[None]).permute(0,3,1,2).to(device).float()
            quant, diff, info  = model.encode(patch)
            _, _, z_index = info # z_index 16x16 = [256]
            z_index = z_index.detach().cpu().numpy().reshape(16,16)
            index_grid[i//16:i//16 + grid_size, j//16:j//16 + grid_size] = z_index
    index_grid = index_grid[:i//16 + grid_size,:j//16+ grid_size]
    return index_grid

ids = [f[:-4].split('/')[-1] for f in glob.glob(os.path.join(dataroot, "*.JPG"))] # self.json_data["images"]     
crop_size = config.data.params.train.params.crop_size

with torch.no_grad():
    for id in tqdm(ids):
        image_path = os.path.join(dataroot, id + '.JPG')
        segmentation_path = os.path.join(dataroot, id + '.npy')    
        img, seg = preprocess_image(image_path, segmentation_path)
        index_img = patch_encode(img, vggan)
        index_seg = patch_encode(seg, vggan_cond)
        
        # save index matrix to patch
        np.save(os.path.join(dataroot, datasetname, f"{id}_img.npy"), index_img)
        np.save(os.path.join(dataroot, datasetname, f"{id}_cond.npy"), index_seg)

  0%|          | 0/269 [00:00<?, ?it/s]

In [39]:
patch = torch.from_numpy(patch_np[None]).permute(0,3,1,2).float()
quant, diff, info  = vggan.encode(patch)
_, _, z_index = info # z_index 256x256/(16x16) = [256]

In [81]:
id

'DJI_0838'