Skip to content

Commit

Permalink
add detect
Browse files Browse the repository at this point in the history
  • Loading branch information
jinopapo committed Nov 16, 2016
1 parent 873d2d7 commit 469018c
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 28 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
# chainer-SSD
uncompleted
Single Shot MultiBox Detector
chainer用
まだ調整中(殴り書き)

[元のリポジトリ](https://github.com/weiliu89/caffe/tree/ssd).
[元の論文](http://arxiv.org/abs/1512.02325).
56 changes: 41 additions & 15 deletions detect.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import argparse

import numpy as np
import cv2
#import cv2
import skimage.io
import skimage.draw
from skimage.transform import resize
import chainer
from chainer import serializers


import ssd_net


Expand All @@ -16,6 +18,7 @@
args = parser.parse_args()
mean = np.array([104,117,123])
img = skimage.img_as_float(skimage.io.imread(args.path, as_grey=False)).astype(np.float32)
moto = img
img = resize(img, (300,300))
img = img*255 - mean[::-1]
img = img.transpose(2, 0, 1)[::-1]
Expand All @@ -27,14 +30,16 @@

def nms(bboxes, scores, score_th, nms_th, top_k):
score_iter = 0
score_index = scores.argsort[::-1][:top_k]
score_index = scores.argsort()[::-1][:top_k]
indices = []
while(score_iter < len(score_index)):
idx = score_index[score_iter]
keep = True
for i in range(indices):
print(idx)
for i in range(len(indices)):
if keep:
kept_idx = indices[i]
print(bboxes[idx], bboxes[kept_idx])
overlap = IoU(bboxes[idx], bboxes[kept_idx])
keep = overlap <= nms_th
else:
Expand All @@ -43,30 +48,51 @@ def nms(bboxes, scores, score_th, nms_th, top_k):
indices.append(idx)
score_iter+=1
return indices

"""
def nms(bboxes, nms_th, top_k):
bboxes = bboxes[:top_k]
indices = []
for bbox in bboxes:
keep = True
for i in range(indices):
if keep:
kept_idx = indices[i]
overlap = IoU(bbox, bboxes[kept_idx])
keep = overlap <= nms_th
else:
break
if keep:
indices.append(idx)
return indices
"""
def IoU(a, b):
U = union(a, b)
#U = union(a, b)
I = intersection(a, b)
if not I:
return 0
a_ = (a[2]-a[0])*(a[3]-a[1])
b_ = (b[2]-b[0])*(b[3]-b[1])
if a_ <=0 or b_ <= 0:
return 1
i = (I[2]-I[0])*(I[3]-I[1])
return a_ + b_ - i*2

def union(a,b):
x1 = min(a[0], b[0])
y1 = min(a[1], b[1])
w = max(a[0]+a[2], b[0]+b[2]) - x
h = max(a[1]+a[3], b[1]+b[3]) - y
return (x, y, w, h)
return i/(a_ + b_ - i*2)

def intersection(a,b):
x1 = max(a[0], b[0])
y1 = max(a[1], b[1])
x2 = min(a[2], b[2])
y1 = min(a[1], b[1])
y2 = min(a[3], b[3])
w = x2 - x1
h = y2 - y1
if w<0 or h<0: return () # or (0,0,0,0) ?
return (x, y, w, h)
return (x1, y1, x2, y2)

a=model.detection()
for i in a:
conf, x1, y1, x2, y2 = i
x1 *= moto.shape[1]
x2 *= moto.shape[1]
y1 *=moto.shape[0]
y2 *=moto.shape[0]
rr,cc = skimage.draw.polygon_perimeter([y1, y2, y2, y1],[x1, x1, x2, x2], shape=moto.shape, clip=True)
moto[rr, cc]= 0
Binary file added fish-bike.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
103 changes: 91 additions & 12 deletions ssd_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,23 @@ def __init__(self):

)
self.train = False
self.conv4_3_norm_priorbox = self.prior((38, 38), 30., 0, [2], 1, 1,(0.1, 0.1, 0.2, 0.2))
self.conv4_3_norm_mbox_priorbox = self.prior((38, 38), 30., 0, [2], 1, 1,(0.1, 0.1, 0.2, 0.2))
self.fc7_mbox_priorbox = self.prior((19, 19), 60., 114., [2, 3], 1, 1,(0.1, 0.1, 0.2, 0.2))
self.conv6_2_mbox_priorbox = self.prior((10, 10), 114., 168., [2, 3], 1, 1,(0.1, 0.1, 0.2, 0.2))
self.conv7_2_mbox_priorbox = self.prior((5, 5),168., 222., [2, 3], 1, 1,(0.1, 0.1, 0.2, 0.2))
self.conv8_2_mbox_priorbox = self.prior((3, 3), 222., 276., [2, 3], 1, 1,(0.1, 0.1, 0.2, 0.2))
self.pool6_mbox_priorbox = self.prior((1, 1), 276., 330., [2, 3], 1, 1,(0.1, 0.1, 0.2, 0.2))
self.mbox_prior = np.hstack([self.conv4_3_norm_mbox_priorbox,
self.fc7_mbox_priorbox,
self.conv6_2_mbox_priorbox,
self.conv7_2_mbox_priorbox,
self.conv8_2_mbox_priorbox,
self.pool6_mbox_priorbox])






def __call__(self, x, t):

Expand Down Expand Up @@ -166,11 +182,12 @@ def __call__(self, x, t):
self.h_conv8_2_mbox_conf_flat,
self.h_pool6_mbox_conf_flat],axis=0)

