## COCO dataset validation using Faster-RCNN

In [3]:
import json
from pathlib import Path

import numpy as np
import torch
import tqdm
import torchvision.datasets as dset
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, fasterrcnn_resnet50_fpn
from torchvision.transforms import ToTensor, Compose
from torchvision.datasets import CocoDetection
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
coco_val = dset.CocoDetection(root="../data/coco/val2017/",
                              annFile="../data/coco/annotations/instances_val2017.json",
                              transform=ToTensor())

loading annotations into memory...
Done (t=0.64s)
creating index...
index created!


In [6]:
model = fasterrcnn_resnet50_fpn(pretrained=True)
model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

In [14]:
# Since images are different sizes, must keep batch size to 1...
coco_val_dl = torch.utils.data.DataLoader(coco_val, batch_size=1, num_workers=1)

In [15]:
def validation_loop(coco_dataloader, model):
    # Prepare a dictionary of counts for each category
    counts = {}
    for cid in coco_dataloader.dataset.coco.cats.keys():
        counts[cid] = 0
    results = []
    model.eval()
    dl = tqdm.tqdm(coco_dataloader)
    with torch.no_grad():
        for X, y in dl:
            pred = model(X.to(device))
            # For some reason, some images return empty labels (?)
            if not y:
                continue
            image_id = y[0]['image_id'].item()
            # Record instances of each category
            for gt in y:
                cid = gt['category_id'].item()
                counts[cid] += 1
            for p in pred:
                for label, box, score in zip(p['labels'].tolist(), p['boxes'].tolist(), p['scores'].tolist()):
                    res = {'image_id': image_id}
                    res['category_id'] = label
                    # Convert to x, y, width, height
                    res['bbox'] = [box[0], box[1], box[2] - box[0], box[3] - box[1]]
                    res['score'] = score
                    results.append(res)
    return results, counts

In [17]:
#results, counts = validation_loop(coco_val_dl, model)

In [12]:
#with open("results.json", "w") as f:
#    json.dump(results, f)
#with open("counts.json", "w") as f:
#    json.dump(counts, f)

In [18]:
with open("results.json", "r") as f:
    results = json.load(f)
with open("counts.json", "r") as f:
    counts = json.load(f)

In [19]:
img_ids = set()
cat_ids = set()
for res in results:
    img_ids.add(res['image_id'])
    cat_ids.add(res['category_id'])

In [20]:
coco_res = coco_val.coco.loadRes("results.json")

Loading and preparing results...
DONE (t=1.69s)
creating index...
index created!


In [21]:
coco_eval = COCOeval(cocoGt=coco_val.coco, cocoDt=coco_res, iouType='bbox')

In [22]:
coco_eval.params.imgIds = list(img_ids)

In [24]:
def coco_summarize(coco_eval, ap=1, iouThr=None, areaRng='all', maxDets=100 ):
    p = coco_eval.params
    iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}'
    titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
    typeStr = '(AP)' if ap==1 else '(AR)'
    iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
        if iouThr is None else '{:0.2f}'.format(iouThr)
    
    aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
    mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
    if ap == 1:
        # dimension of precision: [TxRxKxAxM]
        s = coco_eval.eval['precision']
        # IoU
        if iouThr is not None:
            t = np.where(iouThr == p.iouThrs)[0]
            s = s[t]
        s = s[:,:,:,aind,mind]
    else:
        # dimension of recall: [TxKxAxM]
        s = coco_eval.eval['recall']
        if iouThr is not None:
            t = np.where(iouThr == p.iouThrs)[0]
            s = s[t]
        s = s[:,:,aind,mind]
    if len(s[s>-1])==0:
        mean_s = -1
    else:
        mean_s = np.mean(s[s>-1])
    return mean_s

In [None]:
coco_eval.params.areaRng = [[0, 1e8]]
coco_eval.params.areaRngLbl = ['all']
coco_eval.params.maxDets = [1000]
coco_eval.params.iouThrs = [0.5]
coco_eval.params.catIds = list(cat_ids)
coco_eval.evaluate()
coco_eval.accumulate()

In [26]:
%%capture
precisions = []
recalls = []
# IoU threshold
coco_eval.params.iouThrs = np.array([0.5])
# Max detections per image
coco_eval.params.maxDets = [100]
for cid in cat_ids:
    coco_eval.params.catIds = [cid]
    coco_eval.evaluate()
    coco_eval.accumulate()
    precisions.append(coco_summarize(coco_eval))
    recalls.append(coco_summarize(coco_eval, ap=0))
