Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix MaskFormer and Mask2Former of MMDetection #9515

Merged
merged 19 commits into from
Jan 13, 2023

Conversation

Li-Qingyun
Copy link
Contributor

@Li-Qingyun Li-Qingyun commented Dec 19, 2022

Motivation

The DETR-related modules have been refactored in #8763, which causes breakings of MaskFormer and Mask2Former. The unit tests of them were deleted in #9089. This pr add the ut back, and fix the breakings. Note that this pr only fix the bugs, but not refactor the two detector. Their refactors require new prs and more time for designing scheme.

BC-breaking

The modifications may causes breaking of the MaskFormer and Mask2Former newly supported in open-mmlab/mmsegmentation#2215 and open-mmlab/mmsegmentation#2255 of mmsegmentation v1.0.0rc2.
(The weight in the model zoo should be converted.)

@Li-Qingyun Li-Qingyun changed the title [WIP] Fix maskformers [WIP] Fix MaskFormer and Mask2Former Dec 19, 2022
@Li-Qingyun
Copy link
Contributor Author

@Czm369 This work is directed by @jshilong.

@Li-Qingyun
Copy link
Contributor Author

Li-Qingyun commented Dec 20, 2022

Align Inference Accuracy

configs/maskformer/maskformer_r50_ms-16xb1-75e_coco.py
develop/new_maskformer_r50_mstrain_16x1_75e_coco_20220221_141956-bc2699cb.pth (need to be converted)

MaskFormer

+--------+--------+--------+--------+------------+
|        | PQ     | SQ     | RQ     | categories |
+--------+--------+--------+--------+------------+
| All    | 46.854 | 80.617 | 57.085 | 133        |
| Things | 51.089 | 81.510 | 61.853 | 80         |
| Stuff  | 40.463 | 79.269 | 49.888 | 53         |
+--------+--------+--------+--------+------------+
12/21 00:55:23 - mmengine - INFO - Epoch(test) [5000/5000]  coco_panoptic/PQ: 46.8544  coco_panoptic/SQ: 80.6174  coco_panoptic/RQ: 57.0852  coco_panoptic/PQ_th: 51.0888  coco_panoptic/SQ_th: 81.5105  coco_panoptic/RQ_th: 61.8532  coco_panoptic/PQ_st: 40.4628  coco_panoptic/SQ_st: 79.2693  coco_panoptic/RQ_st: 49.8884

Mask2Former

Instance
.\configs\mask2former\mask2former_r50_8xb2-lsj-50e_coco.py
.\develop\new_mask2former_r50_lsj_8x2_50e_coco_20220506_191028-8e96e88b.pth

12/21 02:38:20 - mmengine - INFO - Iter(test) [5000/5000]  coco/bbox_mAP: 0.4570  coco/bbox_mAP_50: 0.6510  coco/bbox_mAP_75: 0.4910  coco/bbox_mAP_s: 0.2680  coco/bbox_mAP_m: 0.4850  coco/bbox_mAP_l: 0.6260  coco/segm_mAP: 0.4290  coco/segm_mAP_50: 0.6530  coco/segm_mAP_75: 0.4600  coco/segm_mAP_s: 0.2200  coco/segm_mAP_m: 0.4630  coco/segm_mAP_l: 0.6470

Panoptic
.\configs\mask2former\mask2former_r50_8xb2-lsj-50e_coco-panoptic.py
.\develop\new_mask2former_r50_lsj_8x2_50e_coco-panoptic_20220326_224516-11a44721.pth (need to be converted)

+--------+--------+--------+--------+------------+
|        | PQ     | SQ     | RQ     | categories |
+--------+--------+--------+--------+------------+
| All    | 51.865 | 83.071 | 61.591 | 133        |
| Things | 57.737 | 84.043 | 68.129 | 80         |
| Stuff  | 43.003 | 81.604 | 51.722 | 53         |
+--------+--------+--------+--------+------------+
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.448
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.419
12/20 23:29:29 - mmengine - INFO - Iter(test) [5000/5000]  coco_panoptic/PQ: 51.8654  coco_panoptic/SQ: 83.0710  coco_panoptic/RQ: 61.5908  coco_panoptic/PQ_th: 57.7368  coco_panoptic/SQ_th: 84.0430  coco_panoptic/RQ_th: 68.1286  coco_panoptic/PQ_st: 43.0029  coco_panoptic/SQ_st: 81.6037  coco_panoptic/RQ_st: 51.7223  coco/bbox_mAP: 0.4480  coco/bbox_mAP_50: 0.6440  coco/bbox_mAP_75: 0.4800  coco/bbox_mAP_s: 0.2650  coco/bbox_mAP_m: 0.4800  coco/bbox_mAP_l: 0.6250  coco/segm_mAP: 0.4190  coco/segm_mAP_50: 0.6460  coco/segm_mAP_75: 0.4480  coco/segm_mAP_s: 0.2150  coco/segm_mAP_m: 0.4590  coco/segm_mAP_l: 0.6340

