# **How to use this notebook:**

Using this colab, you can project photo portrait* into latent space.

First, run 1. and 1.1 (this cell will print the list of available networks)   
If you have saved state in npz - use 2.b, otherwise if you have an image, paste the path to the file (or url). Seed is used to randomize final result, which will be not particularly good most of the time, but it's ok. You can change it later.   
On 3. you choose a network to secondary project onto. The result will be generally awful, that's why you need to generate a new seed to copy a style from. You can use a specific seed or keep it blank for a random seed. After preview, you can rerun 3.4 with checked "save_projected" to save an image.




----
* you can technically try to project anything, but it's on you. In order to work around an alignment error if our image doesn't contain a face - just ignore it, past path to file (it's better be jpg) into _face_file_path_, run the cell, then rename your file "original_name-aligned.jpg"

In [None]:
#@title 1. Installs
from IPython.display import clear_output 
%cd /content/
!git clone https://github.com/NVlabs/stylegan2.git
!git clone https://github.com/NVlabs/stylegan2-ada-pytorch/
!pip install ninja

!wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
!bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2

clear_output()
print('Done!')

In [None]:
#@title # 1.1 Neural Networks handling
%cd /content/

import json

!curl https://raw.githubusercontent.com/dobrosketchkun/wd_network_zoo/main/wd_networks.json -O wd_networks.json


with open('/content/wd_networks.json', 'r') as f:
    jnetw = json.loads(f.read())

nn_list = list(jnetw.keys())

network_name = '' #@#param {type:"string"}
download_all = True #@#param {type:"boolean"}

if download_all:
    for nn in nn_list:
        url = jnetw[nn]
        id =  url.split('/')[-2]
        network_f_name = nn + '.pkl'
        !gdown --id $id -O $network_f_name
    clear_output()
    network_name = 'ffhq'
    print('Network:', network_name)
else:
    if not network_name:
        print('List of available networks:')
        print(', '.join(nn_list))
    elif network_name not in nn_list:
        print('Error, You can not use {}! Only names from this list are available:'.format(network_name))
        print(', '.join(nn_list))
    else:
        clear_output()
        print('Network:', network_name)
        url = jnetw[network_name]
        id =  url.split('/')[-2]
        network_f_name = network_name + '.pkl'
        !gdown --id $id -O $network_f_name

clear_output()

print('All networks are downloded.')
print('List of available networks:')
print(', '.join(nn_list))


In [None]:
#@title # 2.a Alignment
%cd /content/
clear_output()

face_file_path = 'https://hungarytoday.hu/wp-content/uploads/2020/06/Hide-the-Pain-Harold-prof..jpg' #@param {type:"string"}

import os.path



if os.path.isfile(face_file_path):
    pass
else:
    face_file_path = '"' + face_file_path + '"'
    !wget $face_file_path -O image.imgs
    face_file_path = '/content/image.imgs'
    # clear_output()



import PIL
from PIL import Image

im = Image.open(face_file_path)
rgb_im = im.convert('RGB')
face_file_path = face_file_path.split('.')[0] + '.jpg'
rgb_im.save(face_file_path, quality=90)#, subsampling=0)


%cd /content/stylegan2
%tensorflow_version 1.x
# !nvcc test_nvcc.cu -o test_nvcc -run
import sys
import os
import dlib
import glob

def detect_face_landmarks(face_file_path=None,
                          predictor_path=None,
                          img=None):
  # References:
  # -   http://dlib.net/face_landmark_detection.py.html
  # -   http://dlib.net/face_alignment.py.html

  if predictor_path is None:
    predictor_path = '/content/shape_predictor_68_face_landmarks.dat'

  # Load all the models we need: a detector to find the faces, a shape predictor
  # to find face landmarks so we can precisely localize the face
  detector = dlib.get_frontal_face_detector()
  shape_predictor = dlib.shape_predictor(predictor_path)

  if img is None:
    # Load the image using Dlib
    # print("Processing file: {}".format(face_file_path))
    img = dlib.load_rgb_image(face_file_path)

  shapes = list()

  # Ask the detector to find the bounding boxes of each face. The 1 in the
  # second argument indicates that we should upsample the image 1 time. This
  # will make everything bigger and allow us to detect more faces.
  dets = detector(img, 1)
    
  num_faces = len(dets)
#   print("Number of faces detected: {}".format(num_faces))

  # Find the face landmarks we need to do the alignment.
  faces = dlib.full_object_detections()
  for d in dets:
    #   print("Left: {} Top: {} Right: {} Bottom: {}".format(
    #       d.left(), d.top(), d.right(), d.bottom()
    #   ))

      shape = shape_predictor(img, d)
      faces.append(shape)

  return faces
  
  ################
  
faces = detect_face_landmarks(face_file_path=face_file_path)
  
  ################
  
import collections
import matplotlib.pyplot as plt
import numpy as np

plot_style = dict(marker='o',
                  markersize=4,
                  linestyle='-',
                  lw=2)

pred_type = collections.namedtuple('prediction_type', ['slice', 'color'])
pred_types = {'face': pred_type(slice(0, 17), (0.682, 0.780, 0.909, 0.5)),
              'eyebrow1': pred_type(slice(17, 22), (1.0, 0.498, 0.055, 0.4)),
              'eyebrow2': pred_type(slice(22, 27), (1.0, 0.498, 0.055, 0.4)),
              'nose': pred_type(slice(27, 31), (0.345, 0.239, 0.443, 0.4)),
              'nostril': pred_type(slice(31, 36), (0.345, 0.239, 0.443, 0.4)),
              'eye1': pred_type(slice(36, 42), (0.596, 0.875, 0.541, 0.3)),
              'eye2': pred_type(slice(42, 48), (0.596, 0.875, 0.541, 0.3)),
              'lips': pred_type(slice(48, 60), (0.596, 0.875, 0.541, 0.3)),
              'teeth': pred_type(slice(60, 68), (0.596, 0.875, 0.541, 0.4))
              }



def display_landmarks_raw(input_img, preds=None, fig_size=None):
  # This is a raw copy from:
  # https://github.com/1adrianb/face-alignment/blob/master/examples/detect_landmarks_in_image.py

  if fig_size is None:
    fig_size = plt.figaspect(.5)

  fig = plt.figure(figsize=fig_size)
  ax = fig.add_subplot(1, 1, 1) # only display one image
  ax.imshow(input_img)

  if preds is not None:
    for pred_type in pred_types.values():
        ax.plot(preds[pred_type.slice, 0],
                preds[pred_type.slice, 1],
                color=pred_type.color, **plot_style)

  ax.axis('off')

  return

  import numpy as np
from skimage import io

def display_landmarks(image_name, 
                      dlib_output_faces=None, 
                      face_no=0,
                      fig_size=None):
  
  if fig_size is None:
    fig_size = [15, 15]

  input_img = io.imread(image_name)

  if dlib_output_faces is None:
    dlib_output_faces = detect_face_landmarks(face_file_path=image_name,
                                              img=input_img)

  try:
    current_face = dlib_output_faces[face_no]

  except IndexError:
    current_face = None

    print('No face found for index nÂ°{} (max={}).'.format(
        face_no, 
        len(dlib_output_faces)-1,
        ))

  if current_face is None:
    preds = None
  else:
    face_parts = current_face.parts()
    
    preds = np.array([
                      [v.x, v.y] 
                      for v in face_parts
                      ])    
    
  display_landmarks_raw(input_img=input_img, 
                        preds=preds,
                        fig_size=fig_size)  

  return

face_no=0
fig_size=[15,15]

display_landmarks(image_name=face_file_path,
                  dlib_output_faces=faces,
                  face_no=face_no,
                  fig_size=fig_size)


################
ffhq_aligned_image_name = face_file_path.split('.')[0] + '-aligned.jpg'
################

import os
import sys
import requests
import html
import hashlib
import PIL.Image
import PIL.ImageFile
import numpy as np
import scipy.ndimage
import threading
import queue
import time
import json
import uuid
import glob
import argparse
import itertools
import shutil
from collections import OrderedDict, defaultdict

# Reference: https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py

def recreate_aligned_images(json_data, 
                            dst_dir='realign1024x1024',
                            output_size=1024, 
                            transform_size=4096, 
                            enable_padding=True):
    # print('Recreating aligned images...')
    if dst_dir:
        os.makedirs(dst_dir, exist_ok=True)

    for item_idx, item in enumerate(json_data.values()):
        # print('\r%d / %d ... ' % (item_idx, len(json_data)), end='', flush=True)

        # Parse landmarks.
        # pylint: disable=unused-variable
        lm = np.array(item['in_the_wild']['face_landmarks'])
        lm_chin          = lm[0  : 17]  # left-right
        lm_eyebrow_left  = lm[17 : 22]  # left-right
        lm_eyebrow_right = lm[22 : 27]  # left-right
        lm_nose          = lm[27 : 31]  # top-down
        lm_nostrils      = lm[31 : 36]  # top-down
        lm_eye_left      = lm[36 : 42]  # left-clockwise
        lm_eye_right     = lm[42 : 48]  # left-clockwise
        lm_mouth_outer   = lm[48 : 60]  # left-clockwise
        lm_mouth_inner   = lm[60 : 68]  # left-clockwise

        # Calculate auxiliary vectors.
        eye_left     = np.mean(lm_eye_left, axis=0)
        eye_right    = np.mean(lm_eye_right, axis=0)
        eye_avg      = (eye_left + eye_right) * 0.5
        eye_to_eye   = eye_right - eye_left
        mouth_left   = lm_mouth_outer[0]
        mouth_right  = lm_mouth_outer[6]
        mouth_avg    = (mouth_left + mouth_right) * 0.5
        eye_to_mouth = mouth_avg - eye_avg

        # Choose oriented crop rectangle.
        # print(eye_to_mouth.shape)
        x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
        x /= np.hypot(*x)
        x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
        y = np.flipud(x) * [-1, 1]
        c = eye_avg + eye_to_mouth * 0.1
        quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
        qsize = np.hypot(*x) * 2

        # Load in-the-wild image.
        src_file = item['in_the_wild']['file_path']
        img = PIL.Image.open(src_file)

        # Shrink.
        shrink = int(np.floor(qsize / output_size * 0.5))
        if shrink > 1:
            rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
            img = img.resize(rsize, PIL.Image.ANTIALIAS)
            quad /= shrink
            qsize /= shrink

        # Crop.
        border = max(int(np.rint(qsize * 0.1)), 3)
        crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
        crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
        if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
            img = img.crop(crop)
            quad -= crop[0:2]

        # Pad.
        pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
        pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
        if enable_padding and max(pad) > border - 4:
            pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
            img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
            h, w, _ = img.shape
            y, x, _ = np.ogrid[:h, :w, :1]
            mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
            blur = qsize * 0.02
            img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
            img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
            img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
            quad += pad[:2]

        # Transform.
        img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
        if output_size < transform_size:
            img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)

        # Save aligned image.
        dst_subdir = os.path.join(dst_dir, '%05d' % (item_idx - item_idx % 1000))
        os.makedirs(dst_subdir, exist_ok=True)
        img.save(os.path.join(dst_subdir, '%05d.png' % item_idx))

    # All done.
    # print('\r%d / %d ... done' % (len(json_data), len(json_data)))

    return

