# Third Time's the Charm? StyleGAN3 Inference Notebook

## Prepare Environment and Download Code

In [None]:
#@title Clone Repo and Install Ninja { display-mode: "form" }
import os
from pathlib import Path

os.chdir('/content')
CODE_DIR = 'stylegan3-editing'

## clone repo
!git clone https://github.com/ffletcherr/stylegan3-edit-api.git $CODE_DIR

## install ninja
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force

## install some packages
!pip install pyrallis
!pip install git+https://github.com/openai/CLIP.git

os.chdir(f'./{CODE_DIR}')

In [None]:
#@title Import Packages { display-mode: "form" }
import time
import sys
import pprint
import numpy as np
from PIL import Image
import dataclasses
import torch
import torchvision.transforms as transforms


from editing.interfacegan.face_editor import FaceEditor
from editing.styleclip_global_directions import edit as styleclip_edit
from models.stylegan3.model import GeneratorType
from utils.common import tensor2im
from utils.inference_utils import run_on_batch, load_encoder, get_average_image

%load_ext autoreload
%autoreload 2

In [None]:
import os
from oauth2client.client import GoogleCredentials
import dlib
import subprocess

from utils.alignment_utils import align_face, crop_face, get_stylegan_transform


ENCODER_PATHS = {
    "restyle_e4e_ffhq": {"id": "1z_cB187QOc6aqVBdLvYvBjoc93-_EuRm", "name": "restyle_e4e_ffhq.pt"},
    "restyle_pSp_ffhq": {"id": "12WZi2a9ORVg-j6d9x4eF-CKpLaURC2W-", "name": "restyle_pSp_ffhq.pt"},
}
INTERFACEGAN_PATHS = {
    "age": {'id': '1NQVOpKX6YZKVbz99sg94HiziLXHMUbFS', 'name': 'age_boundary.npy'},
    "smile": {'id': '1KgfJleIjrKDgdBTN4vAz0XlgSaa9I99R', 'name': 'Smiling_boundary.npy'},
    "pose": {'id': '1nCzCR17uaMFhAjcg6kFyKnCCxAKOCT2d', 'name': 'pose_boundary.npy'},
    "Male": {'id': '18dpXS5j1h54Y3ah5HaUpT03y58Ze2YEY', 'name': 'Male_boundary.npy'}
}
STYLECLIP_PATHS = {
    "delta_i_c": {"id": "1HOUGvtumLFwjbwOZrTbIloAwBBzs2NBN", "name": "delta_i_c.npy"},
    "s_stats": {"id": "1FVm_Eh7qmlykpnSBN1Iy533e_A2xM78z", "name": "s_stats"},
}


class Downloader:

    def __init__(self, code_dir, use_pydrive, subdir):
        self.use_pydrive = use_pydrive
        current_directory = os.getcwd()
        self.save_dir = os.path.join(os.path.dirname(current_directory), code_dir, subdir)
        os.makedirs(self.save_dir, exist_ok=True)

    def download_file(self, file_id, file_name):
        file_dst = f'{self.save_dir}/{file_name}'
        if os.path.exists(file_dst):
            print(f'{file_name} already exists!')
            return
        if self.use_pydrive:
            downloaded = self.drive.CreateFile({'id': file_id})
            downloaded.FetchMetadata(fetch_all=True)
            downloaded.GetContentFile(file_dst)
        else:
            command = self._get_download_model_command(file_id=file_id, file_name=file_name)
            subprocess.run(command, shell=True, stdout=subprocess.PIPE)

    def _get_download_model_command(self, file_id, file_name):
        """ Get wget download command for downloading the desired model and save to directory ../pretrained_models. """
        url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=self.save_dir)
        return url


def download_dlib_models():
    if not os.path.exists("shape_predictor_68_face_landmarks.dat"):
        print('Downloading files for aligning face image...')
        os.system('wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2')
        os.system('bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2')
        print('Done.')


def run_alignment(image_path):
    download_dlib_models()
    predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
    detector = dlib.get_frontal_face_detector()
    print("Aligning image...")
    aligned_image = align_face(filepath=str(image_path), detector=detector, predictor=predictor)
    print(f"Finished aligning image: {image_path}")
    return aligned_image


def crop_image(image_path):
    download_dlib_models()
    predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
    detector = dlib.get_frontal_face_detector()
    print("Cropping image...")
    cropped_image = crop_face(filepath=str(image_path), detector=detector, predictor=predictor)
    print(f"Finished cropping image: {image_path}")
    return cropped_image


def compute_transforms(aligned_path, cropped_path):
    download_dlib_models()
    predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
    detector = dlib.get_frontal_face_detector()
    print("Computing landmarks-based transforms...")
    res = get_stylegan_transform(str(cropped_path), str(aligned_path), detector, predictor)
    print("Done!")
    if res is None:
        print(f"Failed computing transforms on: {cropped_path}")
        return
    else:
        rotation_angle, translation, transform, inverse_transform = res
        return inverse_transform


## Define Download Configuration
Select below whether you wish to download all models using `pydrive`. Note that if you do not use `pydrive`, you may encounter a "quota exceeded" error from Google Drive.

