In [None]:
import os, json
import argparse
from collections import OrderedDict
import torch, cv2
from PIL import Image

from diffusers import StableDiffusionPipeline
from templates.templates import inference_templates
import matplotlib.pyplot as plt
from typing import Optional, Union, Tuple, List, Callable, Dict
from diffusers import StableDiffusionPipeline, DDIMScheduler
import torch.nn.functional as nnf
import numpy as np
import abc
# import ptp_utils
# import seq_aligner
import shutil
from torch.optim.adam import Adam
from PIL import Image
from NullextInversion import NullInversion
import math
from lora import (
    save_lora_weight,
    TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
    get_target_module,
    save_lora_layername,
    monkeypatch_or_replace_lora,
    monkeypatch_remove_lora,
    set_lora_requires_grad,
    tune_lora_scale
)
from utils import encoder, p2p, ptp_utils

def make_image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    h, w = imgs[0].shape[0],imgs[0].shape[1]
    grid = np.zeros((rows*h, cols*w, 3)).astype(np.uint8)
    

    for i, img in enumerate(imgs):
        hi = i // cols  # Corrected: Calculate row index
        wi = i % cols   # Corrected: Calculate column index
        grid[hi*h:hi*h+h, wi*w:wi*w+w,:] = img.copy()
    return grid

In [None]:
LOW_RESOURCE = False 
NUM_DDIM_STEPS = 50
GUIDANCE_SCALE = 7.5
MAX_NUM_WORDS = 77
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

def parse_args():
    parser = argparse.ArgumentParser(description="")
    parser.add_argument(
        "--prompt",
        type=str,
        default="A man {} a man",
        help="input a single text prompt for generation",
    )
    parser.add_argument(
        "--template_name",
        type=str,
        help="select a batch of text prompts from templates.py for generation",
    )
    parser.add_argument(
        "--model_id",
        type=str,
        default = "/internfs/xxxxxx/huggingface_models/stable-diffusion-v1-5",
        help="absolute path to the folder that contains the trained results",
    )
    parser.add_argument(
        "--exp_id",
        type=str,
        help="absolute path to the folder that contains the trained results",
    )
    parser.add_argument(
        "--inference_string",
        type=str,
        default="<R>",
        help="inference_string of the relation prompt",
    )
    parser.add_argument(
        "--placeholder_string",
        type=str,
        default="<R>",
        help="place holder string of the relation prompt",
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=10,
        help="number of samples to generate for each prompt",
    )
    parser.add_argument(
        "--guidance_scale",
        type=float,
        default=7.5,
        help="scale for classifier-free guidance",
    )
    parser.add_argument(
        "--strength",
        type=float,
        default=0.5,
        help="strength for inversion",
    )
    parser.add_argument(
        "--only_load_embeds",
        action="store_true",
        default=False,
        help="If specified, the experiment folder only contains the relation prompt, but does not contain the entire folder",
    )
    parser.add_argument(
        "--lora",
        action="store_true",
        default=False,
        help="If specified, utilize lora weights for the generation",
    )
    parser.add_argument(
        "--lora_scale",
        type=float,
        default=0.8,
        help="scale for lora weights",
    )
    parser.add_argument(
        "--noise",
        action="store_true",
        default=False,
        help="If specified, utilize noise weights for the generation",
    )
    parser.add_argument(
        "--pretrain",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--pretrain_id",
        type=str,
        default = "/internfs/xxxxxx/concept_customization/noise_opti/shake_hands",
        help="absolute path to the folder that contains the trained results",
    )
    parser.add_argument(
        "--ori",
        action="store_true",
        default=False,
        help="If specified, utilize original noise weights for the generation",
    )
    parser.add_argument(
        "--splitid",
        type=str,
        default=None,
        help="If specified, utilize noise weights for the generation",
    )
    parser.add_argument(
        "--paste",
        action="store_true",
        default=False,
        help="If specified, copy the trained noise into the specific patch(bbox) for the intial noise",
    )
    parser.add_argument(
        "--dp",
        action="store_true",
        help="to do a dp image training or a natural relation training",
    )
    parser.add_argument(
        "--train_data_dir",
        type=str,
        default='./reversion_benchmark_v1/shake_hands',
        # required=True,
        help=
        "The folder that contains the exemplar images (and coarse descriptions) of the specific relation."
    )
    parser.add_argument(
        "--offset_scale",
        type=float,
        default=0.8,
        help="scale for offset_noise weights",
    )
    parser.add_argument(
        "--encoder",
        type=str,
        default=None,
        # required=True,
        help=
        "The folder that contains the encoder of ddim-style noise."
    )
    parser.add_argument(
        "--noise_mode",
        type=str,
        default='warp',
        help="mode to get the initial latent noise, such as random, DDIM, encoder, or else.",
    )

    args = parser.parse_args(args=["--only_load_embeds", "--lora_scale","0.6", "--lora","--noise"
                ])
    return args
