In [1]:
import os
import glob
import sys
import numpy as np
import pickle
import tensorflow as tf
import PIL
import ipywidgets
import io

""" make sure this notebook is running from root directory """
while os.path.basename(os.getcwd()) in ('notebooks', 'src'):
    os.chdir('..')
assert ('README.md' in os.listdir('./')), 'Can not find project root, please cd to project root before running the following code'

import src.tl_gan.generate_image as generate_image
import src.tl_gan.feature_axis as feature_axis
import src.tl_gan.feature_celeba_organize as feature_celeba_organize

In [2]:
""" load feature directions """
path_feature_direction = './asset_results/pg_gan_celeba_feature_direction_40'

pathfile_feature_direction = glob.glob(os.path.join(path_feature_direction, 'feature_direction_*.pkl'))[-1]

with open(pathfile_feature_direction, 'rb') as f:
    feature_direction_name = pickle.load(f)

feature_direction = feature_direction_name['direction']
feature_name = feature_direction_name['name']
num_feature = feature_direction.shape[1]

print(type(feature_direction))
print(feature_direction)

import importlib
importlib.reload(feature_celeba_organize)
feature_name = feature_celeba_organize.feature_name_celeba_rename
feature_direction = feature_direction_name['direction']* feature_celeba_organize.feature_reverse[None, :]

print(type(feature_direction))
print(feature_direction)
print(np.shape(feature_direction))

<class 'numpy.ndarray'>
[[ 0.0199223  -0.03134961 -0.00460163 ...  0.02191978  0.05446521
  -0.0234645 ]
 [ 0.01922786  0.02906507 -0.00876675 ...  0.00351627  0.00407246
  -0.01586714]
 [ 0.01129154 -0.05090305 -0.05736918 ... -0.04276138  0.05990816
   0.0156728 ]
 ...
 [ 0.02499767 -0.00887705 -0.08504948 ... -0.04194508  0.02711248
  -0.02090008]
 [ 0.01635417  0.0572896  -0.01062254 ...  0.04207035  0.01289932
  -0.05572326]
 [ 0.00674798 -0.01110875 -0.03138209 ...  0.0064889   0.03154353
  -0.04592823]]
