Skip to content

Commit

Permalink
CodeCamp open-mmlab#151[Feature] Support HieraSeg on cityscapes (open…
Browse files Browse the repository at this point in the history
…-mmlab#2444)

## Support `HieraSeg` interface on `cityscapes`

## Motivation

Support `HieraSeg` interface on cityscapes dataset  
Paper link : https://ieeexplore.ieee.org/document/9878466/

```
@Article{li2022deep,
  title={Deep Hierarchical Semantic Segmentation},
  author={Li, Liulei and Zhou, Tianfei and Wang, Wenguan and Li, Jianwu and Yang, Yi},
  journal={CVPR},
  year={2022}
}
```

## Modification

Add `HieraSeg_Projects` on `projects/`
Add `sep_aspp_contrast_head` decoder head.
Add `HieraSeg` config.
Add `hiera_loss`, `hiera_triplet_loss_cityscape`, `tree_triplet_loss`
  • Loading branch information
AI-Tianlong committed Jan 12, 2023
1 parent 6521e36 commit 933d77a
Show file tree
Hide file tree
Showing 11 changed files with 765 additions and 0 deletions.
93 changes: 93 additions & 0 deletions projects/HieraSeg/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# HieraSeg

Support `Deep Hierarchical Semantic Segmentation` interface on `cityscapes`

## Description

Author: AI-Tianlong

This project implements `HieraSeg` inference in the `cityscapes` dataset

## Usage

### Prerequisites

- Python 3.8
- PyTorch 1.6 or higher
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc3
- mmcv v2.0.0rc3
- mmengine

### Dataset preparing

preparing `cityscapes` dataset like this [structure](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#prepare-datasets)

### Testing commands

please put [`hieraseg_deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024_20230112_125023-bc59a3d1.pth`](https://download.openmmlab.com/mmsegmentation/v0.5/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024_20230112_125023-bc59a3d1.pth) to `mmsegmentation/checkpoints`

#### Multi-GPUs Test

```bash
# --tta optional, multi-scale test, need mmengine >=0.4.0
bash tools/dist_test.sh [configs] [model weights] [number of gpu] --tta
```

#### Example

```shell
bash tools/dist_test.sh projects/HieraSeg_project/configs/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80l_cityscapes-512x1024.py checkpoints/hieraseg_deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024_20230112_125023-bc59a3d1.pth 2 --tta
```

## Results

### Cityscapes

| Method | Backbone | Crop Size | mIoU | mIoU (ms+flip) | config | model |
| :--------: | :------: | :-------: | :---: | :------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------: |
| DeeplabV3+ | R-101-D8 | 512x1024 | 81.61 | 82.71 | [config](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/projects/HieraSeg/configs/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80l_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024_20230112_125023-bc59a3d1.pth) |

<img src="https://user-images.githubusercontent.com/50650583/210488953-e3e35ade-1132-47e1-9dfd-cf12b357ae80.png" width="50%"><img src="https://user-images.githubusercontent.com/50650583/210489746-e35ee229-3234-4292-a649-a8cd85f312ad.png" width="50%">

## Citation

This project is modified from [qhanghu/HSSN_pytorch](https://github.com/qhanghu/HSSN_pytorch)

```bibtex
@article{li2022deep,
title={Deep Hierarchical Semantic Segmentation},
author={Li, Liulei and Zhou, Tianfei and Wang, Wenguan and Li, Jianwu and Yang, Yi},
journal={CVPR},
year={2022}
}
```

## Checklist

- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.

- [x] Finish the code

- [x] Basic docstrings & proper citation

- [x] Test-time correctness

- [x] A full README

- [ ] Milestone 2: Indicates a successful model implementation.

- [ ] Training-time correctness

- [ ] Milestone 3: Good to be a part of our core package!

- [ ] Type hints and docstrings

- [ ] Unit tests

- [ ] Code polishing

- [ ] Metafile.yml

- [ ] Move your modules into the core package following the codebase's file hierarchy structure.

- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
67 changes: 67 additions & 0 deletions projects/HieraSeg/configs/_base_/datasets/cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# dataset settings
dataset_type = 'CityscapesDataset'
data_root = 'data/cityscapes/'
crop_size = (512, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomResize',
scale=(2048, 1024),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='leftImg8bit/train', seg_map_path='gtFine/train'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='leftImg8bit/val', seg_map_path='gtFine/val'),
pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator
15 changes: 15 additions & 0 deletions projects/HieraSeg/configs/_base_/default_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
default_scope = 'mmseg'
env_cfg = dict(
cudnn_benchmark=True,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
log_processor = dict(by_epoch=False)
log_level = 'INFO'
load_from = None
resume = False

tta_model = dict(type='SegTTAModel')
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
pretrained=None,
backbone=dict(
type='ResNetV1d',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='DepthwiseSeparableASPPContrastHead',
in_channels=2048,
in_index=3,
channels=512,
dilations=(1, 12, 24, 36),
c1_in_channels=256,
c1_channels=48,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
proj='convmlp',
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
24 changes: 24 additions & 0 deletions projects/HieraSeg/configs/_base_/schedules/schedule_80k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=1e-4,
power=0.9,
begin=0,
end=80000,
by_epoch=False)
]
# training schedule for 80k
train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=8000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=8000),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
_base_ = [
'../_base_/models/deeplabv3plus_r50-d8_vd_contrast.py',
'../_base_/datasets/cityscapes.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_80k.py'
]

custom_imports = dict(imports=[
'projects.HieraSeg.decode_head.sep_aspp_contrast_head',
'projects.HieraSeg.losses.hiera_triplet_loss_cityscape'
])

model = dict(
pretrained=None,
backbone=dict(depth=101),
decode_head=dict(
num_classes=26,
loss_decode=dict(
type='HieraTripletLossCityscape', num_classes=19,
loss_weight=1.0)),
auxiliary_head=dict(num_classes=19),
test_cfg=dict(mode='whole', is_hiera=True, hiera_num_classes=7))
4 changes: 4 additions & 0 deletions projects/HieraSeg/decode_head/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .sep_aspp_contrast_head import DepthwiseSeparableASPPContrastHead

__all__ = ['DepthwiseSeparableASPPContrastHead']
Loading

0 comments on commit 933d77a

Please sign in to comment.