In [None]:
import torch
import torch.nn as nn
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, ToPILImage
import torch.nn.functional as F

import time
import copy
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm

In [None]:
import os
# os.environ['http_proxy'] = "" 
# os.environ['https_proxy'] = ""
device_id = 0
torch.cuda.set_device('cuda:%d' % device_id)

In [None]:
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
from piecewise_rectified_flow.src.scheduler_perflow import PeRFlowScheduler

pipe = StableDiffusionPipeline.from_pretrained("hansyan/perflow-sd15-dreamshaper", safety_checker=None, torch_dtype=torch.float16)
# pipe = StableDiffusionPipeline.from_pretrained("hansyan/perflow-sd15-realisticVisionV51", safety_checker=None, torch_dtype=torch.float16)
# pipe = StableDiffusionPipeline.from_pretrained("hansyan/perflow-sd15-disney", safety_checker=None, torch_dtype=torch.float16)
pipe.scheduler = PeRFlowScheduler.from_config(pipe.scheduler.config, prediction_type="diff_eps", num_time_windows=4)

# Our method also supports SD 2.1 (not used in the paper)
# pipe = StableDiffusionPipeline.from_pretrained("hansyan/perflow-sd21-artius", safety_checker=None, torch_dtype=torch.bfloat16)
# pipe.scheduler = PeRFlowScheduler.from_config(pipe.scheduler.config, prediction_type="velocity", num_time_windows=4)

In [None]:
for module in [pipe.vae, pipe.text_encoder, pipe.unet]:
    for param in module.parameters():
        param.requires_grad = False

pipe.to("cuda")
pipe.set_progress_bar_config(disable=True)

my_forward = pipe.__call__.__wrapped__

In [None]:
import insightface
from onnx2torch import convert

# antelopev2
# https://github.com/deepinsight/insightface/tree/master/python-package#model-zoo
detector = insightface.model_zoo.get_model('scrfd_10g_bnkps.onnx', provider_options=[{'device_id': device_id}, {}])
detector.prepare(ctx_id=0, input_size=(640, 640))
model = convert('glintr100.onnx').eval().to('cuda')
for param in model.parameters():
    param.requires_grad_(False)

In [None]:
from PIL import Image
import torchvision.transforms.functional as TF
import kornia

ref = 'assets/hinton.jpg'
# ref = 'assets/bengio.jpg'
# ref = 'assets/schmidhuber.jpg'
# ref = 'assets/johansson.jpg'
# ref = 'assets/newton.jpg'

ref_image = Image.open(ref).convert("RGB")

with torch.no_grad():
    det_thresh_backup = detector.det_thresh
    boxes = []
    while len(boxes) == 0:
        boxes, kpss = detector.detect(np.array(ref_image), max_num=1)
        detector.det_thresh -= 0.1
    detector.det_thresh = det_thresh_backup
    M = insightface.utils.face_align.estimate_norm(kpss[0])
    ref_image_cropped = kornia.geometry.transform.warp_affine(
        TF.to_tensor(ref_image).unsqueeze(0).to('cuda'), torch.tensor(M).float().unsqueeze(0).to('cuda'), (112, 112)
    ) * 2 - 1

    ref_embedding = model(ref_image_cropped)

cropped_image = np.array((ref_image_cropped[0] / 2 + 0.5).cpu().permute(1, 2, 0) * 255, dtype=np.uint8)
plt.imshow(cropped_image)
plt.axis("off")
plt.show()

In [None]:
import tensorflow as tf
from deepface import DeepFace

tf.config.set_visible_devices([], device_type='GPU')
attribute = DeepFace.analyze(img_path=ref, actions = ['gender', 'race'])

In [None]:
idx = np.argmax([a['region']['w'] * a['region']['h'] for a in attribute])
print(attribute[idx]['dominant_gender'], attribute[idx]['dominant_race'])

In [None]:
from diffusers.utils.torch_utils import randn_tensor

generator = torch.manual_seed(42)

latents = nn.Parameter(randn_tensor((4, 4, 64, 64), generator=generator, device=pipe._execution_device, dtype=pipe.text_encoder.dtype))
latents0 = latents.data[:1].clone()
optimizer = torch.optim.SGD([latents], 1)  # 1 or 2

In [None]:
prompt = 'Selfie of a middle-aged person on a yacht'
# prompt = 'A photo of a person wearing a suit, holding red roses in hand, upper body, behind is the Eiffel Tower'
# prompt = 'a man sitting in the cafe, comic, graphic illustration, comic art, graphic novel art, vibrant, highly detailed, colored, 2d minimalistic'

