論文  
https://arxiv.org/abs/2203.13248<br>
GitHub<br>
https://github.com/williamyang1991/DualStyleGAN<br>
<br>
<a href="https://colab.research.google.com/github/kaz12tech/ai_demos/blob/master/DualStyleGAN_demo.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 環境セットアップ

## GPU確認

In [None]:
!nvidia-smi

## GitHubからコード取得

In [None]:
%cd /content

!git clone https://github.com/williamyang1991/DualStyleGAN.git

## ライブラリのインストール

In [None]:
%cd /content

# ninja
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force 

!pip install faiss-cpu
!pip install wget
!pip install --upgrade --no-cache-dir gdown

## ライブラリのインポート

In [None]:
%cd /content/DualStyleGAN

%load_ext autoreload
%autoreload 2

import sys
sys.path.append(".")
sys.path.append("..")

import numpy as np
import torch
from util import save_image, load_image, visualize
import argparse
from argparse import Namespace
from torchvision import transforms
from torch.nn import functional as F
import torchvision
import matplotlib.pyplot as plt
from model.dualstylegan import DualStyleGAN
from model.sampler.icp import ICPTrainer
from model.encoder.psp import pSp
from model.encoder.align_all_parallel import align_face

import os
import gdown
import wget
import bz2
import dlib

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

MODEL_DIR = '/content/DualStyleGAN/checkpoint'
DATA_DIR = '/content/DualStyleGAN/data'

# Style選択

In [None]:
style_type = 'caricature' #@param ['cartoon', 'caricature', 'anime']
# 'arcane', 'comic', 'pixar', 'slamdunk'

os.makedirs(os.path.join(MODEL_DIR, style_type), exist_ok=True)

# 学習済みモデルのセットアップ

In [None]:
MODEL_PATHS = {
    "encoder": {"id": "1NgI4mPkboYvYw3MWcdUaQhkr0OWgs9ej", "name": "encoder.pt"},
    "cartoon-G": {"id": "1exS9cSFkg8J4keKPmq2zYQYfJYC5FkwL", "name": "generator.pt"},
    "cartoon-N": {"id": "1JSCdO0hx8Z5mi5Q5hI9HMFhLQKykFX5N", "name": "sampler.pt"},
    "cartoon-S": {"id": "1ce9v69JyW_Dtf7NhbOkfpH77bS_RK0vB", "name": "refined_exstyle_code.npy"},
    "caricature-G": {"id": "1BXfTiMlvow7LR7w8w0cNfqIl-q2z0Hgc", "name": "generator.pt"},
    "caricature-N": {"id": "1eJSoaGD7X0VbHS47YLehZayhWDSZ4L2Q", "name": "sampler.pt"},
    "caricature-S": {"id": "1-p1FMRzP_msqkjndRK_0JasTdwQKDsov", "name": "refined_exstyle_code.npy"},
    "anime-G": {"id": "1BToWH-9kEZIx2r5yFkbjoMw0642usI6y", "name": "generator.pt"},
    "anime-N": {"id": "19rLqx_s_SUdiROGnF_C6_uOiINiNZ7g2", "name": "sampler.pt"},
    "anime-S": {"id": "17-f7KtrgaQcnZysAftPogeBwz5nOWYuM", "name": "refined_exstyle_code.npy"},
    "arcane-G": {"id": "15l2O7NOUAKXikZ96XpD-4khtbRtEAg-Q", "name": "generator.pt"},
    "arcane-N": {"id": "1fa7p9ZtzV8wcasPqCYWMVFpb4BatwQHg", "name": "sampler.pt"},
    "arcane-S": {"id": "1z3Nfbir5rN4CrzatfcgQ8u-x4V44QCn1", "name": "exstyle_code.npy"},
    "comic-G": {"id": "1_t8lf9lTJLnLXrzhm7kPTSuNDdiZnyqE", "name": "generator.pt"},
    "comic-N": {"id": "1RXrJPodIn7lCzdb5BFc03kKqHEazaJ-S", "name": "sampler.pt"},
    "comic-S": {"id": "1ZfQ5quFqijvK3hO6f-YDYJMqd-UuQtU-", "name": "exstyle_code.npy"},
    "pixar-G": {"id": "1TgH7WojxiJXQfnCroSRYc7BgxvYH9i81", "name": "generator.pt"},
    "pixar-N": {"id": "18e5AoQ8js4iuck7VgI3hM_caCX5lXlH_", "name": "sampler.pt"},
    "pixar-S": {"id": "1I9mRTX2QnadSDDJIYM_ntyLrXjZoN7L-", "name": "exstyle_code.npy"},    
    "slamdunk-G": {"id": "1MGGxSCtyf9399squ3l8bl0hXkf5YWYNz", "name": "generator.pt"},
    "slamdunk-N": {"id": "1-_L7YVb48sLr_kPpOcn4dUq7Cv08WQuG", "name": "sampler.pt"},
    "slamdunk-S": {"id": "1Dgh11ZeXS2XIV2eJZAExWMjogxi_m_C8", "name": "exstyle_code.npy"},     
}

