In [None]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import torch
import numpy as np
from utils import show, renormalize
from utils import util, paintwidget, labwidget, imutil
from networks import networks
from PIL import Image
import os

### load networks, and sample images

In [None]:
nets = networks.define_nets('stylegan', 'church')
outdim = nets.setting['outdim']

In [None]:
num_components = 4
zs = nets.sample_zs(1000, 0)
seeds = np.random.choice(len(zs), 4)

source_zs = zs[seeds]

with torch.no_grad():
    source_ims = nets.zs2image(source_zs)

In [None]:
# display the input images
for im in source_ims:
    show.a([renormalize.as_image(im).resize((200, 200), Image.ANTIALIAS)])
show.flush()

### set up interactive panels
Draw your mouse on the image panels. The network input will show in the second to last panel, and the network output in the last panel.

In [None]:
def make_callback(painter):
    def probe_changed(c):
        global composite
        global mask_composite
        p = painter
        if p.mask:
            mask = renormalize.from_url(p.mask, target='pt', size=(outdim, outdim)).cuda()[None]
        else:
            mask = torch.zeros_like(sample)[None]
        with torch.no_grad():
            mask = mask[:, [0], :, :].cuda()
            mask_composite += mask
            sample = renormalize.from_url(p.image, size=(outdim, outdim)).cuda()[None]
            
            composite = sample * mask + composite * (1-mask)
            mask_composite = torch.clamp(mask_composite, 0., 1.)
            out = nets.invert(composite, mask_composite)
        img_url = renormalize.as_url(composite[0], size=256)
        img_html = '<img src="%s"/>'% img_url
        collage_div.innerHTML = img_html   
        img_url = renormalize.as_url(out[0], size=256)
        img_html = '<img src="%s"/>'% img_url
        encoded_div.innerHTML = img_html
    return probe_changed

In [None]:
img_url = renormalize.as_url(torch.zeros(3, outdim, outdim), size=256)
img_html = '<img src="%s"/>'%img_url
encoded_div = labwidget.Div(img_html)
collage_div = labwidget.Div(img_html)

painters = []

composite = torch.zeros(1, 3, outdim, outdim).cuda()
mask_composite = torch.zeros_like(composite)[:, [0], :, :]

for i in range(num_components):
    src_painter = paintwidget.PaintWidget(oneshot=False, width=256, height=256, 
                                      brushsize=20, save_sequence=False, track_move=True) # , on_move=True)
    src_painter.image = renormalize.as_url(source_ims[i], size=256)
    painters.append(src_painter)
    callback = make_callback(src_painter)
    src_painter.on('mask', callback)
    show.a([src_painter], cols=3)

show.a([collage_div], cols=3)
show.a([encoded_div], cols=3)
show.flush()

In [None]:
def save_drawing(save_name):
    save_path = os.path.join('drawing/composites/%s' % save_name)
    os.makedirs(save_path)
    masks = []
    for i, p in enumerate(painters):
        if p.mask:
            mask = renormalize.from_url(p.mask, target='pt', size=(outdim, outdim)).cuda()[None]
        else:
            mask = torch.zeros(1, 3, outdim, outdim).cuda()
        with torch.no_grad():
            mask = mask[:, [0], :, :].cuda()
            masks.append(mask)
            sample = renormalize.from_url(p.image, size=(outdim, outdim)).cuda()[None]
            part = sample * mask
        im_pil = imutil.draw_masked_image(sample, mask, size=256)[1]
        im_pil.save(os.path.join(save_path, 'part%d.png' % i))
        show.a(['part %d' % i, im_pil.resize((200, 200), Image.ANTIALIAS)], cols=3)
    with torch.no_grad():
        out = nets.invert(composite, mask_composite)
    composite_pil = renormalize.as_image(out[0])
    composite_pil.save(os.path.join(save_path, 'composite.png'))
    torch.save(dict(masks=torch.cat(masks).cpu(), seeds=seeds),'%s/data.pth' % save_path)
    show.a(['composite', composite_pil.resize((200, 200), Image.ANTIALIAS)])
    show.flush()

In [None]:
save_drawing('church')