**Edit face images using text.**

Update a StyleGAN 2 latent code guided by CLIP, aiming to minimize the cosine similarity between the CLIP embedding of the input text and the CLIP embedding of the image synthesized from the trained latent vector.

In [None]:
from argparse import Namespace
import IPython
import math
import matplotlib.pyplot as plt
import numpy as np
import os
from pathlib import Path
import PIL
import sys
import tempfile
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms

In [None]:
device = 'cuda' #'cpu'

# Encoder

We need to get the latent code that corresponds to the input image. Projection doesn't work that well, so we'll use a pretrained encoder, developed by the creators of https://github.com/eladrich/pixel2style2pixel. It also contains the decoder, a StyleGAN 2 face generator.

The encoding code used in this notebook has been extracted and simplified from https://github.com/eladrich/pixel2style2pixel/blob/master/notebooks/inference_playground.ipynb

## Download requirements

Install Ninja to load C++ extensions:

In [None]:
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force 

## Setup encoder

Download encoder code and add to Python path:

In [None]:
# os.chdir('/content')
ENCODER_CODE_DIR = 'encoder'
!git clone https://github.com/eladrich/pixel2style2pixel.git $ENCODER_CODE_DIR
sys.path.append(str(Path(ENCODER_CODE_DIR).resolve()))

### (Optional) Modify CUDA-only layers implementation to accept CPU as device (credits to https://github.com/rosinality/stylegan2-pytorch.git):

In [None]:
%%writefile {ENCODER_CODE_DIR}/models/stylegan2/op/fused_act.py
import os

import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load


if torch.cuda.is_available():
    module_path = os.path.dirname(__file__)
    fused = load(
        "fused",
        sources=[
            os.path.join(module_path, "fused_bias_act.cpp"),
            os.path.join(module_path, "fused_bias_act_kernel.cu"),
        ],
    )


class FusedLeakyReLUFunctionBackward(Function):
    @staticmethod
    def forward(ctx, grad_output, out, bias, negative_slope, scale):
        ctx.save_for_backward(out)
        ctx.negative_slope = negative_slope
        ctx.scale = scale

        empty = grad_output.new_empty(0)

        grad_input = fused.fused_bias_act(
            grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
        )

        dim = [0]

        if grad_input.ndim > 2:
            dim += list(range(2, grad_input.ndim))

        if bias:
            grad_bias = grad_input.sum(dim).detach()

        else:
            grad_bias = empty

        return grad_input, grad_bias

    @staticmethod
    def backward(ctx, gradgrad_input, gradgrad_bias):
        out, = ctx.saved_tensors
        gradgrad_out = fused.fused_bias_act(
            gradgrad_input.contiguous(),
            gradgrad_bias,
            out,
            3,
            1,
            ctx.negative_slope,
            ctx.scale,
        )

        return gradgrad_out, None, None, None, None


class FusedLeakyReLUFunction(Function):
    @staticmethod
    def forward(ctx, input, bias, negative_slope, scale):
        empty = input.new_empty(0)

        ctx.bias = bias is not None

        if bias is None:
            bias = empty

        out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
        ctx.save_for_backward(out)
        ctx.negative_slope = negative_slope
        ctx.scale = scale

        return out

    @staticmethod
    def backward(ctx, grad_output):
        out, = ctx.saved_tensors

        grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
            grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
        )

        if not ctx.bias:
            grad_bias = None

        return grad_input, grad_bias, None, None


class FusedLeakyReLU(nn.Module):
    def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
        super().__init__()

        if bias:
            self.bias = nn.Parameter(torch.zeros(channel))

        else:
            self.bias = None

        self.negative_slope = negative_slope
        self.scale = scale

    def forward(self, input):
        return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)


def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
    if input.device.type == "cpu":
        if bias is not None:
            rest_dim = [1] * (input.ndim - bias.ndim - 1)
            return (
                F.leaky_relu(
                    input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
                )
                * scale
            )

        else:
            return F.leaky_relu(input, negative_slope=0.2) * scale

    else:
        return FusedLeakyReLUFunction.apply(
            input.contiguous(), bias, negative_slope, scale
        )

