In [None]:
%cd ..

In [None]:
import random
import numpy as np
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import PIL

from torchvision.transforms import Resize
from itertools import combinations, product
from pathlib import Path

from core.uda_models import uda_models
from core.utils.common import mixing_noise
from core.utils.reading_weights import read_weights
from core.utils.example_utils import (
    to_im, Inferencer, 
    vstack_with_lines, 
    hstack_with_lines, 
    insert_image
)

from pprint import pprint


from draw_util import IdentityEditor, StyleEditor, morph_g_ema, weights, set_seed

In [None]:
device = 'cuda:0'

IMAGE_SIZE = 256
SKIP_HORIZ = 20
tr = 0.7
m_iter = 199


dom_to_editor = {
    k: StyleEditor(read_weights(p), device) for k, p in weights.items() if '.pt' not in p.name 
}

In [None]:
g = uda_models['stylegan2'](
    img_size=1024,
    latent_size=512,
    map_layers=8,
    checkpoint_path='pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt',
    device=device
).patch_layers('s_delta')

## Choose target image

In [None]:
set_seed(1)

z = [torch.randn(16, 512).to(device)]

In [None]:
im, _ = g(z, truncation=tr)
to_im(Resize(256)(im))

In [None]:
good_lat_idx = -1

z_single = [z[0].detach()[good_lat_idx].unsqueeze(0)]
z.clear()
s_single = g.get_s_code(z_single, truncation=tr)
im, _ = g(s_single, is_s_code=True)
to_im(Resize(256)(im))

## Morphing

In [None]:
pprint(list(k for k, p in weights.items() if '.pt' not in p.name)) # -- possible domains

In [None]:
pow = 0.6

domain_order = ['pixar', 'joker', 'anime_indomain']

In [None]:
cur_st = [t.clone() for t in s_single]
images = [to_im(Resize(IMAGE_SIZE)(im))]

for new_domain in domain_order:
    editor = dom_to_editor[new_domain]
    cur_st = editor(cur_st, power=pow)

    cur_im, _ = g(cur_st, is_s_code=True)
    images.append(
        to_im(Resize(IMAGE_SIZE)(cur_im))
    )


final_image = hstack_with_lines(images, SKIP_HORIZ)
PIL.Image.fromarray(final_image)