In [None]:
#@title { display-mode: "form" }
download_with_pydrive = False #@param {type:"boolean"}
downloader = Downloader(code_dir=CODE_DIR,
                        use_pydrive=download_with_pydrive,
                        subdir="pretrained_models")

## Select Model for Inference
Currently, we have ReStyle-pSp and ReStyle-e4e encoders trained for human faces.

In [None]:
#@title Select which model/domain you wish to perform inference on: { display-mode: "form" }
experiment_type = 'restyle_pSp_ffhq' #@param ['restyle_e4e_ffhq', 'restyle_pSp_ffhq']

## Define Inference Parameters

Below we have a dictionary defining parameters such as the path to the pretrained model to use and the path to the image to perform inference on. While we provide default values to run this script, feel free to change as needed.

In [None]:
EXPERIMENT_DATA_ARGS = {
    "restyle_pSp_ffhq": {
        "model_path": "./pretrained_models/restyle_pSp_ffhq.pt",
        "image_path": "./images/face_image.jpg",
        "transform": transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    },
    "restyle_e4e_ffhq": {
        "model_path": "./pretrained_models/restyle_e4e_ffhq.pt",
        "image_path": "./images/face_image.jpg",
        "transform": transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    }
}

EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type]

## Download Models
To reduce the number of requests to fetch the model, we'll check if the model was previously downloaded and saved before downloading the model.
We'll download the model for the selected experiment and save it to the folder `stylegan3-editing/pretrained_models`.

We also need to verify that the model was downloaded correctly.
Note that if the file weighs several KBs, you most likely encounter a "quota exceeded" error from Google Drive.

In [None]:
#@title Download ReStyle SG3 Encoder { display-mode: "form" }
if not os.path.exists(EXPERIMENT_ARGS['model_path']) or os.path.getsize(EXPERIMENT_ARGS['model_path']) < 1000000:
    print(f'Downloading ReStyle encoder model: {experiment_type}...')
    try:
      downloader.download_file(file_id=ENCODER_PATHS[experiment_type]['id'],
                              file_name=ENCODER_PATHS[experiment_type]['name'])
    except Exception as e:
      raise ValueError(f"Unable to download model correctly! {e}")
    # if google drive receives too many requests, we'll reach the quota limit and be unable to download the model
    if os.path.getsize(EXPERIMENT_ARGS['model_path']) < 1000000:
        raise ValueError("Pretrained model was unable to be downloaded correctly!")
    else:
        print('Done.')
else:
    print(f'Model for {experiment_type} already exists!')


## Load Pretrained Model
We assume that you have downloaded all relevant models and placed them in the directory defined by the
`EXPERIMENT_DATA_ARGS` dictionary.

In [None]:
#@title Load ReStyle SG3 Encoder { display-mode: "form" }
model_path = EXPERIMENT_ARGS['model_path']
net, opts = load_encoder(checkpoint_path=model_path)
pprint.pprint(dataclasses.asdict(opts))

## Prepare Inputs



In [None]:
#@title Define and Visualize Input { display-mode: "form" }

image_path = Path(EXPERIMENT_DATA_ARGS[experiment_type]["image_path"])
original_image = Image.open(image_path).convert("RGB")
original_image = original_image.resize((256, 256))
original_image

In [None]:
#@title Get Aligned and Cropped Input Images
input_image = run_alignment(image_path)
cropped_image = crop_image(image_path)
joined = np.concatenate([input_image.resize((256, 256)), cropped_image.resize((256, 256))], axis=1)
Image.fromarray(joined)

In [None]:
#@title Compute Landmarks-Based Transforms { display-mode: "form" }
images_dir = Path("./images")
images_dir.mkdir(exist_ok=True, parents=True)
cropped_path = images_dir / f"cropped_{image_path.name}"
aligned_path = images_dir / f"aligned_{image_path.name}"
cropped_image.save(cropped_path)
input_image.save(aligned_path)
landmarks_transform = compute_transforms(aligned_path=aligned_path, cropped_path=cropped_path)

## Perform Inversion
Now we'll run inference. By default, we'll run using 3 inference steps. You can change the parameter in the cell below.

In [None]:
#@title { display-mode: "form" }
n_iters_per_batch = 3 #@param {type:"integer"}
opts.n_iters_per_batch = n_iters_per_batch
opts.resize_outputs = False  # generate outputs at full resolution

In [None]:
#@title Run Inference { display-mode: "form" }
img_transforms = EXPERIMENT_ARGS['transform']
transformed_image = img_transforms(input_image)

avg_image = get_average_image(net)

with torch.no_grad():
    tic = time.time()
    result_batch, result_latents = run_on_batch(inputs=transformed_image.unsqueeze(0).cuda().float(),
                                                net=net,
                                                opts=opts,
                                                avg_image=avg_image,
                                                landmarks_transform=torch.from_numpy(landmarks_transform).cuda().float())
    toc = time.time()
    print('Inference took {:.4f} seconds.'.format(toc - tic))

In [None]:
#@title Visualize Result { display-mode: "form" }

