In [None]:
%cd ..

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import PIL
import random

from torchvision.transforms import Resize
from omegaconf import OmegaConf
from core.utils.common import load_clip, mixing_noise
from core.utils.example_utils import (
    Inferencer, to_im, vstack_with_lines, hstack_with_lines, insert_image
)
from core.utils.image_utils import construct_paper_image_grid
from core.utils.reading_weights import read_weights
from core.uda_models import OffsetsTunningGenerator

from pathlib import Path
from collections import defaultdict

from examples.draw_util import weights, set_seed, IdentityEditor, StyleEditor

In [None]:
device = 'cuda:0'

gan_domain = 'ffhq'
s_domain = 'pixar'

ckpt = read_weights(weights[s_domain])
ckpt_ffhq = {'sg2_params': ckpt['sg2_params']}
ckpt_ffhq['sg2_params']['checkpoint_path'] = weights[gan_domain]

model = Inferencer(ckpt, device)

In [None]:
set_seed(1)

z = [torch.randn(16, 512).to(device)]

In [None]:
src_im, trg_im = model(z, truncation=0.7, offset_power=0.9)

to_im(Resize(256)(src_im))

In [None]:
idx = 15

z_single = [z[0][idx].unsqueeze(0)]
z.clear()

s_single = model.sg2_source.get_s_code(z_single, truncation=0.7)

In [None]:
src, _ = model.sg2_source(s_single, is_s_code=True)

to_im(Resize(256)(src))

In [None]:
column_domains = [
    'ffhq', 
    'to_metfaces',
    'to_mega',
    'to_afhqdog',
    'to_afhqcat'
]


row_domains = [
    'original', 
    'pixar', 
    'anime', 
    'ukiyo-e', 
    'botero', 
    'joker', 
    'anastasia', 
    'speed_paint', 
]


linear_size = 256
truncation = 0.7
offset_pow = 0.85

dom_to_pow = defaultdict(lambda : offset_pow, {
    'original': 0.,
    'sketch': 0.7,
    'pixar': 0.75,
    'botero': 0.75,
    'joker': 0.65,
    'edvard_munch_painting': 0.95,
    'modigliani_painting': 0.75
})

style_to_editor = {
    d: StyleEditor(read_weights(weights[d])) if d != 'original' else IdentityEditor() for d in row_domains
}

model = Inferencer(ckpt_ffhq, device)
resize = Resize(linear_size)
stack = []

for row_domain in row_domains:
    row_image = []
    for column_domain in column_domains:
        model.sg2_source.generator.load_state_dict(torch.load(weights[column_domain])['g_ema'])
        s_edited = style_to_editor[row_domain](s_single, power=dom_to_pow[row_domain])
        src_im, _ = model.sg2_source(
            s_edited, is_s_code=True
        )
        im = np.array(to_im(resize(src_im)))
        row_image.append(im)
        
    stack.append(row_image)

In [None]:
skip_width = 10
skip_vertical = 15

final_image = []

for row_stack in stack:
    final_image.append(hstack_with_lines(row_stack, skip_width))
    
final_image = vstack_with_lines(final_image, skip_vertical)
PIL.Image.fromarray(final_image)