Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] add DPT head #605

Merged
merged 48 commits into from Aug 30, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
46dce78
add DPT head
Jun 17, 2021
7b80fd0
[fix] fix init error
Jun 17, 2021
01b3da2
use mmcv function
Jun 18, 2021
e9df435
delete code
Jun 19, 2021
93635c0
merge upstream
Jun 19, 2021
b21ea15
remove transpose clas
Jun 19, 2021
2efb2eb
support NLC output shape
Jun 19, 2021
685644a
Merge branch 'add_vit_output_type' into dpt
Jun 19, 2021
5f877e1
Delete post_process_layer.py
Jun 22, 2021
5ce02d3
add unittest and docstring
Jun 22, 2021
7f7e4a4
Merge branch 'dpt' of https://github.com/xiexinch/mmsegmentation into…
Jun 22, 2021
de5b3a2
merge conflict
Jun 22, 2021
adbfb60
merge upstream master
Jul 5, 2021
31c42bd
rename variables
Jul 5, 2021
bf900b6
fix project error and add unittest
Jul 5, 2021
716863b
match dpt weights
Jul 6, 2021
94bf935
add configs
Jul 6, 2021
d4cd924
fix vit pos_embed bug and dpt feature fusion bug
Jul 7, 2021
ded2834
merge master
Jul 20, 2021
f147aa9
match vit output
Jul 20, 2021
0e4fb4f
fix gelu
Jul 20, 2021
6073dfa
minor change
Jul 20, 2021
1ebb558
update unitest
Jul 20, 2021
b3903ca
fix configs error
Jul 20, 2021
ef87aa5
inference test
Jul 22, 2021
9669d54
remove auxilary
Jul 22, 2021
0363746
use local pretrain
Jul 29, 2021
e1ecf6a
update training results
Aug 11, 2021
0126c24
Merge branch 'master' of https://github.com/open-mmlab/mmsegmentation…
Aug 11, 2021
7726d2b
update yml
Aug 11, 2021
c5593af
update fps and memory test
Aug 12, 2021
30aabc4
update doc
Aug 19, 2021
64e6f64
update readme
Aug 19, 2021
b749507
merge master
Aug 19, 2021
96ce175
add yml
Aug 19, 2021
fa61339
update doc
Aug 19, 2021
55bcd74
remove with_cp
Aug 19, 2021
4b33f6f
update config
Aug 19, 2021
76344cd
update docstring
Aug 19, 2021
94fb8d4
remove dpt-l
Aug 25, 2021
5e56d1b
add init_cfg and modify readme.md
Aug 25, 2021
f4ad2fa
Update dpt_vit-b16.py
Junjun2016 Aug 25, 2021
161d494
zh-n README
Aug 25, 2021
6b506ba
Merge branch 'dpt' of github.com:xiexinch/mmsegmentation into dpt
Aug 25, 2021
dca6387
solve conflict
Aug 30, 2021
a41ce05
use constructor instead of build function
Aug 30, 2021
78b56b1
prevent tensor being modified by ConvModule
Aug 30, 2021
522cdff
fix unittest
Aug 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmseg/models/decode_heads/__init__.py
Expand Up @@ -5,6 +5,7 @@
from .da_head import DAHead
from .dm_head import DMHead
from .dnl_head import DNLHead
from .dpt import DPTHead
from .ema_head import EMAHead
from .enc_head import EncHead
from .fcn_head import FCNHead
Expand All @@ -24,5 +25,5 @@
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead'
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'DPTHead'
]
199 changes: 199 additions & 0 deletions mmseg/models/decode_heads/dpt.py
@@ -0,0 +1,199 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import (Conv2d, ConvModule, ConvTranspose2d,
build_activation_layer, build_norm_layer)

from ..builder import HEADS
from ..utils import Transpose, _make_readout_ops
from .decode_head import BaseDecodeHead


class ViTPostProcessBlock(nn.Module):

xiexinch marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self,
in_channels=768,
out_channels=[96, 192, 384, 768],
img_size=[384, 384],
readout_type='ignore',
start_index=1,
kernel_sizes=[4, 2, 1, 3],
strides=[4, 2, 1, 2],
paddings=[0, 0, 0, 1]):
super(ViTPostProcessBlock, self).__init__()

self.readout_ops = _make_readout_ops(in_channels, out_channels,
readout_type, start_index)

