# **Face Style Sketches Creator Using JojoGAN**

We will employ and train the JojoGAN model to generate face style sketches by finetunning the StyleGAN model.

**Clone JojoGAN Repo And Install Required Libraries**

In [1]:
!git clone https://github.com/mchong6/JoJoGAN.git
%cd JoJoGAN
!pip install tqdm gdown scikit-learn==0.22 scipy lpips dlib opencv-python wandb
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/

Cloning into 'JoJoGAN'...
remote: Enumerating objects: 490, done.[K
remote: Counting objects: 100% (40/40), done.[K
remote: Compressing objects: 100% (9/9), done.[K
remote: Total 490 (delta 31), reused 31 (delta 31), pack-reused 450[K
Receiving objects: 100% (490/490), 63.51 MiB | 11.81 MiB/s, done.
Resolving deltas: 100% (203/203), done.
/content/JoJoGAN
Collecting scikit-learn==0.22
  Downloading scikit-learn-0.22.tar.gz (6.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m49.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
Collecting wandb
  Downloading wandb-0.16.1-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m75.2 MB/s[0m eta [36m0:00:00[0m
Colle

In [3]:
!pip install wandb

Collecting wandb
  Using cached wandb-0.16.1-py3-none-any.whl (2.1 MB)
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Using cached GitPython-3.1.40-py3-none-any.whl (190 kB)
Collecting sentry-sdk>=1.0.0 (from wandb)
  Using cached sentry_sdk-1.39.1-py2.py3-none-any.whl (254 kB)
Collecting docker-pycreds>=0.4.0 (from wandb)
  Using cached docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting setproctitle (from wandb)
  Using cached setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->wandb)
  Using cached gitdb-4.0.11-py3-none-any.whl (62 kB)
Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb)
  Using cached smmap-5.0.1-py3-none-any.whl (24 kB)
Installing collected packages: smmap, setproctitle, sentry-sdk, docker-pycreds, gitdb, GitPython, wandb
Successfully installed GitPython-3.1.40 docker-pycreds-0.4.0 gitdb-

**Import Libraries**

In [4]:
import torch
torch.backends.cudnn.benchmark = True
from torchvision import transforms, utils
from util import *
from PIL import Image
import math
import random
import os
import numpy
from torch import nn, autograd, optim
from torch.nn import functional
from tqdm import tqdm
import wandb
from model import *
from e4e_projection import projection
from google.colab import files
from copy import deepcopy
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials



**Create Storage Folders**

In [6]:
os.makedirs('inversion_codes', exist_ok=True)
os.makedirs('style_images', exist_ok=True)
os.makedirs('style_images_aligned', exist_ok=True)
os.makedirs('models', exist_ok=True)

**Download Pydrive**

Download the model files with Pydrive.

In [7]:
download_with_pydrive = True
device = 'cuda' #['cuda', 'cpu']

In [8]:
!wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
!bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2
!mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat
%matplotlib inline

--2024-01-08 07:44:03--  http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
Resolving dlib.net (dlib.net)... 107.180.26.78
Connecting to dlib.net (dlib.net)|107.180.26.78|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 64040097 (61M)
Saving to: ‘shape_predictor_68_face_landmarks.dat.bz2’


2024-01-08 07:44:17 (4.25 MB/s) - ‘shape_predictor_68_face_landmarks.dat.bz2’ saved [64040097/64040097]



**Set Drive IDs**

In [9]:
drive_ids = {
    "stylegan2-ffhq-config-f.pt": "1Yr7KuD959btpmcKGAUsbAk5rPjX2MytK",
    "e4e_ffhq_encode.pt": "1o6ijA3PkcewZvwJJ73dJ0fxhndn0nnh7",
    "restyle_psp_ffhq_encode.pt": "1nbxCIVw9H3YnQsoIPykNEFwWJnHVHlVd",
    "arcane_caitlyn.pt": "1gOsDTiTPcENiFOrhmkkxJcTURykW1dRc",
    "arcane_caitlyn_preserve_color.pt": "1cUTyjU-q98P75a8THCaO545RTwpVV-aH",
    "arcane_jinx_preserve_color.pt": "1jElwHxaYPod5Itdy18izJk49K1nl4ney",
    "arcane_jinx.pt": "1quQ8vPjYpUiXM4k1_KIwP4EccOefPpG_",
    "arcane_multi_preserve_color.pt": "1enJgrC08NpWpx2XGBmLt1laimjpGCyfl",
    "arcane_multi.pt": "15V9s09sgaw-zhKp116VHigf5FowAy43f",
    "sketch_multi.pt": "1GdaeHGBGjBAFsWipTL0y-ssUiAqk8AxD",
    "disney.pt": "1zbE2upakFUAx8ximYnLofFwfT8MilqJA",
    "disney_preserve_color.pt": "1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_tsi",
    "jojo.pt": "13cR2xjIBj8Ga5jMO7gtxzIJj2PDsBYK4",
    "jojo_preserve_color.pt": "1ZRwYLRytCEKi__eT2Zxv1IlV6BGVQ_K2",
    "jojo_yasuho.pt": "1grZT3Gz1DLzFoJchAmoj3LoM9ew9ROX_",
    "jojo_yasuho_preserve_color.pt": "1SKBu1h0iRNyeKBnya_3BBmLr4pkPeg_L",
    "art.pt": "1a0QDEHwXQ6hE_FcYEyNMuv5r5UnRQLKT",
}

**Define Downloader Class**

Define a class to download files. This implementation is from StyleGAN-NADA.

In [10]:
class Downloader(object):
    def __init__(self, use_pydrive):
        self.use_pydrive = use_pydrive


        if self.use_pydrive:
            self.authenticate()
    def authenticate(self):
        auth.authenticate_user()
        gauth = GoogleAuth()
        gauth.credentials = GoogleCredentials.get_application_default()
        self.drive = GoogleDrive(gauth)
    def download_file(self, file_name):
        file_dst = os.path.join('models', file_name)
        file_id = drive_ids[file_name]
        if not os.path.exists(file_dst):
            print(f'Downloading {file_name}')
            if self.use_pydrive:
                downloaded = self.drive.CreateFile({'id':file_id})
                downloaded.FetchMetadata(fetch_all=True)
                downloaded.GetContentFile(file_dst)
            else:
                !gdown --id $file_id -O $file_dst

**Download Files**

In [11]:
downloader = Downloader(download_with_pydrive)

downloader.download_file('stylegan2-ffhq-config-f.pt')
downloader.download_file('e4e_ffhq_encode.pt')

Downloading stylegan2-ffhq-config-f.pt
Downloading e4e_ffhq_encode.pt


**Load Generators**

Load the original and finetuned generators.

In [40]:
latent_dim = 512

# Load original generator
original_generator = Generator(1024, latent_dim, 8, 2).to(device)

ckpt = torch.load('models/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)

original_generator.load_state_dict(ckpt["g_ema"], strict=False)

mean_latent = original_generator.mean_latent(10000)

# to be finetuned generator
generator = deepcopy(original_generator)

**Set Tranformer**

Set the transformer for resizing and normalizing the images.

In [41]:
transform = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

**Set Input Image Location**

In [16]:
filename = 'iu.jpeg' #@param {type:"string"}
filepath = f'test_input/{filename}'
name = strip_path_extension(filepath)+'.pt'

**Align And Crop Face**

In [17]:
aligned_face = align_face(filepath)

**Restyle The Projection**

In [18]:
# my_w = restyle_projection(aligned_face, name, device, n_iters=1).unsqueeze(0)
my_w = projection(aligned_face, name, device).unsqueeze(0)

Loading e4e over the pSp framework from checkpoint: models/e4e_ffhq_encode.pt


**Select Pre-Trained Sketch Type**

In [21]:
plt.rcParams['figure.dpi'] = 150
pretrained = 'sketch_multi' #['art', 'arcane_multi', 'sketch_multi', 'arcane_jinx', 'arcane_caitlyn', 'jojo_yasuho', 'jojo', 'disney']

**Generate Results**

Load the checkpoint and generator,set a seed, and start generating a stylized image.

In [23]:
n_sample =  5#{type:"number"}
seed = 3000 #{type:"number"}
torch.manual_seed(seed)
with torch.no_grad():
    generator.eval()
    z = torch.randn(n_sample, latent_dim, device=device)
    original_sample = original_generator([z], truncation=0.7, truncation_latent=mean_latent)
    sample = generator([z], truncation=0.7, truncation_latent=mean_latent)
    original_my_sample = original_generator(my_w, input_is_latent=True)
    my_sample = generator(my_w, input_is_latent=True)



**Display Reference Images**

In [24]:
if pretrained == 'arcane_multi':
    style_path = f'style_images_aligned/arcane_jinx.png'
elif pretrained == 'sketch_multi':
    style_path = f'style_images_aligned/sketch.png'
else:
    style_path = f'style_images_aligned/{pretrained}.png'
style_image = transform(Image.open(style_path)).unsqueeze(0).to(device)
face = transform(aligned_face).unsqueeze(0).to(device)


my_output = torch.cat([style_image, face, my_sample], 0)

**Train Model With Style Images**

Select Face Sketch Images and load them to train the model.

In [28]:
names = ['sketch.jpeg', 'sketch2.jpeg', 'sketch3.jpeg']
targets = []
latents = []
for name in names:
    style_path = os.path.join('style_images', name)
    assert os.path.exists(style_path), f"{style_path} does not exist!"
    name = strip_path_extension(name)


    # crop and align the face
    style_aligned_path = os.path.join('style_images_aligned', f'{name}.png')
    if not os.path.exists(style_aligned_path):
        style_aligned = align_face(style_path)
        style_aligned.save(style_aligned_path)
    else:
        style_aligned = Image.open(style_aligned_path).convert('RGB')


    # GAN invert
    style_code_path = os.path.join('inversion_codes', f'{name}.pt')
    if not os.path.exists(style_code_path):
        latent = projection(style_aligned, style_code_path, device)
    else:
        latent = torch.load(style_code_path)['latent']
    latents.append(latent.to(device))
#targets = torch.stack(targets, 0)
latents = torch.stack(latents, 0)

**Finetune StyleGAN**

Set alpha which controls the strength of the style.

In [29]:
alpha =  1.0 # min:0, max:1, step:0.1
alpha = 1-alpha

Preserve color of original image by limiting family of allowable transformations.

In [31]:
preserve_color = False
#Number of finetuning steps.
num_iter = 300
#Log training on wandb and interval for image logging
use_wandb = False
log_interval = 50

if use_wandb:
    wandb.init(project="JoJoGAN")
    config = wandb.config
    config.num_iter = num_iter
    config.preserve_color = preserve_color
    wandb.log(
    {"Style reference": [wandb.Image(transforms.ToPILImage()(target_im))]},
    step=0)

Load discriminator for perceptual loss.

In [32]:
discriminator = Discriminator(1024, 2).eval().to(device)
ckpt = torch.load('models/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
discriminator.load_state_dict(ckpt["d"], strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['final_conv.0.weight', 'final_conv.1.bias', 'final_linear.0.weight', 'final_linear.0.bias', 'final_linear.1.weight', 'final_linear.1.bias'])

**Reset Generator**

In [33]:
del generator
generator = deepcopy(original_generator)
g_optim = optim.Adam(generator.parameters(), lr=2e-3, betas=(0, 0.99))

**Train Generator**

Train the generator to generated image from the latent space, and optimize the loss.

In [37]:
if preserve_color:
  id_swap = [9,11,15,16,17]

  z = range(num_iter)
  for idx in tqdm( z):
    mean_w = generator.get_latent(torch.randn([latents.size(0), latent_dim]).to(device)).unsqueeze(1).repeat(1, generator.n_latent, 1)

    in_latent = latents.clone()

    in_latent[:, id_swap] = alpha*latents[:, id_swap] + (1-alpha)*mean_w[:, id_swap]

    img = generator(in_latent, input_is_latent=True)

    with torch.no_grad():

      real_feat = discriminator(targets)

    fake_feat = discriminator(img)

    loss = sum([functional.l1_loss(a, b) for a, b in zip(fake_feat, real_feat)])/len(fake_feat)

    if use_wandb:

      wandb.log({"loss": loss}, step=idx)

      if idx % log_interval == 0:

        generator.eval()

        my_sample = generator(my_w, input_is_latent=True)

        generator.train()

        wandb.log(

        {"Current stylization": [wandb.Image(my_sample)]},

        step=idx)

    g_optim.zero_grad()

    loss.backward()

    g_optim.step()

**Generate JojoGAN Results**

In [38]:
n_sample =  5
seed = 3000
torch.manual_seed(seed)
with torch.no_grad():
    generator.eval()
    z = torch.randn(n_sample, latent_dim, device=device)
    original_sample = original_generator([z], truncation=0.7, truncation_latent=mean_latent)
    sample = generator([z], truncation=0.7, truncation_latent=mean_latent)
    original_my_sample = original_generator(my_w, input_is_latent=True)
    my_sample = generator(my_w, input_is_latent=True)



**Display Reference Images**

In [39]:
style_images = []
for name in names:
    style_path = f'style_images_aligned/{strip_path_extension(name)}.png'
    style_image = transform(Image.open(style_path))
    style_images.append(style_image)
face = transform(aligned_face).to(device).unsqueeze(0)
style_images = torch.stack(style_images, 0).to(device)

my_output = torch.cat([face, my_sample], 0)
output = torch.cat([original_sample, sample], 0)