In [None]:
import torch
import torch.nn as nn
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, ToPILImage
import torch.nn.functional as F

import time
import copy
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm

In [None]:
import os
# os.environ['http_proxy'] = "" 
# os.environ['https_proxy'] = ""
device_id = 0
torch.cuda.set_device('cuda:%d' % device_id)

In [None]:
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
from piecewise_rectified_flow.src.scheduler_perflow import PeRFlowScheduler

pipe = StableDiffusionPipeline.from_pretrained("hansyan/perflow-sd15-dreamshaper", safety_checker=None, torch_dtype=torch.float16)
# pipe = StableDiffusionPipeline.from_pretrained("hansyan/perflow-sd15-realisticVisionV51", safety_checker=None, torch_dtype=torch.float16)
# pipe = StableDiffusionPipeline.from_pretrained("hansyan/perflow-sd15-disney", safety_checker=None, torch_dtype=torch.float16)
pipe.scheduler = PeRFlowScheduler.from_config(pipe.scheduler.config, prediction_type="diff_eps", num_time_windows=4)

In [None]:
for module in [pipe.vae, pipe.text_encoder, pipe.unet]:
    for param in module.parameters():
        param.requires_grad = False

pipe.to("cuda")
pipe.set_progress_bar_config(disable=True)

my_forward = pipe.__call__.__wrapped__

In [None]:
import insightface
from onnx2torch import convert

# antelopev2
# https://github.com/deepinsight/insightface/tree/master/python-package#model-zoo
detector = insightface.model_zoo.get_model('scrfd_10g_bnkps.onnx', provider_options=[{'device_id': device_id}, {}])
detector.prepare(ctx_id=0, input_size=(640, 640))
model = convert('glintr100.onnx').eval().to('cuda')
for param in model.parameters():
    param.requires_grad_(False)

In [None]:
from PIL import Image
import torchvision.transforms.functional as TF
import kornia


ref1 = 'assets/hinton.jpg'
# ref1 = 'assets/bengio.jpg'
# ref1 = 'assets/schmidhuber.jpg'
# ref1 = 'assets/johansson.jpg'
# ref1 = 'assets/newton.jpg'
# ref2 = 'assets/hinton.jpg'
ref2 = 'assets/bengio.jpg'
# ref2 = 'assets/schmidhuber.jpg'
# ref2 = 'assets/johansson.jpg'
# ref2 = 'assets/newton.jpg'

ref_image1 = Image.open(ref1).convert("RGB")
ref_image2 = Image.open(ref2).convert("RGB")

def crop_image_embed(ref_image):
    with torch.no_grad():
        det_thresh_backup = detector.det_thresh
        boxes = []
        while len(boxes) == 0:
            boxes, kpss = detector.detect(np.array(ref_image), max_num=1)
            detector.det_thresh -= 0.1
        detector.det_thresh = det_thresh_backup
        M = insightface.utils.face_align.estimate_norm(kpss[0])
        ref_image_cropped = kornia.geometry.transform.warp_affine(
            TF.to_tensor(ref_image).unsqueeze(0).to('cuda'), torch.tensor(M).float().unsqueeze(0).to('cuda'), (112, 112)
        ) * 2 - 1

        ref_embedding = model(ref_image_cropped)
    return ref_image_cropped, ref_embedding

ref_image_cropped1, ref_embedding1 = crop_image_embed(ref_image1)
ref_image_cropped2, ref_embedding2 = crop_image_embed(ref_image2)

cropped_image1 = np.array((ref_image_cropped1[0] / 2 + 0.5).cpu().permute(1, 2, 0) * 255, dtype=np.uint8)
cropped_image2 = np.array((ref_image_cropped2[0] / 2 + 0.5).cpu().permute(1, 2, 0) * 255, dtype=np.uint8)

plt.figure()
plt.subplot(1,2,1)
plt.imshow(cropped_image1)
plt.axis("off")

plt.subplot(1,2,2)
plt.imshow(cropped_image2)
plt.axis("off")
plt.show()

In [None]:
import tensorflow as tf
from deepface import DeepFace

tf.config.set_visible_devices([], device_type='GPU')
attribute1 = DeepFace.analyze(img_path=ref1, actions = ['gender', 'race'])
attribute2 = DeepFace.analyze(img_path=ref2, actions = ['gender', 'race'])

In [None]:
idx1 = np.argmax([a['region']['w'] * a['region']['h'] for a in attribute1])
print(attribute1[idx1]['dominant_gender'], attribute1[idx1]['dominant_race'])

idx2 = np.argmax([a['region']['w'] * a['region']['h'] for a in attribute2])
print(attribute2[idx2]['dominant_gender'], attribute2[idx2]['dominant_race'])

In [None]:
from diffusers.utils.torch_utils import randn_tensor

generator = torch.manual_seed(42)

latents = nn.Parameter(randn_tensor((4, 4, 64, 64), generator=generator, device=pipe._execution_device, dtype=pipe.text_encoder.dtype))
latents0 = latents.data[:1].clone()
optimizer = torch.optim.SGD([latents], 1)  # 1 or 2

