# Convert the dataset representation in the latent Space format

In [1]:
import sys
sys.path.append("/mnt/data1/bardella_data/gitRepos/Deep-Learning-Techniques-for-Image-Generation-from-Music")


import torch, importlib, numpy as np, json, os
from tqdm import tqdm
import albumentations
from PIL import Image
from torchvision import transforms
from omegaconf import OmegaConf

device = "cuda:0" if torch.cuda.is_available() else None
assert device


def instantiate_from_config(config):
    if not "target" in config:
        if config == '__is_first_stage__':
            return None
        elif config == "__is_unconditional__":
            return None
        raise KeyError("Expected key `target` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()))


def get_obj_from_str(string, reload=False):
    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)


def load_model_from_config(config, ckpt, device):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config)
    m, u = model.load_state_dict(sd, strict=False)
    model.to(device)
    model.eval()
    return model

def preprocess_image(size, image_path):
        
    rescaler = albumentations.SmallestMaxSize(max_size = size)
    cropper = albumentations.CenterCrop(height=size,width=size)
    preprocessor = albumentations.Compose([rescaler, cropper])

    image = Image.open(image_path)
    if not image.mode == "RGB":
        image = image.convert("RGB")
    image = np.array(image).astype(np.uint8)
    image_transformed = preprocessor(image=image)["image"]
    image = (image_transformed/127.5 - 1.0).astype(np.float32)

    return torch.tensor(image).permute(2, 0, 1)

In [2]:
ROOT_PATH = "/mnt/data1/bardella_data/gitRepos/Deep-Learning-Techniques-for-Image-Generation-from-Music"

ckpt_path = ROOT_PATH + "/pretrained_model/vq-f8/pretrained_last_openimage.ckpt"
dataset_path = "/mnt/data1/bardella_data/gitRepos/Thesis/Datasets/wikiart"
compressed_dataset_path = "/mnt/data1/bardella_data/gitRepos/Thesis/Datasets/wikiart_compressed"

dataset_info_path = dataset_path + "/dataset.json"
with open(dataset_info_path, "r") as info:
    dataset_infos = json.load(info)["labels"]

experiment_cfg_path = ROOT_PATH + "/configs/custom_vqgan.yaml"
config = OmegaConf.load(experiment_cfg_path)

model = load_model_from_config(config=config.model, ckpt=ckpt_path, device=device)

Loading model from /mnt/data1/bardella_data/gitRepos/Deep-Learning-Techniques-for-Image-Generation-from-Music/pretrained_model/vq-f8/pretrained_last_openimage.ckpt
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
loaded pretrained LPIPS loss from /mnt/data1/bardella_data/gitRepos/Thesis/ldm_porting/pretrained_model/vgg/vgg.pth
VQLPIPSWithDiscriminator running with hinge loss.


In [3]:
samples = None
for image_path, label in tqdm(dataset_infos[:samples]):

    input_path = "/".join([dataset_path, image_path])
    out_path = "/".join([compressed_dataset_path, image_path.split("/")[-2]])
    image_name = image_path.split("/")[-1].split(".")[0]
    
    if not os.path.exists(out_path):
        os.makedirs(out_path)

    preprocessed_image = preprocess_image(size=512, image_path=input_path)
    img_in = preprocessed_image.to(device=device).unsqueeze(dim=0)

    quant, _, _  = model.encode(img_in)
    quant.squeeze_(dim=0)

    torch.save(quant, f"{out_path}/{image_name}.pt")

100%|██████████| 19774/19774 [42:18<00:00,  7.79it/s]
