Skip to content

Commit

Permalink
code release
Browse files Browse the repository at this point in the history
  • Loading branch information
kshmelkov committed Jun 22, 2018
1 parent 8ad1689 commit f829412
Show file tree
Hide file tree
Showing 16 changed files with 2,554 additions and 0 deletions.
36 changes: 36 additions & 0 deletions README.md
@@ -0,0 +1,36 @@
# Incremental Learning of Object Detectors without Catastrophic Forgetting

This is code release for our paper ["Incremental Learning of Object Detectors without Catastrophic Forgetting"]( https://arxiv.org/abs/1708.06977) published on ICCV 2017.

## Requirements

Code is written for Python 3.5 and TensorFlow 1.5 (might require minor modifications for more recent versions). You are also expected to have normal scientific stack installed: NumPy, SciPy, Matplotlib, OpenCV.
If you don't like OpenCV, you can replace it with something that can read and resize images. SciPy is used only for interaction with Matlab.

You also need a checkpoint of [pre-trained ResNet-50](https://drive.google.com/drive/folders/1Xxs6jK_adXdr1asyyqxJiV3a3yUU2G9N?usp=sharing) to initialize an object detector. Put it in the directory `./resnet`. These weights are obtained from official Microsoft release, but slightly changed to correspond better to TF minor differences. This checkpoint is different from the one released by Google for TF-Slim.

## Datasets

All experiments were done on [PASCAL VOC 2007]*(http://host.robots.ox.ac.uk/pascal/VOC/*) and [Microsoft COCO]*(http://cocodataset.org/).
To use COCO you also need [pycocotools]*(https://github.com/cocodataset/cocoapi) installed.

## Experiments

To train and evaluate a normal FastRCNN on VOC 2007 launch the following command:

```
python3 frcnn.py sigmoid --run_name=resnet_sigmoid_20 --num_classes=20 --dataset=voc07 --max_iterations=40000 --action=train,eval --eval_first_n=5000 --eval_ckpts=40k --learning_rate=0.001 --sigmoid
```

To train 10 classes network and then extend it for 10 more classes:

```
python3 frcnn.py sigmoid --run_name=resnet_sigmoid_10 --num_classes=10 --dataset=voc07 --max_iterations=40000 --action=train,eval --eval_ckpts=40k --learning_rate=0.001 --lr_decay 30000 --sigmoid
python3 frcnn.py sigmoid --run_name=resnet_sigmoid_10_ext10 --num_classes=10 --extend=10 --dataset=voc07 --max_iterations=40000 --action=train,eval --eval_ckpts=40k --learning_rate=0.0001 --sigmoid --pretrained_net=resnet_sigmoid_10 --distillation --bias_distillation
```

The same way to train a COCO model on all classes:

```
python3 frcnn.py --run_name=resnet_coco_80 --num_classes=80 --dataset=coco --max_iterations=500000 --lr_decay_step=250000 --weight_decay=0.00005 --eval_first_n=5000 --eval_ckpts=500000 --action=train,eval --sigmoid"
```
157 changes: 157 additions & 0 deletions coco_loader.py
@@ -0,0 +1,157 @@
from scipy.io import loadmat
from loader import Loader, DATASETS_ROOT

from pycocotools.coco import COCO
from pycocotools import mask

import cv2
import numpy as np

COCO_VOC_CATS = ['__background__', 'airplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'dining table',
'dog', 'horse', 'motorcycle', 'person', 'potted plant',
'sheep', 'couch', 'train', 'tv']

COCO_NONVOC_CATS = ['apple', 'backpack', 'banana', 'baseball bat',
'baseball glove', 'bear', 'bed', 'bench', 'book', 'bowl',
'broccoli', 'cake', 'carrot', 'cell phone', 'clock', 'cup',
'donut', 'elephant', 'fire hydrant', 'fork', 'frisbee',
'giraffe', 'hair drier', 'handbag', 'hot dog', 'keyboard',
'kite', 'knife', 'laptop', 'microwave', 'mouse', 'orange',
'oven', 'parking meter', 'pizza', 'refrigerator', 'remote',
'sandwich', 'scissors', 'sink', 'skateboard', 'skis',
'snowboard', 'spoon', 'sports ball', 'stop sign',
'suitcase', 'surfboard', 'teddy bear', 'tennis racket',
'tie', 'toaster', 'toilet', 'toothbrush', 'traffic light',
'truck', 'umbrella', 'vase', 'wine glass', 'zebra']

COCO_CATS = COCO_VOC_CATS+COCO_NONVOC_CATS

coco_ids = {'airplane': 5, 'apple': 53, 'backpack': 27, 'banana': 52,
'baseball bat': 39, 'baseball glove': 40, 'bear': 23, 'bed': 65,
'bench': 15, 'bicycle': 2, 'bird': 16, 'boat': 9, 'book': 84,
'bottle': 44, 'bowl': 51, 'broccoli': 56, 'bus': 6, 'cake': 61,
'car': 3, 'carrot': 57, 'cat': 17, 'cell phone': 77, 'chair': 62,
'clock': 85, 'couch': 63, 'cow': 21, 'cup': 47, 'dining table':
67, 'dog': 18, 'donut': 60, 'elephant': 22, 'fire hydrant': 11,
'fork': 48, 'frisbee': 34, 'giraffe': 25, 'hair drier': 89,
'handbag': 31, 'horse': 19, 'hot dog': 58, 'keyboard': 76, 'kite':
38, 'knife': 49, 'laptop': 73, 'microwave': 78, 'motorcycle': 4,
'mouse': 74, 'orange': 55, 'oven': 79, 'parking meter': 14,
'person': 1, 'pizza': 59, 'potted plant': 64, 'refrigerator': 82,
'remote': 75, 'sandwich': 54, 'scissors': 87, 'sheep': 20, 'sink':
81, 'skateboard': 41, 'skis': 35, 'snowboard': 36, 'spoon': 50,
'sports ball': 37, 'stop sign': 13, 'suitcase': 33, 'surfboard':
42, 'teddy bear': 88, 'tennis racket': 43, 'tie': 32, 'toaster':
80, 'toilet': 70, 'toothbrush': 90, 'traffic light': 10, 'train':
7, 'truck': 8, 'tv': 72, 'umbrella': 28, 'vase': 86, 'wine glass':
46, 'zebra': 24}
coco_ids_to_cats = dict(map(reversed, list(coco_ids.items())))


class COCOLoader(Loader):
cats_to_ids = dict(map(reversed, enumerate(COCO_CATS)))
ids_to_cats = dict(enumerate(COCO_CATS))
num_classes = len(COCO_CATS)
categories = COCO_CATS[1:]

def __init__(self, year, proposals, split, num_proposals=2000, excluded=[], cats=COCO_CATS):
super().__init__()
# TODO support cat reshuffling
self.dataset = 'coco'
self.coco_ids_to_internal = {k: self.cats_to_ids[v] for k, v in coco_ids_to_cats.items()}
self.ids_to_coco_ids = dict(map(reversed, self.coco_ids_to_internal.items()))
self.split = split + year
assert self.split in ['train2014', 'val2014', 'test2014', 'test2015']
self.root = DATASETS_ROOT + 'coco/'
assert proposals in ['mcg', 'edgeboxes']
self.proposals = proposals
self.num_proposals = num_proposals
assert num_proposals >= 0
if excluded == []:
self.included_coco_ids = list(coco_ids.values())
else:
included_internal_ids = [i for i in self.coco_ids_to_internal.values()
if i not in excluded]
self.included_coco_ids = [coco_ids[COCOLoader.ids_to_cats[i]]
for i in included_internal_ids]

self.coco = COCO('%s/annotations/instances_%s.json' % (self.root, self.split))

def load_image(self, img_id, resize=True):
img = self.coco.loadImgs(img_id)[0]
im = cv2.imread('%simages/%s/%s' % (self.root, self.split, img['file_name']))
return self.convert_and_maybe_resize(im, resize)

def read_proposals(self, img_id):
img = self.coco.loadImgs(img_id)[0]
name = img['file_name'][:-4]
if self.proposals == 'edgeboxes':
mat = loadmat('%sEdgeBoxesProposals/%s/%s.mat' % (self.root, self.split, name))
# mat = loadmat('%sEdgeBoxesProposalsSmall/%s/%s.mat' % (self.root, self.split, name))
bboxes = mat['bbs'][:, :4]
if self.proposals == 'selective_search':
raise NotImplementedError
if self.proposals == 'mcg':
mat = loadmat('%sMCGProposals/MCG-COCO-%s-boxes/%s.mat' % (self.root, self.split, name))
bboxes = mat['boxes']
# (y1, x1, y2, x2) -> (x1, y1, w, h)
y1 = bboxes[:, 0]
x1 = bboxes[:, 1]
y2 = bboxes[:, 2]
x2 = bboxes[:, 3]
bboxes = np.stack([x1, y1, x2-x1, y2-y1], axis=1)
# print(bboxes.shape)
if self.num_proposals == 0:
return bboxes
else:
return bboxes[:self.num_proposals]

def get_filenames(self):
# strictly speaking those are not filenames,
# but the usage is consistent in this class
return self.coco.getImgIds()

def get_coco_annotations(self, img_id):
anns = self.coco.loadAnns(self.coco.getAnnIds(
imgIds=img_id, catIds=self.included_coco_ids, iscrowd=False))
return anns

# TODO should I filter categories here?
def read_annotations(self, img_id):
anns = self.get_coco_annotations(img_id)

bboxes = [ann['bbox'] for ann in anns]
cats = [coco_ids_to_cats[ann['category_id']] for ann in anns]
labels = [COCOLoader.cats_to_ids[cat_name] for cat_name in cats]

img = self.coco.loadImgs(img_id)[0]

return np.round(bboxes).astype(np.int32).reshape((-1, 4)), np.array(labels),\
img['width'], img['height'], np.zeros_like(labels, dtype=np.bool)

def _read_segmentation(self, ann, H, W):
s = ann['segmentation']
s = s if type(s) == list else [s]
return mask.decode(mask.frPyObjects(s, H, W)).max(axis=2)


def print_classes_stats(ids, title=""):
imgs = set()
for s in ids[:20]:
imgs = imgs | s
print("{}: {}".format(title, len(imgs)))


if __name__ == '__main__':
root = DATASETS_ROOT + 'coco/'
split = 'train2014'
coco = COCO('%s/annotations/instances_%s.json' % (root, split))
ids = []
for cat in COCO_CATS[1:]:
imgs = coco.getImgIds(catIds=[coco_ids[cat]])
print("{}: {}".format(cat, len(imgs)))
ids.append(set(imgs))
print_classes_stats(ids[:20], "VOC")
print_classes_stats(ids[20:], "NONVOC")
print_classes_stats(ids[20:30], "10 cats after VOC")
43 changes: 43 additions & 0 deletions compute_edgeboxes.m
@@ -0,0 +1,43 @@
addpath(genpath('~/scratch/nn/edges/'));
addpath(genpath('~/scratch/nn/toolbox/'));

% split = 'val2014'
split = 'train2014'
coco_dir = '/scratch2/clear/kshmelko/datasets/coco/';
proposals_dir = [coco_dir 'EdgeBoxesProposalsSmall/' split '/'];
image_dir = [coco_dir 'images/' split '/'];

% voc_dir = '/scratch2/clear/kshmelko/datasets/voc/VOCdevkit/VOC2012/';
% % proposals_dir = [voc_dir 'EBP/'];
% % proposals_dir = [voc_dir 'EdgeBoxesProposals/'];
% image_dir = [voc_dir 'JPEGImages/'];

content = dir([image_dir '*.jpg']);

model = load('~/scratch/nn/edges/models/forest/modelBsds');
model = model.model;
model.opts.multiscale = 1;
model.opts.sharpen = 2;
model.opts.nThreads = 4;

opts = edgeBoxes;
opts.alpha = .65; % step size of sliding window search
opts.beta = .75; % nms threshold for object proposals
opts.minScore = .01; % min score of boxes to detect
opts.maxBoxes = 2000; % max number of boxes to detect
opts.minBoxArea = 50; % min box area
% opts.minBoxArea = 1000; % min box area

for i = 1 : length(content)
if rem(i, 10) == 0
disp(i)
end
img_name = content(i).name;
I = imread([image_dir img_name]);
if length(size(I)) == 2
I = cat(3, I, I, I);
end
% disp(size(I))
bbs = edgeBoxes(I, model, opts);
save([proposals_dir img_name(1:end-3) 'mat'], 'bbs');
end
94 changes: 94 additions & 0 deletions config.py
@@ -0,0 +1,94 @@
import argparse

parser = argparse.ArgumentParser(description='Train or eval a FastRCNN trained on VOC or COCO.')
parser.add_argument("--run_name", type=str, required=True, help='Name of the current run to properly store logs and checkpoints')
parser.add_argument("--ckpt", default=0, type=int, help='Resume training from this checkpoint. Use the most recent one if 0')

parser.add_argument("--dataset", default='voc07', choices=['voc07', 'voc12', 'coco'])
parser.add_argument("--proposals", default='', choices=['mcg', 'edgeboxes', ''], help='Which proposals to use? Empty string means default per-dataset choice (MCG for COCO, EdgeBoxes for VOC).')
parser.add_argument("--num_classes", required=True, type=int, help='Train on this number of classes (first N).')
parser.add_argument("--extend", default=0, type=int, help='Extend existing network by this number of classes incrementally and train on them.')
parser.add_argument("--num_layers", default=56, type=int, help='Number of ResNet layers')
parser.add_argument("--action", required=True, type=str, 'Comma-separated list of actions. Implemented actions: train, eval.')
parser.add_argument("--data_format", default='NHWC', choices=['NHWC', 'NCHW'], help='Data format for conv2d. Using of NCHW gives more cudnn acceleration')
parser.add_argument("--sigmoid", default=False, action='store_true', help='Use sigmoid instead of softmax on the last layer.')
parser.add_argument("--print_step", default=10, type=int, help='Print training logs every N iterations')

# EVALUATION OPTIONS
parser.add_argument("--conf_thresh", default=0.5, type=float, help='Threshold detections with this confidence level.')
parser.add_argument("--nms_thresh", default=0.3, type=float, help='Do NMS on FastFRCNN output with this IoU threshold')
parser.add_argument("--eval_first_n", default=1000000, type=int, help='Only evaluate on first N images from dataset. Useful for COCO, for example')
parser.add_argument("--eval_ckpts", default='', type=str, help='Comma-separated list of checkpoints to evaluate. Supports k as suffix for thousands.')

# TRAINING OPTIONS
parser.add_argument("--batch_size", default=64, type=int, help='Number of proposals per batch')
parser.add_argument("--num_images", default=2, type=int, help='Number of images per batch')
parser.add_argument("--num_positives_in_batch", default=16, type=int, help='Number of positive proposals in the batch.')
parser.add_argument("--pretrained_net", default='', type=str, help='Run name for network we use to extend incrementally')
parser.add_argument("--train_vars", default='', type=str, help='Comma-separated list of substrings. If variable name contains any of them, it is going to be trained. Empty list disables this filtering.')
parser.add_argument("--optimizer", default='nesterov', choices=['adam', 'nesterov', 'sgd', 'momentum'])
parser.add_argument("--weight_decay", default=5e-5, type=float)
parser.add_argument("--learning_rate", default=1e-3, type=float)
parser.add_argument("--lr_decay", default=[], nargs='+', type=int, help='Space-separated list of steps where learning rate decays by factor of 10.')
parser.add_argument("--max_iterations", default=1000000, type=int, help='Total number of SGD steps.')
parser.add_argument("--reset_slots", default=True, type=bool, help='Should we clear out optimizer slots (momentum and Adam stuff) when we extend network?')

# DISTILLATION
# Lambda coefficients balancing each loss term
parser.add_argument("--frcnn_loss_coef", default=1.0, type=float)
parser.add_argument("--class_distillation_loss_coef", default=1.0, type=float)
parser.add_argument("--bbox_distillation_loss_coef", default=1.0, type=float)

parser.add_argument("--distillation", default=False, action='store_true', help='Boolean flag activating distillation')
# TODO make it default?
parser.add_argument("--bias_distillation", default=False, action='store_true', help='Boolean flag activating biased distillation. Requires --distillation flag to work.')
parser.add_argument("--crossentropy", default=False, action='store_true', help='Boolean flag to use crossentropy distillation instead of L2 distillation of logits')
parser.add_argument("--smooth_bbox_distillation", default=True, action='store_true', help='Boolean flag to use smooth L1 bounding box loss for distillation instead of just L2')

# Data loading and preprocessing threads.
parser.add_argument("--num_dataset_readers", default=2, type=int)
parser.add_argument("--num_prep_threads", default=4, type=int)

# deprecated flags, don't do anything in current version
parser.add_argument("--filter_proposals", default=False, action='store_true')
parser.add_argument("--prefetch_all", default=False, action='store_true')

args = parser.parse_args()

LOGS = './logs/'
CKPT_ROOT = './checkpoints/'


def get_logging_config(run):
return {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'standard': {
'format': '%(asctime)s [%(levelname)s]: %(message)s'
},
'short': {
'format': '[%(levelname)s]: %(message)s'
},
},
'handlers': {
'default': {
'level': 'INFO',
'formatter': 'short',
'class': 'logging.StreamHandler',
},
'file': {
'level': 'DEBUG',
'formatter': 'standard',
'class': 'logging.FileHandler',
'filename': LOGS+run+'.log'
},
},
'loggers': {
'': {
'handlers': ['default', 'file'],
'level': 'DEBUG',
'propagate': True
},
}
}

0 comments on commit f829412

Please sign in to comment.