In [None]:
%%writefile {ENCODER_CODE_DIR}/models/stylegan2/op/upfirdn2d.py
from collections import abc
import os

import torch
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load


if torch.cuda.is_available():
    module_path = os.path.dirname(__file__)
    upfirdn2d_op = load(
        "upfirdn2d",
        sources=[
            os.path.join(module_path, "upfirdn2d.cpp"),
            os.path.join(module_path, "upfirdn2d_kernel.cu"),
        ],
    )


class UpFirDn2dBackward(Function):
    @staticmethod
    def forward(
        ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
    ):

        up_x, up_y = up
        down_x, down_y = down
        g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad

        grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)

        grad_input = upfirdn2d_op.upfirdn2d(
            grad_output,
            grad_kernel,
            down_x,
            down_y,
            up_x,
            up_y,
            g_pad_x0,
            g_pad_x1,
            g_pad_y0,
            g_pad_y1,
        )
        grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])

        ctx.save_for_backward(kernel)

        pad_x0, pad_x1, pad_y0, pad_y1 = pad

        ctx.up_x = up_x
        ctx.up_y = up_y
        ctx.down_x = down_x
        ctx.down_y = down_y
        ctx.pad_x0 = pad_x0
        ctx.pad_x1 = pad_x1
        ctx.pad_y0 = pad_y0
        ctx.pad_y1 = pad_y1
        ctx.in_size = in_size
        ctx.out_size = out_size

        return grad_input

    @staticmethod
    def backward(ctx, gradgrad_input):
        kernel, = ctx.saved_tensors

        gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)

        gradgrad_out = upfirdn2d_op.upfirdn2d(
            gradgrad_input,
            kernel,
            ctx.up_x,
            ctx.up_y,
            ctx.down_x,
            ctx.down_y,
            ctx.pad_x0,
            ctx.pad_x1,
            ctx.pad_y0,
            ctx.pad_y1,
        )
        # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
        gradgrad_out = gradgrad_out.view(
            ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
        )

        return gradgrad_out, None, None, None, None, None, None, None, None


class UpFirDn2d(Function):
    @staticmethod
    def forward(ctx, input, kernel, up, down, pad):
        up_x, up_y = up
        down_x, down_y = down
        pad_x0, pad_x1, pad_y0, pad_y1 = pad

        kernel_h, kernel_w = kernel.shape
        batch, channel, in_h, in_w = input.shape
        ctx.in_size = input.shape

        input = input.reshape(-1, in_h, in_w, 1)

        ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))

        out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
        out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
        ctx.out_size = (out_h, out_w)

        ctx.up = (up_x, up_y)
        ctx.down = (down_x, down_y)
        ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)

        g_pad_x0 = kernel_w - pad_x0 - 1
        g_pad_y0 = kernel_h - pad_y0 - 1
        g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
        g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1

        ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)

        out = upfirdn2d_op.upfirdn2d(
            input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
        )
        # out = out.view(major, out_h, out_w, minor)
        out = out.view(-1, channel, out_h, out_w)

        return out

    @staticmethod
    def backward(ctx, grad_output):
        kernel, grad_kernel = ctx.saved_tensors

        grad_input = None

        if ctx.needs_input_grad[0]:
            grad_input = UpFirDn2dBackward.apply(
                grad_output,
                kernel,
                grad_kernel,
                ctx.up,
                ctx.down,
                ctx.pad,
                ctx.g_pad,
                ctx.in_size,
                ctx.out_size,
            )

        return grad_input, None, None, None, None


def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
    if not isinstance(up, abc.Iterable):
        up = (up, up)

    if not isinstance(down, abc.Iterable):
        down = (down, down)

    if len(pad) == 2:
        pad = (pad[0], pad[1], pad[0], pad[1])

    if input.device.type == "cpu":
        out = upfirdn2d_native(input, kernel, *up, *down, *pad)

    else:
        out = UpFirDn2d.apply(input, kernel, up, down, pad)

    return out


