In [1]:
# load package
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import cv2
import numpy as np

import tensorflow as tf
from yolov3.yolov4 import Create_Yolo
from yolov3.utils import image_preprocess, postprocess_boxes, draw_bbox, nms
from yolov3.configs import *

MNIST_CLASS = 'mnist/mnist.names'
WEIGHT_PATH = './model_data/mnist/yolov3_mnist_custom'
IMG_SIZE = 416 # config.py의 yolo_input_size와 같음

yolo = Create_Yolo(input_size=IMG_SIZE, CLASSES=MNIST_CLASS)
yolo.load_weights(WEIGHT_PATH)
weights = yolo.get_weights()

In [15]:
# 이미지로 테스트
yolo.set_weights(weights)

img = cv2.imread('mnist_test.jpg')
image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

image_data = image_preprocess(np.copy(image), [IMG_SIZE, IMG_SIZE]) # (416, 416, 3)
image_data = image_data[np.newaxis, ...].astype(np.float32) # (1, 416, 416, 3)

pred_bbox = yolo.predict(image_data)
pred_bbox = [tf.reshape(x, (-1, tf.shape(x)[-1])) for x in pred_bbox]
pred_bbox = tf.concat(pred_bbox, axis=0)

bboxes = postprocess_boxes(pred_bbox, image, IMG_SIZE, 0.3)

# soft-nms(https://arxiv.org/pdf/1704.04503.pdf, 2017)
bboxes = nms(bboxes, iou_threshold=0.45, method='nms') # bboxes = (xmin, ymin, xmax, ymax, score, class)
image = draw_bbox(image, bboxes, CLASSES=MNIST_CLASS, rectangle_colors=(255, 0, 0))

cv2.imshow('Image', image)
if cv2.waitKey() & 0xFF == 27:
    cv2.destroyAllWindows()

In [13]:
# 카메라로 테스트
cap = cv2.VideoCapture(0)

if cap.isOpened():
    delay = int(1000 / cap.get(cv2.CAP_PROP_FPS))
    while True:
        ret, img = cap.read()
        if ret:
            yolo.set_weights(weights)

            image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            image_data = image_preprocess(np.copy(image), [IMG_SIZE, IMG_SIZE]) # (416, 416, 3)
            image_data = image_data[np.newaxis, ...].astype(np.float32) # (1, 416, 416, 3)

            pred_bbox = yolo.predict(image_data)
            pred_bbox = [tf.reshape(x, (-1, tf.shape(x)[-1])) for x in pred_bbox]
            pred_bbox = tf.concat(pred_bbox, axis=0)

            bboxes = postprocess_boxes(pred_bbox, image, IMG_SIZE, 0.3)
            # soft-nms(https://arxiv.org/pdf/1704.04503.pdf, 2017)
            bboxes = nms(bboxes, iou_threshold=0.45, method='nms') # bboxes = (xmin, ymin, xmax, ymax, score, class)
            image = draw_bbox(image, bboxes, CLASSES=MNIST_CLASS, rectangle_colors=(255, 0, 0))

            cv2.imshow('Image', image)
            if cv2.waitKey(1) & 0xFF == 27:
                break

        else:
            print('no frame')
            break

else:
    print('camera not opened')

cap.release()
cv2.destroyAllWindows()