In [None]:
import wplus_utils
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
import ptp_utils

In [None]:
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1)
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", local_files_only=True, scheduler=scheduler, force_download=True).to(device)
try:
    ldm_stable.disable_xformers_memory_efficient_attention()
except AttributeError:
    print("Attribute disable_xformers_memory_efficient_attention() is missing")
tokenizer = ldm_stable.tokenizer

In [None]:
image_path = "./example_images/a black bear climb a tree in rain.png"
prompt = "a black bear climb a tree in rain"

matrix_inversion = wplus_utils.MatrixInversion(ldm_stable,inner_steps_num=None,lambda_norm=0e-7,use_freq=False,use_attn_loss=False) #derain 1e-7
(image_gt, image_enc), x_t, uncond_embeddings, w_matrices = matrix_inversion.invert(image_path, prompt, offsets=(0,0,0,0), num_inner_steps=10, verbose=True, learning_rate=1e1)
# w_modify_end(w_matrices,0.5)

In [None]:
prompts = [prompt]
cross_replace_steps = 0.0
self_replace_steps = 0.0
tao = 0.9
negative_prompt = "rain"

controller = wplus_utils.WplusAttentionStore(cross_replace_steps=cross_replace_steps,self_replace_steps=self_replace_steps)
image_derain, x_t = wplus_utils.run_and_display(ldm_stable, prompts, controller, run_baseline=False, latent=x_t, uncond_embeddings=None, optimize_matrices=w_matrices, negative_prompt=negative_prompt, verbose=False, tao=tao)
wplus_utils.show_cross_attention(controller,16,["up", "down"],prompts,ldm_stable,0,negative_prompt=negative_prompt)
controller = wplus_utils.AttentionStore()
image_inv, x_t = wplus_utils.run_and_display(ldm_stable, prompts, controller, run_baseline=False, latent=x_t, uncond_embeddings=None, optimize_matrices=w_matrices, verbose=False)
wplus_utils.show_cross_attention(controller,16,["up", "down"],prompts,ldm_stable,0)
# show_self_attention_comp(controller,16,["up", "down"],10,0)
print("showing from left to right: the ground truth image, w+ reconstruction, w+ derain")
ptp_utils.view_images([image_gt, image_inv[0], image_derain[0]])

p = wplus_utils.compare_psnr(image_gt, image_inv[0])
s = wplus_utils.compare_ssim(image_gt, image_inv[0], multichannel=True, channel_axis=2)  # 对于多通道图像(RGB、HSV等)关键词multichannel要设置为True
m = wplus_utils.compare_mse(image_gt, image_inv[0])
 
print('PSNR：{}，SSIM：{}，MSE：{}'.format(p, s, m))