In [None]:
import os

import imageio
import numpy as np
import torch

from PIL import Image, ImageEnhance
%matplotlib widget
import matplotlib.pylab as plt

from checkpointer import Checkpointer
from defaults import get_cfg_defaults
from model import Model


In [None]:
print('Loading config...')
config_file='configs/popgan.yaml'
cfg = get_cfg_defaults()
cfg.merge_from_file(config_file)
cfg.freeze()

print('Initializing model...')
torch.cuda.set_device(0)
layer_count = cfg.MODEL.LAYER_COUNT
model = Model(
    startf=cfg.MODEL.START_CHANNEL_COUNT,
    layer_count=cfg.MODEL.LAYER_COUNT,
    maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
    latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
    truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
    truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
    mapping_layers=cfg.MODEL.MAPPING_LAYERS,
    channels=cfg.MODEL.CHANNELS,
    generator=cfg.MODEL.GENERATOR,
    encoder=cfg.MODEL.ENCODER)
model.cuda(0)
model.eval()
model.requires_grad_(False)

print('Loading checkpoint...')
model_dict = {
    'discriminator_s': model.encoder,
    'generator_s': model.decoder,
    'mapping_tl_s': model.mapping_tl,
    'mapping_fl_s': model.mapping_fl,
    'dlatent_avg': model.dlatent_avg
}
checkpointer = Checkpointer(cfg, model_dict)
extra_checkpoint_data = checkpointer.load(file_name='training_artifacts/popgan/model_tmp_lod8_e200.pth')

print('READY!')

In [None]:
def encode(x, pop):
    x_pop = torch.cat((x, pop), 1)
    zlist = []
    for i in range(x_pop.shape[0]):
        Z, _ = model.encode(x_pop[i][None, ...], layer_count - 1, 1)
        zlist.append(Z)
    Z = torch.cat(zlist)
    Z = Z.repeat(1, model.mapping_fl.num_layers, 1)
    return Z

noises = [[
        0,
        torch.randn([1, 1, 2, 2]),
        torch.randn([1, 1, 4, 4]),
        torch.randn([1, 1, 8, 8]),
        torch.randn([1, 1, 16, 16]),
        torch.randn([1, 1, 32, 32]),
        torch.randn([1, 1, 64, 64]),
        torch.randn([1, 1, 128, 128]),
        torch.randn([1, 1, 256, 256]),
        torch.randn([1, 1, 512, 512]),
        torch.randn([1, 1, 1024, 1024])
    ] for i in range(10)]

def sample_z(seed):
    with torch.no_grad():
        rng = np.random.RandomState(seed)
        z = rng.standard_normal(cfg.MODEL.LATENT_SPACE_SIZE)
        z = torch.from_numpy(z).float().cuda().unsqueeze(0)
    return z

def z2w(z):
    with torch.no_grad():
        w = model.mapping_fl(z)
    return w

def decode(w, pop, noise_i=0):
    return torch.clamp(model.decoder(w, pop, layer_count - 1, 1, noise=noises[noise_i]), -1., 1.)

def mix_styles(style_shallow, style_deep, layer_range, strength):
    style = style_shallow.clone()
    style[layer_range, :] = (style_deep[layer_range, :] * strength) + (style_shallow[layer_range, :] * (1-strength))
    return style

def tensor_to_numpy(img):
    img = img * 0.5 + 0.5
    img = img.cpu().squeeze().numpy()
    img = np.clip(img, 0, 1)
    if len(img.shape) == 3:
        img = np.moveaxis(img, 0, 2)
    img = (img *255).astype(np.uint8)
    return img

def tensor_to_PIL(img):
    img = Image.fromarray(tensor_to_numpy(img))
    img = ImageEnhance.Brightness(img).enhance(1.4)
    img = ImageEnhance.Contrast(img).enhance(1.4)
    return img

def load_image(path):
    im_raw = Image.open(path)
    im = im_raw.resize((1024, 1024))
    im = np.asarray(im)
    x = numpy_to_tensor(im)
    return im_raw, x

def numpy_to_tensor(im):
    if len(im.shape) < 3:
        im = np.expand_dims(im, 2)
    im = im.transpose((2, 0, 1))
    x = torch.tensor(np.asarray(im, dtype=np.float32), requires_grad=False).cuda() / 127.5 - 1.
    return x

def numpy_to_PIL(img):
    img = np.clip(img, 0, 1)
    img = (img *255).astype(np.uint8)
    img = Image.fromarray(img)
    img = img.resize((1024, 1024))
    return img

def save_mp4_from_images(dst, images):
    writer = imageio.get_writer(dst, fps=10)
    for img in images:
        writer.append_data(img)
    writer.close()
    
def db_to_tensor(im):
    x = torch.tensor(np.asarray(im, dtype=np.float32), requires_grad=False).cuda(gpu) / 127.5 - 1.
    return x

In [None]:
z_seed = 7
noise_seed = 4
brush_size = 2
brush_strength = 0.5
img_size = 1024
pop_size = 64
disp_size = 512

def generate_pop(w, np_pop, noise_seed):
    pop = numpy_to_tensor(np_pop*255)
    x_rec = decode(w, pop.unsqueeze(0), noise_seed)
    img = tensor_to_PIL(x_rec)
    img = np.array(img)
    return img

fig, axes = plt.subplots(1,2, figsize=(10, 5))
for a in axes:
    a.get_xaxis().set_visible(False)
    a.get_yaxis().set_visible(False)

z = sample_z(z_seed)
w = z2w(z)
np_pop = np.zeros((pop_size, pop_size))
resized_pop = np.array(Image.fromarray(np_pop).resize((img_size, img_size)))
img = generate_pop(w, resized_pop, noise_seed)
disp_img = np.array(Image.fromarray(img).resize((disp_size, disp_size)))
axes[0].imshow(np_pop)
axes[1].imshow(disp_img)

debug = []
events = []
def onclick(event):
    try:
        if event.button.name in ['LEFT', 'RIGHT']:
            direction = 1 if event.button.name == 'LEFT' else -1
            x_min = int(max(0, event.xdata - brush_size))
            x_max = int(min(pop_size, event.xdata + brush_size))
            y_min = int(max(0, event.ydata - brush_size))
            y_max = int(min(pop_size, event.ydata + brush_size))
            
            debug.append({
                'event': event.button.name,
                'direction': direction,
                'x_min': x_min,
                'x_max': x_max,
                'y_min': y_min,
                'y_max': y_max,
                'pop_shape': np_pop[x_min:x_max, y_min:y_max].shape,
                'brush_effect': brush_strength * direction
            })

            np_pop[y_min:y_max, x_min:x_max] += (brush_strength * direction)
            np_pop[:, :] = np.clip(np_pop[:, :], 0., 1.)
            axes[0].imshow(np_pop)
    except Exception as e:
        debug.append(e)
        
def onpress(event):
    try:
        events.append(event)
        if event.key == 'w':
            resized_pop = np.array(Image.fromarray(np_pop).resize((img_size, img_size)))
            img = generate_pop(w, resized_pop, noise_seed)
            disp_img = np.array(Image.fromarray(img).resize((disp_size, disp_size)))
            axes[1].imshow(disp_img)
    except Exception as e:
        debug.append(e)

fig.canvas.mpl_connect('button_press_event', onclick)
fig.canvas.mpl_connect('key_press_event', onpress)
fig.show()
fig.tight_layout()
