diff --git a/configs/faster/rpn_res101_mx_bn1.yml b/configs/faster/rpn_res101_mx_bn1.yml new file mode 100644 index 000000000..8ef798836 --- /dev/null +++ b/configs/faster/rpn_res101_mx_bn1.yml @@ -0,0 +1,154 @@ +--- +MXNET_VERSION: "mxnet" +output_path: "./output/chips_rpn_resnet101_mx_bn" +proposal_path: "proposals" +symbol: resnet_mx_101_rpn +gpus: '0,1,2,3,4,5,6,7' +CLASS_AGNOSTIC: true +startMSTR: 9 +IS_DPN: false +SCALES: +- !!python/tuple [800,1280] +- !!python/tuple [800,1280] +- !!python/tuple [800,1280] +default: + frequent: 100 + kvstore: device +network: + deform: true + pretrained: "./data/pretrained_model/resnet_mx_101" + pretrained_epoch: 0 + PIXEL_MEANS: + - 103.939 + - 116.779 + - 123.68 + IMAGE_STRIDE: 0 + RCNN_FEAT_STRIDE: 16 + RPN_FEAT_STRIDE: 16 + FIXED_PARAMS: + - conv0 + - bn0 + - stage1 + + FIXED_PARAMS_SHARED: + - conv0 + - bn0 + - stage1 + + ANCHOR_RATIOS: + - 0.5 + - 1 + - 2 + ANCHOR_SCALES: + - 2 + - 4 + - 7 + - 10 + - 13 + - 16 + - 24 + NUM_ANCHORS: 7 +dataset: + NUM_CLASSES: 81 + dataset: coco + dataset_path: "./data/coco" + image_set: minival2014 #train2014+valminusminival2014 + root_path: "./data" + test_image_set: test-dev2015 + proposal: rpn +TRAIN: + lr: 0.015 #0.002 #0.0005 + lr_step: '4' + warmup: true + fp16: true + warmup_lr: 0.0005 #0.00005 + wd: 0.0001 + scale: 100.0 + warmup_step: 9000 #4000 #1000 + begin_epoch: 0 + end_epoch: 5 #9 + model_prefix: 'rcnn' + # whether resume training + RESUME: false + # whether flip image + FLIP: true + # whether shuffle image + SHUFFLE: true + # whether use OHEM + ENABLE_OHEM: true + # size of images for each device, 2 for rcnn, 1 for rpn and e2e + BATCH_IMAGES: 16 + # e2e changes behavior of anchor loader and metric + END2END: false + # group images with similar aspect ratio + ASPECT_GROUPING: true + # R-CNN + # rcnn rois batch size + BATCH_ROIS: -1 + BATCH_ROIS_OHEM: 256 + # rcnn rois sampling params + FG_FRACTION: 0.25 + FG_THRESH: 0.5 + BG_THRESH_HI: 0.5 + BG_THRESH_LO: 0.0 + # rcnn bounding box regression params + BBOX_REGRESSION_THRESH: 0.5 + BBOX_WEIGHTS: + - 1.0 + - 1.0 + - 1.0 + - 1.0 + + # RPN anchor loader + # rpn anchors batch size + RPN_BATCH_SIZE: 256 + # rpn anchors sampling params + RPN_FG_FRACTION: 0.5 + RPN_POSITIVE_OVERLAP: 0.5 + RPN_NEGATIVE_OVERLAP: 0.4 + RPN_CLOBBER_POSITIVES: false + # rpn bounding box regression params + RPN_BBOX_WEIGHTS: + - 1.0 + - 1.0 + - 1.0 + - 1.0 + RPN_POSITIVE_WEIGHT: -1.0 + # used for end2end training + # RPN proposal + CXX_PROPOSAL: false + RPN_NMS_THRESH: 0.7 + RPN_PRE_NMS_TOP_N: 6000 + RPN_POST_NMS_TOP_N: 300 + RPN_MIN_SIZE: 0 + # approximate bounding box regression + BBOX_NORMALIZATION_PRECOMPUTED: true + BBOX_MEANS: + - 0.0 + - 0.0 + - 0.0 + - 0.0 + BBOX_STDS: + - 0.1 + - 0.1 + - 0.2 + - 0.2 +TEST: + # use rpn to generate proposal + HAS_RPN: false + # size of images for each device + BATCH_IMAGES: 1 + # RPN proposal + CXX_PROPOSAL: false + RPN_NMS_THRESH: 0.7 + RPN_PRE_NMS_TOP_N: 6000 + RPN_POST_NMS_TOP_N: 300 + RPN_MIN_SIZE: 0 + # RPN generate proposal + PROPOSAL_NMS_THRESH: 0.7 + PROPOSAL_PRE_NMS_TOP_N: 20000 + PROPOSAL_POST_NMS_TOP_N: 2000 + PROPOSAL_MIN_SIZE: 0 + # RCNN nms + NMS: 0.45 + test_epoch: 7 \ No newline at end of file diff --git a/lib/iterators/HelperV3.py b/lib/iterators/HelperV3.py index c5af614c0..070003ba3 100644 --- a/lib/iterators/HelperV3.py +++ b/lib/iterators/HelperV3.py @@ -1,52 +1,59 @@ import cv2 import mxnet as mx import numpy as np +import numpy.random as npr from bbox.bbox_transform import * from bbox.bbox_regression import expand_bbox_regression_targets -class im_worker(object): - def __init__(self,cfg,crop_size): - self.cfg = cfg - self.crop_size = crop_size - - def worker(self,data): - imp = data[0] - crop = data[1] - flipped = data[2] - crop_size = self.crop_size - pixel_means = self.cfg.network.PIXEL_MEANS - - im = cv2.imread(imp, cv2.IMREAD_COLOR) - - # Crop the image - crop_scale = crop[1] - if flipped: - im = im[:, ::-1, :] - - origim = im[int(crop[0][1]):int(crop[0][3]),int(crop[0][0]):int(crop[0][2]),:] - - # Scale the image - crop_scale = crop[1] - - # Resize the crop - if int(origim.shape[0]*0.625)==0 or int(origim.shape[1]*0.625)==0: - print 'Something wrong3' - try: - im = cv2.resize(origim, None, None, fx=crop_scale, fy=crop_scale, interpolation=cv2.INTER_LINEAR) - except: - print 'Something wrong4' - - rim = np.zeros((3, crop_size, crop_size), dtype=np.float32) - d1m = min(im.shape[0], crop_size) - d2m = min(im.shape[1], crop_size) - if not self.cfg.IS_DPN: - for j in range(3): - rim[j, :d1m, :d2m] = im[:d1m, :d2m, 2-j] - pixel_means[2-j] - else: - for j in range(3): - rim[j, :d1m, :d2m] = (im[:d1m, :d2m, 2-j] - pixel_means[2-j]) * 0.0167 - - - return mx.nd.array(rim, dtype='float32') +from generate_anchor import generate_anchors + +scales = np.array([2, 4, 7, 10, 13, 16, 24], dtype=np.float32) +ratios = (0.5, 1, 2) +feat_stride = 16 +base_anchors = generate_anchors(base_size=feat_stride, ratios=list(ratios), scales=list(scales)) +num_anchors = base_anchors.shape[0] +feat_width = 32 +feat_height = 32 +shift_x = np.arange(0, feat_width) * feat_stride +shift_y = np.arange(0, feat_height) * feat_stride +shift_x, shift_y = np.meshgrid(shift_x, shift_y) +shifts = np.vstack((shift_x.ravel(), shift_y.ravel(), shift_x.ravel(), shift_y.ravel())).transpose() +A = num_anchors +K = shifts.shape[0] +all_anchors = base_anchors.reshape((1, A, 4)) + shifts.reshape((1, K, 4)).transpose((1, 0, 2)) +all_anchors = all_anchors.reshape((K * A, 4)) + +def im_worker(data): + imp = data[0] + crop = data[1] + flipped = data[2] + crop_size = data[3] + im = cv2.imread(imp, cv2.IMREAD_COLOR) + + # Crop the image + crop_scale = crop[1] + if flipped: + im = im[:, ::-1, :] + + origim = im[int(crop[0][1]):int(crop[0][3]),int(crop[0][0]):int(crop[0][2]),:] + + # Scale the image + crop_scale = crop[1] + + # Resize the crop + if int(origim.shape[0]*0.625)==0 or int(origim.shape[1]*0.625)==0: + print 'Something wrong3' + try: + im = cv2.resize(origim, None, None, fx=crop_scale, fy=crop_scale, interpolation=cv2.INTER_LINEAR) + except: + print 'Something wrong4' + + rim = np.zeros((3, crop_size, crop_size), dtype=np.float32) + d1m = min(im.shape[0], crop_size) + d2m = min(im.shape[1], crop_size) + rim[0, :d1m, :d2m] = im[:d1m, :d2m, 2] - 123.15 + rim[1, :d1m, :d2m] = im[:d1m, :d2m, 1] - 115.90 + rim[2, :d1m, :d2m] = im[:d1m, :d2m, 0] - 103.06 + return mx.nd.array(rim, dtype='float32') def sample_rois(rois, fg_rois_per_image, rois_per_image, num_classes, @@ -118,6 +125,145 @@ def sample_rois(rois, fg_rois_per_image, rois_per_image, num_classes, return rois, labels, bbox_targets, bbox_weights +def roidb_anchor_worker(data): + im_info = data[0] + cur_crop = data[1] + im_scale = data[2] + nids = data[3] + gtids = data[4] + gt_boxes = data[5] + boxes = data[6] + + anchors = all_anchors.copy() + inds_inside = np.where((anchors[:, 0] >= -32) & + (anchors[:, 1] >= -32) & + (anchors[:, 2] < im_info[0]+32) & + (anchors[:, 3] < im_info[1]+32))[0] + + anchors = anchors[inds_inside, :] + labels = np.empty((len(inds_inside),), dtype=np.float32) + labels.fill(-1) + total_anchors = int(K * A) + + gt_boxes[:, 0] = gt_boxes[:, 0] - cur_crop[0] + gt_boxes[:, 2] = gt_boxes[:, 2] - cur_crop[0] + gt_boxes[:, 1] = gt_boxes[:, 1] - cur_crop[1] + gt_boxes[:, 3] = gt_boxes[:, 3] - cur_crop[1] + + vgt_boxes = boxes[np.intersect1d(gtids, nids)] + + vgt_boxes[:, 0] = vgt_boxes[:, 0] - cur_crop[0] + vgt_boxes[:, 2] = vgt_boxes[:, 2] - cur_crop[0] + vgt_boxes[:, 1] = vgt_boxes[:, 1] - cur_crop[1] + vgt_boxes[:, 3] = vgt_boxes[:, 3] - cur_crop[1] + + gt_boxes = clip_boxes(np.round(gt_boxes * im_scale), im_info[:2]) + vgt_boxes = clip_boxes(np.round(vgt_boxes * im_scale), im_info[:2]) + + ids = filter_boxes(gt_boxes, 10) + if len(ids)>0: + gt_boxes = gt_boxes[ids] + else: + gt_boxes = np.zeros((0, 4)) + + ids = filter_boxes(vgt_boxes, 10) + if len(ids) > 0: + vgt_boxes = vgt_boxes[ids] + else: + vgt_boxes = np.zeros((0, 4)) + + if len(vgt_boxes) > 0: + ov = bbox_overlaps(np.ascontiguousarray(gt_boxes).astype(float), np.ascontiguousarray(vgt_boxes).astype(float)) + mov = np.max(ov, axis=1) + else: + mov = np.zeros((len(gt_boxes))) + + invalid_gtids = np.where(mov < 1)[0] + valid_gtids = np.where(mov == 1)[0] + invalid_boxes = gt_boxes[invalid_gtids, :] + gt_boxes = gt_boxes[valid_gtids, :] + + def _unmap(data, count, inds, fill=0): + """" unmap a subset inds of data into original data of size count """ + if len(data.shape) == 1: + ret = np.empty((count,), dtype=np.float32) + ret.fill(fill) + ret[inds] = data + else: + ret = np.empty((count,) + data.shape[1:], dtype=np.float32) + ret.fill(fill) + ret[inds, :] = data + return ret + + if gt_boxes.size > 0: + # overlap between the anchors and the gt boxes + # overlaps (ex, gt) + overlaps = bbox_overlaps(anchors.astype(np.float), gt_boxes.astype(np.float)) + if invalid_boxes is not None: + if len(invalid_boxes) > 0: + overlapsn = bbox_overlaps(anchors.astype(np.float), invalid_boxes.astype(np.float)) + argmax_overlapsn = overlapsn.argmax(axis=1) + max_overlapsn = overlapsn[np.arange(len(inds_inside)), argmax_overlapsn] + argmax_overlaps = overlaps.argmax(axis=1) + max_overlaps = overlaps[np.arange(len(inds_inside)), argmax_overlaps] + gt_argmax_overlaps = overlaps.argmax(axis=0) + gt_max_overlaps = overlaps[gt_argmax_overlaps, np.arange(overlaps.shape[1])] + gt_argmax_overlaps = np.where(overlaps == gt_max_overlaps)[0] + + labels[max_overlaps < 0.4] = 0 + labels[gt_argmax_overlaps] = 1 + + # fg label: above threshold IoU + labels[max_overlaps >= 0.5] = 1 + + if invalid_boxes is not None: + if len(invalid_boxes) > 0: + labels[max_overlapsn > 0.3] = -1 + else: + labels[:] = 0 + if len(invalid_boxes) > 0: + overlapsn = bbox_overlaps(anchors.astype(np.float), invalid_boxes.astype(np.float)) + argmax_overlapsn = overlapsn.argmax(axis=1) + max_overlapsn = overlapsn[np.arange(len(inds_inside)), argmax_overlapsn] + if len(invalid_boxes) > 0: + labels[max_overlapsn > 0.3] = -1 + + # subsample positive labels if we have too many + num_fg = 128 + fg_inds = np.where(labels == 1)[0] + if len(fg_inds) > num_fg: + disable_inds = npr.choice(fg_inds, size=(len(fg_inds) - num_fg), replace=False) + labels[disable_inds] = -1 + + # subsample negative labels if we have too many + num_bg = 256 - np.sum(labels == 1) + bg_inds = np.where(labels == 0)[0] + if len(bg_inds) > num_bg: + disable_inds = npr.choice(bg_inds, size=(len(bg_inds) - num_bg), replace=False) + labels[disable_inds] = -1 + + bbox_targets = np.zeros((len(inds_inside), 4), dtype=np.float32) + if gt_boxes.size > 0: + bbox_targets[:] = bbox_transform(anchors, gt_boxes[argmax_overlaps, :4]) + + bbox_weights = np.zeros((len(inds_inside), 4), dtype=np.float32) + bbox_weights[labels == 1, :] = np.array([1.0, 1.0, 1.0, 1.0]) + + # map up to original set of anchors + labels = _unmap(labels, total_anchors, inds_inside, fill=-1) + bbox_targets = _unmap(bbox_targets, total_anchors, inds_inside, fill=0) + bbox_weights = _unmap(bbox_weights, total_anchors, inds_inside, fill=0) + + labels = labels.reshape((1, feat_height, feat_width, A)).transpose(0, 3, 1, 2) + labels = labels.reshape((1, A * feat_height * feat_width)).astype(np.float16) + bbox_targets = bbox_targets.reshape((feat_height, feat_width, A * 4)).transpose(2, 0, 1) + bbox_weights = bbox_weights.reshape((feat_height, feat_width, A * 4)).transpose((2, 0, 1)) + pids = np.where(bbox_weights == 1) + bbox_targets = bbox_targets[pids] + + rval = [mx.nd.array(labels, dtype='float16'), bbox_targets, mx.nd.array(pids)] + return rval + def roidb_worker(data): im_i = data[0] im_info = data[1] @@ -137,6 +283,7 @@ def roidb_worker(data): gt_boxes = clip_boxes(np.round(gt_boxes * im_scale), im_info[:2]) ids = filter_boxes(gt_boxes, 10) + if len(ids)>0: gt_boxes = gt_boxes[ids] gt_labs = gt_labs[ids] diff --git a/lib/iterators/MNIteratorBase.py b/lib/iterators/MNIteratorBase.py index 32ba4398f..634718fa1 100644 --- a/lib/iterators/MNIteratorBase.py +++ b/lib/iterators/MNIteratorBase.py @@ -1,6 +1,8 @@ import numpy as np import mxnet as mx from multiprocessing.pool import ThreadPool +from concurrent.futures import ThreadPoolExecutor + class MNIteratorBase(mx.io.DataIter): def __init__(self, roidb, config, batch_size, threads, nGPUs, pad_rois_to, single_size_change): @@ -14,6 +16,7 @@ def __init__(self, roidb, config, batch_size, threads, nGPUs, pad_rois_to, sing self.pixel_mean = config.network.PIXEL_MEANS self.thread_pool = ThreadPool(threads) + self.executor_pool = ThreadPoolExecutor(threads) self.n_per_gpu = batch_size / nGPUs self.batch = None diff --git a/lib/iterators/MNIteratorChipsRPN.py b/lib/iterators/MNIteratorChipsRPN.py new file mode 100644 index 000000000..5b5f77c97 --- /dev/null +++ b/lib/iterators/MNIteratorChipsRPN.py @@ -0,0 +1,602 @@ +import matplotlib + +#matplotlib.use('Agg') +import matplotlib.pyplot as plt +import mxnet as mx +import cv2 +import numpy as np +import math +from bbox.bbox_regression import expand_bbox_regression_targets +from MNIteratorBase import MNIteratorBase +from bbox.bbox_transform import bbox_overlaps, bbox_pred, bbox_transform, clip_boxes, filter_boxes, ignore_overlaps +from bbox.bbox_regression import compute_bbox_regression_targets +from chips import genchips +from multiprocessing import Pool +import time +from HelperV3 import im_worker, roidb_worker, roidb_anchor_worker + + +def clip_boxes_with_chip(boxes, chip): + boxes[:, 0::4] = np.maximum(np.minimum(boxes[:, 0::4], chip[2] - 1), chip[0]) + # y1 >= 0 + boxes[:, 1::4] = np.maximum(np.minimum(boxes[:, 1::4], chip[3] - 1), chip[1]) + # x2 < im_shape[1] + boxes[:, 2::4] = np.maximum(np.minimum(boxes[:, 2::4], chip[2] - 1), chip[0]) + # y2 < im_shape[0] + boxes[:, 3::4] = np.maximum(np.minimum(boxes[:, 3::4], chip[3] - 1), chip[1]) + return boxes + + +def chip_worker(r): + width = r['width'] + height = r['height'] + im_size_max = max(width, height) + im_scale_1 = 3 + im_scale_2 = 1.667 + im_scale_3 = 512.0 / float(im_size_max) + + gt_boxes = r['boxes'][np.where(r['max_overlaps'] == 1)[0], :] + + ws = (gt_boxes[:, 2] - gt_boxes[:, 0]).astype(np.int32) + hs = (gt_boxes[:, 3] - gt_boxes[:, 1]).astype(np.int32) + area = np.sqrt(ws * hs) + ms = np.maximum(ws, hs) + + ids1 = np.where((area < 80) & (ms < 450.0/im_scale_1) & (ws >= 2) & (hs >= 2))[0] + ids2 = np.where((area >= 32) & (area < 150) & (ms < 450.0/im_scale_2))[0] + ids3 = np.where((area >= 120))[0] + + chips1 = genchips(int(r['width'] * im_scale_1), int(r['height'] * im_scale_1), gt_boxes[ids1, :] * im_scale_1, 512) + chips2 = genchips(int(r['width'] * im_scale_2), int(r['height'] * im_scale_2), gt_boxes[ids2, :] * im_scale_2, 512) + chips3 = genchips(int(r['width'] * im_scale_3), int(r['height'] * im_scale_3), gt_boxes[ids3, :] * im_scale_3, 512) + chips1 = np.array(chips1) / im_scale_1 + chips2 = np.array(chips2) / im_scale_2 + chips3 = np.array(chips3) / im_scale_3 + + chip_ar = [] + for chip in chips1: + chip_ar.append([chip, im_scale_1]) + for chip in chips2: + chip_ar.append([chip, im_scale_2]) + for chip in chips3: + chip_ar.append([chip, im_scale_3]) + + return chip_ar + + # return (np.array(chips1),np.array(chips2).np.array(chips3),im_scale_3) + + +def chip_worker_two_scales(r): + width = r['width'] + height = r['height'] + im_size_max = max(width, height) + im_size_min = min(width, height) + im_scale_2 = 800.0 / im_size_min + im_scale_3 = 512.0 / float(im_size_max) + + gt_boxes = r['boxes'][np.where(r['max_overlaps'] == 1)[0], :] + + ws = (gt_boxes[:, 2] - gt_boxes[:, 0]).astype(np.int32) + hs = (gt_boxes[:, 3] - gt_boxes[:, 1]).astype(np.int32) + area = np.sqrt(ws * hs) + ms = np.maximum(ws, hs) + + ids2 = np.where((area < 150) & (ms < 450.0/im_scale_2))[0] + ids3 = np.where((area >= 120))[0] + + chips2 = genchips(int(r['width'] * im_scale_2), int(r['height'] * im_scale_2), gt_boxes[ids2, :] * im_scale_2, 512) + chips3 = genchips(int(r['width'] * im_scale_3), int(r['height'] * im_scale_3), gt_boxes[ids3, :] * im_scale_3, 512) + chips2 = np.array(chips2) / im_scale_2 + chips3 = np.array(chips3) / im_scale_3 + + chip_ar = [] + for chip in chips2: + chip_ar.append([chip, im_scale_2]) + for chip in chips3: + chip_ar.append([chip, im_scale_3]) + + return chip_ar + +def props_in_chip_worker(r): + props_in_chips = [[] for _ in range(len(r['crops']))] + widths = (r['boxes'][:, 2] - r['boxes'][:, 0]).astype(np.int32) + heights = (r['boxes'][:, 3] - r['boxes'][:, 1]).astype(np.int32) + max_sizes = np.maximum(widths, heights) + + width = r['width'] + height = r['height'] + im_size_max = max(width, height) + im_size_min = min(width, height) + + im_scale_1 = 3 + im_scale_2 = 1.667 + im_scale_3 = 512.0 / float(im_size_max) + + area = np.sqrt(widths * heights) + + sids = np.where((area < 80) & (max_sizes < 450.0/im_scale_1) & (widths >= 2) & (heights >= 2))[0] + mids = np.where((area >= 32) & (area < 150) & (max_sizes < 450.0/im_scale_2))[0] + bids = np.where((area >= 120))[0] + + chips1, chips2, chips3 = [], [], [] + chip_ids1, chip_ids2, chip_ids3 = [], [], [] + for ci, crop in enumerate(r['crops']): + if crop[1] == im_scale_1: + chips1.append(crop[0]) + chip_ids1.append(ci) + elif crop[1] == im_scale_2: + chips2.append(crop[0]) + chip_ids2.append(ci) + else: + chips3.append(crop[0]) + chip_ids3.append(ci) + + chips1 = np.array(chips1, dtype=np.float) + chips2 = np.array(chips2, dtype=np.float) + chips3 = np.array(chips3, dtype=np.float) + chip_ids1 = np.array(chip_ids1) + chip_ids2 = np.array(chip_ids2) + chip_ids3 = np.array(chip_ids3) + + small_boxes = r['boxes'][sids].astype(np.float) + med_boxes = r['boxes'][mids].astype(np.float) + big_boxes = r['boxes'][bids].astype(np.float) + + small_covered = np.zeros(small_boxes.shape[0], dtype=bool) + med_covered = np.zeros(med_boxes.shape[0], dtype=bool) + big_covered = np.zeros(big_boxes.shape[0], dtype=bool) + + if chips1.shape[0] > 0: + overlaps = ignore_overlaps(chips1, small_boxes) + max_ids = overlaps.argmax(axis=0) + for pi, cid in enumerate(max_ids): + cur_chip = chips1[cid] + cur_box = small_boxes[pi] + x1 = max(cur_chip[0], cur_box[0]) + x2 = min(cur_chip[2], cur_box[2]) + y1 = max(cur_chip[1], cur_box[1]) + y2 = min(cur_chip[3], cur_box[3]) + area = math.sqrt(abs((x2-x1)*(y2-y1))) + if (x2 - x1 >= 1 and y2 - y1 >= 1 and area < 80): + props_in_chips[chip_ids1[cid]].append(sids[pi]) + small_covered[pi] = True + #else: + # print ('quack') + + if chips2.shape[0] > 0: + overlaps = ignore_overlaps(chips2, med_boxes) + max_ids = overlaps.argmax(axis=0) + for pi, cid in enumerate(max_ids): + cur_chip = chips2[cid] + cur_box = med_boxes[pi] + x1 = max(cur_chip[0], cur_box[0]) + x2 = min(cur_chip[2], cur_box[2]) + y1 = max(cur_chip[1], cur_box[1]) + y2 = min(cur_chip[3], cur_box[3]) + area = math.sqrt(abs((x2-x1)*(y2-y1))) + if (x2 - x1 >= 1 and y2 - y1 >= 1 and area >= 32 and area <= 150): + props_in_chips[chip_ids2[cid]].append(mids[pi]) + med_covered[pi] = True + #else: + # print ('quack 2') + + if chips3.shape[0] > 0: + overlaps = ignore_overlaps(chips3, big_boxes) + max_ids = overlaps.argmax(axis=0) + for pi, cid in enumerate(max_ids): + cur_chip = chips3[cid] + cur_box = big_boxes[pi] + x1 = max(cur_chip[0], cur_box[0]) + x2 = min(cur_chip[2], cur_box[2]) + y1 = max(cur_chip[1], cur_box[1]) + y2 = min(cur_chip[3], cur_box[3]) + area = math.sqrt(abs((x2-x1)*(y2-y1))) + if (x2 - x1 >= 1 and y2 - y1 >= 1 and area >= 120): + props_in_chips[chip_ids3[cid]].append(bids[pi]) + big_covered[pi] = True + #else: + # print ('quack 3') + + rem_small_boxes = small_boxes[np.where(small_covered == False)[0]] + neg_sids = sids[np.where(small_covered == False)[0]] + rem_med_boxes = med_boxes[np.where(med_covered == False)[0]] + neg_mids = mids[np.where(med_covered == False)[0]] + rem_big_boxes = big_boxes[np.where(big_covered == False)[0]] + neg_bids = bids[np.where(big_covered == False)[0]] + + neg_chips1 = genchips(int(r['width'] * im_scale_1), int(r['height'] * im_scale_1), rem_small_boxes * im_scale_1, 512) + neg_chips1 = np.array(neg_chips1, dtype=np.float) / im_scale_1 + chip_ids1 = np.arange(0, len(neg_chips1)) + neg_chips2 = genchips(int(r['width'] * im_scale_2), int(r['height'] * im_scale_2), rem_med_boxes * im_scale_2, 512) + neg_chips2 = np.array(neg_chips2, dtype=np.float) / im_scale_2 + chip_ids2 = np.arange(len(neg_chips1), len(neg_chips2) + len(neg_chips1)) + neg_chips3 = genchips(int(r['width'] * im_scale_3), int(r['height'] * im_scale_3), rem_big_boxes * im_scale_3, 512) + neg_chips3 = np.array(neg_chips3, dtype=np.float) / im_scale_3 + chip_ids3 = np.arange(len(neg_chips2) + len(neg_chips1), len(neg_chips1) + len(neg_chips2) + len(neg_chips3)) + + neg_props_in_chips = [[] for _ in range(len(neg_chips1) + len(neg_chips2) + len(neg_chips3))] + + if neg_chips1.shape[0] > 0: + overlaps = ignore_overlaps(neg_chips1, rem_small_boxes) + max_ids = overlaps.argmax(axis=0) + for pi, cid in enumerate(max_ids): + cur_chip = neg_chips1[cid] + cur_box = rem_small_boxes[pi] + x1 = max(cur_chip[0], cur_box[0]) + x2 = min(cur_chip[2], cur_box[2]) + y1 = max(cur_chip[1], cur_box[1]) + y2 = min(cur_chip[3], cur_box[3]) + area = math.sqrt(abs((x2-x1)*(y2-y1))) + if (x2 - x1 >= 1 and y2 - y1 >= 1 and area < 80): + neg_props_in_chips[chip_ids1[cid]].append(neg_sids[pi]) + + if neg_chips2.shape[0] > 0: + overlaps = ignore_overlaps(neg_chips2, rem_med_boxes) + max_ids = overlaps.argmax(axis=0) + for pi, cid in enumerate(max_ids): + cur_chip = neg_chips2[cid] + cur_box = rem_med_boxes[pi] + x1 = max(cur_chip[0], cur_box[0]) + x2 = min(cur_chip[2], cur_box[2]) + y1 = max(cur_chip[1], cur_box[1]) + y2 = min(cur_chip[3], cur_box[3]) + area = math.sqrt(abs((x2-x1)*(y2-y1))) + if (x2 - x1 >= 1 and y2 - y1 >= 1 and area >= 32 and area < 150): + neg_props_in_chips[chip_ids2[cid]].append(neg_mids[pi]) + + if neg_chips3.shape[0] > 0: + overlaps = ignore_overlaps(neg_chips3, rem_big_boxes) + max_ids = overlaps.argmax(axis=0) + for pi, cid in enumerate(max_ids): + cur_chip = neg_chips3[cid] + cur_box = rem_big_boxes[pi] + x1 = max(cur_chip[0], cur_box[0]) + x2 = min(cur_chip[2], cur_box[2]) + y1 = max(cur_chip[1], cur_box[1]) + y2 = min(cur_chip[3], cur_box[3]) + area = math.sqrt(abs((x2-x1)*(y2-y1))) + if (x2 - x1 >= 1 and y2 - y1 >= 1 and area >= 120): + neg_props_in_chips[chip_ids3[cid]].append(neg_bids[pi]) + + neg_chips = [] + final_neg_props_in_chips = [] + chip_counter = 0 + for chips, cscale in zip([neg_chips1, neg_chips2, neg_chips3], [im_scale_1, im_scale_2, im_scale_3]): + for chip in chips: + if len(neg_props_in_chips[chip_counter]) > 40: + final_neg_props_in_chips.append(np.array(neg_props_in_chips[chip_counter], dtype=int)) + neg_chips.append([chip, cscale]) + chip_counter += 1 + + # import pdb;pdb.set_trace() + r['neg_chips'] = neg_chips + r['neg_props_in_chips'] = final_neg_props_in_chips + + for j in range(len(props_in_chips)): + props_in_chips[j] = np.array(props_in_chips[j], dtype=np.int32) + + return props_in_chips,neg_chips,final_neg_props_in_chips + + + +def props_in_chip_worker_two_scales(r): + props_in_chips = [[] for _ in range(len(r['crops']))] + widths = (r['boxes'][:, 2] - r['boxes'][:, 0]).astype(np.int32) + heights = (r['boxes'][:, 3] - r['boxes'][:, 1]).astype(np.int32) + max_sizes = np.maximum(widths, heights) + + width = r['width'] + height = r['height'] + im_size_max = max(width, height) + im_size_min = min(width, height) + + im_scale_2 = 800.0 / im_size_min + im_scale_3 = 512.0 / float(im_size_max) + + area = np.sqrt(widths * heights) + + mids = np.where((area < 150) & (max_sizes < 450.0/im_scale_2))[0] + bids = np.where((area >= 120))[0] + + chips2, chips3 = [], [] + chip_ids2, chip_ids3 = [], [] + for ci, crop in enumerate(r['crops']): + if crop[1] == im_scale_2: + chips2.append(crop[0]) + chip_ids2.append(ci) + else: + chips3.append(crop[0]) + chip_ids3.append(ci) + + chips2 = np.array(chips2, dtype=np.float) + chips3 = np.array(chips3, dtype=np.float) + chip_ids2 = np.array(chip_ids2) + chip_ids3 = np.array(chip_ids3) + + med_boxes = r['boxes'][mids].astype(np.float) + big_boxes = r['boxes'][bids].astype(np.float) + + med_covered = np.zeros(med_boxes.shape[0], dtype=bool) + big_covered = np.zeros(big_boxes.shape[0], dtype=bool) + + if chips2.shape[0] > 0: + overlaps = ignore_overlaps(chips2, med_boxes) + max_ids = overlaps.argmax(axis=0) + for pi, cid in enumerate(max_ids): + cur_chip = chips2[cid] + cur_box = med_boxes[pi] + x1 = max(cur_chip[0], cur_box[0]) + x2 = min(cur_chip[2], cur_box[2]) + y1 = max(cur_chip[1], cur_box[1]) + y2 = min(cur_chip[3], cur_box[3]) + area = math.sqrt(abs((x2-x1)*(y2-y1))) + if (x2 - x1 >= 1 and y2 - y1 >= 1 and area <= 150): + props_in_chips[chip_ids2[cid]].append(mids[pi]) + med_covered[pi] = True + #else: + # print ('quack 2') + + if chips3.shape[0] > 0: + overlaps = ignore_overlaps(chips3, big_boxes) + max_ids = overlaps.argmax(axis=0) + for pi, cid in enumerate(max_ids): + cur_chip = chips3[cid] + cur_box = big_boxes[pi] + x1 = max(cur_chip[0], cur_box[0]) + x2 = min(cur_chip[2], cur_box[2]) + y1 = max(cur_chip[1], cur_box[1]) + y2 = min(cur_chip[3], cur_box[3]) + area = math.sqrt(abs((x2-x1)*(y2-y1))) + if (x2 - x1 >= 1 and y2 - y1 >= 1 and area >= 120): + props_in_chips[chip_ids3[cid]].append(bids[pi]) + big_covered[pi] = True + #else: + # print ('quack 3') + + rem_med_boxes = med_boxes[np.where(med_covered == False)[0]] + neg_mids = mids[np.where(med_covered == False)[0]] + rem_big_boxes = big_boxes[np.where(big_covered == False)[0]] + neg_bids = bids[np.where(big_covered == False)[0]] + + neg_chips2 = genchips(int(r['width'] * im_scale_2), int(r['height'] * im_scale_2), rem_med_boxes * im_scale_2, 512) + neg_chips2 = np.array(neg_chips2, dtype=np.float) / im_scale_2 + chip_ids2 = np.arange(len(neg_chips2)) + neg_chips3 = genchips(int(r['width'] * im_scale_3), int(r['height'] * im_scale_3), rem_big_boxes * im_scale_3, 512) + neg_chips3 = np.array(neg_chips3, dtype=np.float) / im_scale_3 + chip_ids3 = np.arange(len(neg_chips2), len(neg_chips2) + len(neg_chips3)) + + neg_props_in_chips = [[] for _ in range(len(neg_chips2) + len(neg_chips3))] + + if neg_chips2.shape[0] > 0: + overlaps = ignore_overlaps(neg_chips2, rem_med_boxes) + max_ids = overlaps.argmax(axis=0) + for pi, cid in enumerate(max_ids): + cur_chip = neg_chips2[cid] + cur_box = rem_med_boxes[pi] + x1 = max(cur_chip[0], cur_box[0]) + x2 = min(cur_chip[2], cur_box[2]) + y1 = max(cur_chip[1], cur_box[1]) + y2 = min(cur_chip[3], cur_box[3]) + area = math.sqrt(abs((x2-x1)*(y2-y1))) + if (x2 - x1 >= 1 and y2 - y1 >= 1 and area < 150): + neg_props_in_chips[chip_ids2[cid]].append(neg_mids[pi]) + + if neg_chips3.shape[0] > 0: + overlaps = ignore_overlaps(neg_chips3, rem_big_boxes) + max_ids = overlaps.argmax(axis=0) + for pi, cid in enumerate(max_ids): + cur_chip = neg_chips3[cid] + cur_box = rem_big_boxes[pi] + x1 = max(cur_chip[0], cur_box[0]) + x2 = min(cur_chip[2], cur_box[2]) + y1 = max(cur_chip[1], cur_box[1]) + y2 = min(cur_chip[3], cur_box[3]) + area = math.sqrt(abs((x2-x1)*(y2-y1))) + if (x2 - x1 >= 1 and y2 - y1 >= 1 and area >= 120): + neg_props_in_chips[chip_ids3[cid]].append(neg_bids[pi]) + + neg_chips = [] + final_neg_props_in_chips = [] + chip_counter = 0 + for chips, cscale in zip([neg_chips2, neg_chips3], [im_scale_2, im_scale_3]): + for chip in chips: + if len(neg_props_in_chips[chip_counter]) > 50: + final_neg_props_in_chips.append(np.array(neg_props_in_chips[chip_counter], dtype=int)) + neg_chips.append([chip, cscale]) + chip_counter += 1 + + # import pdb;pdb.set_trace() + r['neg_chips'] = neg_chips + r['neg_props_in_chips'] = final_neg_props_in_chips + + for j in range(len(props_in_chips)): + props_in_chips[j] = np.array(props_in_chips[j], dtype=np.int32) + + return props_in_chips,neg_chips,final_neg_props_in_chips + +class MNIteratorChips(MNIteratorBase): + def __init__(self, roidb, config, batch_size=4, threads=8, nGPUs=1, pad_rois_to=400, crop_size=(512, 512)): + self.crop_size = crop_size + self.num_classes = roidb[0]['gt_overlaps'].shape[1] + self.bbox_means = np.tile(np.array(config.TRAIN.BBOX_MEANS), (self.num_classes, 1)) + self.bbox_stds = np.tile(np.array(config.TRAIN.BBOX_STDS), (self.num_classes, 1)) + self.data_name = ['data'] + self.label_name = ['label', 'bbox_target', 'bbox_weight'] + self.pool = Pool(32) + self.context_size = 320 + self.epiter = 0 + super(MNIteratorChips, self).__init__(roidb, config, batch_size, threads, nGPUs, pad_rois_to, False) + + def reset(self): + self.cur_i = 0 + self.n_neg_per_im = 2 + self.crop_idx = [0] * len(self.roidb) + chips = self.pool.map(chip_worker, self.roidb) + #chipindex = [] + chip_count = 0 + for i, r in enumerate(self.roidb): + cs = chips[i] + chip_count += len(cs) + r['crops'] = cs + + all_props_in_chips = self.pool.map(props_in_chip_worker, self.roidb) + + for (props_in_chips, neg_chips, neg_props_in_chips), cur_roidb in zip(all_props_in_chips, self.roidb): + cur_roidb['props_in_chips'] = props_in_chips + cur_roidb['neg_crops'] = neg_chips + cur_roidb['neg_props_in_chips'] = neg_props_in_chips + + # Append negative chips + chipindex = [] + for i, r in enumerate(self.roidb): + cs = r['neg_crops'] + if len(cs) > 0: + sel_inds = np.arange(len(cs)) + if len(cs) > self.n_neg_per_im: + sel_inds = np.random.permutation(sel_inds)[0:self.n_neg_per_im] + for ind in sel_inds: + chip_count = chip_count + 1 + r['crops'].append(r['neg_crops'][ind]) + r['props_in_chips'].append(r['neg_props_in_chips'][ind].astype(np.int32)) + all_crops = r['crops'] + for j in range(len(all_crops)): + chipindex.append(i) + + + print('quack N chips: {}'.format(chip_count)) + + blocksize = self.batch_size + chipindex = np.array(chipindex) + if chipindex.shape[0] % blocksize > 0: + extra = blocksize - (chipindex.shape[0] % blocksize) + chipindex = np.hstack((chipindex, chipindex[0:extra])) + allinds = np.random.permutation(chipindex) + self.inds = np.array(allinds, dtype=int) + for r in self.roidb: + r['chip_order'] = np.random.permutation(np.arange(len(r['crops']))) + + self.epiter = self.epiter + 1 + self.size = len(self.inds) + print 'Done!' + + def get_batch(self): + if self.cur_i >= self.size: + return False + + # cur_roidbs = [self.roidb[self.inds[i%self.size]] for i in range(self.cur_i, self.cur_i+self.batch_size)] + + # Process cur roidb + self.batch = self._get_batch() + + self.cur_i += self.batch_size + return True + + def _get_batch(self): + """ + return a dict of multiple images + :param roidb: a list of dict, whose length controls batch size + ['images', 'flipped'] + ['gt_boxes', 'boxes', 'gt_overlap'] => ['bbox_targets'] + :return: data, label + """ + import time + t1 = time.time() + + cur_from = self.cur_i + cur_to = self.cur_i + self.batch_size + roidb = [self.roidb[self.inds[i]] for i in range(cur_from, cur_to)] + # num_images = len(roidb) + cropids = [self.roidb[self.inds[i]]['chip_order'][self.crop_idx[self.inds[i]]%len(self.roidb[self.inds[i]]['chip_order'])] for i in range(cur_from, cur_to)] + n_batch = len(roidb) + ims = [] + for i in range(n_batch): + ims.append([roidb[i]['image'], roidb[i]['crops'][cropids[i]], roidb[i]['flipped'], self.crop_size[0]]) + + for i in range(cur_from, cur_to): + self.crop_idx[self.inds[i]] = self.crop_idx[self.inds[i]] + 1 + + # im_tensor, roidb = self.im_process(roidb,cropids) + processed_roidb = [] + for i in range(len(roidb)): + tmp = roidb[i].copy() + scale = roidb[i]['crops'][cropids[i]][1] + tmp['im_info'] = [self.crop_size[0], self.crop_size[1], scale] + processed_roidb.append(tmp) + + processed_list = self.thread_pool.map_async(im_worker, ims) + + worker_data = [] + for i in range(len(processed_roidb)): + cropid = cropids[i] + nids = processed_roidb[i]['props_in_chips'][cropid] + gtids = np.where(processed_roidb[i]['max_overlaps'] == 1)[0] + gt_boxes = processed_roidb[i]['boxes'][gtids, :] + boxes = processed_roidb[i]['boxes'].copy() + cur_crop = processed_roidb[i]['crops'][cropid][0] + im_scale = processed_roidb[i]['crops'][cropid][1] + + argw = [processed_roidb[i]['im_info'], cur_crop, im_scale, nids, gtids, gt_boxes, boxes] + worker_data.append(argw) + + t2 = time.time() + + #print 'q1 ' + str(t2 - t1) + all_labels = self.pool.map(roidb_anchor_worker, worker_data) + t3 = time.time() + #print 'q2 ' + str(t3 - t2) + A = 21 + feat_height = 32 + feat_width = 32 + labels = mx.nd.zeros((n_batch, A*feat_height*feat_width), mx.cpu(0)) + bbox_targets = mx.nd.zeros((n_batch, A*4, feat_height, feat_width), mx.cpu(0)) + bbox_weights = mx.nd.zeros((n_batch, A*4, feat_height, feat_width), mx.cpu(0)) + + for i in range(len(all_labels)): + labels[i] = all_labels[i][0][0] + pids = all_labels[i][2] + if len(pids[0]) > 0: + bbox_targets[i][pids[0], pids[1], pids[2]] = all_labels[i][1] + bbox_weights[i][pids[0], pids[1], pids[2]] = 1.0 + t4 = time.time() + #print 'q3 ' + str(t4 - t3) + + im_tensor = mx.nd.zeros((n_batch, 3, self.crop_size[0], self.crop_size[1]), dtype=np.float32) + processed_list = processed_list.get() + for i in range(len(processed_list)): + im_tensor[i] = processed_list[i] + t5 = time.time() + #print 'q4 ' + str(t5 - t4) + #self.visualize(im_tensor, rois, labels) + self.data = [im_tensor] + self.label = [labels, bbox_targets, bbox_weights] + t6 = time.time() + #print 'convert ' + str(t6 - t5) + return mx.io.DataBatch(data=self.data, label=self.label, pad=self.getpad(), index=self.getindex(), + provide_data=self.provide_data, provide_label=self.provide_label) + + + def visualize(self, im_tensor, boxes, labels): # , bbox_targets, bbox_weights): + # import pdb;pdb.set_trace() + im_tensor = im_tensor.asnumpy() + boxes = boxes.asnumpy() + + for imi in range(im_tensor.shape[0]): + im = np.zeros((im_tensor.shape[2], im_tensor.shape[3], 3), dtype=np.uint8) + for i in range(3): + im[:, :, i] = im_tensor[imi, i, :, :] + self.pixel_mean[2 - i] + # Visualize positives + plt.imshow(im) + pos_ids = np.where(labels[imi].asnumpy() > 0)[0] + cboxes = boxes[imi][pos_ids, 1:5] + # cboxes = boxes[imi][:, 0:4] + for box in cboxes: + rect = plt.Rectangle((box[0], box[1]), + box[2] - box[0], + box[3] - box[1], fill=False, + edgecolor='green', linewidth=3.5) + plt.gca().add_patch(rect) + num = np.random.randint(100000) + #plt.show() + plt.savefig('debug/visualization/test_{}_pos.png'.format(num)) + plt.cla() + plt.clf() + plt.close() diff --git a/lib/iterators/generate_anchor.py b/lib/iterators/generate_anchor.py new file mode 100644 index 000000000..00e883b0c --- /dev/null +++ b/lib/iterators/generate_anchor.py @@ -0,0 +1,77 @@ +""" +Generate base anchors on index 0 +""" + +import numpy as np + + +def generate_anchors(base_size=16, ratios=[0.5, 1, 2], + scales=2 ** np.arange(3, 6)): + """ + Generate anchor (reference) windows by enumerating aspect ratios X + scales wrt a reference (0, 0, 15, 15) window. + """ + + base_anchor = np.array([1, 1, base_size, base_size]) - 1 + ratio_anchors = _ratio_enum(base_anchor, ratios) + anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales) + for i in xrange(ratio_anchors.shape[0])]) + return anchors + + +def _whctrs(anchor): + """ + Return width, height, x center, and y center for an anchor (window). + """ + + w = anchor[2] - anchor[0] + 1 + h = anchor[3] - anchor[1] + 1 + x_ctr = anchor[0] + 0.5 * (w - 1) + y_ctr = anchor[1] + 0.5 * (h - 1) + return w, h, x_ctr, y_ctr + + +def _mkanchors(ws, hs, x_ctr, y_ctr): + """ + Given a vector of widths (ws) and heights (hs) around a center + (x_ctr, y_ctr), output a set of anchors (windows). + """ + + ws = ws[:, np.newaxis] + hs = hs[:, np.newaxis] + anchors = np.hstack((x_ctr - 0.5 * (ws - 1), + y_ctr - 0.5 * (hs - 1), + x_ctr + 0.5 * (ws - 1), + y_ctr + 0.5 * (hs - 1))) + return anchors + + +def _ratio_enum(anchor, ratios): + """ + Enumerate a set of anchors for each aspect ratio wrt an anchor. + """ + + w, h, x_ctr, y_ctr = _whctrs(anchor) + size = w * h + size_ratios = size / ratios + ws = np.round(np.sqrt(size_ratios)) + hs = np.round(ws * ratios) + anchors = _mkanchors(ws, hs, x_ctr, y_ctr) + return anchors + + +def _scale_enum(anchor, scales): + """ + Enumerate a set of anchors for each scale wrt an anchor. + """ + + w, h, x_ctr, y_ctr = _whctrs(anchor) + ws = [] + for i in range(len(scales)): + ws.append(w*scales[i]) + hs = [] + for i in range(len(scales)): + hs.append(h*scales[i]) + + anchors = _mkanchors(np.array(ws), np.array(hs), x_ctr, y_ctr) + return anchors diff --git a/main.py b/main.py deleted file mode 100644 index a722137f4..000000000 --- a/main.py +++ /dev/null @@ -1,109 +0,0 @@ -import os -os.environ['PYTHONUNBUFFERED'] = '1' -os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '0' -os.environ['MXNET_ENABLE_GPU_P2P'] = '0' -import init -from iterators.MNIteratorFaster import MNIteratorFaster -from load_model import load_param -import sys -sys.path.insert(0,'lib') -from symbols.faster.resnet_v1_50_fast import resnet_v1_50_fast,checkpoint_callback -from configs.faster.default_configs import config,update_config,get_opt_params -import mxnet as mx -import metric,callback -import numpy as np -from general_utils import get_optim_params,get_fixed_param_names,create_logger - -from iterators.PrefetchingIter import PrefetchingIter -from load_data import load_proposal_roidb,merge_roidb,filter_roidb -from bbox.bbox_regression import add_bbox_regression_targets -from argparse import ArgumentParser - -def parser(): - arg_parser = ArgumentParser('Faster R-CNN training module') - arg_parser.add_argument('--cfg',dest='cfg',help='Path to the config file', - default='configs/faster/res50_coco.yml',type=str) - arg_parser.add_argument('--display',dest='display',help='Number of epochs between displaying loss info', - default=100,type=int) - arg_parser.add_argument('--save_prefix',dest='save_prefix',help='Prefix used for snapshotting the network', - default='CRCNN',type=str) - - return arg_parser.parse_args() - -if __name__=='__main__': - args = parser() - update_config(args.cfg) - context=[mx.gpu(int(gpu)) for gpu in config.gpus.split(',')] - nGPUs = len(context) - batch_size = nGPUs * config.TRAIN.BATCH_IMAGES - - if not os.path.isdir(config.output_path): - os.mkdir(config.output_path) - - - # Create roidb - image_sets = [iset for iset in config.dataset.image_set.split('+')] - roidbs = [load_proposal_roidb(config.dataset.dataset, image_set, config.dataset.root_path, config.dataset.dataset_path, - proposal=config.dataset.proposal, append_gt=True, flip=True, result_path=config.output_path) - for image_set in image_sets] - roidb = merge_roidb(roidbs) - roidb = filter_roidb(roidb, config) - bbox_means, bbox_stds = add_bbox_regression_targets(roidb, config) - - # Creating the iterator - print('Creating Iterator with {} Images'.format(len(roidb))) - train_iter = MNIteratorFaster(roidb=roidb,config=config,batch_size=batch_size,nGPUs=nGPUs,threads=batch_size) - - # Creating the module - print('Initializing the model...') - sym_inst = resnet_v1_50_fast() - sym = sym_inst.get_symbol_rcnn(config) - - # Creating the Logger - logger, output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set) - - # get list of fixed parameters - fixed_param_names = get_fixed_param_names(config.network.FIXED_PARAMS,sym) - - # Creating the module - mod = mx.mod.Module(symbol=sym, - context=context, - data_names=[k[0] for k in train_iter.provide_data_single], - label_names=[k[0] for k in train_iter.provide_label_single], - fixed_param_names=fixed_param_names) - shape_dict = dict(train_iter.provide_data_single+train_iter.provide_label_single) - sym_inst.infer_shape(shape_dict) - arg_params, aux_params = load_param(config.network.pretrained,config.network.pretrained_epoch,convert=True) - sym_inst.init_weight_rcnn(config,arg_params,aux_params) - - - # Creating the metrics - eval_metric = metric.RCNNAccMetric(config) - cls_metric = metric.RCNNLogLossMetric(config) - bbox_metric = metric.RCNNL1LossMetric(config) - eval_metrics = mx.metric.CompositeEvalMetric() - eval_metrics.add(eval_metric) - eval_metrics.add(cls_metric) - eval_metrics.add(bbox_metric) - - - eval_metrics = mx.metric.CompositeEvalMetric() - eval_metrics.add(eval_metric) - eval_metrics.add(cls_metric) - eval_metrics.add(bbox_metric) - - optimizer_params = get_optim_params(config,len(roidb),batch_size) - print ('Optimizer params: {}'.format(optimizer_params)) - - # Checkpointing - prefix = os.path.join(output_path,args.save_prefix) - batch_end_callback = mx.callback.Speedometer(batch_size, args.display) - epoch_end_callback = [mx.callback.module_checkpoint(mod, prefix, period=1, save_optimizer_states=True), - checkpoint_callback(sym_inst.get_bbox_param_names(),prefix, bbox_means, bbox_stds)] - - train_iter = PrefetchingIter(train_iter) - mod.fit(train_iter,optimizer='sgd',optimizer_params=optimizer_params, - eval_metric=eval_metrics,num_epoch=config.TRAIN.end_epoch,kvstore=config.default.kvstore, - batch_end_callback=batch_end_callback, - epoch_end_callback=epoch_end_callback, arg_params=arg_params,aux_params=aux_params) - \ No newline at end of file diff --git a/main_chips_rpn.py b/main_chips_rpn.py new file mode 100644 index 000000000..2f607d873 --- /dev/null +++ b/main_chips_rpn.py @@ -0,0 +1,138 @@ +import os + +os.environ['PYTHONUNBUFFERED'] = '1' +os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '2' +# os.environ['MXNET_ENABLE_GPU_P2P'] = '0' +import init +from iterators.MNIteratorChipsV3NegR2 import MNIteratorChips +from load_model import load_param +import sys + +sys.path.insert(0, 'lib') +from symbols.faster.resnet_mx_101_rpn import resnet_mx_101_rpn, checkpoint_callback +#from symbols.faster.symbol_dpn_98_cls import symbol_dpn_98_cls, checkpoint_callback +from configs.faster.default_configs import config, update_config, get_opt_params +import mxnet as mx +import metric, callback +import numpy as np +from general_utils import get_optim_params, get_fixed_param_names, create_logger +from iterators.PrefetchingIter import PrefetchingIter +from iterators.MNIteratorChipsRPN import MNIteratorChips +from load_data import load_proposal_roidb, merge_roidb, filter_roidb, add_chip_data, remove_small_boxes +from bbox.bbox_regression import add_bbox_regression_targets +from argparse import ArgumentParser +import cPickle + + +def parser(): + arg_parser = ArgumentParser('Faster R-CNN training module') + arg_parser.add_argument('--cfg', dest='cfg', help='Path to the config file', + #default='configs/faster/dpn98_coco_chips.yml', type=str) + default='configs/faster/rpn_res101_mx_bn1.yml',type=str) + arg_parser.add_argument('--display', dest='display', help='Number of epochs between displaying loss info', + default=100, type=int) + arg_parser.add_argument('--momentum', dest='momentum', help='BN momentum', default=0.995, type=float) + arg_parser.add_argument('--save_prefix', dest='save_prefix', help='Prefix used for snapshotting the network', + default='CRCNN', type=str) + arg_parser.add_argument('--threadid', dest='threadid', help='Prefix used for snapshotting the network', + type=int) + + return arg_parser.parse_args() + + +if __name__ == '__main__': + args = parser() + update_config(args.cfg) + context = [mx.gpu(int(gpu)) for gpu in config.gpus.split(',')] + nGPUs = len(context) + batch_size = nGPUs * config.TRAIN.BATCH_IMAGES + + if not os.path.isdir(config.output_path): + os.mkdir(config.output_path) + + # Create roidb + config.debug = False + if config.debug == False: + image_sets = [iset for iset in config.dataset.image_set.split('+')] + roidbs = [load_proposal_roidb(config.dataset.dataset, image_set, config.dataset.root_path, + config.dataset.dataset_path, + proposal=config.dataset.proposal, append_gt=True, flip=True, + result_path=config.output_path, + proposal_path=config.proposal_path) + for image_set in image_sets] + + roidb = merge_roidb(roidbs) + # roidb = remove_small_boxes(roidb,max_scale=3,min_size=2) + roidb = filter_roidb(roidb, config) + bbox_means, bbox_stds = add_bbox_regression_targets(roidb, config) + else: + args.display = 20 + with open('/home/ubuntu/bigminival2014.pkl', 'rb') as file: + roidb = cPickle.load(file) + bbox_means, bbox_stds = add_bbox_regression_targets(roidb, config) + + print('Creating Iterator with {} Images'.format(len(roidb))) + train_iter = MNIteratorChips(roidb=roidb, config=config, batch_size=batch_size, nGPUs=nGPUs, threads=32, + pad_rois_to=400) + print('The Iterator has {} samples!'.format(len(train_iter))) + + #for data in train_iter: + # print 'Yes' + #import time + #t1 = time.time() + #for i,batch in enumerate(train_iter): + # t2 = time.time() - t1 + # print 128.0 / t2 + # t1 = time.time() + # exit(0) + + + print('Initializing the model...') + sym_inst = resnet_mx_101_rpn(n_proposals=400) + sym = sym_inst.get_symbol_rcnn(config) + + # Creating the Logger + logger, output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set) + + # get list of fixed parameters + fixed_param_names = get_fixed_param_names(config.network.FIXED_PARAMS, sym) + + # Creating the module + mod = mx.mod.Module(symbol=sym, + context=context, + data_names=[k[0] for k in train_iter.provide_data_single], + label_names=[k[0] for k in train_iter.provide_label_single], + fixed_param_names=fixed_param_names) + + shape_dict = dict(train_iter.provide_data_single + train_iter.provide_label_single) + sym_inst.infer_shape(shape_dict) + arg_params, aux_params = load_param(config.network.pretrained, config.network.pretrained_epoch, convert=True) + + sym_inst.init_weight_rcnn(config, arg_params, aux_params) + + # Creating the metrics + eval_metric = metric.RPNAccMetric() + cls_metric = metric.RPNLogLossMetric() + bbox_metric = metric.RPNL1LossMetric() + + eval_metrics = mx.metric.CompositeEvalMetric() + + eval_metrics.add(eval_metric) + eval_metrics.add(cls_metric) + eval_metrics.add(bbox_metric) + # eval_metrics.add(vis_metric) + + optimizer_params = get_optim_params(config, len(train_iter), batch_size) + print ('Optimizer params: {}'.format(optimizer_params)) + + # Checkpointing + prefix = os.path.join(output_path, args.save_prefix) + batch_end_callback = mx.callback.Speedometer(batch_size, args.display) + epoch_end_callback = [mx.callback.module_checkpoint(mod, prefix, period=1, save_optimizer_states=True), + checkpoint_callback(sym_inst.get_bbox_param_names(), prefix, bbox_means, bbox_stds)] + + train_iter = PrefetchingIter(train_iter) + mod.fit(train_iter, optimizer='sgd', optimizer_params=optimizer_params, + eval_metric=eval_metrics, num_epoch=config.TRAIN.end_epoch, kvstore=config.default.kvstore, + batch_end_callback=batch_end_callback, + epoch_end_callback=epoch_end_callback, arg_params=arg_params, aux_params=aux_params) diff --git a/symbols/faster/resnet_mx_101_rpn.py b/symbols/faster/resnet_mx_101_rpn.py new file mode 100644 index 000000000..5807f0390 --- /dev/null +++ b/symbols/faster/resnet_mx_101_rpn.py @@ -0,0 +1,285 @@ +import cPickle +import mxnet as mx +from lib.symbol import Symbol +# from operator_py.debug import * +from operator_py.box_annotator_ohem import * +from operator_py.debug_data import * +import numpy as np + +def checkpoint_callback(bbox_param_names, prefix, means, stds): + def _callback(iter_no, sym, arg, aux): + weight = arg[bbox_param_names[0]] + bias = arg[bbox_param_names[1]] + arg[bbox_param_names[0] + '_test'] = (weight.T * mx.nd.array(stds)).T + arg[bbox_param_names[1] + '_test'] = bias * mx.nd.array(stds) + mx.nd.array(means) + mx.model.save_checkpoint(prefix, iter_no + 1, sym, arg, aux) + arg.pop(bbox_param_names[0] + '_test') + arg.pop(bbox_param_names[1] + '_test') + + return _callback + + +class resnet_mx_101_rpn(Symbol): + def __init__(self, n_proposals=400, momentum=0.95, fix_bn=False): + """ + Use __init__ to define parameter network needs + """ + self.momentum = momentum + self.use_global_stats = True + self.workspace = 512 + self.units = (3, 4, 23, 3) # use for 101 + self.filter_list = [64, 256, 512, 1024, 2048] + self.fix_bn = fix_bn + + def get_bbox_param_names(self): + return ['bbox_pred_weight', 'bbox_pred_bias'] + + def residual_unit(self, data, num_filter, stride, dim_match, name, bn_mom=0.9, workspace=512, memonger=False, + fix_bn=False): + if fix_bn or self.fix_bn: + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, use_global_stats=True, name=name + '_bn1') + else: + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=self.momentum, name=name + '_bn1') + act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1') + conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter * 0.25), kernel=(1, 1), stride=(1, 1), + pad=(0, 0), + no_bias=True, workspace=workspace, name=name + '_conv1') + if fix_bn or self.fix_bn: + bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, use_global_stats=True, name=name + '_bn2') + else: + bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=self.momentum, name=name + '_bn2') + act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2') + conv2 = mx.sym.Convolution(data=act2, num_filter=int(num_filter * 0.25), kernel=(3, 3), stride=stride, + pad=(1, 1), + no_bias=True, workspace=workspace, name=name + '_conv2') + if fix_bn or self.fix_bn: + bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, use_global_stats=True, name=name + '_bn3') + else: + bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=self.momentum, name=name + '_bn3') + act3 = mx.sym.Activation(data=bn3, act_type='relu', name=name + '_relu3') + conv3 = mx.sym.Convolution(data=act3, num_filter=num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), + no_bias=True, + workspace=workspace, name=name + '_conv3') + if dim_match: + shortcut = data + else: + shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1, 1), stride=stride, no_bias=True, + workspace=workspace, name=name + '_sc') + if memonger: + shortcut._set_attr(mirror_stage='True') + return conv3 + shortcut + + def residual_unit_dilate(self, data, num_filter, stride, dim_match, name, bn_mom=0.9, workspace=512, + memonger=False): + if self.fix_bn: + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, use_global_stats=True, name=name + '_bn1') + else: + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=self.momentum, name=name + '_bn1') + act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1') + conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter * 0.25), kernel=(1, 1), stride=(1, 1), + pad=(0, 0), + no_bias=True, workspace=workspace, name=name + '_conv1') + if self.fix_bn: + bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, use_global_stats=True, name=name + '_bn2') + else: + bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=self.momentum, name=name + '_bn2') + act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2') + conv2 = mx.sym.Convolution(data=act2, num_filter=int(num_filter * 0.25), kernel=(3, 3), dilate=(2, 2), + stride=stride, pad=(2, 2), + no_bias=True, workspace=workspace, name=name + '_conv2') + if self.fix_bn: + bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, use_global_stats=True, name=name + '_bn3') + else: + bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=self.momentum, name=name + '_bn3') + act3 = mx.sym.Activation(data=bn3, act_type='relu', name=name + '_relu3') + conv3 = mx.sym.Convolution(data=act3, num_filter=num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), + no_bias=True, + workspace=workspace, name=name + '_conv3') + if dim_match: + shortcut = data + else: + shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1, 1), stride=stride, no_bias=True, + workspace=workspace, name=name + '_sc') + if memonger: + shortcut._set_attr(mirror_stage='True') + return conv3 + shortcut + + def residual_unit_deform(self, data, num_filter, stride, dim_match, name, bn_mom=0.9, workspace=512, + memonger=False): + if self.fix_bn: + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, use_global_stats=True, name=name + '_bn1') + else: + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=self.momentum, name=name + '_bn1') + act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1') + conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter * 0.25), kernel=(1, 1), stride=(1, 1), + pad=(0, 0), + no_bias=True, workspace=workspace, name=name + '_conv1') + if self.fix_bn: + bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, use_global_stats=True, name=name + '_bn2') + else: + bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=self.momentum, name=name + '_bn2') + act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2') + offset = mx.symbol.Convolution(name=name + '_offset', data=act2, + num_filter=72, pad=(2, 2), kernel=(3, 3), stride=(1, 1), + dilate=(2, 2), cudnn_off=True) + conv2 = mx.contrib.symbol.DeformableConvolution(name=name + '_conv2', data=act2, + offset=offset, + num_filter=512, pad=(2, 2), kernel=(3, 3), + num_deformable_group=4, + stride=(1, 1), dilate=(2, 2), no_bias=True) + if self.fix_bn: + bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, use_global_stats=True, name=name + '_bn3') + else: + bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=self.momentum, name=name + '_bn3') + + act3 = mx.sym.Activation(data=bn3, act_type='relu', name=name + '_relu3') + conv3 = mx.sym.Convolution(data=act3, num_filter=num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), + no_bias=True, + workspace=workspace, name=name + '_conv3') + if dim_match: + shortcut = data + else: + shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1, 1), stride=stride, no_bias=True, + workspace=workspace, name=name + '_sc') + if memonger: + shortcut._set_attr(mirror_stage='True') + return conv3 + shortcut + + def get_rpn(self, conv_feat, num_anchors): + rpn_conv = mx.sym.Convolution( + data=conv_feat, kernel=(3, 3), pad=(1, 1), num_filter=512, name="rpn_conv_3x3") + rpn_relu = mx.sym.Activation(data=rpn_conv, act_type="relu", name="rpn_relu") + rpn_cls_score = mx.sym.Convolution( + data=rpn_relu, kernel=(1, 1), pad=(0, 0), num_filter=2 * num_anchors, name="rpn_cls_score") + rpn_bbox_pred = mx.sym.Convolution( + data=rpn_relu, kernel=(1, 1), pad=(0, 0), num_filter=4 * num_anchors, name="rpn_bbox_pred") + return rpn_cls_score, rpn_bbox_pred + + def get_symbol_rcnn(self, cfg, is_train=True): + num_anchors = cfg.network.NUM_ANCHORS + + # input init + if is_train: + data = mx.sym.Variable(name="data") + rpn_label = mx.sym.Variable(name='label') + rpn_bbox_target = mx.sym.Variable(name='bbox_target') + rpn_bbox_weight = mx.sym.Variable(name='bbox_weight') + else: + data = mx.sym.Variable(name="data") + + # shared convolutional layers + conv_feat = self.resnetc4(data, fp16=cfg.TRAIN.fp16) + # res5 + relut = self.resnetc5(conv_feat, deform=True) + relu1 = mx.symbol.Concat(*[conv_feat, relut], name='cat4') + if cfg.TRAIN.fp16: + relu1 = mx.sym.Cast(data=relu1, dtype=np.float32) + + rpn_cls_score, rpn_bbox_pred = self.get_rpn(relu1, num_anchors) + + if is_train: + # prepare rpn data + rpn_cls_score_reshape = mx.sym.Reshape(data=rpn_cls_score, shape=(0, 2, -1, 0), + name="rpn_cls_score_reshape") + # classification + rpn_cls_prob = mx.sym.SoftmaxOutput(data=rpn_cls_score_reshape, label=rpn_label, multi_output=True, + normalization='valid', use_ignore=True, ignore_label=-1, + name="rpn_cls_prob") + + # bounding box regression + rpn_bbox_loss_ = rpn_bbox_weight * mx.sym.smooth_l1(name='rpn_bbox_loss_', scalar=1.0, + data=(rpn_bbox_pred - rpn_bbox_target)) + rpn_bbox_loss = mx.sym.MakeLoss(name='rpn_bbox_loss', data=rpn_bbox_loss_, + grad_scale=1.0 / float(cfg.TRAIN.BATCH_IMAGES*cfg.TRAIN.RPN_BATCH_SIZE)) + group = mx.sym.Group([rpn_cls_prob, rpn_bbox_loss]) + else: + # ROI Proposal + rpn_cls_score_reshape = mx.sym.Reshape( + data=rpn_cls_score, shape=(0, 2, -1, 0), name="rpn_cls_score_reshape") + rpn_cls_prob = mx.sym.SoftmaxActivation( + data=rpn_cls_score_reshape, mode="channel", name="rpn_cls_prob") + rpn_cls_prob_reshape = mx.sym.Reshape( + data=rpn_cls_prob, shape=(0, 2 * num_anchors, -1, 0), name='rpn_cls_prob_reshape') + rois = mx.sym.Custom( + cls_prob=rpn_cls_prob_reshape, bbox_pred=rpn_bbox_pred, im_info=im_info, name='rois', + op_type='proposal', feat_stride=cfg.network.RPN_FEAT_STRIDE, + scales=tuple(cfg.network.ANCHOR_SCALES), ratios=tuple(cfg.network.ANCHOR_RATIOS), output_score='True', + rpn_pre_nms_top_n=cfg.TEST.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=cfg.TEST.RPN_POST_NMS_TOP_N, + threshold=cfg.TEST.RPN_NMS_THRESH, rpn_min_size=cfg.TEST.RPN_MIN_SIZE) + group = mx.sym.Group([rois]) + self.sym = group + return group + + def resnetc4(self, data, fp16=False): + units = self.units + filter_list = self.filter_list + bn_mom = self.momentum + workspace = self.workspace + num_stage = len(units) + memonger = False + + data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, use_global_stats=True, name='bn_data') + body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2, 2), pad=(3, 3), + no_bias=True, name="conv0", workspace=workspace) + if fp16: + body = mx.sym.Cast(data=body, dtype=np.float16) + body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, use_global_stats=True, name='bn0') + body = mx.sym.Activation(data=body, act_type='relu', name='relu0') + body = mx.symbol.Pooling(data=body, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max') + + for i in range(num_stage - 1): + body = self.residual_unit(body, filter_list[i + 1], (1 if i == 0 else 2, 1 if i == 0 else 2), False, + name='stage%d_unit%d' % (i + 1, 1), workspace=workspace, + memonger=memonger, fix_bn=(i == 0)) + for j in range(units[i] - 1): + body = self.residual_unit(body, filter_list[i + 1], (1, 1), True, + name='stage%d_unit%d' % (i + 1, j + 2), + workspace=workspace, memonger=memonger, fix_bn=(i == 0)) + + return body + + def resnetc5(self, body, deform): + units = self.units + filter_list = self.filter_list + workspace = self.workspace + num_stage = len(units) + memonger = False + + i = num_stage - 1 + if deform: + body = self.residual_unit_deform(body, filter_list[i + 1], (1, 1), False, + name='stage%d_unit%d' % (i + 1, 1), workspace=workspace, + memonger=memonger) + else: + body = self.residual_unit_dilate(body, filter_list[i + 1], (1, 1), False, + name='stage%d_unit%d' % (i + 1, 1), workspace=workspace, + memonger=memonger) + for j in range(units[i] - 1): + if deform: + body = self.residual_unit_deform(body, filter_list[i + 1], (1, 1), True, + name='stage%d_unit%d' % (i + 1, j + 2), + workspace=workspace, memonger=memonger) + else: + body = self.residual_unit_dilate(body, filter_list[i + 1], (1, 1), True, + name='stage%d_unit%d' % (i + 1, j + 2), + workspace=workspace, memonger=memonger) + + return body + + def init_weight_rcnn(self, cfg, arg_params, aux_params): + arg_params['stage4_unit1_offset_weight'] = mx.nd.zeros(shape=self.arg_shape_dict['stage4_unit1_offset_weight']) + arg_params['stage4_unit1_offset_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['stage4_unit1_offset_bias']) + arg_params['stage4_unit2_offset_weight'] = mx.nd.zeros(shape=self.arg_shape_dict['stage4_unit2_offset_weight']) + arg_params['stage4_unit2_offset_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['stage4_unit2_offset_bias']) + arg_params['stage4_unit3_offset_weight'] = mx.nd.zeros(shape=self.arg_shape_dict['stage4_unit3_offset_weight']) + arg_params['stage4_unit3_offset_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['stage4_unit3_offset_bias']) + + arg_params['rpn_conv_3x3_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['rpn_conv_3x3_weight']) + arg_params['rpn_conv_3x3_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['rpn_conv_3x3_bias']) + arg_params['rpn_cls_score_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['rpn_cls_score_weight']) + arg_params['rpn_cls_score_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['rpn_cls_score_bias']) + arg_params['rpn_bbox_pred_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['rpn_bbox_pred_weight']) + arg_params['rpn_bbox_pred_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['rpn_bbox_pred_bias']) + + def init_weight(self, cfg, arg_params, aux_params): + self.init_weight_rcnn(cfg, arg_params, aux_params)