# Calc attribute and identity preservation
### Mostly the implementation provided by https://github.com/KelestZ/Latent2im/blob/main/eval.py  

In [None]:
from os import listdir
from os.path import isfile, join
from natsort import natsorted
import sys
stylegan2_path = './stylegan2-ada-pytorch'
sys.path.append(stylegan2_path)

import dnnlib
import click
import legacy
import torch
import numpy as np
from PIL import Image
from scipy.spatial.distance import cosine
import torch.nn.functional as F

n_regressor_predictions = 40 # the regressor is pretrained on CelebA and predicts 40 face attributes
device = 'cuda:0'
batch_size = 1
truncation_psi = 0.5
noise_mode = 'const'
network_pkl =  './pretrained_models/ffhq.pkl'

with dnnlib.util.open_url(network_pkl) as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device) # stylegan2-generator

with dnnlib.util.open_url(network_pkl) as f:
    D = legacy.load_network_pkl(f)['D'].to(device)

label = torch.zeros([1, G.c_dim], device=device)

def initialize_model():
    from facenet_pytorch import MTCNN, InceptionResnetV1
    resnet = InceptionResnetV1(pretrained='vggface2').to(device).eval()
    return resnet

face_rec = initialize_model()
regressor = torch.jit.load('./pretrained_models/resnet50.pth').eval().to(device)
scaling_coeffs_eval = np.linspace(0, 1, 10) # 10 scaling factors for the edited images just like in Enjoy your Editing

In [None]:
## index for target features
'5_o_Clock_Shadow', 0
'Arched_Eyebrows', 1
'Attractive', 2
'Bags_Under_Eyes', 3 
'Bald', 4
'Bangs', 5
'Big_Lips', 6 
'Big_Nose', 7
'Black_Hair', 8 
'Blond_Hair', 9
'Blurry', 10
'Brown_Hair', 11 
'Bushy_Eyebrows', 12 
'Chubby', 13
'Double_Chin', 14 
'Eyeglasses', 15
'Goatee', 16
'Gray_Hair', 17 
'Heavy_Makeup', 18 
'High_Cheekbones', 19 
'Male', 20
'Mouth_Slightly_Open', 21 
'Mustache', 22
'Narrow_Eyes', 23 
'No_Beard', 24
'Oval_Face', 25
'Pale_Skin', 26
'Pointy_Nose', 27
'Receding_Hairline', 28 
'Rosy_Cheeks', 29
'Sideburns', 30
'Smiling',31
'Straight_Hair', 32 
'Wavy_Hair', 33
'Wearing_Earrings', 34 
'Wearing_Hat', 35
'Wearing_Lipstick', 36 
'Wearing_Necklace', 37
'Wearing_Necktie', 38
'Young' 39

