In [None]:
import torch
import torch.nn as nn
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, ToPILImage
import torch.nn.functional as F

import time
import copy
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm

In [None]:
import os
# os.environ['http_proxy'] = "" 
# os.environ['https_proxy'] = ""
device_id = 0
torch.cuda.set_device('cuda:%d' % device_id)

In [None]:
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
from piecewise_rectified_flow.src.scheduler_perflow import PeRFlowScheduler

pipe = StableDiffusionPipeline.from_pretrained("hansyan/perflow-sd15-dreamshaper", safety_checker=None, torch_dtype=torch.float16)
# pipe = StableDiffusionPipeline.from_pretrained("hansyan/perflow-sd15-realisticVisionV51", safety_checker=None, torch_dtype=torch.float16)
# pipe = StableDiffusionPipeline.from_pretrained("hansyan/perflow-sd15-disney", safety_checker=None, torch_dtype=torch.float16)
pipe.scheduler = PeRFlowScheduler.from_config(pipe.scheduler.config, prediction_type="diff_eps", num_time_windows=4)

In [None]:
for module in [pipe.vae, pipe.text_encoder, pipe.unet]:
    for param in module.parameters():
        param.requires_grad = False

pipe.to("cuda")
pipe.set_progress_bar_config(disable=True)

my_forward = pipe.__call__.__wrapped__

In [None]:
from transformers import AutoImageProcessor, AutoModel, OwlViTProcessor, OwlViTForObjectDetection, AutoProcessor
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD

detector_processor = OwlViTProcessor.from_pretrained('google/owlvit-base-patch32')
detector_processor2 = AutoProcessor.from_pretrained('google/owlvit-base-patch32')
detector = OwlViTForObjectDetection.from_pretrained('google/owlvit-base-patch32').to('cuda')
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
model = AutoModel.from_pretrained('facebook/dinov2-base').to('cuda')
for param in model.parameters():
    param.requires_grad_(False)

OPENAI_CLIP_MEAN = torch.tensor(OPENAI_CLIP_MEAN).to('cuda')
OPENAI_CLIP_STD = torch.tensor(OPENAI_CLIP_STD).to('cuda')

In [None]:
from PIL import Image
import torchvision.transforms.functional as TF
import kornia

# ref1 = 'can'
# ref1 = 'cat'
ref1 = 'dog'
# ref1 = 'dog3'
ref2 = 'can'
# ref2 = 'cat'
# ref2 = 'dog'
# ref2 = 'dog3'

extract_name = lambda f: ''.join([i for i in f.replace('_', ' ') if not i.isdigit()])

ref_text1 = 'a photo of a %s' % extract_name(ref1)
ref_text2 = 'a photo of a %s' % extract_name(ref2)
ref_image1 = Image.open('assets/%s_00.jpg' % ref1).convert("RGB")
ref_image2 = Image.open('assets/%s_00.jpg' % ref2).convert("RGB")

def crop_image_embed(ref_image, ref_text):
    with torch.no_grad():
        ref_image_detector_processed = detector_processor(text=ref_text, images=ref_image, return_tensors="pt")
    ref_image_detector_processed = {a: x.to('cuda') for a, x in ref_image_detector_processed.items()}
    detector_outputs = detector(**ref_image_detector_processed)
    target_sizes = torch.Tensor([ref_image.size[::-1]])
    results = detector_processor.post_process_object_detection(outputs=detector_outputs, target_sizes=target_sizes)
    box = results[0]['boxes'][results[0]['scores'].argmax()].tolist()

    ref_image_cropped = ref_image.crop(box)
    ref_image_processed = processor(images=ref_image_cropped, return_tensors="pt")['pixel_values'].to('cuda')
    ref_embedding = model(ref_image_processed)[0][:, 0]
    return ref_image_cropped, ref_embedding

ref_image_cropped1, ref_embedding1 = crop_image_embed(ref_image1, ref_text1)
ref_image_cropped2, ref_embedding2 = crop_image_embed(ref_image2, ref_text2)

