In [1]:
import requests
import random
import torch
import time
import yaml
import io
from omegaconf import OmegaConf
from taming.models.vqgan import VQModel
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from PIL import Image
import struct
from queue import Queue
from threading import Thread

HEADERS = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36'}
DATA_PATH = "D:/C12M/cc12m_tr.tsv"
OUT_PATH = "D:/C12M/cc12m_tr_encoded.bin"
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
PRETRAIN_DIR = "pretrained"
MODEL_DIR = "vqgan_f16_16384"
NUM_PATCHES = 256
NUM_DOWNLOADERS = 15
NUM_PREENCODERS = 1
INFO_FREQ = 50

print(f"DEVICE: {DEVICE}")

def load_config(config_path, display=False):
    config = OmegaConf.load(config_path)
    if display:
        print(yaml.dump(OmegaConf.to_container(config)))
    return config

def load_vqgan(config, ckpt_path=None, is_gumbel=False):
    model = VQModel(**config.model.params)
    if ckpt_path is not None:
        sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
        missing, unexpected = model.load_state_dict(sd, strict=False)
    return model.eval()

def preprocess_vqgan(x):
    x = 2.*x - 1.
    return x

def custom_to_pil(x):
    x = x.detach().cpu()
    x = torch.clamp(x, -1., 1.)
    x = (x + 1.)/2.
    x = x.permute(1,2,0).numpy()
    x = (255*x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x

def vqgan_encode(x, model):
    with torch.no_grad():
        z, _, [_, _, indices] = model.encode(x)
        return z, indices

def vqgan_decode(x, model):
    with torch.no_grad():
        return model.decode(x)

def vqgan_reconstruct(x, model):
    with torch.no_grad():
        z, _, [_, _, indices] = model.encode(x)
        xrec = model.decode(z)
        return xrec

def download_image(url, headers=HEADERS):
    resp = requests.get(url, headers=headers)
    resp.raise_for_status()
    return Image.open(io.BytesIO(resp.content))

def preprocess(img, target_image_size=256):
    s = min(img.size)
    r = target_image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [target_image_size])
    img = torch.unsqueeze(T.ToTensor()(img), 0)
    return img

DEVICE: cuda:0


In [2]:
def preencode(data_path, out_path, num_patches, device="cpu", max_writes=-1, info_rate=100, verbose=False):
	cfg_vqgan = load_config(f"../{PRETRAIN_DIR}/{MODEL_DIR}/configs/model.yaml", display=False)
	model_vqgan = load_vqgan(cfg_vqgan, ckpt_path=f"../{PRETRAIN_DIR}/{MODEL_DIR}/checkpoints/last.ckpt").to(DEVICE)
	model_vqgan.eval()
	with open(data_path, "r", encoding="utf-8") as src, open(out_path, "wb") as dest:
		dest.write(struct.pack("i", num_patches))
		num_writes = 0
		line_index = 0
		for line in src:
			if num_writes == max_writes:
				break
			url, en, tr = line.strip().split("\t")
			try:
				image = preprocess(download_image(url), num_patches).to(device)
				with torch.no_grad():
					_, image_tokens = vqgan_encode(image, model_vqgan)
				image_tokens = image_tokens.flatten().type(torch.int32)
				assert len(image_tokens) == num_patches
				if verbose:
					print(image_tokens)
				dest.write(struct.pack("i", line_index))
				for i in range(num_patches):
					dest.write(struct.pack("i", image_tokens[i]))
				num_writes += 1
				if num_writes % info_rate == 0:
					print(f"INFO: Total lines encoded - {num_writes}.")
			except Exception as e:
				print(f"ERR: {e} at line {line_index}")
			line_index += 1

def load_tensor(f, patches):
	idx_bytes = f.read(4)
	if not idx_bytes:
		return -1, None
	index = struct.unpack("i", idx_bytes)
	data = torch.zeros(patches, dtype=torch.int32)
	for i in range(patches):
		data[i] = struct.unpack("i", f.read(4))[0]
	return index, data

In [3]:
def download_scheduler(data_path, url_queue:Queue):
	with open(data_path, "r", encoding="utf-8") as src:
		line_index = 0
		for line in src:
			url, _, _ = line.strip().split("\t")
			url_queue.put((line_index, url))
			line_index += 1
		url_queue.put(None)

def download_worker(url_queue:Queue, image_queue:Queue, worker_id):
	running = True
	while running:
		task = url_queue.get()
		if task is None:
			url_queue.put(None)
			running = False
			break
		try:
			idx = task[0]
			url = task[1]
			img = download_image(url)
			image_queue.put((idx, img))
		except:
			pass