The checkpoint should be converted by the following script:

import json
from collections import OrderedDict

import torch

from mmengine.config import Config
from mmdet.models import build_detector
from mmdet.utils import register_all_modules
register_all_modules(init_default_scope=True)


def get_new_name(old_name: str):
    new_name = old_name

    if 'encoder.layers' in new_name:
        new_name = new_name.replace('attentions.0', 'self_attn')

    new_name = new_name.replace('ffns.0', 'ffn')

    if 'decoder.layers' in new_name:
        # for Mask2Former
        new_name = new_name.replace('attentions.0', 'cross_attn')
        new_name = new_name.replace('attentions.1', 'self_attn')
        # # for MaskFormer
        # new_name = new_name.replace('attentions.0', 'self_attn')
        # new_name = new_name.replace('attentions.1', 'cross_attn')

    return new_name


def cvt_sd(old_sd: OrderedDict):
    new_sd = OrderedDict()
    for name, param in old_sd.items():
        new_name = get_new_name(name)
        assert new_name not in new_sd
        new_sd[new_name] = param
    assert len(new_sd) == len(old_sd)
    return new_sd


if __name__ == '__main__':

    CFG_FILE = 'configs/mask2former/mask2former_r50_8xb2-lsj-50e_coco-panoptic.py'
    OLD_CKPT_FILENAME = 'mask2former_r50_lsj_8x2_50e_coco-panoptic_20220326_224516-11a44721.pth'
    OLD_CKPT_FILEPATH = 'develop/' + OLD_CKPT_FILENAME
    NEW_CKPT_FILEPATH = 'develop/new_' + OLD_CKPT_FILENAME

    cfg = Config.fromfile(CFG_FILE)
    model_cfg = cfg.model

    detector = build_detector(model_cfg)

    refer_sd = detector.state_dict()
    old_sd = torch.load(OLD_CKPT_FILEPATH)['state_dict']

    new_sd = cvt_sd(old_sd)

    new_names = sorted(list(refer_sd.keys()))
    cvt_names = sorted(list(new_sd.keys()))
    old_names = sorted(list(old_sd.keys()))

    # we should make cvt_names --> new_names
    json.dump(new_names, open(r'./develop/new_names.json', 'w'), indent='\n')
    json.dump(cvt_names, open(r'./develop/cvt_names.json', 'w'), indent='\n')
    json.dump(old_names, open(r'./develop/old_names.json', 'w'), indent='\n')

    new_ckpt = dict(state_dict=new_sd)
    torch.save(new_ckpt, NEW_CKPT_FILEPATH)
    print(f'{NEW_CKPT_FILEPATH} has been saved!')

@Li-Qingyun Li-Qingyun changed the title [WIP] Fix MaskFormer and Mask2Former [Fix] Fix MaskFormer and Mask2Former Dec 20, 2022
@jshilong jshilong assigned jshilong and unassigned Czm369 Dec 22, 2022
@jshilong
Copy link
Collaborator

please merge the refactor-detr and resolve the conflict

@Li-Qingyun
Copy link
Contributor Author

please merge the refactor-detr and resolve the conflict

@jshilong the merging has been finished and the conflicts has been resolved.

@ZwwWayne ZwwWayne added this to the 3.0.0rc6 milestone Jan 3, 2023
Copy link
Collaborator

@jshilong jshilong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZwwWayne ZwwWayne merged commit ade888f into open-mmlab:refactor-detr Jan 13, 2023
@Li-Qingyun Li-Qingyun deleted the fix-maskformers branch January 13, 2023 11:19
jshilong pushed a commit that referenced this pull request Jan 19, 2023
Co-authored-by: Kei-Chi Tse <109070650+KeiChiTse@users.noreply.github.com>
@Li-Qingyun Li-Qingyun changed the title [Fix] Fix MaskFormer and Mask2Former [Fix] Fix MaskFormer and Mask2Former of MMDetection Jan 30, 2023
MeowZheng added a commit to open-mmlab/mmsegmentation that referenced this pull request Feb 1, 2023
## Motivation

