Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] add SFSegNet head #733

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions configs/_base_/datasets/cityscapes_pd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# dataset settings
dataset_type = 'CityscapesDataset'
data_root = 'data/cityscapes/'
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
crop_size = (512, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 1024),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='leftImg8bit/train',
ann_dir='gtFine/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='leftImg8bit/val',
ann_dir='gtFine/val',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='leftImg8bit/val',
ann_dir='gtFine/val',
pipeline=test_pipeline))
35 changes: 35 additions & 0 deletions configs/_base_/models/sfnet_r50-d8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 2, 2),
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=False),
decode_head=dict(
type='SFNetHead',
in_channels=2048,
in_index=3,
channels=256,
pool_scales=(1, 2, 3, 6),
fpn_inplanes=[256, 512, 1024, 2048],
fpn_dim=256,
dropout_ratio=0,
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),

# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
32 changes: 32 additions & 0 deletions configs/_base_/models/sfnet_r50-d8_pd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='ResNetV1d',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=False),
decode_head=dict(
type='SFNetHead',
in_channels=2048,
in_index=3,
channels=256,
pool_scales=(1, 2, 3, 6),
dropout_ratio=0,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),

# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
11 changes: 11 additions & 0 deletions configs/sfnet/sfnet_r18-d32_512x1024_80k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_base_ = './sfnet_r50-d32_512x1024_80k_cityscapes.py'
model = dict(
pretrained='open-mmlab://resnet18_v1c',
backbone=dict(depth=18),
decode_head=dict(
in_channels=512,
channels=128,
fpn_inplanes=[64, 128, 256, 512],
fpn_dim=128,
),
)
4 changes: 4 additions & 0 deletions configs/sfnet/sfnet_r50-d32_512x1024_80k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = [
'../_base_/models/sfnet_r50-d8.py', '../_base_/datasets/cityscapes.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
5 changes: 5 additions & 0 deletions configs/sfnet/sfnet_temp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/sfnet_r50-d8_pd.py',
'../_base_/datasets/cityscapes_pd.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_80k.py'
]
11 changes: 11 additions & 0 deletions configs/sfnet/sfnet_temp18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_base_ = './sfnet_temp.py'
model = dict(
pretrained=None,
backbone=dict(type='ResNetV1c', depth=18, strides=(1, 2, 2, 2)),
decode_head=dict(
in_channels=512,
channels=128,
fpn_inplanes=[64, 128, 256, 512],
fpn_dim=128,
),
)
4 changes: 3 additions & 1 deletion mmseg/models/decode_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
from .sep_fcn_head import DepthwiseSeparableFCNHead
from .setr_mla_head import SETRMLAHead
from .setr_up_head import SETRUPHead
from .sfnet_head import SFNetHead
from .uper_head import UPerHead

__all__ = [
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', 'SETRMLAHead'
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
'SETRMLAHead', 'SFNetHead'
]
220 changes: 220 additions & 0 deletions mmseg/models/decode_heads/sfnet_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner.base_module import BaseModule

from mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
from .psp_head import PPM


@HEADS.register_module()
class SFNetHead(BaseDecodeHead):
"""Semantic Flow for Fast and Accurate SceneParsing.

This head is the implementation of
`SFSegNet <https://arxiv.org/pdf/2002.10120>`_.

Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module. Default: (1, 2, 3, 6).
fpn_inplanes (list):
The list of feature channels number from backbone.
fpn_dim (int, optional):
The input channels of FAM module.
Default: 256 for ResNet50, 128 for ResNet18.
"""

def __init__(self,
pool_scales=(1, 2, 3, 6),
fpn_inplanes=[256, 512, 1024, 2048],
fpn_dim=256,
**kwargs):
super(SFNetHead, self).__init__(**kwargs)
assert isinstance(pool_scales, (list, tuple))
self.pool_scales = pool_scales
self.fpn_inplanes = fpn_inplanes
self.fpn_dim = fpn_dim
self.psp_modules = PPM(
self.pool_scales,
self.in_channels,
self.in_channels // 4,
bias=True,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=True)
self.bottleneck = ConvModule(
self.in_channels * 2,
self.channels,
3,
padding=1,
bias=True,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)

