## Find optimal anchor boxes for YOLOv2

Extracts the widths and heights of all bounding boxes in the COCO training dataset and runs k-means clustering with a IoU distance to find the optimal anchor boxes.

In [1]:
%matplotlib inline
import matplotlib as mpl
from matplotlib import pyplot as plt
import json
from pycocotools.coco import COCO
from anchor_boxes import *

In [6]:
def load_bbox_wh_coco(data_dir="../../cocodata/",data_type = "train2017"):
    """Create an a list of all anchor box width and heights

    The returned list contains tuples of the form (width, height), where bith
    width and height are measured in units of the image size.
    
    :arg data_dir: directory containing the COCO data
    :arg data_type: type of data, e.g. train2017
    """
    ann_file=f"{data_dir}/annotations/instances_{data_type}.json"
    coco = COCO(ann_file)
    cat_ids = coco.getCatIds(catNms=['any'])
    img_ids = coco.getImgIds(catIds=cat_ids)
    bbox_wh = []
    for img_id in img_ids:
        img_info = coco.loadImgs([img_id])[0]
        ann_ids = coco.getAnnIds(imgIds=[img_id])
        for ann_id in ann_ids:
            annotation = coco.loadAnns([ann_id])
            bbox_wh.append([annotation[0]['bbox'][2]/img_info['width'],
                            annotation[0]['bbox'][3]/img_info['height']])
        
    return np.asarray(bbox_wh)

def visualise_anchor_boxes(anchor_boxes):
    """Print out anchor boxes and plot them
    
    :arg anchor_boxes: list with tuples of form (width, height) for all anchor boxes    
    """
    print ("*** Anchor boxes*** ")
    for j,anchor_box in enumerate(anchor_boxes):
        print (f"{j} : {anchor_box[0]:5.2f} x {anchor_box[1]:5.2f}")

    # visualise bounding boxes
    plt.clf()
    plt.figure(figsize=(10, 10))
    ax = plt.gca()
    ax.set_xlim(-0.5,+0.5)
    ax.set_ylim(-0.5,+0.5)
    for j,anchor_box in enumerate(anchor_boxes):
        w,h = anchor_box[0], anchor_box[1]
        x,y = -0.5*w,-0.5*h
        ax.add_patch(mpl.patches.Rectangle((x,y),w,h,linewidth=2,edgecolor='red',facecolor='blue',alpha=0.2))
        plt.text(x+0.002,y+0.002,f"{j} : {w:5.2f} x {h:5.2f}")
    plt.show()

def save_anchor_boxes(anchor_boxes, filename):
    """Save anchor boxes as a dictionary to json file
    
    :arg anchor_boxes: list consisting of tuples (width, height) for all anchor boxes
    :arg filename: name of json file to save dictionary to
    """
    # Save to json as a dictionary
    anchor_box_dict = []
    for anchor_box in anchor_boxes:
        anchor_box_dict.append({'width':anchor_box[0],'height':anchor_box[1]})
    with open(filename,'w',encoding='utf8') as f:
        json.dump(anchor_box_dict,f)

##############################################################################################

# number of anchor boxes (= number of centroids for k-means clustering)
n_anchors=5

bbox_wh = {'coco':load_bbox_wh_coco(data_dir="../../cocodata/",data_type = "train2017")}

for dataset in bbox_wh.keys():
    print (f"======== {dataset} ========")

    # k-means clustering using IoU distance
    kmeans = KMeans(distanceIoU,n_centroid=n_anchors)
    assignments, anchor_boxes = kmeans.cluster(bbox_wh[dataset],maxiter=200)

    # visualise anchor boxes
    visualise_anchor_boxes(anchor_boxes)

    # visualise clustered data
    plot_kmeans_data(bbox_wh[dataset], assignments, anchor_boxes,label_x='width',label_y='height')

    # save anchor boxes to json file
    save_anchor_boxes(anchor_boxes, f"anchor_boxes.json")

loading annotations into memory...
Done (t=11.71s)
creating index...
index created!


KeyboardInterrupt: 