# prompt = 'a photo of a person'
# prompt = 'a person with a sad expression'
# prompt = 'a person with a happy expression'
# prompt = 'a person with a puzzled expression'
# prompt = 'a person with an angry expression'
# prompt = 'a person plays the LEGO toys'
# prompt = 'a person on the beach'
# prompt = 'a person piloting a fighter jet'
# prompt = 'a person wearing the sweater, a backpack and camping stove, outdoors, RAW, ultra high res'
# prompt = 'a person wearing a scifi spacesuit in space'
# prompt = 'cubism painting of a person'
# prompt = 'fauvism painting of a person'
# prompt = 'cave mural depicting a person'
# prompt = 'pointillism painting of a person'
# prompt = 'a person latte art'

if attribute[idx]['dominant_gender'] == 'Man':
    prompt = prompt.replace('person', attribute[idx]['dominant_race'] + ' man')
else:
    prompt = prompt.replace('person', attribute[idx]['dominant_race'] + ' woman')
# prompt = prompt + ', face'

In [None]:
latents_last = latents.data.clone()
latents_last_e = latents.data.clone()
initialized_i = -1

def callback(self, i, t, callback_kwargs):
    global latents_last, latents_last_e, initialized_i
    if initialized_i < i:
        latents[i:(i+1)].data.copy_(callback_kwargs['latents'])
        latents_last[i:(i+1)].copy_(callback_kwargs['latents'])
        latents_last_e[i:(i+1)].copy_(callback_kwargs['latents'])
        initialized_i = i
    if i < 3:
        callback_kwargs['latents'] += latents[(i+1):(i+2)] - latents[(i+1):(i+2)].detach()
    latents_e = callback_kwargs['latents'].data.clone()
    callback_kwargs['latents'] += latents_last[i:(i+1)].detach() - callback_kwargs['latents'].detach()
    callback_kwargs['latents'] += latents_e.detach() - latents_last_e[i:(i+1)].detach()
    # callback_kwargs['latents'] += latents[i:(i+1)].detach() - latents_last_e[i:(i+1)].detach()
    callback_kwargs['latents'] += (latents[i:(i+1)].detach() - latents_last_e[i:(i+1)].detach()) * 0.95796674
    latents_last[i:(i+1)].copy_(callback_kwargs['latents'])
    latents_last_e[i:(i+1)].data.copy_(latents_e)
    latents[i:(i+1)].data.copy_(latents_e)
    return callback_kwargs

for epoch in tqdm(range(51)):
    t0 = time.time()
    image = my_forward(pipe, prompt=prompt, num_inference_steps=4, guidance_scale=3.0, latents=latents0+latents[:1]-latents[:1].detach(), output_type='pt', return_dict=False, callback_on_step_end=callback)[0][0]
    t1 = time.time()
    
    det_thresh_backup = detector.det_thresh
    boxes = []
    while len(boxes) == 0:
        boxes, kpss = detector.detect(np.array(image.permute(1, 2, 0).detach().cpu().numpy() * 255, dtype=np.uint8), max_num=1)
        detector.det_thresh -= 0.1
    det_thresh_backup2 = detector.det_thresh + 0.1
    detector.det_thresh = det_thresh_backup

    t2 = time.time()
    M = insightface.utils.face_align.estimate_norm(kpss[0])
    image_cropped = kornia.geometry.transform.warp_affine(
        image.float().unsqueeze(0), torch.tensor(M).float().unsqueeze(0).to('cuda'), (112, 112)
    ) * 2 - 1
    embedding = model(image_cropped)
    loss = (1 - F.cosine_similarity(embedding, ref_embedding)) * 100
    t3 = time.time()

    optimizer.zero_grad()
    loss.backward()
    t4 = time.time()
    grad_norm = latents.grad.reshape(4, -1).norm(dim=-1)
    latents.grad /= grad_norm.reshape(4, 1, 1, 1).clamp(min=1)
    # latents.grad.clamp_(min=-2e-2, max=2e-2)  # optional for removing artifacts
    optimizer.step()
    t5 = time.time()

    if epoch % 10 == 0:
        print('loss:', loss.data.item())
        print('grad:', grad_norm.tolist())
        print('time:', t1-t0, t2-t1, '(%f)' % det_thresh_backup2, t3-t2, t4-t3, t5-t4)
        plt.imshow(np.array(image.permute(1, 2, 0).detach().cpu() * 255, dtype=np.uint8))
        plt.axis("off")
        plt.show()
    del image