self.fpn_in = []
for fpn_inplane in self.fpn_inplanes[:-1]:
self.fpn_in.append(
ConvModule(
fpn_inplane,
self.fpn_dim,
kernel_size=1,
bias=True,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=False))
self.fpn_in = nn.ModuleList(self.fpn_in)
self.fpn_out = []
self.fpn_out_align = []
self.dsn = []
for i in range(len(self.fpn_inplanes) - 1):
self.fpn_out.append(
ConvModule(
self.fpn_dim,
self.fpn_dim,
kernel_size=3,
stride=1,
padding=1,
bias=False,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=True))
self.fpn_out_align.append(
AlignedModule(
inplane=self.fpn_dim, outplane=self.fpn_dim // 2))

self.fpn_out = nn.ModuleList(self.fpn_out)
self.fpn_out_align = nn.ModuleList(self.fpn_out_align)
self.conv_last = ConvModule(
len(self.fpn_inplanes) * self.fpn_dim,
self.fpn_dim,
kernel_size=3,
stride=1,
padding=1,
bias=False,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=True)

def forward(self, inputs):
x = self._transform_inputs(inputs)
psp_outs = [x]
psp_outs.extend(self.psp_modules(x)[::-1])
psp_outs = torch.cat(psp_outs, dim=1)
psp_out = self.bottleneck(psp_outs)

f = psp_out
fpn_feature_list = [psp_out]

for i in reversed(range(len(inputs) - 1)):
conv_x = inputs[i]
conv_x = self.fpn_in[i](conv_x)
f = self.fpn_out_align[i]([conv_x, f])
f = conv_x + f
fpn_feature_list.append(self.fpn_out[i](f))

fpn_feature_list.reverse() # [P2 - P5]
output_size = fpn_feature_list[0].size()[2:]
fusion_list = [fpn_feature_list[0]]

for i in range(1, len(fpn_feature_list)):
fusion_list.append(
nn.functional.interpolate(
fpn_feature_list[i],
output_size,
mode='bilinear',
align_corners=True))

fusion_out = torch.cat(fusion_list, 1)
x = self.conv_last(fusion_out)
output = self.cls_seg(x)

return output


class AlignedModule(BaseModule):
"""The implementation of Flow Alignment Module (FAM).

Args:
inplane (int): The number of FAM input channles.
outplane (int): The number of FAM output channles.
"""

def __init__(self, inplane, outplane, kernel_size=3):
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
super(AlignedModule, self).__init__()
self.down_h = nn.Conv2d(inplane, outplane, 1, bias=False)
self.down_l = nn.Conv2d(inplane, outplane, 1, bias=False)
self.flow_make = nn.Conv2d(
outplane * 2, 2, kernel_size=kernel_size, padding=1, bias=False)

def forward(self, x):
low_feature, h_feature = x
h_feature_orign = h_feature
h, w = low_feature.size()[2:]
size = (h, w)
low_feature = self.down_l(low_feature)
h_feature = self.down_h(h_feature)
h_feature = resize(
h_feature, size=size, mode='bilinear', align_corners=True)
flow = self.flow_make(torch.cat([h_feature, low_feature], 1))
h_feature = self.flow_warp(h_feature_orign, flow, size=size)

return h_feature

def flow_warp(self, input, flow, size):
"""Implementation of Warp Procedure in Fig 3(b) of original paper,
which is between Flow Field and High Resolution Feature Map.

Args:
input (Tensor): High Resolution Feature Map.
flow (Tensor): Semantic Flow Field that will give
dynamic indication about how to align these
two feature maps effectively.
size (Tuple): Shape of height and width of output.

Returns:
output (Tensor): High Resolution Feature Map after
warped offset and bilinear interpolation.

For example, in cityscapes 1024x2048 dataset with ResNet18 config,
feature map from backbone is:
[[1, 64, 256, 512],
[1, 128, 128, 256],
[1, 256, 64, 128],
[1, 512, 32, 64]]

Thus, its inverse shape of [input, flow, size] is:
[[1, 128, 32, 64], [1, 2, 64, 128], (64, 128)],
[[1, 128, 64, 128], [1, 2, 128, 256], (128, 256)], and
[[1, 128, 128, 256], [1, 2, 256, 512], (256, 512)], respectively.

The final output is:
[[1, 128, 64, 128],
[1, 128, 128, 256],
[1, 128, 256, 512]], respectively.
"""

out_h, out_w = size
n, c, h, w = input.size()

# Warped offset in grid, from -1 to 1.
norm = torch.tensor([[[[out_w,
out_h]]]]).type_as(input).to(input.device)
h = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w)
w = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1)
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
grid = torch.cat((w.unsqueeze(2), h.unsqueeze(2)), 2)
grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device)

# Warped grid which is corrected the flow offset.
grid = grid + flow.permute(0, 2, 3, 1) / norm

# Sampling mechanism interpolates the values of the 4-neighbors
# (top-left, top-right, bottom-left, and bottom-right) of input.
output = nn.functional.grid_sample(input, grid, align_corners=True)
return output
Empty file added pth2kvlist.py
Empty file.
Loading