In [1]:
import argparse
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
import scipy.io
import scipy.misc
import numpy as np
import pandas as pd
import PIL
import tensorflow as tf
from keras import backend as K
from keras.layers import Input, Lambda, Conv2D
from keras.models import load_model, Model

from yad2k.models.keras_yolo import yolo_head, yolo_boxes_to_corners, preprocess_true_boxes, yolo_loss, yolo_body
import yolo_utils

Using TensorFlow backend.


In [2]:
def yolo_filter_boxes(box_confidence, boxes, box_class_probs, threshold=.6):
	"""
	Filters YOLO boxes by thresholding on object and class confidence
	:param box_confidence: tensor of shape(19, 19, 5, 1)
	:param boxes: tensor of shape(19, 19, 5, 4)
	:param box_class_prob: tensor of shape(19,19,5,80)
	:param threshold: real value
	:return: 
	
	"""
	#step1 compute box scores
	box_scores = box_confidence * box_class_probs
	
	#step2: Find the box_classes thanks to to max box_scores, keep track of the corresponding score
	box_classes = K.argmax(box_scores, axis=-1)
	box_class_scores = K.max(box_scores, axis=-1, keepdims=False)
	
	#step3: Create a filter mask based on "box_class_scores" by using "threshold". The mask should
	# have the same dimension as box_class_scores, and be True for the boxex you want to keep
	filtering_mask = box_class_scores >= threshold
	
	# Step4: Apply the mask to socres, boxes and classes
	scores = tf.boolean_mask(box_class_scores, filtering_mask)
	boxes = tf.boolean_mask(boxes, filtering_mask)
	classes = tf.boolean_mask(box_classes, filtering_mask)
	
	return scores, boxes, classes
	

In [7]:
with tf.Session() as sess:
	box_confidence = tf.random_normal([19,19,5,1], mean=1, stddev=4, seed=1)
	boxes = tf.random_normal([19,19,5,4], mean=1, stddev=4, seed=1)
	box_class_probs = tf.random_normal([19,19,5,80], mean=1, stddev=4, seed=1)
	scores, boxes, classes = yolo_filter_boxes(box_confidence, boxes, box_class_probs,threshold=0.5)
	print(str(scores[2].eval()))
	print(str(boxes[2].eval()))
	print(str(classes[2].eval()))
	print(str(scores.shape))
	print(str(boxes.shape))
	print(str(classes.shape))
	
	sess.close()

10.750582
[ 8.426533   3.2713668 -0.5313436 -4.9413733]
7
(?,)
(?, 4)
(?,)
