-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathsimba_single.py
29 lines (26 loc) · 1.01 KB
/
simba_single.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import utils
def normalize(x):
return utils.apply_normalization(x, 'imagenet')
def get_probs(model, x, y):
output = model(normalize(x.cuda())).cpu()
probs = torch.nn.Softmax()(output)[:, y]
return torch.diag(probs.data)
# 20-line implementation of (untargeted) SimBA for single image input
def simba_single(model, x, y, num_iters=10000, epsilon=0.2):
n_dims = x.view(1, -1).size(1)
perm = torch.randperm(n_dims)
last_prob = get_probs(model, x, y)
for i in range(num_iters):
diff = torch.zeros(n_dims)
diff[perm[i]] = epsilon
left_prob = get_probs(model, (x - diff.view(x.size())).clamp(0, 1), y)
if left_prob < last_prob:
x = (x - diff.view(x.size())).clamp(0, 1)
last_prob = left_prob
else:
right_prob = get_probs(model, (x + diff.view(x.size())).clamp(0, 1), y)
if right_prob < last_prob:
x = (x + diff.view(x.size())).clamp(0, 1)
last_prob = right_prob
return x