# Network initialization

In [None]:
#@title ▶ Download the neural networks from github and the training for generation faces

#@markdown We will use Styelgan2-ADA to generate faces and at the end we will use CLIP to search inside the latent space

%cd /content/

!git clone https://github.com/pbaylies/stylegan2-ada-pytorch.git
!git clone https://github.com/openai/CLIP.git
!cp -R CLIP/clip/ stylegan2-ada-pytorch

%cd stylegan2-ada-pytorch
%mkdir weights
%cd weights
!wget -nc https://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-f.pkl
%cd ..

In [None]:
#@title ▶ Install the required libraries

#@markdown It may take a while (3 minutes)

!apt install -y -q ninja-build
!pip -q install mediapy opensimplex ftfy ninja

import subprocess

CUDA_version = [s for s in subprocess.check_output(["nvcc", "--version"]).decode("UTF-8").split(", ") if s.startswith("release")][0].split(" ")[-1]
print("CUDA version:", CUDA_version)

if CUDA_version == "10.0":
    torch_version_suffix = "+cu100"
elif CUDA_version == "10.1":
    torch_version_suffix = "+cu101"
elif CUDA_version == "10.2":
    torch_version_suffix = ""
else:
    torch_version_suffix = "+cu117"

!pip -q install torch==1.13.1{torch_version_suffix} torchvision==0.14.1{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html

In [None]:
#@title ▶ Initialize the network and define some general functions

%cd stylegan2-ada-pytorch

# General imports
import cv2
from tqdm import tqdm
import PIL.Image
import random
from random import randint
import numpy as np
import scipy
import scipy.ndimage
import os
import os.path
import copy

# Notebook imports
from IPython.display import Image
from IPython.display import clear_output
import mediapy as media

# Neural network imports
import torch
import torchvision
import torch.nn.functional as F

# Imports from Network
import dnnlib
import legacy
from generate import seeds_to_zs, line_interpolate
import clip
import apply_factor


custom = False
network_pkl = 'weights/stylegan2-ffhq-config-f.pkl'

G_kwargs = dnnlib.EasyDict()
#G_kwargs.size = size
#G_kwargs.scale_type = scale_type

print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as f:
    # G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
    G = legacy.load_network_pkl(f, custom=custom, **G_kwargs)['G_ema'].to(device) # type: ignore

label = torch.zeros([1, G.c_dim], device=device)

def generate_image(z, truncation_psi, to_pil=True):
  zt = torch.from_numpy(z).to(device)
  img = G(zt, label, truncation_psi=truncation_psi, noise_mode='const')
  return torch_to_image(img, to_pil)

def torch_to_image(tensor, to_pil=True):
  img = (tensor.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
  if to_pil:
    return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
  else:
    return img[0].cpu().numpy()



# Generating images

In [None]:
#@title 🖼 Generate an image { display-mode: "form", run: "auto" }

#@markdown 💬 A seed is a shortcut to a point in the latent space so there's no need to write a value for each of the 512 dimensions.
seed = 1#@param {type:"integer"}
#@markdown 💬 Truncation PSI is a parameter of the network that defines how "strange" are the images that it generates. "Strange" meaning away from to the "average" "intermediate".
truncation_psi = 0.8#@param {type:"slider", min:0, max:1, step:0.1}

z = np.random.RandomState(seed).randn(1, G.z_dim)
im = generate_image(z, truncation_psi)
display(im.resize((512, 512)))

In [None]:
#@title 🖼 Generate multiple images { display-mode: "form" }

#@markdown 💬 Select the number of images and the start seed. The images that will be generated will be for the seeds start_seed to start_seed + num_seeds.
start_seed = 1#@param {type:"integer"}
num_seeds = 64#@param {type:"integer"}
truncation_psi = 0.8#@param {type:"slider", min:0, max:1, step:0.1}
#@markdown 💬 Display size for each image.
image_size = 128#@param [32, 64, 128, 256, 512]

image_size = int(image_size)

seeds = [seed for seed in range(start_seed, start_seed+num_seeds+1)]
seeds_text = [str(seed) for seed in seeds]

points = seeds_to_zs(G, seeds)

images = []
for point in tqdm(points, leave=False):
  im = generate_image(point, truncation_psi, False)
  images.append(cv2.resize(im, dsize=(image_size, image_size), interpolation=cv2.INTER_CUBIC));

columns = int(1536/image_size)
media.show_images(images, border=True, height=image_size, columns=columns, titles=seeds_text)

# Navigating the latent space

In [None]:
#@title 🖼 Generate images from points close to one seed { display-mode: "form", run: "auto" }

#@markdown By generating images to points close to one seed we see that they are similar. Like generating the image at [1,1] and the imatges at [1,1.01], [1,1.02], ...

def generate_variations(seed, num_variations):
  # print(G.z_dim)
  z = np.random.RandomState(seed).randn(1, G.z_dim)
  # print(z)
  im = generate_image(z, truncation_psi)

  images = [im]

  for i in range(0,num_variations):
    z[0,i] = 1
    im = generate_image(z, truncation_psi)
    images.append(im)

  media.show_images(images, border=True, height=128)

seed = 1#@param {type:"integer"}
truncation_psi = 0.8#@param {type:"slider", min:0, max:1, step:0.1}
generate_variations(seed, 10)

In [None]:
#@title 🖼 Generate an interpolation video between two seeds { display-mode: "form" }

#@markdown We can also generate images for the points that are between two seeds, and the convert it to a video.

def generate_interpolation(seed_1, seed_2, frames, easing, video_size):
  points = seeds_to_zs(G,[seed_1,seed_2])
  points = line_interpolate(points,frames,easing)

  images = []
  for idx, point in enumerate(tqdm(points, leave=False)):
    im = generate_image(point, truncation_psi, False)
    images.append(cv2.resize(im, dsize=(video_size, video_size), interpolation=cv2.INTER_CUBIC));

  media.show_video(images, fps=8)

#@markdown 💬 Starting and ending seed
seed1 = 1#@param {type:"integer"}
seed2 = 2#@param {type:"integer"}
truncation_psi = 1#@param {type:"slider", min:0, max:1, step:0.1}

#@markdown

#@markdown 💬 Number of frames of the video: more frames takes longer and the transition is slower
frames = 100#@param {type:"integer"}
#@markdown 💬 Dimensions of the generated video
video_size = "512"#@param [256, 512, 1024]

#@markdown ☝ The video can be downloaded right clicking in it

generate_interpolation(seed1,seed2,frames, 'linear', int(video_size))

In [None]:
#@title 🖼 Generate an interpolation matrix between four seeds { display-mode: "form" }

#@markdown Visualize the interpolation four-side

def generate_matrix(seed_1, seed_2, seed_3, seed_4, steps, easing, truncation_psi=0.8, image_size=64):
  points = seeds_to_zs(G,[seed_1,seed_2,seed_3,seed_4])
  a_to_b = line_interpolate([points[0], points[1]], steps, easing)
  c_to_d = line_interpolate([points[2], points[3]], steps, easing)

  images = []

  for step in range(0, steps):
    row_points = line_interpolate([a_to_b[step], c_to_d[step]], steps, easing)
    for point in row_points:
      im = generate_image(point, truncation_psi, False)
      images.append(cv2.resize(im, dsize=(256, 256), interpolation=cv2.INTER_CUBIC));

  media.show_images(images, border=True, height=64, columns=steps)

#@markdown 💬 Seeds at the corners
seed1 = 1#@param {type:"integer"}
seed2 = 2#@param {type:"integer"}
seed3 = 3#@param {type:"integer"}
seed4 = 4#@param {type:"integer"}
truncation_psi = 0.8#@param {type:"slider", min:0, max:1, step:0.1}

#@markdown 💬 Number and size of the interpolation images
steps = 8#@param {type:"integer"}
image_size = "64"#@param [64, 128, 256]

generate_matrix(seed1, seed2, seed3, seed4, steps, 'linear', truncation_psi, int(image_size))

# Trying to make sense of the directions in this space

In [None]:
#@title ▶ Extracting the directions and defining necessary functions { display-mode: "form" }

#@markdown Using a command from the network a new file is created that contains some directions of the latents space that we will later use.

#@markdown Also, some functions to work with the directions are defined.

%cd stylegan2-ada-pytorch

! python closed_form_factorization.py --ckpt {network_pkl} --out {network_pkl}.pt

def generate_factorized_images(z_w, label, truncation_psi, noise_mode, direction, space):
    if(space == 'w'):
        img1 = G.synthesis(z_w, noise_mode=noise_mode, force_fp32=True)
        img2 = G.synthesis(z_w + direction, noise_mode=noise_mode, force_fp32=True)
        img3 = G.synthesis(z_w - direction, noise_mode=noise_mode, force_fp32=True)
    else:
        img1 = G(z_w, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
        img2 = G(z_w + direction, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
        img3 = G(z_w - direction, label, truncation_psi=truncation_psi, noise_mode=noise_mode)

    return [torch_to_image(img3), torch_to_image(img1), torch_to_image(img2)]


def factorize(latents, space, index, degree, truncation_psi=0.8, image_size=128):

  device = torch.device('cuda')
  eigvec = torch.load(network_pkl + ".pt")["eigvec"].to(device)

  custom = False

  label = torch.zeros([1, G.c_dim], device=device) # assume no class label
  noise_mode = "const" # default

  index_list_of_eigenvalues = []

  image_grid_eigvec = [[],[],[]]

  for l in latents:
      if len(index) ==  1 and index[0] == -1: # use all eigenvalues
          index_list_of_eigenvalues = [*range(len(eigvec))]
      else: # use certain indexes as eigenvalues
          index_list_of_eigenvalues = index

      for j in index_list_of_eigenvalues:
          current_eigvec = eigvec[:, j].unsqueeze(0)
          direction = degree * current_eigvec
          image_group = generate_factorized_images(l, label, truncation_psi, noise_mode, direction, space)

          image_grid_eigvec[0].append(image_group[0])
          image_grid_eigvec[1].append(image_group[1])
          image_grid_eigvec[2].append(image_group[2])

          # media.show_images(image_group, border=False, height=128)

  columns = len(latents)
  if len(latents) == 1:
    columns = 3
  media.show_images(image_grid_eigvec[0] + image_grid_eigvec[1] + image_grid_eigvec[2], border=False, height=image_size, columns=columns)

In [None]:
#@title 🖼 Modify some seeds in one of the found directions { display-mode: "form" }

#@markdown In the previous step the tool defined a list of semantic directions, now we can modify the seeds in those directions. We don't know what each direction means, so we have to try for each one.

#@markdown The tool will generate the seed image and the seed image modified in the positive and negative direction for 8 seeds.

#@markdown 💬 Index of the direction
factor_index = 1#@param {type:"integer"}
#@markdown 💬 Amount of change
factor_degree = 2.5#@param {type:"slider", min:0, max:10, step:0.25}

#@markdown &nbsp;
start_seed = 1#@param {type:"integer"}
truncation_psi = 0.8#@param {type:"slider", min:0, max:1, step:0.1}
image_size = "128"#@param [64, 128, 256, 512]

factor_seeds = [start_seed + i for i in range(8)]
factor_index = [factor_index]

latents_z = []
for seed in factor_seeds:
  z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
  latents_z.append(z)

latents_w = apply_factor.zs_to_ws(G,torch.device('cuda'),label,truncation_psi,latents_z)

factorize(latents_w, 'w', factor_index, factor_degree, truncation_psi, int(image_size))

# Finding you inside the latent space

In [None]:
#@title ⬆ Upload a picture

from google.colab import files

!mkdir /content/images
%cd /content/images

uploaded = files.upload()

for fn in uploaded.keys():
  print('Uploaded file "{name}" with length {length} bytes'.format(name=fn, length=len(uploaded[fn])))

%cd /content/stylegan2-ada-pytorch/

In [None]:
#@title ▶ Define some functions to crop and align the protrait { display-mode: "form" }

%cd stylegan2-ada-pytorch

!wget -nc http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
!bunzip2 --keep shape_predictor_68_face_landmarks.dat.bz2

import dlib

predictor = dlib.shape_predictor('./shape_predictor_68_face_landmarks.dat')

def get_landmark(filepath):
    """get landmark with dlib
    :return: np.array shape=(68, 2)
    """
    detector = dlib.get_frontal_face_detector()

    img = dlib.load_rgb_image(filepath)
    dets = detector(img, 1)

    print("Number of faces detected: {}".format(len(dets)))
    for k, d in enumerate(dets):
        print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format(
            k, d.left(), d.top(), d.right(), d.bottom()))
        # Get the landmarks/parts for the face in box d.
        shape = predictor(img, d)
        print("Part 0: {}, Part 1: {} ...".format(shape.part(0), shape.part(1)))


    t = list(shape.parts())
    a = []
    for tt in t:
        a.append([tt.x, tt.y])
    lm = np.array(a)
    # lm is a shape=(68,2) np.array
    return lm


def align_face(filepath):
    """
    :param filepath: str
    :return: PIL Image
    """

    lm = get_landmark(filepath)

    lm_chin          = lm[0  : 17]  # left-right
    lm_eyebrow_left  = lm[17 : 22]  # left-right
    lm_eyebrow_right = lm[22 : 27]  # left-right
    lm_nose          = lm[27 : 31]  # top-down
    lm_nostrils      = lm[31 : 36]  # top-down
    lm_eye_left      = lm[36 : 42]  # left-clockwise
    lm_eye_right     = lm[42 : 48]  # left-clockwise
    lm_mouth_outer   = lm[48 : 60]  # left-clockwise
    lm_mouth_inner   = lm[60 : 68]  # left-clockwise

    # Calculate auxiliary vectors.
    eye_left     = np.mean(lm_eye_left, axis=0)
    eye_right    = np.mean(lm_eye_right, axis=0)
    eye_avg      = (eye_left + eye_right) * 0.5
    eye_to_eye   = eye_right - eye_left
    mouth_left   = lm_mouth_outer[0]
    mouth_right  = lm_mouth_outer[6]
    mouth_avg    = (mouth_left + mouth_right) * 0.5
    eye_to_mouth = mouth_avg - eye_avg

    # Choose oriented crop rectangle.
    x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
    x /= np.hypot(*x)
    x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
    y = np.flipud(x) * [-1, 1]
    c = eye_avg + eye_to_mouth * 0.1
    quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
    qsize = np.hypot(*x) * 2


    # read image
    img = PIL.Image.open(filepath)

    output_size=1024
    transform_size=4096
    enable_padding=True

    # Shrink.
    shrink = int(np.floor(qsize / output_size * 0.5))
    if shrink > 1:
        rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
        img = img.resize(rsize, PIL.Image.ANTIALIAS)
        quad /= shrink
        qsize /= shrink

    # Crop.
    border = max(int(np.rint(qsize * 0.1)), 3)
    crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
    crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
    if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
        img = img.crop(crop)
        quad -= crop[0:2]

    # Pad.
    pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
    pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
    if enable_padding and max(pad) > border - 4:
        pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
        img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
        h, w, _ = img.shape
        y, x, _ = np.ogrid[:h, :w, :1]
        mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
        blur = qsize * 0.02
        img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
        img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
        img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
        quad += pad[:2]

    # Transform.
    img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
    if output_size < transform_size:
        img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)

    # Save aligned image.
    return img


In [None]:
#@title 🖼 Align the portrait { display-mode: "form" }

#@markdown 💬 Name of the previously upload image
input_image = "portrait.jpg" #@param {type:"string"}
#@markdown 💬 Name of the cropped and aligned image
output_image = "portrait_aligned.jpg" #@param {type:"string"}

aligned = align_face(os.path.join("../images/", input_image))
aligned.save(os.path.join("../images/", output_image))

from IPython.display import Image
Image(os.path.join("../images/", output_image), width=256, height=256)

In [None]:
#@title ▶ Define some functions needed to find an image in the latent space { display-mode: "form" }

def project(
    G,
    target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
    *,
    num_steps                  = 1000,
    w_avg_samples              = 10000,
    initial_learning_rate      = 0.1,
    initial_noise_factor       = 0.05,
    lr_rampdown_length         = 0.25,
    lr_rampup_length           = 0.05,
    noise_ramp_length          = 0.75,
    regularize_noise_weight    = 1e5,
    verbose                    = False,
    device: torch.device
):
    assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)

    def logprint(*args):
        if verbose:
            print(*args)

    G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore

    # Compute w stats.
    logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
    z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
    w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None)  # [N, L, C]
    w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)       # [N, 1, C]
    w_avg = np.mean(w_samples, axis=0, keepdims=True)      # [1, 1, C]
    w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5

    # Setup noise inputs.
    noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }

    # Load VGG16 feature detector.
    url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
    with dnnlib.util.open_url(url) as f:
        vgg16 = torch.jit.load(f).eval().to(device)

    # Features for target image.
    target_images = target.unsqueeze(0).to(device).to(torch.float32)
    if target_images.shape[2] > 256:
        target_images = F.interpolate(target_images, size=(256, 256), mode='area')
    target_features = vgg16(target_images, resize_images=False, return_lpips=True)

    w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
    w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
    optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)

    # Init noise.
    for buf in noise_bufs.values():
        buf[:] = torch.randn_like(buf)
        buf.requires_grad = True

    pimages = []

    pbar = tqdm(range(num_steps), position=0)
    for step in pbar:
        # Learning rate schedule.
        t = step / num_steps
        w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
        lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
        lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
        lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
        lr = initial_learning_rate * lr_ramp
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Synth images from opt_w.
        w_noise = torch.randn_like(w_opt) * w_noise_scale
        ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
        synth_images = G.synthesis(ws, noise_mode='const')

        # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
        synth_images = (synth_images + 1) * (255/2)
        if synth_images.shape[2] > 256:
            synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')

        # Features for synth images.
        synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
        dist = (target_features - synth_features).square().sum()

        # Noise regularization.
        reg_loss = 0.0
        for v in noise_bufs.values():
            noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
            while True:
                reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
                reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
                if noise.shape[2] <= 8:
                    break
                noise = F.avg_pool2d(noise, kernel_size=2)
        loss = dist + reg_loss * regularize_noise_weight

        # Step
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        # logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')
        pbar.set_description(f'dist {dist:<4.2f} loss {float(loss):<5.2f}')

        # Save projected W for each optimization step.
        w_out[step] = w_opt.detach()[0]
        if step % 25 == 0:
          synth_image = G.synthesis(w_out[step].repeat([G.mapping.num_ws, 1]).unsqueeze(0), noise_mode='const')
          im = torch_to_image(synth_image)
          pimages.insert(0, im.resize((128, 128)))

          clear_output()
          media.show_images(pimages, height=128)

        # Normalize noise.
        with torch.no_grad():
            for buf in noise_bufs.values():
                buf -= buf.mean()
                buf *= buf.square().mean().rsqrt()

    return w_out.repeat([1, G.mapping.num_ws, 1])

