In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":16:8"

from os import listdir
from os.path import isfile, join
import sys
import pickle
import dnnlib
import click
import legacy
import torch
import numpy as np
from PIL import Image
from scipy.spatial.distance import cosine
import torch.nn.functional as F
from training.networks import SynthesisNetwork,SynthesisBlock,SynthesisLayer,ToRGBLayer
from torch_utils import misc
import types

n_regressor_predictions = 40 # the regressor is pretrained on CelebA and predicts 40 face attributes
device = 'cuda:0'
batch_size = 1
truncation_psi = 0.5
noise_mode = 'const'
network_pkl =  '../pretrained_models/ffhq.pkl'

goal_attr_change = 2.197 # We want the sigmoid output in the range of 0.1..0.9. I
# in case of a sigmoid output of 0.5 we only want to change the target attribute for 0.4 ~ 2.197 for the sigmoid input
target_attr = 31
stepsize = 0.5 
break_delta=0.01
max_iter=50


def initialize_model():
    from facenet_pytorch import MTCNN, InceptionResnetV1
    resnet = InceptionResnetV1(pretrained='vggface2').to(device).eval()
    return resnet

face_rec = initialize_model()

file_to_read = open("../pretrained_models/resnet_092_all_attr_5_epochs.pkl", "rb")
regressor = pickle.load(file_to_read)
file_to_read.close()
regressor.eval()
regressor = regressor.to(device)

def get_face_embed(img):
    img_org_np = np.uint8(np.clip(((img.detach().cpu().numpy() + 1) / 2.0) * 255, 0, 255))
    org = Image.fromarray(np.transpose(img_org_np[0], (1,2,0)))
    reshaped_org = org.resize((160, 160))
    reshaped_org = torch.Tensor(np.transpose(np.array(reshaped_org), (2,0,1))).to(device).unsqueeze(0)
    embed_org = face_rec(reshaped_org)
    return embed_org

In [None]:

def LoadModel(network_pkl,device):
    with dnnlib.util.open_url(network_pkl) as fp:
        G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore
    
    G.synthesis.forward=types.MethodType(SynthesisNetwork.forward,G.synthesis)
    G.synthesis.W2S=types.MethodType(SynthesisNetwork.W2S,G.synthesis)
    
    for res in G.synthesis.block_resolutions:
        block = getattr(G.synthesis, f'b{res}')
        # print(block)
        block.forward=types.MethodType(SynthesisBlock.forward,block)
        
        if res!=4:
            layer=block.conv0
            layer.forward=types.MethodType(SynthesisLayer.forward,layer)
            layer.name='conv0_resolution_'+str(res)
            
        layer=block.conv1
        layer.forward=types.MethodType(SynthesisLayer.forward,layer)
        layer.name='conv1_resolution_'+str(res)
        
        layer=block.torgb
        layer.forward=types.MethodType(ToRGBLayer.forward,layer)
        layer.name='toRGB_resolution_'+str(res)
        
    
    return G

G = LoadModel(network_pkl, device)   
label = torch.zeros([1, G.c_dim], device=device)

def add_attr_vec(encoded_styles, attr_vec, epsilon):
    new_attr_vec = {}
    i=0
    for key in encoded_styles.keys():
        new_attr_vec[key] = encoded_styles[key] + attr_vec[0:, i:i+encoded_styles[key].shape[1]]*epsilon
        i+=encoded_styles[key].shape[1]
    return new_attr_vec

In [None]:
def eval_attr_vec_s(l_vec, target_attr, num_samples):
    pred_diff_list = []
    embedding_distance_list = []
    attr_pres_list = []

    for j in range(num_samples):
        z = torch.from_numpy(np.random.RandomState(j).randn(1, 512)).to(device)
        w = G.mapping(z,label, truncation_psi=truncation_psi)
        img_orig = G.synthesis(w, noise_mode=noise_mode)
        img_orig = F.interpolate(img_orig, size=256) # reshape the image for the face_rec and regressor
        pred_orig = regressor(img_orig).detach().cpu().numpy()
        embedding_orig = get_face_embed(img_orig)
        random_scaling = torch.from_numpy(np.random.RandomState(j).rand(1, 1)).to(device)
        #random_scaling = torch.Tensor([1.0]).to(device)

        if pred_orig[0, target_attr] > 0.0:
            #img_shifted = G.synthesis(w-(l_vec*random_scaling), noise_mode=noise_mode)
            s=G.synthesis.W2S(w)
            s_new = add_attr_vec(s, l_vec, -random_scaling)
            img_shifted = G.synthesis(None, encoded_styles=s_new,noise_mode='const')
            
        else:
            #img_shifted = G.synthesis(w+(l_vec*random_scaling), noise_mode=noise_mode)
            s=G.synthesis.W2S(w)
            s_new = add_attr_vec(s, l_vec, random_scaling)
            img_shifted = G.synthesis(None, encoded_styles=s_new,noise_mode='const')

        img_shifted = F.interpolate(img_shifted, size=256) # reshape the image for the face_rec and regressor
        pred_shifted = regressor(img_shifted).detach().cpu().numpy()

        # cosine distance
        embedding_shifted = get_face_embed(img_shifted)
        embedding_distance = cosine(embedding_orig.detach().cpu().numpy(), embedding_shifted.detach().cpu().numpy())
        embedding_distance_list.append(embedding_distance)
        
        # attribute change
        pred_diff_list.append(np.abs(pred_shifted[0, target_attr] - pred_orig[0, target_attr] ))

        # attribute preservation
        other_pred_org = np.hstack([pred_orig[:, :int(target_attr)], pred_orig[:, int(target_attr + 1):]])
        other_pred_shifted = np.hstack([pred_shifted[:, :int(target_attr)], pred_shifted[:, int(target_attr + 1):]])
        attr_preservation = np.mean(np.abs(other_pred_org - other_pred_shifted))
        attr_pres_list.append(attr_preservation)
        
    target_change = np.array(pred_diff_list).mean()
    attr_dist = np.array(attr_pres_list).mean()
    arcf_dist = np.array(embedding_distance).mean()
    target_change_std = np.array(pred_diff_list).std()
    attr_dist_std = np.array(attr_pres_list).std()
    arcf_dist_std = np.array(embedding_distance_list).std()
    return target_change, target_change_std, attr_dist, attr_dist_std, arcf_dist, arcf_dist_std

