Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeCamp2023-367] Add pp_mobileseg model #3239

Merged
merged 24 commits into from Aug 9, 2023

Conversation

Yang-Changhui
Copy link
Contributor

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

Please describe the motivation of this PR and the goal you want to achieve through this PR.

Modification

Please briefly describe what modification is made in this PR.

BC-breaking (Optional)

Does the modification introduce changes that break the backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

@mm-assistant
Copy link

mm-assistant bot commented Jul 31, 2023

We recommend using English or English & Chinese for pull requests so that we could have broader discussion.

@CLAassistant
Copy link

CLAassistant commented Jul 31, 2023

CLA assistant check
All committers have signed the CLA.

@OpenMMLab-Assistant-004
Copy link

Hi @Yang-Changhui,

We'd like to express our appreciation for your valuable contributions to the mmsegmentation. Your efforts have significantly aided in enhancing the project's quality.
It is our pleasure to invite you to join our community thorugh Discord_Special Interest Group (SIG) channel. This is a great place to share your experiences, discuss ideas, and connect with other like-minded people. To become a part of the SIG channel, send a message to the moderator, OpenMMLab, briefly introduce yourself and mention your open-source contributions in the #introductions channel. Our team will gladly facilitate your entry. We eagerly await your presence. Please follow this link to join us: ​https://discord.gg/UjgXkPWNqA.

If you're on WeChat, we'd also love for you to join our community there. Just add our assistant using the WeChat ID: openmmlabwx. When sending the friend request, remember to include the remark "mmsig + Github ID".

Thanks again for your awesome contribution, and we're excited to have you as part of our community!

from mmengine.model import BaseModule
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict

from mmseg.models.utils.transformer_utils import DropPath
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 17 to 20
cfg1,
cfg2,
cfg3,
cfg4,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, we should use meaningful parameter names.

mmseg/models/backbones/strideformer.py Outdated Show resolved Hide resolved
Comment on lines 228 to 259
class ConvBNAct(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
groups=1,
conv_cfg=dict(type='Conv'),
norm_cfg=dict(type='BN'),
act_cfg=None,
bias_attr=False):
super(ConvBNAct, self).__init__()

self.conv = build_conv_layer(
conv_cfg,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=None if bias_attr else False)
self.act = build_activation_layer(act_cfg) if act_cfg is not None else nn.Identity()
self.bn = build_norm_layer(norm_cfg, out_channels)[1] \
if norm_cfg is not None else nn.Identity()

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.act(x)
return x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its function is the same as ConvModule, might use ConvModule instead.

Comment on lines 262 to 278
class Conv2DBN(nn.Module):
def __init__(self,
in_channels,
out_channels,
ks=1,
stride=1,
pad=0,
dilation=1,
groups=1,
):
super().__init__()
self.conv_norm = ConvModule(in_channels, out_channels, ks, stride, pad, dilation, groups, False,
norm_cfg=dict(type='BN'), act_cfg=None)

def forward(self, inputs):
out = self.conv_norm(inputs)
return out
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's no need to define this class.

