Skip to content

Commit

Permalink
Track corrupted output_similarity.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 317342104
  • Loading branch information
GhassenJ authored and edward-bot committed Jun 19, 2020
1 parent 990e3e7 commit 862f338
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions baselines/cifar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def aggregate_corrupt_metrics(metrics,
Dictionary of aggregated results.
"""
diversity_keys = ['disagreement', 'cosine_similarity', 'average_kl']
diversity_keys = ['disagreement', 'cosine_similarity', 'average_kl',
'outputs_similarity']
results = {
'test/nll_mean_corrupted': 0.,
'test/accuracy_mean_corrupted': 0.,
Expand All @@ -305,7 +306,7 @@ def aggregate_corrupt_metrics(metrics,
disagreement = np.zeros(len(corruption_types))
cosine_similarity = np.zeros(len(corruption_types))
average_kl = np.zeros(len(corruption_types))

outputs_similarity = np.zeros(len(corruption_types))
for i in range(len(corruption_types)):
dataset_name = '{0}_{1}'.format(corruption_types[i], intensity)
nll[i] = metrics['test/nll_{}'.format(dataset_name)].result()
Expand All @@ -321,17 +322,21 @@ def aggregate_corrupt_metrics(metrics,
dataset_name)].result()
member_ece[i] = 0.
if corrupt_diversity is not None:
error = 1 - acc[i] + tf.keras.backend.epsilon()
disagreement[i] = (
corrupt_diversity['corrupt_diversity/disagreement_{}'.format(
dataset_name)].result())
dataset_name)].result()) / error
# Normalize the corrupt disagreement by its error rate.
error = 1 - acc[i] + tf.keras.backend.epsilon()
cosine_similarity[i] = (
corrupt_diversity['corrupt_diversity/cosine_similarity_{}'.format(
dataset_name)].result()) / error
dataset_name)].result())
average_kl[i] = (
corrupt_diversity['corrupt_diversity/average_kl_{}'.format(
dataset_name)].result())
outputs_similarity[i] = (
corrupt_diversity['corrupt_diversity/outputs_similarity_{}'.format(
dataset_name)].result())

if log_fine_metrics or output_dir is not None:
fine_metrics_results['test/nll_{}'.format(dataset_name)] = nll[i]
fine_metrics_results['test/accuracy_{}'.format(dataset_name)] = acc[i]
Expand All @@ -343,6 +348,9 @@ def aggregate_corrupt_metrics(metrics,
dataset_name)] = cosine_similarity[i]
fine_metrics_results['corrupt_diversity/average_kl_{}'.format(
dataset_name)] = average_kl[i]
fine_metrics_results['corrupt_diversity/outputs_similarity_{}'.format(
dataset_name)] = outputs_similarity[i]

avg_nll = np.mean(nll)
avg_accuracy = np.mean(acc)
avg_ece = np.mean(ece)
Expand All @@ -363,7 +371,7 @@ def aggregate_corrupt_metrics(metrics,
results['test/member_ece_mean_corrupted'] += avg_member_ece
if corrupt_diversity is not None:
avg_diversity_metrics = [np.mean(disagreement), np.mean(
cosine_similarity), np.mean(average_kl)]
cosine_similarity), np.mean(average_kl), np.mean(outputs_similarity)]
for key, avg in zip(diversity_keys, avg_diversity_metrics):
results['corrupt_diversity/{}_mean_{}'.format(
key, intensity)] = avg
Expand Down

0 comments on commit 862f338

Please sign in to comment.