Skip to content

Commit

Permalink
resolve conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghonglie committed Sep 14, 2022
1 parent 4c39bd8 commit e774e18
Showing 1 changed file with 21 additions and 171 deletions.
192 changes: 21 additions & 171 deletions tools/analysis_tools/test_robustness.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,93 +4,16 @@
import os
import os.path as osp

import mmcv
import torch
# TODO need refactor
from mmcv.runner import MMDataParallel, wrap_fp16_model
from mmengine.config import Config, DictAction
from mmengine.dist import get_dist_info, init_dist
from mmengine.fileio import dump, load
from mmengine.model import MMDistributedDataParallel
from mmengine.runner import load_checkpoint
from mmengine.utils import is_str
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from mmengine.dist import get_dist_info
from mmengine.fileio import dump
from mmengine.runner import Runner

from mmdet.registry import RUNNERS
from mmdet.utils import register_all_modules
from tools.analysis_tools.robustness_eval import get_results


def coco_eval_with_return(result_files,
result_types,
coco,
max_dets=(100, 300, 1000)):
for res_type in result_types:
assert res_type in ['proposal', 'bbox', 'segm', 'keypoints']

if is_str(coco):
coco = COCO(coco)
assert isinstance(coco, COCO)

eval_results = {}
for res_type in result_types:
result_file = result_files[res_type]
assert result_file.endswith('.json')

coco_dets = coco.loadRes(result_file)
img_ids = coco.getImgIds()
iou_type = 'bbox' if res_type == 'proposal' else res_type
cocoEval = COCOeval(coco, coco_dets, iou_type)
cocoEval.params.imgIds = img_ids
if res_type == 'proposal':
cocoEval.params.useCats = 0
cocoEval.params.maxDets = list(max_dets)
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
if res_type == 'segm' or res_type == 'bbox':
metric_names = [
'AP', 'AP50', 'AP75', 'APs', 'APm', 'APl', 'AR1', 'AR10',
'AR100', 'ARs', 'ARm', 'ARl'
]
eval_results[res_type] = {
metric_names[i]: cocoEval.stats[i]
for i in range(len(metric_names))
}
else:
eval_results[res_type] = cocoEval.stats

return eval_results


def voc_eval_with_return(result_file,
dataset,
iou_thr=0.5,
logger='print',
only_ap=True):
det_results = load(result_file)
annotations = [dataset.get_ann_info(i) for i in range(len(dataset))]
if hasattr(dataset, 'year') and dataset.year == 2007:
dataset_name = 'voc07'
else:
dataset_name = dataset.CLASSES
mean_ap, eval_results = eval_map(
det_results,
annotations,
scale_ranges=None,
iou_thr=iou_thr,
dataset=dataset_name,
logger=logger)

if only_ap:
eval_results = [{
'ap': eval_results[i]['ap']
} for i in range(len(eval_results))]

return mean_ap, eval_results


def parse_args():
parser = argparse.ArgumentParser(description='MMDet test detector')
parser.add_argument('config', help='test config file path')
Expand All @@ -110,6 +33,9 @@ def parse_args():
'spatter', 'saturate'
],
help='corruptions')
parser.add_argument(
'--work-dir',
help='the directory to save the file containing evaluation metrics')
parser.add_argument(
'--severities',
type=int,
Expand Down Expand Up @@ -196,7 +122,9 @@ def main():
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
raise ValueError('The output file must be a pkl file.')

# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

Expand Down Expand Up @@ -286,104 +214,26 @@ def main():
severity=corruption_severity)
# TODO: hard coded "1", we assume that the first step is
# loading images, which needs to be fixed in the future
test_data_cfg['pipeline'].insert(1, corruption_trans)
test_loader_cfg.dataset.pipeline.insert(1, corruption_trans)

# print info
print(f'\nTesting {corruption} at severity {corruption_severity}')
test_loader = runner.build_dataloader(test_loader_cfg)

# build the dataloader
# TODO: support multiple images per gpu
# (only minor changes are needed)
dataset = build_dataset(test_data_cfg)
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=args.workers,
dist=distributed,
shuffle=False)
runner.test_loop.dataloader = test_loader
# runner._test_evaluator.metrics
# set random seeds
if args.seed is not None:
runner.set_randomness(args.seed)

# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
checkpoint = load_checkpoint(
model, args.checkpoint, map_location='cpu')
# old versions did not save class info in checkpoints,
# this walkaround is for backward compatibility
if 'CLASSES' in checkpoint.get('meta', {}):
model.CLASSES = checkpoint['meta']['CLASSES']
else:
model.CLASSES = dataset.CLASSES

if not distributed:
# TODO
model = MMDataParallel(model, device_ids=[0])
show_dir = args.show_dir
if show_dir is not None:
show_dir = osp.join(show_dir, corruption)
show_dir = osp.join(show_dir, str(corruption_severity))
if not osp.exists(show_dir):
osp.makedirs(show_dir)
outputs = single_gpu_test(model, data_loader, args.show,
show_dir, args.show_score_thr)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir)
# print info
print(f'\nTesting {corruption} at severity {corruption_severity}')

if args.out and rank == 0:
eval_results = runner.test()
if args.out:
eval_results_filename = (
osp.splitext(args.out)[0] + '_results' +
osp.splitext(args.out)[1])
dump(outputs, args.out)
eval_types = args.eval
if cfg.dataset_type == 'VOCDataset':
if eval_types:
for eval_type in eval_types:
if eval_type == 'bbox':
# TODO
test_dataset = mmcv.runner.obj_from_dict(
cfg.data.test, datasets)
logger = 'print' if args.summaries else None
mean_ap, eval_results = \
voc_eval_with_return(
args.out, test_dataset,
args.iou_thr, logger)
aggregated_results[corruption][
corruption_severity] = eval_results
else:
print('\nOnly "bbox" evaluation \
is supported for pascal voc')
else:
if eval_types:
print(f'Starting evaluate {" and ".join(eval_types)}')
if eval_types == ['proposal_fast']:
result_file = args.out
else:
if not isinstance(outputs[0], dict):
result_files = dataset.results2json(
outputs, args.out)
else:
for name in outputs[0]:
print(f'\nEvaluating {name}')
outputs_ = [out[name] for out in outputs]
result_file = args.out
+ f'.{name}'
result_files = dataset.results2json(
outputs_, result_file)
eval_results = coco_eval_with_return(
result_files, eval_types, dataset.coco)
aggregated_results[corruption][
corruption_severity] = eval_results
else:
print('\nNo task was selected for evaluation;'
'\nUse --eval to select a task')

# save results after each evaluation
aggregated_results[corruption][
corruption_severity] = eval_results
dump(aggregated_results, eval_results_filename)

rank, _ = get_dist_info()
Expand Down

0 comments on commit e774e18

Please sign in to comment.