Comment on lines 367 to 381
self.conv1 = build_conv_layer(
conv_cfg,
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0)
self.relu = build_activation_layer(act_cfg)
self.conv2 = build_conv_layer(
conv_cfg,
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might instead with two ConvModule.

Comment on lines 485 to 486
class Sea_Attention(nn.Module):
def __init__(self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might rename to SeaAttention.

Comment on lines 622 to 623
class Fusion_block(nn.Module):
def __init__(self, inp, oup, embed_dim, activations=None) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might rename to FusionBlock.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the names of parameters are not readable, such as inp, oup. It suggests changing to names more readable.

Comment on lines 743 to 753

def _create_act(act):
if act == "hardswish":
return nn.Hardswish()
elif act == "relu":
return nn.ReLU()
elif act is None:
return None
else:
raise RuntimeError(
"The activation function is not supported: {}".format(act))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we replace it with build_activation_layer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build_activation_layer好像没有nn.Hardswish()有这个类

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那可以单独写一个包装类,然后注册进去吗,这样写法就比较通用了

Comment on lines 756 to 765
@MODELS.register_module()
class MobileSeg_Base(StrideFormer):
def __init__(self, **kwargs):
super().__init__(**kwargs)


@MODELS.register_module()
class MobileSeg_Tiny(StrideFormer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might add StrideFormer to the MODEL register and delete these two classes.

Comment on lines 17 to 33
mobileV3_cfg,
channels,
embed_dims,
key_dims=[16, 24],
depths=[2, 2],
num_heads=8,
attn_ratios=2,
mlp_ratios=[2, 4],
drop_path_rate=0.1,
act_cfg=dict(type='ReLU'),
inj_type='AAM',
out_channels=256,
dims=(128, 160),
out_feat_chs=None,
stride_attention=True,
pretrained=None,
init_cfg=None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might add type hints

Comment on lines 305 to 306
class ResidualUnit(nn.Module):
def __init__(self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And type hints are needed.

Comment on lines 397 to 399
class SqueezeAxialPositionalEmbedding(nn.Module):
def __init__(self, dim, shape):
super().__init__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might add a docstring.

Comment on lines 414 to 422
class SeaAttention(nn.Module):
def __init__(self,
dim,
key_dim,
num_heads,
attn_ratio=4.,
act_cfg=None,
norm_cfg=dict(type='BN'),
stride_attention=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might add docstring and type hints.

Comment on lines 719 to 728
class HSigmoid(nn.Module):
def __init__(self):
super().__init__()
self.relu = nn.ReLU6()

def forward(self, x):
return self.relu(x + 3) / 6


MODELS.register_module(module=HSigmoid, name='HSigmoid')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class HSigmoid(nn.Module):
def __init__(self):
super().__init__()
self.relu = nn.ReLU6()
def forward(self, x):
return self.relu(x + 3) / 6
MODELS.register_module(module=HSigmoid, name='HSigmoid')
@MODELS.register_module()
class HSigmoid(nn.Module):
def __init__(self):
super().__init__()
self.relu = nn.ReLU6()
def forward(self, x):
return self.relu(x + 3) / 6

Comment on lines 731 to 740
class hardswish(nn.Module):
def __init__(self, inplace=False):
super().__init__()
self.relu = nn.Hardswish(inplace=inplace)

def forward(self, x):
return self.relu(x)


MODELS.register_module(module=hardswish, name='hardswish')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class hardswish(nn.Module):
def __init__(self, inplace=False):
super().__init__()
self.relu = nn.Hardswish(inplace=inplace)
def forward(self, x):
return self.relu(x)
MODELS.register_module(module=hardswish, name='hardswish')
@MODELS.register_module()
class Hardswish(nn.Module):
def __init__(self, inplace=False):
super().__init__()
self.relu = nn.Hardswish(inplace=inplace)
def forward(self, x):
return self.relu(x)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might add a docstring to explain that it's a wrapper to torch.nn.Hardswish.

Comment on lines 743 to 753
class Hardsigmoid(nn.Module):
def __init__(self, slope=0.2, offset=0.5, inplace=False):
super().__init__()
self.slope = slope
self.offset = offset

def forward(self, x):
return x.mul(self.slope).add(self.offset).clamp(0, 1)


MODELS.register_module(module=Hardsigmoid, name='Hardsigmoid')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class Hardsigmoid(nn.Module):
def __init__(self, slope=0.2, offset=0.5, inplace=False):
super().__init__()
self.slope = slope
self.offset = offset
def forward(self, x):
return x.mul(self.slope).add(self.offset).clamp(0, 1)
MODELS.register_module(module=Hardsigmoid, name='Hardsigmoid')
@MODELS.register_module()
class Hardsigmoid(nn.Module):
def __init__(self, slope=0.2, offset=0.5, inplace=False):
super().__init__()
self.slope = slope
self.offset = offset
def forward(self, x):
return x.mul(self.slope).add(self.offset).clamp(0, 1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we should add a docstring.

Comment on lines 12 to 22
def __init__(self,
num_classes,
in_channels,
use_dw=True,
dropout_ratio=0.1,
align_corners=False,
upsample='intepolate',
out_channels=None,
conv_cfg=dict(type='Conv'),
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a docstring and type hints.

@xiexinch xiexinch changed the title Add pp_mobileseg model 添加 pp_mobileseg 模型 [CodeCamp2023-367] Add pp_mobileseg model Aug 7, 2023
Comment on lines 59 to 60
act_cfg(nn.Layer, optional): The activation layer of AAM.
inj_type(string, optional): The type of injection/AAM.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might add a brief introduction to 'AAM'.

Comment on lines 821 to 830
@MODELS.register_module()
class HSigmoid(nn.Module):

def __init__(self):
super().__init__()
self.relu = nn.ReLU6()

def forward(self, x):
return self.relu(x + 3) / 6

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 832 to 840
@MODELS.register_module()
class Hardswish(nn.Module):

def __init__(self, inplace=False):
super().__init__()
self.relu = nn.Hardswish(inplace=inplace)

def forward(self, x):
return self.relu(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xiexinch xiexinch merged commit 1e93796 into open-mmlab:dev-1.x Aug 9, 2023
7 checks passed
angiecao pushed a commit to angiecao/mmsegmentation that referenced this pull request Aug 31, 2023
emily-lin pushed a commit to emily-lin/mmsegmentation that referenced this pull request Nov 18, 2023
nahidnazifi87 pushed a commit to nahidnazifi87/mmsegmentation_playground that referenced this pull request Apr 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants