# Implementation of Keep it Simple: Local Search-based Latent Space Editing
Meißner, A.; Fröhlich, A. and Geierhos, M. (2022). Keep It Simple: Local Search-based Latent Space Editing. In Proceedings of the 14th International Joint Conference on Computational Intelligence - NCTA, ISBN 978-989-758-611-8; ISSN 2184-2825, pages 273-283. DOI: 10.5220/0011524700003332


In [None]:
from IPython.display import Image
Image(filename='./jupyter_imgs/local_search_architecture.png', width=300)

## index for target features
'5_o_Clock_Shadow', 0
'Arched_Eyebrows', 1
'Attractive', 2
'Bags_Under_Eyes', 3 
'Bald', 4
'Bangs', 5
'Big_Lips', 6 
'Big_Nose', 7
'Black_Hair', 8 
'Blond_Hair', 9
'Blurry', 10
'Brown_Hair', 11 
'Bushy_Eyebrows', 12 
'Chubby', 13
'Double_Chin', 14 
'Eyeglasses', 15
'Goatee', 16
'Gray_Hair', 17 
'Heavy_Makeup', 18 
'High_Cheekbones', 19 
'Male', 20
'Mouth_Slightly_Open', 21 
'Mustache', 22
'Narrow_Eyes', 23 
'No_Beard', 24
'Oval_Face', 25
'Pale_Skin', 26
'Pointy_Nose', 27
'Receding_Hairline', 28 
'Rosy_Cheeks', 29
'Sideburns', 30
'Smiling',31
'Straight_Hair', 32 
'Wavy_Hair', 33
'Wearing_Earrings', 34 
'Wearing_Hat', 35
'Wearing_Lipstick', 36 
'Wearing_Necklace', 37
'Wearing_Necktie', 38
'Young' 39

In [None]:
stylegan2_path = "./stylegan2-ada-pytorch"

import sys
sys.path.append(stylegan2_path)

import torch
import pickle
import torch.nn as nn
import torch.nn.functional as F
import os
import click
import dnnlib
import legacy
import numpy as np
import time

device = "cuda:0"
size = 1024
truncation_psi = 0.5
n_iters = 65001
batch_size = 1
d = torch.zeros([1, 512]).to(device)
norm_length = 1.0
target_feature_index = 31
learning_rate = 0.001

noise_mode = 'const'


In [None]:
logging_folder = stylegan2_path + f'/training_runs/stylegan2/our_approach_feature_{target_feature_index}_maxLenght_{norm_length}_lr_{learning_rate}_batch_size{batch_size}'
if not os.path.exists(logging_folder):
    os.makedirs(logging_folder)

if not os.path.exists(logging_folder + '/saved_latent_vecs'):
    os.makedirs(logging_folder + '/saved_latent_vecs')

with open(logging_folder + "/training_log.txt", "a") as myfile:
            myfile.write(f'./training_runs/stylegan2/our_approach_feature_{target_feature_index}_maxLenght_{norm_length}_lr_{learning_rate}_batch_size{batch_size}' + '\n')


In [None]:
class Normalization(nn.Module):
    def __init__(self):
        super(Normalization, self).__init__()
        mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
        std = torch.tensor([0.229, 0.224, 0.225]).to(device)

        self.mean = mean.clone().detach().view(-1, 1, 1)
        self.std = std.clone().detach().view(-1, 1, 1)

    def forward(self, img):
        return (img - self.mean) / self.std

In [None]:
network_pkl = './pretrained_models/ffhq.pkl'

with dnnlib.util.open_url(network_pkl) as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore

with dnnlib.util.open_url(network_pkl) as f:
    D = legacy.load_network_pkl(f)['D'].to(device)
    
label = torch.zeros([1, G.c_dim], device=device)

# Regressor

In [None]:
regressor = torch.jit.load('./pretrained_models/resnet50.pth').eval().to(device)

# Algorithm

In [None]:
Image(filename='./jupyter_imgs/local_search_pseudocode.png', width=500)

In [None]:
time1 = time.time()
for random_seed in range(n_iters):
    with torch.no_grad():
        z = torch.from_numpy(np.random.RandomState(random_seed).randn(batch_size, 512)).to(device)
        w = G.mapping(z, label, truncation_psi=truncation_psi)
        
        epsilon = np.random.choice([-1,1])
        img_d = G.synthesis(w+(d*epsilon), noise_mode=noise_mode)
        img_d = F.interpolate(img_d, size=256)
        alpha_d = regressor(img_d)[:, target_feature_index]
        
        d_new = d + learning_rate*torch.Tensor(np.random.RandomState(random_seed+n_iters).randn(1, 512)).to(device)
        if torch.norm(d_new).item() > norm_length:
            d_new = norm_length*d_new/torch.norm(d_new)
        
        img_d_new = G.synthesis(w+(d_new*epsilon), noise_mode=noise_mode)
        img_d_new = F.interpolate(img_d_new, size=256)
        alpha_d_new = regressor(img_d_new)[:, target_feature_index]
        
        print('pred_old:', alpha_d.mean().item(), 'pred_new:', alpha_d_new.mean().item())
        with open(logging_folder + "/training_log.txt", "a") as myfile:
            myfile.write('pred_old: ' + str(alpha_d.mean().item()) + 'pred_new: ' + str(alpha_d_new.mean().item()) + '\n')
        
        if random_seed%100 == 0:
            torch.save(d, logging_folder+'/saved_latent_vecs/latent_vec_' + str(random_seed) + '.pt')
        
        if epsilon * alpha_d.mean().item() < epsilon*alpha_d_new.mean().item():
            d = d_new
            
time2 = time.time()
print('training-time: ', time2 - time1)