In [6]:
import torch
import pickle
import time
import cv2
import numpy as np
import matplotlib.pyplot as plt
import utils.inverter as inv
from models.stylegan_generator_idinvert import StyleGANGeneratorIdinvert

In [7]:
# settings
path_images = '/Users/max/Desktop/FFHQ'
n_iter = 200

In [8]:
# initialize generator & invertor
G = StyleGANGeneratorIdinvert('styleganinv_ffhq256')
Inverter = inv.StyleGANInverter(G, 'styleganinv_ffhq256', iteration=n_iter)

[2020-12-28 14:01:12,859][INFO] Build network for module `generator` in model `styleganinv_ffhq256`.
[2020-12-28 14:01:13,132][INFO] Loading pytorch weights from `models/pretrain/styleganinv_ffhq256_generator.pth`.
[2020-12-28 14:01:13,517][INFO] Successfully loaded!
[2020-12-28 14:01:13,540][INFO] Current `lod` is 0.0.
[2020-12-28 14:01:13,542][INFO] Build network for module `encoder` in model `styleganinv_ffhq256`.
[2020-12-28 14:01:14,718][INFO] Loading pytorch weights from `models/pretrain/styleganinv_ffhq256_encoder.pth`.
[2020-12-28 14:01:16,815][INFO] Successfully loaded!


In [4]:
def preprocess(image):
    image = image[:,:,:3]
    image = cv2.resize(image, (256, 256))
    image = image.astype(np.float32)
    image = image * 2 - 1
    return image.astype(np.float32).transpose(2, 0, 1)

In [34]:
# process images in packages of 1000

latents = []
fakes = []
losses = []
start = time.time()

for i in range(11):
    
    # read .png files
    real_list = []
    for j in range(1000):
        file = path_images+'/'+str(i*1000).zfill(5)+'/'+str(i*1000+j).zfill(5)+'.png'
        real_list.append(preprocess(plt.imread(file)))

    real = torch.from_numpy(np.array(real_list))

    # create optimized latent code & fake images
    for k in range(real.shape[0]):
        latent, fake, loss = Inverter.invert_offline(image=real[k].unsqueeze(0))
        latents.append(latent.squeeze().detach().numpy())
        fakes.append(fake.squeeze().detach().numpy())
        losses.append(loss)
        
latents = np.array(latents)
fakes = np.array(fakes)
losses = np.array(losses)
    
print((time.time()-start)/60, 'min')
print('=',(time.time()-start)/3600, 'h')

loss_pix: 0.096, loss_feat: 6908.397, loss_reg: 0.095, loss: 0.632: 100%|██████████| 3/3 [00:25<00:00,  8.37s/it]    
loss_pix: 0.047, loss_feat: 4491.565, loss_reg: 0.038, loss: 0.348: 100%|██████████| 3/3 [00:26<00:00,  8.80s/it]
loss_pix: 0.147, loss_feat: 11304.095, loss_reg: 0.067, loss: 0.846: 100%|██████████| 3/3 [00:25<00:00,  8.65s/it]
loss_pix: 0.063, loss_feat: 6518.699, loss_reg: 0.048, loss: 0.484: 100%|██████████| 3/3 [00:25<00:00,  8.40s/it]


1.8976312677065532 min
= 0.0316273252831565 h


In [39]:
# save
pickle.dump(latents, open( "lat.p", "wb" ))
pickle.dump(fakes, open( "fak.p", "wb" ))
pickle.dump(losses, open( "los.p", "wb" ))