Skip to content

Commit

Permalink
update robustness_eval and test_robustness
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghonglie committed Jul 28, 2022
1 parent a0a0211 commit cd06834
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 220 deletions.
23 changes: 12 additions & 11 deletions tools/analysis_tools/robustness_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ def _print(result, ap=1, iouThr=None, areaRng='all', maxDets=100):
stats[3] = _print(results[3], 1, areaRng='small')
stats[4] = _print(results[4], 1, areaRng='medium')
stats[5] = _print(results[5], 1, areaRng='large')
# TODO support recall metric
'''
stats[6] = _print(results[6], 0, maxDets=1)
stats[7] = _print(results[7], 0, maxDets=10)
stats[8] = _print(results[8], 0)
stats[9] = _print(results[9], 0, areaRng='small')
stats[10] = _print(results[10], 0, areaRng='medium')
stats[11] = _print(results[11], 0, areaRng='large')
'''


def get_coco_style_results(filename,
Expand All @@ -49,29 +52,27 @@ def get_coco_style_results(filename,

if metric is None:
metrics = [
'AP', 'AP50', 'AP75', 'APs', 'APm', 'APl', 'AR1', 'AR10', 'AR100',
'ARs', 'ARm', 'ARl'
'mAP',
'mAP_50',
'mAP_75',
'mAP_s',
'mAP_m',
'mAP_l',
]
elif isinstance(metric, list):
metrics = metric
else:
metrics = [metric]

for metric_name in metrics:
assert metric_name in [
'AP', 'AP50', 'AP75', 'APs', 'APm', 'APl', 'AR1', 'AR10', 'AR100',
'ARs', 'ARm', 'ARl'
]

eval_output = mmcv.load(filename)

num_distortions = len(list(eval_output.keys()))
results = np.zeros((num_distortions, 6, len(metrics)), dtype='float32')

for corr_i, distortion in enumerate(eval_output):
for severity in eval_output[distortion]:
for metric_j, metric_name in enumerate(metrics):
mAP = eval_output[distortion][severity][task][metric_name]
mAP = eval_output[distortion][severity]['_'.join(
(task, metric_name))]
results[corr_i, severity, metric_j] = mAP

P = results[0, 0, :]
Expand Down Expand Up @@ -155,7 +156,7 @@ def get_voc_style_results(filename, prints='mPC', aggregate='benchmark'):

def get_results(filename,
dataset='coco',
task='bbox',
task='coco/bbox',
metric=None,
prints='mPC',
aggregate='benchmark'):
Expand Down
Loading

0 comments on commit cd06834

Please sign in to comment.