In [None]:
from inversion_free import *
from utils import *
from transformers.models.clip.modeling_clip import CLIPTextModel
from diffusers import DDIMScheduler, StableDiffusionPipeline
from multi_token_clip import MultiTokenCLIPTokenizer

In [None]:
def init_model(
        pretrained_model_name_or_path,
        seed=0,
):
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder",
                                                 revision=False)
    tokenizer = MultiTokenCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")

    scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,
                              set_alpha_to_one=False)
    ldm_stable = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path, scheduler=scheduler,
                                                         tokenizer=tokenizer, text_encoder=text_encoder).to(device)
    ldm_stable.safety_checker = lambda images, clip_input: (images, False)

    generator = torch.Generator(device=device)
    if seed is not None:
        generator.manual_seed(seed)

    return generator, ldm_stable, device, tokenizer

In [None]:
#将 model_path 设置为 Stable-Diffusion-v1.5 的模型路径或名字
model_path = "runwayml/stable-diffusion-v1-5"
# model_path = "../../s"
generator, model, device, tokenizer = init_model(model_path)

In [None]:
# 设置一些路径和参数
from datetime import datetime
import os

# device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
# 加载图片
# 第一级目录
img_dir_src_1 = "./examples/"
img_dir_tar_1 = img_dir_src_1
# 第二级目录
img_dir_src_2 = "dog"
src_img_dir = os.path.join(img_dir_src_1, img_dir_src_2)
tar_img_dir_2 = "dog_hat"
tar_img_dir = os.path.join(img_dir_tar_1, tar_img_dir_2)


def get_image_file(path):
    img_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.gif')
    img_files = [f for f in os.listdir(path) if f.lower().endswith(img_extensions)]
    return img_files[0]


src_img_file = get_image_file(src_img_dir)
tar_img_file = get_image_file(tar_img_dir)

src_latent = img2latent(os.path.join(src_img_dir, src_img_file), model, device)
tar_latent = img2latent(os.path.join(tar_img_dir, tar_img_file), model, device)

# 加载嵌入向量
# 第一级目录
embed_dir_src_1 = "./output/"
embed_dir_tar_1 = embed_dir_src_1

# 第二级目录
embed_dir_src_2 = "dog"
embed_dir_src_2 = os.path.join(embed_dir_src_1, embed_dir_src_2)
tar_embed_dir_2 = "dog_hat"
tar_embed_dir_2 = os.path.join(embed_dir_tar_1, tar_embed_dir_2)

# 嵌入向量名字
src_embed_dir = "dog_05_08_2024_1919"
src_embed_dir = os.path.join(embed_dir_src_2, src_embed_dir)
tar_embed_dir = "dog_hat_05_08_2024_2000"
tar_embed_dir = os.path.join(tar_embed_dir_2, tar_embed_dir)
# 嵌入向量训练步数
src_steps = 1000
tar_steps = src_steps

src_embedding = torch.load(os.path.join(src_embed_dir, f"{src_steps}.bin")).to(device)
tar_embedding = torch.load(os.path.join(tar_embed_dir, f"{tar_steps}.bin")).to(device)

In [None]:
from attention_control import make_controller

# 设置输出目录并生成
date = datetime.now().strftime("%Y-%m-%d")
output_dir = f"./output_img/{str(date)}-{img_dir_src_2}-to-{tar_img_dir_2}/"
os.makedirs(output_dir, exist_ok=True)

time = datetime.now().strftime("%H-%M-%S")
num_inference_steps = 30
save_all = False
use_attention = True

# src 和 tar 方向的系数
src_weight = 0
tar_weight = 1

output_img_name = f"{time}_{src_img_file[:-4]}2{tar_img_file[:-4]}_{num_inference_steps}_steps_src{src_weight}_tar{tar_weight}"

if use_attention:
    # 设置注意力控制器
    placeholder = ["", "<s2>"]
    cross_injection_ratio = 0.
    self_injection_ratio = 0.7
    eq_param = {
        'words': (placeholder[-1],),
        'values': (0.5,),
    }
    controller = make_controller(
        prompts=placeholder,
        tokenizer=tokenizer,
        is_replace_controller=False,
        cross_replace_steps={'default_': cross_injection_ratio},
        self_replace_steps=self_injection_ratio,
        equilizer_params=eq_param,
    )
else:
    controller = None

# 生成图片
# src_latent = tar_latent
latents = gen_inversion_free(model,
                             src_latent,
                             tar_latent,
                             src_embedding,
                             tar_embedding,
                             num_inference_steps=num_inference_steps,
                             tar_weight=tar_weight, src_weight=src_weight,
                             )

if save_all:
    for i, latent in enumerate(latents):
        img = latent2img(latent.detach(), model)
        print(f"Saving {output_img_name}_{i}.png in {output_dir}")
        img.save(os.path.join(output_dir, f"{output_img_name}_{i}.png"))
else:
    img = latent2img(latents[-1].detach(), model)
    print(f"Saving {output_img_name}.png in {output_dir}")
    img.save(os.path.join(output_dir, f"{output_img_name}.png"))