################

# The first face which is detected:
# NB: we assume that there is exactly one face per picture!
f = faces[0]

parts = f.parts()

num_face_landmarks=68

v = np.zeros(shape=(num_face_landmarks, 2))
for k, e in enumerate(parts):
  v[k, :] = [e.x, e.y]


json_data = dict()

item_idx = 0

json_data[item_idx] = dict()
json_data[item_idx]['in_the_wild'] = dict()
json_data[item_idx]['in_the_wild']['file_path'] = face_file_path
json_data[item_idx]['in_the_wild']['face_landmarks'] = v

recreate_aligned_images(json_data)

!cp '/content/stylegan2/realign1024x1024/00000/00000.png' $ffhq_aligned_image_name
  

In [None]:
#@title # 2.a.1 Projecting
%cd /content/stylegan2-ada-pytorch/


import os
import re
from typing import List, Optional
import click
import numpy as np
import PIL.Image
import torch
from io import BytesIO
from math import ceil
import argparse
import numpy as np
import re
import sys
from io import BytesIO
import IPython.display
import numpy as np
from math import ceil
from PIL import Image, ImageDraw
import imageio
import os
import pickle
from google.colab import files
from IPython.display import display


out_dir = '/content/out'
save_projected_w = True #@param {type:"boolean"}
# save_video = str(save_video)
network_name_ = '/content/' + network_name + '.pkl'

