In [None]:
import os
import sys
import PIL
import pickle
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline

sys.path.append('stylegan2encoder')

import dnnlib
import pretrained_networks
import dnnlib.tflib as tflib
from encoder.generator_model import Generator

In [None]:
# Make sure the images are rotated properly:
# https://medium.com/@ageitgey/the-dumb-reason-your-fancy-computer-vision-app-isnt-working-exif-orientation-73166c7d39da
def exif_transpose(img):
    if not img:
        return img
    
    exif_orientation_tag = 274
    
    if not hasattr(img, '_getexif'):
        return img
    if not isinstance(img._getexif(), dict):
        return img
    if not exif_orientation_tag in img._getexif():
        return img
    
    exif_data = img._getexif()
    orientation = exif_data[exif_orientation_tag]
    
    # Handle orientation:
    if orientation == 1:
        pass
    elif orientation == 2:
        img = img.transpose(PIL.Image.FLIP_LEFT_RIGHT)
    elif orientation == 3:
        img = img.rotate(180)
    elif orientation == 4:
        img = img.rotate(180).transpose(PIL.Image.FLIP_LEFT_RIGHT)
    elif orientation == 5:
        img = img.rotate(-90, expand=True).transpose(PIL.Image.FLIP_LEFT_RIGHT)
    elif orientation == 6:
        img = img.rotate(-90, expand=True)
    elif orientation == 7:
        img = img.rotate(90, expand=True).transpose(PIL.Image.FLIP_LEFT_RIGHT)
    elif orientation == 8:
        img = img.rotate(90, expand=True)

    return img

def get_files_with_ext(path, extensions):
    files = os.listdir(path)
    files = [os.path.join(path,f) for f in files]
    files = [f for f in files if os.path.isfile(f)]
    files = [f for f in files if f.endswith(tuple(extensions))]
    return files

def rotate_images(src_path, dest_path):
    image_files = get_files_with_ext(src_path, ['.jpg','.png','.jpeg'])

    if not os.path.exists(dest_path):
        os.makedirs(dest_path)
    
    for file in image_files:
        # Load the image
        image = PIL.Image.open(file)
        # Fix the orientation using EXIT data
        image = exif_transpose(image)
        # Save the image
        new_file = os.path.splitext(os.path.basename(file))[0]
        new_file = new_file + '.png'
        new_file = os.path.join(dest_path, new_file)
        image.save(new_file)        

def align_images(src_path, dest_path):
    !python stylegan2encoder/align_images.py $src_path $dest_path
    
    
def resize_images(src_path, dest_path, max_size=1024):
    image_files = get_files_with_ext(src_path, ['.jpg','.png','.jpeg'])

    if not os.path.exists(dest_path):
        os.makedirs(dest_path)
    
    for file in image_files:
        # Load the image
        image = PIL.Image.open(file)
        # ---------------------------
        # Resize the image
        # ---------------------------
        width, height = image.size
        if width > height:
            if width > max_size:
                new_width = int(max_size)
                new_height = int(height * (max_size / width))
                image = image.resize((new_width, new_height), PIL.Image.BILINEAR)
        else:
            if height > max_size:
                new_width = int(width * (max_size / height))
                new_height = int(max_size)
                image = image.resize((new_width, new_height), PIL.Image.BILINEAR)
        # ---------------------------
        
        # Save the image
        new_file = os.path.splitext(os.path.basename(file))[0]
        new_file = new_file + '.png'
        new_file = os.path.join(dest_path, new_file)
        image.save(new_file)
        
def compute_image_latent_vectors(src_path, dest_path):
    !python stylegan2encoder/project_images.py $src_path $dest_path --video=True --video-mode=2
    

def get_latent_vectors(path):
    files = get_files_with_ext(path, '.npy')
    latent_vectors = [np.load(f) for f in files]
    return latent_vectors


In [None]:
def load_model(model_weights):
    _G, _D, Gs = pretrained_networks.load_networks(model_weights)
    generator = Generator(Gs, batch_size=1, randomize_noise=False)
    model = {
        'generator': generator
    }
    return model


def generate_image(latent_vector, model):
    latent_vector = latent_vector.reshape((1, 18, 512))
    model['generator'].set_dlatents(latent_vector)
    img_array = model['generator'].generate_images()[0]
    img = PIL.Image.fromarray(img_array, 'RGB')
    # img = img.resize((256, 256))
    return img

def show(latent_vector, model):
    plt.subplots()
    plt.imshow(generate_image(latent_vector, model))
    plt.show()

In [None]:
def mix_latents(latent_vector_a, latent_vector_b, t):
    return latent_vector_a + (latent_vector_b - latent_vector_a) * t

