In [None]:
import torch, copy
from seeing import nethook, setting, show, renormalize, zdataset, pbar
from seeing import encoder_net
from seeing import imgviz, segmenter
from torchvision import models
from torch.nn.functional import mse_loss, l1_loss
torch.set_grad_enabled(False)


In [None]:
model = 'church'
image_number = 883
loaded_image, _ = setting.load_test_image(image_number, 'train', model)
show([image_number, renormalize.as_image(loaded_image[0])])

In [None]:
def optimize_residuals(target_x, optimize_over=None, num_steps=3000, show_every=500,
                       lr=0.01, milestones=[800, 1200, 1800], model='church'):
    if optimize_over is None:
        optimize_over = ['layer1', 'layer2', 'layer3']
    layernums = [name.replace('layer', '') for name in optimize_over]
    show.reset()
    
    # Load a GAN generator, a trained encoder, and a pretrained VGG.
    unwrapped_G = setting.load_proggan(model)
    E = setting.load_proggan_inversion(model)
    vgg = models.vgg16(pretrained=True)
    VF = nethook.subsequence(vgg.features, last_layer='20')
    
    # Move models and data to GPU
    for m in [unwrapped_G, E, VF]:
        m.cuda()
    
    # Some constants
    with torch.no_grad():
        init_z = E(target_x)
        target_v = VF(target_x)
        
    # Wrap the GAN in an instrumented model that adds residuals at the requested layer
    G = encoder_net.ResidualGenerator(copy.deepcopy(unwrapped_G), init_z, optimize_over)
    parameters = list(G.parameters(recurse=False))
    
    # We only need grad over the top-level residual parameters in G.
    nethook.set_requires_grad(False, G, E)
    nethook.set_requires_grad(True, *parameters)
    optimizer = torch.optim.Adam(parameters, lr=lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=milestones, gamma=0.5)

    with torch.enable_grad():
        for step_num in pbar(range(num_steps + 1)):
            current_x = G()
            loss_x = l1_loss(target_x, current_x)
            loss_v = mse_loss(target_v, VF(current_x))
            loss_d = sum(getattr(G, 'd%s' % n).pow(2).mean() for n in layernums)
            loss = loss_x + loss_v + loss_d
            if show_every and step_num % show_every == 0:
                with torch.no_grad():
                    show.a(
                        ['step %d' % step_num] +
                        ['loss: %f' % loss.item()] +
                        ['loss_x: %f' % loss_x.item()] +
                        ['loss_v: %f' % loss_v.item()] +
                        ['loss_d: %f' % loss_d.item()] +
                        [[renormalize.as_image(current_x[0])]], cols=3)
            optimizer.zero_grad()
            loss.backward()
            if step_num > 0:
                optimizer.step()
        show.flush()
    
    return current_x

In [None]:
reconst = optimize_residuals(loaded_image.cuda(), model=model)

In [None]:
iv = imgviz.ImageVisualizer(256)
upp = segmenter.UnifiedParsingSegmenter()

show([['original', iv.image(loaded_image)],
      ['reconstruction', iv.image(reconst)]])
orig_seg = upp.segment_batch(loaded_image.cuda())[0, 0:1]
reconst_seg = upp.segment_batch(reconst.cuda())[0, 0:1]

show([[iv.segmentation(orig_seg)], [iv.segmentation(reconst_seg)],
     iv.segment_key(torch.cat([orig_seg, reconst_seg]), upp)])