plt.figure()
plt.subplot(1,2,1)
plt.imshow(ref_image_cropped1)
plt.axis("off")

plt.subplot(1,2,2)
plt.imshow(ref_image_cropped2)
plt.axis("off")
plt.show()

In [None]:
from diffusers.utils.torch_utils import randn_tensor

generator = torch.manual_seed(42)

latents = nn.Parameter(randn_tensor((4, 4, 64, 64), generator=generator, device=pipe._execution_device, dtype=pipe.text_encoder.dtype))
latents0 = latents.data[:1].clone()
optimizer = torch.optim.SGD([latents], 2)  # 1 or 2

In [None]:
# prompt = 'a %s in the jungle'
# prompt = 'a %s in the snow'
# prompt = 'a %s on the beach'
# prompt = 'a %s on a cobblestone street'
# prompt = 'a %s on top of pink fabric'
# prompt = 'a %s on top of a wooden floor'
# prompt = 'a %s with a city in the background'
# prompt = 'a %s with a mountain in the background'
# prompt = 'a %s with a blue house in the background'
# prompt = 'a %s on top of a purple rug in a forest'
# prompt = 'a %s wearing a red hat'
# prompt = 'a %s wearing a santa hat'
# prompt = 'a %s wearing a rainbow scarf'
# prompt = 'a %s wearing a black top hat and a monocle'
# prompt = 'a %s in a chef outfit'
# prompt = 'a %s in a firefighter outfit'
# prompt = 'a %s in a police outfit'
# prompt = 'a %s wearing pink glasses'
# prompt = 'a %s wearing a yellow shirt'
prompt = 'a %s in a purple wizard outfit'
# prompt = 'a red %s'
# prompt = 'a purple %s'
# prompt = 'a shiny %s'
# prompt = 'a wet %s'
# prompt = 'a cube shaped %s'

# prompt = 'a %s with a wheat field in the background'
# prompt = 'a %s with a tree and autumn leaves in the background'
# prompt = 'a %s with the Eiffel Tower in the background'
# prompt = 'a %s floating on top of water'
# prompt = 'a %s floating in an ocean of milk'
# prompt = 'a %s on top of green grass with sunflowers around it'
# prompt = 'a %s on top of a mirror'
# prompt = 'a %s on top of the sidewalk in a crowded street'
# prompt = 'a %s on top of a dirt road'
# prompt = 'a %s on top of a white rug'

prompt = prompt % (extract_name(ref1) + ' and a ' + extract_name(ref2))

In [None]:
latents_last = latents.data.clone()
latents_last_e = latents.data.clone()
initialized_i = -1

def callback(self, i, t, callback_kwargs):
    global latents_last, latents_last_e, initialized_i
    if initialized_i < i:
        latents[i:(i+1)].data.copy_(callback_kwargs['latents'])
        latents_last[i:(i+1)].copy_(callback_kwargs['latents'])
        latents_last_e[i:(i+1)].copy_(callback_kwargs['latents'])
        initialized_i = i
    if i < 3:
        callback_kwargs['latents'] += latents[(i+1):(i+2)] - latents[(i+1):(i+2)].detach()
    latents_e = callback_kwargs['latents'].data.clone()
    callback_kwargs['latents'] += latents_last[i:(i+1)].detach() - callback_kwargs['latents'].detach()
    callback_kwargs['latents'] += latents_e.detach() - latents_last_e[i:(i+1)].detach()
    # callback_kwargs['latents'] += latents[i:(i+1)].detach() - latents_last_e[i:(i+1)].detach()
    callback_kwargs['latents'] += (latents[i:(i+1)].detach() - latents_last_e[i:(i+1)].detach()) * 0.95796674
    latents_last[i:(i+1)].copy_(callback_kwargs['latents'])
    latents_last_e[i:(i+1)].data.copy_(latents_e)
    latents[i:(i+1)].data.copy_(latents_e)
    return callback_kwargs

