# Minimal implementation of Enjoy your Editing
Zhuang, P., Koyejo, O. O., and Schwing, A. (2021). 
Enjoy your editing: Controllable GANs for image editing via latent space navigation. 
In International Conference on Learning Representations
https://arxiv.org/pdf/2102.01187.pdf

In [None]:
from IPython.display import Image
Image(filename='./jupyter_imgs/enjoy_your_editing_architecture.png')

## index for target features
'5_o_Clock_Shadow', 0
'Arched_Eyebrows', 1
'Attractive', 2
'Bags_Under_Eyes', 3 
'Bald', 4
'Bangs', 5
'Big_Lips', 6 
'Big_Nose', 7
'Black_Hair', 8 
'Blond_Hair', 9
'Blurry', 10
'Brown_Hair', 11 
'Bushy_Eyebrows', 12 
'Chubby', 13
'Double_Chin', 14 
'Eyeglasses', 15
'Goatee', 16
'Gray_Hair', 17 
'Heavy_Makeup', 18 
'High_Cheekbones', 19 
'Male', 20
'Mouth_Slightly_Open', 21 
'Mustache', 22
'Narrow_Eyes', 23 
'No_Beard', 24
'Oval_Face', 25
'Pale_Skin', 26
'Pointy_Nose', 27
'Receding_Hairline', 28 
'Rosy_Cheeks', 29
'Sideburns', 30
'Smiling',31
'Straight_Hair', 32 
'Wavy_Hair', 33
'Wearing_Earrings', 34 
'Wearing_Hat', 35
'Wearing_Lipstick', 36 
'Wearing_Necklace', 37
'Wearing_Necktie', 38
'Young' 39

In [None]:
stylegan2_path = "./stylegan2-ada-pytorch"

import sys
sys.path.append(stylegan2_path)

import torch
import pickle
import torch.nn as nn
import torch.nn.functional as F
import os
import click
import dnnlib
import legacy
import numpy as np
import time

device = "cuda:0"
size = 1024
truncation_psi = 0.5
n_iters = 20001
batch_size = 8
random_start_seed = 0
iter_start_seed = 0
#d = torch.Tensor(np.random.RandomState(random_start_seed).normal(0, 0.002, [1, 512])).to(device)
d = torch.Tensor(np.zeros([1, 512])).to(device)
target_feature_index = 31
lambda_regressor = 10.0
lambda_content = 0.05
lambda_gan = 0.05
learning_rate = 0.0001

noise_mode = 'const'


In [None]:
logging_folder = stylegan2_path + f'/training_runs/stylegan2/eye_start_zero_start_seed_{random_start_seed}_iter_start_seed_{iter_start_seed}_lambda_regressor_{lambda_regressor}_lambda_content_{lambda_content}_lambda_gan_{lambda_gan}_feature_{target_feature_index}_lr_{learning_rate}_batch_size{batch_size}'
if not os.path.exists(logging_folder):
    os.makedirs(logging_folder)

if not os.path.exists(logging_folder + '/saved_latent_vecs'):
    os.makedirs(logging_folder + '/saved_latent_vecs')

with open(logging_folder + "/training_log.txt", "a") as myfile:
            myfile.write(f'./training_runs/stylegan2/eye_start_zero_start_seed_{random_start_seed}_iter_start_seed_{iter_start_seed}_lambda_regressor_{lambda_regressor}_lambda_content_{lambda_content}_lambda_gan_{lambda_gan}_feature_{target_feature_index}_lr_{learning_rate}_batch_size{batch_size}' + '\n')


In [None]:
class Normalization(nn.Module):
    def __init__(self):
        super(Normalization, self).__init__()
        mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
        std = torch.tensor([0.229, 0.224, 0.225]).to(device)

        self.mean = mean.clone().detach().view(-1, 1, 1)
        self.std = std.clone().detach().view(-1, 1, 1)

    def forward(self, img):
        return (img - self.mean) / self.std

In [None]:
def get_bce_loss(pred, y, eps=1e-12):
    loss = -(y * pred.clamp(min=eps).log() + (1 - y) * (1 - pred).clamp(min=eps).log()).mean()
    return loss

BCE_loss_logits = nn.BCEWithLogitsLoss()

# Generator and Discriminator

In [None]:
network_pkl = './pretrained_models/ffhq.pkl'

with dnnlib.util.open_url(network_pkl) as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore

with dnnlib.util.open_url(network_pkl) as f:
    D = legacy.load_network_pkl(f)['D'].to(device)

label = torch.zeros([1, G.c_dim], device=device)

# Content Loss

