In [None]:
!pip install transformers compel accelerate gdown diffusers pyrebase4 -U > /dev/null
!gdown 1IohKG23i468616bPHBkDHiBnZvmPW3m5 > /dev/null #нужный csv с новыми промптами, которых нет в базе данных
!gdown 1pF1u8mekNs_z_KvFIJRTdqhlFHm1lp5n > /dev/null

In [None]:
import numpy as np
import random
import sys
import os
import torch
import torchvision
from PIL import Image, ImageDraw, ImageFont, ImageChops
from scipy.spatial.qhull import QhullError
from scipy import spatial
spatial.QhullError = QhullError
from tqdm.auto import tqdm
import io
import glob
import pandas as pd
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import zlib
import requests
from transformers import AutoProcessor, AutoModel
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
import threading
import os
import shutil
from torchvision.transforms.functional import pil_to_tensor
import pyrebase
from transformers import BlipProcessor, BlipForConditionalGeneration
from compel import Compel, ReturnedEmbeddingsType

In [None]:
def load_db():
    config_path = 'firebase_auth.json'
    assert os.path.exists(config_path)

    config = {
      "apiKey": "AIzaSyBnWywH3ZswQNyLblBohBAp__f_F2myt5M",
      "authDomain": "datasetcollect-81ac0.firebaseapp.com",
      "databaseURL": "https://datasetcollect-81ac0-default-rtdb.firebaseio.com",
      "storageBucket": "datasetcollect-81ac0.appspot.com",
      "ServiceAccount": config_path,
    }



    firebase = pyrebase.initialize_app(config)
    db = firebase.database()
    return db

db = load_db()

In [None]:
GEN_BATCH_SIZE = 100

def set_generated(prompt_id):
    db.child("prompts_raw").child(prompt_id).update({
        "is_generated": True
    })

def load_not_generated_db_prompts():
    not_generated_db_prompts = []
    prompts_table = db.child("prompts_raw")

    for prompt_record in prompts_table.get().val().items():
        prompt_id, prompt, is_generated = prompt_record[0], prompt_record[1]["prompt"], prompt_record[1]["is_generated"]
        if not is_generated:
            not_generated_db_prompts.append(prompt)
            set_generated(prompt_id)
        if len(not_generated_db_prompts) >= GEN_BATCH_SIZE:
            break

    return not_generated_db_prompts

prompts = load_not_generated_db_prompts()
print(prompts)

In [None]:
def drop_dots(prompts):
    for i in range(len(prompts)):
        if prompts[i][-1] == ".":
            prompts[i] = prompts[i][:-1]
    return prompts

prompts = drop_dots(prompts)
print(prompts)

In [None]:
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
unet_id = "mhdang/dpo-sdxl-text2image-v1"

pipe1 = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16",
                                                  use_safetensors=True).to("cuda:0")
unet1 = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float16).to('cuda:0')
pipe1.unet = unet1
pipe1 = pipe1.to("cuda:0")

pipe2 = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16",
                                                  use_safetensors=True).to("cuda:1")
unet2 = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float16).to('cuda:1')
pipe2.unet = unet2
pipe2 = pipe2.to("cuda:1")

In [None]:
blip_version = "noamrot/FuseCap"

blip_processor = BlipProcessor.from_pretrained(blip_version)
blip_model1 = BlipForConditionalGeneration.from_pretrained(blip_version).to("cuda:0")

blip_model2 = BlipForConditionalGeneration.from_pretrained(blip_version).to("cuda:1")

In [None]:
def generate_caption(raw_image, device):
    
    inputs = blip_processor(raw_image, return_tensors="pt").to(
        device)
    if device == 'cuda:0':
        out = blip_model1.generate(**inputs)
    else:
        out = blip_model2.generate(**inputs)
    caption = blip_processor.decode(out[0], skip_special_tokens=True)
    return caption.strip()

In [None]:
base_compel1 = Compel(
    tokenizer=[pipe1.tokenizer, pipe1.tokenizer_2],
    text_encoder=[pipe1.text_encoder, pipe1.text_encoder_2],
    returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
    requires_pooled=[False, True]
)

base_compel2 = Compel(
    tokenizer=[pipe2.tokenizer, pipe2.tokenizer_2],
    text_encoder=[pipe2.text_encoder, pipe2.text_encoder_2],
    returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
    requires_pooled=[False, True]
)

