## Load Synthesizer

In [None]:
import numpy as np
import torch
from pathlib import Path
from functools import partial
from torchvision import transforms
from PIL import Image
from IPython.display import display
from tqdm.notebook import tqdm

from interfacegan.models.stylegan_generator import StyleGANGenerator
from models.image_to_latent import ImageToLatent
from models.latent_optimizer import LatentOptimizer, VGGProcessing, PostSynthesisProcessing
from models.losses import LatentLoss

In [None]:
image_size = 256
image_shape = (image_size, image_size)

In [None]:
# synthesizer generates images from latents
synthesizer_name = 'stylegan_ffhq'
synthesizer = StyleGANGenerator(synthesizer_name).model.synthesis

# to clip the values of the generated image
post_processor = PostSynthesisProcessing()

In [None]:
def synthesize(latent):
    """ Synthesize an image from the WP latents.
    
    Args:
        latent: WP latent as numpy array of shape (batch_size, 18, 512)
        
    Returns:
        synthesized image as PIL.Image of shape (256, 256)
    """
    latent = torch.from_numpy(latent.astype(np.float32))
    synth = synthesizer(latent)
    postproc = post_processor(synth).detach().numpy().astype(np.uint8)[0].transpose((1, 2, 0))
    from_tensor = transforms.ToPILImage()
    output_image = from_tensor(postproc)
    output_image.thumbnail(image_shape)
    return output_image

## Validate Data

Make sure that image generated from a given dlatent looks the same as the corresponding image 

In [None]:
data_dir = Path('data/stylegan_ffhq')

dlatents = np.load(data_dir/'dlatents/wp.npy')
dlatents.shape

In [None]:
idx = 0

img = Image.open(f'data/stylegan_ffhq/images/{idx + 10000:06d}.jpg')
img.thumbnail(image_shape)

# Image.fromarray(np.interp(dlatent, (dlatent.min(), dlatent.max()), (0, 255)).astype(np.uint8))

display(img, synthesize(dlatents[np.newaxis, idx]))

## Load Latent Optimizer

Optionally load the `ImageToLatent` network that generates a latent representation of the image to initialize the optimization with instead of zeros.

In [None]:
vgg_layer = 12
learning_rate = 1
iterations = 100
# model_path = '2022-11-17_13:17_1000.pt'
model_path = '2022-11-20_11 06_50000_19.pt'

# preprocessing for VGG and ImageToLatent nets
vgg_processing = VGGProcessing()

# image to latent creates an initial latent to start the optimization from 
image_to_latent = ImageToLatent(image_size)
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
image_to_latent.load_state_dict(checkpoint['model_state_dict'])
# image_to_latent.load_state_dict(checkpoint)
image_to_latent.eval()

# latent optimizer iteratively improves the latent of the image using embeddings from vgg  
latent_optimizer = LatentOptimizer(synthesizer, vgg_layer)

In [None]:
def optimize_latents(
    image,
    use_image_to_latent=False,
    iterations=iterations,
):
    """ Optimize/train the latents of the image.
    
    Args:
        image: input image as numpy array of shape (3, 256, 256)
        use_image_to_latent: whether to generate the initial latent using ImageToLatent net
        iterations: number of optimization iterations
        
    Returns:
        optimized WP latents as numpy array of shape (1, 18, 512)
    """
    # Optimize only the dlatents.
    for param in latent_optimizer.parameters():
        param.requires_grad_(False)

    image = torch.from_numpy(image)
    image = latent_optimizer.vgg_processing(image)
    reference_features = latent_optimizer.vgg16(image).detach()
    image = image.detach()
    
    latents_to_be_optimized = (
        image_to_latent(image.unsqueeze(0)).detach()
        if use_image_to_latent
        else torch.zeros((1, 18, 512))
    )
    latents_to_be_optimized = latents_to_be_optimized.requires_grad_(True)

    criterion = LatentLoss()
    optimizer = torch.optim.SGD([latents_to_be_optimized], lr=learning_rate)

    progress_bar = tqdm(range(iterations))
    for step in progress_bar:
        optimizer.zero_grad()

        generated_image_features = latent_optimizer(latents_to_be_optimized).squeeze()
        
        loss = criterion(generated_image_features, reference_features)
        loss.backward()
        loss = loss.item()

        optimizer.step()
        progress_bar.set_description("Step: {}, Loss: {}".format(step, loss))
    
    optimized_dlatents = latents_to_be_optimized.detach().numpy()
    
    return optimized_dlatents

In [None]:
def predict(
    image,
    use_image_to_latent=False,
    iterations=iterations,
):
    """ Optimize the latents and synthesize the image.
    
    Args:
        image: input image as numpy array of shape (3, 256, 256)
        use_image_to_latent: whether to generate the initial latent using ImageToLatent net
        iterations: number of optimization iterations
        
    Returns:
        image synthesized from the optimized latents
    """
    image = image.transpose((2, 0, 1))
    dlatents = optimize_latents(
        image,
        use_image_to_latent=use_image_to_latent,
        iterations=iterations,
    )
    output_image = synthesize(dlatents)
    return output_image

## Use Gradio

In [None]:
import gradio as gr

In [None]:
image1 = Image.open('data/us/maoka.png')
image1.thumbnail(image_shape)
image2 = Image.open('data/us/mamala.png')
image2.thumbnail(image_shape)

# vector in the WP latent space to move along to change the age
age_boundary = np.load('InterFaceGAN/boundaries/stylegan_ffhq_age_w_boundary.npy')

In [None]:
def predict_offspring(
    image1,
    image2,
    weight=0.5,
    child_scale=1,
):
    images = []
    for image in (image1, image2):
        image = torch.from_numpy(np.asarray(image).transpose((2, 0, 1)))
        image = vgg_processing(image).detach()
        images.append(image)
    images = torch.stack(images)
    latents = image_to_latent(images).detach().numpy()
    latent = (weight * latents[0] + (1 - weight) * latents[1]) - child_scale * age_boundary
    offspring = synthesize(latent[np.newaxis])
    return offspring

In [None]:
?gr.Image

In [None]:
gr.Interface(
    fn=predict_offspring,
    inputs=[
        gr.Image(
            label='XY',
            shape=image_shape,
        ),
        gr.Image(
            label='XX',
            shape=image_shape,
        ),
#         gr.Checkbox(
# #             value=True,
#             label='Use ImageToLatent',
#         ),
#         gr.Slider(
#             minimum=0,
#             maximum=500,
#             value=0,
#             label='Iterations',
#         ),
        gr.Slider(
            label='XY weight',
            minimum=0,
            maximum=1,
            value=.5,
            step=.1,
        ),
        gr.Slider(
            label='child scale',
            minimum=0,
            maximum=3,
            value=1.5,
            step=.5,
        )
    ],
    examples=[
        [
            'data/us/maoka.png',
            'data/us/mamala.png',
#             .5,
#             2,
        ],
    ],
    outputs=gr.Image(type='pil'),
).launch()