-
Notifications
You must be signed in to change notification settings - Fork 2
/
robust_measures.py
487 lines (423 loc) · 19.1 KB
/
robust_measures.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
"""
Adapted from "In search of robust measures of generalization" by Dziugaite et al.
https://github.com/nitarshan/robust-generalization-measures.git
"""
from contextlib import contextmanager
from copy import deepcopy
import math
from typing import List, Tuple
from enum import Enum
import numpy as np
import torch
from torch import Tensor
from torch.utils.data.dataloader import DataLoader
from torch import Tensor
import torch.nn as nn
from utils.utils_CKA import *
from utils.data_utils import get_masks_and_count_tokens, get_src_and_trg_batches
from utils.optimizers_and_distributions import LabelSmoothingDistribution
class CT(Enum):
# Measures from Fantastic Generalization Measures (equation numbers)
PARAMS = 20
INVERSE_MARGIN = 22
LOG_SPEC_INIT_MAIN = 29
LOG_SPEC_ORIG_MAIN = 30
LOG_PROD_OF_SPEC_OVER_MARGIN = 31
LOG_PROD_OF_SPEC = 32
FRO_OVER_SPEC = 33
LOG_SUM_OF_SPEC_OVER_MARGIN = 34
LOG_SUM_OF_SPEC = 35
LOG_PROD_OF_FRO_OVER_MARGIN = 36
LOG_PROD_OF_FRO = 37
LOG_SUM_OF_FRO_OVER_MARGIN = 38
LOG_SUM_OF_FRO = 39
FRO_DIST = 40
DIST_SPEC_INIT = 41
PARAM_NORM = 42
PATH_NORM_OVER_MARGIN = 43
PATH_NORM = 44
PACBAYES_INIT = 48
PACBAYES_ORIG = 49
PACBAYES_FLATNESS = 53
PACBAYES_MAG_INIT = 56
PACBAYES_MAG_ORIG = 57
PACBAYES_MAG_FLATNESS = 61
# Other Measures
L2 = 100
L2_DIST = 101
# FFT Spectral Measures
LOG_SPEC_INIT_MAIN_FFT = 129
LOG_SPEC_ORIG_MAIN_FFT = 130
LOG_PROD_OF_SPEC_OVER_MARGIN_FFT = 131
LOG_PROD_OF_SPEC_FFT = 132
FRO_OVER_SPEC_FFT = 133
LOG_SUM_OF_SPEC_OVER_MARGIN_FFT = 134
LOG_SUM_OF_SPEC_FFT = 135
DIST_SPEC_INIT_FFT = 141
@torch.no_grad()
def eval_batch(batch, model, labels):
input_ids = batch['input_ids'].cuda()
attention_mask = batch['attention_mask'].cuda()
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
return outputs
@torch.no_grad()
def eval_acc(model, eval_loader):
model.eval()
num = 0
correct = 0
for batch in eval_loader:
labels = batch['labels'].cuda()
outputs = eval_batch(batch, model, labels)
predictions = outputs.logits.argmax(dim=-1)
num += len(labels)
correct += (labels==predictions).sum().item()
assert num>0
acc = correct/num
print(f"Evaluate accuracy = {acc}.")
return acc
@torch.no_grad()
def eval_NMT_loss(model, dataloader, pad_token_id=None, trg_vocab_size=0, NMT_maximum_samples = 10000):
## This function is used to calculate the training loss of machine translation.
num_processed_samples = 0
device = next(model.parameters()).device
training_loss = 0
loss_step = 0
for _, token_ids_batch in enumerate(dataloader):
src_token_ids_batch, trg_token_ids_batch_input, target = get_src_and_trg_batches(token_ids_batch)
num_processed_samples += token_ids_batch.batch_size
src_mask, trg_mask, _, _ = get_masks_and_count_tokens(src_token_ids_batch, trg_token_ids_batch_input, pad_token_id, device)
logits = model(src_token_ids_batch, trg_token_ids_batch_input, src_mask, trg_mask)
kl_div_loss = nn.KLDivLoss(reduction='batchmean')
label_smoothing = LabelSmoothingDistribution(0, pad_token_id, trg_vocab_size, device) # Use label smoothing = 0 here
smooth_target_distributions = label_smoothing(target) # these are regular probabilities
loss = kl_div_loss(logits, smooth_target_distributions)
training_loss += loss.item()
loss_step += 1
if num_processed_samples>=NMT_maximum_samples:
break
training_loss = training_loss/loss_step
print(f"NMT training loss is {training_loss}")
return training_loss
# Adapted from https://github.com/bneyshabur/generalization-bounds/blob/master/measures.py
@torch.no_grad()
def _reparam(model):
def in_place_reparam(model, prev_layer=None):
for child in model.children():
prev_layer = in_place_reparam(child, prev_layer)
if child._get_name() == 'Conv2d':
prev_layer = child
elif child._get_name() == 'BatchNorm2d':
scale = child.weight / ((child.running_var + child.eps).sqrt())
prev_layer.bias.copy_( child.bias + ( scale * (prev_layer.bias - child.running_mean) ) )
perm = list(reversed(range(prev_layer.weight.dim())))
prev_layer.weight.copy_((prev_layer.weight.permute(perm) * scale ).permute(perm))
child.bias.fill_(0)
child.weight.fill_(1)
child.running_mean.fill_(0)
child.running_var.fill_(1)
return prev_layer
model = deepcopy(model)
in_place_reparam(model)
return model
@contextmanager
def _perturbed_model(
model,
sigma,
rng,
magnitude_eps = None
):
device = next(model.parameters()).device
if magnitude_eps is not None:
noise = [torch.normal(0,sigma**2 * torch.abs(p) ** 2 + magnitude_eps ** 2, generator=rng) for p in model.parameters()]
else:
noise = [torch.normal(0,sigma**2,p.shape, generator=rng).to(device) for p in model.parameters()]
model = deepcopy(model)
try:
[p.add_(n) for p,n in zip(model.parameters(), noise)]
yield model
finally:
[p.sub_(n) for p,n in zip(model.parameters(), noise)]
del model
# Adapted from https://drive.google.com/file/d/1_6oUG94d0C3x7x2Vd935a2QqY-OaAWAM/view
def _pacbayes_sigma(
model,
dataloader,
accuracy,
seed,
magnitude_eps = None,
search_depth = 15,
montecarlo_samples = 10,
accuracy_displacement = 0.1,
displacement_tolerance = 1e-2,
task_type = 'normal',
pad_token_id = None,
trg_vocab_size = 0,
search_upper_limit = 0.2
) -> float:
if task_type == 'NMT' and magnitude_eps:
# This is a tricky case. It seems that using search_upper_limit=0.2 is not large enough
search_upper_limit = 2
lower, upper = 0, search_upper_limit
sigma = 0.1
BIG_NUMBER = 10348628753
device = next(model.parameters()).device
rng = torch.Generator(device=device) if magnitude_eps is not None else torch.Generator()
rng.manual_seed(BIG_NUMBER + seed)
if not accuracy and task_type == 'NMT':
# In this case, the training accuracy is hard to evaluate
# So we use the training loss instead
# It is training loss, but we still call it "accuracy" to follow the convention
print("Evaluate training loss using the original model.")
accuracy = eval_NMT_loss(model, dataloader, pad_token_id=pad_token_id, trg_vocab_size=trg_vocab_size)
accuracy_displacement = 0.5
displacement_tolerance = 0.05
print(f"Start binary search for PAC-Bayes sigma.")
for _ in range(search_depth):
sigma = (lower + upper) / 2
# If sigma > search_upper_limit - 0.01, most likely the search is stuck because upper limit is too small
if sigma > search_upper_limit * 0.95:
return search_upper_limit
accuracy_samples = []
print(f"Getting samples for current sigma.")
for _ in range(montecarlo_samples):
print(f"current sigma is {sigma}")
with _perturbed_model(model, sigma, rng, magnitude_eps) as p_model:
# The following code is replaced with a method of evaluating accuracy
#loss_estimate = 0
#for data, target in dataloader:
# logits = p_model(data)
# pred = logits.data.max(1, keepdim=True)[1] # get the index of the max logits
# batch_correct = pred.eq(target.data.view_as(pred)).type(torch.FloatTensor).cpu()
# loss_estimate += batch_correct.sum()
#loss_estimate /= len(dataloader.dataset)
if task_type == 'NMT':
loss_estimate = eval_NMT_loss(p_model, dataloader, pad_token_id=pad_token_id, trg_vocab_size=trg_vocab_size)
else:
loss_estimate = eval_acc(p_model, dataloader)
accuracy_samples.append(loss_estimate)
displacement = abs(np.mean(accuracy_samples) - accuracy)
if abs(displacement - accuracy_displacement) < displacement_tolerance:
break
elif displacement > accuracy_displacement:
# Too much perturbation
upper = sigma
else:
# Not perturbed enough to reach target displacement
lower = sigma
return sigma
def W_CKA(p,q, feature_space=True):
eps=1e-15
p = p.data.numpy()
q = q.data.numpy()
if np.sqrt(np.sum((p-q)**2)) < eps:
return 1.0
if feature_space:
return feature_space_linear_cka(p, q)
else:
return cka_compute(gram_linear(p, q))
@torch.no_grad()
def get_all_measures(
model,
init_model,
dataloader,
acc,
seed,
no_path_norm=True,
no_exact_spectral_norm=True,
no_pac_bayes=False,
no_margin=False,
no_basics=False,
no_CKA=True,
task_type='NMT',
path_norm_transformer=None,
pad_token_id=None,
trg_vocab_size=0,
pacbayes_depth=15
):
measures = {}
model = _reparam(model)
init_model = _reparam(init_model)
device = next(model.parameters()).device
m = len(dataloader.dataset)
def get_weights_only(model):
blacklist = {'bias', 'bn'}
return [p for name, p in model.named_parameters() if all(x not in name for x in blacklist)]
weights = get_weights_only(model)
init_weights = get_weights_only(init_model)
weights_cpu = [p.to("cpu") for p in weights]
init_weights_cpu = [p.to("cpu") for p in init_weights]
dist_init_weights = [p-q for p,q in zip(weights, init_weights)]
d = len(weights)
def get_vec_params(weights: List[Tensor]) -> Tensor:
return torch.cat([p.view(-1) for p in weights], dim=0)
w_vec = get_vec_params(weights)
dist_w_vec = get_vec_params(dist_init_weights)
num_params = len(w_vec)
if not no_CKA:
measures["W_CKA"] = np.mean([W_CKA(p,q, feature_space=True) for p,q in zip(weights_cpu, init_weights_cpu) if len(p.shape)>1])
def get_reshaped_weights(weights: List[Tensor]) -> List[Tensor]:
# If the weight is a tensor (e.g. a 4D Conv2d weight), it will be reshaped to a 2D matrix
return [p.view(p.shape[0],-1) for p in weights]
reshaped_weights = get_reshaped_weights(weights)
dist_reshaped_weights = get_reshaped_weights(dist_init_weights)
if not no_basics:
print("Vector Norm Measures")
measures["L2"] = w_vec.norm(p=2)
measures["L2_DIST"] = dist_w_vec.norm(p=2)
print("VC-Dimension Based Measures")
measures["PARAMS"] = torch.tensor(num_params) # 20
if not no_margin:
print("Measures on the output of the network")
def _calculate_margin(
logits,
target
):
correct_logit = logits[torch.arange(logits.shape[0]), target].clone()
logits[torch.arange(logits.shape[0]), target] = float('-inf')
max_other_logit = logits.data.max(1).values # get the index of the max logits
margin = correct_logit - max_other_logit
return margin
@torch.no_grad()
def _margin(
model,
dataloader,
task_type='normal',
pad_token_id=None,
NMT_maximum_samples = 10000,
) -> Tensor:
margins = []
if task_type=='NMT':
num_processed_samples = 0
for batch_id, token_ids_batch in enumerate(dataloader):
src_token_ids_batch, trg_token_ids_batch_input, target = get_src_and_trg_batches(token_ids_batch)
num_processed_samples += token_ids_batch.batch_size
src_mask, trg_mask, num_src_tokens, num_trg_tokens = get_masks_and_count_tokens(src_token_ids_batch, trg_token_ids_batch_input, pad_token_id, device)
logits = model(src_token_ids_batch, trg_token_ids_batch_input, src_mask, trg_mask, no_softmax=True) # do not use softmax
margins.append(_calculate_margin(logits.clone(),target.flatten()))
if num_processed_samples >= NMT_maximum_samples:
print(f"There are {num_processed_samples} sentences processed when calculating the margin.")
break
margin_distribution = torch.cat(margins)
return margin_distribution.kthvalue(len(margin_distribution) // 10)[0]
else:
for batch in dataloader:
target = batch['labels'].cuda()
outputs = eval_batch(batch, model, target)
logits = outputs.logits
margins.append(_calculate_margin(logits,target))
return torch.cat(margins).kthvalue(m // 10)[0]
true_margin = _margin(model, dataloader, task_type, pad_token_id)
measures["TRUE_MARGIN"] = true_margin # Only used for checking if the true margin could become negative
margin = true_margin.abs()
measures["INVERSE_MARGIN"] = torch.tensor(1, device=device) / margin ** 2 # 22
if not no_basics:
print("(Norm & Margin)-Based Measures")
fro_norms = torch.cat([p.norm('fro').unsqueeze(0) ** 2 for p in reshaped_weights])
print("Starting SVD calculations which may occupy large memory.")
spec_norms = torch.cat([p.svd().S.max().unsqueeze(0) ** 2 for p in reshaped_weights])
print("End SVD calculations.")
dist_fro_norms = torch.cat([p.norm('fro').unsqueeze(0) ** 2 for p in dist_reshaped_weights])
dist_spec_norms = torch.cat([p.svd().S.max().unsqueeze(0) ** 2 for p in dist_reshaped_weights])
print("Approximate Spectral Norm")
# Note that these use an approximation from [Yoshida and Miyato, 2017]
# https://arxiv.org/abs/1705.10941 (Section 3.2, Convolutions)
measures["LOG_PROD_OF_SPEC"] = spec_norms.log().sum() # 32
measures["FRO_OVER_SPEC"] = (fro_norms / spec_norms).sum() # 33
measures["LOG_SUM_OF_SPEC"] = math.log(d) + (1/d) * measures["LOG_PROD_OF_SPEC"] # 35
if not no_margin:
measures["LOG_PROD_OF_SPEC_OVER_MARGIN"] = measures["LOG_PROD_OF_SPEC"] - 2 * margin.log() # 31
measures["LOG_SPEC_INIT_MAIN"] = measures["LOG_PROD_OF_SPEC_OVER_MARGIN"] + (dist_fro_norms / spec_norms).sum().log() # 29
measures["LOG_SPEC_ORIG_MAIN"] = measures["LOG_PROD_OF_SPEC_OVER_MARGIN"] + measures["FRO_OVER_SPEC"].log() # 30
measures["LOG_SUM_OF_SPEC_OVER_MARGIN"] = math.log(d) + (1/d) * (measures["LOG_PROD_OF_SPEC"] - 2 * margin.log()) # 34
if not no_basics:
print("Frobenius Norm")
measures["LOG_PROD_OF_FRO"] = fro_norms.log().sum() # 37
measures["LOG_SUM_OF_FRO"] = math.log(d) + (1/d) * measures["LOG_PROD_OF_FRO"] # 39
if not no_margin:
measures["LOG_PROD_OF_FRO_OVER_MARGIN"] = measures["LOG_PROD_OF_FRO"] - 2 * margin.log() # 36
measures["LOG_SUM_OF_FRO_OVER_MARGIN"] = math.log(d) + (1/d) * (measures["LOG_PROD_OF_FRO"] - 2 * margin.log()) # 38
print("Distance to Initialization")
measures["FRO_DIST"] = dist_fro_norms.sum() # 40
measures["DIST_SPEC_INIT"] = dist_spec_norms.sum() # 41
measures["PARAM_NORM"] = fro_norms.sum() # 42
if not no_path_norm:
print("Path-norm")
# Adapted from https://github.com/bneyshabur/generalization-bounds/blob/master/measures.py#L98
def _path_norm(model):
model = deepcopy(model)
model.eval()
for param in model.parameters():
if param.requires_grad:
param.data.pow_(2)
# path norm requires all 1 input
# we construct the all 1 input using length-1 sequence
model.src_embedding.embeddings_table.weight.data = torch.ones_like(model.src_embedding.embeddings_table.weight.data)
model.src_pos_embedding.positional_encodings_table.data = torch.zeros_like(model.src_pos_embedding.positional_encodings_table.data)
model.trg_embedding.embeddings_table.weight.data = torch.ones_like(model.trg_embedding.embeddings_table.weight.data)
model.trg_pos_embedding.positional_encodings_table.data = torch.zeros_like(model.trg_pos_embedding.positional_encodings_table.data)
if task_type == 'NMT':
src_token=torch.ones(1,1).long()
trg_token=torch.ones(1,1).long()
src_mask=torch.ones(1,1,1,1)>0
trg_mask=torch.ones(1,1,1,1)>0
x = model(src_token, trg_token, src_mask, trg_mask)
else:
raise ValueError
del model
return x.sum()
measures["PATH_NORM"] = _path_norm(path_norm_transformer) # 44
if not no_margin:
measures["PATH_NORM_OVER_MARGIN"] = measures["PATH_NORM"] / margin ** 2 # 43
if not no_exact_spectral_norm:
print("Exact Spectral Norm")
# Proposed in https://arxiv.org/abs/1805.10408
# Adapted from https://github.com/brain-research/conv-sv/blob/master/conv2d_singular_values.py#L52
def _spectral_norm_fft(kernel: Tensor, input_shape: Tuple[int, int]) -> Tensor:
# PyTorch conv2d filters use Shape(out,in,kh,kw)
# [Sedghi 2018] code expects filters of Shape(kh,kw,in,out)
# Pytorch doesn't support complex FFT and SVD, so we do this in numpy
np_kernel = np.einsum('oihw->hwio', kernel.data.cpu().numpy())
transforms = np.fft.fft2(np_kernel, input_shape, axes=[0, 1]) # Shape(ih,iw,in,out)
singular_values = np.linalg.svd(transforms, compute_uv=False) # Shape(ih,iw,min(in,out))
spec_norm = singular_values.max()
return torch.tensor(spec_norm, device=kernel.device)
input_shape = (model.dataset_type.D[1], model.dataset_type.D[2])
fft_spec_norms = torch.cat([_spectral_norm_fft(p, input_shape).unsqueeze(0) ** 2 for p in weights])
fft_dist_spec_norms = torch.cat([_spectral_norm_fft(p, input_shape).unsqueeze(0) ** 2 for p in dist_init_weights])
measures[CT.LOG_PROD_OF_SPEC_FFT] = fft_spec_norms.log().sum() # 32
measures[CT.LOG_PROD_OF_SPEC_OVER_MARGIN_FFT] = measures[CT.LOG_PROD_OF_SPEC_FFT] - 2 * margin.log() # 31
measures[CT.FRO_OVER_SPEC_FFT] = (fro_norms / fft_spec_norms).sum() # 33
measures[CT.LOG_SUM_OF_SPEC_OVER_MARGIN_FFT] = math.log(d) + (1/d) * (measures[CT.LOG_PROD_OF_SPEC_FFT] - 2 * margin.log()) # 34
measures[CT.LOG_SUM_OF_SPEC_FFT] = math.log(d) + (1/d) * measures[CT.LOG_PROD_OF_SPEC_FFT] # 35
measures[CT.DIST_SPEC_INIT_FFT] = fft_dist_spec_norms.sum() # 41
measures[CT.LOG_SPEC_INIT_MAIN_FFT] = measures[CT.LOG_PROD_OF_SPEC_OVER_MARGIN_FFT] + (dist_fro_norms / fft_spec_norms).sum().log() # 29
measures[CT.LOG_SPEC_ORIG_MAIN_FFT] = measures[CT.LOG_PROD_OF_SPEC_OVER_MARGIN_FFT] + measures[CT.FRO_OVER_SPEC_FFT].log() # 30
if not no_pac_bayes:
print("Flatness-based measures")
sigma = _pacbayes_sigma(model, dataloader, acc, seed, search_depth=pacbayes_depth, task_type=task_type, pad_token_id=pad_token_id, trg_vocab_size=trg_vocab_size)
def _pacbayes_bound(reference_vec: Tensor) -> Tensor:
return (reference_vec.norm(p=2) ** 2) / (4 * sigma ** 2) + math.log(m / sigma) + 10
measures["PACBAYES_INIT"] = _pacbayes_bound(dist_w_vec) # 48
measures["PACBAYES_ORIG"] = _pacbayes_bound(w_vec) # 49
measures["PACBAYES_FLATNESS"] = torch.tensor(1 / sigma ** 2) # 53
print("Magnitude-aware Perturbation Bounds")
mag_eps = 1e-3
mag_sigma = _pacbayes_sigma(model, dataloader, acc, seed, mag_eps, search_depth=pacbayes_depth, task_type=task_type, pad_token_id=pad_token_id, trg_vocab_size=trg_vocab_size)
omega = num_params
def _pacbayes_mag_bound(reference_vec: Tensor) -> Tensor:
numerator = mag_eps ** 2 + (mag_sigma ** 2 + 1) * (reference_vec.norm(p=2)**2) / omega
denominator = mag_eps ** 2 + mag_sigma ** 2 * dist_w_vec ** 2
return 1/4 * (numerator / denominator).log().sum() + math.log(m / mag_sigma) + 10
measures["PACBAYES_MAG_INIT"] = _pacbayes_mag_bound(dist_w_vec) # 56
measures["PACBAYES_MAG_ORIG"] = _pacbayes_mag_bound(w_vec) # 57
measures["PACBAYES_MAG_FLATNESS"] = torch.tensor(1 / mag_sigma ** 2) # 61
# Adjust for dataset size
def adjust_measure(measure: CT, value: float) -> float:
#if measure.name.startswith('LOG_'):
if measure.startswith('LOG_'):
return 0.5 * (value - np.log(m))
elif 'CKA' in measure or 'TRUE_MARGIN' in measure:
return value
else:
return np.sqrt(value / m)
return {k: adjust_measure(k, v.item()) for k, v in measures.items()}