In [None]:
# import packages
from core.utils import get_config
from core.trainer import HiSD_Trainer
import argparse
import torchvision.utils as vutils
import sys
import torch
import os
from torchvision import transforms
from PIL import Image
import numpy as np
import time
import matplotlib.pyplot as plt

In [None]:
# use cpu by default
# device = 'cuda:0' 
device = 'cpu'

# load checkpoint
config = get_config('configs/celeba-hq_256.yaml')
noise_dim = config['noise_dim']
image_size = config['new_size']
checkpoint = 'checkpoint_256_celeba-hq.pt'
trainer = HiSD_Trainer(config)
state_dict = torch.load(checkpoint)
trainer.models.gen.load_state_dict(state_dict['gen_test'])
trainer.models.gen.to(device)

E = trainer.models.gen.encode
T = trainer.models.gen.translate
G = trainer.models.gen.decode
M = trainer.models.gen.map
F = trainer.models.gen.extract

transform = transforms.Compose([transforms.Resize(image_size),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [None]:
"""
Do the translation and plot the output.
Every time you run this block will output a result with different bangs but reference's eyeglasses.
"""
def translate(input, steps):
    x = transform(Image.open(input).convert('RGB')).unsqueeze(0).to(device)
    c = E(x)
    c_trg = c
    for j in range(len(steps)):
        step = steps[j]
        if step['type'] == 'latent-guided':
            if step['seed'] is not None:
                torch.manual_seed(step['seed'])
                torch.cuda.manual_seed(step['seed']) 

            z = torch.randn(1, noise_dim).to(device)
            s_trg = M(z, step['tag'], step['attribute'])

        elif step['type'] == 'reference-guided':
            reference = transform(Image.open(step['reference']).convert('RGB')).unsqueeze(0).to(device)
            s_trg = F(reference, step['tag'])
        
        c_trg = T(c_trg, s_trg, step['tag'])
            
    x_trg = G(c_trg)
    output = x_trg.squeeze(0).cpu().permute(1, 2, 0).add(1).mul(1/2).clamp(0,1).numpy()
    return output

In [None]:
# You need to crop the image if you use your own input.
input = 'examples/input_0.jpg'

# e.g.1 change tag 'Bangs' to attribute 'with' using 3x latent-guided styles (generated by random noise). 
steps = [
    {'type': 'latent-guided', 'tag': 0, 'attribute': 0, 'seed': None}
]

for i in range(3):
    output = translate(input, steps)
    plt.imshow(output, aspect='auto')
    plt.show()

In [None]:
# You need to crop the image if you use your own input.
input = 'examples/input_1.jpg'

# e.g.2 change tag 'Glasses' to attribute 'with' using reference-guided styles (extracted from another image). 
steps = [
    {'type': 'reference-guided', 'tag': 1, 'reference': 'examples/reference_glasses_0.jpg'}
]

output = translate(input, steps)
plt.imshow(output, aspect='auto')
plt.show()

steps = [
    {'type': 'reference-guided', 'tag': 1, 'reference': 'examples/reference_glasses_1.jpg'}
]

output = translate(input, steps)
plt.imshow(output, aspect='auto')
plt.show()

steps = [
    {'type': 'reference-guided', 'tag': 1, 'reference': 'examples/reference_glasses_2.jpg'}
]

output = translate(input, steps)
plt.imshow(output, aspect='auto')
plt.show()


In [None]:
# You need to crop the image if you use your own input.
input = 'examples/input_2.jpg'

# e.g.3 change tag 'Glasses' and 'Bangs 'to attribute 'with', 'Hair color' to 'black' during one translation. 
steps = [
    {'type': 'reference-guided', 'tag': 0, 'reference': 'examples/reference_bangs_0.jpg'},
    {'type': 'reference-guided', 'tag': 1, 'reference': 'examples/reference_glasses_0.jpg'},
    {'type': 'latent-guided', 'tag': 2, 'attribute': 0, 'seed': None}
]

output = translate(input, steps)
plt.imshow(output, aspect='auto')
plt.show()

In [None]:
# Try to DIY your translation here. For example, can you remove one's glasses and bangs?