##### Reused code from [StyleGAN3 Inversion](https://github.com/ouhenio/stylegan3-projector)

# Generating/ projecting images using pre-trained StyleGAN3 models  

In this notebook, we're going to look at the [StyleGAN3](https://github.com/NVlabs/stylegan3) (aka Alias-Free GAN) image generation model. We'll walkthrough how to generate random images from the model, and interpolating between them. We'll also cover how to project (encode) an image into a StyleGAN latent space and find a close match.  

Compared to DCGAN, StyleGAN is a model architecture that can generate images with much higher resolution, higher quality and diversity. We'll be using the third version of it (StyleGAN3). Different from StyleGAN and StyleGAN2, StyleGAN3 models images as continuous signals (shown in the figure below) so that it is free from the [texture sticking](https://www.youtube.com/watch?v=-2hLdOonvK0) effect, and it can be trained very well on unaligned datasets!

<img src='./notebook_ims/stylegan3-teaser-1920x1006.png' width='500px'></img>

**Note:**  
Due to StyleGAN's requirement for a few customised C++ operations, it's likely that this notebook **cannot run on Windows devices without NVIDIA GPU**. If you miss any library, please refer to StyleGAN3's [requirements](https://github.com/NVlabs/stylegan3?tab=readme-ov-file#requirements). A suggestion is to first try running it on your device, and if anything goes wrong then move to Colab. 

<a href="https://colab.research.google.com/drive/1ak3Yakel1BMqHfrgQi12GOZAQ-6Zf1FQ" target="_blank"><img src="./notebook_ims/colab-badge.svg" height=22></a>   

## 00. Clone repo, installation, download models

#### Clone the official StyleGAN3 repository

In [None]:
!git clone https://github.com/NVlabs/stylegan3.git

#### Install libraries

In [None]:
!pip install ninja

#### Download pre-trained models

We'll use pre-trained StyleGAN models, there are a lot more different models [here](https://github.com/NVlabs/stylegan3?tab=readme-ov-file#additional-material) and [here](https://github.com/justinpinkney/awesome-pretrained-stylegan3)

In [None]:
!curl -L 'https://api.ngc.nvidia.com/v2/models/org/nvidia/team/research/stylegan2/1/files?redirect=true&path=stylegan2-ffhq-1024x1024.pkl' -o stylegan2-ffhq-1024x1024.pkl


if the above `curl` command doesn't work, try `wget` instead (or just navigate to the url and download it and place the file next to this notebook):

In [None]:
#!wget --content-disposition 'https://api.ngc.nvidia.com/v2/models/org/nvidia/team/research/stylegan2/1/files?redirect=true&path=stylegan2-ffhq-1024x1024.pkl' -O stylegan2-ffhq-1024x1024.pkl


#### Import

In [None]:
import sys
import os

base_dir = 'stylegan3'
sys.path.append(f'{base_dir}')

In [None]:
import torch
import numpy as np

from torch_utils import misc
from torch_utils.ops import upfirdn2d

from torchvision.transforms import ToTensor
from PIL import Image

from torchvision.transforms.functional import to_pil_image
from torch.nn.functional import interpolate
from torchvision.utils import make_grid
from IPython.display import display, HTML

import legacy
import dnnlib

Define some functions

In [None]:
def slerp(val, low, high):
    '''
    original: Animating Rotation with Quaternion Curves, Ken Shoemake
    Code: https://github.com/soumith/dcgan.torch/issues/14, Tom White
    '''
    if len(low.shape) == 1:
        omega = np.arccos(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)))
        so = np.sin(omega)
        return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega)/so * high
    elif len(low.shape) == 2:
        ws = []
        for i in range(low.shape[0]):
            omega = np.arccos(np.dot(low[i,:]/np.linalg.norm(low[i,:]), high[i,:]/np.linalg.norm(high[i,:])))
            so = np.sin(omega)
            w = np.sin((1.0-val)*omega) / so * low[i,:] + np.sin(val*omega)/so * high[i,:]
            ws.append(w)
        return torch.tensor(np.array(ws))


