In [1]:
# stylegan2-ada-pytorch/blending.ipynb

from training.networks import Generator
from copy import deepcopy
import math
import torch
import dnnlib
import legacy
import torchvision
from PIL import Image
import numpy as np
import torchvision
import matplotlib.pyplot as plt

In [2]:
def gather_params(G: Generator) -> dict:
    params = dict(
        [(res, {}) for res in G.synthesis.block_resolutions] + [("mapping", {})]
    )
    # G params: mapping.xxx / synthesys.b128.xxx
    for n, p in sorted(list(G.named_buffers()) + list(G.named_parameters())):
        if n.startswith("mapping"):
            params["mapping"][n] = p
        else:
            res = int(n.split(".")[1][1:])
            params[res][n] = p
    return params


def blend_models(G_low: Generator, G_high: Generator, swap_layer: int, blend_width: float = 3) -> Generator:
    params_low = gather_params(G_low)
    params_high = gather_params(G_high)

    for layer_idx, res in enumerate(G_low.synthesis.block_resolutions):
        x = layer_idx - swap_layer
        
        if blend_width is not None:
            assert blend_width > 0
            exponent = - x / blend_width
            y = 1 / (1 + math.exp(exponent))
        else:
            y = 1 if x > 0 else 0
            
        for n, p in params_high[res].items():
            params_high[res][n] = params_high[res][n] * y + params_low[res][n] * (1 - y)

    state_dict = {}
    for _, p in params_high.items():
        state_dict.update(p)

    G_mix = deepcopy(G_high)
    G_mix.load_state_dict(state_dict)
    return G_mix

In [3]:
ffhq_path = './ffhq_512.pkl'
# met_path = './metfaces.pkl'
met_path = './training-runs/00024-acanev3-mirror-paper512-batch8-resumeffhq512/network-snapshot-000400.pkl'
truncation_psi = 1

In [4]:
device = torch.device('cuda')
with dnnlib.util.open_url(ffhq_path) as f:
    G_low = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore

In [5]:
device = torch.device('cuda')
with dnnlib.util.open_url(met_path) as f:
    G_high = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore

In [30]:
max_seed = 2**32 - 1

for i in range(0, 10000):
    seed = np.random.randint(0, max_seed, dtype=np.int64)
    
    G_blend = blend_models(G_low, G_high, swap_layer=7, blend_width=3)
    all_z = np.stack([np.random.RandomState(seed).randn(G_blend.z_dim)])
    all_w = G_blend.mapping(torch.from_numpy(all_z).to(device), None)
    w_avg = G_blend.mapping.w_avg
    all_w = w_avg + (all_w - w_avg) * truncation_psi

    input  = G_low.synthesis(all_w, noise_mode="const")
    input2 = G_high.synthesis(all_w, noise_mode="const")
    target = G_blend.synthesis(all_w, noise_mode="const")
    torchvision.utils.save_image(torch.vstack([input]), f'../../datasets/style_transfer/real/{i}.png', normalize=True)
    torchvision.utils.save_image(torch.vstack([target]), f'../../datasets/style_transfer/style/{i}.png', normalize=True)