In [None]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import torch
import numpy as np
from utils import show, renormalize, masking
from utils import util, imutil, pbar, losses, inversions
from networks import networks
from PIL import Image
import os
from torchvision import transforms
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def load_nets():
    
    # bonus stylegan encoder trained on real images + identity loss
    # nets = networks.define_nets('stylegan', 'ffhq', ckpt_path='pretrained_models/sgan_encoders/ffhq_reals_RGBM/netE_epoch_best.pth')
    
    # stylegan trained on gsamples + identity loss
    nets = networks.define_nets('stylegan', 'ffhq')
    return nets

In [None]:
im_path = 'img/torralba_cropped.png'

outdim=1024 # for faces
# outdim = 256 # for churches

transform = transforms.Compose([
                transforms.Resize(outdim),
                transforms.CenterCrop(outdim),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])    
source_im = transform(Image.open(im_path).convert('RGB'))[None].cuda()
show(['Source Image', renormalize.as_image(source_im[0]).resize((256, 256), Image.LANCZOS)])

In [None]:
# need to reload nets each time after finetuning
nets = load_nets()
outdim = nets.setting['outdim']

with torch.no_grad():
    mask = torch.ones_like(source_im)[:, [0], :, :]
    out = nets.invert(source_im, mask)
    # encoded = nets.encode(source_im, mask)
    # out = nets.decode(encoded)
    show(['Inverted Image', renormalize.as_image(out[0]).resize((256, 256), Image.LANCZOS)])

# finetune the encoder towards the real image

In [None]:
from networks.psp import id_loss

In [None]:

# need to reload nets each time after finetuning
nets = load_nets()
outdim = nets.setting['outdim']

with torch.no_grad():
    mask = torch.ones_like(source_im)[:, [0], :, :]
    initial_inversion = nets.invert(source_im, mask)

batch_size = 1
lambda_mse = 1.0
lambda_lpips = 1.0
lambda_z = 0. # set lambda_z to 10.0 to optimize the latent first. (optional)
lambda_id = 0.1

# do optional latent optimization
if lambda_z > 0.:
    checkpoint_dict, opt_losses = inversions.invert_lbfgs(nets, source_im, num_steps=30)
    opt_ws = checkpoint_dict['current_z'].detach().clone().repeat(batch_size, 1, 1)
    # reenable grad after LBFGS
    torch.set_grad_enabled(True)


netG = nets.generator.eval()
netE = nets.encoder.eval()
util.set_requires_grad(False, netG)
util.set_requires_grad(True, netE)

mse_loss = torch.nn.MSELoss()
l1_loss = torch.nn.L1Loss()
perceptual_loss = losses.LPIPS_Loss().cuda().eval()
identity_loss = id_loss.IDLoss().cuda().eval()
util.set_requires_grad(False, identity_loss)
util.set_requires_grad(False, perceptual_loss)

optimizer = torch.optim.Adam(netE.parameters(), lr=0.00005, betas=(0.5, 0.999))

target = source_im.repeat(batch_size, 1, 1, 1)

reshape = torch.nn.AdaptiveAvgPool2d((256, 256))

In [None]:
all_losses = dict(z=[], mse=[], lpips=[], id=[], sim_improvement=[])

# 30-50 steps is about enough
torch.manual_seed(0)

for i in pbar(range(30)):
    optimizer.zero_grad()
    mask_data = [masking.mask_upsample(source_im) for _ in range(batch_size)]
    hints = torch.cat([m[0] for m in mask_data])
    masks = torch.cat([m[1] for m in mask_data])
    
    encoded = netE(torch.cat([hints, masks], dim=1))
    regenerated = netG(encoded)
    if lambda_z > 0.:
        loss_z = mse_loss(encoded, opt_ws)
    else:
        loss_z = torch.Tensor((0.,)).cuda()
    loss_mse = mse_loss(regenerated, target)
    loss_perceptual = perceptual_loss.forward(
        reshape(regenerated), reshape(target)).mean()
    loss_id, sim_improvement, id_logs = identity_loss(reshape(regenerated), reshape(target), reshape(target))
    loss = (lambda_z * loss_z + lambda_mse * loss_mse
            + lambda_lpips * loss_perceptual + lambda_id * loss_id)
    # loss.backward(retain_graph=True)
    loss.backward()
    optimizer.step()
    all_losses['z'].append(loss_z.item())
    all_losses['mse'].append(loss_mse.item())
    all_losses['lpips'].append(loss_perceptual.item())
    all_losses['id'].append(loss_id.item())
    all_losses['sim_improvement'].append(sim_improvement)

In [None]:
f, ax = plt.subplots(1,4, figsize=(16, 3))
ax[0].plot(all_losses['z'])
ax[0].set_title('Z loss')
ax[1].plot(all_losses['mse'])
ax[1].set_title('MSE loss')
ax[2].plot(all_losses['lpips'])
ax[2].set_title('LPIPS loss')
ax[3].plot(all_losses['id'])
ax[3].set_title('ID loss')