def upfirdn2d_native(
    input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
    _, channel, in_h, in_w = input.shape
    input = input.reshape(-1, in_h, in_w, 1)

    _, in_h, in_w, minor = input.shape
    kernel_h, kernel_w = kernel.shape

    out = input.view(-1, in_h, 1, in_w, 1, minor)
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(
        out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
    )
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape(
        [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
    )
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x

    return out.view(-1, channel, out_h, out_w)

### Download weights:

In [None]:
def get_download_encoder_command(file_id, file_name):
    """ Get wget download command for downloading the desired model and save to directory ../pretrained_models. """
    current_directory = os.getcwd()
    save_path = os.path.join(os.path.dirname(current_directory), ENCODER_CODE_DIR, "pretrained_models")
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)
    return url


ENCODER_DOWNLOAD_PATH = {"id": "1bMTNWkh5LArlaWSc_wa8VKyq2V42T2z0", "name": "psp_ffhq_encode.pt"}
NORMALIZE_MEAN = 0.5
NORMALIZE_STD = 0.5

Download encoder and decoder weights:

In [None]:
os.chdir(f'./{ENCODER_CODE_DIR}')
download_command = get_download_encoder_command(file_id=ENCODER_DOWNLOAD_PATH["id"], file_name=ENCODER_DOWNLOAD_PATH["name"])
!{download_command}

### Load and test encoder

In [None]:
from models.psp import pSp


encoder_args = {
    "model_path": "pretrained_models/psp_ffhq_encode.pt",
    "image_path": "notebooks/images/input_img.jpg",
    "transform": transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([NORMALIZE_MEAN]*3, [NORMALIZE_STD]*3)])
}


def denormalize(t:torch.Tensor):
    return t * NORMALIZE_STD + NORMALIZE_MEAN


def tensor_to_image(t:torch.Tensor) -> PIL.Image:
    return PIL.Image.fromarray((denormalize(t).detach().cpu().clamp(0, 1).permute(1, 2, 0) * 255).numpy().astype(np.uint8))


def run_encoder(inputs, net, device):
    return net(inputs.to(device).float(), randomize_noise=False)

Load model:

In [None]:
encoder_pt_path = encoder_args['model_path']

if os.path.getsize(encoder_pt_path) < 1000000:
    raise ValueError("Pretrained encoder model was unable to be downloaded correctly!")
    
ckpt = torch.load(encoder_pt_path, map_location='cpu')
opts = ckpt['opts']

# update the training options
opts['checkpoint_path'] = encoder_pt_path
if 'learn_in_w' not in opts:
    opts['learn_in_w'] = False
if 'output_size' not in opts:
    opts['output_size'] = 1024
opts['device'] = device

opts = Namespace(**opts)
net = pSp(opts)
net.eval()
net.to(device)
print('Model successfully loaded!')

In [None]:
net

Load and show input image:

In [None]:
image_path = encoder_args["image_path"]
original_image = PIL.Image.open(image_path)
print('Size = ', original_image.size)
original_image

Prepare image as Pytorch input tensor:

In [None]:
img_transforms = encoder_args['transform']
image_tensor = img_transforms(original_image)
image_tensor.shape, image_tensor.min(), image_tensor.max()

Test encoder:

In [None]:
out_image_tensor = run_encoder(image_tensor.unsqueeze(0), net, device)
tensor_to_image(out_image_tensor[0])

# Face editing

## Install requirements

Install HuggingFace transformers library (needed for CLIP):

In [None]:
!pip install transformers

## Helper methods and classes

In [None]:
from dataclasses import dataclass
from transformers import CLIPProcessor, CLIPModel
from typing import Callable, List, Tuple


def generate(latent):
    out_img_t = net.decoder([latent], randomize_noise=False, input_is_latent=True)

    
def get_default_device():
    return 'cuda' if torch.cuda.is_available() else 'cpu'
    

@dataclass
class HyperParameters:
    lr:float=1e-3
    betas:Tuple[float, float]=(0.9, 0.999)
    wd:float=0.
    low_grad_discarded_pct:float=0.
    

