In [1]:
%load_ext autoreload
%autoreload 2

import os
from PIL import Image, ImageOps
import requests
import torch
import matplotlib.pyplot as plt
import numpy as np
import json
import torch
import requests
from tqdm import tqdm
from io import BytesIO
from diffusersgrad import StableDiffusionImg2ImgPipeline, StableDiffusionInstructPix2PixPipeline
import torchvision.transforms as T

from utils import preprocess, recover_image
to_pil = T.ToPILImage()

model_id_or_path = "runwayml/stable-diffusion-v1-5"
#model_id_or_path = "CompVis/stable-diffusion-v1-4"
#model_id_or_path = "stabilityai/stable-diffusion-2-1"

image_variation = StableDiffusionImg2ImgPipeline.from_pretrained(
    model_id_or_path,
    revision="fp16", 
    torch_dtype=torch.float16,
)

#model_id_or_path = "timbrooks/instruct-pix2pix"

#image_variation = StableDiffusionInstructPix2PixPipeline.from_pretrained(
#    model_id_or_path,
#    revision="fp16", 
#    torch_dtype=torch.float16,
#)
image_variation = image_variation.to("cuda")

with open("dataset/prompts.json", "r") as f:
    prompts_dict = json.load(f)
image_names = list(prompts_dict.keys())
print(image_names[0])

image_idxs = [0,6,7,10,13,17,20,21,25,27,32,34,35,36,38,39,40,44,46,47,48,49,51,54,55,57,59,67,68,70,72,77,78,79,
80,81,82,84,87,89,92,94,96,97,98,103,104,107,108,111,112,114,115,116,119,120,123,124,128,129,132,134,
139,140,141,143,144,151,154,155,158,159,165,172,173,175,176,177,178,180,182,184,185,186,187,189,190,192,195,
197,198,200,201,204,206,208,209,211,212,213,214,215,218,219,222,224,227,230,231,237,239,240,241,244,249,252,
253,254,257,263,264,265,266,268,269,271,275,278,281,282,286,289,294,296,298,303,304,305,306,309,314,318,319,
320,322,323,331,332,334,335,341,343,345,349,350,353,355,360,361,370,373,375,376,377,383,393,394,395,397,398]

# A fixed random selected seed in all the experiments
SEED = 9209
torch.manual_seed(SEED)
strength = 0.7
guidance_scale = 7.5
num_inference_steps = 100

torch.cuda.empty_cache()
MSE = torch.nn.MSELoss()

class BIM(object):
    def __init__(self, model, epsilon=0.1, iteration=15, step_length=0.01):
        self.model = model
        self.record_features = []
        self._register_model()
        self.eps = epsilon
        self.T = iteration
        self.step_length = step_length
        self.feature_ori = []
        
    def _register_model(self): 
        def obtain_output_feature(module, feature_in, feature_out):
            self.record_features.append(feature_out[0])
        
        # Encoding
        #self.hook = self.model.vae.encoder.register_forward_hook(obtain_output_feature) # encoder
        #self.hook = self.model.vae.quant_conv.register_forward_hook(obtain_output_feature) # quant conv
        
        # Unet
        #self.hook = self.model.unet.down_blocks[1].attentions[0].transformer_blocks[0].attn1.register_forward_hook(obtain_output_feature) # self-attn
        #self.hook = self.model.unet.down_blocks[1].attentions[0].transformer_blocks[0].attn2.register_forward_hook(obtain_output_feature) # cross-attn
        #self.hook = self.model.unet.down_blocks[1].attentions[0].transformer_blocks[0].ff.register_forward_hook(obtain_output_feature) # feed-forward
        self.hook = self.model.unet.down_blocks[1].resnets[0].register_forward_hook(obtain_output_feature) # resnet
        
        # Decoding
        #self.hook = self.model.vae.post_quant_conv.register_forward_hook(obtain_output_feature) # post quant conv
        #self.hook = self.model.vae.decoder.register_forward_hook(obtain_output_feature) # decoder
        
        # other trials
        # downblock is the best choice
        #self.hook = self.model.unet.mid_block.attentions[0].transformer_blocks[0].attn1.register_forward_hook(obtain_output_feature)
        #self.hook = self.model.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn2.register_forward_hook(obtain_output_feature)

    def generate(self, image, prompt, strength=0.7,guidance_scale=7.5,num_inference_steps=100):
        with torch.no_grad():
            torch.manual_seed(SEED)
            img_tmp = self.model(prompt=prompt, image=image, strength=strength, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images[0]
        return img_tmp

YES


`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


000000276284.jpg
dict_keys(['info', 'licenses', 'images', 'annotations', 'categories'])


In [2]:
# Generate edited images for adversarial samples
folder = "data/variation/resnet/"
folder_save = "data/variation/resnet_generate/"
if not os.path.exists(folder_save):
    os.mkdir(folder_save)

attack = BIM(image_variation)

# We do experiments on 100*5 data pairs
for image_idx in image_idxs[:100]:
    torch.cuda.empty_cache()
    prompts = prompts_dict[image_names[image_idx]]
    for j in range(5):
        ori_image = Image.open(folder + image_names[image_idx][:-4]+"_"+str(j)+".png").convert('RGB').resize((512,512))
        prompt = prompts[j]
        torch.manual_seed(SEED)
        img = attack.generate(ori_image, prompt)
        save_path = folder_save + image_names[image_idx][:-4]+ "_" + str(j) + image_names[image_idx][-4:]
        img.save(save_path)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


In [None]:
# Generate edited images for benign inputs
folder = "dataset/images/"
folder_save = "data/variation/ori_image_generate/"
if not os.path.exists(folder_save):
    os.mkdir(folder_save)

attack = BIM(image_variation)

# We do experiments on 100*5 data pairs
for image_idx in image_idxs[:2]:
    torch.cuda.empty_cache()
    prompts = prompts_dict[image_names[image_idx]]
    for j in range(5):
        ori_image = Image.open(folder + image_names[image_idx]).convert('RGB').resize((512,512))
        prompt = prompts[j]
        torch.manual_seed(SEED)
        img = attack.generate(ori_image, prompt)
        save_path = folder_save + image_names[image_idx][:-4]+ "_" + str(j) + image_names[image_idx][-4:]
        img.save(save_path)