In [None]:
def gen_sticker_with_mixed_prompts(prompt_sticker, caption, generator, device, sticker_cond = 'a cartoon sticker', x1=1.0, x2=0.55,x3=1.0):
    prompt = f'"{prompt_sticker}", "{caption}", "{sticker_cond}").blend({x1}, {x2}, {x3})'
    
    if device == 'cuda:0':
    
        base_positive_prompt_embeds, base_positive_prompt_pooled = base_compel1(prompt)
        base_negative_prompt_embeds, base_negative_prompt_pooled = base_compel1('low-quality')
        base_positive_prompt_embeds, base_negative_prompt_embeds = base_compel1.pad_conditioning_tensors_to_same_length([
            base_positive_prompt_embeds, base_negative_prompt_embeds
        ])

        image_sticker_new = pipe1(prompt_embeds=base_positive_prompt_embeds,
            pooled_prompt_embeds=base_positive_prompt_pooled,
            negative_prompt_embeds=base_negative_prompt_embeds,
            negative_pooled_prompt_embeds=base_negative_prompt_pooled,
            output_type="pil", guidance_scale=5, generator=generator).images[0]
    else:
            
        base_positive_prompt_embeds, base_positive_prompt_pooled = base_compel2(prompt)
        base_negative_prompt_embeds, base_negative_prompt_pooled = base_compel2('low-quality')
        base_positive_prompt_embeds, base_negative_prompt_embeds = base_compel2.pad_conditioning_tensors_to_same_length([
            base_positive_prompt_embeds, base_negative_prompt_embeds
        ])

        image_sticker_new = pipe2(prompt_embeds=base_positive_prompt_embeds,
            pooled_prompt_embeds=base_positive_prompt_pooled,
            negative_prompt_embeds=base_negative_prompt_embeds,
            negative_pooled_prompt_embeds=base_negative_prompt_pooled,
            output_type="pil", guidance_scale=5, generator=generator).images[0]
        
        
    return image_sticker_new


def generate_seed_from_prompt(prompt):
    return zlib.adler32(prompt.encode())

def get_rand_gen(prompt, device):
    rnd_seed = generate_seed_from_prompt(prompt)
    return torch.Generator(device).manual_seed(rnd_seed), rnd_seed

def get_pair_data(prompt_sticker, device):
    prompt = prompt_sticker.replace('A sticker', 'A photo', 1)
    prompt += ', 8k, fullHD, realistic photography.'
    prompt_sticker += ', cartoon style.'
    generator, rnd_seed = get_rand_gen(prompt, device)
    
    if device == 'cuda:0':
        image = pipe1(prompt, guidance_scale=5, generator=generator).images[0]
        img_cap = generate_caption(image, device)
        image_sticker1 = pipe1(prompt_sticker, guidance_scale=5, generator=generator).images[0]
    else:
        image = pipe2(prompt, guidance_scale=5, generator=generator).images[0]
        img_cap = generate_caption(image, device)
        image_sticker1 = pipe2(prompt_sticker, guidance_scale=5, generator=generator).images[0]
        
    image_sticker2 = gen_sticker_with_mixed_prompts(prompt_sticker, img_cap, generator, device)

    return (prompt_sticker,image_sticker1,image_sticker2, prompt,image,rnd_seed, img_cap)


def runner(prompt_stickers, device):
    global max_results, result
    
    count = 0
    for prompt_sticker in tqdm(prompt_stickers):
        data = get_pair_data(prompt_sticker, device)
        count += 1
        result.append(data)
                
        if count == max_results // 2:
            break
    
def main(prompt_stickers):
    try:
        p1, p2 = prompt_stickers[:len(prompt_stickers) // 2], prompt_stickers[(len(prompt_stickers) // 2) : ]
    
        t1 = threading.Thread(target = runner, args=(p1, 'cuda:0',), daemon=False)
        t2 = threading.Thread(target = runner, args=(p2, 'cuda:1',), daemon=False)
        t1.start()
        t2.start()
        t1.join()
        t2.join()
    except KeyboardInterrupt:
        pass


In [None]:
result = []

main(prompts)

In [None]:
!mkdir imgs

In [None]:
def dump_sticker_info(info):
    prompt_sticker, image_sticker, image_sticker_mixed, \
        prompt_photo, image_photo, rnd_seed, image_photo_caption = info

    sticker_record = {
        "prompt_sticker": prompt_sticker,
        "caption": image_photo_caption,
        "image_photo_name": f"{prompt_photo}_{rnd_seed}.png",
        "image_sticker_name": f"{prompt_sticker}_{rnd_seed}.png",
        "image_sticker_mixed_name": f"mixed_{prompt_sticker}_{rnd_seed}.png"
    }

    db.child("prompts_generated").push(sticker_record)

    image_photo.save(os.path.join("imgs", sticker_record["image_photo_name"]))
    image_sticker.save(os.path.join("imgs", sticker_record["image_sticker_name"]))
    image_sticker_mixed.save(os.path.join("imgs", sticker_record["image_sticker_mixed_name"]))


In [None]:
for info in result:
    dump_sticker_info(info)

In [None]:
!zip -r imgs_result.zip imgs > /dev/null

In [None]:
show_res = True
N_display = 10

if show_res:
    for idx, item in enumerate(result[:N_display]):
        prompt_sticker,image_sticker1,image_sticker2, prompt,image,rnd_seed, img_cap = item
        
        fig, axs = plt.subplots(1, 3, figsize=(20*3, 10*3))

        axs[0].imshow(image)
        axs[0].set_title(prompt)
        axs[0].axis('off') 

        axs[1].imshow(image_sticker1)
        axs[1].set_title(prompt_sticker)
        axs[1].axis('off')
        
        axs[2].imshow(image_sticker2)
        axs[2].set_title(img_cap)
        axs[2].axis('off')


        plt.show()