def get_perceptual_loss(synth_image, target_features, perceptual_model):
    # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
    synth_image = (synth_image + 1) * (255/2)
    if synth_image.shape[2] > 256:
        synth_image = interpolate(synth_image, size=(256, 256), mode='area')

    # Features for synth images.
    synth_features = perceptual_model(synth_image, resize_images=False, return_lpips=True)
    return (target_features - synth_features).square().sum()

def get_target_features(target, perceptual_model, device):
    target_images = target.unsqueeze(0).to(device).to(torch.float32)
    if target_images.shape[2] > 256:
        target_images = interpolate(target_images, size=(256, 256), mode='area')
    return perceptual_model(target_images, resize_images=False, return_lpips=True)


def run_projector(projection_target, g_model, steps, perceptual_model, device, save_path = None):
  zs = torch.randn([10000, g_model.mapping.z_dim], device=device)
  w_stds = g_model.mapping(zs, None).std(0)

  target_features = get_target_features(projection_target, perceptual_model, device)

  with torch.no_grad():
    qs = []
    losses = []
    for _ in range(8):
      q = (g_model.mapping(torch.randn([4,g_model.mapping.z_dim], device=device), None, truncation_psi=0.75) - g_model.mapping.w_avg) / w_stds
      images = g_model.synthesis(q * w_stds + g_model.mapping.w_avg)
      loss = get_perceptual_loss(images, target_features, perceptual_model)
      i = torch.argmin(loss)
      qs.append(q[i])
      losses.append(loss)
    qs = torch.stack(qs)
    losses = torch.stack(losses)
    i = torch.argmin(losses)
    q = qs[i].unsqueeze(0).requires_grad_()

  # Sampling loop
  q_ema = q
  opt = torch.optim.AdamW([q], lr=0.20, betas=(0.0,0.999))
  target_images = projection_target.unsqueeze(0).to(device).to(torch.float32)/255*2-1

  for i in range(steps):
    opt.zero_grad()
    w = q * w_stds
    image = g_model.synthesis(w + g_model.mapping.w_avg, noise_mode='const')
    # in the first quarter of the steps, use MSE loss, then switch to perceptual loss
    if i > (steps / 4):
      loss = get_perceptual_loss(image, target_features, perceptual_model)
      if i % 10 == 0:
        print(f"image {i}/{steps} | perceptual_loss: {loss.item()}")
    else:
      loss = torch.nn.functional.mse_loss(image, target_images)
      if i % 10 == 0:
        print(f"image {i}/{steps} | mse_loss: {loss}")

    loss.backward()
    opt.step()

    q_ema = q_ema * 0.9 + q * 0.1
    image = g_model.synthesis(q_ema * w_stds + g_model.mapping.w_avg, noise_mode='const')

    if save_path is not None:
        pil_image = to_pil_image(image[0].add(1).div(2).clamp(0,1))
        pil_image.save(f'{save_path}/{i:04}.jpg')

  return q_ema * w_stds + g_model.mapping.w_avg

def image_path_to_tensor(target_image_filename, model_resolution):
    target_pil = Image.open(target_image_filename).convert('RGB')
    w, h = target_pil.size
    s = min(w, h)
    target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
    target_pil = target_pil.resize((model_resolution, model_resolution), Image.LANCZOS)
    target_uint8 = np.array(target_pil, dtype=np.uint8)
    target_tensor = torch.tensor(target_uint8.transpose([2, 0, 1]))
    return target_tensor

In [None]:
device = "cpu"

if torch.cuda.is_available():
    device = "cuda"

elif torch.backends.mps.is_available():
    device = "mps"

print(f'torch version {torch.__version__}')
print(f'Using device: {device}')

## 01. Load a StyleGAN model

Make sure the `network_pkl` match the path to the `.pkl` file we downloaded.

In [None]:
network_pkl = 'stylegan2-ffhq-1024x1024.pkl'

with dnnlib.util.open_url(network_pkl) as f:
    model = legacy.load_network_pkl(f)
    g_model = model['G'].eval().requires_grad_(False).to(device)

## 02. Interpolation

We'll start with randomly sample two latent vectors and interpolate between them. We'll generate is a short video in the z latent space of this model.

#### Create random latent vector

Firsr, create two random latent vectors with seed, and convert them to tensors

In [14]:
np.random.seed(1)

z1 = np.random.randn(1, 512)
z1 = torch.Tensor(z1).to(device)

