-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
120 lines (96 loc) · 4.11 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import numpy as np
from tqdm import tqdm
import argparse
from PIL import Image
import random
import glob
from detector_cleanse import detector_cleanse
from model import FasterRCNNVGG16
from transform import preprocess
import warnings
warnings.filterwarnings("ignore")
def parse_arguments():
parser = argparse.ArgumentParser(description='Run Detector Cleanse on an image')
parser.add_argument('--n', type=int, required=True, help='Number of features to randomly select')
parser.add_argument('--m', type=float, required=True, help='Detection mean')
parser.add_argument('--delta', type=float, required=True, help='Detection threshold')
parser.add_argument('--alpha', type=float, default=0.5, help='Blending ratio')
parser.add_argument('--iouthresh', type=float, default=0.5, help='Threshold iou')
parser.add_argument('--image_path', type=str, default='images', help='Path to the image(s) to be analyzed')
parser.add_argument('--clean_feature_path', type=str ,default='clean_feature_images', help='Path to the clean_feature image folder')
parser.add_argument('--weight', type=str, required=True, help='Path to weight of the model')
return parser.parse_args()
def main():
args = parse_arguments()
print("Loading clean feature files...")
clean_feature_files = glob.glob(f'{args.clean_feature_path}/*.jpg')
selected_features = random.sample(clean_feature_files, args.n)
clean_features = [Image.open(feature_path) for feature_path in selected_features]
for i in range(len(clean_features)):
feature = clean_features[i].convert('RGB')
feature = np.asarray(feature, dtype=np.float32)
feature = feature.transpose((2, 0, 1))
feature = preprocess(feature)
clean_features[i] = feature
print("Complete")
print("Loading model...")
model = FasterRCNNVGG16(n_fg_class=20)
state_dict = torch.load(args.weight)
if 'model' in state_dict:
model.load_state_dict(state_dict['model'])
else: # legacy way, for backward compatibility
model.load_state_dict(state_dict)
print("Complete")
print("Detecting")
if 'jpg' not in args.image_path:
image_files = glob.glob(f'{args.image_path}/*.jpg')
total_clean = 0
total_poison = 0
false_accept = 0
false_reject = 0
success = 0
pbar = tqdm(image_files)
for image_file in pbar:
f = Image.open(image_file)
ori_img = f.convert('RGB')
ori_img = np.asarray(ori_img, dtype=np.float32)
ori_img = ori_img.transpose((2, 0, 1))
img = preprocess(ori_img)
poisoned, coordinates = detector_cleanse(img, model, clean_features, args.m, args.delta, args.alpha, args.iouthresh)
if "modified" in image_file:
total_poison += 1
if poisoned:
success += 1
else:
false_accept += 1
else:
total_clean += 1
if poisoned:
false_reject += 1
else:
success += 1
far = false_accept/total_poison if total_poison != 0 else 0
frr = false_reject/total_clean if total_clean != 0 else 0
pbar.set_description(f"accuracy {success/(total_clean + total_poison)} FAR {far},{total_poison} FRR {frr},{total_clean}")
print(total_clean)
print(total_poison)
print(success)
print(false_accept)
print(false_reject)
else:
f = Image.open(args.image_path)
ori_img = f.convert('RGB')
ori_img = np.asarray(ori_img, dtype=np.float32)
ori_img = ori_img.transpose((2, 0, 1))
img = preprocess(ori_img)
poisoned, coordinates = detector_cleanse(img, model, clean_features, args.m, args.delta, args.alpha, args.iouthresh)
if poisoned:
print()
print("Image is poisoned")
print(f"Coordinate : {coordinates}")
else:
print()
print("Image is clean")
if __name__ == "__main__":
main()