Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Side Adapter Network #3232

Merged
merged 67 commits into from
Sep 20, 2023

Conversation

angiecao
Copy link
Contributor

@angiecao angiecao commented Jul 25, 2023

Motivation

Support SAN for Open-Vocabulary Semantic Segmentation
Paper: Side Adapter Network for Open-Vocabulary Semantic Segmentation
official Code: SAN

Modification

  • Added the parameters of backbone vit for implementing the image encoder of CLIP.
  • Added text encoder code.
  • Added segmentor multimodel encoder-decoder code for open-vocabulary semantic segmentation.
  • Added SideAdapterNetwork decode head code.
  • Added config files for train and inference.
  • Added tools for converting pretrained models.
  • Added loss implementation for mask classification model, such as SAN, Maskformer and remove dependency on mmdetection.
  • Added test units for text encoder, multimodel encoder-decoder, san decode head and hungarian_assigner.

Use cases

Convert Models

pretrained SAN model
The official pretrained model can be downloaded from san_clip_vit_b_16.pth and san_clip_vit_large_14.pth.
Use tools/model_converters/san2mmseg.py to convert offcial model into mmseg style.
python tools/model_converters/san2mmseg.py <MODEL_PATH> <OUTPUT_PATH>

pretrained CLIP model
Use the CLIP model provided by openai to train SAN. The CLIP model can be download from ViT-B-16.pt and ViT-L-14-336px.pt.
Use tools/model_converters/clip2mmseg.py to convert model into mmseg style.
python tools/model_converters/clip2mmseg.py <MODEL_PATH> <OUTPUT_PATH>

Inference

test san_vit-base-16 model on coco-stuff164k dataset
python tools/test.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py <TRAINED_MODEL_PATH>

Train

test san_vit-base-16 model on coco-stuff164k dataset
python tools/train.py ./configs/san/san-vit-b16_coco-stuff164k-640x640.py --cfg-options model.pretrained=<PRETRAINED_MODEL_PATH>

Comparision Results

Train on COCO-Stuff164k

mIoU mAcc pAcc
san-vit-base16 official 41.93 56.73 67.69
mmseg 41.93 56.84 67.84
san-vit-large14 official 45.57 59.52 69.76
mmseg 45.78 59.61 69.21

Evaluate on Pascal Context

mIoU mAcc pAcc
san-vit-base16 official 54.05 72.96 77.77
mmseg 54.04 73.74 77.71
san-vit-large14 official 57.53 77.56 78.89
mmseg 56.89 76.96 78.74

Evaluate on Voc12Aug

mIoU mAcc pAcc
san-vit-base16 official 93.86 96.61 97.11
mmseg 94.58 97.01 97.38
san-vit-large14 official 95.17 97.61 97.63
mmseg 95.58 97.75 97.79

@CLAassistant
Copy link

CLAassistant commented Jul 25, 2023

CLA assistant check
All committers have signed the CLA.

@xiexinch xiexinch changed the base branch from main to dev-1.x July 31, 2023 09:32
@angiecao angiecao changed the title [WIP] Support Side Adapter Network [Feature] Support Side Adapter Network Sep 7, 2023
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might remove this file.

Comment on lines 89 to 93
def all_reduce_dict(py_dict, op='sum', group=None, to_float=True):
"""Apply all reduce function for python dict object.

The code is modified from https://github.com/Megvii-
BaseDetection/YOLOX/blob/main/yolox/utils/allreduce_norm.py.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method seems not in use.

Comment on lines 148 to 151
def sync_random_seed(seed=None, device='cuda'):
"""Make sure different ranks share the same seed.

All workers must call this function, otherwise it will deadlock.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method seems not in use.

Comment on lines 79 to 86
@functools.lru_cache()
def _get_global_gloo_group():
"""Return a process group based on gloo backend, containing all the ranks
The result is cached."""
if dist.get_backend() == 'nccl':
return dist.new_group(backend='gloo')
else:
return dist.group.WORLD
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method seems not in use.

Comment on lines 37 to 41
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
"""Allreduce gradients.

Args:
params (list[torch.Parameters]): List of parameters of a model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same as other methods, it's not in use.

"""
import gzip
import html
# https://stackoverflow.com/q/62691279
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might remove this line.