def run_projection(image_file, seed, num_steps):
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Load target image.
    target_pil = PIL.Image.open(image_file).convert('RGB')
    w, h = target_pil.size
    s = min(w, h)
    target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
    target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
    target_uint8 = np.array(target_pil, dtype=np.uint8)

    # Optimize projection.
    projected_w_steps = project(
        G,
        target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable
        num_steps=num_steps,
        device=device,
        verbose=True
    )

    # Save final projected frame and W vector.
    projected_w = projected_w_steps[-1]
    synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
    img = torch_to_image(synth_image)
    display(img)

    np.savez(image_file + ".npz", w=projected_w.unsqueeze(0).cpu().numpy())



In [None]:
#@title 🖼 Find the uploaded image in the latent space

#@markdown Now we will try to find the point in the latent space that generates the image that is more similar to the portrait.

%cd stylegan2-ada-pytorch

#@markdown 💬 Name of the cropped and aligned image
input_image = "portrait_aligned.jpg" #@param {type:"string"}
#@markdown 💬 Number of steps of the search. More steps is better, but it takes longer.
steps = 1000#@param {type:"integer"}

run_projection(os.path.join("../images/", input_image), 27, steps)

In [None]:
#@title 🖼 Modify yourself inside the latent space { display-mode: "form", run: "auto" }

