In [None]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import torch
import numpy as np
from utils import show, renormalize, pbar
from utils import util, paintwidget, labwidget, imutil
from networks import networks
from PIL import Image
import os
import skvideo.io
from torchvision import transforms
import time

### load networks

In [None]:
nets = networks.define_nets('stylegan', 'car')
outdim = nets.setting['outdim']

### sample an image, and reencode it

In [None]:
use_g_sample = True
if use_g_sample:
    # use a gan image as source, change n or seed to get a different image
    n = 56
    seed = 0
    with torch.no_grad():
        source_z = nets.sample_zs(n+1, seed=0)[n][None]
        source_im = nets.zs2image(source_z)
    show(['Source Image', renormalize.as_image(source_im[0]).resize((256, 256), Image.LANCZOS)])
else:
    # use a real image as source
    im_path = 'img/car0.png'
    transform = transforms.Compose([
                    transforms.Resize(outdim),
                    transforms.CenterCrop(outdim),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])    
    source_im = transform(Image.open(im_path))[None].cuda()
    show(['Source Image', renormalize.as_image(source_im[0]).resize((256, 256), Image.LANCZOS)])

In [None]:
with torch.no_grad():
    out = nets.invert(source_im)
    show(['GAN Reconstruction', renormalize.as_image(out[0]).resize((256, 256), Image.LANCZOS)])

### visualize network priors
You can drag your mouse on the left panel, and the GAN reconstruction will show in the right panel

In [None]:
src_painter = paintwidget.PaintWidget(oneshot=False, width=256, height=256, 
                                      brushsize=20, save_sequence=False, track_move=True) # , on_move=True)
src_painter.image = renormalize.as_url(source_im[0], size=256)

img_url = renormalize.as_url(torch.zeros(3, 256, 256))
img_html = '<img src="%s"/>'%img_url
output_div = labwidget.Div(img_html)

counter = 0
prev_time = time.time()
update_freq = 0.1 # mouse time intervals 0.05 to 0.07, can change this
mask_list = []
reconstruction_list = []

def probe_changed(c):
    global counter
    global prev_time
    counter += 1
    curr_time = time.time()
    if curr_time - prev_time < update_freq:
        return
    prev_time = time.time()
    
    mask_url = src_painter.mask_buffer
    mask =  renormalize.from_url(mask_url, target='pt', size=(outdim, outdim)).cuda()[None] # 1x3xHxW
    with torch.no_grad():
        mask = mask[:, [0], :, :] # 1x1xHxW
        mask_list.append(mask.cpu())
        masked_im = source_im * mask
        regenerated_mask = nets.invert(masked_im, mask)
    img_url = renormalize.as_url(regenerated_mask[0], size=256)
    img_html = '<img src="%s"/>'%img_url
    output_div.innerHTML = img_html
    reconstruction_list.append(renormalize.as_image(regenerated_mask[0]))
    
src_painter.on('mask_buffer', probe_changed)

show.a([src_painter], cols=2)
show.a([output_div], cols=2)

show.flush()

### save the resulting video

In [None]:
def write_video(file_name, rate='15'):
    os.makedirs('drawing/masking', exist_ok=True)
    assert(not os.path.isfile('drawing/masking/%s' % file_name))

    inputdict = {
        '-r': rate
    }
    outputdict = {
        '-pix_fmt': 'yuv420p',
        '-r': rate
    }
    writer = skvideo.io.FFmpegWriter('drawing/masking/%s' % file_name, inputdict, outputdict)

    source_im_np = np.array(renormalize.as_image(source_im[0]))
    for mask, rec_image in zip(pbar(mask_list), reconstruction_list):
        masked_im = renormalize.as_image((source_im.cpu() * mask)[0])
        masked_im_np = np.array(masked_im)
        rec_im_np = np.array(rec_image)
        im_np = np.concatenate([source_im_np, masked_im_np, rec_im_np], axis=1)
        writer.writeFrame(im_np)
    writer.close()

In [None]:
write_video('car.mp4', rate='15')