In [1]:
import cv2
import time
import torch
import ntpath
import numpy as np
import matplotlib.pyplot as plt
from options.test_options import TestOptions
import os
from data import create_dataset
from models import create_model
from torchvision import transforms as tf
from util import html,util

In [2]:
NAME = 'gen_new4'
DS_NAME = ''

defaults = {
    'dataroot': f'../../generator_data/{DS_NAME}',
    'model': 'my',
    'dataset_mode': 'my',
    'dataset_root': f'../../generator_data/{DS_NAME}',
    'embedding_save_dir': './checkpoints/emb_vidit4',
    'name': NAME
}

# defaults = {
#     'dataroot': '../../embedding_data/vidit/',
#     'model': 'emb',
#     'dataset_mode': 'emb',
#     'dataset_root': '../../embedding_data/vidit/',
# #     'embedding_save_dir': './checkpoints/hdr1/',
#     'name': NAME
# }



    
opt = TestOptions(defaults=defaults).parse()

----------------- Options ---------------
             aspect_ratio: 1.0                           
               batch_size: 1                             
          checkpoints_dir: ./checkpoints                 
                crop_size: 256                           
                 dataroot: ../../generator_data/         
             dataset_mode: my                            
             dataset_root: ../../generator_data/         
                direction: AtoB                          
             display_port: 9333                          
          display_winsize: 256                           
       embedding_save_dir: ./checkpoints/emb_vidit4      
                    epoch: latest                        
                     eval: False                         
                  gpu_ids: 0                             
                init_gain: 0.02                          
                init_type: normal                        
                 input_nc: 5  

In [3]:
# hard-code some parameters for test
opt.num_threads = 0   # test code only supports num_threads = 1
opt.batch_size = 1    # test code only supports batch_size = 1
opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
opt.no_flip = True    # no flip; comment this line if results on flipped images are needed.
opt.display_id = -1   # no visdom display; the test code saves the results to a HTML file.

# opt.gpu_ids = []

In [4]:
assert opt.isTrain == False

In [5]:
dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options

dataset_size = len(dataset)    # get the number of images in the dataset.
print('The number of training images = %d' % dataset_size)

model = create_model(opt)      # create a model given opt.model and other options
# model.setup(opt)               # regular setup: load and print networks; create schedulers
model.load_networks('latest')
model.print_networks(verbose=False)
total_iters = 0                # the total number of training iterations

loading test file
dataset [MyDataset] was created
The number of training images = 7393
initialize network with normal
initialize network with normal
loading the model from ./checkpoints/emb_vidit4/150_net_G.pth
---------- Networks initialized -------------
[Network G] Total number of parameters : 54.415 M
-----------------------------------------------
model [MyModel] was created
loading the model from ./checkpoints/gen_new4/latest_net_G.pth
---------- Networks initialized -------------
[Network G] Total number of parameters : 93.475 M
-----------------------------------------------


In [6]:
class NormalizeInverse(tf.Normalize):
    """
    Undoes the normalization and returns the reconstructed images in the input domain.
    """

    def __init__(self, mean, std):
        mean = torch.as_tensor(mean)
        std = torch.as_tensor(std)
        std_inv = 1 / (std + 1e-7)
        mean_inv = -mean * std_inv
        super().__init__(mean=mean_inv, std=std_inv)

    def __call__(self, tensor):
        return super().__call__(tensor.clone())

unnorm = NormalizeInverse((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
unnorm1 = NormalizeInverse((0.5,), (0.5,))

In [7]:
# create a website
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch))  # define the website directory
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))

In [8]:
def read_rgb(img_path):
    img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img.astype(np.uint8)

def read_rgba(img_path):
    img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
    img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA)
    return img.astype(np.uint8)

def read_bytes(img_path):
    with open(img_path, 'rb') as image_file:
        return image_file.read()

def fread_rgb(img_path):
    return (read_rgb(img_path).astype(np.float32) / 255)

def read_mask(path):
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    return img.astype(np.uint8)

def write_rgb(img, path):
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    cv2.imwrite(path, img)

def write_rgba(img, path):
    img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA)
    cv2.imwrite(path, img)