np.random.seed(42)

z2 = np.random.randn(1, 512)
z2 = torch.Tensor(z2).to(device)


#### Forward pass and generation

Forward pass the two latent vector tensors to the model

In [None]:
img1 = g_model(z1, None, truncation_psi=0.6)[0]
img2 = g_model(z2, None, truncation_psi=0.6)[0]

display(to_pil_image(torch.cat((img1,img2),dim=2).add(1).div(2).clamp(0, 1)))

#### Create interpolated latent vectors  

In order to interpolate between two latent vectors, we're going to use a slerp function to sample between them, this will create a spherical interpolation

In [None]:
from utils import slerp
from base64 import b64encode
import torchvision.transforms as transforms

In [None]:
# define how many vectors we want
num_interp = 100

In [None]:
# create intervals
interp_vals = np.linspace(1./num_interp, 1, num=num_interp)

# Convert latent vectors to numpy arrays
latent_a_np = z1.cpu().numpy().squeeze()
latent_b_np = z2.cpu().numpy().squeeze()

# Create our spherical interpolation between two points
latent_interp = np.array([slerp(v, latent_a_np, latent_b_np) for v in interp_vals], dtype=np.float32)

Now we have 100 latent vectors, each vector has 512 dimensions (512 is the dimensionality of the latent vector)

In [None]:
latent_interp.shape

#### Generate animated frames using all latent vectors  

First, create a folder to store all generated frames, then we're going to loop through all latent vectors and forward pass each of them to generate images

In [None]:
folder_name = 'animation_frames_stylegan'

# create a directory to save the images
if not os.path.exists(folder_name):
    os.makedirs(folder_name)

In [None]:
# Array for images to save to for visualisation
img_list = []

# For each latent vector in our interpolation
for i,latent in enumerate(latent_interp):
    # Convert to torch tensor
    latent = torch.tensor(latent).unsqueeze(0).to(device)
    # Generate image from latent
    image_tensor = g_model(latent, None, truncation_psi=0.6)
    # Convert to PIL Image
    image = transforms.functional.to_pil_image(image_tensor.clamp(-1,1).add(1).div(2).cpu().squeeze(0))
    image.save(f'./{folder_name}/{i:05}.jpg')
    # Add to image array
    img_list.append(image)

#### Create a video using generated images as video frames  

