Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cityscapes dataset #1037

Merged
merged 15 commits into from
Jul 27, 2019
11 changes: 11 additions & 0 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,21 @@ mmdetection
│ │ ├── train2017
│ │ ├── val2017
│ │ ├── test2017
│ ├── cityscapes
│ │ ├── annotations
│ │ ├── train
│ │ ├── val
│ ├── VOCdevkit
│ │ ├── VOC2007
│ │ ├── VOC2012

```
The cityscapes annotations have to be converted into the coco format using the [cityscapesScripts](https://github.com/mcordts/cityscapesScripts) toolbox.
We plan to provide an easy to use conversion script. For the moment we recommend following the instructions provided in the
[maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark/tree/master/maskrcnn_benchmark/data) toolbox. When using this script all images have to be moved into the same folder. On linux systems this can e.g. be done for the train images with:
```shell
cd data/cityscapes/
mv train/*/* train/
```

### Scripts
Expand Down
28 changes: 28 additions & 0 deletions configs/cityscapes/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
## Common settings

- All baselines were trained using 8 GPU with a batch size of 8 (1 images per GPU) using the [linear scaling rule](https://arxiv.org/abs/1706.02677) to scale the learning rate.
- All models were trained on `cityscapes_train`, and tested on `cityscapes_val`.
- 1x training schedule indicates 64 epochs which corresponds to slightly less than the 24k iterations reported in the original schedule from the [Mask R-CNN paper](https://arxiv.org/abs/1703.06870)
- All pytorch-style pretrained backbones on ImageNet are from PyTorch model zoo.


## Baselines

Download links and more models with different backbones and training schemes will be added to the model zoo.


### Faster R-CNN

| Backbone | Style | Lr schd | Scale | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |
| :-------------: | :-----: | :-----: | :---: | :------: | :-----------------: | :------------: | :----: | :------: |
| R-50-FPN | pytorch | 1x | 800-1024 | 4.9 | 0.345 | 8.8 | 36.0 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/cityscapes/faster_rcnn_r50_fpn_1x_city_20190727-7b9c0534.pth) |

### Mask R-CNN

| Backbone | Style | Lr schd | Scale | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | mask AP | Download |
| :-------------: | :-----: | :-----: | :------: | :------: | :-----------------: | :------------: | :----: | :-----: | :------: |
| R-50-FPN | pytorch | 1x | 800-1024 | 4.9 | 0.609 | 2.5 | 37.4 | 32.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/cityscapes/mask_rcnn_r50_fpn_1x_city_20190727-9b3c56a5.pth) |

**Notes:**
- In the original paper, the mask AP of Mask R-CNN R-50-FPN is 31.5.

175 changes: 175 additions & 0 deletions configs/cityscapes/faster_rcnn_r50_fpn_1x_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# model settings
model = dict(
type='FasterRCNN',
pretrained='modelzoo://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=9,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False))
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_num=1000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
# soft-nms is also supported for rcnn testing
# e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
)
# dataset settings
dataset_type = 'CityscapesDataset'
data_root = 'data/cityscapes/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data = dict(
imgs_per_gpu=1,
workers_per_gpu=2,
train=dict(
type='RepeatDataset', # to avoid reloading datasets frequently
times=8,
dataset=dict(
type=dataset_type,
ann_file=data_root +
'annotations/instancesonly_filtered_gtFine_train.json',
img_prefix=data_root + 'train/',
img_scale=[(2048, 800), (2048, 1024)],
img_norm_cfg=img_norm_cfg,
multiscale_mode='range',
size_divisor=32,
flip_ratio=0.5,
with_mask=False,
with_crowd=True,
with_label=True)),
val=dict(
type=dataset_type,
ann_file=data_root +
'annotations/instancesonly_filtered_gtFine_val.json',
img_prefix=data_root + 'val/',
img_scale=(2048, 1024),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_crowd=True,
with_label=True),
test=dict(
type=dataset_type,
ann_file=data_root +
'annotations/instancesonly_filtered_gtFine_val.json',
img_prefix=data_root + 'val/',
img_scale=(2048, 1024),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_label=False,
test_mode=True))
# optimizer
# lr is set for a batch size of 8
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[6])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=100,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 8 # actual epoch = 8 * 8 = 64
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/cityscapes/faster_rcnn_r50_fpn_1x_cityscapes'
load_from = None
resume_from = None
workflow = [('train', 1)]
Loading