Skip to content

Commit

Permalink
# Update python scrypt
Browse files Browse the repository at this point in the history
For multi scale training
  • Loading branch information
SnowMasaya committed Nov 12, 2017
1 parent dc6b10e commit 4e4eb96
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 30 deletions.
22 changes: 22 additions & 0 deletions cfgs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,28 @@ def mkdir(path, max_depth=3):

# input and output size
############################
multi_scale_inp_size = [np.array([320, 320], dtype=np.int),
np.array([352, 352], dtype=np.int),
np.array([384, 384], dtype=np.int),
np.array([416, 416], dtype=np.int),
np.array([448, 448], dtype=np.int),
np.array([480, 480], dtype=np.int),
np.array([512, 512], dtype=np.int),
np.array([544, 544], dtype=np.int),
np.array([576, 576], dtype=np.int),
# np.array([608, 608], dtype=np.int),
] # w, h
multi_scale_out_size = [multi_scale_inp_size[0] / 32,
multi_scale_inp_size[1] / 32,
multi_scale_inp_size[2] / 32,
multi_scale_inp_size[3] / 32,
multi_scale_inp_size[4] / 32,
multi_scale_inp_size[5] / 32,
multi_scale_inp_size[6] / 32,
multi_scale_inp_size[7] / 32,
multi_scale_inp_size[8] / 32,
# multi_scale_inp_size[9] / 32,
] # w, h
inp_size = np.array([416, 416], dtype=np.int) # w, h
out_size = inp_size / 32

Expand Down
39 changes: 24 additions & 15 deletions darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
import utils.network as net_utils
import cfgs.config as cfg
from layers.reorg.reorg_layer import ReorgLayer
from utils.cython_bbox import bbox_ious, bbox_intersections, bbox_overlaps, anchor_intersections
from utils.cython_bbox import bbox_ious, anchor_intersections
from utils.cython_yolo import yolo_to_bbox
from functools import partial

from multiprocessing import Pool
import multiprocessing


def _make_layers(in_channels, net_cfg):
Expand All @@ -25,17 +27,21 @@ def _make_layers(in_channels, net_cfg):
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
else:
out_channels, ksize = item
layers.append(net_utils.Conv2d_BatchNorm(in_channels, out_channels, ksize, same_padding=True))
# layers.append(net_utils.Conv2d(in_channels, out_channels, ksize, same_padding=True))
layers.append(net_utils.Conv2d_BatchNorm(in_channels,
out_channels,
ksize,
same_padding=True))
# layers.append(net_utils.Conv2d(in_channels, out_channels,
# ksize, same_padding=True))
in_channels = out_channels

return nn.Sequential(*layers), in_channels


def _process_batch(data):
W, H = cfg.out_size
inp_size = cfg.inp_size
out_size = cfg.out_size
def _process_batch(data, size_index):
W, H = cfg.multi_scale_out_size[size_index]
inp_size = cfg.multi_scale_inp_size[size_index]
out_size = cfg.multi_scale_out_size[size_index]

bbox_pred_np, gt_boxes, gt_classes, dontcares, iou_pred_np = data

Expand Down Expand Up @@ -105,7 +111,7 @@ def _process_batch(data):
ious_reshaped = np.reshape(ious, [hw, num_anchors, len(cell_inds)])
for i, cell_ind in enumerate(cell_inds):
if cell_ind >= hw or cell_ind < 0:
print(cell_ind)
print('cell over {} hw {}'.format(cell_ind, hw))
continue
a = anchor_inds[i]

Expand Down Expand Up @@ -154,7 +160,8 @@ def __init__(self):
self.conv3, c3 = _make_layers(c2, net_cfgs[6])

stride = 2
self.reorg = ReorgLayer(stride=2) # stride*stride times the channels of conv1s
# stride*stride times the channels of conv1s
self.reorg = ReorgLayer(stride=2)
# cat [conv1s, conv3]
self.conv4, c4 = _make_layers((c1*(stride*stride) + c3), net_cfgs[7])