def get_coupled_results(result_batch, cropped_image):
    result_tensors = result_batch[0]  # there's one image in our batch
    resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size)
    final_rec = tensor2im(result_tensors[-1]).resize(resize_amount)
    input_im = cropped_image.resize(resize_amount)
    res = np.concatenate([np.array(input_im), np.array(final_rec)], axis=1)
    res = Image.fromarray(res)
    return res

res = get_coupled_results(result_batch, cropped_image)
res.resize((1024, 512))

In [None]:
#@title Save Result { display-mode: "form" }

# save image
outputs_path = "./outputs"
os.makedirs(outputs_path, exist_ok=True)
res.save(os.path.join(outputs_path, os.path.basename(image_path)))

# Editing
Given the resulting latent code obtained above, we can perform various edits, which we demonstrate below.

In [None]:
#@title Download pretrained boundaries and editing files: { display-mode: "form" }
download_with_pydrive = True #@param {type:"boolean"}

# download files for interfacegan
downloader = Downloader(code_dir=CODE_DIR,
                        use_pydrive=download_with_pydrive,
                        subdir="editing/interfacegan/boundaries/ffhq")
print("Downloading InterFaceGAN boundaries...")
for editing_file, params in INTERFACEGAN_PATHS.items():
    print(f"Downloading {editing_file} boundary...")
    downloader.download_file(file_id=params['id'],
                             file_name=params['name'])

# download files for styleclip
downloader = Downloader(code_dir=CODE_DIR,
                        use_pydrive=download_with_pydrive,
                        subdir="editing/styleclip_global_directions/sg3-r-ffhq-1024")
print("Downloading StyleCLIP auxiliary files...")
for editing_file, params in STYLECLIP_PATHS.items():
    print(f"Downloading {editing_file}...")
    downloader.download_file(file_id=params['id'],
                             file_name=params['name'])

### InterFaceGAN

In [None]:
editor = FaceEditor(stylegan_generator=net.decoder, generator_type=GeneratorType.ALIGNED)

#@title Select which edit you wish to perform: { display-mode: "form" }
edit_direction = 'age' #@param ['age', 'smile', 'pose', 'Male']
min_value = -5 #@param {type:"slider", min:-10, max:10, step:1}
max_value = 5 #@param {type:"slider", min:-10, max:10, step:1}

In [None]:
#@title Perform Edit! { display-mode: "form" }
print(f"Performing edit for {edit_direction}...")
input_latent = torch.from_numpy(result_latents[0][-1]).unsqueeze(0).cuda()
edit_images, edit_latents = editor.edit(latents=input_latent,
                                        direction=edit_direction,
                                        factor_range=(min_value, max_value),
                                        user_transforms=landmarks_transform,
                                        apply_user_transformations=True)
print("Done!")

In [None]:
#@title Show Result { display-mode: "form" }
def prepare_edited_result(edit_images):
  if type(edit_images[0]) == list:
      edit_images = [image[0] for image in edit_images]
  res = np.array(edit_images[0].resize((512, 512)))
  for image in edit_images[1:]:
      res = np.concatenate([res, image.resize((512, 512))], axis=1)
  res = Image.fromarray(res).convert("RGB")
  return res

res = prepare_edited_result(edit_images)
res

### StyleCLIP (Global Directions)

In [None]:
#@title Prepare StyleCLIP Editor { display-mode: "form" }
styleclip_args = styleclip_edit.EditConfig()
global_direction_calculator = styleclip_edit.load_direction_calculator(stylegan_model=net.decoder, opts=styleclip_args)

In [None]:
#@title Select which edit you wish to perform. { display-mode: "form" }
#@markdown We recommend keeping the neutral_text as is by default.
neutral_text = "a face" #@param {type:"raw"}
target_text = "a smiling face" #@param {type:"raw"}
alpha = 4 #@param {type:"slider", min:-5, max:5, step:0.5}
beta = 0.13 #@param {type:"slider", min:-1, max:1, step:0.1}

In [None]:
#@title Perform Edit! { display-mode: "form" }
opts = styleclip_edit.EditConfig()
opts.alpha_min = alpha
opts.alpha_max = alpha
opts.num_alphas = 1
opts.beta_min = beta
opts.beta_max = beta
opts.num_betas = 1
opts.neutral_text = neutral_text
opts.target_text = target_text

input_latent = result_latents[0][-1]
input_transforms = torch.from_numpy(landmarks_transform).cpu().numpy()
print(f'Performing edit for: "{opts.target_text}"...')
edit_res, edit_latent = styleclip_edit.edit_image(latent=input_latent,
                                                  landmarks_transform=input_transforms,
                                                  stylegan_model=net.decoder,
                                                  global_direction_calculator=global_direction_calculator,
                                                  opts=opts,
                                                  image_name=None,
                                                  save=False)
print("Done!")

In [None]:
#@title Show Result { display-mode: "form" }
input_im = tensor2im(transformed_image).resize((512, 512))
edited_im = tensor2im(edit_res[0]).resize((512, 512))
edit_coupled = np.concatenate([np.array(input_im), np.array(edited_im)], axis=1)
edit_coupled = Image.fromarray(edit_coupled)
edit_coupled.resize((1024, 512))