#@markdown Now we can modify our portrait in the latent space as we did before with the seeds.

#@markdown 💬 Name of the cropped and aligned image
input_position = "portrait_aligned.jpg.npz" #@param {type:"string"}

#@markdown 💬 Index of the direction and amount of change
factor_index = 1#@param {type:"integer"}
factor_degree = 3#@param {type:"slider", min:0, max:10, step:0.25}
#@markdown 💬 Size of the output images
image_size = "512"#@param [128, 256, 512, 1024]

factor_index = [factor_index]

portrait_w = np.load(os.path.join("../images/", input_position))

factorize(torch.from_numpy(np.array([portrait_w['w']])).to(device), 'w', factor_index, factor_degree, image_size=int(image_size))

# Finding someone inside latent space

In [None]:
#@title ▶ Define the functions necessary to find inside the latent space by a description { display-mode: "form" }

def clip_approach(
    G,
    *,
    num_steps                  = 100,
    w_avg_samples              = 10000,
    initial_learning_rate      = 0.02,
    initial_noise_factor       = 0.02,
    noise_floor                = 0.02,
    psi                        = 0.8,
    noise_ramp_length          = 1.0, # was 0.75
    regularize_noise_weight    = 10000, # was 1e5
    seed                       = 69097,
    autoseed                   = True,
    autoseed_samples           = 128,
    noise_opt                  = True,
    ws                         = None,
    text                       = 'a computer generated image',
    device: torch.device
):

    '''
    local_args = dict(locals())
    params = []
    for x in local_args:
        if x != 'G' and x != 'device':
            print(x,':',local_args[x])
            params.append({x:local_args[x]})
    print(json.dumps(params))
    '''

    G = copy.deepcopy(G).eval().requires_grad_(False).to(device)
    lr = initial_learning_rate

    # Load the perceptor
    print('Loading perceptor for text:', text)
    perceptor, preprocess = clip.load('ViT-B/32', jit=True)
    perceptor = perceptor.eval()
    tx = clip.tokenize(text)
    whispers = perceptor.encode_text(tx.cuda()).detach().clone()

    # autoseed
    if autoseed:
      seed = clip_find_best_seed(seed, perceptor, whispers, autoseed_samples, psi)

    # derive W from seed
    if ws is None:
        print('Generating w for seed %i' % seed )
        z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
        w_samples = G.mapping(z,  None, truncation_psi=psi)
        w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)
        w_avg = np.mean(w_samples, axis=0, keepdims=True)
    else:
        w_samples = torch.tensor(ws, device=device)
        w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)
        w_avg = np.mean(w_samples, axis=0, keepdims=True)
    #w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
    w_std = 2 # ~9.9 for portraits network. should compute if using median median

    # Setup noise inputs.
    noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
    w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
    w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)

    if noise_opt:
        optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
        print('optimizer: w + noise')
    else:
        optimizer = torch.optim.Adam([w_opt] , betas=(0.9, 0.999), lr=initial_learning_rate)
        print('optimizer: w')

    # Init noise.
    for buf in noise_bufs.values():
        buf[:] = torch.randn_like(buf)
        buf.requires_grad = True

    pimages = []

    # Descend
    pbar = tqdm(range(num_steps))
    for step in pbar:
        # noise schedule
        t = step / num_steps
        w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2

        # floor
        if w_noise_scale < noise_floor:
            w_noise_scale = noise_floor

        # lr schedule is disabled
        '''
        lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
        lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
        lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
        lr = initial_learning_rate * lr_ramp
        '''

        ''' for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        '''

        # do G.synthesis
        w_noise = torch.randn_like(w_opt) * w_noise_scale
        ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
        synth_images = G.synthesis(ws, noise_mode='const')

        #save1
        '''
        synth_images_save = (synth_images + 1) * (255/2)
        synth_images_save = synth_images_save.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
        PIL.Image.fromarray(synth_images_save, 'RGB').save('project/test1.png')
        '''

        nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        into = synth_images
        into = nom(into) # normalize copied from CLIP preprocess. doesn't seem to affect tho

        # scale to CLIP input size
        into = torch.nn.functional.interpolate(synth_images, (224,224), mode='bilinear', align_corners=True)

        # CLIP expects [1, 3, 224, 224], so we should be fine
        glimmers = perceptor.encode_image(into)
        proximity =  -30 * torch.cosine_similarity(whispers, glimmers, dim = -1).mean() # Dunno why 30 works lol

        # noise reg, from og projector
        reg_loss = 0.0
        for v in noise_bufs.values():
            noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
            while True:
                reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
                reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
                if noise.shape[2] <= 8:
                    break
                noise = F.avg_pool2d(noise, kernel_size=2)

        if noise_opt:
            loss = proximity + reg_loss * regularize_noise_weight
        else:
            loss = proximity

        # Step
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        # print(f'step {step+1:>4d}/{num_steps}:  loss {float(loss):<5.2f} ','lr',
        #       lr, f'noise scale: {float(w_noise_scale):<5.6f}',f'proximity: {float(proximity / (-30)):<5.6f}')
        pbar.set_description(f'loss {float(loss):<5.2f} | proximity: {float(proximity / (-30)):<5.6f}')

        # Save projected W for each optimization step.
        w_out[step] = w_opt.detach()[0]
        if step % 25 == 0:
          synth_image = G.synthesis(w_out[step].repeat([G.mapping.num_ws, 1]).unsqueeze(0), noise_mode='const')
          im = torch_to_image(synth_image)
          pimages.insert(0, im.resize((128, 128)))

          clear_output()
          media.show_images(pimages, height=128)


        # Normalize noise.
        with torch.no_grad():
            for buf in noise_bufs.values():
                buf -= buf.mean()
                buf *= buf.square().mean().rsqrt()

    return w_out.repeat([1, G.mapping.num_ws, 1])