Expand All @@ -172,7 +179,7 @@ def __init__(self):
def loss(self):
return self.bbox_loss + self.iou_loss + self.cls_loss

def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None):
def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None, size_index=0):
conv1s = self.conv1s(im_data)
conv2 = self.conv2(conv1s)
conv3 = self.conv3(conv2)
Expand Down Expand Up @@ -201,7 +208,7 @@ def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None):
bbox_pred_np = bbox_pred.data.cpu().numpy()
iou_pred_np = iou_pred.data.cpu().numpy()
_boxes, _ious, _classes, _box_mask, _iou_mask, _class_mask = self._build_target(
bbox_pred_np, gt_boxes, gt_classes, dontcare, iou_pred_np)
bbox_pred_np, gt_boxes, gt_classes, dontcare, iou_pred_np, size_index)

_boxes = net_utils.np_to_variable(_boxes)
_ious = net_utils.np_to_variable(_ious)
Expand All @@ -223,14 +230,16 @@ def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None):

return bbox_pred, iou_pred, prob_pred

def _build_target(self, bbox_pred_np, gt_boxes, gt_classes, dontcare, iou_pred_np):
def _build_target(self, bbox_pred_np, gt_boxes, gt_classes, dontcare, iou_pred_np, size_index):
"""
:param bbox_pred: shape: (bsize, h x w, num_anchors, 4) : (sig(tx), sig(ty), exp(tw), exp(th))
"""

bsize = bbox_pred_np.shape[0]

targets = self.pool.map(_process_batch, ((bbox_pred_np[b], gt_boxes[b], gt_classes[b], dontcare[b], iou_pred_np[b]) for b in range(bsize)))
targets = self.pool.map(partial(_process_batch, size_index=size_index),
((bbox_pred_np[b], gt_boxes[b], gt_classes[b], dontcare[b], iou_pred_np[b])
for b in range(bsize)))

_boxes = np.stack(tuple((row[0] for row in targets)))
_ious = np.stack(tuple((row[1] for row in targets)))
Expand All @@ -250,7 +259,7 @@ def load_from_npz(self, fname, num_conv=None):
keys = list(own_dict.keys())

for i, start in enumerate(range(0, len(keys), 5)):
if num_conv is not None and i>= num_conv:
if num_conv is not None and i >= num_conv:
break
end = min(start+5, len(keys))
for key in keys[start:end]:
Expand All @@ -263,8 +272,8 @@ def load_from_npz(self, fname, num_conv=None):
param = param.permute(3, 2, 0, 1)
own_dict[key].copy_(param)


if __name__ == '__main__':
net = Darknet19()
# net.load_from_npz('models/yolo-voc.weights.npz')
net.load_from_npz('models/darknet19.weights.npz', num_conv=18)

44 changes: 29 additions & 15 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import os
import cv2
import torch
import numpy as np
import datetime
from torch.multiprocessing import Pool

from darknet import Darknet19

Expand All @@ -12,6 +9,7 @@
import utils.network as net_utils
from utils.timer import Timer
import cfgs.config as cfg
from random import randint

try:
from pycrayon import CrayonClient
Expand All @@ -21,12 +19,15 @@

# data loader
imdb = VOCDataset(cfg.imdb_train, cfg.DATA_DIR, cfg.train_batch_size,
yolo_utils.preprocess_train, processes=2, shuffle=True, dst_size=cfg.inp_size)
yolo_utils.preprocess_train, processes=2, shuffle=True,
dst_size=cfg.multi_scale_inp_size)
# dst_size=cfg.inp_size)
print('load data succ...')

