Skip to content

Commit

Permalink
# Update python scrypt
Browse files Browse the repository at this point in the history
For multi scale training
  • Loading branch information
SnowMasaya committed Nov 12, 2017
1 parent 4e4eb96 commit 7fa25e1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
3 changes: 3 additions & 0 deletions darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def _process_batch(data, size_index):
ious_reshaped = np.reshape(ious, [hw, num_anchors, len(cell_inds)])
for i, cell_ind in enumerate(cell_inds):
if cell_ind >= hw or cell_ind < 0:
print('cell inds size {}'.format(len(cell_inds)))
print('cell over {} hw {}'.format(cell_ind, hw))
continue
a = anchor_inds[i]
Expand Down Expand Up @@ -168,6 +169,7 @@ def __init__(self):
# linear
out_channels = cfg.num_anchors * (cfg.num_classes + 5)
self.conv5 = net_utils.Conv2d(c4, out_channels, 1, 1, relu=False)
self.global_average_pool = nn.AvgPool2d((1,1))

# train
self.bbox_loss = None
Expand All @@ -187,6 +189,7 @@ def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None, size_i
cat_1_3 = torch.cat([conv1s_reorg, conv3], 1)
conv4 = self.conv4(cat_1_3)
conv5 = self.conv5(conv4) # batch_size, out_channels, h, w
conv5 = self.global_average_pool(conv5)

# for detection
# bsize, c, h, w -> bsize, h, w, c -> bsize, h x w, num_anchors, 5+num_classes
Expand Down
14 changes: 12 additions & 2 deletions datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import PIL
import numpy as np
from multiprocessing import Pool
from functools import partial
import cfgs.config as cfg
import cv2


def mkdir(path, max_depth=3):
Expand All @@ -13,6 +16,12 @@ def mkdir(path, max_depth=3):
os.mkdir(path)


def image_resize(im, size_index):
w, h = cfg.multi_scale_inp_size[size_index]
im = cv2.resize(im, (w, h))
return im


class ImageDataset(object):
def __init__(self, name, datadir, batch_size, im_processor, processes=3, shuffle=True, dst_size=None):
self._name = name
Expand All @@ -38,12 +47,13 @@ def __init__(self, name, datadir, batch_size, im_processor, processes=3, shuffle
self.gen = None
self._im_processor = im_processor

def next_batch(self):
def next_batch(self, size_index):
batch = {'images': [], 'gt_boxes': [], 'gt_classes': [], 'dontcare': [], 'origin_im': []}
i = 0
while i < self.batch_size:
try:
images, gt_boxes, classes, dontcare, origin_im = next(self.gen)
images = image_resize(images, size_index)
batch['images'].append(images)
batch['gt_boxes'].append(gt_boxes)
batch['gt_classes'].append(classes)
Expand All @@ -54,7 +64,7 @@ def next_batch(self):
indexes = np.arange(len(self.image_names), dtype=np.int)
if self._shuffle:
np.random.shuffle(indexes)
self.gen = self.pool.imap(self._im_processor,
self.gen = self.pool.imap(partial(self._im_processor, size_index=size_index),
([self.image_names[i], self.get_annotation(i), self.dst_size] for i in indexes),
chunksize=self.batch_size)
self._epoch += 1
Expand Down

0 comments on commit 7fa25e1

Please sign in to comment.