self.mbox_reahpe = F.reshape(self.mbox_conf,(1,7308,21))
self.mbox_reahpe_softmax = F.softmax(self.mbox_reahpe)
self.loss = self.loss_func(h, t)
self.accuracy = self.loss
return self.loss
self.mbox_conf_reahpe = F.reshape(self.mbox_conf,(7308,21))
self.mbox_conf_softmax = F.softmax(self.mbox_conf_reahpe)
if self.train:
self.loss = self.loss_func(h, t)
self.accuracy = self.loss
return self.loss

def prior(self, h, min_size, max_size,aspect, flip, clip, variance):
aspect_ratio = [1.]
Expand All @@ -181,7 +198,7 @@ def prior(self, h, min_size, max_size,aspect, flip, clip, variance):
img_width = img_height = 300.
step_x = img_width / float(width)
step_y = img_width / float(height)
top_data=np.zeros(height * width * len(aspect_ratio) * 4)
top_data=np.zeros(height * width * (len(aspect_ratio) + bool(max_size))* 4 )
idx=0
for h in range(height):
for w in range(width):
Expand Down Expand Up @@ -228,7 +245,7 @@ def prior(self, h, min_size, max_size,aspect, flip, clip, variance):
top_data[i] = 1
elif top_data[i] < 0:
top_data[i] = 0
val_data=np.zeros(height * width * len(aspect_ratio) * 4)
val_data=np.zeros(height * width * (len(aspect_ratio) + bool(max_size))* 4 )
if len(variance)==1:
pass
else:
Expand All @@ -241,7 +258,23 @@ def prior(self, h, min_size, max_size,aspect, flip, clip, variance):
count+=1
return np.vstack([top_data, val_data])

def decoder(prior, loc, prior_data):
def detection(self):
prior = np.reshape(self.mbox_prior,(2, 7308, 4))
loc = np.reshape(self.mbox_loc.data, (7308, 4))
conf = self.mbox_conf_softmax.data
cand = []
for label in range(1,21):
l = conf[:,label].argsort()
label_cand = np.array([np.hstack([conf[i, label] ,self.decoder(prior[0, i], loc[i], prior[1, i])]) for i in l if conf[i,label] > 0.1])
if label_cand.any():
k = self.nms(label_cand[:,1:], label_cand[:,0], 0.1, 0.45, 200)
for i in k:
cand.append(label_cand[i])
cand = np.array(cand)
cand = cand[np.where(cand[:,0]>=0.6)]
return cand

def decoder(self, prior, loc, prior_data):
bbox_data = np.array([0]*4,dtype=np.float32)
p_xmin, p_ymin, p_xmax, p_ymax= prior
xmin, ymin, xmax, ymax= loc
Expand All @@ -254,7 +287,53 @@ def decoder(prior, loc, prior_data):
decode_bbox_width = np.exp(prior_data[2] * xmax) * prior_width;
decode_bbox_height = np.exp(prior_data[3] * ymax) * prior_height;
bbox_data[0] = decode_bbox_center_x - decode_bbox_width / 2.
bbox_data[1] = decode_bbox_center_y - decode_bbox_height / 2.;
bbox_data[2] = decode_bbox_center_x + decode_bbox_width / 2.;
bbox_data[3] = decode_bbox_center_y + decode_bbox_height / 2.;
bbox_data[1] = decode_bbox_center_y - decode_bbox_height / 2.
bbox_data[2] = decode_bbox_center_x + decode_bbox_width / 2.
bbox_data[3] = decode_bbox_center_y + decode_bbox_height / 2.
return bbox_data

def nms(self, bboxes, scores, score_th, nms_th, top_k):
score_iter = 0
score_index = scores.argsort()[::-1][:top_k]
indices = []
while(score_iter < len(score_index)):
idx = score_index[score_iter]
keep = True
cand_bbox=bboxes[idx]
if cand_bbox[0] == cand_bbox[2] or cand_bbox[1] == cand_bbox[3]:
keep=False
print(idx)
for i in range(len(indices)):
if keep:
kept_idx = indices[i]
overlap = self.IoU(bboxes[idx], bboxes[kept_idx])
print("over = ",overlap)
keep = overlap <= nms_th
else:
break
if keep:
indices.append(idx)
score_iter+=1
return indices

def IoU(self, a, b):
I = self.intersection(a, b)
if not I:
return 0
a_ = (a[2]-a[0])*(a[3]-a[1])
b_ = (b[2]-b[0])*(b[3]-b[1])
print(a_, b_)
if a_ <=0 or b_ <= 0:
return 1
i = (I[2]-I[0])*(I[3]-I[1])
return i/(a_ + b_ - i*2)

def intersection(self, a,b):
x1 = max(a[0], b[0])
y1 = max(a[1], b[1])
x2 = min(a[2], b[2])
y2 = min(a[3], b[3])
w = x2 - x1
h = y2 - y1
if w<0 or h<0: return () # or (0,0,0,0) ?
return (x1, y1, x2, y2)

0 comments on commit 469018c

Please sign in to comment.