In [None]:
show.a(['Initial Inversion', renormalize.as_image(initial_inversion[0]).resize((256, 256), Image.LANCZOS)])
if lambda_z > 0.:
    show.a(['optimized w', renormalize.as_image(checkpoint_dict['current_x'][0]).resize((256, 256), Image.LANCZOS)])
show.flush()


with torch.no_grad():
    hints = source_im
    mask = torch.ones_like(source_im)[:, [0], :, :]
    
    # hints, mask = masking.mask_upsample(source_im, threshold=0.5) 
    # mask = mask+0.5
    
    encoded = nets.encode(hints, mask)
    out = nets.decode(encoded)
    show.a(['hints Image', renormalize.as_image(hints[0]).resize((256, 256), Image.LANCZOS)])
    show.a(['Inverted Image', renormalize.as_image(out[0]).resize((256, 256), Image.LANCZOS)])
    show.flush()
    
    mask = torch.ones_like(source_im)[:, [0], :, :]
    mask[:, :, 100:-100, 100:-100] = 0.
    hints = source_im*mask
    
    encoded = nets.encode(hints, mask)
    out = nets.decode(encoded)
    show.a(['Hints Image', renormalize.as_image(hints[0]).resize((256, 256), Image.LANCZOS)])
    show.a(['Inverted Hints', renormalize.as_image(out[0]).resize((256, 256), Image.LANCZOS)])
    show.flush()

# interactive mixing
Draw your mouse on the image panels. The network input will show in the second to last panel, and the network output in the last panel.

In [None]:
collage_paths = [
    im_path,
    'img/efros_cropped.png',
    'img/phil_cropped.png',
    'img/biden_cropped.png'
]
num_components = len(collage_paths)
collage_ims = torch.cat([transform(Image.open(p).convert('RGB'))[None].cuda() for p in collage_paths])

In [None]:
from utils import paintwidget, labwidget

def make_callback(painter):
    def probe_changed(c):
        global composite
        global mask_composite
        p = painter
        if p.mask:
            mask = renormalize.from_url(p.mask, target='pt', size=(outdim, outdim)).cuda()[None]
        else:
            mask = torch.zeros_like(sample)[None]
        with torch.no_grad():
            mask = mask[:, [0], :, :].cuda()
            mask_composite += mask
            sample = renormalize.from_url(p.image, size=(outdim, outdim)).cuda()[None]
            
            composite = sample * mask + composite * (1-mask)
            mask_composite = torch.clamp(mask_composite, 0., 1.)
            out = nets.invert(composite, mask_composite)
        img_url = renormalize.as_url(composite[0], size=256)
        img_html = '<img src="%s"/>'% img_url
        collage_div.innerHTML = img_html   
        img_url = renormalize.as_url(out[0], size=256)
        img_html = '<img src="%s"/>'% img_url
        encoded_div.innerHTML = img_html
    return probe_changed

In [None]:
img_url = renormalize.as_url(torch.zeros(3, outdim, outdim), size=256)
img_html = '<img src="%s"/>'%img_url
encoded_div = labwidget.Div(img_html)
collage_div = labwidget.Div(img_html)

painters = []

composite = torch.zeros(1, 3, outdim, outdim).cuda()
mask_composite = torch.zeros_like(composite)[:, [0], :, :]

for i in range(num_components):
    src_painter = paintwidget.PaintWidget(oneshot=False, width=256, height=256, 
                                      brushsize=20, save_sequence=False, track_move=True) # , on_move=True)
    src_painter.image = renormalize.as_url(collage_ims[i], size=256)
    painters.append(src_painter)
    callback = make_callback(src_painter)
    src_painter.on('mask', callback)
    show.a([src_painter], cols=3)

show.a([collage_div], cols=3)
show.a([encoded_div], cols=3)
show.flush()

In [None]:
def show_drawing():
    for i, p in enumerate(painters):
        if p.mask:
            mask = renormalize.from_url(p.mask, target='pt', size=(outdim, outdim)).cuda()[None]
        else:
            mask = torch.zeros(1, 3, outdim, outdim).cuda()
        mask = mask[:, [0], :, :].cuda()
        sample = renormalize.from_url(p.image, size=(outdim, outdim)).cuda()[None]
        part = sample * mask
        im_pil = imutil.draw_masked_image(sample, mask, size=256)[1]
        # im_pil.save(os.path.join(save_path, 'part%d.png' % i))
        show.a(['part %d' % i, im_pil.resize((200, 200), Image.ANTIALIAS)], cols=3)
    with torch.no_grad():
        out = nets.invert(composite, mask_composite)
    composite_pil = renormalize.as_image(out[0])
    input_np = np.array(renormalize.as_image(composite[0]))
    mask_np = np.stack([np.array(mask_composite.cpu()[0][0])] * 3, axis=2)
    input_np[mask_np == 0] = 200 # lighten the unfilled region
    input_pil = Image.fromarray(input_np)
    # input_pil = renormalize.as_image(composite[0])
    # composite_pil.save(os.path.join(save_path, 'composite.png'))
    show.a(['input', input_pil.resize((200, 200), Image.ANTIALIAS)])
    show.a(['composite', composite_pil.resize((200, 200), Image.ANTIALIAS)])
    show.flush()

In [None]:
show_drawing()