In [60]:
import requests
import torch
import yaml
import io
from omegaconf import OmegaConf
from taming.models.vqgan import VQModel
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from PIL import Image

HEADERS = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'}
DATA_PATH = "../data/cc12m_sample/cc12m_tr.tsv"
OUT_PATH = "../data/cc12m_sample/cc12m_tr_encoded.bin"
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
PRETRAIN_DIR = "pretrained"
MODEL_DIR = "vqgan_f16_16384"
NUM_PATCHES = 256

def load_config(config_path, display=False):
    config = OmegaConf.load(config_path)
    if display:
        print(yaml.dump(OmegaConf.to_container(config)))
    return config

def load_vqgan(config, ckpt_path=None, is_gumbel=False):
    model = VQModel(**config.model.params)
    if ckpt_path is not None:
        sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
        missing, unexpected = model.load_state_dict(sd, strict=False)
    return model.eval()

def preprocess_vqgan(x):
    x = 2.*x - 1.
    return x

def custom_to_pil(x):
    x = x.detach().cpu()
    x = torch.clamp(x, -1., 1.)
    x = (x + 1.)/2.
    x = x.permute(1,2,0).numpy()
    x = (255*x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x

def vqgan_encode(x, model):
    with torch.no_grad():
        z, _, [_, _, indices] = model.encode(x)
        return z, indices

def vqgan_decode(x, model):
    with torch.no_grad():
        return model.decode(x)

def vqgan_reconstruct(x, model):
    with torch.no_grad():
        z, _, [_, _, indices] = model.encode(x)
        xrec = model.decode(z)
        return xrec

def download_image(url, headers=HEADERS):
    resp = requests.get(url, headers=headers)
    resp.raise_for_status()
    return Image.open(io.BytesIO(resp.content))

def preprocess(img, target_image_size=256):
    s = min(img.size)
    r = target_image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [target_image_size])
    img = torch.unsqueeze(T.ToTensor()(img), 0)
    return img

In [61]:
cfg_vqgan = load_config(f"../{PRETRAIN_DIR}/{MODEL_DIR}/configs/model.yaml", display=False)
model_vqgan = load_vqgan(cfg_vqgan, ckpt_path=f"../{PRETRAIN_DIR}/{MODEL_DIR}/checkpoints/last.ckpt").to(DEVICE)
model_vqgan.eval()
print("VQGAN Loaded.")

Working with z of shape (1, 256, 16, 16) = 65536 dimensions.
loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips\vgg.pth
VQLPIPSWithDiscriminator running with hinge loss.
VQGAN Loaded.


In [62]:
image = preprocess(download_image("https://www.abc.net.au/news/image/9329676-3x2-940x627.jpg"), NUM_PATCHES).to(DEVICE)
image_encode, image_tokens = vqgan_encode(image, model_vqgan)
image_tokens = image_tokens.type(torch.int32)



In [69]:
import struct

def preencode(data_path, out_path, num_patches, device="cpu", max_writes=-1, info_rate = 100):
	with open(data_path, "r", encoding="utf-8") as src, open(out_path, "wb") as dest:
		dest.write(struct.pack("i", num_patches))
		num_writes = 0
		line_index = 0
		for line in src:
			if num_writes == max_writes:
				break
			url, en, tr = line.strip().split("\t")
			try:
				image = preprocess(download_image(url), num_patches).to(device)
				with torch.no_grad():
					_, image_tokens = vqgan_encode(image, model_vqgan)
				image_tokens = image_tokens.flatten().type(torch.int32)
				assert len(image_tokens) == num_patches
				dest.write(struct.pack("i", line_index))
				for i in range(num_patches):
					dest.write(struct.pack("i", image_tokens[i]))
				num_writes += 1
				if num_writes % info_rate == 0:
					print(f"INFO: Total lines encoded - {num_writes}.")
			except Exception as e:
				print(f"ERR: {e} at line {line_index}")
			line_index += 1

def load_tensor(f, patches):
	idx_bytes = f.read(4)
	if not idx_bytes:
		return -1, None
	index = struct.unpack("i", idx_bytes)
	data = torch.zeros(patches, dtype=torch.int32)
	for i in range(patches):
		data[i] = struct.unpack("i", f.read(4))[0]
	return index, data

In [70]:
preencode(DATA_PATH, OUT_PATH, NUM_PATCHES, DEVICE)

ERR: cannot identify image file <_io.BytesIO object at 0x000002179631FA40> at line 8
INFO: 10 lines encoded.
ERR: 403 Client Error: Forbidden for url: https://assets.nst.com.my/images/articles/18nt23sunway_1531880949.jpg at line 11


KeyboardInterrupt: 