-
Notifications
You must be signed in to change notification settings - Fork 110
/
utils_tf2.py
210 lines (152 loc) · 7.85 KB
/
utils_tf2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import tensorflow as tf
import numpy as np
import torch
class ModelAdapter():
def __init__(self, model, num_classes=10):
"""
Please note that model should be tf.keras model without activation function 'softmax'
"""
self.num_classes = num_classes
self.tf_model = model
self.__check_channel_ordering()
def __check_channel_ordering(self):
for L in self.tf_model.layers:
if isinstance(L, tf.keras.layers.Conv2D):
print("[INFO] set data_format = '{:s}'".format(L.data_format))
self.data_format = L.data_format
return
print("[INFO] Can not find Conv2D layer")
input_shape = self.tf_model.input_shape
if input_shape[3] == 3:
print("[INFO] Because detecting input_shape[3] == 3, set data_format = 'channels_last'")
self.data_format = 'channels_last'
elif input_shape[3] == 1:
print("[INFO] Because detecting input_shape[3] == 1, set data_format = 'channels_last'")
self.data_format = 'channels_last'
else:
print("[INFO] set data_format = 'channels_first'")
self.data_format = 'channels_first'
def __get_logits(self, x_input):
logits = self.tf_model(x_input, training=False)
return logits
@tf.function
@tf.autograph.experimental.do_not_convert
def __get_jacobian(self, x_input):
with tf.GradientTape(watch_accessed_variables=False) as g:
g.watch(x_input)
logits = self.__get_logits(x_input)
jacobian = g.batch_jacobian(logits, x_input)
if self.data_format == 'channels_last':
jacobian = tf.transpose(jacobian, perm=[0,1,4,2,3])
return jacobian
def __get_xent(self, logits, y_input):
xent = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y_input)
return xent
@tf.function
@tf.autograph.experimental.do_not_convert
def __get_grad_xent(self, x_input, y_input):
with tf.GradientTape(watch_accessed_variables=False) as g:
g.watch(x_input)
logits = self.__get_logits(x_input)
xent = self.__get_xent(logits, y_input)
grad_xent = g.gradient(xent, x_input)
return logits, xent, grad_xent
def __get_dlr(self, logits, y_input):
val_dlr = dlr_loss(logits, y_input, num_classes=self.num_classes)
return val_dlr
@tf.function
@tf.autograph.experimental.do_not_convert
def __get_grad_dlr(self, x_input, y_input):
with tf.GradientTape(watch_accessed_variables=False) as g:
g.watch(x_input)
logits = self.__get_logits(x_input)
val_dlr = self.__get_dlr(logits, y_input)
grad_dlr = g.gradient(val_dlr, x_input)
return logits, val_dlr, grad_dlr
def __get_dlr_target(self, logits, y_input, y_target):
dlr_target = dlr_loss_targeted(logits, y_input, y_target, num_classes=self.num_classes)
return dlr_target
@tf.function
@tf.autograph.experimental.do_not_convert
def __get_grad_dlr_target(self, x_input, y_input, y_target):
with tf.GradientTape(watch_accessed_variables=False) as g:
g.watch(x_input)
logits = self.__get_logits(x_input)
dlr_target = self.__get_dlr_target(logits, y_input, y_target)
grad_target = g.gradient(dlr_target, x_input)
return logits, dlr_target, grad_target
@tf.function
@tf.autograph.experimental.do_not_convert
def __get_grad_diff_logits_target(self, x, la, la_target):
la_mask = tf.one_hot(la, self.num_classes)
la_target_mask = tf.one_hot(la_target, self.num_classes)
with tf.GradientTape(watch_accessed_variables=False) as g:
g.watch(x)
logits = self.__get_logits(x)
difflogits = tf.reduce_sum((la_target_mask - la_mask) * logits, axis=1)
g2 = g.gradient(difflogits, x)
return difflogits, g2
def predict(self, x):
x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32)
if self.data_format == 'channels_last':
x2 = tf.transpose(x2, perm=[0,2,3,1])
y = self.__get_logits(x2).numpy()
return torch.from_numpy(y).cuda()
def grad_logits(self, x):
x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32)
if self.data_format == 'channels_last':
x2 = tf.transpose(x2, perm=[0,2,3,1])
g2 = self.__get_jacobian(x2)
return torch.from_numpy(g2.numpy()).cuda()
def set_target_class(self, y, y_target):
pass
def get_grad_diff_logits_target(self, x, y, y_target):
x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32)
if self.data_format == 'channels_last':
x2 = tf.transpose(x2, perm=[0,2,3,1])
la = y.cpu().numpy()
la_target = y_target.cpu().numpy()
difflogits, g2 = self.__get_grad_diff_logits_target(x2, la, la_target)
if self.data_format == 'channels_last':
g2 = tf.transpose(g2, perm=[0, 3, 1, 2])
return torch.from_numpy(difflogits.numpy()).cuda(), torch.from_numpy(g2.numpy()).cuda()
def get_logits_loss_grad_xent(self, x, y):
x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32)
y2 = tf.convert_to_tensor(y.clone().cpu().numpy(), dtype=tf.int32)
if self.data_format == 'channels_last':
x2 = tf.transpose(x2, perm=[0,2,3,1])
logits_val, loss_indiv_val, grad_val = self.__get_grad_xent(x2, y2)
if self.data_format == 'channels_last':
grad_val = tf.transpose(grad_val, perm=[0,3,1,2])
return torch.from_numpy(logits_val.numpy()).cuda(), torch.from_numpy(loss_indiv_val.numpy()).cuda(), torch.from_numpy(grad_val.numpy()).cuda()
def get_logits_loss_grad_dlr(self, x, y):
x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32)
y2 = tf.convert_to_tensor(y.clone().cpu().numpy(), dtype=tf.int32)
if self.data_format == 'channels_last':
x2 = tf.transpose(x2, perm=[0,2,3,1])
logits_val, loss_indiv_val, grad_val = self.__get_grad_dlr(x2, y2)
if self.data_format == 'channels_last':
grad_val = tf.transpose(grad_val, perm=[0,3,1,2])
return torch.from_numpy(logits_val.numpy()).cuda(), torch.from_numpy(loss_indiv_val.numpy()).cuda(), torch.from_numpy(grad_val.numpy()).cuda()
def get_logits_loss_grad_target(self, x, y, y_target):
x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32)
y2 = tf.convert_to_tensor(y.clone().cpu().numpy(), dtype=tf.int32)
y_targ = tf.convert_to_tensor(y_target.clone().cpu().numpy(), dtype=tf.int32)
if self.data_format == 'channels_last':
x2 = tf.transpose(x2, perm=[0,2,3,1])
logits_val, loss_indiv_val, grad_val = self.__get_grad_dlr_target(x2, y2, y_targ)
if self.data_format == 'channels_last':
grad_val = tf.transpose(grad_val, perm=[0,3,1,2])
return torch.from_numpy(logits_val.numpy()).cuda(), torch.from_numpy(loss_indiv_val.numpy()).cuda(), torch.from_numpy(grad_val.numpy()).cuda()
def dlr_loss(x, y, num_classes=10):
x_sort = tf.sort(x, axis=1)
y_onehot = tf.one_hot(y, num_classes)
### TODO: adapt to the case when the point is already misclassified
loss = -(x_sort[:, -1] - x_sort[:, -2]) / (x_sort[:, -1] - x_sort[:, -3] + 1e-12)
return loss
def dlr_loss_targeted(x, y, y_target, num_classes=10):
x_sort = tf.sort(x, axis=1)
y_onehot = tf.one_hot(y, num_classes)
y_target_onehot = tf.one_hot(y_target, num_classes)
loss = -(tf.reduce_sum(x * y_onehot, axis=1) - tf.reduce_sum(x * y_target_onehot, axis=1)) / (x_sort[:, -1] - .5 * x_sort[:, -3] - .5 * x_sort[:, -4] + 1e-12)
return loss