def get_latent_directions(path):
    files = get_files_with_ext(path, '.npy')
    
    latent_directions = {}
    
    for file in files:
        name = os.path.splitext(os.path.basename(file))[0]
        latent_directions[name] = np.load(file)
    
    return latent_directions

def move(latent_vector, feature, dist, latent_directions):
    return latent_vector + latent_directions[feature] * dist

# Interactive Stuff

## Requirements

In [None]:
import io
import PIL
from ipywidgets import interact, FloatSlider, Layout, FileUpload, \
    Button, Output, Image, HBox, Label
from IPython.display import display

## Defs

In [None]:
def bytes_to_image(b):
    return PIL.Image.open(io.BytesIO(b))

def image_to_bytes(img):
    image_data = io.BytesIO()
    img.save(image_data, format='PNG')
    image_data = image_data.getvalue()
    return image_data

def update_generated_image(image_widget, model, latent_vector_1, 
                           latent_vector_2, latent_directions,
                           mix=None, key=None, value=None):
    
    with output:
        print('Hello from update: mix: {}, key: {}, value: {}'.format(mix,key,value))
        print('\tMagnitudes: ',update_generated_image.latent_direction_magnitudes)
        print('\tMix: ',update_generated_image.mix)
    
    # --------------------------------------------------
    # If the state is being updated, store the state
    # --------------------------------------------------
    if mix is not None:
        update_generated_image.mix = mix

    if key is not None and value is not None:
        update_generated_image.latent_direction_magnitudes[key] = value
        
    mix = update_generated_image.mix
    latent_direction_magnitudes = update_generated_image.latent_direction_magnitudes
    # --------------------------------------------------
        
    # --------------------------------------------------
    # Generate the new latent vector
    # --------------------------------------------------
    latent = mix_latents(latent_vector_1, latent_vector_2, mix)
    
    for key in latent_directions:
        if key not in latent_direction_magnitudes:
            continue
        
        value = latent_direction_magnitudes[key]
        latent = move(latent, key, value, latent_directions)
    # --------------------------------------------------
    
    np.save('generated_latent.npy',latent)
    
    # --------------------------------------------------
    # Generate the new image
    # --------------------------------------------------
    image = generate_image(latent, model)
    image_data = image_to_bytes(image)
    image_widget.value = image_data
    # --------------------------------------------------
    
def make_mix_latents_slider(image_widget, model, latent_vector_1, latent_vector_2, latent_directions):
    start_value = 0.5
    min_value = 0.0
    max_value = 1.0
    value_step = 0.01
    
    label = 'Mix'
    layout = Layout(width='80%', height='20px')
    # If False, image will only be updated on mouse release events.
    # If True, image will be updated continuously (though there is significant delay)
    continuous_update = False
    
    update_func = lambda x : update_generated_image(image_widget,model, 
                                                    latent_vector_1,latent_vector_2,
                                                    latent_directions,mix=x)
    
    slider = FloatSlider(min=min_value, max=max_value, 
                         step=value_step, value=start_value, 
                         description=label, layout=layout, 
                         continuous_update=continuous_update)
    interact(update_func, x=slider)
    
def make_latent_direction_slider(image_widget, model, latent_vector_1, latent_vector_2, feature, latent_directions):
    start_value = 0.0
    min_value = -20.0
    max_value = 20.0
    value_step = 0.01
    
    label = str(feature)
    # Remove prefix to make the label shorter
    label = label.replace('emotion_','')
    layout = Layout(width='80%', height='20px')
    # If False, image will only be updated on mouse release events.
    # If True, image will be updated continuously (though there is significant delay)
    continuous_update = False
    
    update_func = lambda x : update_generated_image(image_widget,model, 
                                                latent_vector_1,latent_vector_2,
                                                latent_directions,
                                                key=feature, value=x)

    slider = FloatSlider(min=min_value, max=max_value, 
                         step=value_step, value=start_value, 
                         description=label, layout=layout, 
                         continuous_update=continuous_update)
    interact(update_func, x=slider)


## Setup

In [None]:
# ---------------------------------------------------------
# Run this section to run this on images in 'raw_images'
# ---------------------------------------------------------
# # Verify that images are upright
# rotate_images('raw_images','rotated_images')

# # Crop to faces in images
# align_images('rotated_images','aligned_images')

# # Make images be no bigger than 1024
# resize_images('aligned_images','resized_images',1024)

# # Compute latent representation of images
# compute_image_latent_vectors('resized_images','processed_images')

# # Load the latent representations
# latent_reps = get_latent_vectors('processed_images')

# stylegan2weights = 'gdrive:networks/stylegan2-ffhq-config-f.pkl'
# model = load_model(stylegan2weights)

# # Draw the latent
# show(latent_reps[0], model)

