In [None]:
from __future__ import division

from models import *
from utils.utils import *
from utils.datasets import *

import os
import sys
import time
import datetime
import argparse

from PIL import Image

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.ticker import NullLocator


class opt:
    image_folder = "data/samples"
    model_def = "config/yolov3-custom.cfg"
    weights_path = "checkpoints_voc/yolov3_ckpt_309.pth"
    class_path = "config/voc.txt"
    conf_thres = 0.8
    nms_thres = 0.4
    batch_size = 1
    img_size = 416


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("output", exist_ok=True)

model = Darknet(opt.model_def, img_size=opt.img_size).to(device)

if opt.weights_path.endswith(".weights"):
    model.load_darknet_weights(opt.weights_path)
else:
    model.load_state_dict(torch.load(opt.weights_path))

model.eval()

dataloader = DataLoader(
    ImageFolder(opt.image_folder, img_size=opt.img_size),
    batch_size=opt.batch_size,
    shuffle=False
)

classes = load_classes(opt.class_path)

Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

imgs = []
img_detections = []

print("\nPerforming object detection:")
prev_time = time.time()
for batch_i, (img_paths, input_imgs) in enumerate(dataloader):
    input_imgs = Variable(input_imgs.type(Tensor))

    with torch.no_grad():
        detections = model(input_imgs)
        detections = non_max_suppression(detections, opt.conf_thres, opt.nms_thres)

    current_time = time.time()
    inference_time = datetime.timedelta(seconds=current_time - prev_time)
    prev_time = current_time
    print("\t+ Batch %d, Inference Time: %s" % (batch_i, inference_time))

    imgs.extend(img_paths)
    img_detections.extend(detections)

cmap = plt.get_cmap("tab20b")
colors = [cmap(i) for i in np.linspace(0, 1, 20)]

print("\nSaving images:")
for img_i, (path, detections) in enumerate(zip(imgs, img_detections)):

    print("(%d) Image: '%s'" % (img_i, path))

    img = np.array(Image.open(path))
    plt.figure()
    fig, ax = plt.subplots(1)
    ax.imshow(img)

    if detections is not None:
        detections = rescale_boxes(detections, opt.img_size, img.shape[:2])
        unique_labels = detections[:, -1].cpu().unique()
        n_cls_preds = len(unique_labels)
        bbox_colors = random.sample(colors, n_cls_preds)
        for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:

            print("\t+ Label: %s, Conf: %.5f" % (classes[int(cls_pred)], cls_conf.item()))

            box_w = x2 - x1
            box_h = y2 - y1

            color = bbox_colors[int(np.where(unique_labels == int(cls_pred))[0])]
            bbox = patches.Rectangle((x1, y1), box_w, box_h, linewidth=2, edgecolor=color, facecolor="none")
            ax.add_patch(bbox)

            plt.text(
                x1,
                y1,
                s=classes[int(cls_pred)],
                color="white",
                verticalalignment="top",
                bbox={"color": color, "pad": 0},
            )

    plt.axis("off")
    plt.gca().xaxis.set_major_locator(NullLocator())
    plt.gca().yaxis.set_major_locator(NullLocator())
    filename = path.split("/")[-1].split(".")[0]
    plt.savefig(f"output/{filename}.png", bbox_inches="tight", pad_inches=0.0)
    plt.show()
    plt.close()



Performing object detection:
	+ Batch 0, Inference Time: 0:00:00.047660
	+ Batch 1, Inference Time: 0:00:00.037917
	+ Batch 2, Inference Time: 0:00:00.067192
	+ Batch 3, Inference Time: 0:00:00.038191
	+ Batch 4, Inference Time: 0:00:00.053498
	+ Batch 5, Inference Time: 0:00:00.033658
	+ Batch 6, Inference Time: 0:00:00.065203
	+ Batch 7, Inference Time: 0:00:00.045006
	+ Batch 8, Inference Time: 0:00:00.059300
	+ Batch 9, Inference Time: 0:00:00.040758
	+ Batch 10, Inference Time: 0:00:00.061528
	+ Batch 11, Inference Time: 0:00:00.082270
	+ Batch 12, Inference Time: 0:00:00.044135
	+ Batch 13, Inference Time: 0:00:00.076204
	+ Batch 14, Inference Time: 0:00:00.032284
	+ Batch 15, Inference Time: 0:00:00.076737
	+ Batch 16, Inference Time: 0:00:00.036421
	+ Batch 17, Inference Time: 0:00:00.057090
	+ Batch 18, Inference Time: 0:00:00.035594
	+ Batch 19, Inference Time: 0:00:00.059728
	+ Batch 20, Inference Time: 0:00:00.034189
	+ Batch 21, Inference Time: 0:00:00.057678
	+ Batch 22,