In [None]:
def get_attribute_preservation_identity_preservation_buckets(attr_vec, target_attr_index):
    '''
    attr_vec: torch Tensor of shape [1, 512], which should change a semantically meaningful attribute 
              in the stylegan2-latent space
    target_attr_index: int, which defines the attribute which should be changed, the index for each feature
              is listed in the cell above
    This function calculates the identity and attribute preservation for evaluating attribute vectors.
    To have a the same amount of attribute change for a fair comparision of different approaches we 
    return interval_counter and normalize using the first element as described in our paper
    '''
    img_orig_segments = [[], [], []]
    img_shifted_segments = [[], [], []]

    pred_orig_segments = [[], [], []]
    pred_shifted_segments = [[], [], []]
    interval_counter = [0, 0, 0]

    with torch.no_grad():
        for random_seed in range(1000):
            z = torch.from_numpy(np.random.RandomState(random_seed).randn(batch_size, 512)).to(device)
            w = G.mapping(z,label, truncation_psi=truncation_psi)
            img_orig = G.synthesis(w, noise_mode=noise_mode)
            img_orig = F.interpolate(img_orig, size=256) # reshape the image for the face_rec and regressor
            pred_orig = regressor(img_orig).detach().cpu().numpy()

            # as defined in the enjoy your editing paper we create 1000 original and 10,000 shifted images
            # with an increasing amount of attribute change
            for alpha in scaling_coeffs_eval: 
                delta = alpha - pred_orig[0, target_attr_index]
                img_shifted = G.synthesis(w+attr_vec*delta, noise_mode=noise_mode)
                img_shifted = F.interpolate(img_shifted, size=256)
                pred_shifted = regressor(img_shifted).detach().cpu().numpy()
                if np.abs(pred_shifted[0, target_attr_index] - pred_orig[0, target_attr_index]) <=0.3:
                    img_orig_segments[0].append(img_orig)
                    img_shifted_segments[0].append(img_shifted)
                    pred_shifted_segments[0].append(pred_shifted)
                    pred_orig_segments[0].append(pred_orig)
                elif np.abs(pred_shifted[0, target_attr_index] - pred_orig[0, target_attr_index]) <=0.6:
                    img_orig_segments[1].append(img_orig)
                    img_shifted_segments[1].append(img_shifted)
                    pred_shifted_segments[1].append(pred_shifted)
                    pred_orig_segments[1].append(pred_orig)
                elif np.abs(pred_shifted[0, target_attr_index] - pred_orig[0, target_attr_index]) <=0.9:
                    img_orig_segments[2].append(img_orig)
                    img_shifted_segments[2].append(img_shifted)
                    pred_shifted_segments[2].append(pred_shifted)
                    pred_orig_segments[2].append(pred_orig)

        embeddings = []
        sim = [[], [], []]
        interval_counter[0] += len(pred_orig_segments[0])
        interval_counter[1] += len(pred_orig_segments[1])
        interval_counter[2] += len(pred_orig_segments[2])

        for k in range(3): # we have the 3 segments (0-0.3, 0.3-0.6, 0.6-0.9) defined above
            for i in range(len(img_shifted_segments[k])):
                # Compute the embedding of the original image
                img_org_np = np.uint8(np.clip(((img_orig_segments[k][i].cpu().numpy() + 1) / 2.0) * 255, 0, 255))
                org = Image.fromarray(np.transpose(img_org_np[0], (1,2,0)))
                reshaped_org = org.resize((160, 160))
                reshaped_org = torch.Tensor(np.transpose(np.array(reshaped_org), (2,0,1))).to(device).unsqueeze(0)
                embed_org = face_rec(reshaped_org)

                # Compute the embedding of the edited image
                img_shifted_np = np.uint8(np.clip(((img_shifted_segments[k][i].cpu().numpy() + 1) / 2.0) * 255, 0, 255))
                img_shifted = Image.fromarray(np.transpose(img_shifted_np[0], (1, 2, 0)))
                reshaped_shifted_img = img_shifted.resize((160, 160))
                reshaped_shifted_img = torch.Tensor(np.transpose(np.array(reshaped_shifted_img), (2, 0, 1))).to(device).unsqueeze(0)

                embed = face_rec(reshaped_shifted_img)
                # Compute the Cosine similarity for image identity preservation
                similarity = cosine(embed.detach().cpu().numpy(), embed_org.detach().cpu().numpy())
                sim[k].append(similarity)

        results_avg = []
        results_std = []

        for k in range(3):
            if len(sim[k]) == 0:
                continue
            result_avg = 1-np.mean(sim[k])
            result_std = np.array(sim[k]).std()
            results_avg.append(result_avg)
            results_std.append(result_std)

        print('[IDENTITY PRESERVATION MEAN] Results on 3 epsilon segments', ['%.4f' % i for i in results_avg])
        print('[IDENTITY PRESERVATION STD ] Results on 3 epsilon segments', ['%.4f' % i for i in results_std])

        multi_attrs = [[], [], []]
        original_attrs = [[], [], []]

        multi_attrs[0] += pred_shifted_segments[0]
        multi_attrs[1] += pred_shifted_segments[1]
        multi_attrs[2] += pred_shifted_segments[2]

        original_attrs[0] += pred_orig_segments[0]
        original_attrs[1] += pred_orig_segments[1]
        original_attrs[2] += pred_orig_segments[2]

        for k in range(3):
            multi_attrs[k] = np.array(multi_attrs[k])
            original_attrs[k] = np.array(original_attrs[k])

        results_avg = []
        results_std = []

        for k in range(3):
            if (original_attrs[k].shape[0] == 0):
                continue

            org = np.hstack([original_attrs[k][:, :int(target_attr_index)], original_attrs[k][:, int(target_attr_index + 1):]])
            changed = np.hstack([multi_attrs[k][:, :int(target_attr_index)], multi_attrs[k][:, int(target_attr_index + 1):]])
            
            result_avg = np.mean(np.abs(changed - org))
            result_std = np.abs(changed - org).std()
            
            results_avg.append(result_avg)
            results_std.append(result_std)

        print('[ATTRIBUTE PRESERVATION MEAN] Results on 3 epsilon segments', ['%.4f' % i for i in results_avg])
        print('[ATTRIBUTE PRESERVATION STD ] Results on 3 epsilon segments', ['%.4f' % i for i in results_std])
        
        print('intervall_counter ', interval_counter)
        print('===========================================')
        print('attribute vector length', attr_vec.norm())
        return interval_counter