def clip_search(text, seed, autoseed, num_steps, truncation_psi=0.8):
  # dummy
  ws = None
  outdir = '../images'
  save_video = False

  psi = 0.8
  initial_learning_rate = 0.02
  initial_noise_factor = 0.04 # 0.02 originally
  noise_floor = 0.02
  # If noise_opt is true then we're optimizing w and noise vars (default behaviour)
  noise_opt = True

  # approach
  projected_w_steps = clip_approach(
      G,
      num_steps=num_steps,
      device=device,
      initial_learning_rate = initial_learning_rate,
      psi = truncation_psi,
      seed = seed,
      initial_noise_factor = initial_noise_factor,
      noise_floor = noise_floor,
      text = text,
      autoseed = autoseed,
      ws = ws,
      noise_opt = noise_opt
  )

  # save video
  os.makedirs(outdir, exist_ok=True)
  if save_video:
      video = imageio.get_writer(f'{outdir}/out.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
      print (f'Saving optimization progress video "{outdir}/out.mp4"')
      for projected_w in projected_w_steps:
          synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
          synth_image = (synth_image + 1) * (255/2)
          synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
          video.append_data(np.concatenate([synth_image], axis=1))
      video.close()

  '''
  # save ws
  if save_ws:
      print ('Saving optimization progress ws')
      step = 0
      for projected_w in projected_w_steps:
          np.savez(f'{outdir}/w-{hashname}-{step}.npz', w=projected_w.unsqueeze(0).cpu().numpy())
          step+=1
  '''

  # save the result and the final w
  # print ('Saving finals')
  projected_w = projected_w_steps[-1]
  synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
  synth_image = (synth_image + 1) * (255/2)
  synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
  im = PIL.Image.fromarray(synth_image, 'RGB')
  display(im)

  np.savez(f'{outdir}/{text.replace(" ", "_")}.npz', w=projected_w.unsqueeze(0).cpu().numpy())



def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)