net = Darknet19()
# net_utils.load_net(cfg.trained_model, net)
# pretrained_model = os.path.join(cfg.train_output_dir, 'darknet19_voc07trainval_exp1_63.h5')
# pretrained_model = os.path.join(cfg.train_output_dir,
# 'darknet19_voc07trainval_exp1_63.h5')
# pretrained_model = cfg.trained_model
# net_utils.load_net(pretrained_model, net)
net.load_from_npz(cfg.pretrained_model, num_conv=18)
Expand All @@ -37,7 +38,8 @@
# optimizer
start_epoch = 0
lr = cfg.init_learning_rate
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=cfg.momentum,
weight_decay=cfg.weight_decay)

# tensorboad
use_tensorboard = cfg.use_tensorboard and CrayonClient is not None
Expand All @@ -63,19 +65,23 @@
cnt = 0
t = Timer()
step_cnt = 0
for step in range(start_epoch * imdb.batch_per_epoch, cfg.max_epoch * imdb.batch_per_epoch):
size_index = 0
for step in range(start_epoch * imdb.batch_per_epoch,
cfg.max_epoch * imdb.batch_per_epoch):
t.tic()
# batch
batch = imdb.next_batch()
batch = imdb.next_batch(size_index)
im = batch['images']
gt_boxes = batch['gt_boxes']
gt_classes = batch['gt_classes']
dontcare = batch['dontcare']
orgin_im = batch['origin_im']

# forward
im_data = net_utils.np_to_variable(im, is_cuda=True, volatile=False).permute(0, 3, 1, 2)
net(im_data, gt_boxes, gt_classes, dontcare)
im_data = net_utils.np_to_variable(im,
is_cuda=True,
volatile=False).permute(0, 3, 1, 2)
net(im_data, gt_boxes, gt_classes, dontcare, size_index)

# backward
loss = net.loss
Expand All @@ -94,9 +100,12 @@
bbox_loss /= cnt
iou_loss /= cnt
cls_loss /= cnt
print(('epoch %d[%d/%d], loss: %.3f, bbox_loss: %.3f, iou_loss: %.3f, cls_loss: %.3f (%.2f s/batch, rest:%s)' % (
imdb.epoch, step_cnt, batch_per_epoch, train_loss, bbox_loss, iou_loss, cls_loss, duration,
str(datetime.timedelta(seconds=int((batch_per_epoch - step_cnt) * duration))))))
print(('epoch %d[%d/%d], loss: %.3f, bbox_loss: %.3f, iou_loss: %.3f, '
'cls_loss: %.3f (%.2f s/batch, rest:%s)' % (
imdb.epoch, step_cnt, batch_per_epoch, train_loss, bbox_loss,
iou_loss, cls_loss, duration,
str(datetime.timedelta(seconds=int((batch_per_epoch - step_cnt)
* duration))))))

if use_tensorboard and step % cfg.log_interval == 0:
exp.add_scalar_value('loss_train', train_loss, step=step)
Expand All @@ -109,13 +118,18 @@
bbox_loss, iou_loss, cls_loss = 0., 0., 0.
cnt = 0
t.clear()
size_index = randint(0, len(cfg.multi_scale_inp_size) - 1)
print("image_size {}".format(cfg.multi_scale_inp_size[size_index]))

if step > 0 and (step % imdb.batch_per_epoch == 0):
if imdb.epoch in cfg.lr_decay_epochs:
lr *= cfg.lr_decay
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
optimizer = torch.optim.SGD(net.parameters(), lr=lr,
momentum=cfg.momentum,
weight_decay=cfg.weight_decay)

save_name = os.path.join(cfg.train_output_dir, '{}_{}.h5'.format(cfg.exp_name, imdb.epoch))
save_name = os.path.join(cfg.train_output_dir,
'{}_{}.h5'.format(cfg.exp_name, imdb.epoch))
net_utils.save_net(save_name, net)
print(('save model: {}'.format(save_name)))
step_cnt = 0
Expand Down

0 comments on commit 4e4eb96

Please sign in to comment.