In [None]:
from diffusers import DDIMScheduler
from diffusers import StableDiffusionPipeline, AutoencoderKL
import torch

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
from multi_token_clip import MultiTokenCLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPTextModel

# 将 model_path 设置为 Stable-Diffusion-v1.5 的模型路径或名字
model_or_name = "/root/autodl-tmp/model/"

image_encoder_path = "/root/autodl-tmp/ip-composition-adapter/sd15_models/image_encoder"

ip_ckpt = "/root/autodl-tmp/ip-composition-adapter/sd15_models/ip-adapter_sd15.bin"

scheduler = DDIMScheduler.from_pretrained(model_or_name, subfolder="scheduler")
# text_encoder = CLIPTextModel.from_pretrained(model_or_name, subfolder="text_encoder", revision=False)
# tokenizer = MultiTokenCLIPTokenizer.from_pretrained(model_or_name, subfolder="tokenizer")
vae = AutoencoderKL.from_pretrained(model_or_name, subfolder="vae").to(dtype=torch.float16)
pipeline = StableDiffusionPipeline.from_pretrained(
    model_or_name, scheduler=scheduler, torch_dtype=torch.float16,
    vae=vae,
)

# 设置随机种子
seed = 4
generator = torch.Generator(device=device)
generator.manual_seed(seed)

pipeline.generator = generator

In [None]:
# 设置一些路径和参数
from inversion_free import gen_inversion_free, gen_inversion_free_test
from utils import img2latent, latent2img
from datetime import datetime
from PIL import Image
import os
from ip_adapter import IPAdapter

ip_pipeline = IPAdapter(pipeline, image_encoder_path, ip_ckpt, device)

# 加载图片
# 第一级目录
img_dir_src_1 = "./DVCT/examples/"
img_dir_tar_1 = img_dir_src_1
# 第二级目录
img_dir_src_2 = "dog"
src_img_dir = os.path.join(img_dir_src_1, img_dir_src_2)
tar_img_dir_2 = "cat_hat"
tar_img_dir = os.path.join(img_dir_tar_1, tar_img_dir_2)

def get_image_file(path):
    img_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.gif')
    img_files = [f for f in os.listdir(path) if f.lower().endswith(img_extensions)]
    return img_files[0]

src_img_file = get_image_file(src_img_dir)
tar_img_file = get_image_file(tar_img_dir)

pipeline = pipeline.to(device)

src_image = Image.open(os.path.join(src_img_dir, src_img_file))
tar_image = Image.open(os.path.join(tar_img_dir, tar_img_file))

src_embedding, src_embedding_uc = ip_pipeline.get_image_embeds(src_image)
tar_embedding, tar_embedding_uc = ip_pipeline.get_image_embeds(tar_image)
print(src_embedding.shape, tar_embedding.shape)

src_latent = img2latent(os.path.join(src_img_dir, src_img_file), pipeline)
tar_latent = img2latent(os.path.join(tar_img_dir, tar_img_file), pipeline)

In [None]:
from attention_control import make_controller

# 设置输出目录并生成
date = datetime.now().strftime("%Y-%m-%d")
output_dir = f"./output_img/{str(date)}-{img_dir_src_2}-to-{tar_img_dir_2}-IP/"
os.makedirs(output_dir, exist_ok=True)

time = datetime.now().strftime("%H-%M-%S")
num_inference_steps = 20
save_all = False

for attn in [False]:
    for cfg in [3]:
        # 是否使用注意力控制器
        use_attention = attn
        
        # 选择权重的类型
        inclination = "none-tar"
        mode = "cosine"
        cfg_guidance_scale = cfg
        
        # src 和 tar 方向的系数
        src_coefficient = 0
        tar_coefficient = 0
        
        output_img_name = f"{time}_{src_img_file[:-4]}2{tar_img_file[:-4]}_{num_inference_steps}steps_{inclination}_{mode[:3]}_cfg{cfg_guidance_scale}_src{src_coefficient}_tar{tar_coefficient}_attn{use_attention}"
        
        # if use_attention:
        #     # 设置注意力控制器
        #     placeholder = [src_placeholders, tar_placeholders]
        #     cross_injection_ratio = 0.2
        #     self_injection_ratio = 0.7
        #     eq_param = {
        #         'words' : (placeholder[-1],),
        #         'values' : (0.5,),
        #     }
        #     controller = make_controller(
        #         prompts=placeholder,
        #         tokenizer=tokenizer,
        #         is_replace_controller=False,
        #         cross_replace_steps={'default_': cross_injection_ratio},
        #         self_replace_steps=self_injection_ratio,
        #         equilizer_params=eq_param,
        #     )
        #     model_2 = None
        # else:
        #     controller = None
        #     model_2 = None
            
        # 生成图片
        # latents = gen_inversion_free(
        #     model.to(device), src_latent, tar_latent, src_embedding, tar_embedding,
        #     num_inference_steps=num_inference_steps, mode=mode, inclination=inclination,
        #     cfg_guidance=cfg_guidance_scale, src_coef=src_coefficient, tar_coef=tar_coefficient,
        #     controller=controller, return_all=save_all,
        # )
        latents = gen_inversion_free_test(
            pipeline, src_latent, tar_embedding, tar_embedding_uc,
            num_inference_steps=num_inference_steps,
            cfg_guidance=cfg_guidance_scale, return_all=save_all,
        )
        
        
        if save_all:
            for i, latent in enumerate(latents):
                img = latent2img(latent[0].detach(), pipeline)
                print(f"Saving {output_img_name}_{i}.png in {output_dir}")
                img.save(os.path.join(output_dir, f"{output_img_name}_{i}.png"))
        else:
            img = latent2img(latents[-1].detach(), pipeline)
            print(f"Saving {output_img_name}.png in {output_dir}")
            img.save(os.path.join(output_dir, f"{output_img_name}.png"))