Comment on lines 18 to 21
class CLIPTextEncoder(BaseModule):
"""A text encoder with transformer architecture to encode the label text.

Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a reference link to the original implementation and add a license.


from mmseg.models.backbones.vit import TransformerEncoderLayer

# Modified from https://github.com/MendelXu/SAN/blob/main/san/model/attn_helper.py # noqa: E501
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should move this line to the top of this file.

Comment on lines 386 to 392
class LayerNorm(nn.Module):
"""A LayerNorm variant, popularized by Transformers, that performs point-
wise mean and variance normalization over the channel dimension for inputs
that have shape (batch_size, channels, height, width).

https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use nn.LayerNorm instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this method can be replaced by using nn.LayerNorm and torch.permute twice. But I think adding this LayerNorm variant might be a better choice.
This variant is popularized by Transformers. It supports that the input shape can be directly (B, C, H, W). The input and output of the vision transformer layers are also in this shape, so there is no need to constantly adjust inputs and outputs between network layers as when using nn.LayerNorm.
In order to distinguish it from nn.LayerNorm, I changed the name to LayerNorm2d.

Comment on lines 293 to 294
qkv_bias (int): Whether to use bias in multihead-attention.
Default: True.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
qkv_bias (int): Whether to use bias in multihead-attention.
Default: True.
qkv_bias (bool): Whether to use bias in multihead-attention.
Default: True.

Comment on lines 376 to 379
def init_para(self):
if hasattr(self, 'sos_token'):
nn.init.normal_(self.sos_token, std=0.02)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not in use.

Comment on lines 421 to 422
def forward(self, bias, feature):
"""Forward function."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might add type hints and docstring.

Comment on lines 368 to 371
def cross_attn_layer(self: TransformerEncoderLayer, x, mem, attn_bias):
"""Implementation of transformer layer with cross attention
Args:
self (TransformerEncoderLayer): The Module of vision transformer layer.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is suggested that self be renamed something else.

Comment on lines 16 to 27
def cross_attn_with_self_bias(self, query, key, value, attn_mask=None):
"""Implementation of cross attention layer which shares the embedding
weights with self-attention.

Args:
self: self-attention layer
query, key, value: map a query and a set of key-value pairs to
an output. See "Attention Is All You Need" for more details.
attn_mask: mask that prevents attention to certain positions.
"""
return cross_attn_with_self_bias_func(
query,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the two methods be combined into one? And then the naming of the self param is a bit misleading.

embed_dims=embed_dims,
feedforward_channels=mlp_ratio * embed_dims,
act_cfg=act_cfg),
operation_order=('norm', 'self_attn', 'norm', 'ffn')))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use self.cross_attn to control whether to add 'cross_attn' to the operation_order?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation of cross attention for SAN implementations differs from that in mmcv. It also contains the process of calculating self-attention for query.
Concatenate self-attention weight of query with cross attention weight.

    self_weight = (q * q_k).sum(dim=-1, keepdim=True) 
    total_attn_output_weights = torch.cat([attn_output_weights, self_weight], dim=-1)
    total_attn_output_weights = F.softmax(total_attn_output_weights, dim=-1)

Add weighted query to the final attention output.

    attn_output = ( attn_output + self_weight * q_v) 

@xiexinch xiexinch merged commit 608e319 into open-mmlab:dev-1.x Sep 20, 2023
2 checks passed
@zsc1220
Copy link

zsc1220 commented Oct 13, 2023

How to support open vocabulary prompting inference ?

@angiecao
Copy link
Contributor Author

How to support open vocabulary prompting inference ?

  1. If you want a custom set of category names, you can define model.text_encoder.vocabulary in the config file and set model.text_encoder.dataset_name to None
model = dict(
    text_encoder=dict(dataset_name=None, 
                      vocabulary=['classA', 'classB', 'classC']))
  1. If you want to set templates of prompts, you can define a new list of templates in the file mmseg/utils/get_templates.py and change model.text_encoder.templates to your template name in the config file.
PREDEFINED_TEMPLATES = {
    'custom':[
        'a photo of a {}.',
        'This is a photo of a {}',
        'There is a {} in the scene',
    ],
}
model = dict(text_encoder=dict(templates='custom'))

The implementation of prompt engineering is in function template_encode in the file mmseg/models/text_encoder/clip_text_encoder.py

@zsc1220
Copy link

zsc1220 commented Oct 13, 2023