args = parse_args()

# args.lora_root = "noise_astar_diffe/n1.0_a0.0_m0.1_detach/shake_hands_t0.0002_lora0.0001"



In [None]:
args.lora_scale=0.6
args.exp_id = "your/path"     # lora weights

# args.exp_id = f"./noise_astar_diffe/n1.0_a0.0_m{mix_weights[e1]}_detach/shake_hands_t{lrs[e2]}_lora{loras[e3]}"

args.encoder = "your/path" # noise encoder

args.lora_root = args.exp_id

scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset = 1)
MY_TOKEN = ''
if args.train_data_dir[-1] == '/': args.train_data_dir = args.train_data_dir[:-1] 
noise_embeds_root = args.train_data_dir + '_noise'
local_f = open(os.path.join(args.train_data_dir, 'text.json'))
templates = json.load(local_f, object_pairs_hook=OrderedDict)
noise_dict = {}
for index, (key, value) in enumerate(templates.items()):
    noise_dict[key] = index
local_f = open(os.path.join(args.train_data_dir, 'bbox.json'))
all_bboxes = json.load(local_f, object_pairs_hook=OrderedDict)




# create inference pipeline
if args.only_load_embeds:

    print('load relation prompt only')
    learned_lora = None
    # ddim_noise = None
    for filename in os.listdir(args.exp_id):
        if filename.endswith('loraemb.pth'):
            embed_path = os.path.join(args.lora_root, filename)
            learned_embeds = torch.load(embed_path, weights_only=True)
        elif filename.endswith('lora.pth'):
            lora_path = os.path.join(args.exp_id, filename)
            learned_lora = torch.load(lora_path, weights_only=True)
        # elif filename.endswith('noisemb.pth'):
    exp_name = args.train_data_dir.split('/')[-1]
    noise_path = os.path.join(args.pretrain_id, f'{exp_name}_noisemb_ori.pth' if args.ori else f'{exp_name}_noisemb.pth')

    pipe = StableDiffusionPipeline.from_pretrained(args.model_id,torch_dtype=torch.float32, scheduler=scheduler).to("cuda")
    
    text_encoder = pipe.text_encoder
    tokenizer = pipe.tokenizer
    unet = pipe.unet
    
    # keep original embeddings as reference
    orig_embeds_params = text_encoder.get_input_embeddings().weight.data.clone()

    # Add the placeholder token in tokenizer
    tokenizer.add_tokens(args.placeholder_string)
    text_encoder.get_input_embeddings().weight.data = torch.cat((orig_embeds_params, orig_embeds_params[0:1]))
    text_encoder.resize_token_embeddings(len(tokenizer)) 

    # Let's make sure we don't update any embedding weights besides the newly added token
    placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_string)
    index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
    text_encoder.get_input_embeddings().weight.data[index_no_updates] = orig_embeds_params
    text_encoder.get_input_embeddings().weight.data[placeholder_token_id] = learned_embeds[0]


    if args.lora and (learned_lora is not None) and args.lora_scale>0:
        unet_lora_params = None
        use_lora_extended = False
        lora_unet_rank = 32
        lora_txt_rank = 32
        injectable_lora = get_target_module("injection", use_lora_extended)
        target_module = get_target_module("module", use_lora_extended)

        monkeypatch_or_replace_lora(unet, learned_lora, r = lora_unet_rank)

        tune_lora_scale(unet, args.lora_scale)
