In [1]:
from datetime import datetime
import io

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import pickle
import ipywidgets as widgets
import tensorflow as tf
import tqdm.notebook as tqdm

In [2]:
import lib_stylegan

im_size = 256
batch_size = 8 
latent_size = 512 
channels = 32 # Should be at least 32 for good results

model = lib_stylegan.style_gan.StyleGan(im_size=im_size, 
                                        latent_size=latent_size,
                                        nb_style_mapper_layer=6,
                                        channels=channels)

model.load_weights("samples/model_small_lr_009")

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f8c73d1d160>

In [3]:
def get_images(M,S,G,nb=16):
    l_z, _, _, noise = model.get_noise(tf.ones((nb,)))
    l_w = M(l_z)
    style = tf.stack([l_w for i in range(model.n_layers)],axis=1)
    seed = S(style)
    generated = G([seed, style, noise])
    images = [im for im in generated.numpy()]
    images = [Image.fromarray(np.uint8(np.clip(im, 0, 1)*255)) for im in images]
    images = [im.resize((128,128)) for im in images]
    images_bytes = []
    for im in images:
        img_byte_arr = io.BytesIO()
        im.save(img_byte_arr, format='PNG')
        img_byte_arr = img_byte_arr.getvalue()
        images_bytes.append(img_byte_arr)
    return [z for z in l_z.numpy()], [w for w in l_w.numpy()], images_bytes


## Click for 1

In [4]:
result = []

In [5]:
COLS = 4
ROWS = 4
IMG_WIDTH = 150
IMG_HEIGHT = 150
TOTAL_BATCH = 100
progress = tqdm.tqdm(total=100)

import ipywidgets as widgets
import functools

z,w,im = get_images(model.M,model.S,model.G)
def next_images(*args):
    global z
    global w
    global im
    values = [bw.value for bw in button_widgets]
    result.extend(zip(z,w,values))
    z,w,im = get_images(model.M,model.S,model.G)
    for bw in button_widgets:
        bw.value = False
    for i,iw in enumerate(images_widgets):
        iw.value = im[i]
    progress.update(1)
    
images_widgets = []    
button_widgets = []    
rows = []    
for row in range(ROWS):
    cols = []
    for col in range(COLS):
        index = row * COLS + col
        image = widgets.Image(
            value=im[index], width=IMG_WIDTH, height=IMG_HEIGHT
        )
        images_widgets.append(image)
        button = widgets.ToggleButton(description=f'Click to select')
        # Bind the click event to the on_click function, with our index as argument
        button_widgets.append(button)
        # Create a vertical layout box, image above the button
        box = widgets.VBox([image, button])
        cols.append(box)

    # Create a horizontal layout box, grouping all the columns together
    rows.append(widgets.HBox(cols))

next_button = widgets.Button(description='Next')
next_button.on_click(next_images)
rows.append(next_button)

# Create a vertical layout box, grouping all the rows together
widgets.VBox(rows)

  0%|          | 0/100 [00:00<?, ?it/s]

VBox(children=(HBox(children=(VBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x80…

In [6]:
print([r[2] for r in result])

[]


## Saving

In [12]:
with open(f'data/lattent_{datetime.now().strftime("%Y%m%d-%H%M%S")}_nb_{len(result)}.pkl', 'wb') as f:
    pickle.dump(result, f)