import torchvision.transforms as transforms

def clip_find_best_seed(seed, perceptor, whispers, autoseed_samples, psi):
  print(f'Guessing the best seed using {autoseed_samples} samples')

  random.seed(seed)

  pod = np.full((autoseed_samples),0)
  for i in range(autoseed_samples):
      seed = randint(0,500000)
      pod[i] = seed

  normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])

  series = []
  for i in range(autoseed_samples):
      snap = G(torch.from_numpy(np.random.RandomState(pod[i]).randn(1,G.z_dim)).to(device), None, truncation_psi=psi, noise_mode='const')
      snap = torch.nn.functional.interpolate(snap, (224,224), mode='bilinear', align_corners=True)
      # fitness = int( torch.cosine_similarity(whispers, perceptor.encode_image(snap), dim = -1).cpu().detach().numpy() * 1000)
      # fitness = int( spherical_dist_loss(whispers, perceptor.encode_image(snap) ).cpu().detach().numpy() * 1000)
      fitness = int( spherical_dist_loss(whispers, perceptor.encode_image(normalize(snap.add(1).div(2))) ).cpu().detach().numpy() * 1000)

      series.append( (pod[i], fitness ))

  series = sorted(series,key=lambda x:(x[1]))

  # for i in range(4):
  #   print(i, series[i][0], series[i][1])
  #   z = np.random.RandomState(series[i][0]).randn(1, G.z_dim)
  #   im = generate_image(z, truncation_psi)
  #   display(im.resize((256, 256)))

  # print (f'Top guess {series[0][0]}')
  seed = series[0][0]
  return seed