else:
    # now this works
    print('load full model')
    pipe = StableDiffusionPipeline.from_pretrained(args.model_id,torch_dtype=torch.float32).to("cuda")

unet.eval()
unet.enable_gradient_checkpointing()
# inversion for a existing image
null_inversion = NullInversion(pipe)
blend_word = None
key_words = ['man']

# blend_word = ((('people')))
eq_params = None
args.strength = 1.0

noise_embedding_encoder = encoder.Warpper()
noise_embedding_encoder = torch.load(os.path.join(args.encoder,'{}_noisewarpper.pth'.format(exp_name)))
            
    

# single text prompt
if args.prompt is not None:
    prompt_list = [args.prompt]


In [None]:
#for copy-paste methods: random sample or user decide a exemplar image, and load its DDIM inversion noise
args.splitid = "7-1"
if f'{args.splitid}.png' in all_bboxes.keys():
    splitid_name = f'{args.splitid}.png'
else:
    splitid_name = f'{args.splitid}.jpg'
bboxes = all_bboxes[splitid_name] 
ddim_noise =  torch.load(os.path.join(noise_embeds_root, f'{args.splitid}.pth')  ,map_location = device, weights_only=True)
if args.splitid is not None: 
    relation_index = noise_dict[splitid_name]
src_img = cv2.imread(os.path.join(args.train_data_dir, splitid_name))
h, w = src_img.shape[0], src_img.shape[1]

In [None]:
prompt = prompt_list[0]
prompt = prompt.lower().replace("<r>", "<R>").format(args.inference_string.replace('-',' '))


resolution = 512


args.encoder = 'noise_new_encoder_shake_hands/'
noise_embedding_encoder = torch.load(os.path.join(args.encoder,'{}_noisewarpper.pth'.format(exp_name)))


prompt = "a woman <R> a woman"
# random_seed = 123
# torch.cuda.manual_seed_all(random_seed)