def preencode_worker(image_queue:Queue, token_queue:Queue, worker_id, num_patches, device, info_freq, wait_thresh):
	time.sleep(worker_id * 3.5) # delay so models are not loaded in parallel
	cfg_vqgan = load_config(f"../{PRETRAIN_DIR}/{MODEL_DIR}/configs/model.yaml", display=False)
	model_vqgan = load_vqgan(cfg_vqgan, ckpt_path=f"../{PRETRAIN_DIR}/{MODEL_DIR}/checkpoints/last.ckpt").to(DEVICE)
	model_vqgan.eval()
	running = True
	waiting_time = 0
	process_time = 0
	profile_count = 0
	fallback = 1
	while running:
		wait_start = time.time()
		task = image_queue.get()
		wait_end = time.time()
		waiting_time += wait_end - wait_start
		if task is None:
			image_queue.put(None)
			token_queue.put(None)
			running = False
			break
		process_start = time.time()
		idx = task[0]
		try:
			img = preprocess(task[1], num_patches).to(device)
			with torch.no_grad():
				_, tokens = vqgan_encode(img, model_vqgan)
			tokens = tokens.flatten().type(torch.int32).to("cpu")
			token_queue.put((idx, tokens))
		except Exception as e:
			pass
		process_end = time.time()
		process_time += process_end - process_start
		profile_count += 1
		if profile_count == info_freq:
			avg_wait = waiting_time / info_freq
			avg_proc = process_time / info_freq
			print(f"ENCODER-{worker_id}: AVG Wait - {avg_wait} AVG Process - {avg_proc}")
			waiting_time = 0
			process_time = 0
			profile_count = 0
			if worker_id > 0 and avg_wait > avg_proc: # Too many workers
				sleep_amt = worker_id * fallback * 15
				print(f"ENCODER-{worker_id}: AVG Wait > AVG Process. Sleeping for {sleep_amt} with fallback {fallback}.")
				time.sleep(sleep_amt)
				fallback += 1
			else:
				fallback = max(1, fallback - 1)

def token_combiner(out_path, token_queue:Queue, num_patches, num_preencoders, info_freq):
	running = True
	retired_preencoders = 0
	written_lines = 0
	file = open(out_path, "wb")
	file.write(struct.pack("i", num_patches))
	start = time.time()
	while running:
		task = token_queue.get()
		if task is None:
			retired_preencoders += 1
			running = retired_preencoders < num_preencoders
			continue
		idx = task[0]
		tokens = task[1]
		if file.closed:
			file = open(out_path, "ab")
		try:
			assert len(tokens) == num_patches
			file.write(struct.pack("i", idx))
			file.write(struct.pack(f"{num_patches}i", *tokens))
			written_lines += 1
			if written_lines % info_freq == 0:
				print(f"COMBINER: Encoded {written_lines} images. EPS: {info_freq / (time.time() - start)}")
				start = time.time()
		except Exception as e:
			print(f"COMBINER: {e}")
	file.flush()
	file.close()
		

In [4]:
def encode_pipelined():
	url_queue = Queue()
	image_queue = Queue()
	token_queue = Queue()
	scheduler = Thread(target=download_scheduler, args=(DATA_PATH, url_queue))
	downloaders = []
	preencoders = []
	for i in range(NUM_DOWNLOADERS):
		thread = Thread(target=download_worker, args=(url_queue, image_queue, i))
		downloaders.append(thread)
	for i in range(NUM_PREENCODERS):
		thread = Thread(target=preencode_worker, args=(image_queue, token_queue, i, NUM_PATCHES, DEVICE, INFO_FREQ, 0.5))
		preencoders.append(thread)
	combiner = Thread(target=token_combiner, args=(OUT_PATH, token_queue, NUM_PATCHES, NUM_PREENCODERS, INFO_FREQ * NUM_PREENCODERS))
	scheduler.start()
	for i in range(NUM_DOWNLOADERS):
		downloaders[i].start()
	for i in range(NUM_PREENCODERS):
		preencoders[i].start()
	combiner.start()
	scheduler.join()
	print("SCHEDULER: Done")
	for i in range(NUM_DOWNLOADERS):
		downloaders[i].join()
		print(f"DOWNLOADER-{i}: Done")
	for i in range(NUM_PREENCODERS):
		preencoders[i].join()
		print(f"ENCODER-{i}: Done")
	combiner.join()
	print(f"COMBINER-{i}: Done")

In [5]:
def sanity_check():
	preencode(DATA_PATH, OUT_PATH, NUM_PATCHES, DEVICE, max_writes=1, verbose=True)
	file = open(OUT_PATH, "rb")
	file.read(4)
	print(load_tensor(file, NUM_PATCHES))

In [6]:
encode_pipelined()

Working with z of shape (1, 256, 16, 16) = 65536 dimensions.
SCHEDULER: Done
loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips\vgg.pth
VQLPIPSWithDiscriminator running with hinge loss.




ENCODER-0: AVG Wait - 0.00045638084411621095 AVG Process - 0.1200715446472168
COMBINER: Encoded 50 images. EPS: 4.625401095437032
ENCODER-0: AVG Wait - 0.0003170013427734375 AVG Process - 0.08993515968322754
COMBINER: Encoded 100 images. EPS: 11.082502910465468
ENCODER-0: AVG Wait - 0.00042668342590332034 AVG Process - 0.08934927463531495
COMBINER: Encoded 150 images. EPS: 11.137612730945985
ENCODER-0: AVG Wait - 0.00026773929595947266 AVG Process - 0.09678948402404786
COMBINER: Encoded 200 images. EPS: 9.933631383831711
ENCODER-0: AVG Wait - 0.00037690162658691404 AVG Process - 0.10600512504577636
COMBINER: Encoded 250 images. EPS: 9.37821863787871
ENCODER-0: AVG Wait - 0.0005341625213623047 AVG Process - 0.11143348217010499
COMBINER: Encoded 300 images. EPS: 8.946264829748111
ENCODER-0: AVG Wait - 0.00039630889892578125 AVG Process - 0.09597640037536621
COMBINER: Encoded 350 images. EPS: 10.145582180443492
ENCODER-0: AVG Wait - 0.0003555917739868164 AVG Process - 0.1031099796295166
C