In [None]:
def get_pred_from_img_s(s, target_attr):
    img = G.synthesis(None, encoded_styles=s,noise_mode='const')
    img = F.interpolate(img, size=256) # reshape the image for the face_rec and regressor
    pred = regressor(img).detach().cpu().numpy()
    return pred[0, target_attr]

def get_scaling_factor_s(l_vec1, goal_attr_change, target_attr, batch_size, scaling_factor, stepsize, break_delta, max_iter):
    i = 0
    scaling_direction_flag = 0

    while(True):
        target_attr_change_list = []
        for j in range(batch_size):
            z = torch.from_numpy(np.random.RandomState(j).randn(1, 512)).to(device)
            w = G.mapping(z,label, truncation_psi=truncation_psi)
            s=G.synthesis.W2S(w)
            pred_orig = get_pred_from_img_s(s, target_attr)

            if pred_orig > 0.0:
                target_attr_change = get_pred_from_img_s(add_attr_vec(s, l_vec1, -scaling_factor), target_attr) - pred_orig
            else:
                target_attr_change = get_pred_from_img_s(add_attr_vec(s, l_vec1, scaling_factor), target_attr) - pred_orig

            target_attr_change_list.append(target_attr_change)

        avg_attr_delta = np.abs(np.array(target_attr_change_list)).mean()
        
        if np.abs(avg_attr_delta - goal_attr_change) < break_delta:
            return scaling_factor
        
        else:              
            if avg_attr_delta < goal_attr_change:
                scaling_factor += stepsize
                if scaling_direction_flag == -1:
                    stepsize = stepsize/2
                scaling_direction_flag = 1
            elif avg_attr_delta > goal_attr_change:
                scaling_factor -= stepsize
                if scaling_direction_flag == 1:
                    stepsize = stepsize/2
                scaling_direction_flag = -1

        #print(avg_attr_delta, goal_attr_change, scaling_factor)

        i+=1
        if i > max_iter:
            return scaling_factor
            #break

In [None]:
def scale_and_eval_s(attr_vec, target_attr):
    scaling_factor = 1
    for batch_size in [8, 32, 128, 1000]:
        scaling_factor = get_scaling_factor_s(attr_vec, goal_attr_change, target_attr, batch_size, scaling_factor, stepsize, break_delta, max_iter)
        print("bs: ", batch_size, "scaling: ", scaling_factor)

        attribute_change, attr_ch_std, attribute_preservation, attr_pres_std, similarity_metric, sim_std = eval_attr_vec_s(scaling_factor*attr_vec, target_attr, 1000)
        print("attribute_change: ", attribute_change, attr_ch_std)
        print("attribute_preservation: ", attribute_preservation, attr_pres_std)
        print("similarity_metric: ", similarity_metric, sim_std)
        print("-------------------------------")



In [None]:
def load_attribute_vec_into_row(attr_vec_path):
    a = np.load(attr_vec_path)
    l_vec1 = torch.zeros(1,9088)
    z = torch.from_numpy(np.random.RandomState(0).randn(1, G.z_dim)).to(device)
    w = G.mapping(z,label, truncation_psi=truncation_psi)
    s = G.synthesis.W2S(w)

    i=0
    j = 0
    for key in s.keys():
        l_vec1[0:, i:i+s[key].shape[1]] = torch.from_numpy(a[0, j, 0:s[key].shape[1]])
        i+=s[key].shape[1]
        j += 1

    l_vec1 = l_vec1.to(device)
    return l_vec1

print("l_vec_stylemc_blonde.npy")
attr_vec = load_attribute_vec_into_row("../attribute_vectors/l_vec_stylemc_blonde.npy")
scale_and_eval_s(attr_vec, 9)

print("l_vec_stylemc_smile.npy")
attr_vec = load_attribute_vec_into_row("../attribute_vectors/l_vec_stylemc_smile.npy")
scale_and_eval_s(attr_vec, 31)

In [None]:
print("l_vec_LS-StyleEdit_hair9_s.pt")
attr_vec = torch.load("../attribute_vectors/l_vec_LS-StyleEdit_hair9_s.pt").to(device)
scale_and_eval_s(attr_vec, 31)

print("l_vec_LS-StyleEdit_smile_s.pt")
attr_vec = torch.load("../attribute_vectors/l_vec_LS-StyleEdit_smile_s.pt").to(device)
scale_and_eval_s(attr_vec, 31)