@angiecao Wow,Thank you so much ! But if I just want to give some categories when predicting, and only predict those categories, How should I do? like this

python predict.py test/xxx.jpg configs/xxx.py work_dirs/xxx.pth --class-names 'Oculus' 'Ukulele'  --output ./pred/xxx.jpg

emily-lin pushed a commit to emily-lin/mmsegmentation that referenced this pull request Nov 18, 2023
## Motivation
Support SAN for Open-Vocabulary Semantic Segmentation
Paper: [Side Adapter Network for Open-Vocabulary Semantic
Segmentation](https://arxiv.org/abs/2302.12242)
official Code: [SAN](https://github.com/MendelXu/SAN)

## Modification
- Added the parameters of backbone vit for implementing the image
encoder of CLIP.
- Added text encoder code.
- Added segmentor multimodel encoder-decoder code for open-vocabulary
semantic segmentation.
- Added SideAdapterNetwork decode head code.
- Added config files for train and inference.
- Added tools for converting pretrained models.
- Added loss implementation for mask classification model, such as SAN,
Maskformer and remove dependency on mmdetection.
- Added test units for text encoder, multimodel encoder-decoder, san
decode head and hungarian_assigner.

## Use cases
### Convert Models
**pretrained SAN model**
The official pretrained model can be downloaded from
[san_clip_vit_b_16.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_b_16.pth)
and
[san_clip_vit_large_14.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_large_14.pth).
Use tools/model_converters/san2mmseg.py to convert offcial model into
mmseg style.
`python tools/model_converters/san2mmseg.py <MODEL_PATH> <OUTPUT_PATH>`

**pretrained CLIP model**
Use the CLIP model provided by openai to train SAN. The CLIP model can
be download from
[ViT-B-16.pt](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt)
and
[ViT-L-14-336px.pt](https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt).
Use tools/model_converters/clip2mmseg.py to convert model into mmseg
style.
`python tools/model_converters/clip2mmseg.py <MODEL_PATH> <OUTPUT_PATH>`

### Inference
test san_vit-base-16 model on coco-stuff164k dataset
`python tools/test.py
./configs/san/san-vit-b16_coco-stuff164k-640x640.py
<TRAINED_MODEL_PATH>`

### Train
test san_vit-base-16 model on coco-stuff164k dataset
`python tools/train.py
./configs/san/san-vit-b16_coco-stuff164k-640x640.py --cfg-options
model.pretrained=<PRETRAINED_MODEL_PATH>`

## Comparision Results
### Train on COCO-Stuff164k
|                 |       | mIoU  | mAcc  | pAcc  |
| --------------- | ----- | ----- | ----- | ----- |
| san-vit-base16  | official  | 41.93 | 56.73 | 67.69 |
|                 | mmseg | 41.93 | 56.84 | 67.84 |
| san-vit-large14 | official  | 45.57 | 59.52 | 69.76 |
|                 | mmseg | 45.78 | 59.61 | 69.21 |

### Evaluate on Pascal Context
|                 |       | mIoU  | mAcc  | pAcc  |
| --------------- | ----- | ----- | ----- | ----- |
| san-vit-base16  | official  | 54.05 | 72.96 | 77.77 |
|                 | mmseg | 54.04 | 73.74 | 77.71 |
| san-vit-large14 | official  | 57.53 | 77.56 | 78.89 |
|                 | mmseg | 56.89 | 76.96 | 78.74 |

### Evaluate on Voc12Aug
|                 |       | mIoU  | mAcc  | pAcc  |
| --------------- | ----- | ----- | ----- | ----- |
| san-vit-base16  | official  | 93.86 | 96.61 | 97.11 |
|                 | mmseg | 94.58 | 97.01 | 97.38 |
| san-vit-large14 | official  | 95.17 | 97.61 | 97.63 |
|                 | mmseg | 95.58 | 97.75 | 97.79 |

---------

Co-authored-by: CastleDream <35064479+CastleDream@users.noreply.github.com>
Co-authored-by: yeedrag <46050186+yeedrag@users.noreply.github.com>
Co-authored-by: Yang-ChangHui <71805205+Yang-Changhui@users.noreply.github.com>
Co-authored-by: Xu CAO <49406546+SheffieldCao@users.noreply.github.com>
Co-authored-by: xiexinch <xiexinch@outlook.com>
Co-authored-by: 小飞猪 <106524776+ooooo-create@users.noreply.github.com>
nahidnazifi87 pushed a commit to nahidnazifi87/mmsegmentation_playground that referenced this pull request Apr 5, 2024
## Motivation
Support SAN for Open-Vocabulary Semantic Segmentation
Paper: [Side Adapter Network for Open-Vocabulary Semantic
Segmentation](https://arxiv.org/abs/2302.12242)
official Code: [SAN](https://github.com/MendelXu/SAN)

## Modification
- Added the parameters of backbone vit for implementing the image
encoder of CLIP.
- Added text encoder code.
- Added segmentor multimodel encoder-decoder code for open-vocabulary
semantic segmentation.
- Added SideAdapterNetwork decode head code.
- Added config files for train and inference.
- Added tools for converting pretrained models.
- Added loss implementation for mask classification model, such as SAN,
Maskformer and remove dependency on mmdetection.
- Added test units for text encoder, multimodel encoder-decoder, san
decode head and hungarian_assigner.

## Use cases
### Convert Models
**pretrained SAN model**
The official pretrained model can be downloaded from
[san_clip_vit_b_16.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_b_16.pth)
and
[san_clip_vit_large_14.pth](https://huggingface.co/Mendel192/san/blob/main/san_vit_large_14.pth).
Use tools/model_converters/san2mmseg.py to convert offcial model into
mmseg style.
`python tools/model_converters/san2mmseg.py <MODEL_PATH> <OUTPUT_PATH>`

**pretrained CLIP model**
Use the CLIP model provided by openai to train SAN. The CLIP model can
be download from
[ViT-B-16.pt](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt)
and
[ViT-L-14-336px.pt](https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt).
Use tools/model_converters/clip2mmseg.py to convert model into mmseg
style.
`python tools/model_converters/clip2mmseg.py <MODEL_PATH> <OUTPUT_PATH>`

### Inference
test san_vit-base-16 model on coco-stuff164k dataset
`python tools/test.py
./configs/san/san-vit-b16_coco-stuff164k-640x640.py
<TRAINED_MODEL_PATH>`

### Train
test san_vit-base-16 model on coco-stuff164k dataset
`python tools/train.py
./configs/san/san-vit-b16_coco-stuff164k-640x640.py --cfg-options
model.pretrained=<PRETRAINED_MODEL_PATH>`

## Comparision Results
### Train on COCO-Stuff164k
|                 |       | mIoU  | mAcc  | pAcc  |
| --------------- | ----- | ----- | ----- | ----- |
| san-vit-base16  | official  | 41.93 | 56.73 | 67.69 |
|                 | mmseg | 41.93 | 56.84 | 67.84 |
| san-vit-large14 | official  | 45.57 | 59.52 | 69.76 |
|                 | mmseg | 45.78 | 59.61 | 69.21 |

### Evaluate on Pascal Context
|                 |       | mIoU  | mAcc  | pAcc  |
| --------------- | ----- | ----- | ----- | ----- |
| san-vit-base16  | official  | 54.05 | 72.96 | 77.77 |
|                 | mmseg | 54.04 | 73.74 | 77.71 |
| san-vit-large14 | official  | 57.53 | 77.56 | 78.89 |
|                 | mmseg | 56.89 | 76.96 | 78.74 |

### Evaluate on Voc12Aug
|                 |       | mIoU  | mAcc  | pAcc  |
| --------------- | ----- | ----- | ----- | ----- |
| san-vit-base16  | official  | 93.86 | 96.61 | 97.11 |
|                 | mmseg | 94.58 | 97.01 | 97.38 |
| san-vit-large14 | official  | 95.17 | 97.61 | 97.63 |
|                 | mmseg | 95.58 | 97.75 | 97.79 |

---------

Co-authored-by: CastleDream <35064479+CastleDream@users.noreply.github.com>
Co-authored-by: yeedrag <46050186+yeedrag@users.noreply.github.com>
Co-authored-by: Yang-ChangHui <71805205+Yang-Changhui@users.noreply.github.com>
Co-authored-by: Xu CAO <49406546+SheffieldCao@users.noreply.github.com>
Co-authored-by: xiexinch <xiexinch@outlook.com>
Co-authored-by: 小飞猪 <106524776+ooooo-create@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants