In [None]:
%cd ..

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import PIL
import random

from torchvision.transforms import Resize
from omegaconf import OmegaConf
from core.utils.common import load_clip, mixing_noise
from core.utils.example_utils import Inferencer, to_im, vstack_with_lines, hstack_with_lines, insert_image
from core.utils.image_utils import construct_paper_image_grid
from core.utils.reading_weights import read_weights
from core.uda_models import uda_models

from pathlib import Path
from collections import defaultdict, OrderedDict
from copy import deepcopy

from examples.draw_util import IdentityEditor, StyleEditor, morph_g_ema, weights, set_seed

In [None]:
device = 'cuda:0'

dom_to_editor = {
    k: StyleEditor(read_weights(v), device) for k, v in weights.items() if '.pt' not in v.name
}

In [None]:
g = uda_models['stylegan2'](
    img_size=1024,
    latent_size=512,
    map_layers=8,
    checkpoint_path='pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt'
).eval().to(device)

In [None]:
set_seed(3)

z = mixing_noise(16, 512, 0, device)

In [None]:
im, _ = g(z, truncation=0.7)

to_im(Resize(256)(im))

In [None]:
idx = 4

z_single = [z[0][idx].unsqueeze(0)]
z.clear()

im, _ = g(z_single, truncation=0.7)
to_im(Resize(256)(im))

## Combined Morphing

In [None]:
phase_step = 8      ## images per one morphing stage
alphas = np.linspace(0, 1, phase_step).tolist()

linear_size = 256   ## single image size
skip_horiz = 10     ## skip distance

offset_power = 0.8  ## offset_power of stylization model


resize = Resize(linear_size)


# Each combined morphing is represented by 3 domains:
# The first is domain of initial stylization e.g. 'anime', 'pixar', etc
# Second domain is domain of finetuned GAN model e.g. 'AFHQDog', AFHQCat, etc
# Third domain is also about final stylization
# In `row_domain` examples are shown (to see results use only one row)


row_domains = [
    ('impressionism_painting', 'to_afhqdog', 'pop_art_indomain'),
    # ('ukiyo-e', 'afhqdog', 'sketch'),
    # ('dali_painting', 'afhqcat', 'cubism_painting'),
    # ('ukiyo-e', 'afhqdog', 'werewolf'),
    # ('ukiyo-e', 'afhqcat', 'werewolf')
]

In [None]:
images = []
ckpt_ffhq = torch.load(weights['ffhq'])

for first_style, g_domain, next_style in row_domains:
    g.generator.load_state_dict(ckpt_ffhq['g_ema'])
    g.eval().to(device)
    g.mean_latent = g.generator.mean_latent(4096)
    
    row_image = []
    for alpha in alphas:
        style_editor = dom_to_editor[first_style] * alpha
        s = g.get_s_code(z_single, truncation=0.7)
        edited_s = style_editor(s, offset_power)
        im, _ = g(edited_s, is_s_code=True)
        row_image.append(to_im(resize(im)))
    
    images.append(row_image)
    
    
    style_editor = dom_to_editor[first_style]
    ckpt2 = torch.load(weights[g_domain])
    
    
    row_image = []
    for alpha in alphas:
        morphed_ckpt = morph_g_ema(ckpt_ffhq, ckpt2, 1 - alpha)
        
        ## ---------Update Generator weights------------ ##
        g.generator.load_state_dict(morphed_ckpt['g_ema'])
        g.eval().to(device)
        g.mean_latent = g.generator.mean_latent(4096)
        ## --------------------------------------------- ## 
        
        s = g.get_s_code(z_single, truncation=0.7)
        edited_s = style_editor(s, offset_power)
        im, _ = g(edited_s, is_s_code=True)
        row_image.append(to_im(resize(im)))
    
    images.append(row_image[::-1])
    row_image = []
    
    for alpha in alphas:
        style_editor = dom_to_editor[next_style] * alpha + dom_to_editor[first_style] * (1 - alpha)
        edited_s_ = style_editor(s, offset_power)
        im, _ = g(edited_s_, is_s_code=True)
        row_image = [to_im(resize(im))] + row_image
    
    images.append(row_image[::-1])
    
    row_image = []
    for alpha in alphas:
        morphed_ckpt = morph_g_ema(ckpt_ffhq, ckpt2, alpha)
        
        ## ---------Update Generator weights------------ ##
        g.generator.load_state_dict(morphed_ckpt['g_ema'])
        g.eval().to(device)
        g.mean_latent = g.generator.mean_latent(4096)
        ## --------------------------------------------- ## 
        
        
        s = g.get_s_code(z_single, truncation=0.7)
        edited_s = style_editor(s, offset_power)
        im, _ = g(edited_s, is_s_code=True)
        row_image = [to_im(resize(im))] + row_image
    
    images.append(row_image)

In [None]:
skip_horiz = 10
skip_vertical = 15
    
final_image = []

for row_stack in images:
    final_image.append(hstack_with_lines(row_stack, skip_horiz))
    
final_image = vstack_with_lines(final_image, skip_vertical)
PIL.Image.fromarray(final_image)