In [1]:
import os
import clip
import numpy as np
from PIL import Image
import torch

from embedding import get_delta_t
from manipulator import Manipulator
from mapper import get_delta_s
from wrapper import Generator

In [2]:
# GPU device
device = torch.device('cuda:2')
# pretrained ffhq generator
ckpt = 'pretrained/ffhq.pkl'
G = Generator(ckpt, device)
# CLIP
model, preprocess = clip.load("ViT-B/32", device=device)
# global image direction
fs3 = np.load('tensor/fs3.npy')

In [3]:
manipulator = Manipulator(G, device)

In [4]:
# test image dir path
H,W=256,256
imgdir = 'samples'
origs=[]
for filename in sorted(os.listdir(imgdir)):
    origs.append(np.asanyarray(Image.open(f'{imgdir}/{filename}').resize(H,W)))
origs=np.stack(origs)

ValueError: Unknown resampling filter (256). Use Image.Resampling.NEAREST (0), Image.Resampling.LANCZOS (1), Image.Resampling.BILINEAR (2), Image.Resampling.BICUBIC (3), Image.Resampling.BOX (4) or Image.Resampling.HAMMING (5)

In [None]:
# manipulator mode
# inv_mode : inversion mode
# 'w' : use w projector proposed by Karras et al.
# 'w+' : use e4e encoder (only implemented for ffhq1024 now)
# pti_mode : pivot tuning mode
# 'w' : W latent space pivot tuning
# 's' : Style space pivot tuning
manipulator.set_real_img_projection(imgdir, inv_mode='w+', pti_mode='s')

In [None]:
# text direction : neutral -> target

targets = [

    'face with Arched Eyebrows',
    'face with Bushy Eyebrows',

    'face with Big Lips',

    'face with Big Nose',
    'face with Pointy Nose',

    'face with Black Hair',
    'face with Blond Hair',
    'face with Brown Hair',
    'face with Gray Hair',
    'face with Curly Hair',
    'face with Straight Hair',
    'face with Wavy Hair',
    'face with Receding Hairline',
    'face with Bangs',

    'face with Eyeglasses',
    'face with Sunglasses',

    'face with Eyes Open',
    'face with Narrow Eyes',
    'face with Brown Eyes',
    'face with Bags Under Eyes',

    'face with Heavy Makeup',
    'face with Lipstick',

    'face with Mouth Closed',
    'face with Mouth Slightly Open',
    'face with Mouth Wide Open',

    'face with Beard',
    'face with No Beard',
    'face with Mustache',
    'face with Goatee',

    'face with Pale Skin',
    'face with Shiny Skin',
    'face with Rosy Cheeks',
    'face with Sideburns',

    'face with Earrings',
    'face with Hat',
    'face with Necklace',
    'face with Necktie',

    'face with Double Chin',
    'face with High Cheekbones',
    'face with Frowning',
    'face with Round Jaw',

    'Asian face',
    'White face',
    'Black face',
    'Indian face',

    'Baby face',
    'Child face',
    'Middle Aged face',
    'Senior face',
    'Youth face',

    'Oval Face',
    'Square Face',
    'Round Face',

    'Attractive face',
    'Bald face',
    'Blurry face',
    'Chubby face',
    'Smiling face',

    'Surprised face',
    'Fearful face',
    'Disgusted face',
    'Happy face',
    'Sad face',
    'Angry face',

    'face under Harsh Lighting',
    'face under Flash Lighting',

]
neutral = 'face'

# beta_threshold : Determines the degree of disentanglement, # channels manipulated
beta_threshold = 0.10

In [None]:
def visualize(imgs,gh,gw):
    _,H,W,C=imgs.shape
    imgs=imgs.reshape(gh,gw,H,W,C)
    imgs=imgs.transpose(0,2,1,3,4)
    imgs = imgs.reshape(gh*H, gw*W, C)
    print(imgs.shape)
    display(Image.fromarray(imgs, 'RGB'))

In [None]:
import ipywidgets as widgets
delta_s=None

In [None]:
slider = widgets.FloatSlider(
    value=0,
    min=-5,
    max=5,
    step=0.01,
    description='Alpha:',
    readout_format='.2f',
)

In [None]:
dropdown = widgets.Dropdown(options=targets)

In [None]:
def slider_handler():
    # manipulate styles
    manipulator.set_alpha([slider.value])
    styles = manipulator.manipulate(delta_s)
    all_imgs = manipulator.synthesis_from_styles(styles, 0, manipulator.num_images)

    # visualize
    lst = []
    for imgs in all_imgs:
        lst.append((imgs.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).numpy())

    imgs = lst[0]
    # visualize(origs,gw,gh)
    # visualize(imgs,gw,gh)
    output=np.concatenate([origs,imgs],axis=2)
    display(Image.fromarray(output.reshape(-1,output.shape[-2],output.shape[-1]), 'RGB'))

slider.observe(slider_handler, names='value')

In [None]:
def dropdown_handler():
    global delta_s
    classnames = [neutral, dropdown.value]
    delta_t = get_delta_t(classnames, model)
    delta_s, _ = get_delta_s(fs3, delta_t, manipulator, beta_threshold=beta_threshold)

dropdown.observe(dropdown_handler,names='value')

In [None]:
display(slider,dropdown)

In [None]:
import ipywidgets as widgets

@widgets.interact(alpha=(-5, 5, 0.01),target=targets)
def run(alpha,target):
    classnames = [neutral, target]
    delta_t = get_delta_t(classnames, model)
    delta_s, num_channel = get_delta_s(fs3, delta_t, manipulator, beta_threshold=beta_threshold)
    manipulator.set_alpha([alpha])

    # manipulate styles
    styles = manipulator.manipulate(delta_s)
    all_imgs = manipulator.synthesis_from_styles(styles, 0, manipulator.num_images)

    # visualize
    lst = []
    for imgs in all_imgs:
        lst.append((imgs.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).numpy())

    gw, gh = (manipulator.num_images, 1)

    imgs = lst[0]
    # visualize(origs,gw,gh)
    # visualize(imgs,gw,gh)
    output=np.concatenate([origs,imgs],axis=2)
    display(Image.fromarray(output.reshape(-1,output.shape[-2],output.shape[-1]), 'RGB'))