# HyperStyle Inference Notebook

## Prepare Environment and Download HyperStyle Code

In [None]:
#@title Clone HyperStyle Repo and Install Ninja { display-mode: "form" } 
import os
os.chdir('/content')
CODE_DIR = 'hyperstyle'

## clone repo
!git clone https://github.com/yuval-alaluf/hyperstyle.git $CODE_DIR

## install ninja
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force

os.chdir(f'./{CODE_DIR}')

In [None]:
#@title Import Packages { display-mode: "form" } 
import time
import sys
import pprint
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms

sys.path.append(".")
sys.path.append("..")

from notebooks.notebook_utils import Downloader, HYPERSTYLE_PATHS, W_ENCODERS_PATHS, run_alignment
from utils.common import tensor2im
from utils.inference_utils import run_inversion
from utils.domain_adaptation_utils import run_domain_adaptation
from utils.model_utils import load_model, load_generator

%load_ext autoreload
%autoreload 2

## Define Download Configuration
Select below whether you wish to download all models using `pydrive`. Note that if you do not use `pydrive`, you may encounter a "quota exceeded" error from Google Drive.

In [None]:
#@title { display-mode: "form" } 
download_with_pydrive = True #@param {type:"boolean"} 
downloader = Downloader(code_dir=CODE_DIR, use_pydrive=download_with_pydrive)

## Select Domain for Inference

In [None]:
#@title Select which domain you wish to perform inference on: { display-mode: "form" }
experiment_type = 'faces' #@param ['faces', 'cars', 'afhq_wild']

## Define Inference Parameters

Below we have a dictionary defining parameters such as the path to the pretrained model to use and the path to the image to perform inference on. While we provide default values to run this script, feel free to change as needed.

In [None]:
EXPERIMENT_DATA_ARGS = {
    "faces": {
        "model_path": "./pretrained_models/hyperstyle_ffhq.pt",
        "w_encoder_path": "./pretrained_models/faces_w_encoder.pt",
        "image_path": "./notebooks/images/face_image.jpg",
        "transform": transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    },
    "cars": {
        "model_path": "./pretrained_models/hyperstyle_cars.pt",
        "w_encoder_path": "./pretrained_models/cars_w_encoder.pt",
        "image_path": "./notebooks/images/car_image.jpg",
        "transform": transforms.Compose([
            transforms.Resize((192, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    },
    "afhq_wild": {
        "model_path": "./pretrained_models/hyperstyle_afhq_wild.pt",
        "w_encoder_path": "./pretrained_models/afhq_wild_w_encoder.pt",
        "image_path": "./notebooks/images/afhq_wild_image.jpg",
        "transform": transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    }
}

EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type]

## Download Models
To reduce the number of requests to fetch the model, we'll check if the model was previously downloaded and saved before downloading the model. We'll download the model for the selected experiment and save it to the folder `hyperstyle/pretrained_models`.

We also need to verify that the model was downloaded correctly. All of our models should weigh approximately 1.3GB. Note that if the file weighs several KBs, you most likely encounter a "quota exceeded" error from Google Drive.

In [None]:
#@title Download HyperStyle Model { display-mode: "form" } 
if not os.path.exists(EXPERIMENT_ARGS['model_path']) or os.path.getsize(EXPERIMENT_ARGS['model_path']) < 1000000:
    print(f'Downloading HyperStyle model for {experiment_type}...')
    downloader.download_file(file_id=HYPERSTYLE_PATHS[experiment_type]['id'], file_name=HYPERSTYLE_PATHS[experiment_type]['name'])
    # if google drive receives too many requests, we'll reach the quota limit and be unable to download the model
    if os.path.getsize(EXPERIMENT_ARGS['model_path']) < 1000000:
        raise ValueError("Pretrained model was unable to be downloaded correctly!")
    else:
        print('Done.')
else:
    print(f'HyperStyle model for {experiment_type} already exists!')

In [None]:
#@title Download WEncoder Model { display-mode: "form" } 
if not os.path.exists(EXPERIMENT_ARGS['w_encoder_path']) or os.path.getsize(EXPERIMENT_ARGS['w_encoder_path']) < 1000000:
    print(f'Downloading the WEncoder model for {experiment_type}...')
    downloader.download_file(file_id=W_ENCODERS_PATHS[experiment_type]['id'], file_name=W_ENCODERS_PATHS[experiment_type]['name'])
    # if google drive receives too many requests, we'll reach the quota limit and be unable to download the model
    if os.path.getsize(EXPERIMENT_ARGS['w_encoder_path']) < 1000000:
        raise ValueError("Pretrained model was unable to be downloaded correctly!")
    else:
        print('Done.')
else:
    print(f'WEncoder model for {experiment_type} already exists!')

## Load Pretrained Model
We assume that you have downloaded all relevant models and placed them in the directory defined by the `EXPERIMENT_DATA_ARGS` dictionary.

In [None]:
#@title Load HyperStyle Model { display-mode: "form" } 
model_path = EXPERIMENT_ARGS['model_path']
net, opts = load_model(model_path, update_opts={"w_encoder_checkpoint_path": EXPERIMENT_ARGS['w_encoder_path']})
print('Model successfully loaded!')
pprint.pprint(vars(opts))

## Define and Visualize Input

In [None]:
image_path = EXPERIMENT_DATA_ARGS[experiment_type]["image_path"]
original_image = Image.open(image_path).convert("RGB")
if experiment_type == 'cars':
    original_image = original_image.resize((192, 256))
else:
    original_image = original_image.resize((256, 256))
original_image

In [None]:
#@title Align Image (If Needed)
input_is_aligned = False #@param {type:"boolean"}
if experiment_type == "faces" and not input_is_aligned:
    input_image = run_alignment(image_path)
else:
    input_image = original_image

input_image.resize((256, 256))

## Perform Inference
Now we'll run inference. By default, we'll run using 5 inference steps. You can change the parameter in the cell below.

In [None]:
#@title { display-mode: "form" } 
n_iters_per_batch = 5 #@param {type:"integer"}
opts.n_iters_per_batch = n_iters_per_batch
opts.resize_outputs = False  # generate outputs at full resolution

In [None]:
#@title Run Inference! { display-mode: "form" }
img_transforms = EXPERIMENT_ARGS['transform']
transformed_image = img_transforms(input_image) 

with torch.no_grad():
    tic = time.time()
    result_batch, result_latents, _ = run_inversion(transformed_image.unsqueeze(0).cuda(), 
                                                    net, 
                                                    opts,
                                                    return_intermediate_results=True)
    toc = time.time()
    print('Inference took {:.4f} seconds.'.format(toc - tic))

## Visualize Result

In [None]:
def get_coupled_results(result_batch, transformed_image):
    result_tensors = result_batch[0]  # there's one image in our batch
    final_rec = tensor2im(result_tensors[-1]).resize(resize_amount)
    input_im = tensor2im(transformed_image).resize(resize_amount)
    res = np.concatenate([np.array(input_im), np.array(final_rec)], axis=1)
    res = Image.fromarray(res)
    return res

In [None]:
if opts.dataset_type == "cars":
    resize_amount = (256, 192) if opts.resize_outputs else (512, 384)
else:
    resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size)

res = get_coupled_results(result_batch, transformed_image)
res

## Save Result

In [None]:
# save image 
outputs_path = "./outputs"
os.makedirs(outputs_path, exist_ok=True)
res.save(os.path.join(outputs_path, os.path.basename(image_path)))