In [None]:
import sys
import os
import cv2
import importlib
import torch
import numpy as np
import argparse
import yaml
from tqdm import tqdm
from PIL import Image, ImageDraw, ImageOps
import json
import nltk
import torchvision.transforms as T
import matplotlib.pyplot as plt
import torch.nn.functional as F
from scipy import signal
import matplotlib.pyplot as plt
from matplotlib.widgets import Button
import clip
import kornia
import torchvision
import skimage.color
import random

from utils_func import *
from html_images import *
from sample_func import *
from ImageMatch.warp import ImageWarper
from colorizer import *

# Create colorizer

In [None]:
ckpt_file = 'C:/MyFiles/CondTran/finals/bert_final/logs/bert/epoch=14-step=142124.ckpt'
device = 'cuda:0'
colorizer = Colorizer(ckpt_file, device, [256, 256])

# Sample dataset (raw)

In [None]:
# Sampling arguments
img_dir = 'C:\\MyFiles\\Dataset\\imagenet\\val5000\\val'
save_dir = 'C:\\Users\\lucky\\Desktop\\raw_diverse'
topk = 100
num_samples = 5
img_size = [256, 256]
sample_size = [img_size[0]//16, img_size[1]//16]

html = HTML(save_dir, 'Sample')

files = os.listdir(img_dir)
np.random.shuffle(files)

pbar = tqdm(enumerate(files))
for i, filename in pbar:
    if filename.endswith('.jpg') or filename.endswith('.png') or filename.endswith('.JPEG') or filename.endswith('.jpeg'):
        fname = filename.split('.')[0]
        I_color = Image.open(os.path.join(img_dir, filename)).convert('RGB')
        I_gray = I_color.convert('L')
        
        gen_imgs = [I_color, I_gray]

        gen = colorizer.sample(I_gray, strokes=[], topk=topk)
        gen_imgs.append(gen)
        
        save_result(html, index=fname, images=gen_imgs, texts=[fname])
        html.save()

# Sample strokes

In [None]:
# Sampling arguments
topk = 100
num_samples = 1
img_size = [256, 256]
sample_size = [img_size[0]//16, img_size[1]//16]
stroke_path = 'C:\\MyFiles\\CondTran\\data\\coco_strokes.json'
dataset_path = 'C:\\MyFiles\\Dataset\\coco\\val2017'
save_dir = 'C:\\Users\\lucky\\Desktop\\final_stroke_coco'
num_strokes = [2, 16]

html = HTML(save_dir, 'Sample')

# Load strokes
with open(stroke_path, 'r') as file:
    all_strokes = json.load(file)
    print(len(all_strokes))

pbar = tqdm(enumerate(all_strokes))

random.seed(100)
for i, file in pbar:
    filename = file['image']
    strokes = file['strokes']
    n_strokes = random.randint(num_strokes[0], num_strokes[1])
    n_strokes = min(n_strokes, len(strokes))
    strokes = random.sample(strokes, k=n_strokes)
    name = filename.split('.')[0]

    gen_imgs = []

    I_color = Image.open(os.path.join(dataset_path, filename)).convert('RGB')
    I_gray = I_color.convert('L')

    gen_imgs.append(I_color)

    draw_img = I_gray.copy().resize(img_size).convert('RGB')
    draw_img = draw_strokes(draw_img, img_size, strokes)

    gen_imgs.append(draw_img.resize(I_color.size))

    for n in range(num_samples):
        gen = colorizer.sample(I_gray, strokes, topk)
        gen_imgs.append(gen)
    
    save_result(html, index=i, images=gen_imgs, texts=['original', 'strokes', 'colorized'])
    html.save()

# Sample text

In [None]:
# Sampling arguments
topk = 100
num_samples = 1
img_size = [256, 256]
sample_size = [img_size[0]//16, img_size[1]//16]
stroke_path = 'C:/MyFiles/CondTran/data/all_text_strokes.json'
dataset_path = 'C:/MyFiles/Dataset/coco/val2017'
save_dir = 'C:/MyFiles/CondTran/sample_result/slic_text'

html = HTML(save_dir, 'Sample')
html.add_header(os.path.join(args.root_dir, args.log_dir, args.step))

# Load strokes
with open(stroke_path, 'r') as file:
    all_strokes = json.load(file)

np.random.seed(10)
np.random.shuffle(all_strokes)
pbar = tqdm(enumerate(all_strokes[:200]))

for i, file in pbar:
    filename = file['image']
    strokes = file['strokes']
    name = filename.split('.')[0]

    for stk in strokes:
        ind = stk['index'].copy()
        stk['index'] = [ind[0]*16, ind[1]*16]

    gen_imgs = []

    I_color = Image.open(os.path.join(dataset_path, filename)).convert('RGB')
    I_gray = I_color.convert('L')

    gen_imgs.append(I_color)

    draw_img = I_gray.copy().resize(img_size).convert('RGB')
    for stk in strokes:
        ind = stk['index']
        color = np.array(stk['color'])
        color = np.expand_dims(color, axis=(0, 1))
        color = cv2.resize(color, (16-6, 16-6), interpolation=cv2.INTER_NEAREST)
        color = cv2.copyMakeBorder(color, 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=(255, 255, 255))
        draw_img = draw_full_color(draw_img, color, [ind[0], ind[0]+16, ind[1], ind[1]+16])

    gen_imgs.append(draw_img.resize(I_color.size))
    
    x_color = preprocess(I_color, img_size).to(args.device)
    x_gray = preprocess(I_gray, img_size).to(args.device)

    for n in range(num_samples):
        gen = filltran.sample(x_gray, topk, strokes)
        gen = output_to_pil(gen[0])
        gen_resize = color_resize(I_gray, gen)
        gen_imgs.append(gen_resize)
    
    save_result(html, index=i, images=gen_imgs, texts=[' ', file['caption']])
    html.save()



# Sample exemplar

In [None]:
# Sampling arguments
topk = 100
num_samples = 1
img_size = [256, 256]
sample_size = [img_size[0]//16, img_size[1]//16]
pair_path = 'C:/MyFiles/CondTran/data/all_exemplars.json'
dataset_path = 'C:/MyFiles/Dataset/imagenet/val5000/val'
ref_path = 'C:/MyFiles/Dataset/imagenet/full/train'
save_dir = 'C:/MyFiles/CondTran/sample_result/bert_final_exp'

html = HTML(save_dir, 'Sample')
html.add_header(os.path.join(args.root_dir, args.log_dir, args.step))

# Load strokes
with open(pair_path, 'r') as file:
    all_pairs = json.load(file)
np.random.shuffle(all_pairs)

# Load image warper
warper = ImageWarper('cuda')

i = 0
pbar = tqdm(all_pairs)
for file in pbar:
    filename = file['image']
    refname = file['exemplar']
    name = filename.split('.')[0]
    gen_imgs = []

    in_dir = os.path.join(dataset_path, filename)
    ref_dir = os.path.join(ref_path, refname)
    I_color = Image.open(in_dir).convert('RGB')
    I_gray = I_color.convert('L')
    I_ref = Image.open(ref_dir).convert('RGB')

    gen_imgs.append(I_color)
    gen_imgs.append(I_ref)

    warped_img, similarity_map = warper.warp_image(I_gray.convert('RGB'), I_ref)
    gen_imgs.append(warped_img.resize(I_color.size))
    warped_img = warped_img.resize(img_size)

    similarity_map = cv2.resize(similarity_map, tuple(sample_size))
    similarity_map = similarity_map.reshape(-1)
    threshold = min(0.23, np.sort(similarity_map)[-10])
    indices = np.where( (similarity_map >= 0.23))

    strokes = []
    warped_img = np.array(warped_img)
    for ind in indices[0]:
        index = [ind//16 * 16, ind%16 * 16]
        color = warped_img[index[0]:index[0]+16, index[1]:index[1]+16, :]
        color = color.mean(axis=(0, 1))
        strokes.append({'index': index, 'color': color.tolist()})

    draw_img = I_gray.copy().resize(img_size).convert('RGB')
    for stk in strokes:
        ind = stk['index']
        color = np.array(stk['color'])
        color = np.expand_dims(color, axis=(0, 1))
        color = cv2.resize(color, (16-6, 16-6), interpolation=cv2.INTER_NEAREST)
        color = cv2.copyMakeBorder(color, 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=(255, 255, 255))
        draw_img = draw_full_color(draw_img, color, [ind[0], ind[0]+16, ind[1], ind[1]+16])

    gen_imgs.append(draw_img.resize(I_color.size))

    x_gray = preprocess(I_gray, img_size).to(args.device)

    for n in range(num_samples):
        gen = filltran.sample(x_gray, topk, strokes)
        gen = output_to_pil(gen[0])
        #gen.show()
        gen_resize = color_resize(I_gray, gen)
        #gen_resize.show()
        gen_imgs.append(gen_resize)

    save_result(html, index=i, images=gen_imgs)
    html.save()
    i += 1
    

# Upsample

In [None]:
# Sampling arguments
img_dir = 'C:/MyFiles/Dataset/imagenet/val5000/val'
save_dir = 'C:/MyFiles/CondTran/sample_result/bert_final_upsample'
topk = 1
img_size = [256, 256]

html = HTML(save_dir, 'Sample')

pbar = tqdm(enumerate(os.listdir(img_dir)[:10]))
for i, filename in pbar:
    if filename.endswith('.jpg') or filename.endswith('.png') or filename.endswith('.JPEG') or filename.endswith('.jpeg'):
        gen_imgs = []

        I_color = Image.open(os.path.join(img_dir, filename)).convert('RGB')
        gen_imgs.append(I_color)
        I_gray = I_color.convert('L')
        I_color = I_color.resize(img_size)
        
        I_bilinear = color_resize(I_gray, I_color)
        gen_imgs.append(I_bilinear)

        I_upsample = colorizer.upsample(I_gray, I_color)
        gen_imgs.append(I_upsample)

        save_result(html, index=i, images=gen_imgs)
        html.save()