precisions = np.array(precisions)
recalls = np.array(recalls)

In [27]:
f1_scores = 2 * precisions * recalls / (precisions + recalls)

In [394]:
k = 20
k_lowest = np.argsort(f1_scores)[:k]
bad_cats = np.array(list(cat_ids))[k_lowest]
bad_cat_dict = {coco_val.coco.cats[cid]['name']: cid for cid in bad_cats}
print(f"{'Category': >16}\t{'F1': >5}   \tInstances\n")
for cid, low in zip(bad_cats, k_lowest):
    print(f"{coco_val.coco.cats[cid]['name']: >16}", f"\t{f1_scores[low]:0.04f}  \t{counts[str(cid)]}")

        Category	   F1   	Instances

      hair drier 	0.1164  	11
           spoon 	0.3449  	253
         handbag 	0.3670  	540
           apple 	0.3797  	239
           knife 	0.3811  	326
        backpack 	0.3846  	371
        scissors 	0.4088  	36
            book 	0.4106  	1161
           bench 	0.4299  	413
      toothbrush 	0.4658  	57
          carrot 	0.4996  	371
         hot dog 	0.5134  	127
         toaster 	0.5216  	9
          banana 	0.5232  	379
          orange 	0.5259  	287
           chair 	0.5295  	1791
    dining table 	0.5377  	697
            skis 	0.5421  	241
        broccoli 	0.5444  	316
          remote 	0.5457  	283


In [83]:
%%capture
coco_eval.params.iouThrs = np.array([0.5])
# Max detections per image
coco_eval.params.maxDets = [100]

bad_cat_imgs = []

for cid in bad_cats:
    cat_img_ids = set()
    coco_eval.params.catIds = [cid]
    coco_eval.evaluate()
    coco_eval.accumulate()
    cat_imgs = np.where(coco_eval.evalImgs)[0]
    for cimg in cat_imgs:
        cat_img_ids.add(coco_eval.evalImgs[cimg]['image_id'])
    bad_cat_imgs.append(sorted(list(cat_img_ids)))

In [100]:
coco_val.coco.anns[0] = coco_val.coco.anns[1768]

In [177]:
def get_img_from_id(iid):
    ind = np.where(np.array(coco_val.ids) == iid)[0][0]
    img, ann = coco_val[ind]
    return img.squeeze().permute(1, 2, 0).numpy().copy(), ann

In [178]:
bad_cat_imgs = [list(set(coco_val.coco.catToImgs[cat])) for cat in bad_cats]

In [179]:
results_by_img = {}
for res in results:
    if results_by_img.get(res['image_id']):
        results_by_img[res['image_id']].append(res)
    else:
        results_by_img[res['image_id']] = [res]

In [395]:
def write_bad_cat_images(cat_name):
    outdir = Path(f"../data/coco_val_results/{cat_name}")
    outdir.mkdir(exist_ok=True)
    cid = bad_cat_dict[cat_name]
    cat_imgs = bad_cat_imgs[np.where(bad_cats == cid)[0][0]]
    i = 0
    for iid in cat_imgs:
        metrics = coco_eval.evaluateImg(iid, cid, [0, 1e6], 1000)
        if metrics['gtMatches'].any():
            continue
        img_res = results_by_img[iid]
        img, anns = get_img_from_id(iid)
        for res in anns:
            if res['category_id'] == cid:
                bx, by, w, h = np.array(res['bbox'], dtype=int)
                img = cv.rectangle(img, (bx, by), (bx+w, by+h), color=[1, 0, 0], thickness=3)
        for res in img_res:
            if res['category_id'] == cid:
                bx, by, w, h = np.array(res['bbox'], dtype=int)
                img = cv.rectangle(img, (bx, by), (bx+w, by+h), color=[0, 1, 1], thickness=2)
        plt.figure(figsize=(5, 5))
        plt.imshow(img)
        plt.axis("off")
        plt.show()
        plt.close()
        img = (np.round(img * 255)).astype(np.uint8)
        cv.imwrite(str(outdir / f"{iid}.jpg"), cv.cvtColor(img, cv.COLOR_BGR2RGB))

In [397]:
for cat_name in ["apple", "banana", "book", "couch", "dining table", "handbag", "backpack", "bench", "chair", "hair drier"]:
    write_bad_cat_images(cat_name)