for epoch in tqdm(range(51)):
    t0 = time.time()
    image = my_forward(pipe, prompt=prompt, num_inference_steps=4, guidance_scale=3.0, latents=latents0+latents[:1]-latents[:1].detach(), output_type='pt', return_dict=False, callback_on_step_end=callback)[0][0]
    t1 = time.time()

    with torch.no_grad():
        image_detector_processed = detector_processor2(images=(image * 255).int(), query_images=ref_image_cropped1, return_tensors="pt")
        image_detector_processed = {a: x.to('cuda') for a, x in image_detector_processed.items()}
        detector_outputs = detector.image_guided_detection(**image_detector_processed)
        target_sizes = torch.Tensor([image.shape[1:]])
        results = detector_processor2.post_process_image_guided_detection(outputs=detector_outputs, target_sizes=target_sizes)
        box = results[0]['boxes'][results[0]['scores'].argmax()].tolist()
        box1 = [min(max(round(b), 0), 512) for b in box]
        image_detector_processed = detector_processor2(images=(image * 255).int(), query_images=ref_image_cropped2, return_tensors="pt")
        image_detector_processed = {a: x.to('cuda') for a, x in image_detector_processed.items()}
        detector_outputs = detector.image_guided_detection(**image_detector_processed)
        target_sizes = torch.Tensor([image.shape[1:]])
        results = detector_processor2.post_process_image_guided_detection(outputs=detector_outputs, target_sizes=target_sizes)
        box = results[0]['boxes'][results[0]['scores'].argmax()].tolist()
        box2 = [min(max(round(b), 0), 512) for b in box]
    image_cropped1 = TF.crop(image, box1[1], box1[0], box1[3]-box1[1], box1[2]-box1[0])
    image_cropped2 = TF.crop(image, box2[1], box2[0], box2[3]-box2[1], box2[2]-box2[0])
    t2 = time.time()

    image_processed = (F.interpolate(image_cropped1.unsqueeze(0), (224, 224)) - OPENAI_CLIP_MEAN[..., np.newaxis, np.newaxis]) / OPENAI_CLIP_STD[..., np.newaxis, np.newaxis]
    embedding = model(image_processed)[0][:, 0]
    loss1_1 = (1 - F.cosine_similarity(embedding, ref_embedding1)) * 100
    loss1_2 = F.l1_loss(F.interpolate(image_cropped1.unsqueeze(0), (224, 224)), F.interpolate(TF.to_tensor(ref_image_cropped1).half().to('cuda').unsqueeze(0), (224, 224))) * 0  # optional
    image_processed = (F.interpolate(image_cropped2.unsqueeze(0), (224, 224)) - OPENAI_CLIP_MEAN[..., np.newaxis, np.newaxis]) / OPENAI_CLIP_STD[..., np.newaxis, np.newaxis]
    embedding = model(image_processed)[0][:, 0]
    loss2_1 = (1 - F.cosine_similarity(embedding, ref_embedding2)) * 100
    loss2_2 = F.l1_loss(F.interpolate(image_cropped2.unsqueeze(0), (224, 224)), F.interpolate(TF.to_tensor(ref_image_cropped2).half().to('cuda').unsqueeze(0), (224, 224))) * 1000  # optional
    loss = loss1_1 + loss1_2 + loss2_1 + loss2_2
    t3 = time.time()

    optimizer.zero_grad()
    loss.backward()
    t4 = time.time()
    grad_norm = latents.grad.reshape(4, -1).norm(dim=-1)
    latents.grad /= grad_norm.reshape(4, 1, 1, 1).clamp(min=1)
    optimizer.step()
    t5 = time.time()

    if epoch % 10 == 0:
        print('loss:', loss.data.item())
        print('grad:', grad_norm.tolist())
        print('time:', t1-t0, t2-t1, t3-t2, t4-t3, t5-t4)
        plt.imshow(np.array(image.permute(1, 2, 0).detach().cpu() * 255, dtype=np.uint8))
        plt.axis("off")
        plt.show()
    del image