diff --git a/mmdet/core/evaluation/coco_utils.py b/mmdet/core/evaluation/coco_utils.py index f6f5ac0a5dc..ef44940366e 100644 --- a/mmdet/core/evaluation/coco_utils.py +++ b/mmdet/core/evaluation/coco_utils.py @@ -1,12 +1,19 @@ +import itertools + import mmcv import numpy as np from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval +from terminaltables import AsciiTable from .recall import eval_recalls -def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000)): +def coco_eval(result_files, + result_types, + coco, + max_dets=(100, 300, 1000), + classwise=False): for res_type in result_types: assert res_type in [ 'proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints' @@ -43,6 +50,36 @@ def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000)): cocoEval.accumulate() cocoEval.summarize() + if classwise: + # Compute per-category AP + # from https://github.com/facebookresearch/detectron2/blob/03064eb5bafe4a3e5750cc7a16672daf5afe8435/detectron2/evaluation/coco_evaluation.py#L259-L283 # noqa + precisions = cocoEval.eval['precision'] + catIds = coco.getCatIds() + # precision has dims (iou, recall, cls, area range, max dets) + assert len(catIds) == precisions.shape[2] + + results_per_category = [] + for idx, catId in enumerate(catIds): + # area range index 0: all area ranges + # max dets index -1: typically 100 per image + nm = coco.loadCats(catId)[0] + precision = precisions[:, :, idx, 0, -1] + precision = precision[precision > -1] + ap = np.mean(precision) if precision.size else float('nan') + results_per_category.append( + ('{}'.format(nm['name']), + '{:0.3f}'.format(float(ap * 100)))) + + N_COLS = min(6, len(results_per_category) * 2) + results_flatten = list(itertools.chain(*results_per_category)) + headers = ['category', 'AP'] * (N_COLS // 2) + results_2d = itertools.zip_longest( + *[results_flatten[i::N_COLS] for i in range(N_COLS)]) + table_data = [headers] + table_data += [result for result in results_2d] + table = AsciiTable(table_data) + print(table.table) + def fast_eval_recall(results, coco, diff --git a/tools/coco_eval.py b/tools/coco_eval.py index 65e114ca280..bc3c96b3cfb 100644 --- a/tools/coco_eval.py +++ b/tools/coco_eval.py @@ -20,8 +20,10 @@ def main(): nargs='+', default=[100, 300, 1000], help='proposal numbers, only used for recall evaluation') + parser.add_argument( + '--classwise', action='store_true', help='whether eval class wise ap') args = parser.parse_args() - coco_eval(args.result, args.types, args.ann, args.max_dets) + coco_eval(args.result, args.types, args.ann, args.max_dets, args.classwise) if __name__ == '__main__':