In [None]:
def get_download_model_command(file_id, file_name):
  download_path = os.path.join(MODEL_DIR, file_name)
  if not os.path.exists(download_path):
    gdown.download('https://drive.google.com/uc?id='+file_id, download_path, quiet=False)

In [None]:
# download pSp encoder
get_download_model_command(MODEL_PATHS["encoder"]["id"], MODEL_PATHS["encoder"]["name"])
# download dualstylegan
get_download_model_command(
    MODEL_PATHS[style_type+'-G']["id"], 
    os.path.join(style_type, MODEL_PATHS[style_type+'-G']["name"]) )
# download sampler
get_download_model_command(
    MODEL_PATHS[style_type+'-N']["id"], 
    os.path.join(style_type, MODEL_PATHS[style_type+'-N']["name"]) )
# download extrinsic style code
get_download_model_command(
    MODEL_PATHS[style_type+'-S']["id"], 
    os.path.join(style_type, MODEL_PATHS[style_type+'-S']["name"]) )

# モデルのロード

## Preprocess

In [None]:
transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

## load models

In [None]:
# DualStyleGANのロード
generator = DualStyleGAN(1024, 512, 8, 2, res_index=6)
generator.eval()
ckpt = torch.load(os.path.join(MODEL_DIR, style_type, 'generator.pt'), map_location=lambda storage, loc: storage)
generator.load_state_dict(ckpt["g_ema"])
generator = generator.to(device)

In [None]:
# encoderのロード
model_path = os.path.join(MODEL_DIR, 'encoder.pt')
ckpt = torch.load(model_path, map_location='cpu')
opts = ckpt['opts']
opts['checkpoint_path'] = model_path
opts = Namespace(**opts)
opts.device = device
encoder = pSp(opts)
encoder.eval()
encoder = encoder.to(device)

In [None]:
# extrinsic style codeのロード
exstyles = np.load(os.path.join(MODEL_DIR, style_type, MODEL_PATHS[style_type+'-S']["name"]), allow_pickle='TRUE').item()

In [None]:
# sampler networkのロード
icptc = ICPTrainer(np.empty([0,512*11]), 128)
icpts = ICPTrainer(np.empty([0,512*7]), 128)
ckpt = torch.load(os.path.join(MODEL_DIR, style_type, 'sampler.pt'), map_location=lambda storage, loc: storage)
icptc.icp.netT.load_state_dict(ckpt['color'])
icpts.icp.netT.load_state_dict(ckpt['structure'])
icptc.icp.netT = icptc.icp.netT.to(device)
icpts.icp.netT = icpts.icp.netT.to(device)

print('Model successfully loaded!')

