<a href="https://colab.research.google.com/github/kenken0830/Graphing/blob/main/e4e_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# セットアップ

In [1]:
# --- セットアップ ---

import os
os.chdir('/content')
CODE_DIR = 'encoder4editing'

!git clone https://github.com/cedro3/encoder4editing.git $CODE_DIR
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force
os.chdir(f'./{CODE_DIR}')

from argparse import Namespace
import time
import os
import sys
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms

sys.path.append(".")
sys.path.append("..")

from utils.common import tensor2im
from models.psp import pSp  # we use the pSp framework to load the e4e encoder.

%load_ext autoreload
%autoreload 2

# 学習済みパラメータのダウンロード
! pip install --upgrade gdown
import os
import gdown
os.makedirs('pretrained_models', exist_ok=True)
gdown.download('https://drive.google.com/u/1/uc?id=1Du_8FzOPKJhk6aJmiOBhAWVe3_6vAyET', 'pretrained_models/e4e_ffhq_encode.pt', quiet=False)

# ランドマークデータのダウンロード
! wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
! bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2

# モデルに学習済みパラメータをロード
model_path = 'pretrained_models/e4e_ffhq_encode.pt'  ####
ckpt = torch.load(model_path, map_location='cpu')
opts = ckpt['opts']
opts['checkpoint_path'] = model_path
opts= Namespace(**opts)
net = pSp(opts)
net.eval()
net.cuda()
print('Model successfully loaded!')

