In [1]:
import os, sys
import torch
import argparse
import yaml
import importlib
import numpy as np
from tqdm import tqdm
import pytorch_lightning as pl
from PIL import Image
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as T

from utils_func import *

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda:0')

parser.add_argument('--root_dir', type=str, default='C:/MyFiles/CondTran/frameworks/bert_baseline')
parser.add_argument('--log_dir', type=str, default='logs/vqgan_imagenet_full')
parser.add_argument('--step', type=str, default='279999')
args = parser.parse_args(args=[])
os.chdir(args.root_dir)
sys.path.append(args.root_dir)

def find(path, name):
    for root, dirs, files in os.walk(path):
        for f in files:
            if name in f:
                return os.path.join(root, f)
args.checkpoint = find(args.log_dir, args.step+'.ckpt')

module = importlib.import_module(f'vqgan.models')
vq_model = getattr(importlib.import_module(f'vqgan.models.hybrid_vqgan'), 'VQModel')
vq_loss = getattr(importlib.import_module(f'vqgan.models.vqperceptual'), 'VQLPIPSWithDiscriminator')

config_path = os.path.join(args.log_dir, 'config.yaml')
with open(config_path, 'rb') as fin:
    config = yaml.safe_load(fin)

# Load pretrained model
vqgan_model = vq_model.load_from_checkpoint(
    args.checkpoint,
    ddconfig=config['model']['ddconfig'],
    loss=vq_loss(**config['loss']),
    n_embed=config['model']['n_embed'],
    embed_dim=config['model']['embed_dim'],
    learning_rate=0.0,
).to(args.device).eval().requires_grad_(False)


  from collections import namedtuple, Mapping
  from collections import Mapping, MutableMapping


loaded pretrained LPIPS loss from C:\MyFiles\CondTran\frameworks\bert_baseline\vqgan\models\..\data\lpips\vgg.pth
VQLPIPSWithDiscriminator running with hinge loss.
Working with z of shape (1, 1024, 16, 16) = 262144 dimensions.


In [9]:
img_dir = 'C:/MyFiles/ColorizationTran/data/raw'
img_size = [256, 256]

pbar = tqdm(enumerate(os.listdir(img_dir)))
for i, filename in pbar:
    #if filename.endswith('.jpg') or filename.endswith('.png') or filename.endswith('.JPEG') or filename.endswith('.jpeg'):
    if 'in3' in filename:
        I_color = Image.open(os.path.join(img_dir, filename)).convert('RGB')
        I_gray = I_color.convert('L')


        #I_color = draw_color(I_gray, [255, 50, 10], [None, None, None, None])

        x_color = preprocess(I_color, img_size).to(args.device)
        x_gray = preprocess(I_gray, img_size).to(args.device)

        x_color = x_color[:, :, 32:48, 48:64]
        x_gray = x_gray[:, :, 32:48, 48:64]
        # Encoding
        f_gray = vqgan_model.gray_encoder(x_gray)
        h = vqgan_model.encoder(x_color)
        h = vqgan_model.quant_conv(h)
        quant, emb_loss, info = vqgan_model.quantize(h)
        color_idx = info[2].view(quant.shape[0], -1)
        # Decoding
        q_shape = [f_gray.shape[0], f_gray.shape[2], f_gray.shape[3]]
        quant = vqgan_model.quantize.get_codebook_entry(color_idx.view(-1), q_shape)
        feat = torch.cat([quant, f_gray], dim=1)
        feat = vqgan_model.post_quant_conv(feat)
        rec = vqgan_model.decoder(feat)

        output_to_pil(x_color[0]).show()
        output_to_pil(rec[0]).show()

        print(color_idx)

20it [00:03,  6.60it/s]

tensor([[1491]], device='cuda:0')



