Google Colab上でHairCLIPを動作させる環境です。

- Respect: [[HairCLIP] 機械学習で顔写真の髪型や髪色を変更する](https://www.12-technology.com/2022/05/hairclip.html)
- GitHub: [HairCLIP_demo.ipynb](https://github.com/kaz12tech/ai_demos/blob/main/HairCLIP_demo.ipynb)

上記を参考にさせていただきました。基本的には上記のipynbファイルのまま、GoogleDriveマウントとカメラ表示だけ追加しています。

HairCLIPの詳細は下記参照。
- 論文: https://arxiv.org/abs/2112.05142
- GitHub: https://github.com/wty-ustc/HairCLIP


# 環境セットアップ
## GPU確認

In [None]:
!nvidia-smi

# Google Drive マウント

In [None]:
from google.colab import drive
drive.mount('/content/drive')

1回目だけ、Drive上にフォルダ作成、GitHubからコード取得。

In [None]:
!mkdir drive/MyDrive/Colab\ Notebooks/HairCLIP
%cd drive/MyDrive/Colab\ Notebooks/HairCLIP
base_dir = "/content/drive/MyDrive/Colab Notebooks/HairCLIP"

# GitHubからコード取得

In [None]:
# %cd /content
%cd /content/drive/MyDrive/Colab\ Notebooks/HairCLIP

!git clone https://github.com/wty-ustc/HairCLIP.git
!git clone https://github.com/omertov/encoder4editing.git
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force

# ライブラリのインストール

In [None]:
# %cd HairCLIP
%cd  /content/drive/MyDrive/Colab\ Notebooks/HairCLIP/HairCLIP

!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install tensorflow-io
!pip install --upgrade --no-cache-dir gdown

# ライブラリのインポート

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks/HairCLIP/encoder4editing

from utils.alignment import align_face
from models.psp import pSp

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks/HairCLIP/HairCLIP

import os
import gdown
from argparse import ArgumentParser

import sys
sys.path.append(".")
sys.path.append("..")
import tempfile
from argparse import Namespace

import dlib
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

import imageio
from IPython.display import HTML
from base64 import b64encode
import glob
import numpy as np
from PIL import Image
import random

from criteria.parse_related_loss import average_lab_color_loss

In [None]:
from IPython.display import Image
try:
  filename = take_photo()
  print('Saved to {}'.format(filename))
  
  # Show the image which was just taken.
  display(Image(filename))
except Exception as err:
  # Errors will be thrown if the user does not have a webcam or if they do not
  # grant the page permission to access it.
  print(str(err))

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks/HairCLIP/HairCLIP/mapper

from mapper.datasets.latents_dataset_inference import LatentsDatasetInference
from mapper.hairclip_mapper import HairCLIPMapper

# 学習済みモデルのダウンロード
Access denied with the following error:
が発生する場合、何回か実行

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks/HairCLIP

from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

download_with_pydrive = True

class Downloader(object):
    def __init__(self, use_pydrive):
        self.use_pydrive = use_pydrive
        current_directory = os.getcwd()
        
        self.save_dir = base_dir + "/HairCLIP/pretrained_models"
        os.makedirs(self.save_dir, exist_ok=True)
        if self.use_pydrive:
            self.authenticate()

    def authenticate(self):
        auth.authenticate_user()
        gauth = GoogleAuth()
        gauth.credentials = GoogleCredentials.get_application_default()
        self.drive = GoogleDrive(gauth)

    def download_file(self, file_id, file_name):
        file_dst = f'{self.save_dir}/{file_name}'
        if os.path.exists(file_dst):
            print(f'{file_dst} already exists!')
            return
        if self.use_pydrive:
            downloaded = self.drive.CreateFile({'id':file_id})
            downloaded.FetchMetadata(fetch_all=True)
            downloaded.GetContentFile(file_dst)
        else:
            !gdown --id $file_id -O $file_dst

downloader = Downloader(download_with_pydrive)
downloader.download_file(file_id="1cUv_reLE6k3604or78EranS7XzuVMWeO", file_name="e4e_ffhq_encode.pt")

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks/HairCLIP/HairCLIP

if not os.path.exists("./pretrained_models/hairclip.pt"):
  gdown.download('https://drive.google.com/uc?id=1hqZT6ZMldhX3M_x378Sm4Z2HMYr-UwQ4', "./pretrained_models/hairclip.pt", quiet=False)
if not os.path.exists("./pretrained_models/stylegan2-ffhq-config-f.pt"):
  gdown.download('https://drive.google.com/uc?id=1pts5tkfAcWrg4TpLDu6ILF5wHID32Nzm', "./pretrained_models/stylegan2-ffhq-config-f.pt", quiet=False)
if not os.path.exists("./pretrained_models/model_ir_se50.pth"):
  gdown.download('https://drive.google.com/uc?id=1FS2V756j-4kWduGxfir55cMni5mZvBTv', "./pretrained_models/model_ir_se50.pth", quiet=False)


# if not os.path.exists("./pretrained_models/test_faces.pt"):
#   gdown.download('https://drive.google.com/uc?id=1j7RIfmrCoisxx3t-r-KC02Qc8barBecr', "./pretrained_models/test_faces.pt", quiet=False)

if not os.path.exists("./pretrained_models/shape_predictor_68_face_landmarks.dat.bz2"):
  !wget -c http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 \
        -O ./pretrained_models/shape_predictor_68_face_landmarks.dat.bz2
  !bzip2 -dk ./pretrained_models/shape_predictor_68_face_landmarks.dat.bz2

# テスト画像のセットアップ
## テスト画像のダウンロード
[サンプル画像](https://www.pakutaso.com/20190117010post-18492.html)



In [None]:
%cd /content/drive/MyDrive/Colab Notebooks/HairCLIP/HairCLIP
!mkdir demo

!wget -c https://www.pakutaso.com/shared/img/thumb/model10211041_TP_V4.jpg \
      -O ./demo/model10211041_TP_V4.jpg

# Edit Hair
## Set parametor

In [None]:
%cd /content/drive/MyDrive/Colab Notebooks/HairCLIP/HairCLIP

# @markdown 入力画像
# image_path = base_dir + "/HairCLIP/demo/model10211041_TP_V4.jpg" #@param {type:"string"}
image_path = base_dir + "/HairCLIP/demo/face_demo.JPG" #@param {type:"string"}

# @markdown editing_type="both"でrandomに生成<br>
# @markdown randomの場合下記の設定は反映されません。
IsRandom = False #@param {type:"boolean"}
random_num = 30 #@param {type:"integer"}

# @markdown randomではない場合は以下設定<br>
# @markdown 編集タイプ colorのみ、styleのみ、両方
editing_type = "both" #@param["hairstyle", "color", "both"]
# @markdown 髪型選択
hairstyle_description = "crew cut hairstyle" #@param["afro hairstyle", "bob cut hairstyle", "bowl cut hairstyle", "braid hairstyle", "caesar cut hairstyle", "chignon hairstyle", "cornrows hairstyle", "crew cut hairstyle", "crown braid hairstyle", "curtained hair hairstyle", "dido flip hairstyle", "dreadlocks hairstyle", "extensions hairstyle", "fade hairstyle", "fauxhawk hairstyle", "finger waves hairstyle", "french braid hairstyle", "frosted tips hairstyle", "full crown hairstyle", "harvard clip hairstyle", "high and tight hairstyle", "hime cut hairstyle", "hi-top fade hairstyle","jewfro hairstyle", "jheri curl hairstyle", "liberty spikes hairstyle", "marcel waves hairstyle", "mohawk hairstyle", "pageboy hairstyle", "perm hairstyle", "pixie cut hairstyle", "psychobilly wedge hairstyle", "quiff hairstyle", "regular taper cut hairstyle", "ringlets hairstyle", "shingle bob hairstyle", "short hair hairstyle", "slicked-back hairstyle", "spiky hair hairstyle","surfer hair hairstyle", "taper cut hairstyle", "the rachel hairstyle", "undercut hairstyle", "updo hairstyle"]
# @markdown 髪色選択
color_description = "yellow" #@param["purple", "red", "orange", "yellow", "green", "blue", "gray", "brown", "black", "white", "blond", "pink"]

# 出力先ディレクトリ作成
!mkdir outputs

from IPython.display import Image,display_jpeg
display_jpeg(Image(image_path))

# Define functions

In [None]:
def run_alignment(image_path):
  predictor = dlib.shape_predictor(base_dir + "/HairCLIP/pretrained_models/shape_predictor_68_face_landmarks.dat")
  aligned_image = align_face(filepath=image_path, predictor=predictor)
  print("Aligned image has shape: {}".format(aligned_image.size))
  return aligned_image

In [None]:
def run_on_batch_e4e(inputs, net):
  images, latents = net(
      inputs.to("cuda").float(), randomize_noise=False, return_latents=True
      )
  return images, latents

In [None]:
def run_on_batch(
    inputs,
    hairstyle_text_inputs,
    color_text_inputs,
    hairstyle_tensor_hairmasked,
    color_tensor_hairmasked,
    net,
):
    w = inputs
    with torch.no_grad():
        w_hat = w + 0.1 * net.mapper(
            w,
            hairstyle_text_inputs,
            color_text_inputs,
            hairstyle_tensor_hairmasked,
            color_tensor_hairmasked,
        )
        x_hat, w_hat = net.decoder(
            [w_hat],
            input_is_latent=True,
            return_latents=True,
            randomize_noise=False,
            truncation=1,
        )
        x, _ = net.decoder(
            [w], input_is_latent=True, randomize_noise=False, truncation=1
        )
        result_batch = (x_hat, w_hat, x)
    return result_batch

In [None]:
from IPython.display import Image,display_png
def predict(
    edit_t, hair_d, color_d, 
    ck, im_path, trans, device):
    editing_type_ = edit_t
    hairstyle_description_ = hair_d
    color_description_ = color_d

    if editing_type_ == "both":
      assert (
          hairstyle_description_ is not None and color_d is not None
          ), ("Please provide description " "for both hairstyle and color.")
    elif editing_type_ == "hairstyle":
      assert (
          hairstyle_description_ is not None
          ), "Please provide description for hairstyle."
    else:
      assert (
          color_description_ is not None
          ), "Please provide description for color."

    opts = ck["opts"]
    opts = Namespace(**opts)
    opts.editing_type = editing_type_
    opts.input_type = "text"
    opts.color_description = color_description_

    if hair_d is not None:
      with open(base_dir + "/HairCLIP/outputs/hairstyle_description.txt", "w") as file:
        file.write(hairstyle_description_)
      opts.hairstyle_description = base_dir + "/HairCLIP/outputs/hairstyle_description.txt"

    opts.checkpoint_path  = base_dir + "/HairCLIP/pretrained_models/hairclip.pt"
    opts.parsenet_weights = base_dir + "/HairCLIP/pretrained_models/parsenet.pth"
    opts.stylegan_weights = base_dir + "/HairCLIP/pretrained_models/stylegan2-ffhq-config-f.pt"
    opts.ir_se50_weights  = base_dir + "/HairCLIP/pretrained_models/model_ir_se50.pth"
    net = HairCLIPMapper(opts)
    net.eval()
    net.cuda()

    # 顔部分のalignment, transform
    input_image = run_alignment(str(im_path))
    resize_dims = (256, 256)
    input_image.resize(resize_dims)
    transformed_image = trans(input_image)

    with torch.no_grad():
      images, latents = run_on_batch_e4e(
          transformed_image.unsqueeze(0), e4e_net
          )
      print("Latent code calculated!")

    dataset = LatentsDatasetInference(latents=latents.cpu(), opts=opts)
    dataloader = DataLoader(dataset)

    average_color_loss = (
        average_lab_color_loss.AvgLabLoss(opts).to(device).eval()
        )

    out_filename = editing_type_ + "_" \
      + hairstyle_description_ + "_" \
      + color_description_ + ".png"
    out_path = os.path.join(base_dir + "/HairCLIP/outputs", out_filename)
    print("output path:", out_path)

    for input_batch in tqdm(dataloader):
      with torch.no_grad():
        w,hairstyle_text_inputs_list, color_text_inputs_list, selected_description_tuple_list, hairstyle_tensor_list, color_tensor_list = input_batch

        hairstyle_text_inputs = hairstyle_text_inputs_list[0]
        color_text_inputs = color_text_inputs_list[0]

        selected_description = selected_description_tuple_list[0][0]
        hairstyle_tensor = hairstyle_tensor_list[0]
        color_tensor = color_tensor_list[0]

        w = w.cuda().float()
        hairstyle_text_inputs = hairstyle_text_inputs.cuda()
        color_text_inputs = color_text_inputs.cuda()
        hairstyle_tensor = hairstyle_tensor.cuda()
        color_tensor = color_tensor.cuda()
        if hairstyle_tensor.shape[1] != 1:
          hairstyle_tensor_hairmasked = (
              hairstyle_tensor * average_color_loss.gen_hair_mask(hairstyle_tensor))
        else:
          hairstyle_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).cuda()

        if color_tensor.shape[1] != 1:
          color_tensor_hairmasked = (
              color_tensor * average_color_loss.gen_hair_mask(color_tensor))
        else:
          color_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).cuda()
          result_batch = run_on_batch(
              w, hairstyle_text_inputs, color_text_inputs,
              hairstyle_tensor_hairmasked, color_tensor_hairmasked, net,)   

        if (hairstyle_tensor.shape[1] != 1) and (color_tensor.shape[1] != 1):
          img_tensor = torch.cat([hairstyle_tensor, color_tensor], dim=3)
        elif hairstyle_tensor.shape[1] != 1:
          img_tensor = hairstyle_tensor
        elif color_tensor.shape[1] != 1:
          img_tensor = color_tensor
        else:
          img_tensor = None   

        if img_tensor is not None:
          if img_tensor.shape[3] == 1024:
            couple_output = torch.cat(
                [result_batch[2][0].unsqueeze(0), result_batch[0][0].unsqueeze(0), img_tensor,])
          elif img_tensor.shape[3] == 2048:
            couple_output = torch.cat(
                [result_batch[2][0].unsqueeze(0), result_batch[0][0].unsqueeze(0),
                 img_tensor[:, :, :, 0:1024], img_tensor[:, :, :, 1024::], ])
            couple_output = torch.cat(
                [result_batch[2][0].unsqueeze(0), result_batch[0][0].unsqueeze(0),
                 img_tensor[:, :, :, 0:1024], img_tensor[:, :, :, 1024::], ])
        else:
            couple_output = torch.cat(
                [result_batch[2][0].unsqueeze(0),result_batch[0][0].unsqueeze(0),])

        torchvision.utils.save_image(
            couple_output, str(out_path), normalize=True, range=(-1, 1))
        
        
        display_png(Image(out_path))

# Predict

In [None]:
base_dir = "/content/drive/MyDrive/Colab Notebooks/HairCLIP"
with open(base_dir + "/HairCLIP/mapper/hairstyle_list.txt") as infile:
  HAIRSTYLE_LIST = sorted([line.rstrip() for line in infile])
COLORSTYLE_LIST = ["purple", "red", "orange", "yellow", "green", "blue", "gray", "brown", "black", "white", "blond", "pink"]

device = "cuda:0"

# load e4e ffhq model
e4e_model_path = base_dir + "/HairCLIP/pretrained_models/e4e_ffhq_encode.pt"
e4e_ckpt = torch.load(e4e_model_path, map_location="cpu")
e4e_opts = e4e_ckpt["opts"]
e4e_opts["checkpoint_path"] = e4e_model_path
e4e_opts = Namespace(**e4e_opts)

e4e_net = pSp(e4e_opts)
e4e_net.eval()
e4e_net.cuda()
print("e4e model successfully loaded!")

# set transforms
img_transforms = transforms.Compose(
    [
     transforms.Resize((256, 256)),
     transforms.ToTensor(),
     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

# load hairclip model
checkpoint_path = base_dir + "/HairCLIP/pretrained_models/hairclip.pt"
ckpt = torch.load(checkpoint_path, map_location="cpu")

In [None]:
from IPython.display import display, Javascript
from google.colab.output import eval_js
from base64 import b64decode

def take_photo(filename='photo.jpg', quality=0.8):
  js = Javascript('''
    async function takePhoto(quality) {
      const div = document.createElement('div');
      const capture = document.createElement('button');
      capture.textContent = 'Capture';
      div.appendChild(capture);

      const video = document.createElement('video');
      video.style.display = 'block';
      const stream = await navigator.mediaDevices.getUserMedia({video: true});

      document.body.appendChild(div);
      div.appendChild(video);
      video.srcObject = stream;
      await video.play();

      // Resize the output to fit the video element.
      google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true);

      // Wait for Capture to be clicked.
      await new Promise((resolve) => capture.onclick = resolve);

      const canvas = document.createElement('canvas');
      canvas.width = video.videoWidth;
      canvas.height = video.videoHeight;
      canvas.getContext('2d').drawImage(video, 0, 0);
      stream.getVideoTracks()[0].stop();
      div.remove();
      return canvas.toDataURL('image/jpeg', quality);
    }
    ''')
  display(js)
  data = eval_js('takePhoto({})'.format(quality))
  binary = b64decode(data.split(',')[1])
  with open(filename, 'wb') as f:
    f.write(binary)
  return filename

In [None]:
%cd /content/drive/MyDrive/Colab Notebooks/HairCLIP/HairCLIP

from IPython.display import Image
try:
  name = 'demo/photo.jpg'
  filename = take_photo(name)
  print('Saved to {}'.format(filename))
  
  # Show the image which was just taken.
  display(Image(filename))
  
except Exception as err:
  # Errors will be thrown if the user does not have a webcam or if they do not
  # grant the page permission to access it.
  print(str(err))

In [None]:
# @markdown 入力画像
# image_path = base_dir + "/HairCLIP/demo/model10211041_TP_V4.jpg" #@param {type:"string"}
image_path = base_dir + "/HairCLIP/demo/photo.jpg" #@param {type:"string"}

from IPython.display import Image,display_jpeg
display_jpeg(Image(image_path))

In [None]:
if IsRandom == True:
  for i in range(random_num):
    hairstyle_index = random.randrange(0, (len(HAIRSTYLE_LIST)-1), 1)
    colorstyle_index = random.randrange(0, (len(COLORSTYLE_LIST)-1), 1)
    predict(
        "both", HAIRSTYLE_LIST[hairstyle_index], COLORSTYLE_LIST[colorstyle_index], 
        ckpt, image_path, img_transforms, device)

else:
  predict(
      editing_type, hairstyle_description, color_description, 
      ckpt, image_path, img_transforms, device)


# Show Result

mp4動画を作成して再生するところまで。必要あれば実行。

In [None]:
# !pip install imageio-ffmpeg

In [None]:
"""
def generate_mp4(out_name, images, kwargs):
  print(out_name + '.mp4')
  writer = imageio.get_writer(out_name + '.mp4', **kwargs)
  for image in images:
    writer.append_data(image)
  writer.close()

def show_mp4(filename, width):
  print(filename+'.mp4')
  mp4 = open(filename + '.mp4', 'rb').read()
  data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
  display(HTML("""
  <video width="%d" controls autoplay loop>
    <source src="%s" type="video/mp4">
  </video>
  """ % (width, data_url)))

res_list = glob.glob(base_dir + "/HairCLIP/outputs/*.png")
images = []
for img_path in res_list:
  images.append(np.array(Image.open(img_path)))

kwargs = {'fps': 2}

gif_path = os.path.join(base_dir + "/HairCLIP/outputs", "animation")
generate_mp4(gif_path, images, kwargs)
show_mp4(gif_path, width=514)
"""

# GradCAM
ここからテスト、うまく作業中。


In [None]:
"""
!pip install pytorch-gradcam
!pip install grad-cam
"""


In [None]:
"""
# Basic Modules
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

# PyTorch Modules
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, datasets
import torchvision.transforms as transforms
from torch.utils.data.dataset import Subset
import torchvision.models as models
import torch.optim as optim
from torchvision.utils import make_grid, save_image

# Grad-CAM
from gradcam.utils import visualize_cam
from gradcam import GradCAM, GradCAMpp


device = torch.device("cuda:0" if torch.cuda.is_available()  else "cpu")
model = models.densenet161(pretrained=True)
model.fc = nn.Linear(2048,5)
model = torch.nn.DataParallel(model).to(device)
model.eval()
#  model.load_state_dict(torch.load('trained_model.pt'))
# model.load_state_dict(torch.load('/content/HairCLIP/pretrained_models/e4e_ffhq_encode.pt'))
model.load_state_dict(torch.load(base_dir + '/HairCLIP/pretrained_models/hairclip.pt'))

# Grad-CAM
target_layer = model.module.features
gradcam = GradCAM(model, target_layer)
gradcam_pp = GradCAMpp(model, target_layer)

images = []
# あるラベルの検証用データセットを呼び出してる想定
for path in glob.glob("{}/label1/*".format(config['dataset'])):
    img = Image.open(path)
    torch_img = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])(img).to(device)
    normed_torch_img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(torch_img)[None]
    mask, _ = gradcam(normed_torch_img)
    heatmap, result = visualize_cam(mask, torch_img)

    mask_pp, _ = gradcam_pp(normed_torch_img)
    heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img)

    images.extend([torch_img.cpu(), heatmap, heatmap_pp, result, result_pp])
grid_image = make_grid(images, nrow=5)

# 結果の表示
transforms.ToPILImage()(grid_image)
"""
