In [None]:
import torch
from diffusers import StableDiffusionInpaintPipeline,DDIMScheduler
from torchvision.io import read_image, ImageReadMode
import torch.nn.functional as F
from torchvision.transforms.functional import gaussian_blur
from pytorch_lightning import seed_everything
import os
from torchvision.utils import save_image
import cv2
from matplotlib import pyplot as plt

In [2]:
# Switch to "AttentiveEraser" dictionary
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
os.chdir(parent_dir)

In [None]:
dtype = torch.float16
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)

#model_path = "stabilityai/stable-diffusion-2-1-base"
model_path = "/hy-tmp/stable-diffusion-2-1-base" #"stable-diffusion-v1-5" "solarsync_v11"/ change this to the path of the model if you are loading the model offline
                                                    
pipe = StableDiffusionInpaintPipeline.from_pretrained(
    model_path,
    scheduler=scheduler,
    custom_pipeline="./pipelines/pipeline_inp.py",
    torch_dtype=dtype,  
)
pipe.to(device)
pipe.enable_attention_slicing()
pipe.enable_model_cpu_offload()

In [4]:
#freeu can further improve results in some cases(https://github.com/ChenyangSi/FreeU)
from utils import register_free_upblock2d, register_free_crossattn_upblock2d
register_free_upblock2d(pipe, b1=1.4, b2=1.6, s1=0.9, s2=0.2) #2.1
register_free_crossattn_upblock2d(pipe, b1=1.4, b2=1.6, s1=0.9, s2=0.2)#2.1
#register_free_upblock2d(pipe, b1=1.5, b2=1.6, s1=0.9, s2=0.2) #1.5
#register_free_crossattn_upblock2d(pipe, b1=1.5, b2=1.6, s1=0.9, s2=0.2) #1.5

In [None]:

def load_image(image_path, device):
    image = read_image(image_path)
    image = image[:3].unsqueeze_(0).float() / 127.5 - 1.  # [-1, 1]
    if image.shape[1] != 3:
        image = image.expand(-1, 3, -1, -1)
    image = F.interpolate(image, (512, 512), mode="bicubic")
    #image = F.interpolate(image, (768, 768), mode="bicubic")
    image = image.to(dtype).to(device)
    return image

def load_mask(mask_path, device):
    mask = read_image(mask_path,mode=ImageReadMode.GRAY)
    mask = mask.unsqueeze_(0).float() / 255.  # 0 or 1
    mask = F.interpolate(mask, (512, 512), mode="bicubic")
    #mask = F.interpolate(mask, (768, 768), mode="bicubic")
    mask = gaussian_blur(mask, kernel_size=(7,7))
    mask[mask < 0.1] = 0
    mask[mask >= 0.1] = 1
    mask = mask.to(dtype).to(device)
    return mask

seed = 123
seed_everything(seed)
generator=torch.Generator("cuda").manual_seed(seed)
sample = "an" 
out_dir = f"./workdir_inp/{sample}/"
os.makedirs(out_dir, exist_ok=True)
sample_count = len(os.listdir(out_dir))
out_dir = os.path.join(out_dir, f"sample_{sample_count}")
os.makedirs(out_dir, exist_ok=True)
# source image
SOURCE_IMAGE_PATH = f"./examples/img/{sample}.png"
MASK_PATH = f"./examples/mask/{sample}_mask.png"
prompt = ""
source_image = load_image(SOURCE_IMAGE_PATH, device)
mask_an = load_mask(MASK_PATH, device)


In [None]:
from AAS.AAS import AAS ,AAS_768
from AAS.AAS_utils import regiter_attention_editor_diffusers
strength = 0.8
num_inference_steps = 50
START_STEP = 0
END_STEP = int(strength*num_inference_steps)
LAYER = 7 # 0~5down,6mid,7~15up /layer that starting AAS
END_LAYER = 16 # layer that ending AAS

attentionstore = None
#removelist=[6]
layer_idx=list(range(LAYER, END_LAYER))
ss_steps = 9 # similarity suppression steps
ss_scale = 0.3 # similarity suppression scale

# hijack the attention module
editor = AAS(attentionstore,START_STEP, END_STEP, LAYER, END_LAYER,layer_idx= layer_idx, mask=mask_an, ss_steps=ss_steps, ss_scale=ss_scale)
#editor = AAS_768(attentionstore,START_STEP, END_STEP, LAYER, END_LAYER,layer_idx= layer_idx, mask=mask_an, ss_steps=ss_steps, ss_scale=ss_scale)
regiter_attention_editor_diffusers(pipe, editor)

In [None]:
rm_guidance_scale = 9 # removal guidance scale
image = pipe(
            prompt=prompt, 
            image=source_image, 
            mask_image=mask_an,
            num_inference_steps = num_inference_steps,
            strength=strength,
            generator=generator, 
            rm_guidance_scale=rm_guidance_scale,
            guidance_scale = 1,
            return_intermediates = False)

In [8]:
def make_redder(img, mask, increase_factor=0.4):
    img_redder = img.clone()
    mask_expanded = mask.expand_as(img)
    img_redder[0][mask_expanded[0] == 1] = torch.clamp(img_redder[0][mask_expanded[0] == 1] + increase_factor, 0, 1)
    
    return img_redder
img = (source_image* 0.5 + 0.5).squeeze(0)
mask_red = mask_an.squeeze(0)
img_redder = make_redder(img, mask_red)

In [9]:
from torchvision.transforms.functional import to_pil_image, to_tensor
from PIL import Image, ImageFilter
pil_mask = to_pil_image(mask_an.squeeze(0))
pil_mask_blurred = pil_mask.filter(ImageFilter.GaussianBlur(radius=15))
mask_blurred = to_tensor(pil_mask_blurred).unsqueeze_(0).to(mask_an.device)
msak_f = 1-(1-mask_an)*(1-mask_blurred)

In [None]:
out_tile = msak_f * image[-1:] + (1 - msak_f) * (source_image* 0.5 + 0.5)
out_image = torch.concat([img_redder.unsqueeze(0),
                         image[-1:],
                         out_tile],
                         #image[:1]],
                         dim=0)
save_image(out_image, os.path.join(out_dir, f"all_step{END_STEP}_layer{LAYER}.png"))
save_image(out_image[0], os.path.join(out_dir, f"source_step{END_STEP}_layer{LAYER}.png"))
save_image(out_image[1], os.path.join(out_dir, f"AE_step{END_STEP}_layer{LAYER}.png"))
save_image(out_image[2], os.path.join(out_dir, f"AE_tile_step{END_STEP}_layer{LAYER}.png"))
#save_image(out_image[2], os.path.join(out_dir, f"compare_step{END_STEP}_layer{LAYER}.png"))
print("Syntheiszed images are saved in", out_dir)
img_ori = cv2.imread(os.path.join(out_dir, f"all_step{END_STEP}_layer{LAYER}.png"))
img_ori = cv2.cvtColor(img_ori, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(18, 24))
plt.imshow(img_ori)