In [None]:
# prompt = 'A person and a person in the snow'
# prompt = 'a movie poster of a person and a person'
prompt = "A person and a person having dinner together"
# prompt = "A person and a person on the beach"
# prompt = "A person and a person sitting in a park"
# prompt = "A person and a person holding a bottle of red wine"
# prompt = "A person and a person standing together"
# prompt = "A person and a person riding a horse"

if attribute1[idx1]['dominant_gender'] == 'Man':
    prompt = prompt.replace('person', attribute1[idx1]['dominant_race'] + ' man', 1)
else:
    prompt = prompt.replace('person', attribute1[idx1]['dominant_race'] + ' woman', 1)

if attribute2[idx2]['dominant_gender'] == 'Man':
    prompt = prompt.replace('person', attribute2[idx2]['dominant_race'] + ' man', 1)
else:
    prompt = prompt.replace('person', attribute2[idx2]['dominant_race'] + ' woman', 1)
prompt = prompt + ', closeup'
# prompt = prompt + ', faces'

In [None]:
latents_last = latents.data.clone()
latents_last_e = latents.data.clone()
initialized_i = -1

def callback(self, i, t, callback_kwargs):
    global latents_last, latents_last_e, initialized_i
    if initialized_i < i:
        latents[i:(i+1)].data.copy_(callback_kwargs['latents'])
        latents_last[i:(i+1)].copy_(callback_kwargs['latents'])
        latents_last_e[i:(i+1)].copy_(callback_kwargs['latents'])
        initialized_i = i
    if i < 3:
        callback_kwargs['latents'] += latents[(i+1):(i+2)] - latents[(i+1):(i+2)].detach()
    latents_e = callback_kwargs['latents'].data.clone()
    callback_kwargs['latents'] += latents_last[i:(i+1)].detach() - callback_kwargs['latents'].detach()
    callback_kwargs['latents'] += latents_e.detach() - latents_last_e[i:(i+1)].detach()
    # callback_kwargs['latents'] += latents[i:(i+1)].detach() - latents_last_e[i:(i+1)].detach()
    callback_kwargs['latents'] += (latents[i:(i+1)].detach() - latents_last_e[i:(i+1)].detach()) * 0.95796674
    latents_last[i:(i+1)].copy_(callback_kwargs['latents'])
    latents_last_e[i:(i+1)].data.copy_(latents_e)
    latents[i:(i+1)].data.copy_(latents_e)
    return callback_kwargs

for epoch in tqdm(range(51)):
    t0 = time.time()
    image = my_forward(pipe, prompt=prompt, num_inference_steps=4, guidance_scale=3.0, latents=latents0+latents[:1]-latents[:1].detach(), output_type='pt', return_dict=False, callback_on_step_end=callback)[0][0]
    t1 = time.time()

    det_thresh_backup = detector.det_thresh
    boxes = []
    while len(boxes) <= 1:
        boxes, kpss = detector.detect(np.array(image.permute(1, 2, 0).detach().cpu().numpy() * 255, dtype=np.uint8), max_num=2)
        detector.det_thresh -= 0.1
    det_thresh_backup2 = detector.det_thresh + 0.1
    detector.det_thresh = det_thresh_backup
    t2 = time.time()

    M1 = insightface.utils.face_align.estimate_norm(kpss[0])
    image_cropped_1 = kornia.geometry.transform.warp_affine(
        image.float().unsqueeze(0), torch.tensor(M1).float().unsqueeze(0).to('cuda'), (112, 112)
    ) * 2 - 1
    M2 = insightface.utils.face_align.estimate_norm(kpss[1])
    image_cropped_2 = kornia.geometry.transform.warp_affine(
        image.float().unsqueeze(0), torch.tensor(M2).float().unsqueeze(0).to('cuda'), (112, 112)
    ) * 2 - 1
    embedding_1 = model(image_cropped_1)
    embedding_2 = model(image_cropped_2)
    ref_embeddings = torch.cat([ref_embedding1, ref_embedding2, ref_embedding1, ref_embedding2])
    proposal_embeddings = torch.cat([embedding_1, embedding_2, embedding_2, embedding_1])
    sim = F.cosine_similarity(ref_embeddings, proposal_embeddings)
    loss = (2 - max(sim[0] + sim[1], sim[2] + sim[3])) * 100
    t3 = time.time()

    optimizer.zero_grad()
    loss.backward()
    t4 = time.time()
    grad_norm = latents.grad.reshape(4, -1).norm(dim=-1)
    latents.grad /= grad_norm.reshape(4, 1, 1, 1).clamp(min=1)
    # latents.grad.clamp_(min=-2e-2, max=2e-2)  # optional for removing artifacts
    optimizer.step()
    t5 = time.time()

    if epoch % 10 == 0:
        print('loss:', loss.data.item())
        print('grad:', grad_norm.tolist())
        print('time:', t1-t0, t2-t1, '(%f)' % det_thresh_backup2, t3-t2, t4-t3, t5-t4)
        plt.imshow(np.array(image.permute(1, 2, 0).detach().cpu() * 255, dtype=np.uint8))
        plt.axis("off")
        plt.show()
    del image