self.unflatten_size = torch.Size(
[img_size[0] // 16, img_size[1] // 16])

self.post_process_ops = []
for idx, out_channels in enumerate(out_channels):
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
self.post_process_ops.append(
nn.Sequential(
self.readout_ops[idx], Transpose(1, 2),
nn.Unflatten(2, self.unflatten_size),
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
Conv2d(in_channels, out_channels, kernel_size=1),
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
ConvTranspose2d(
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
out_channels,
out_channels,
kernel_size=kernel_sizes[idx],
stride=strides[idx],
padding=paddings[idx])))

def forward(self, inputs):
assert len(inputs) == len(self.readout_ops)
for idx, x in enumerate(inputs):
inputs[idx] = self.post_process_ops[idx](x)
return inputs


class ResidualConvUnit(nn.Module):
xiexinch marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, in_channels, act_cfg=None, norm_cfg=None):
super(ResidualConvUnit, self).__init__()
self.channels = in_channels

self.activation = build_activation_layer(act_cfg)
self.bn = False if norm_cfg is None else True
self.bias = not self.bn

self.conv1 = Conv2d(
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
self.channels,
self.channels,
kernel_size=3,
padding=1,
bias=self.bias)

self.conv2 = Conv2d(
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
self.channels,
self.channels,
kernel_size=3,
padding=1,
bias=self.bias)

if self.bn:
self.bn1 = build_norm_layer(norm_cfg, self.channels)
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
self.bn2 = build_norm_layer(norm_cfg, self.channels)

def forward(self, inputs):
x = self.activation(inputs)
x = self.conv1(x)
if self.bn:
x = self.bn1(x)

x = self.activation(x)
x = self.conv2(x)
if self.bn:
x = self.bn2(x)

return x + inputs


class FeatureFusionBlock(nn.Module):

def __init__(self,
in_channels,
act_cfg=None,
norm_cfg=None,
deconv=False,
expand=False,
align_corners=True):
super(FeatureFusionBlock, self).__init__()

self.in_channels = in_channels
self.expand = expand
self.deconv = deconv
self.align_corners = align_corners

self.out_channels = in_channels
if self.expand:
self.out_channels = in_channels // 2

self.out_conv = Conv2d(
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
self.in_channels, self.out_channels, kernel_size=1)

self.res_conv_unit1 = ResidualConvUnit(self.in_channels, act_cfg,
norm_cfg)
self.res_conv_unit2 = ResidualConvUnit(self.in_channels, act_cfg,
norm_cfg)

def forward(self, *inputs):
x = inputs[0]
if len(inputs) == 2:
x = x + self.res_conv_unit1(inputs[1])
x = self.res_conv_unit2(x)
x = F.interpolate(
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
x,
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners)
return self.out_conv(x)


@HEADS.register_module()
class DPTHead(BaseDecodeHead):
"""Vision Transformers for Dense Prediction.

This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.

Args:
"""

def __init__(self,
img_size=[384, 384],
out_channels=[96, 192, 384, 768],
readout_type='ignore',
patch_start_index=1,
post_process_kernel_size=[4, 2, 1, 3],
post_process_strides=[4, 2, 1, 2],
post_process_paddings=[0, 0, 0, 1],
expand_channels=False,
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN'),
**kwards):
super(DPTHead, self).__init__(**kwards)

self.in_channels = self.in_channels
self.out_channels = out_channels
self.expand_channels = expand_channels
self.post_process_block = ViTPostProcessBlock(
self.channels, out_channels, img_size, readout_type,
patch_start_index, post_process_kernel_size, post_process_strides,
post_process_paddings)

out_channels = [
channel * math.pow(2, idx) if expand_channels else channel
for idx, channel in enumerate(self.out_channels)
]
self.convs = []
for idx, channel in enumerate(self.out_channels):
self.convs.append(
Conv2d(
channel, self.out_channels[idx], kernel_size=3, padding=1))

self.refinenet0 = FeatureFusionBlock(self.channels, act_cfg, norm_cfg)
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
self.refinenet1 = FeatureFusionBlock(self.channels, act_cfg, norm_cfg)
self.refinenet2 = FeatureFusionBlock(self.channels, act_cfg, norm_cfg)
self.refinenet3 = FeatureFusionBlock(self.channels, act_cfg, norm_cfg)

self.conv = ConvModule(
self.channels, self.channels, kernel_size=3, padding=1)

def forward(self, inputs):
x = self._transform_inputs(inputs)
x = self.post_process_block(x)

x = [self.convs[idx](feature) for idx, feature in enumerate(x)]

path_3 = self.refinenet3(x[3])
path_2 = self.refinenet2(path_3, x[2])
path_1 = self.refinenet1(path_2, x[1])
path_0 = self.refinenet0(path_1, x[0])

x = self.conv(path_0)
output = self.cls_seg(x)
return output
4 changes: 3 additions & 1 deletion mmseg/models/utils/__init__.py
@@ -1,6 +1,7 @@
from .drop import DropPath
from .inverted_residual import InvertedResidual, InvertedResidualV3
from .make_divisible import make_divisible
from .post_process_layer import Transpose, _make_readout_ops
from .res_layer import ResLayer
from .se_layer import SELayer
from .self_attention_block import SelfAttentionBlock
Expand All @@ -9,5 +10,6 @@

__all__ = [
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'DropPath', 'trunc_normal_'
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'DropPath',
'trunc_normal_', '_make_readout_ops', 'Transpose'
]
66 changes: 66 additions & 0 deletions mmseg/models/utils/post_process_layer.py
@@ -0,0 +1,66 @@
import torch
import torch.nn as nn

xiexinch marked this conversation as resolved.
Show resolved Hide resolved

class Readout(nn.Module):

def __init__(self, start_index=1):
super(Readout, self).__init__()
self.start_index = start_index


class Slice(Readout):

def forward(self, x):
return x[:, self.start_index:]


class AddReadout(Readout):

def forward(self, x):
if self.start_index == 2:
readout = (x[:, 0] + x[:, 1]) / 2
else:
readout = x[:, 0]
return x[:, self.start_index:] + readout.unsqueeze(1)


class ProjectReadout(Readout):

def __init__(self, in_channels, start_index=1):
super().__init__(start_index=start_index)
self.project = nn.Sequential(
nn.Linear(2 * in_channels, in_channels), nn.GELU)

def forward(self, x):
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:])
features = torch.cat((x[:, self.start_index:], readout), -1)
return self.project(features)


def _make_readout_ops(channels, out_channels, readout_type, start_index):
if readout_type == 'ignore':
readout_ops = [Slice(start_index) for _ in out_channels]
elif readout_type == 'add':
readout_ops = [AddReadout(start_index) for _ in out_channels]
elif readout_type == 'project':
readout_ops = [
ProjectReadout(channels, start_index) for _ in out_channels
]
else:
assert f"unexpected readout operation type, expected 'ignore',\
'add' or 'project', but got {readout_type}"

return readout_ops


class Transpose(nn.Module):

def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1

def forward(self, x):
x = x.transpose(self.dim0, self.dim1)
return x