In [None]:
%load_ext autoreload
%autoreload 2

import os
from PIL import Image, ImageOps
import requests
import torch
import matplotlib.pyplot as plt
import numpy as np
import json
import torch
import requests
from tqdm import tqdm
from io import BytesIO

from diffusersgrad import StableDiffusionInpaintPipeline
import torchvision.transforms as T

from utils import preprocess, recover_image
to_pil = T.ToPILImage()

model_id_or_path = "runwayml/stable-diffusion-inpainting"
#model_id_or_path = "stabilityai/stable-diffusion-2-inpainting"

image_inpainting = StableDiffusionInpaintPipeline.from_pretrained(
    model_id_or_path,
    revision="fp16", 
    torch_dtype=torch.float16,
)
image_inpainting = image_inpainting.to("cuda")

with open("dataset/prompts.json", "r") as f:
    prompts_dict = json.load(f)
image_names = list(prompts_dict.keys())

In [None]:
# Selected image ids for image inpainting model
image_idxs = [0,2,7,8,16,17,19,20,23,25,26,27,31,33,35,36,41,42,43,44,45,46,49,50,51,52,54,55,56,58,61,63,67,71,
77,78,82,83,85,87,88,89,91,92,97,98,99,100,116,118,129,131,133,134,137,138,148,155,157,159,162,166,167,172,
178,182,183,184,186,189,192,193,195,197,199,201,206,207,208,209,212,213,214,215,216,217,231,235,237,239,240,
241,250,251,252,254,255,256,266,268,269,271,272,279,285,286,287,288,293,299,300,301,310,321,326,327,331,332,
333,338,339,343,345,349,351,355,356,361,363,364,367,369,370,375,377,378,380,382,383,387,390,391,397]

# A fixed random selected seed in all the experiments
SEED = 9209
torch.manual_seed(SEED)
strength = 0.7
guidance_scale = 7.5
num_inference_steps = 100

torch.cuda.empty_cache()
MSE = torch.nn.MSELoss()

class BIM_inpainting(object):
    def __init__(self, model, epsilon=0.1, iteration=15, step_length=0.01):
        self.model = model
        self.record_features = []
        self._register_model()
        self.eps = epsilon
        self.T = iteration
        self.step_length = step_length
        self.feature_ori = []
        
    def _register_model(self): 
        def obtain_output_feature(module, feature_in, feature_out):
            self.record_features.append(feature_out[0])
        # Encoding
        self.hook = self.model.vae.encoder.register_forward_hook(obtain_output_feature) # encoder
        #self.hook = self.model.vae.quant_conv.register_forward_hook(obtain_output_feature) # quant conv
        
        # Unet
        #self.hook = self.model.unet.down_blocks[1].attentions[0].transformer_blocks[0].attn1.register_forward_hook(obtain_output_feature) # self-attn
        #self.hook = self.model.unet.down_blocks[1].attentions[0].transformer_blocks[0].attn2.register_forward_hook(obtain_output_feature) # cross-attn
        #self.hook = self.model.unet.down_blocks[1].attentions[0].transformer_blocks[0].ff.register_forward_hook(obtain_output_feature) # feed-forward
        #self.hook = self.model.unet.down_blocks[1].resnets[0].register_forward_hook(obtain_output_feature) # resnet
        
        # Decoding
        #self.hook = self.model.vae.post_quant_conv.register_forward_hook(obtain_output_feature) # post quant conv
        #self.hook = self.model.vae.decoder.register_forward_hook(obtain_output_feature) # decoder
        
        # other trials
        # downblock is the best choice
        #self.hook = self.model.unet.mid_block.attentions[0].transformer_blocks[0].attn1.register_forward_hook(obtain_output_feature)
        #self.hook = self.model.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn2.register_forward_hook(obtain_output_feature)

    def generate(self, image, prompt, mask_image, strength=0.7,guidance_scale=7.5,num_inference_steps=100):
        with torch.no_grad():
            torch.manual_seed(SEED)
            img_tmp = self.model(prompt=prompt, image=image, mask_image = mask_image, strength=strength, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images[0]
        return img_tmp
    
    def attack(self, ori_image, prompt, mask_image, strength=0.7,guidance_scale=7.5,num_inference_steps=15):
        for i in range(self.T):
            if i==0:
                self.record_features = []
                torch.manual_seed(SEED)
                with torch.no_grad():
                    #img_tmp = self.model(prompt=prompt, image=ori_image, mask_image = mask_image, strength=strength, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images[0]
                    img_tmp = self.model(prompt=prompt, image=ori_image, mask_image = mask_image, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images[0]
                for f_id in range(len(self.record_features)):
                    self.feature_ori.append(self.record_features[f_id])
                    
                ori = np.array(ori_image).astype(np.float32) / 255.0
                ori = ori[None].transpose(0, 3, 1, 2)
                ori_image = torch.from_numpy(ori)
                ori_mask = np.array(mask_image).astype(np.float32) / 255.0
                ori_mask = ori_mask[None].transpose(0, 3, 1, 2)
                ori_mask_image = torch.from_numpy(ori_mask)
                ori_mask_image = ori_mask_image[:,0,:,:]
                # initialize with a small noise to start attack
                adv_image = ori_image+torch.normal(0.0, 0.1, size=ori_image.shape)
                adv_image = torch.clamp(adv_image,0.0,1.0)
                adv_image = adv_image.cuda()
                ori_image = ori_image.cuda()
                ori_mask_image = ori_mask_image.cuda()
                adv_image.requires_grad_()
            del self.record_features
            self.record_features = []
            torch.manual_seed(SEED)
            adv_image.requires_grad_()
            torch.cuda.empty_cache()
            #img_tmp = self.model(prompt=prompt, image=adv_image, mask_image = ori_mask_image, strength=strength, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images[0]
            img_tmp = self.model(prompt=prompt, image=adv_image, mask_image = ori_mask_image, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images[0]
            
            cost = torch.tensor(0).half().cuda()
            for f_id in range(len(self.record_features)):
                cost += MSE(self.record_features[f_id], self.feature_ori[f_id])
            cost = cost.requires_grad_()
            grad, = torch.autograd.grad(cost, [adv_image])

            grad = grad/torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)
            adv_image = adv_image + self.step_length * grad.sign()
            pert = torch.clamp(adv_image - ori_image, -self.eps, self.eps)
            adv_image = ori_image + pert
            adv_image = torch.clamp(adv_image,0.0,1.0)
            del pert, grad, cost, img_tmp
            torch.cuda.empty_cache()
            
        return adv_image

In [None]:
image_folder = "dataset/images_crop/"
mask_folder = "dataset/mask_crop/"
folder = "data/inpainting/encoder/"
if not os.path.exists(folder):
    os.mkdir(folder)

attack = BIM_inpainting(image_inpainting)

# We do experiments on 100*5 data triplets
start = 0
for i in range(start, 100):
    image_idx = image_idxs[i]
    print(i)
    torch.cuda.empty_cache()
    prompts = prompts_dict[image_names[image_idx]]
    ori_image = Image.open(image_folder + image_names[image_idx]).convert('RGB').resize((512,512))
    mask_image = Image.open(mask_folder + image_names[image_idx]).convert('RGB').resize((512,512))
    mask_image = ImageOps.invert(mask_image).resize((512,512))
    for j in range(5):
        prompt = prompts[j]
        torch.manual_seed(SEED)
        img = attack.attack(ori_image, prompt, mask_image)
        img = to_pil(img[0]).convert("RGB")
        save_path = folder + image_names[image_idx][:-4]+ "_" + str(j) + ".png"
        img.save(save_path)