Cloning into 'encoder4editing'...
remote: Enumerating objects: 233, done.[K
remote: Total 233 (delta 0), reused 0 (delta 0), pack-reused 233[K
Receiving objects: 100% (233/233), 35.00 MiB | 26.41 MiB/s, done.
Resolving deltas: 100% (78/78), done.
--2022-10-09 11:36:41--  https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
Resolving github.com (github.com)... 140.82.112.4
Connecting to github.com (github.com)|140.82.112.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/1335132/d2f252e2-9801-11e7-9fbf-bc7b4e4b5c83?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20221009%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20221009T113641Z&X-Amz-Expires=300&X-Amz-Signature=0da4a4cd4bbbec589ef813414fe04933949e930f21b645a70ef1c527101f2669&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=1335132&response-content-disposition=attachment%3B%2

Downloading...
From: https://drive.google.com/u/1/uc?id=1Du_8FzOPKJhk6aJmiOBhAWVe3_6vAyET
To: /content/encoder4editing/pretrained_models/e4e_ffhq_encode.pt
100%|██████████| 1.20G/1.20G [00:09<00:00, 128MB/s]


--2022-10-09 11:38:11--  http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
Resolving dlib.net (dlib.net)... 107.180.26.78
Connecting to dlib.net (dlib.net)|107.180.26.78|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 64040097 (61M)
Saving to: ‘shape_predictor_68_face_landmarks.dat.bz2’


2022-10-09 11:38:12 (75.0 MB/s) - ‘shape_predictor_68_face_landmarks.dat.bz2’ saved [64040097/64040097]

Loading e4e over the pSp framework from checkpoint: pretrained_models/e4e_ffhq_encode.pt
Model successfully loaded!


# 事前準備


In [7]:
# --- 顔画像の切り出し ---

import os
import shutil
from tqdm import tqdm

if os.path.isdir('align'):
     shutil.rmtree('align')
os.makedirs('align', exist_ok=True)

def run_alignment(image_path):
  import dlib
  from utils.alignment import align_face
  predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
  aligned_image = align_face(filepath=image_path, predictor=predictor) 
  return aligned_image 

path = './images'
files = sorted(os.listdir(path))
for i, file in enumerate(tqdm(files)):
  if file=='.ipynb_checkpoints':
     continue
  input_image = run_alignment(path+'/'+file)
  input_image.resize((256,256))
  input_image.save('./align/'+file)

100%|██████████| 23/23 [00:27<00:00,  1.18s/it]


In [8]:
# --- 潜在変数の推定 ---

if os.path.isdir('vec_pic'):
     shutil.rmtree('vec_pic')
os.makedirs('vec_pic', exist_ok=True)

if os.path.isdir('vec'):
     shutil.rmtree('vec')
os.makedirs('vec', exist_ok=True)

img_transforms = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

path = './align'
files = sorted(os.listdir(path))
for i, file in enumerate(tqdm(files)):
  if file=='.ipynb_checkpoints':
     continue
  input_image = Image.open(path+'/'+file)
  transformed_image = img_transforms(input_image)
  with torch.no_grad():
     images, latents = net(transformed_image.unsqueeze(0).to('cuda').float(), randomize_noise=False, return_latents=True)
     result_image, latent = images[0], latents[0]
     tensor2im(result_image).save('./vec_pic/'+file) # vec_pic 保存
     torch.save(latents, './vec/'+file[:-4]+'.pt') # vec  保存

100%|██████████| 22/22 [00:07<00:00,  2.75it/s]


In [9]:
# --- 元画像と生成画像の表示 ---

import matplotlib.pyplot as plt
from PIL import Image
import os
def display_pic(folder):
    fig = plt.figure(figsize=(30, 40))
    files = os.listdir(folder)
    files.sort()
    for i, file in enumerate(files):
        img = Image.open(folder+'/'+file)    
        images = np.asarray(img)
        ax = fig.add_subplot(10, 10, i+1, xticks=[], yticks=[])
        image_plt = np.array(images)
        ax.imshow(image_plt)
        ax.set_xlabel(folder+'/'+file, fontsize=15)               
    plt.show()
    plt.close()  

display_pic('align')
display_pic('vec_pic')

# 画像編集

In [102]:
#@title 設定
latent = "oonishi1.pt"#@param {type:"string"}
direction = "age+pose" #@param ["age", "pose", "smile", "age+pose"] {allow-input: true}
min = -50 #@param {type:"slider", min:-50, max:0, step:10}
max = 50 #@param {type:"slider", min:0, max:50, step:10}

In [103]:
# --- 静止画の生成 ---

import os
import shutil
if os.path.isdir('pic'):
     shutil.rmtree('pic')
os.makedirs('pic', exist_ok=True)

from editings import latent_editor
from tqdm import trange

folder = 'vec'
latents = torch.load(folder+'/'+latent)
editor = latent_editor.LatentEditor(net.decoder, False)

interfacegan_directions = {
        'age': 'editings/interfacegan_directions/age.pt',
        'smile': 'editings/interfacegan_directions/smile.pt',
        'pose': 'editings/interfacegan_directions/pose.pt',
        'age+pose':  'editings/interfacegan_directions/age+pose.pt'
    }

interfacegan_direction = torch.load(interfacegan_directions[direction]).cuda()
cnt = 0

for i in trange(0, min, -1, desc='0 -> min'):
     result = editor.apply_interfacegan(latents, interfacegan_direction, factor=i).resize((512,512))
     result.save('./pic/'+str(cnt).zfill(6)+'.jpg')
     cnt +=1

for i in trange(min, max, desc='min -> max'):
     result = editor.apply_interfacegan(latents, interfacegan_direction, factor=i).resize((512,512))
     result.save('./pic/'+str(cnt).zfill(6)+'.jpg')
     cnt +=1

for i in trange(max, 0, -1, desc='max -> 0'):
     result = editor.apply_interfacegan(latents, interfacegan_direction, factor=i).resize((512,512))
     result.save('./pic/'+str(cnt).zfill(6)+'.jpg')
     cnt +=1

0 -> min: 100%|██████████| 50/50 [00:06<00:00,  7.79it/s]
min -> max: 100%|██████████| 100/100 [00:12<00:00,  8.17it/s]
max -> 0: 100%|██████████| 50/50 [00:05<00:00,  8.50it/s]


In [104]:
# --- mp4動画の作成 ---

# 既に output.mp4 があれば削除する
import os
if os.path.exists('./output.mp4'):
   os.remove('./output.mp4')

# pic フォルダーの静止画から動画を作成
! ffmpeg -r 30 -i pic/%6d.jpg\
               -vcodec libx264 -pix_fmt yuv420p output.mp4

# movieフォルダへ名前を付けてコピー
import shutil
os.makedirs('movie', exist_ok=True)
shutil.copy('output.mp4', 'movie/'+direction+'_'+latent[:-3]+'.mp4')

ffmpeg version 3.4.11-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers
  built with gcc 7 (Ubuntu 7.5.0-3ubuntu1~18.04)
  configuration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --enable-gpl --disable-stripping --enable-avresample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librubberband --enable-librsvg --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-li

'movie/age+pose_oonishi1.mp4'

In [105]:
# --- mp4動画の再生 ---
from IPython.display import HTML
from base64 import b64encode

mp4 = open('./output.mp4', 'rb').read()
data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()
HTML(f"""
<video width="50%" height="50%" controls>
      <source src="{data_url}" type="video/mp4">
</video>""")