-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support Side Adapter Network #3232
[Feature] Support Side Adapter Network #3232
Conversation
…tation into angiecao/add_SAN_infer synchronize the remote branch
This reverts commit 89600c9.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might remove this file.
mmseg/utils/dist_utils.py
Outdated
def all_reduce_dict(py_dict, op='sum', group=None, to_float=True): | ||
"""Apply all reduce function for python dict object. | ||
|
||
The code is modified from https://github.com/Megvii- | ||
BaseDetection/YOLOX/blob/main/yolox/utils/allreduce_norm.py. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method seems not in use.
mmseg/utils/dist_utils.py
Outdated
def sync_random_seed(seed=None, device='cuda'): | ||
"""Make sure different ranks share the same seed. | ||
|
||
All workers must call this function, otherwise it will deadlock. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method seems not in use.
mmseg/utils/dist_utils.py
Outdated
@functools.lru_cache() | ||
def _get_global_gloo_group(): | ||
"""Return a process group based on gloo backend, containing all the ranks | ||
The result is cached.""" | ||
if dist.get_backend() == 'nccl': | ||
return dist.new_group(backend='gloo') | ||
else: | ||
return dist.group.WORLD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method seems not in use.
mmseg/utils/dist_utils.py
Outdated
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): | ||
"""Allreduce gradients. | ||
|
||
Args: | ||
params (list[torch.Parameters]): List of parameters of a model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same as other methods, it's not in use.
mmseg/utils/tokenizer.py
Outdated
""" | ||
import gzip | ||
import html | ||
# https://stackoverflow.com/q/62691279 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might remove this line.
class CLIPTextEncoder(BaseModule): | ||
"""A text encoder with transformer architecture to encode the label text. | ||
|
||
Args: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a reference link to the original implementation and add a license.
mmseg/models/utils/san_layers.py
Outdated
|
||
from mmseg.models.backbones.vit import TransformerEncoderLayer | ||
|
||
# Modified from https://github.com/MendelXu/SAN/blob/main/san/model/attn_helper.py # noqa: E501 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should move this line to the top of this file.
mmseg/models/utils/san_layers.py
Outdated
class LayerNorm(nn.Module): | ||
"""A LayerNorm variant, popularized by Transformers, that performs point- | ||
wise mean and variance normalization over the channel dimension for inputs | ||
that have shape (batch_size, channels, height, width). | ||
|
||
https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use nn.LayerNorm
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this method can be replaced by using nn.LayerNorm
and torch.permute
twice. But I think adding this LayerNorm variant might be a better choice.
This variant is popularized by Transformers. It supports that the input shape can be directly (B, C, H, W). The input and output of the vision transformer layers are also in this shape, so there is no need to constantly adjust inputs and outputs between network layers as when using nn.LayerNorm
.
In order to distinguish it from nn.LayerNorm
, I changed the name to LayerNorm2d
.
qkv_bias (int): Whether to use bias in multihead-attention. | ||
Default: True. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qkv_bias (int): Whether to use bias in multihead-attention. | |
Default: True. | |
qkv_bias (bool): Whether to use bias in multihead-attention. | |
Default: True. |
def init_para(self): | ||
if hasattr(self, 'sos_token'): | ||
nn.init.normal_(self.sos_token, std=0.02) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not in use.
def forward(self, bias, feature): | ||
"""Forward function.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might add type hints and docstring.
mmseg/models/utils/san_layers.py
Outdated
def cross_attn_layer(self: TransformerEncoderLayer, x, mem, attn_bias): | ||
"""Implementation of transformer layer with cross attention | ||
Args: | ||
self (TransformerEncoderLayer): The Module of vision transformer layer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is suggested that self
be renamed something else.
mmseg/models/utils/san_layers.py
Outdated
def cross_attn_with_self_bias(self, query, key, value, attn_mask=None): | ||
"""Implementation of cross attention layer which shares the embedding | ||
weights with self-attention. | ||
|
||
Args: | ||
self: self-attention layer | ||
query, key, value: map a query and a set of key-value pairs to | ||
an output. See "Attention Is All You Need" for more details. | ||
attn_mask: mask that prevents attention to certain positions. | ||
""" | ||
return cross_attn_with_self_bias_func( | ||
query, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the two methods be combined into one? And then the naming of the self
param is a bit misleading.
embed_dims=embed_dims, | ||
feedforward_channels=mlp_ratio * embed_dims, | ||
act_cfg=act_cfg), | ||
operation_order=('norm', 'self_attn', 'norm', 'ffn'))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use self.cross_attn to control whether to add 'cross_attn' to the operation_order?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation of cross attention for SAN implementations differs from that in mmcv. It also contains the process of calculating self-attention for query.
Concatenate self-attention weight of query with cross attention weight.
self_weight = (q * q_k).sum(dim=-1, keepdim=True)
total_attn_output_weights = torch.cat([attn_output_weights, self_weight], dim=-1)
total_attn_output_weights = F.softmax(total_attn_output_weights, dim=-1)
Add weighted query to the final attention output.
attn_output = ( attn_output + self_weight * q_v)
…tation into angiecao/add_SAN_infer
How to support open vocabulary prompting inference ? |
The implementation of prompt engineering is in function template_encode in the file mmseg/models/text_encoder/clip_text_encoder.py |
@angiecao Wow,Thank you so much ! But if I just want to give some categories when predicting, and only predict those categories, How should I do? like this python predict.py test/xxx.jpg configs/xxx.py work_dirs/xxx.pth --class-names 'Oculus' 'Ukulele' --output ./pred/xxx.jpg |
## Motivation Support SAN for Open-Vocabulary Semantic Segmentation Paper: [Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242) official Code: [SAN](https://github.com/MendelXu/SAN) ## Modification - Added the parameters of backbone vit for implementing the image encoder of CLIP. - Added text encoder code. - Added segmentor multimodel encoder-decoder code for open-vocabulary semantic segmentation. - Added SideAdapterNetwork decode head code. - Added config files for train and inference. - Added tools for converting pretrained models. - Added loss implementation for mask classification model, such as SAN, Maskformer and remove dependency on mmdetection. - Added test units for text encoder, multimodel encoder-decoder, san decode head and hungarian_assigner. ## Use cases ### Convert Models **pretrained SAN model** The official pretrained model can be downloaded from [san_clip_vit_b_16.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_b_16.pth) and [san_clip_vit_large_14.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_large_14.pth). Use tools/model_converters/san2mmseg.py to convert offcial model into mmseg style. `python tools/model_converters/san2mmseg.py <MODEL_PATH> <OUTPUT_PATH>` **pretrained CLIP model** Use the CLIP model provided by openai to train SAN. The CLIP model can be download from [ViT-B-16.pt](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt) and [ViT-L-14-336px.pt](https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt). Use tools/model_converters/clip2mmseg.py to convert model into mmseg style. `python tools/model_converters/clip2mmseg.py <MODEL_PATH> <OUTPUT_PATH>` ### Inference test san_vit-base-16 model on coco-stuff164k dataset `python tools/test.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py <TRAINED_MODEL_PATH>` ### Train test san_vit-base-16 model on coco-stuff164k dataset `python tools/train.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py --cfg-options model.pretrained=<PRETRAINED_MODEL_PATH>` ## Comparision Results ### Train on COCO-Stuff164k | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 41.93 | 56.73 | 67.69 | | | mmseg | 41.93 | 56.84 | 67.84 | | san-vit-large14 | official | 45.57 | 59.52 | 69.76 | | | mmseg | 45.78 | 59.61 | 69.21 | ### Evaluate on Pascal Context | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 54.05 | 72.96 | 77.77 | | | mmseg | 54.04 | 73.74 | 77.71 | | san-vit-large14 | official | 57.53 | 77.56 | 78.89 | | | mmseg | 56.89 | 76.96 | 78.74 | ### Evaluate on Voc12Aug | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 93.86 | 96.61 | 97.11 | | | mmseg | 94.58 | 97.01 | 97.38 | | san-vit-large14 | official | 95.17 | 97.61 | 97.63 | | | mmseg | 95.58 | 97.75 | 97.79 | --------- Co-authored-by: CastleDream <35064479+CastleDream@users.noreply.github.com> Co-authored-by: yeedrag <46050186+yeedrag@users.noreply.github.com> Co-authored-by: Yang-ChangHui <71805205+Yang-Changhui@users.noreply.github.com> Co-authored-by: Xu CAO <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: xiexinch <xiexinch@outlook.com> Co-authored-by: 小飞猪 <106524776+ooooo-create@users.noreply.github.com>
## Motivation Support SAN for Open-Vocabulary Semantic Segmentation Paper: [Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242) official Code: [SAN](https://github.com/MendelXu/SAN) ## Modification - Added the parameters of backbone vit for implementing the image encoder of CLIP. - Added text encoder code. - Added segmentor multimodel encoder-decoder code for open-vocabulary semantic segmentation. - Added SideAdapterNetwork decode head code. - Added config files for train and inference. - Added tools for converting pretrained models. - Added loss implementation for mask classification model, such as SAN, Maskformer and remove dependency on mmdetection. - Added test units for text encoder, multimodel encoder-decoder, san decode head and hungarian_assigner. ## Use cases ### Convert Models **pretrained SAN model** The official pretrained model can be downloaded from [san_clip_vit_b_16.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_b_16.pth) and [san_clip_vit_large_14.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_large_14.pth). Use tools/model_converters/san2mmseg.py to convert offcial model into mmseg style. `python tools/model_converters/san2mmseg.py <MODEL_PATH> <OUTPUT_PATH>` **pretrained CLIP model** Use the CLIP model provided by openai to train SAN. The CLIP model can be download from [ViT-B-16.pt](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt) and [ViT-L-14-336px.pt](https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt). Use tools/model_converters/clip2mmseg.py to convert model into mmseg style. `python tools/model_converters/clip2mmseg.py <MODEL_PATH> <OUTPUT_PATH>` ### Inference test san_vit-base-16 model on coco-stuff164k dataset `python tools/test.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py <TRAINED_MODEL_PATH>` ### Train test san_vit-base-16 model on coco-stuff164k dataset `python tools/train.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py --cfg-options model.pretrained=<PRETRAINED_MODEL_PATH>` ## Comparision Results ### Train on COCO-Stuff164k | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 41.93 | 56.73 | 67.69 | | | mmseg | 41.93 | 56.84 | 67.84 | | san-vit-large14 | official | 45.57 | 59.52 | 69.76 | | | mmseg | 45.78 | 59.61 | 69.21 | ### Evaluate on Pascal Context | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 54.05 | 72.96 | 77.77 | | | mmseg | 54.04 | 73.74 | 77.71 | | san-vit-large14 | official | 57.53 | 77.56 | 78.89 | | | mmseg | 56.89 | 76.96 | 78.74 | ### Evaluate on Voc12Aug | | | mIoU | mAcc | pAcc | | --------------- | ----- | ----- | ----- | ----- | | san-vit-base16 | official | 93.86 | 96.61 | 97.11 | | | mmseg | 94.58 | 97.01 | 97.38 | | san-vit-large14 | official | 95.17 | 97.61 | 97.63 | | | mmseg | 95.58 | 97.75 | 97.79 | --------- Co-authored-by: CastleDream <35064479+CastleDream@users.noreply.github.com> Co-authored-by: yeedrag <46050186+yeedrag@users.noreply.github.com> Co-authored-by: Yang-ChangHui <71805205+Yang-Changhui@users.noreply.github.com> Co-authored-by: Xu CAO <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: xiexinch <xiexinch@outlook.com> Co-authored-by: 小飞猪 <106524776+ooooo-create@users.noreply.github.com>
Motivation
Support SAN for Open-Vocabulary Semantic Segmentation
Paper: Side Adapter Network for Open-Vocabulary Semantic Segmentation
official Code: SAN
Modification
Use cases
Convert Models
pretrained SAN model
The official pretrained model can be downloaded from san_clip_vit_b_16.pth and san_clip_vit_large_14.pth.
Use tools/model_converters/san2mmseg.py to convert offcial model into mmseg style.
python tools/model_converters/san2mmseg.py <MODEL_PATH> <OUTPUT_PATH>
pretrained CLIP model
Use the CLIP model provided by openai to train SAN. The CLIP model can be download from ViT-B-16.pt and ViT-L-14-336px.pt.
Use tools/model_converters/clip2mmseg.py to convert model into mmseg style.
python tools/model_converters/clip2mmseg.py <MODEL_PATH> <OUTPUT_PATH>
Inference
test san_vit-base-16 model on coco-stuff164k dataset
python tools/test.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py <TRAINED_MODEL_PATH>
Train
test san_vit-base-16 model on coco-stuff164k dataset
python tools/train.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py --cfg-options model.pretrained=<PRETRAINED_MODEL_PATH>
Comparision Results
Train on COCO-Stuff164k
Evaluate on Pascal Context
Evaluate on Voc12Aug