# Localized Semantic Editing of StyleGAN outputs

Introduced in the paper:<br>
> Edo Collins, Raja Bala, Bob Price and Sabine Süsstrunk. _Editing in Style: Uncovering the Local Semantics of GANs_.  IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2020.

This demo illustrates a simple and effective method for making local, semantically-aware edits to a _target_ GAN output image. This is accomplished by borrowing styles from a _source_ image, also a GAN output.

The method requires neither supervision from an external model, nor involves complex spatial morphing operations. Instead, it relies on the emergent disentanglement of semantic objects that is learned by StyleGAN during its training, which we detect using Spherical _k_-means.

The implementation below relies on PyTorch and requires downloading additional parameter files found here: https://drive.google.com/open?id=1GYzEzOCaI8FUS6JHdt6g9UfNTmpO08Tt

In [1]:
%load_ext autoreload
%autoreload 2
import torch
import stylegan                                     # StyleGAN model
from stylegan_output import GANOutputs              # Data structure to hold GAN outputs
import ptutils                                      # Helper tensor functions
import visutils                                     # Visualization functions
# from factor_catalog import FactorCatalog            # Spherical k-means and the M matrix (Eq. 1)
from style2_interpolator import StyleInterpolator    # The 'sequential' style-interpolator (Eq. 5)
import cielab                                       # Helper functions for CIELAB color-space
torch.cuda.set_device(3)

Choose a dataset

In [2]:
dataset_name = 'FFHQ'
#dataset_name = 'bedrooms'

Load the appropriate StyleGAN model

In [3]:
catalog = torch.load('catalogs/stylegan1_FFHQ.pkl'.format(dataset_name)) # See comment above regarding additional files
si_wf = StyleInterpolator(catalog, bias=False)



In [4]:
select = catalog.M[6]
print(type(select))
# for i in range(len(catalog.M)):
#     print(catalog.M[i].size())
for i in range(8):
    select = catalog.M[6][i]
    print(torch.min(select), torch.max(select), torch.sum(select))

<class 'torch.Tensor'>
tensor(0.0018) tensor(0.6879) tensor(52.2493)
tensor(0.0071) tensor(0.7265) tensor(68.7965)
tensor(0.0021) tensor(0.3560) tensor(19.9632)
tensor(0.0064) tensor(0.8137) tensor(75.3605)
tensor(0.0038) tensor(0.4327) tensor(26.2488)
tensor(0.0048) tensor(0.5194) tensor(32.5841)
tensor(0.0048) tensor(0.5246) tensor(47.5136)
tensor(0.0056) tensor(0.5826) tensor(60.4719)


In [5]:
root_dir = 'state_dicts' # See comment above regarding additional files
if dataset_name == 'bedrooms':
    base_filename = 'bedrooms-256x256'
    truncation = 0.5
elif dataset_name == 'FFHQ':
    base_filename = 'ffhq'
    truncation = 0.7

G = stylegan.StyleGAN.load_from_pth('../karass_ckpt/ffhqE.pt').eval()
G = G.cuda()
G.g_synthesis.set_noise(mode='fixed')

ValueError: max() arg is an empty sequence

Load the pre-computed spherical k-means clusters, and provide them to the style interpolator

Generate some examples

In [None]:
if dataset_name == 'FFHQ':
    gs = GANOutputs.from_seed(5, 2001)
elif dataset_name == 'bedrooms':
    gs = GANOutputs.from_seed((0,33,3,19,34), 6813)

with torch.no_grad():
    print('z: ', gs.z.size())
    tmp_w = G.z_to_w(gs.z.cuda(), truncation=truncation)
    print('w: ', tmp_w[0].size(), len(tmp_w))
    gs.ys = G.w_to_ys(tmp_w)
    print('ys: ', gs.ys[0].size(), len(gs.ys))
    rgb = G.ys_to_rgb(gs.ys)
    rgb = (rgb.clamp(-1, 1) + 1) / 2
    rgb = rgb.cpu()
    gs.rgb = ptutils.MultiResolutionStore(rgb)
        
    gs1 = gs[:1]
    gs2 = gs[1:]

res=256
i, n = 0,4
print(gs1.rgb.get(res)[i:i+n].size())
visutils.show(gs1.rgb.get(res)[i:i+n].permute(0,2,3,1).cpu(), title='Target')
visutils.show(gs2.rgb.get(res)[i:i+n].permute(0,2,3,1).cpu(), title='References')

Transfer object styles from refernces to target 

In [None]:
print('gs1.ys: ', gs1.ys[0].size())
print('gs2.ys: ', gs2.ys[0].size())

In [None]:
part_gs = {}
print(len(G.AdaIN_layers))
def get_epsilons(epsilon, low_res_epsilon=0):
    epsilons = [epsilon]*len(G.AdaIN_layers)
    for i in range(4): epsilons[i] = low_res_epsilon
    return epsilons

if dataset_name == 'FFHQ':
    parts_thresholds = {
        'eyes': (0.1, get_epsilons(50, 5)),
        'nose': (0.1, get_epsilons(30, 5)),
        'mouth': (0.1, get_epsilons(50, 5)),
    }

elif dataset_name == 'bedrooms':
    parts_thresholds = {
        'bed': (0.01, get_epsilons(120)),
        'pillow': (0.05, get_epsilons(100)),
        'window': (0.05, get_epsilons(100)),
    }

for label, (rho, epsilon) in parts_thresholds.items():
        key = (label)
        part_gs[key]  = GANOutputs()
        print(len(gs1.ys), gs1.ys[0].size(), len(gs2.ys), gs2.ys[0].size())
        part_gs[key].ys = si_wf.interpolate_ys(gs1.ys, gs2.ys, label, rho, epsilon)
        with torch.no_grad():
                rgb = G.ys_to_rgb(part_gs[key].ys)
                rgb = (rgb.clamp(-1, 1) + 1) / 2
                rgb = rgb.cpu()
                part_gs[key].rgb = ptutils.MultiResolutionStore(rgb)

View the results

In [None]:
res = 256
visutils.part_grid(gs1.rgb.get(res), gs2.rgb.get(res), {k: v.rgb.get(res) for k,v in part_gs.items()});

View the MSE in CIELAB color-space, between the edited output and the target image

In [None]:
res = 256
normalize = lambda x: x/x.max()
visutils.part_grid(gs1.rgb.get(res), gs2.rgb.get(res),
                 {k: normalize(cielab.squared_error(v.rgb.get(res), gs1.rgb.get(res))) for k,v in part_gs.items()});