# 画像のセットアップ
[使用画像1](https://www.pakutaso.com/20210224036post-33401.html)<br>
[使用画像2](https://www.pakutaso.com/20160130026post-6693.html)


In [None]:
%cd /content/DualStyleGAN
!rm -rf images output_images
!mkdir images output_images

# !wget -c https://www.pakutaso.com/shared/img/thumb/soraPAR59476_TP_V.jpg \
#       -O ./images/test1.jpg
!wget -c https://www.pakutaso.com/shared/img/thumb/max16011524_TP_V.jpg \
      -O ./images/test1.jpg

In [None]:
%matplotlib inline

In [None]:
image_path = './images/test1.jpg'
original_image = load_image(image_path)

plt.figure(figsize=(10,10),dpi=30)
visualize(original_image[0])
plt.show()

## Align face

In [None]:
def run_alignment(image_path):
    modelname = os.path.join(MODEL_DIR, 'shape_predictor_68_face_landmarks.dat')
    if not os.path.exists(modelname):
        wget.download('http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2', modelname+'.bz2')
        zipfile = bz2.BZ2File(modelname+'.bz2')
        data = zipfile.read()
        open(modelname, 'wb').write(data) 
    predictor = dlib.shape_predictor(modelname)
    aligned_image = align_face(filepath=image_path, predictor=predictor)
    return aligned_image

In [None]:
I = transform(run_alignment(image_path)).unsqueeze(dim=0).to(device)

plt.figure(figsize=(10,10),dpi=30)
visualize(I[0].cpu())
plt.show()

# スタイル転送

## style_id指定
[こちら](https://github.com/williamyang1991/DualStyleGAN/#1-dataset-preparation)からtrain画像を取得していない場合repositoryにデフォルトで格納された数枚から選択

## styleimageのロード

In [None]:
if style_type == "anime":
  # stylepath = "/content/DualStyleGAN/data/anime/images/train/16031200.jpg"
  stylepath = "/content/DualStyleGAN/data/anime/images/train/23075800.jpg"
elif style_type == "caricature":
  stylepath = "/content/DualStyleGAN/data/caricature/images/train/Hillary_Clinton_C00034.jpg"
  # stylepath = "/content/DualStyleGAN/data/caricature/images/train/Liv_Tyler_C00009.jpg"
elif style_type == "cartoon":
  stylepath = "/content/DualStyleGAN/data/cartoon/images/train/Cartoons_00003_01.jpg"
  # stylepath = "/content/DualStyleGAN/data/cartoon/images/train/Cartoons_00038_07.jpg"
  # stylepath = "/content/DualStyleGAN/data/cartoon/images/train/Cartoons_00167_01.jpg"
else:
  print(exstyles.keys())
  raise Exception("Please download train images.")

stylename = os.path.basename(stylepath)

In [None]:
# style imageのロード
print('loading %s'%stylepath)
if os.path.exists(stylepath):
    S = load_image(stylepath)
    plt.figure(figsize=(10,10),dpi=30)
    visualize(S[0])
    plt.show()
else:
    print('%s is not found'%stylename)

## style転送

In [None]:
with torch.no_grad():
    img_rec, instyle = encoder(I, randomize_noise=False, return_latents=True, 
                            z_plus_latent=True, return_z_plus_latent=True, resize=False)    
    img_rec = torch.clamp(img_rec.detach(), -1, 1)
    
    latent = torch.tensor(exstyles[stylename]).repeat(2,1,1).to(device)
    # latent[0] for both color and structrue transfer and latent[1] for only structrue transfer
    latent[1,7:18] = instyle[0,7:18]
    exstyle = generator.generator.style(latent.reshape(latent.shape[0]*latent.shape[1], latent.shape[2])).reshape(latent.shape)
    
    img_gen, _ = generator([instyle.repeat(2,1,1)], exstyle, z_plus_latent=True, 
                           truncation=0.7, truncation_latent=0, use_res=True, interp_weights=[0.6]*7+[1]*11)
    img_gen = torch.clamp(img_gen.detach(), -1, 1)
    # deactivate color-related layers by setting w_c = 0
    img_gen2, _ = generator([instyle], exstyle[0:1], z_plus_latent=True, 
                            truncation=0.7, truncation_latent=0, use_res=True, interp_weights=[0.6]*7+[0]*11)
    img_gen2 = torch.clamp(img_gen2.detach(), -1, 1)

## 結果の表示
左から


1.   pSpで再構成したコンテンツ画像
2.   colorとstructureをスタイル転送した画像
3.   コンテンツ画像の色に置き換え、コンテンツ画像の色を再現したスタイル転送画像
4.   色関連のレイヤーを非アクティブにすることによりコンテンツ画像の色を保持したスタイル転送画像

In [None]:
vis = torchvision.utils.make_grid(F.adaptive_avg_pool2d(torch.cat([img_rec, img_gen, img_gen2], dim=0), 256), 4, 1)
plt.figure(figsize=(10,10),dpi=120)
visualize(vis.cpu())
plt.show()

# weightを調整したスタイル転送

In [None]:
!rm -rf "/content/DualStyleGAN/output_images"
!mkdir "/content/DualStyleGAN/output_images"

In [None]:
results = []
s_root = 12
num = s_root*s_root
for i in range(num): 
  structrue_w = [i/num]*7 # structure codesのweightを変更
  color_w = [i/num]*11 # color codesのweightを変更

  w = structrue_w + color_w  
  img_gen, _ = generator(
      [instyle], exstyle[0:1], z_plus_latent=True, 
      truncation=0.7, truncation_latent=0, use_res=True, interp_weights=w)
  img_gen = torch.clamp(F.adaptive_avg_pool2d(img_gen.detach(), 512), -1, 1)
  results += [img_gen]
  
  # save image
  sv_img = torchvision.utils.make_grid(torch.cat([img_gen], dim=0), 1, 1)
  sv_img = ((sv_img.cpu().detach().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)
  filename = os.path.join("/content/DualStyleGAN/output_images", "result_" + f'{i:06}' + ".jpg")
  plt.imsave(filename, sv_img)
        
vis = torchvision.utils.make_grid(torch.cat(results, dim=0), s_root, 1)
plt.figure(figsize=(10,10),dpi=120)
visualize(vis.cpu())
plt.show()

In [None]:
!ffmpeg -i "/content/DualStyleGAN/output_images/result_%06d.jpg" -c:v libx264 -vf "format=yuv420p" "/content/DualStyleGAN/output_images/result.mp4"

In [None]:
from moviepy.editor import *
from moviepy.video.fx.resize import resize
clip = VideoFileClip("/content/DualStyleGAN/output_images/result.mp4")
clip = resize(clip, height=420)
clip.ipython_display()