In [None]:
import gradio as gr
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
from IPython.display import display
from pathlib import Path

from InterFaceGAN.models.stylegan_generator import StyleGANGenerator
from models.image_to_latent import ImageToLatent
from models.latent_optimizer import PostSynthesisProcessing

## Load Synthesizer

In [None]:
image_size = 256
synthesizer_name = 'stylegan_ffhq'

synthesizer = StyleGANGenerator(synthesizer_name).model.synthesis
post_processor = PostSynthesisProcessing()

In [None]:
def synthesize(latent):
    synth = synthesizer(latent)
    postproc = post_processor(synth).detach().numpy().astype(np.uint8)[0].transpose((1, 2, 0))
    from_tensor = transforms.ToPILImage()
    output_image = from_tensor(postproc)
    output_image.thumbnail((image_size, image_size))
    return output_image

In [None]:
output_dir = Path('output/stylegan_ffhq')
path1 = output_dir/'dlatents/maoka.npy'
latent1 = np.load(path1).squeeze()
path2 = output_dir/'dlatents/mamala.npy'
latent2 = np.load(path2).squeeze()

age_boundary_path = 'InterFaceGAN/boundaries/stylegan_ffhq_age_w_boundary.npy'
age_boundary = np.load(age_boundary_path)
gender_boundary_path = 'InterFaceGAN/boundaries/stylegan_ffhq_gender_w_boundary.npy'
gender_boundary = np.load(gender_boundary_path)


latent = np.mean((latent1, latent2), axis=0)
# latent = latent2
latent -= 3 * age_boundary
latent -= gender_boundary

In [None]:
img1 = Image.open(output_dir/'images/maoka.jpg')
img1.thumbnail((image_size, image_size))


img2 = Image.open(output_dir/'images/mamala.jpg')
img2.thumbnail((image_size, image_size))


offspring = synthesize(transforms.ToTensor()(latent).float())

display(img1, img2, offspring)

## Load Encoder

In [None]:
# model_path = '2022-11-17_13:17_1000.pt'
model_path = '2022-11-20_11 06_50000_19.pt'

image_to_latent = ImageToLatent(image_size)

checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

image_to_latent.load_state_dict(checkpoint['model_state_dict'])
# image_to_latent.load_state_dict(checkpoint)
image_to_latent.eval()

In [None]:
def predict(image):
    to_tensor = transforms.ToTensor()
    latent = image_to_latent(to_tensor(image).unsqueeze(0))
    return synthesize(latent)

In [None]:
img = Image.open('data/us/maoka.png')
img.thumbnail((image_size, image_size))

In [None]:
display(img, predict(img))