Skip to content

Commit

Permalink
Support to view AP for each class (#1549)
Browse files Browse the repository at this point in the history
* also support to view ap

* change string format

* eval class_wise in coco_eval

* reformat

* class_wise API from detectron

* reformat

* change code source

* reformat, use terminaltable
  • Loading branch information
ZwwWayne authored and hellock committed Oct 25, 2019
1 parent 1fe3e7d commit 1f3e273
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
39 changes: 38 additions & 1 deletion mmdet/core/evaluation/coco_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import itertools

import mmcv
import numpy as np
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from terminaltables import AsciiTable

from .recall import eval_recalls


def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000)):
def coco_eval(result_files,
result_types,
coco,
max_dets=(100, 300, 1000),
classwise=False):
for res_type in result_types:
assert res_type in [
'proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'
Expand Down Expand Up @@ -43,6 +50,36 @@ def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000)):
cocoEval.accumulate()
cocoEval.summarize()

if classwise:
# Compute per-category AP
# from https://github.com/facebookresearch/detectron2/blob/03064eb5bafe4a3e5750cc7a16672daf5afe8435/detectron2/evaluation/coco_evaluation.py#L259-L283 # noqa
precisions = cocoEval.eval['precision']
catIds = coco.getCatIds()
# precision has dims (iou, recall, cls, area range, max dets)
assert len(catIds) == precisions.shape[2]

results_per_category = []
for idx, catId in enumerate(catIds):
# area range index 0: all area ranges
# max dets index -1: typically 100 per image
nm = coco.loadCats(catId)[0]
precision = precisions[:, :, idx, 0, -1]
precision = precision[precision > -1]
ap = np.mean(precision) if precision.size else float('nan')
results_per_category.append(
('{}'.format(nm['name']),
'{:0.3f}'.format(float(ap * 100))))

N_COLS = min(6, len(results_per_category) * 2)
results_flatten = list(itertools.chain(*results_per_category))
headers = ['category', 'AP'] * (N_COLS // 2)
results_2d = itertools.zip_longest(
*[results_flatten[i::N_COLS] for i in range(N_COLS)])
table_data = [headers]
table_data += [result for result in results_2d]
table = AsciiTable(table_data)
print(table.table)


def fast_eval_recall(results,
coco,
Expand Down
4 changes: 3 additions & 1 deletion tools/coco_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ def main():
nargs='+',
default=[100, 300, 1000],
help='proposal numbers, only used for recall evaluation')
parser.add_argument(
'--classwise', action='store_true', help='whether eval class wise ap')
args = parser.parse_args()
coco_eval(args.result, args.types, args.ann, args.max_dets)
coco_eval(args.result, args.types, args.ann, args.max_dets, args.classwise)


if __name__ == '__main__':
Expand Down

0 comments on commit 1f3e273

Please sign in to comment.