seed = '' #@param {type:"string"}

# /content/ffhq.pkl

if seed:
    pass
else:
    seed = np.random.randint(2**32-1)

!python projector.py --outdir=$out_dir \
    --target=/$ffhq_aligned_image_name --num-steps=800\
    --network=$network_name_ --save-video=False --seed=410

clear_output()

def createImageGrid(images, scale=0.25, rows=1):
  #  PIL.Image
   w,h = images[0].size
   w = int(w*scale)
   h = int(h*scale)
   height = rows*h
   cols = ceil(len(images) / rows)
   width = cols*w
   canvas = PIL.Image.new('RGBA', (width,height), 'white')
   for i,img in enumerate(images):
     img = img.resize((w,h), PIL.Image.ANTIALIAS)
     canvas.paste(img, (w*(i % cols), h*(i // cols))) 
   return canvas


if save_projected_w:
    files.download('{}/projected_w.npz'.format(out_dir))

orig = PIL.Image.open('{}/target.png'.format(out_dir))
proj = PIL.Image.open('{}/proj.png'.format(out_dir))
display(createImageGrid(images=[orig, proj], scale=0.5, rows=1))


In [None]:
#@title # 2.b If you have a projected_w.npz file
%cd /content/


out_dir = '/content/out'

!mkdir $out_dir

projected_w = '' #@param {type:"string"}

!cp $projected_w /content/out/projected_w.npz

#@title # Load style network

%cd /content/stylegan2-ada-pytorch/
import dnnlib
import legacy
import os
import re
from typing import List, Optional
import click
import numpy as np
import PIL.Image
import torch
from io import BytesIO
from math import ceil
import argparse
import numpy as np
import re
import sys
from io import BytesIO
import IPython.display
import numpy as np
from math import ceil
from PIL import Image, ImageDraw
import imageio
import os
import pickle
from google.colab import files
from IPython.display import display


def createImageGrid(images, scale=0.25, rows=1):
  #  PIL.Image
   w,h = images[0].size
   w = int(w*scale)
   h = int(h*scale)
   height = rows*h
   cols = ceil(len(images) / rows)
   width = cols*w
   canvas = PIL.Image.new('RGBA', (width,height), 'white')
   for i,img in enumerate(images):
     img = img.resize((w,h), PIL.Image.ANTIALIAS)
     canvas.paste(img, (w*(i % cols), h*(i // cols))) 
   return canvas



network_name_ = '' #@param {type: "string"}

network_pkl = '/content/' + network_name_ + '.pkl'
noise_mode = 'const'


# print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore



ws = np.load(projected_w)['w']
ws = torch.tensor(ws, device=device) 
all_w = ws
orig_im_all_images = G.synthesis(all_w, noise_mode=noise_mode)
orig_im_all_images = (orig_im_all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
orig_im = PIL.Image.fromarray(orig_im_all_images[0], 'RGB') 
orig_im.save('{}/target.png'.format(out_dir))#, subsampling=0)

display(createImageGrid([orig_im], scale=0.5, rows=1))   

# 3. Copy style

In [None]:
#@title # 3.1 Secondary network preview 

from pathlib import Path

print('List of available networks:')
print(', '.join(nn_list))


network_name_project = '' #@param {type:"string"}
if network_name_project:
    projected_w = out_dir + '/projected_w.npz'
    out_dir2_dump = '/content/projected/'
    out_dir2 = '/content/proj_temp/'

    Path(out_dir2).mkdir(parents=True, exist_ok=True)

    network_name_ = '/content/' + network_name_project + '.pkl'

    print(network_name_)

    %cd /content/stylegan2-ada-pytorch/
    !python generate.py --outdir=$out_dir2 --projected-w=$projected_w \
        --network=$network_name_

    orig = PIL.Image.open('{}/target.png'.format(out_dir))
    proj = PIL.Image.open('{}/proj00.png'.format(out_dir2))
    display(createImageGrid(images=[orig, proj], scale=0.5, rows=1))

    
else:
    raise Exception('Choose some network to project on.')


In [None]:
#@title # 3.2 Load style network
%cd /content/stylegan2-ada-pytorch/

import dnnlib
import legacy

network_pkl = network_name_
noise_mode = 'const'


# print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore


In [None]:
#@title # 3.3 Choose the style
import numpy as np

seed = '' #@param {type:"string"}
try:
    seed = np.array([int(seed)])
except Exception:
    pass
if not seed:
    seed = np.random.randint(2**32-1, size=1)
print('Seed of this image is:', seed[0])

truncation_psi = 0.7
rand = np.random.RandomState(seed).randn(1, G.z_dim)
all_z = np.stack(rand)
all_w = G.mapping(torch.from_numpy(all_z).to(device), None)
w_avg = G.mapping.w_avg
all_w = w_avg + (all_w - w_avg) * truncation_psi
all_images = G.synthesis(all_w, noise_mode=noise_mode)
all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
im = PIL.Image.fromarray(all_images[0], 'RGB')
createImageGrid([im], scale=0.5, rows=1)

In [None]:
#@title # 3.4 Preview
import uuid

deeper_copy = True #@param {type:"boolean"}
even_deeper_copy = False #@param {type:"boolean"}
save_projected_image = True #@param {type:"boolean"}
save_npz = False #@param {type:"boolean"}


ws = np.load(projected_w)['w']
ws = torch.tensor(ws, device=device) 
if even_deeper_copy:
    ws[0][6] = all_w[0][6]
if deeper_copy:
    ws[0][7] = all_w[0][7]
ws[0][8] = all_w[0][8]
ws[0][9] = all_w[0][9]
ws[0][10] = all_w[0][10]
ws[0][11] = all_w[0][11]
ws[0][12] = all_w[0][12]
ws[0][13] = all_w[0][13]
ws[0][14] = all_w[0][14]
ws[0][15] = all_w[0][15]
ws[0][16] = all_w[0][16]
ws[0][17] = all_w[0][17]
all_images = G.synthesis(ws, noise_mode=noise_mode)
all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
im = PIL.Image.fromarray(all_images[0], 'RGB')

# createImageGrid([im], scale=0.5, rows=1)

orig = PIL.Image.open('{}/target.png'.format(out_dir))
proj = im
display(createImageGrid(images=[orig, proj], scale=0.5, rows=1))


###################


out_dir2_dump = '/content/projected/'
Path(out_dir2_dump).mkdir(parents=True, exist_ok=True)

if save_projected_image:
    filename = network_name_project + '_' + uuid.uuid4().hex + '.jpg'
    im.save(out_dir2_dump +  filename, quality=90)#, subsampling=0)

if save_npz:
    np.savez('/content/' + uuid.uuid4().hex + '_projected_w_ns.npz', w=ws)

In [None]:
#@title # 4. Preview all saved images
import glob
import random

%cd $out_dir2_dump
images = glob.glob('*.jpg')
images = list(sorted(images)[::-1])

size = len(images)
in_a_row =  4#@param  {type:"integer"}
scale = 0.4 #@param  {type:"number"}
shuffle = True #@param {type:"boolean"}
if shuffle:
   random.shuffle(images)

images = [orig] + [PIL.Image.open(_) for _ in images]


rows = ceil(size/in_a_row)


display(createImageGrid(images=images, scale=scale, rows=rows))