In [None]:
def get_attr_vector_scaling(attr_vec, reference_bucket, stepsize, optimization_direction, scaling_coeff):
    '''
    attr_vec: torch Tensor of shape [1, 512], which should change a semantically meaningful attribute 
              in the stylegan2-latent space
    reference_bucket: int, input first element of a reference interval_counter to find a scaling for an attribute 
              vector such that the first element of interval_counter is reference_bucket +/- 1%
    stepsize: int, starting step size for optimization
    optimization_direction: int, defines if we start to increase or decrease the scaling_coeff
    scaling_coeff: int, initial scaling coefficient to get the same interval_counter[0] as reference_bucket
              gets optimized with this function
    This function is a quick implementation to find a scaling_coeff for attr_vec, such that the 
    interval_counter[0] == reference_bucket +/- 1%
    
    '''
    for i in range(20):
        intervall_counter = get_attribute_preservation_identity_preservation_buckets(attr_vec*scaling_coeff, target_attr_index)
        first_bucket = intervall_counter[0]
        if np.abs(first_bucket - reference_bucket) < reference_bucket*0.01:
            break
        elif first_bucket - reference_bucket < 0:
            scaling_coeff -= stepsize
            if optimization_direction == -1:
                stepsize /= 2
            optimization_direction = 1
        else:
            scaling_coeff += stepsize
            if optimization_direction == 1:
                stepsize /= 2
            optimization_direction = -1
        print(stepsize, intervall_counter, reference_bucket, scaling_coeff)
    print(stepsize, intervall_counter, reference_bucket, scaling_coeff)
    return scaling_coeff


In [None]:
# smiling
attr_vec_zhuang_smile = torch.load("./stylegan2-ada-pytorch/training_runs/stylegan2/eye_start_zero_start_seed_0_iter_start_seed_0_lambda_regressor_10.0_lambda_content_0.05_lambda_gan_0.05_feature_31_lr_0.0001_batch_size1/saved_latent_vecs/latent_vec_20000.pt", map_location=device)
d_np = np.load('./stylegan2-ada-pytorch/training_runs/stylegan2/shen/shen_smiling_w.npy')
attr_vec_shen_smile = torch.Tensor([d_np]).to(device)
attr_vec_ours_bs1_smile = torch.load("./stylegan2-ada-pytorch/training_runs/stylegan2/our_approach_feature_31_maxLenght_0.8_lr_0.0003_batch_size1/saved_latent_vecs/latent_vec_61000.pt", map_location=device)
attr_vec_ours_bs8_smile = torch.load("./stylegan2-ada-pytorch/training_runs/stylegan2/our_approach_feature_31_maxLenght_0.8_lr_0.0003_batch_size8/saved_latent_vecs/latent_vec_12000.pt", map_location=device)
###########################################
# hair color
attr_vec_zhuang_hair = torch.load("./stylegan2-ada-pytorch/training_runs/stylegan2/eye_start_zero_start_seed_0_iter_start_seed_20000_lambda_regressor_10.0_lambda_content_0.05_lambda_gan_0.05_feature_9_lr_0.0001_batch_size1/saved_latent_vecs/latent_vec_20000.pt", map_location=device)
d_np = np.load('./stylegan2-ada-pytorch/training_runs/stylegan2/shen/shen_hair_col_9_w.npy')
attr_vec_shen_hair = torch.Tensor([d_np]).to(device)
attr_vec_ours_bs1_hair = torch.load("./stylegan2-ada-pytorch/training_runs/stylegan2/our_approach_feature_9_maxLenght_0.8_lr_0.0003_batch_size1_seed_offset_0/saved_latent_vecs/latent_vec_61000.pt", map_location=device)
attr_vec_ours_bs8_hair = torch.load("./stylegan2-ada-pytorch/training_runs/stylegan2/our_approach_feature_9_maxLenght_0.8_lr_0.0003_batch_size8_seed_offset_0/saved_latent_vecs/latent_vec_12000.pt", map_location=device)