class ClipLoss:
    def __init__(self, target_text:str, negative_texts:List[str]=None, device=None):
        if negative_texts is None:
            negative_texts = []
        if device is None:
            device = get_default_device()

        self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").eval()
        self.model.requires_grad_(False).to(device)
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        
        self.inputs = self.processor(text=[target_text] + negative_texts, images=None, return_tensors="pt", padding=True).to(device)
        self.inner_loss = self._single_text_loss if len(negative_texts) == 0 else self._multi_text_loss
    
    def _single_text_loss(self, clip_out):
        return -clip_out.logits_per_image.mean()
    
    #def _multi_text_loss(self, clip_out):
    #    # TODO: try refactor as CELoss with target 0
    #    # The `target_text` is always at the first position, hence 0 index
    #    return -F.log_softmax(clip_out.logits_per_image, dim=-1)[:, 0].mean()
    
    def _multi_text_loss(self, clip_out):
        n_texts = clip_out.logits_per_image.shape[-1]
        target_idx = 0
        negative_idxs = list(range(target_idx)) + list(range(target_idx+1, n_texts))
        loss_val = -clip_out.logits_per_image[:, target_idx] + clip_out.logits_per_image[:, negative_idxs].mean()
        return loss_val.mean()
    
    def _normalize(self, img_t):
        norm_mean = torch.tensor(self.processor.feature_extractor.image_mean, device=img_t.device)[:, None, None]
        norm_std = torch.tensor(self.processor.feature_extractor.image_std, device=img_t.device)[:, None, None]   
        return (img_t - norm_mean) / norm_std
    
    def __call__(self, img_t:torch.Tensor, *args):
        # Prepare input for CLIP: denormalize using our encoder stats and normalize using CLIP stats
        clip_in_img_t = F.interpolate(self._normalize(denormalize(img_t)), size=224)
        assert clip_in_img_t.requires_grad
        
        clip_out = self.model(pixel_values=clip_in_img_t, **self.inputs)
        loss = self.inner_loss(clip_out)
        return loss, f'{loss:.2f}'
    
    
class FeatureLoss:
    def __init__(self, initial_img, inner_loss=None, device=None, denormalize=True, layers_idxs:List[int]=None):
        self.initial_img = initial_img
        self.inner_loss = inner_loss if inner_loss is not None else nn.MSELoss()
        if device is None:
            device = get_default_device()
        # Not sure whether convnext_tiny is the best option
        self.net = models.convnext_tiny(pretrained=True).eval()
        self.net.requires_grad_(False)
        self.net.to(device)
            
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.denormalize = denormalize
        
        self.hook_outputs = []    
        max_layer_idx = len(self.net.features) - 1      
        if layers_idxs is None:
            layers_idxs = range(max_layer_idx + 1)
        for i in layers_idxs:
            assert 0 <= i <= max_layer_idx, f'layers_idxs must be between 0 and {max_layer_idx}'
            self.net.features[i].register_forward_hook(self._hook_fn)
                
        with torch.no_grad():
            self.initial_img_ftrs_by_layer = self._extract_ftrs(self.initial_img.unsqueeze(0).to(device))
        
    def _hook_fn(self, module_self, inp, out):
        self.hook_outputs.append(out)
        
    def _normalize(self, x):
        if self.denormalize:
            x = denormalize(x)
        return self.normalize(x)
    
    def _maybe_reshape(self, x):
        expected_spatial_shape = self.initial_img.shape[-2:]
        if x.shape[-2:] != expected_spatial_shape:
            x = F.interpolate(x, size=(expected_spatial_shape))
        return x
        
    def _extract_ftrs(self, x):
        self.net.features(self._maybe_reshape(self._normalize(x)))
        # The feature maps are stored in self.hook_outputs
        # Clone hook_outputs
        all_ftrs = list(self.hook_outputs)
        self.hook_outputs.clear()
        return all_ftrs
        
    def __call__(self, x, *args):
        x_ftrs_by_layer = self._extract_ftrs(x)
        losses = torch.stack([
            self.inner_loss(x_ftrs, initial_img_ftrs)
            for x_ftrs, initial_img_ftrs in zip(x_ftrs_by_layer, self.initial_img_ftrs_by_layer)
        ])
        loss = losses.sum()
        return loss, f'{loss:.3f}'

    