<class 'numpy.ndarray'>
[[ 0.0199223  -0.03134961 -0.00460163 ...  0.02191978  0.05446521
   0.0234645 ]
 [ 0.01922786  0.02906507 -0.00876675 ...  0.00351627  0.00407246
   0.01586714]
 [ 0.01129154 -0.05090305 -0.05736918 ... -0.04276138  0.05990816
  -0.0156728 ]
 ...
 [ 0.02499767 -0.00887705 -0.08504948 ... -0.04194508  0.02711248
   0.02090008]
 [ 0.01635417  0.0572896  -0.01062254 ...  0.04207035  0.01289932
   0.05572326]
 [ 0.00674798 -0.01110875 -0.03138209 ...  0.00648

In [5]:
""" start tf session and load GAN model """

# path to model code and weight
path_pg_gan_code = './src/model/pggan'
path_model = './asset_model/karras2018iclr-celebahq-1024x1024.pkl'
sys.path.append(path_pg_gan_code)


""" create tf session """
yn_CPU_only = False

if yn_CPU_only:
    config = tf.ConfigProto(device_count = {'GPU': 0}, allow_soft_placement=True)
else:
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True

sess = tf.InteractiveSession(config=config)

try:
    with open(path_model, 'rb') as file:
        G, D, Gs = pickle.load(file)
except FileNotFoundError:
    print('before running the code, download pre-trained model to project_root/asset_model/')
    raise

print(type(Gs))

len_z = Gs.input_shapes[0][1]
z_sample = np.random.randn(len_z)
x_sample = generate_image.gen_single_img(z_sample, Gs=Gs)

<class 'tfutil.Network'>


In [6]:
def img_to_bytes(x_sample):
    imgObj = PIL.Image.fromarray(x_sample)
    imgByteArr = io.BytesIO()
    imgObj.save(imgByteArr, format='PNG')
    imgBytes = imgByteArr.getvalue()
    return imgBytes


In [7]:
z_sample = np.random.randn(len_z)
x_sample = generate_image.gen_single_img(Gs=Gs)

w_img = ipywidgets.widgets.Image(value=img_to_bytes(x_sample), format='png', width=512, height=512)

class GuiCallback(object):
    counter = 0
    #     latents = z_sample
    def __init__(self):
        self.latents = z_sample
        self.feature_direction = feature_direction
        self.feature_lock_status = np.zeros(num_feature).astype('bool')
        self.feature_directoion_disentangled = feature_axis.disentangle_feature_axis_by_idx(
            self.feature_direction, idx_base=np.flatnonzero(self.feature_lock_status))

    def random_gen(self, event):
        self.latents = np.random.randn(len_z)
        self.update_img()

    def modify_along_feature(self, event, idx_feature, step_size=0.01):
        self.latents += self.feature_directoion_disentangled[:, idx_feature] * step_size
        self.update_img()

    def set_feature_lock(self, event, idx_feature, set_to=None):
        if set_to is None:
            self.feature_lock_status[idx_feature] = np.logical_not(self.feature_lock_status[idx_feature])
        else:
            self.feature_lock_status[idx_feature] = set_to
        self.feature_directoion_disentangled = feature_axis.disentangle_feature_axis_by_idx(
            self.feature_direction, idx_base=np.flatnonzero(self.feature_lock_status))
    
    def update_img(self):        
        x_sample = generate_image.gen_single_img(z=self.latents, Gs=Gs)
        x_byte = img_to_bytes(x_sample)
        w_img.value = x_byte

guicallback = GuiCallback()

step_size = 0.4
def create_button(idx_feature, width=96, height=40):
    """ function to built button groups for one feature """
    w_name_toggle = ipywidgets.widgets.ToggleButton(
        value=False, description=feature_name[idx_feature],
        tooltip='{}, Press down to lock this feature'.format(feature_name[idx_feature]),
        layout=ipywidgets.Layout(height='{:.0f}px'.format(height/2), 
                                 width='{:.0f}px'.format(width),
                                 margin='2px 2px 2px 2px')
    )
    w_neg = ipywidgets.widgets.Button(description='-',
                                      layout=ipywidgets.Layout(height='{:.0f}px'.format(height/2), 
                                                               width='{:.0f}px'.format(width/2),
                                                               margin='1px 1px 5px 1px'))
    w_pos = ipywidgets.widgets.Button(description='+',
                                      layout=ipywidgets.Layout(height='{:.0f}px'.format(height/2), 
                                                               width='{:.0f}px'.format(width/2),
                                                               margin='1px 1px 5px 1px'))
    
    w_name_toggle.observe(lambda event: 
                      guicallback.set_feature_lock(event, idx_feature))
    w_neg.on_click(lambda event: 
                     guicallback.modify_along_feature(event, idx_feature, step_size=-1 * step_size))
    w_pos.on_click(lambda event: 
                     guicallback.modify_along_feature(event, idx_feature, step_size=+1 * step_size))
    
    button_group = ipywidgets.VBox([w_name_toggle, ipywidgets.HBox([w_neg, w_pos])],
                                  layout=ipywidgets.Layout(border='1px solid gray'))
    
    return button_group
  

list_buttons = []
for idx_feature in range(num_feature):
    list_buttons.append(create_button(idx_feature))

yn_button_select = True
def arrange_buttons(list_buttons, yn_button_select=True, ncol=4):
    num = len(list_buttons)
    if yn_button_select:
        feature_celeba_layout = feature_celeba_organize.feature_celeba_layout
        layout_all_buttons = ipywidgets.VBox([ipywidgets.HBox([list_buttons[item] for item in row]) for row in feature_celeba_layout])
    else:
        layout_all_buttons = ipywidgets.VBox([ipywidgets.HBox(list_buttons[i*ncol:(i+1)*ncol]) for i in range(num//ncol+int(num%ncol>0))])
    return layout_all_buttons
    

# w_button.on_click(on_button_clicked)
guicallback.update_img()
w_button_random = ipywidgets.widgets.Button(description='random face', button_style='success',
                                           layout=ipywidgets.Layout(height='40px', 
                                                               width='128px',
                                                               margin='1px 1px 5px 1px'))
w_button_random.on_click(guicallback.random_gen)

w_box = ipywidgets.HBox([w_img, 
                         ipywidgets.VBox([w_button_random, 
                                         arrange_buttons(list_buttons, yn_button_select=True)])
                        ], layout=ipywidgets.Layout(height='1024}px', width='1024px')
                       )

print('press +/- to adjust feature, toggle feature name to lock the feature')
display(w_box)

press +/- to adjust feature, toggle feature name to lock the feature


HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x04\x00\x00\x00\x04\x00\x08\x02\x00\x…

In [40]:
# n_iters = 1
# for i in range(n_iters):
#     len_z = Gs.input_shapes[0][1]
#     z_sample = np.random.randn(len_z)
#     print(type(z_sample))
#     print(np.shape(z_sample))
#     z_sample = np.array([0] * 512)
#     x_sample = generate_image.gen_single_img(z_sample, Gs=Gs)
# #     generate_image.save_img(x_sample, "src/notebooks/out/test_" + str(i)  + ".jpg")
#     generate_image.save_img(x_sample, "src/notebooks/out/yeee.jpg")

In [8]:
def generate_image_from_features(z_init, features, feature_direction, feature_directoion_disentangled):
    # features: list of 40 values
    # feature_direction: np array of shape (512, 40)
    # returns: z_sample of shape (512, )
    
    assert len(features) == 40
    
    print(features)
    
    z_sample = z_init.copy()
    
    feature_direction_transposed = np.transpose(feature_direction)
    
    step_size = 5
    
    print((features[0] * feature_directoion_disentangled[:, 0])[0])
    
    for direction, feature_val, idx_feature in zip(feature_direction_transposed, features, range(len(features))):
        z_sample = np.add(z_sample, feature_val * feature_directoion_disentangled[:, idx_feature] * step_size)
#         print(z_sample)
    
    return z_sample
    

In [48]:
num_iter = 10

feature_lock_status = np.zeros(len(feature_direction)).astype('bool')
feature_directoion_disentangled = feature_axis.disentangle_feature_axis_by_idx(
    feature_direction, idx_base=np.flatnonzero(feature_lock_status))


z_init = np.random.randn(512)

for i in range(num_iter):
    test = [(i - 5) * 0.2] + [-0.4] * 39
    
    print(feature_name)
    print(test[0])

    z_sample = generate_image_from_features(z_init, test, feature_direction, feature_directoion_disentangled)
    print(z_sample[0])
    x_sample = generate_image.gen_single_img(z=z_sample, Gs=Gs)
    generate_image.save_img(x_sample, "src/notebooks/out/yeeeT{}.jpg".format(i))

['Shadow', 'Arched_Eyebrows', 'Attractive', 'Eye_bags', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Makeup', 'High_Cheekbones', 'Male', 'Mouth_Open', 'Mustache', 'Narrow_Eyes', 'Beard', 'Oval_Face', 'Skin_Tone', 'Pointy_Nose', 'Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Earrings', 'Hat', 'Lipstick', 'Necklace', 'Necktie', 'Age']
-1.0
[-1.0, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4, -0.4]
-0.01992229897026372
1.4056223193676884
['Shadow', 'Arched_Eyebrows', 'Attractive', 'Eye_bags', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee',

In [14]:
feature_lock_status = np.zeros(len(feature_direction)).astype('bool')
feature_directoion_disentangled = feature_axis.disentangle_feature_axis_by_idx(
    feature_direction, idx_base=np.flatnonzero(feature_lock_status))
z_init = np.random.randn(512)
test =  [-0.9583767 , -0.94203347,  0.11940472, -0.30442804, -0.99999946,
       -0.9999996 , -0.7525675 ,  0.5638125 ,  0.7467824 , -0.99928004,
       -0.9999967 , -0.9972268 ,  0.74412024, -0.9978278 , -0.9988206 ,
       -0.9999995 , -0.99997133, -0.99987286, -0.98980856, -0.10975579,
        0.996792  ,  0.79417837, -0.999932  , -0.9641402 ,  0.9906429 ,
       -0.6010362 , -0.99996775, -0.84928405, -0.9913011 , -0.9998905 ,
       -0.9999994 ,  0.9119818 ,  0.12061694, -0.9988287 , -0.9971764 ,
       -0.9999827 , -0.977496  , -0.88524956, -0.9921849 ,  0.8898498 ]
z_sample = generate_image_from_features(z_init, test, feature_direction, feature_directoion_disentangled)
x_sample = generate_image.gen_single_img(z=z_sample, Gs=Gs)
generate_image.save_img(x_sample, "src/notebooks/out/yeeeTus-kabeetus.jpg")

[-0.9583767, -0.94203347, 0.11940472, -0.30442804, -0.99999946, -0.9999996, -0.7525675, 0.5638125, 0.7467824, -0.99928004, -0.9999967, -0.9972268, 0.74412024, -0.9978278, -0.9988206, -0.9999995, -0.99997133, -0.99987286, -0.98980856, -0.10975579, 0.996792, 0.79417837, -0.999932, -0.9641402, 0.9906429, -0.6010362, -0.99996775, -0.84928405, -0.9913011, -0.9998905, -0.9999994, 0.9119818, 0.12061694, -0.9988287, -0.9971764, -0.9999827, -0.977496, -0.88524956, -0.9921849, 0.8898498]
-0.01909306714353474