In [None]:
import torchvision.models as models
vgg19 = models.vgg19(pretrained=True).features.to(device).eval()

def get_content_loss(org_img, shifted_img):
    content_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4']
    norm = Normalization().to(device)
    model = nn.Sequential(norm)

    i = 0
    content_losses = []
    for layer in vgg19.children():
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = 'conv_{}'.format(i)
        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(i)
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool_{}'.format(i)
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)
        else:
            raise RuntimeError('Unrecognized layer: {}'
                               .format(layer.__class__.__name__))
        model.add_module(name, layer)
        if name in content_layers:
            org_content = model(org_img).detach()
            shifted_content = model(shifted_img)
            content_loss = F.mse_loss(org_content.detach(), shifted_content)
            content_losses.append(content_loss)
        
    for i in range(len(content_losses)):
        content_loss += content_losses[i]
        content_loss = content_loss / len(content_losses)
    return content_loss

# Regressor

In [None]:
'''def get_reg_module():
    # Scene/Face, hard code resnet50 here
    model = torch.hub.load('pytorch/vision:v0.5.0', 'resnet50', pretrained=False)
    model.fc = torch.nn.Linear(2048, 40)
    model = model.cuda()
    ckpt = torch.load('./003_dict.model')
    model.load_state_dict(ckpt['model'])
    """
    If fine-tune or jointly train the classifier
    """
    # optimizer.load_state_dict(ckpt['optm'])
    # return model, optimizer
    return model, None

regressor, _ = get_reg_module()
regressor.eval()
'''
regressor = torch.jit.load('./pretrained_models/resnet50.pth').eval().to(device)

In [None]:
Image(filename='./jupyter_imgs/enjoy_your_editing_pseudocode.png')

In [None]:
d.requires_grad = True
optimizer = torch.optim.Adam([d], lr=learning_rate, betas=(0.5, 0.99))
time1 = time.time()

for i in range(n_iters):
    optimizer.zero_grad()
    
    # Algorithm image step 2
    z = torch.from_numpy(np.random.RandomState(i+iter_start_seed).randn(batch_size, 512)).to(device)
    w = G.mapping(z,label, truncation_psi=truncation_psi) #  mapping z --> w, we use w instead of z as latent vec
    epsilon = torch.from_numpy((np.random.RandomState(i+iter_start_seed).rand(batch_size) - 0.5)*2).to(device) # uniform distribution -1..1
    
    # Algorithm image step 3, computing the image and extract attributes with regressor
    img_orig = G.synthesis(w, noise_mode=noise_mode)
    
    # scaling the img from 1024x1024 to 256x256 so it has the right size for the regressor
    img_orig = F.interpolate(img_orig, size=256)
    
    alpha = regressor(img_orig)[:, target_feature_index]
    
    # Algorithm image step 4
    delta = torch.clamp(alpha+epsilon, min=0.0, max=1.0) - alpha
    
    # Algorithm image step 5 and 6 compute shifted img
    attribute_vector = (d*torch.transpose(torch.stack(512*[delta]), 0, 1))
    attribute_vector_stacked18 = torch.stack([attribute_vector]*18).permute(1,0,2)
    w_shifted = w + attribute_vector_stacked18
    img_shifted = G.synthesis(w_shifted, noise_mode=noise_mode)
    
    # scaling the img from 1024x1024 to 256x256
    img_shifted_256 = F.interpolate(img_shifted, size=256)
    
    alpha_prime = alpha + delta
    
    # Algorithm image step 7
    alpha_shifted = regressor(img_shifted_256)[:, target_feature_index]
    
    # Algorithm step 8, calculate loss
    content_loss = get_content_loss(img_orig, img_shifted_256)

    regressor_loss = get_bce_loss(alpha_shifted, alpha_prime, eps=1e-12) # paper implementation
    
    discriminator_pred = D(img_shifted, c=label)
    y_real = torch.autograd.Variable(torch.ones_like(discriminator_pred).to(device))
    gan_loss = BCE_loss_logits(discriminator_pred, y_real)
    
    loss = lambda_regressor*regressor_loss + lambda_content*content_loss + lambda_gan*gan_loss
    print('iter: ', i, ' loss:', loss.item())
    
    with open(logging_folder + "/training_log.txt", "a") as myfile:
            myfile.write('loss: ' + str(loss.item()) + '\n')
        
    if i%100 == 0:
        torch.save(d, logging_folder+'/saved_latent_vecs/latent_vec_' + str(i) + '.pt')
    
    # Algorithm step 9
    loss.backward()
    optimizer.step()

time2 = time.time()
print('training_time: ', time2 - time1)