In [1]:
import os

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm

import _init_paths
from model.config import cfg
from model.test import im_detect
from model.nms_wrapper import nms
from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1
from utils.timer import Timer

%matplotlib inline

In [2]:
NETS = {
    'vgg16': ('vgg16_faster_rcnn_iter_%d.pth',),
    'res101': ('res101_faster_rcnn_iter_%d.pth',),
}

DATASETS= {
    'pascal_voc': ('voc_2007_trainval',),
    'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',),
}

IMAGE_DIRECTORY = 'frames/gangnam'
SAVE_DIRECTORY = 'output/gangnam'

torch.set_num_threads(1)
#IMAGE_DIRECTORY = 'demo'

In [3]:
def vis_detections(im, class_name, dets, thresh=0.5, save_fig=False, fname=None):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]

    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 8))
    ax.imshow(im, aspect='equal')
    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]

        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red', linewidth=3.5)
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')
        
    plt.axis('off')
    plt.tight_layout()
    if save_fig:
        save_file = os.path.join(cfg.DATA_DIR, SAVE_DIRECTORY, fname)
        plt.savefig(save_file)
        plt.close()
    else:
        plt.show()


In [4]:
def detect_person(net, image_name, save_fig=False):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, IMAGE_DIRECTORY, image_name)
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(net, im)
    timer.toc()
    print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time(), boxes.shape[0]))

    # Visualize detections for each class
    conf_thresh = 0.8
    nms_thresh = 0.3
    
    cls_ind = 15
    cls = 'Person'
    cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
    cls_scores = scores[:, cls_ind]
    dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32)
    keep = nms(torch.from_numpy(dets), nms_thresh)
    dets = dets[keep.numpy(), :]
    vis_detections(im, cls, dets, thresh=conf_thresh, save_fig=save_fig, fname=image_name)
    return scores, boxes

In [5]:
demonet = 'res101' # Network to use [vgg16 res101]
dataset = 'pascal_voc_0712' # Trained dataset [pascal_voc pascal_voc_0712]

cfg.TEST.HAS_RPN = True  # Use RPN for proposals
cfg.TEST.RPN_POST_NMS_TOP_N = 300 # Paper uses 2000 region proposals

# model path
saved_model = os.path.join('../output', demonet, DATASETS[dataset][0], 'default',
                          NETS[demonet][0] %(70000 if dataset == 'pascal_voc' else 110000))


# load network
net = resnetv1(num_layers=101)
net.create_architecture(21, tag='default', anchor_scales=[8, 16, 32])

net.load_state_dict(torch.load(saved_model, map_location=lambda storage, loc: storage))

net.eval()
net._device = 'cpu'
net.to(net._device)

print('Loaded network {:s}'.format(saved_model))


Loaded network ../output/res101/voc_2007_trainval+voc_2012_trainval/default/res101_faster_rcnn_iter_110000.pth


In [8]:
frame_skip = 12

im_names = sorted(os.listdir(os.path.join(cfg.DATA_DIR, IMAGE_DIRECTORY)))[8::frame_skip]
#im_names = ['004545.jpg']

for im_name in im_names:
    print(im_name)
    scores, boxes = detect_person(net, im_name, save_fig=True)

frame_000009.jpg
Detection took 25.459s for 300 object proposals
frame_000021.jpg
Detection took 31.218s for 300 object proposals
frame_000033.jpg
Detection took 30.742s for 300 object proposals
frame_000045.jpg
Detection took 31.278s for 300 object proposals
frame_000057.jpg
Detection took 28.747s for 265 object proposals
frame_000069.jpg
Detection took 25.946s for 229 object proposals
frame_000081.jpg
Detection took 24.823s for 213 object proposals
frame_000093.jpg
Detection took 24.263s for 195 object proposals
frame_000105.jpg
Detection took 25.569s for 225 object proposals
frame_000117.jpg
Detection took 29.420s for 275 object proposals
frame_000129.jpg
Detection took 30.478s for 291 object proposals
frame_000141.jpg
Detection took 19.718s for 141 object proposals
frame_000153.jpg
Detection took 17.986s for 130 object proposals
frame_000165.jpg
Detection took 16.183s for 109 object proposals
frame_000177.jpg
Detection took 16.844s for 109 object proposals
frame_000189.jpg
Detectio

Detection took 29.995s for 300 object proposals
frame_001533.jpg
Detection took 30.472s for 300 object proposals
frame_001545.jpg
Detection took 30.454s for 300 object proposals
frame_001557.jpg
Detection took 29.544s for 300 object proposals
frame_001569.jpg
Detection took 29.880s for 300 object proposals
frame_001581.jpg
Detection took 30.310s for 300 object proposals
frame_001593.jpg
Detection took 30.912s for 300 object proposals
frame_001605.jpg
Detection took 30.539s for 300 object proposals
frame_001617.jpg
Detection took 29.973s for 300 object proposals
frame_001629.jpg
Detection took 29.853s for 300 object proposals
frame_001641.jpg
Detection took 31.814s for 300 object proposals
frame_001653.jpg
Detection took 36.447s for 300 object proposals
frame_001665.jpg
Detection took 37.786s for 297 object proposals
frame_001677.jpg
Detection took 35.355s for 267 object proposals
frame_001689.jpg
Detection took 26.100s for 223 object proposals
frame_001701.jpg
Detection took 16.995s fo

KeyboardInterrupt: 