noise = torch.randn_like(ddim_noise)
mix_noise = noise.clone()
noise_offsets = []
mix_weights = 1
R_attn_mask = torch.zeros((resolution//8, resolution//8)).to("cuda")

min_height, min_width, max_height, max_width = bboxes[0]
mask = torch.zeros((1,1,noise.shape[2], noise.shape[3])).to(device)
mask[0:1,:,min_height//8:max_height//8, min_width//8:max_width//8] = 1

input_warp = torch.cat([mask, noise], 1)
output_warp = noise_embedding_encoder(input_warp)
base_encoder_noise = output_warp[:,:,min_height//8:max_height//8, min_width//8:max_width//8].clone()
base_mix_noise = ddim_noise[:,:,min_height//8:max_height//8, min_width//8:max_width//8].clone()

bbox_fix = []
mix_sets = []
encoder_sets = []
grid_size = 8

height_fix, width_fix = -8 * 8, 10 * 8
        
mix_noise = noise.clone()
for bbox in bboxes:
    min_height, min_width, max_height, max_width = bbox
    min_width += width_fix
    max_width += width_fix
    min_height += height_fix
    max_height += height_fix
    mix_noise[:,:,min_height//8:max_height//8, min_width//8:max_width//8] = base_mix_noise 

    R_attn_mask[min_height//8:max_height//8, min_width//8:max_width//8] = 1


    mask = torch.zeros((1,1,noise.shape[2], noise.shape[3])).to(device)
    mask[0:1,:,min_height//8:max_height//8, min_width//8:max_width//8] = 1

    input_warp = torch.cat([mask, noise], 1)
    output_warp = noise_embedding_encoder(input_warp)
    encoder_noise = mix_weights * output_warp * mask + (1 - mix_weights) *noise * mask + noise * (1-mask) 


ksize = 7
R_attn_mask = ptp_utils.tensor_dilate(R_attn_mask.view(1,1,resolution//8,resolution//8), ksize = ksize)[0,0].float()


In [None]:


images = null_inversion.noise_inference(noise, prompt.replace('<R>', 'and'))
img0 = images[0].copy()

min_height, min_width, max_height, max_width = bboxes[0]
mix_imgs, encoder_imgs = [], []
src_mix_imgs, src_encoder_imgs = [], []

    
images = null_inversion.noise_inference(mix_noise, prompt)
img1 = images[0].copy()
src_mix_imgs = images[0].copy()
cv2.rectangle(img1, (min_width+width_fix, min_height+height_fix ), (max_width+width_fix, max_height+height_fix), (255, 0, 0), 2)
mix_img_ = img1


images = null_inversion.noise_inference(encoder_noise, prompt)
img2 = images[0].copy()
src_encoder_imgs = images[0].copy()
cv2.rectangle(img2, (min_width+width_fix, min_height+height_fix ), (max_width+width_fix, max_height+height_fix), (255, 0, 0), 2)
encoder_img_ = img2

final_bbox = [min_height+height_fix, min_width+width_fix, max_height+height_fix, max_width+width_fix]
# cv2.rectangle(img0, (min_width-width_fix, min_height-height_fix ), (max_width-width_fix, max_height-height_fix ), (0, 0, 255), 2)
    
    
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(img0)
axes[0].axis('off')  # 关闭坐标轴
axes[0].set_title("base noise")

# 在第二列显示第二张图像
axes[1].imshow(mix_img_)
axes[1].axis('off')  # 关闭坐标轴
axes[1].set_title("Copy-Paste")
# 在第二列显示第二张图像
axes[2].imshow(encoder_img_)
axes[2].axis('off')  # 关闭坐标轴
axes[2].set_title("Ours")

# 显示图像
plt.tight_layout()  # 自动调整子图间距
plt.show()

In [None]:
idi = 0

image_folder = "./2-results/shake_hands"
os.makedirs(image_folder, exist_ok = True)

name = os.path.join(image_folder, f'{idi}.pth')
if os.path.exists(name):
    print('ERROR: existing name of random noise!!!')
else:
    torch.save(noise, name )

    cv2.imwrite(os.path.join(image_folder, f'{idi}-src.png'), cv2.cvtColor(img0, cv2.COLOR_RGB2BGR))
    cv2.imwrite(os.path.join(image_folder, f'{idi}-bbox-copypaste.png'), cv2.cvtColor(mix_img_, cv2.COLOR_RGB2BGR))
    cv2.imwrite(os.path.join(image_folder, f'{idi}-bbox-encodernoise.png'), cv2.cvtColor(encoder_img_, cv2.COLOR_RGB2BGR))

    cv2.imwrite(os.path.join(image_folder, f'{idi}-copypaste.png'), cv2.cvtColor(src_mix_imgs, cv2.COLOR_RGB2BGR))
    cv2.imwrite(os.path.join(image_folder, f'{idi}-encodernoise.png'), cv2.cvtColor(src_encoder_imgs, cv2.COLOR_RGB2BGR))
    
    paras = {}
    paras['prompt'] = prompt
    paras['bbox'] = final_bbox
    paras['lora_scale'] = args.lora_scale
    paras['exp_id'] = args.exp_id
    paras['encoder'] = args.encoder

    para_name = f'{idi}.json'
    with open(os.path.join(image_folder,para_name), 'w') as f:
        json.dump(paras, f, indent=2, sort_keys=True, ensure_ascii=False) 
    idi += 1