def write_mask(img, path):
    cv2.imwrite(path, img)
    

def save(img, path, aspect_ratio=1.0):
    if len(img.shape) == 2 or img.shape[2] == 1:
        cv2.imwrite(path, img)
    else:
        write_rgb(img, path)


def show(img):
    fig = plt.figure(figsize=(8,8))
    plt.imshow(img)
    plt.show()

In [9]:
def to_img(tensor):
    np_img = torch.squeeze(tensor.detach().to('cpu')).numpy()
    np_img = (np.transpose(np_img, (1, 2, 0)) + 1) / 2.0
    return (np_img.clip(0, 1) * 255).astype(np.uint8)
    
def to_img1(tensor):
    tensor = tensor.detach().to('cpu')[0]
    if unnorm is not None:
        tensor = unnorm1(tensor)
    np_img = tensor.numpy()
    np_img = np_img.transpose((1, 2, 0))
    return (np_img.clip(0, 1) * 255).astype(np.uint8)

def tensor2im(input_image, imtype=np.uint8):
    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor):  # get the data from a variable
            image_tensor = input_image.data
        else:
            return input_image
        image_numpy = image_tensor[0].cpu().float().numpy()  # convert it into a numpy array
        if image_numpy.shape[0] == 1:  # grayscale to RGB
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling
    else:  # if it is a numpy array, do nothing
        image_numpy = input_image
    return image_numpy.astype(imtype)


def display(visuals):
    real = to_img(visuals['real'])
    comp = to_img(visuals['comp'])
    harmonized = to_img(visuals['harmonized'])
    
    print(f'composition/real/harmonized')
    display = np.concatenate([comp, real, harmonized], axis=1)
    fig = plt.figure(figsize=(15, 15))
    plt.imshow(display)
    plt.show()

In [10]:
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
    """Save images to the disk.

    Parameters:
        webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
        visuals (OrderedDict)    -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
        image_path (str)         -- the string is used to create image paths
        aspect_ratio (float)     -- the aspect ratio of saved images
        width (int)              -- the images will be resized to width x width

    This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
    """
    image_dir = webpage.get_image_dir()
    name = image_path.split('/')[-1]

    webpage.add_header(name)
    ims, txts, links = [], [], []

    for label, im_data in visuals.items():
        im = im_data
        image_name = '%s_%s.png' % (name, label)
        save_path = os.path.join(image_dir, image_name)
        save(im, save_path, aspect_ratio=aspect_ratio)
        ims.append(image_name)
        txts.append(label)
        links.append(image_name)
    webpage.add_images(ims, txts, links, width=width)

In [11]:
test_loss = 0.0

if opt.eval:
    model.eval()

for i, data in enumerate(dataset):
    model.set_input(data)  # unpack data from data loader
    model.test()           # run inference
    visuals = model.get_current_visuals()  # get image results
    for label, im_data in visuals.items():
        if label=='harmonized':
            harmonized = tensor2im(im_data)
            img_path = str(data['img_path'])
            raw_name = img_path.split('/')[-1]
            raw_name = raw_name.replace(('[\''),'')
            raw_name = raw_name.replace(('.jpg\']'),'.jpg')
            image_name = '%s' % raw_name
            save_path = os.path.join(opt.results_dir+opt.name+'/test_latest/images/', image_name)
            print('processing (%04d)-th image... %s' % (i, img_path))
            break
            
    real = to_img(data['real'])            
    comp = to_img(data['comp'])
    mask = to_img1(data['mask'])
    
    # for real images
#     mask3 = np.dstack([mask // 255] * 3)
#     harmonized = mask3 * harmonized + (1 - mask3) * real
    
    visuals['real'] = real
    visuals['comp'] = comp
    visuals['harmonized'] = harmonized
    
    if i % 1 == 0:  # save images to an HTML file
        print('saving (%04d)-th image... %s' % (i, img_path))
        save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
webpage.save()  # save the HTML
print('Done')

processing (0000)-th image... ['../../generator_data/HAdobe5k/composite_images/a3630_1_5.jpg']
saving (0000)-th image... ['../../generator_data/HAdobe5k/composite_images/a3630_1_5.jpg']
Done