class ComposedLoss:
    def __init__(self, losses:List[Tuple[Callable, float]]):
        self.losses = losses
        
    def __call__(self, x, *args):
        #return torch.stack([w * l(x, *args) for l, w in self.losses]).sum()
        losses_info = [(l.__class__.__name__, l(x, *args), w) for l, w in self.losses]
        displayable_losses = {name.replace('Loss', ''): val for name, (_, val), _ in losses_info}
        loss = torch.stack([w * val for _, (val, _), w in losses_info]).sum()
        return loss, displayable_losses
    
    
class DisplayResultsCallback:
    def __init__(self, n_steps:int, display_interval=20):
        self.display_interval = display_interval
        n_images_to_display = math.ceil(n_steps / display_interval)
        self.n_cols = 5
        self.n_rows = math.ceil(n_images_to_display / self.n_cols)
        self.fig, self.axs = plt.subplots(self.n_rows, self.n_cols, figsize=(5 * self.n_cols, 5 * self.n_rows))
        for row in self.axs:
            for axis in row:
                axis.set_axis_off()
        self.display_handle = None

    def __call__(self, step_idx, loss_val, image_t):
        if (step_idx+1) % self.display_interval == 0:
            image = tensor_to_image(image_t[0].detach().cpu())
            flattened_idx = (step_idx+1) // self.display_interval - 1
            row_idx = flattened_idx // self.n_cols
            col_idx = flattened_idx % self.n_cols
            axis = self.axs[row_idx, col_idx]
            axis.imshow(image)
            axis.set_title(f'Step{step_idx},loss={loss_val}')
            IPython.display.clear_output(wait=True)
            if self.display_handle is None:
                self.display_handle = IPython.display.display(self.fig, display_id=True)
            else:
                self.display_handle.update(self.fig)
            
    def __del__(self):
        self.fig.clear()


