Skip to content

Commit

Permalink
Add cityscapes dataset (#1037)
Browse files Browse the repository at this point in the history
* added cityscapes

* updated configs

* removed wip configs

* Add initial dataset instructions

* Add cityscapes readme

* Add explanation for lr scaling

* Ensure pep8 conformity

* Add CityscapesDataset to the registry

* add benchmark

* rename config, modify README.md

* fix typo

* fix typo in config

* modify INSTALL.md

Update information how to arrange cityscapes data.

* Add cityscapes class names
  • Loading branch information
michaelisc authored and hellock committed Jul 27, 2019
1 parent f97d361 commit 1c28e66
Show file tree
Hide file tree
Showing 7 changed files with 426 additions and 4 deletions.
11 changes: 11 additions & 0 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,21 @@ mmdetection
│ │ ├── train2017
│ │ ├── val2017
│ │ ├── test2017
│ ├── cityscapes
│ │ ├── annotations
│ │ ├── train
│ │ ├── val
│ ├── VOCdevkit
│ │ ├── VOC2007
│ │ ├── VOC2012
```
The cityscapes annotations have to be converted into the coco format using the [cityscapesScripts](https://github.com/mcordts/cityscapesScripts) toolbox.
We plan to provide an easy to use conversion script. For the moment we recommend following the instructions provided in the
[maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark/tree/master/maskrcnn_benchmark/data) toolbox. When using this script all images have to be moved into the same folder. On linux systems this can e.g. be done for the train images with:
```shell
cd data/cityscapes/
mv train/*/* train/
```

### Scripts
Expand Down
28 changes: 28 additions & 0 deletions configs/cityscapes/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
## Common settings

- All baselines were trained using 8 GPU with a batch size of 8 (1 images per GPU) using the [linear scaling rule](https://arxiv.org/abs/1706.02677) to scale the learning rate.
- All models were trained on `cityscapes_train`, and tested on `cityscapes_val`.
- 1x training schedule indicates 64 epochs which corresponds to slightly less than the 24k iterations reported in the original schedule from the [Mask R-CNN paper](https://arxiv.org/abs/1703.06870)
- All pytorch-style pretrained backbones on ImageNet are from PyTorch model zoo.


## Baselines

Download links and more models with different backbones and training schemes will be added to the model zoo.


### Faster R-CNN

| Backbone | Style | Lr schd | Scale | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |
| :-------------: | :-----: | :-----: | :---: | :------: | :-----------------: | :------------: | :----: | :------: |
| R-50-FPN | pytorch | 1x | 800-1024 | 4.9 | 0.345 | 8.8 | 36.0 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/cityscapes/faster_rcnn_r50_fpn_1x_city_20190727-7b9c0534.pth) |

### Mask R-CNN

| Backbone | Style | Lr schd | Scale | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | mask AP | Download |
| :-------------: | :-----: | :-----: | :------: | :------: | :-----------------: | :------------: | :----: | :-----: | :------: |
| R-50-FPN | pytorch | 1x | 800-1024 | 4.9 | 0.609 | 2.5 | 37.4 | 32.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/cityscapes/mask_rcnn_r50_fpn_1x_city_20190727-9b3c56a5.pth) |

**Notes:**
- In the original paper, the mask AP of Mask R-CNN R-50-FPN is 31.5.

175 changes: 175 additions & 0 deletions configs/cityscapes/faster_rcnn_r50_fpn_1x_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# model settings
model = dict(
type='FasterRCNN',
pretrained='modelzoo://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=9,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False))
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_num=1000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
# soft-nms is also supported for rcnn testing
# e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
)
# dataset settings
dataset_type = 'CityscapesDataset'
data_root = 'data/cityscapes/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data = dict(
imgs_per_gpu=1,
workers_per_gpu=2,
train=dict(
type='RepeatDataset', # to avoid reloading datasets frequently
times=8,
dataset=dict(
type=dataset_type,
ann_file=data_root +
'annotations/instancesonly_filtered_gtFine_train.json',
img_prefix=data_root + 'train/',
img_scale=[(2048, 800), (2048, 1024)],
img_norm_cfg=img_norm_cfg,
multiscale_mode='range',
size_divisor=32,
flip_ratio=0.5,
with_mask=False,
with_crowd=True,
with_label=True)),
val=dict(
type=dataset_type,
ann_file=data_root +
'annotations/instancesonly_filtered_gtFine_val.json',
img_prefix=data_root + 'val/',
img_scale=(2048, 1024),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_crowd=True,
with_label=True),
test=dict(
type=dataset_type,
ann_file=data_root +
'annotations/instancesonly_filtered_gtFine_val.json',
img_prefix=data_root + 'val/',
img_scale=(2048, 1024),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_label=False,
test_mode=True))
# optimizer
# lr is set for a batch size of 8
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[6])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=100,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 8 # actual epoch = 8 * 8 = 64
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/cityscapes/faster_rcnn_r50_fpn_1x_cityscapes'
load_from = None
resume_from = None
workflow = [('train', 1)]
Loading

0 comments on commit 1c28e66

Please sign in to comment.