We will need [FFMPEG](https://ffmpeg.org/) to do create a video from generated frames:

##### Installing FFMPEG

If you want to make the gif animation using this code notebook you will need to install [ffmpeg](https://ffmpeg.org/download.html). We have already installed ffmpeg in week 4, so you should already have it. But if you don't, then you can follow these instructions:

To install FFMPEG on Mac:
- Step 1: [install homebrew](https://brew.sh/)
- Step 2: Run `brew install ffmpeg`

To install FFMPEG on Windows:
- Follow [these instructions](https://phoenixnap.com/kb/ffmpeg-windows) for Windows installation

To install FFMPEG on Ubuntu linux:
- Step 1: Run `sudo apt update`
- Step 2: Run `sudo apt install ffmpeg`

If you cannot install FFMPEG you can [make a gif manually using this website](https://ezgif.com/maker), or use other video editing softwares that can process frames into videos.


In [None]:
!ffmpeg -framerate 30 -i ./{folder_name}/%05d.jpg animation_frames_stylegan.mp4

In [None]:
mp4 = open('animation_frames_stylegan.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=512 controls loop>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)

## 03. Projection into latent space

We have shown what can be done, with some sampling of random latent variables in the z space, now we are going upload an image of anyone you choose, and project it into StyleGAN's latent space. 

A StyleGAN model has a fully connected mapping network that transform a latent vector to a series of "style latent vector", called the $w$ space. Each synthesis layer in the StyleGAN generator applies a style latent vector. Here we're projecting the input image into the $w$ space.



In [None]:
from stylegan3.dnnlib.util import open_url
from utils import image_path_to_tensor
from utils import run_projector

from utils import slerp
from base64 import b64encode
import torchvision.transforms as transforms

#### Fetch a feature extractor

In order to move closer to the target style vector, we'll be using a pre-trained feature extractor to tell us how close we are. We'll use [VGG16](https://www.geeksforgeeks.org/vgg-16-cnn-model/) for this.

In [None]:
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
with open_url(url) as f:
    vgg16 = torch.jit.load(f).eval().to(device)
print('Using device:', device, file=sys.stderr)

#### Prepare an image  

You'll need to align the image of the face into similar settings (position, rotation) used for the dataset.

If you're using another image, make sure to change the path here to your image file!

In [None]:
target_image1_filename = "../media/b.jpg"
target_image2_filename = "../media/k.jpg"

target_tensor1 = image_path_to_tensor(target_image1_filename, g_model.img_resolution).to(device)
target_tensor2 = image_path_to_tensor(target_image2_filename, g_model.img_resolution).to(device)
print(target_tensor1.shape)
print(target_tensor2.shape)
print('inputs:')
display(to_pil_image(torch.cat((target_tensor1,target_tensor2),dim=2).div(255).clamp(0, 1)))

In [None]:
# create a directory to save the model checkpoints
if not os.path.exists('animation_frames_stylegan_projector'):
    os.makedirs('animation_frames_stylegan_projector')

#### Project into the $w$ space

We are now going to take our image, and project it into the $w$ space of StyleGAN. This process will start with a random vector, and make changes to the latent vector and noise input, until it converges on on the closest matching image in StyleGAN space to our input image. This is quite a long process, however if you want to shorten it you can change the the `step` variable to a smaller number if you want to reduce the amount of steps taken to find the closest match.

In [None]:
steps = 500

ws_ema1 = run_projector(projection_target = target_tensor1,
                       g_model = g_model, 
                       steps = steps,
                       perceptual_model = vgg16, 
                       device = device, 
                       save_path = None)

ws_ema2 = run_projector(projection_target = target_tensor2,
                       g_model = g_model, 
                       steps = steps,
                       perceptual_model = vgg16, 
                       device = device, 
                       save_path = None)

In [None]:
image1 = g_model.synthesis(ws_ema1, noise_mode='const')[0]
image2 = g_model.synthesis(ws_ema2, noise_mode='const')[0]

display(to_pil_image(torch.cat((image1,image2),dim=2).add(1).div(2).clamp(0, 1)))

#### Interpolating the projections

Now we have our closest match in projected image space, depending on how closely the chosen image conforms to the content and style of photos in the Flickr Faces High Quality (FFHQ) dataset usually determines how closely we are going to be able to find a match.

We can also interpolate between these two projections by interpolating in the $w$ space:

In [12]:
num_interp = 100

# create intervals
interp_vals = np.linspace(1./num_interp, 1, num=num_interp)

# Convert latent vectors to numpy arrays
latent_a_np = ws_ema1.cpu().detach().numpy().squeeze()
latent_b_np = ws_ema2.cpu().detach().numpy().squeeze()

# Create our spherical interpolation between two points
latent_interp = np.array([slerp(v, latent_a_np, latent_b_np) for v in interp_vals], dtype=np.float32)

Now we have 100 latent style vectors $w$, each vector has (num_of_layers, 512) dimensions

In [None]:
latent_interp.shape

In [None]:
folder_name = 'animation_frames_stylegan_projector'

# create a directory to save the images
if not os.path.exists(folder_name):
    os.makedirs(folder_name)

In [None]:
# Array for images to save to for visualisation
img_list = []

# For each latent vector in our interpolation
for i,latent in enumerate(latent_interp):
    # Convert to torch tensor
    latent = torch.tensor(latent).unsqueeze(0).to(device)
    # Generate image from latent
    image_tensor = g_model.synthesis(latent, noise_mode='const')
    # Convert to PIL Image
    image = transforms.functional.to_pil_image(image_tensor.clamp(-1,1).add(1).div(2).cpu().squeeze(0))
    image.save(f'./{folder_name}/{i:05}.jpg')
    # Add to image array
    img_list.append(image)

#### Create an intrpolation video  

We'll use FFMPEG again to create an intrpolation video  

In [None]:
!ffmpeg -framerate 30 -i ./{folder_name}/%05d.jpg animation_frames_stylegan_projector.mp4

In [None]:
mp4 = open('animation_frames_stylegan_projector.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=400 controls loop>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)