class GifRecorderCallback:
    def __init__(self, n_steps:int, save_interval=5, frame_duration_ms=100, gif_filename="transformation.gif"):
        self.n_steps = n_steps
        self.save_interval = save_interval
        self.frame_duration_ms = frame_duration_ms
        self.gif_filename = gif_filename
        self.temp_dir = tempfile.TemporaryDirectory()
    
    def _get_img_path(self, idx):
        filename = f'img{idx}.jpg'
        path = f"{self.temp_dir.name}/{filename}"
        return path
    
    def _save_gif(self):
        frames = [PIL.Image.open(self._get_img_path(i)) for i in range(1, (self.n_steps // self.save_interval) + 1)]
        if len(frames) == 0: return
        
        frames[0].save(
            self.gif_filename, 
            format="GIF", 
            append_images=frames[1:], 
            save_all=True, 
            duration=self.frame_duration_ms, 
            loop=0
        )
        self.temp_dir.cleanup()

    def __call__(self, step_idx, loss_val, image_t):
        if (step_idx+1) % self.save_interval == 0:
            img_idx = (step_idx+1) // self.save_interval
            image = tensor_to_image(image_t[0].detach().cpu())
            image.save(self._get_img_path(img_idx))
            
        if step_idx == (self.n_steps - 1):
            self._save_gif()   

    
def unravel_index(indices, shape):
    coefs = shape[1:].flipud().cumprod(dim=0).flipud()
    coefs = torch.cat((coefs, coefs.new_tensor((1,))), dim=0)
    coords = torch.div(indices[..., None], coefs, rounding_mode='trunc') % shape
    return coords.t()
    
    
def discard_grad_(x, bottom_grad_pct=0.2):
    if bottom_grad_pct == 0: return
    
    with torch.no_grad():
        abs_grad = x.grad.abs().view(-1)
        # TODO: should relative grad be better? (x.grad / x)
        n_items_zeroed = int(bottom_grad_pct * x.grad.numel())
        bottomk_indices = unravel_index(abs_grad.topk(n_items_zeroed, largest=False)[1], torch.tensor(x.shape, device=x.device))
        x.grad[list(bottomk_indices)] = 0
        

def train(
    n_steps:int, enc_dec:nn.Module, initial_img_t, target_text, negative_texts=None, hp:HyperParameters=None, device=None, 
    loss_fn:Callable=None, after_step:List[Callable]=None
):
    if hp is None:
        hp = HyperParameters()
    if device is None:
        device = get_default_device()        
    if loss_fn is None:
        loss_fn = ClipLoss(target_text, negative_texts=negative_texts, device=device)
    if after_step is None:
        after_step = []

    enc_dec.requires_grad_(False).eval().to(device)
        
    with torch.no_grad():
        #initial_latent = enc_dec.encoder(initial_img_t.unsqueeze(0).to(device).float(), randomize_noise=False)
        _, initial_latent = enc_dec(
            initial_img_t.unsqueeze(0).to(device).float(), randomize_noise=False, return_latents=True
        )
        
    # shape (1, 18, 512)
    latent = initial_latent.detach().clone()
    latent.requires_grad_(True)
    opt = torch.optim.Adam([latent], lr=hp.lr, betas=hp.betas, weight_decay=hp.wd)
    
    for i in range(n_steps):
        img_t, _ = enc_dec.decoder([latent], randomize_noise=False, input_is_latent=True)
        
        loss, loss_to_display = loss_fn(img_t, latent, initial_latent)
        loss.backward()  
        discard_grad_(latent, hp.low_grad_discarded_pct)
        opt.step()
        
        for cb in after_step: cb(i, loss_to_display, img_t)
            
        opt.zero_grad()

## (Optional) Install Kaggle and download FFHQ dataset

It's only needed if you wish to use images from FFHQ as starting images, like some examples do.

In [None]:
!pip install kaggle

Upload token `kaggle.json`:

In [None]:
from google.colab import files
files.upload()

Copy token to the directory that Kaggle expects:

In [None]:
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

Download the images used in the examples

In [None]:
!kaggle datasets download arnaud58/flickrfaceshq-dataset-ffhq -f 00001.png -p ffhq --unzip
!kaggle datasets download arnaud58/flickrfaceshq-dataset-ffhq -f 00011.png -p ffhq --unzip
!kaggle datasets download arnaud58/flickrfaceshq-dataset-ffhq -f 00012.png -p ffhq --unzip

To download the full dataset, execute:

In [None]:
#!kaggle datasets download arnaud58/flickrfaceshq-dataset-ffhq --unzip

## Train latent

Set an input image:

In [None]:
!wget https://upload.wikimedia.org/wikipedia/commons/a/a0/Bill_Gates_2018.jpg
initial_image_path = './Bill_Gates_2018.jpg'
initial_image = PIL.Image.open(initial_image_path).crop((70, 40, 630, 600))
print('Size = ', initial_image.size)
initial_image_tensor = img_transforms(initial_image)
initial_image

Align input image (optional, skip if the input image is already centered):

In [None]:
# TODO

Prepare image as Pytorch input tensor:

In [None]:
initial_image_tensor = img_transforms(initial_image)
initial_image_tensor.shape

`train` parameters:
- `n_steps`: number of updates/optimizer steps.
- `enc_dec`: pSp network (encoder + StyleGAN 2 decoder). You should just pass `net` always.
- `initial_image_tensor`: pyTorch tensor of shape (3, 256, 256) normalized with `NORMALIZE_MEAN` and `NORMALIZE_STD`. The conversion from a PIL image can be performed by `img_transforms`, as shown above.
- `target_text`: the text that describes the face you want to get after several updates, starting from `initial_image_tensor`.
- `negative_texts`: list of strings that shouldn't be a good description of the desired target image.
- `hp`: hyperparameters
  - `lr`, `betas`, `wd`: hyperparameters of Adam optimizer.
  - `low_grad_discarded_pct`: the items of the gradient vector whose absolute value is one of the `low_grad_discarded_pct * 100`% smallest are discarded (zeroed) before every optimizer step. 0 means all elements of the latent vector are updated, and 1 means none, which wouldn't make sense as you'd get the original image. It's meant to reduce entanglement but it may not be the best option.
- `device`: specified with PyTorch format.
- `loss_fn`: loss function, a callable that must return a scalar tensor. Parameters:
  - img_t: image tensor generated from the current `latent`.
  - latent: current latent vector.
  - initial_latent: latent vector that makes StyleGAN generate `initial_image_tensor`.
- `after_step`: callable invoked after every optimizer step. Parameters:
  - step_idx: index of last step.
  - loss_val: last output of the loss function.
  - img_t: image tensor generated from the current `latent`.
  
Other variables to configure:
- `FEATURE_LOSS_W`: weight that multiplies the feature loss. Set to 0 to disable the feature loss. The higher its value, the more the generated images will resemble the initial image.
- `save_gif`: if `True`, a gif of the transformation is stored in `./transformation.gif`

Train:

In [None]:
n_steps = 100
target_text = "Mark Zuckerberg"
negative_texts = []
clip_loss = ClipLoss(target_text, negative_texts=negative_texts, device=device)
FEATURE_LOSS_W = 0
save_gif = True
gif_filename = "transformation.gif"

if FEATURE_LOSS_W > 0:
    feature_loss = FeatureLoss(
        initial_image_tensor, 
        device=device, 
    )
    loss_fn = ComposedLoss([
        (clip_loss, 1.),
        (feature_loss, FEATURE_LOSS_W)
    ])
else:
    loss_fn = clip_loss
    
callbacks = [DisplayResultsCallback(n_steps, display_interval=5)]
if save_gif:
    callbacks.append(GifRecorderCallback(
        n_steps, save_interval=5, frame_duration_ms=100, gif_filename=gif_filename,
    ))

train(
    n_steps, 
    net, 
    initial_image_tensor, 
    target_text,
    negative_texts=negative_texts,
    hp=HyperParameters(lr=1e-2, low_grad_discarded_pct=0.),
    after_step=callbacks,
    loss_fn=loss_fn,
)

Download the gif if you have set `save_gif=True`:

In [None]:
IPython.display.FileLink(gif_filename)

Release memory:

In [None]:
del callbacks

# Examples

## An old woman

In [None]:
initial_image_path = 'ffhq/00001.png'
initial_image = PIL.Image.open(initial_image_path)
print('Size = ', initial_image.size)
initial_image_tensor = img_transforms(initial_image)
initial_image

In [None]:
n_steps = 100
train(
    n_steps, 
    net, 
    initial_image_tensor, 
    "An old woman",
    hp=HyperParameters(lr=1e-2, low_grad_discarded_pct=0),
    after_step=[DisplayResultsCallback(n_steps, display_interval=5)]
)

## An old man

In [None]:
n_steps = 100
train(
    n_steps, 
    net, 
    initial_image_tensor, 
    "An old man",
    hp=HyperParameters(lr=1e-2, low_grad_discarded_pct=0),
    after_step=[DisplayResultsCallback(n_steps, display_interval=5)]
)

## Nami, One Piece

In [None]:
n_steps = 100
train(
    n_steps, 
    net, 
    initial_image_tensor, 
    "Nami, One Piece",
    hp=HyperParameters(lr=1e-2, low_grad_discarded_pct=0),
    after_step=[DisplayResultsCallback(n_steps, display_interval=5)]
)

## A pink-haired woman

In [None]:
n_steps = 100
train(
    n_steps, 
    net, 
    initial_image_tensor, 
    "A pink-haired woman",
    hp=HyperParameters(lr=1e-2, low_grad_discarded_pct=0),
    after_step=[DisplayResultsCallback(n_steps, display_interval=5)]
)

## A pink-haired man

In [None]:
n_steps = 100
train(
    n_steps, 
    net, 
    initial_image_tensor, 
    "A pink-haired man",
    hp=HyperParameters(lr=1e-2, low_grad_discarded_pct=0),
    after_step=[DisplayResultsCallback(n_steps, display_interval=5)]
)

## A woman with big eyes

In [None]:
n_steps = 100
train(
    n_steps, 
    net, 
    initial_image_tensor, 
    "A woman with big eyes",
    hp=HyperParameters(lr=1e-2, low_grad_discarded_pct=0),
    after_step=[DisplayResultsCallback(n_steps, display_interval=5)]
)

With feature loss, the result is more similar to the initial image:

In [None]:
n_steps = 100
target_text = "A woman with big eyes"
negative_texts = []
feature_loss = FeatureLoss(
    initial_image_tensor, 
    device=device, 
)
FEATURE_LOSS_W = 100.
train(
    n_steps, net, initial_image_tensor, target_text,
    negative_texts=negative_texts,
    hp=HyperParameters(lr=1e-2, low_grad_discarded_pct=0),
    after_step=[DisplayResultsCallback(n_steps, display_interval=5)],
    loss_fn=ComposedLoss([
        (ClipLoss(target_text, negative_texts=negative_texts, device=device), 1.),
        (feature_loss, FEATURE_LOSS_W)
    ])
)

## A woman that didn't apply sun screen

In [None]:
n_steps = 100
train(
    n_steps, 
    net, 
    initial_image_tensor, 
    "A woman that didn't apply sun screen",
    hp=HyperParameters(lr=1e-2, low_grad_discarded_pct=0),
    after_step=[DisplayResultsCallback(n_steps, display_interval=5)]
)

Without blonde hair:

In [None]:
n_steps = 100
train(
    n_steps, 
    net, 
    initial_image_tensor, 
    "A woman that didn't apply sun screen",
    negative_texts=["Blond hair"],
    hp=HyperParameters(lr=1e-2, low_grad_discarded_pct=0),
    after_step=[DisplayResultsCallback(n_steps, display_interval=5)]
)

## Rihanna

In [None]:
n_steps = 100
train(
    n_steps, 
    net, 
    initial_image_tensor, 
    "Rihanna",
    hp=HyperParameters(lr=1e-2, low_grad_discarded_pct=0),
    after_step=[DisplayResultsCallback(n_steps, display_interval=5)]
)

## Andrew Ng

In [None]:
initial_image_path = 'ffhq/00011.png'
initial_image = PIL.Image.open(initial_image_path)
print('Size = ', initial_image.size)
initial_image_tensor = img_transforms(initial_image)
initial_image

In [None]:
n_steps = 100
train(
    n_steps, 
    net, 
    initial_image_tensor, 
    "Andrew Ng",
    hp=HyperParameters(lr=1e-2, low_grad_discarded_pct=0),
    after_step=[DisplayResultsCallback(n_steps, display_interval=5)]
)

## A Draculesque man

In [None]:
n_steps = 100
train(
    n_steps, 
    net, 
    initial_image_tensor, 
    "A Draculesque man",
    hp=HyperParameters(lr=1e-2, low_grad_discarded_pct=0),
    after_step=[DisplayResultsCallback(n_steps, display_interval=5)]
)

## A white-haired young man

In [None]:
initial_image_path = 'ffhq/00012.png'
initial_image = PIL.Image.open(initial_image_path)
print('Size = ', initial_image.size)
initial_image_tensor = img_transforms(initial_image)
initial_image

In [None]:
n_steps = 100
train(
    n_steps, 
    net, 
    initial_image_tensor, 
    "A white-haired young man",
    hp=HyperParameters(lr=1e-2, low_grad_discarded_pct=0),
    after_step=[DisplayResultsCallback(n_steps, display_interval=5)]
)

## Bill Gates -> Elon Musk interpolation

In [None]:
!wget https://upload.wikimedia.org/wikipedia/commons/a/a0/Bill_Gates_2018.jpg
initial_image_path = './Bill_Gates_2018.jpg'
initial_image = PIL.Image.open(initial_image_path).crop((70, 40, 630, 600))
print('Size = ', initial_image.size)
initial_image_tensor = img_transforms(initial_image)
initial_image

In [None]:
n_steps = 100
train(
    n_steps, 
    net, 
    initial_image_tensor, 
    "Elon Musk",
    hp=HyperParameters(lr=1e-2, low_grad_discarded_pct=0),
    after_step=[DisplayResultsCallback(n_steps, display_interval=5)]
)

## Bill Gates -> Mark Zuckerberg interpolation

In [None]:
n_steps = 100
train(
    n_steps, 
    net, 
    initial_image_tensor, 
    "Mark Zuckerberg",
    hp=HyperParameters(lr=1e-2, low_grad_discarded_pct=0),
    after_step=[DisplayResultsCallback(n_steps, display_interval=5)]
)