In [None]:
target_attr_index = 31 # smiling
#target_attr_index = 9 # hair_col

print("Get the first Bucket of Zhuang for the attribute Smiling as a reference to scale the other vectors")
buckets_zhuang = get_attribute_preservation_identity_preservation_buckets(attr_vec_zhuang_smile, target_attr_index)
reference_bucket = buckets_zhuang[0]
stepsize = 0.5
optimization_direction = 1
scaling_coeff = 4.0

print("Find the right scaling such that the first bucket has as many samples as Zhuang +/- 1%")
get_attr_vector_scaling(attr_vec_shen_smile, reference_bucket, stepsize, optimization_direction, scaling_coeff)

In [None]:
# smiling
# scaling zhuang smiling = 1.0
print("Zhuang smiling")
get_attribute_preservation_identity_preservation_buckets(attr_vec_zhuang_smile*1.0, 31)
print("----------------------------------------")
print(' ')

# scaling shen smiling = 1.31
print("Shen smiling")
get_attribute_preservation_identity_preservation_buckets(attr_vec_shen_smile*1.31, 31)
print("----------------------------------------")
print(' ')

# scaling ours_bs1 smiling = 1.75
print("Ours Batch Size=1 smiling")
get_attribute_preservation_identity_preservation_buckets(attr_vec_ours_bs1_smile*1.75, 31)
print("----------------------------------------")
print(' ')

# scaling ours_bs8 smiling = 1.625
print("Ours Batch Size=8 smiling")
get_attribute_preservation_identity_preservation_buckets(attr_vec_ours_bs8_smile*1.625, 31)
print("----------------------------------------")
print(' ')

###########################################
# hair color

# scaling zhuang hair = 1.0
print("Zhuang Hair9")
get_attribute_preservation_identity_preservation_buckets(attr_vec_zhuang_hair*1.0, 9)
print("----------------------------------------")
print(' ')

# scaling shen hair = 2.4375
print("Shen Hair9")
get_attribute_preservation_identity_preservation_buckets(attr_vec_shen_hair*2.4375, 9)
print("----------------------------------------")
print(' ')

# scaling ours_bs1 = 4
print("Ours Hair9")
get_attribute_preservation_identity_preservation_buckets(attr_vec_ours_bs1_hair*4.0, 9)
print("----------------------------------------")
print(' ')

# scaling ours_bs8 = 3.625
print("Ours Hair9")
get_attribute_preservation_identity_preservation_buckets(attr_vec_ours_bs8_hair*3.625, 9)
print("----------------------------------------")
print(' ')

# Visualize influece of attribute vector

In [None]:
def normalize_tensor(tensor):
    tensor = tensor.clone()  # avoid modifying tensor in-place

    def norm_ip(img, min, max):
        img.clamp_(min=min, max=max)
        img.add_(-min).div_(max - min + 1e-5)
        return img

    return norm_ip(tensor, float(tensor.min()), float(tensor.max()))

In [None]:
import matplotlib.pyplot as plt
random_seed = 0
truncation_psi = 0.5
noise_mode = 'const'
label = torch.zeros([1, G.c_dim], device=device)
z = torch.from_numpy(np.random.RandomState(random_seed).randn(1, G.z_dim)).to(device)
w = G.mapping(z,label, truncation_psi=truncation_psi)
img_pt1 = G.synthesis(w, noise_mode=noise_mode)

img = normalize_tensor(img_pt1[0])
img = img.cpu().numpy()
img = np.rollaxis(img, 0, 3)
plt.imshow(img)

In [None]:
img_pt1 = G.synthesis(w+attr_vec_ours_bs8_hair*3.625, noise_mode=noise_mode)

img = normalize_tensor(img_pt1[0])
img = img.detach().cpu().numpy()
img = np.rollaxis(img, 0, 3)
plt.imshow(img)