The DETR-related modules have been refactored in
open-mmlab/mmdetection#8763, which causes breakings of MaskFormer and
Mask2Former in both MMDetection (has been fixed in
open-mmlab/mmdetection#9515) and MMSegmentation. This pr fix the bugs in
MMSegmentation.

### TO-DO List

- [x] update configs
- [x] check and modify data flow
- [x] fix unit test
- [x] aligning inference
- [x] write a ckpt converter
- [x] write ckpt update script
- [x] update model zoo
- [x] update model link in readme
- [x] update
[faq.md](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/notes/faq.md#installation)

## Tips of Fixing other implementations based on MaskXFormer of mmseg

1. The Transformer modules should be built directly. The original
building with register manner has been refactored.
2. The config requires to be modified. Delete `type` and modify several
keys, according to the modifications in this pr.
3. The `batch_first` is set `True` uniformly in the new implementations.
Hence the data flow requires to be transposed and config of
`batch_first` needs to be modified.
4. The checkpoint trained on the old implementation should be converted
to be used in the new one.

### Convert script

```Python
import argparse
from copy import deepcopy
from collections import OrderedDict

import torch

from mmengine.config import Config
from mmseg.models import build_segmentor
from mmseg.utils import register_all_modules
register_all_modules(init_default_scope=True)


def parse_args():
    parser = argparse.ArgumentParser(
        description='MMSeg convert MaskXFormer model, by Li-Qingyun')
    parser.add_argument('Mask_what_former', type=int,
                        help='Mask what former, can be a `1` or `2`',
                        choices=[1, 2])
    parser.add_argument('CFG_FILE', help='config file path')
    parser.add_argument('OLD_CKPT_FILEPATH', help='old ckpt file path')
    parser.add_argument('NEW_CKPT_FILEPATH', help='new ckpt file path')
    args = parser.parse_args()
    return args


args = parse_args()

def get_new_name(old_name: str):
    new_name = old_name

    if 'encoder.layers' in new_name:
        new_name = new_name.replace('attentions.0', 'self_attn')

    new_name = new_name.replace('ffns.0', 'ffn')

    if 'decoder.layers' in new_name:

        if args.Mask_what_former == 2:
            # for Mask2Former
            new_name = new_name.replace('attentions.0', 'cross_attn')
            new_name = new_name.replace('attentions.1', 'self_attn')
        else:
            # for Mask2Former
            new_name = new_name.replace('attentions.0', 'self_attn')
            new_name = new_name.replace('attentions.1', 'cross_attn')

    return new_name
    
def cvt_sd(old_sd: OrderedDict):
    new_sd = OrderedDict()
    for name, param in old_sd.items():
        new_name = get_new_name(name)
        assert new_name not in new_sd
        new_sd[new_name] = param
    assert len(new_sd) == len(old_sd)
    return new_sd
    
if __name__ == '__main__':
    cfg = Config.fromfile(args.CFG_FILE)
    model_cfg = cfg.model

    segmentor = build_segmentor(model_cfg)

    refer_sd = segmentor.state_dict()
    old_ckpt = torch.load(args.OLD_CKPT_FILEPATH)
    old_sd = old_ckpt['state_dict']

    new_sd = cvt_sd(old_sd)
    print(segmentor.load_state_dict(new_sd))

    new_ckpt = deepcopy(old_ckpt)
    new_ckpt['state_dict'] = new_sd
    torch.save(new_ckpt, args.NEW_CKPT_FILEPATH)
    print(f'{args.NEW_CKPT_FILEPATH} has been saved!')
```

Usage:
```bash
# for example
python ckpt4pr2532.py 1 configs/maskformer/maskformer_r50-d32_8xb2-160k_ade20k-512x512.py original_ckpts/maskformer_r50-d32_8xb2-160k_ade20k-512x512_20221030_182724-cbd39cc1.pth cvt_outputs/maskformer_r50-d32_8xb2-160k_ade20k-512x512_20221030_182724.pth
python ckpt4pr2532.py 2 configs/mask2former/mask2former_r50_8xb2-160k_ade20k-512x512.py original_ckpts/mask2former_r50_8xb2-160k_ade20k-512x512_20221204_000055-4c62652d.pth cvt_outputs/mask2former_r50_8xb2-160k_ade20k-512x512_20221204_000055.pth
```

---------

Co-authored-by: MeowZheng <meowzheng@outlook.com>
yumion pushed a commit to yumion/mmdetection that referenced this pull request Jan 31, 2024
Co-authored-by: Kei-Chi Tse <109070650+KeiChiTse@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants