In [1]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from argparse import Namespace
from pprint import pprint

import sys
sys.path.append('.')
sys.path.append('..')

import torch
from torch import nn
import torch.nn.functional as F

# import torchvision
from torchvision import transforms

# from models.stylegan2.model import Generator, Encoder_BL
# from models.encoders.psp_encoders import GradualStyleEncoder
from models.psp import pSp
from criteria.vgg_loss import VGGLoss

In [2]:
from utils.common import tensor2im

In [3]:
device = 'cuda'

In [4]:
def image2tensor(image):
    image = torch.FloatTensor(image).permute(2,0,1).unsqueeze(0)/255.
    return (image-0.5)/0.5

def tensor2image(tensor):
    tensor = tensor.clamp_(-1., 1.).detach().squeeze().permute(1,2,0).cpu().numpy()
    return tensor*0.5 + 0.5

def imshow(img, size=5, cmap='jet'):
    plt.figure(figsize=(size,size))
    plt.imshow(img, cmap=cmap)
    plt.axis('off')
    plt.show()


# Load Input Image

In [5]:
image_size = 1024  # select from [256, 512, 1024]

transform = transforms.Compose([
            transforms.Resize(256),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

transform2 = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

In [6]:
imgs_root = './images/'
img_name = 'gl1_01.png'

# img = transform(Image.open(f'{imgs_root}/{img_name}')).to(device)
# input_image = tensor2im(img)
ori_img = Image.open(f'{imgs_root}/{img_name}')
imgs = transform(ori_img).unsqueeze(0).to(device)

In [None]:
torch.cuda.memory_allocated() * 1e-9

In [None]:
torch.cuda.empty_cache()

# Load Models

In [8]:
# psp model
model_path = '../pretrained_models/psp_ffhq_encode.pt'
ckpt = torch.load(model_path, map_location='cpu')

In [9]:
opts = ckpt['opts']

opts['checkpoint_path'] = model_path

if 'learn_in_w' not in opts:
    opts['learn_in_w'] = False
if 'output_size' not in opts:
    opts['output_size'] = image_size

In [10]:
net = pSp(Namespace(**opts))
net.eval()
net.cuda()
print('Model successfully loaded!')

Loading pSp from checkpoint: ../pretrained_models/psp_ffhq_encode.pt
Model successfully loaded!


In [None]:
# with torch.no_grad():
#     imgOut, z0 = net(imgs, resize=True, randomize_noise=False, return_latents=True)

# imgOut = imgOut.to('cpu')
# torch.cuda.empty_cache() 

In [None]:
with torch.no_grad():
    z0 = net.encoder(imgs)

# Improve Latent Representation

In [None]:
iterations = 20

In [None]:
imgs = transform2(ori_img).unsqueeze(0).to('cpu')

In [None]:
vgg_loss = VGGLoss(device)

z = z0.detach().clone()

z.requires_grad = True
optimizer = torch.optim.Adam([z], lr=0.01)

for step in range(iterations):
    imgs_gen, _ = net.decoder([z], 
                           input_is_latent=True, 
                           # truncation=truncation,
                           # truncation_latent=trunc, 
                           randomize_noise=False)
    imgs_gen.to('cpu')
    
    z_hat = net.encoder(imgs_gen)
    z_hat.to('cpu')
    
    loss = F.mse_loss(imgs_gen, imgs) + vgg_loss(imgs_gen, imgs) + F.mse_loss(z0, z_hat)*2.0
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step() 
    
    if (step+1)%5 == 0:
        print(f'step:{step+1}, loss:{loss.item()}')
        imgs_fakes = torch.cat([img_gen for img_gen in imgs_gen], dim=1)        
        imshow(tensor2image(torch.cat([imgs_real, imgs_fakes], dim=2)),10)

# Show Results

In [None]:
# output_image = tensor2im(imgOut)
# res = np.concatenate([np.array(input_image.resize((256, 256))),
#                     np.array(output_image.resize((256, 256)))], axis=1)
# Image.fromarray(res)