In [None]:
#@title 🖼 Search inside the latent space by a description { display-mode: "form" }

#@markdown 💬 The description of the image to search for.
text = "a portrait of john malkovich" #@param {type:"string"}
#@markdown 💬 Starting seed. It is important to chose a starting seed similar to the description to make it faster and with more probabilities of working.
seed =  1#@param {type:"integer"}
#@markdown 💬 Try to automatically find a good starting seed (sometimes works, sometimes not)
autoseed =  False#@param {type:"boolean"}
#@markdown 💬 Number of steps of the search. More steps is better, but it takes longer.
num_steps = 100#@param {type:"integer"}

clip_search(text, seed, autoseed, num_steps)

# Download

Download the found positions inside the latent space (.npz files)

In [None]:
!zip -qr -0 /content/images.zip /content/images

from google.colab import files
files.download("/content/images.zip")

#Related Collab Notebooks and resources

Search inside faces and other latent spaces: [StyleGAN3+CLIP Online](https://replicate.com/ouhenio/stylegan3-clip) /
[StyleGAN3+CLIP Notebook](https://colab.research.google.com/github/ouhenio/StyleGAN3-CLIP-notebook/blob/main/StyleGAN3%2BCLIP.ipynb)

Find a face inside the latent space and modify with text descriptions:  [StyleClip online](https://replicate.com/orpatashnik/styleclip) / [Styleclip notebooks at Github](https://github.com/orpatashnik/StyleCLIP)

Search inside a more general latent space: [VQGAN+CLIP online](https://huggingface.co/spaces/multimodalart/vqgan) / [VQGAN+CLIP notebook](https://colab.research.google.com/github/justinjohn0306/VQGAN-CLIP/blob/main/VQGAN%2BCLIP%28Updated%29.ipynb)


# Credits

Taller Estampa https://tallerestampa.com / https://github.com/estampa

### Based on
Based on notebooks of [pbaylies fork](https://github.com/pbaylies/stylegan2-ada-pytorch) of [dvschultz Stylegan2-ADA Pytorch fork](https://github.com/dvschultz/stylegan2-ada-pytorch)