# # Load the latent direction basis vectors
# latent_directions = get_latent_directions('stylegan2encoder/latent_directions')
# print('-- Available Latent Directions --')
# for k in latent_directions.keys():
#     print('\t',k)
# ---------------------------------------------------------


# ---------------------------------------------------------
# Run this section to run this on precomputed latent vectors
# ---------------------------------------------------------
# Load the latent representations
latent_reps = get_latent_vectors('latent_representations')

stylegan2weights = 'gdrive:networks/stylegan2-ffhq-config-f.pkl'
model = load_model(stylegan2weights)

# Load the latent direction basis vectors
latent_directions = get_latent_directions('stylegan2encoder/latent_directions')
print('-- Available Latent Directions --')
for k in latent_directions.keys():
    print('\t',k)
# ---------------------------------------------------------

# ---------------------------------------------------------
# Test the latents created above:
# ---------------------------------------------------------
# # Draw the latent
# show(latent_reps[0], model)
# show(latent_reps[1], model)

# # Mix the first two latents
# show(mix_latents(latent_reps[0],latent_reps[1], 0.5), model)

# # Load the latent direction basis vectors
# latent_directions = get_latent_directions('stylegan2encoder/latent_directions')
# print('-- Available Latent Directions --')
# for k in latent_directions.keys():
#     print('\t',k)

# # Make them happy
# show(move(latent_reps[0], 'smile', -2.0, latent_directions), model)
# ---------------------------------------------------------

## Interactive Demo

In [None]:
output = Output()
layout = Layout(width='300px')

image_widget_1 = Image(layout=layout)
image_widget_2 = Image(layout=layout)
image_widget_3 = Image(layout=layout)
images_container = HBox([image_widget_1, image_widget_3, image_widget_2])
display(images_container)


latent_1 = latent_reps[0]
latent_2 = latent_reps[1]

# Generated latents are saved here, you can use them, too:
# latent_1 = np.load('generated_latent.npy')


# np.save('generated_latent.npy',latent)

image_widget_1.value = image_to_bytes(generate_image(latent_1, model))
image_widget_2.value = image_to_bytes(generate_image(latent_2, model))


# --------------------------------------------------
# Set the initial state 
# --------------------------------------------------
update_generated_image.mix = 0.5
update_generated_image.latent_direction_magnitudes = {}
# --------------------------------------------------
make_mix_latents_slider(image_widget_3, model, latent_1, latent_2, latent_directions)
for k in latent_directions:
    make_latent_direction_slider(image_widget_3, model, latent_1, latent_2, k, latent_directions)
# --------------------------------------------------

# Experimental stuff

In [None]:
# --------------------------------
# File Uploaders
# --------------------------------
# This file uploader works, but while developing, we lost the ability to get new latent vecs
# (Thanks, Google Drive setting a rate limit on the amount of times model weights may be downloaded)
# So I had to abandon the use of this.
# --
# The only thing left to do is to compute the latent vectors of the loaded images,
# and use those latent vectors for the slider values.
# --------------------------------
def on_value_changed(change):
    with output:
        filename = list(change['new'].keys())[0]
        image_data = change['new'][filename]['content']
        print('Uploaded file: ',filename)
        image = bytes_to_image(image_data)
        
        # Orient the image properly:
        image = exif_transpose(image)
        image_data = image_to_bytes(image)
                
        if len(image_widget_1.value) == 0:
            image_widget_1.value = image_data
        elif len(image_widget_2.value) == 0:
            image_widget_2.value = image_data
        else:
            image_widget_1.value = image_widget_2.value
            image_widget_2.value = image_data
        
        # If we have both images, update the generated image:
        if len(image_widget_1.value) > 0 and len(image_widget_2.value) > 0:            
            image_1 = bytes_to_image(image_widget_1.value)
            image_2 = bytes_to_image(image_widget_2.value)
    
            i1_width, i1_height = image_1.size
            image_2_resized = image_2.resize((i1_width,i1_height), PIL.Image.LANCZOS)
            # image_1.paste(image_2_resized, None, image_2_resized)
            
            image_3 = PIL.Image.blend(image_1, image_2_resized, 0.5)
            image_widget_3.value = image_to_bytes(image_3)


uploader = FileUpload()
uploader.observe(on_value_changed, names='value')
display(uploader)
# -----------------------------------------------------------------------------------

# --------------------------------
# button
# --------------------------------
# This button implementation is a working stub.
# Not sure what we were planning on using it for, 
# but here's a button, if you ever need one.
# --------------------------------
# button = Button(description='Click Me!')

# display(button, output)

# def on_button_clicked(button):
#     with output:
#         print('Button clicked')

# button.on_click(on_button_clicked)
# --------------------------------