diff --git a/README.md b/README.md index ac3c0c7f1c9..296e5f1949e 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,10 @@ English | [简体中文](README_zh-CN.md) +
+ +
+ ## Introduction MMDetection is an open source object detection toolbox based on PyTorch. It is diff --git a/README_zh-CN.md b/README_zh-CN.md index 9ed79d347dd..4ee964f4b21 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -61,6 +61,10 @@ +
+ +
+ ## 简介 MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [OpenMMLab](https://openmmlab.com/) 项目的一部分。 diff --git a/configs/_base_/datasets/ade20k_instance.py b/configs/_base_/datasets/ade20k_instance.py new file mode 100644 index 00000000000..57f657aa67f --- /dev/null +++ b/configs/_base_/datasets/ade20k_instance.py @@ -0,0 +1,53 @@ +# dataset settings +dataset_type = 'ADE20KInstanceDataset' +data_root = 'data/ADEChallengeData2016/' + +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/ADEChallengeData2016/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='Resize', scale=(2560, 640), keep_ratio=True), + # If you don't have a gt annotation, delete the pipeline + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='ade20k_instance_val.json', + data_prefix=dict(img='images/validation'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type='CocoMetric', + ann_file=data_root + 'ade20k_instance_val.json', + metric=['bbox', 'segm'], + format_only=False, + backend_args=backend_args) +test_evaluator = val_evaluator diff --git a/configs/_base_/datasets/ade20k_panoptic.py b/configs/_base_/datasets/ade20k_panoptic.py index 7672d5d99fc..7be5ddd7f07 100644 --- a/configs/_base_/datasets/ade20k_panoptic.py +++ b/configs/_base_/datasets/ade20k_panoptic.py @@ -4,18 +4,9 @@ backend_args = None -train_pipeline = [ - dict(type='LoadImageFromFile', backend_args=backend_args), - dict(type='LoadPanopticAnnotations', backend_args=backend_args), - # TODO: the performance of `FixScaleResize` need to check. - dict(type='FixScaleResize', scale=(2560, 640), backend_args=backend_args), - dict(type='RandomCrop', crop_size=(640, 640), crop_type='absolute'), - dict(type='RandomFlip', prob=0.5), - dict(type='PackDetInputs') -] test_pipeline = [ dict(type='LoadImageFromFile', backend_args=backend_args), - dict(type='Resize', scale=(640, 640), keep_ratio=True), + dict(type='Resize', scale=(2560, 640), keep_ratio=True), dict(type='LoadPanopticAnnotations', backend_args=backend_args), dict( type='PackDetInputs', @@ -23,24 +14,10 @@ 'scale_factor')) ] -train_dataloader = dict( - batch_size=4, - num_workers=2, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=True), - batch_sampler=dict(type='AspectRatioBatchSampler'), - dataset=dict( - type=dataset_type, - data_root=data_root, - ann_file='ade20k_panoptic_train.json', - data_prefix=dict(img='images/training/', seg='ade20k_panoptic_train/'), - filter_cfg=dict(filter_empty_gt=True, min_size=32), - pipeline=train_pipeline, - backend_args=backend_args)) val_dataloader = dict( batch_size=1, - num_workers=2, - persistent_workers=True, + num_workers=0, + persistent_workers=False, drop_last=False, sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( diff --git a/configs/_base_/datasets/ade20k_semantic.py b/configs/_base_/datasets/ade20k_semantic.py new file mode 100644 index 00000000000..522a7757041 --- /dev/null +++ b/configs/_base_/datasets/ade20k_semantic.py @@ -0,0 +1,48 @@ +dataset_type = 'ADE20KSegDataset' +data_root = 'data/ADEChallengeData2016/' + +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/ADEChallengeData2016/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='Resize', scale=(2048, 512), keep_ratio=True), + dict( + type='LoadAnnotations', + with_bbox=False, + with_mask=False, + with_seg=True, + reduce_zero_label=True), + dict( + type='PackDetInputs', meta_keys=('img_path', 'ori_shape', 'img_shape')) +] + +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='images/validation', + seg_map_path='annotations/validation'), + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict(type='SemSegMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator diff --git a/configs/_base_/datasets/coco_caption.py b/configs/_base_/datasets/coco_caption.py index 95ec03075b9..a1bd8983139 100644 --- a/configs/_base_/datasets/coco_caption.py +++ b/configs/_base_/datasets/coco_caption.py @@ -1,6 +1,6 @@ # data settings -dataset_type = 'COCOCaptionDataset' +dataset_type = 'CocoCaptionDataset' data_root = 'data/coco/' # Example to use different file client diff --git a/configs/_base_/datasets/coco_semantic.py b/configs/_base_/datasets/coco_semantic.py new file mode 100644 index 00000000000..944bbbaeaeb --- /dev/null +++ b/configs/_base_/datasets/coco_semantic.py @@ -0,0 +1,78 @@ +# dataset settings +dataset_type = 'CocoSegDataset' +data_root = 'data/coco/' + +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict( + type='LoadAnnotations', + with_bbox=False, + with_label=False, + with_seg=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='PackDetInputs') +] + +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict( + type='LoadAnnotations', + with_bbox=False, + with_label=False, + with_seg=True), + dict( + type='PackDetInputs', + meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor')) +] + +# For stuffthingmaps_semseg, please refer to +# `docs/en/user_guides/dataset_prepare.md` +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + batch_sampler=dict(type='AspectRatioBatchSampler'), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='train2017/', + seg_map_path='stuffthingmaps_semseg/train2017/'), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='val2017/', + seg_map_path='stuffthingmaps_semseg/val2017/'), + pipeline=test_pipeline)) + +test_dataloader = val_dataloader + +val_evaluator = dict(type='SemSegMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator diff --git a/configs/_base_/datasets/refcoco+.py b/configs/_base_/datasets/refcoco+.py index caa8369ba19..ae0278ddf6c 100644 --- a/configs/_base_/datasets/refcoco+.py +++ b/configs/_base_/datasets/refcoco+.py @@ -1,44 +1,24 @@ # dataset settings -dataset_type = 'RefCOCODataset' -data_root = 'data/refcoco/' +dataset_type = 'RefCocoDataset' +data_root = 'data/coco/' backend_args = None -train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='Resize', scale=(1333, 800), keep_ratio=True), - dict(type='RandomFlip', prob=0.5), - dict( - type='PackDetInputs', - meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', - 'scale_factor', 'text', 'image_id')) -] - test_pipeline = [ - dict(type='LoadImageFromFile'), + dict(type='LoadImageFromFile', backend_args=backend_args), dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict( + type='LoadAnnotations', + with_mask=True, + with_bbox=False, + with_seg=False, + with_label=False), dict( type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', - 'scale_factor', 'text', 'image_id')) + 'scale_factor', 'gt_masks', 'text')) ] -train_dataloader = dict( - batch_size=2, - num_workers=2, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=True), - batch_sampler=dict(type='AspectRatioBatchSampler'), - dataset=dict( - type=dataset_type, - data_root=data_root, - data_prefix=dict(img='train2014/'), - ann_file='refcoco+/instances.json', - split_file='refcoco+/refs(unc).p', - split='train', - pipeline=train_pipeline, - backend_args=backend_args)) - val_dataloader = dict( batch_size=1, num_workers=2, @@ -48,12 +28,12 @@ dataset=dict( type=dataset_type, data_root=data_root, - data_prefix=dict(img='train2014/'), + data_prefix=dict(img_path='train2014/'), ann_file='refcoco+/instances.json', split_file='refcoco+/refs(unc).p', split='val', - pipeline=test_pipeline, - backend_args=backend_args)) + text_mode='select_first', + pipeline=test_pipeline)) test_dataloader = dict( batch_size=1, @@ -64,11 +44,12 @@ dataset=dict( type=dataset_type, data_root=data_root, - data_prefix=dict(img='train2014/'), + data_prefix=dict(img_path='train2014/'), ann_file='refcoco+/instances.json', split_file='refcoco+/refs(unc).p', split='testA', # or 'testB' - pipeline=test_pipeline, - backend_args=backend_args)) + text_mode='select_first', + pipeline=test_pipeline)) -# TODO: set the metrics +val_evaluator = dict(type='RefSegMetric', metric=['cIoU', 'mIoU']) +test_evaluator = val_evaluator diff --git a/configs/_base_/datasets/refcoco.py b/configs/_base_/datasets/refcoco.py index c98ee8017d4..7b6caefa9a4 100644 --- a/configs/_base_/datasets/refcoco.py +++ b/configs/_base_/datasets/refcoco.py @@ -1,44 +1,24 @@ # dataset settings -dataset_type = 'RefCOCODataset' -data_root = 'data/refcoco/' +dataset_type = 'RefCocoDataset' +data_root = 'data/coco/' backend_args = None -train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='Resize', scale=(1333, 800), keep_ratio=True), - dict(type='RandomFlip', prob=0.5), - dict( - type='PackDetInputs', - meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', - 'scale_factor', 'text', 'image_id')) -] - test_pipeline = [ - dict(type='LoadImageFromFile'), + dict(type='LoadImageFromFile', backend_args=backend_args), dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict( + type='LoadAnnotations', + with_mask=True, + with_bbox=False, + with_seg=False, + with_label=False), dict( type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', - 'scale_factor', 'text', 'image_id')) + 'scale_factor', 'gt_masks', 'text')) ] -train_dataloader = dict( - batch_size=2, - num_workers=2, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=True), - batch_sampler=dict(type='AspectRatioBatchSampler'), - dataset=dict( - type=dataset_type, - data_root=data_root, - data_prefix=dict(img='train2014/'), - ann_file='refcoco/instances.json', - split_file='refcoco/refs(unc).p', - split='train', - pipeline=train_pipeline, - backend_args=backend_args)) - val_dataloader = dict( batch_size=1, num_workers=2, @@ -48,12 +28,12 @@ dataset=dict( type=dataset_type, data_root=data_root, - data_prefix=dict(img='train2014/'), + data_prefix=dict(img_path='train2014/'), ann_file='refcoco/instances.json', split_file='refcoco/refs(unc).p', split='val', - pipeline=test_pipeline, - backend_args=backend_args)) + text_mode='select_first', + pipeline=test_pipeline)) test_dataloader = dict( batch_size=1, @@ -64,11 +44,12 @@ dataset=dict( type=dataset_type, data_root=data_root, - data_prefix=dict(img='train2014/'), + data_prefix=dict(img_path='train2014/'), ann_file='refcoco/instances.json', split_file='refcoco/refs(unc).p', split='testA', # or 'testB' - pipeline=test_pipeline, - backend_args=backend_args)) + text_mode='select_first', + pipeline=test_pipeline)) -# TODO: set the metrics +val_evaluator = dict(type='RefSegMetric', metric=['cIoU', 'mIoU']) +test_evaluator = val_evaluator diff --git a/configs/_base_/datasets/refcocog.py b/configs/_base_/datasets/refcocog.py index 9a2a45ff8a6..19dbeef1cde 100644 --- a/configs/_base_/datasets/refcocog.py +++ b/configs/_base_/datasets/refcocog.py @@ -1,44 +1,24 @@ # dataset settings -dataset_type = 'RefCOCODataset' -data_root = 'data/refcoco/' +dataset_type = 'RefCocoDataset' +data_root = 'data/coco/' backend_args = None -train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='Resize', scale=(1333, 800), keep_ratio=True), - dict(type='RandomFlip', prob=0.5), - dict( - type='PackDetInputs', - meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', - 'scale_factor', 'text', 'image_id')) -] - test_pipeline = [ - dict(type='LoadImageFromFile'), + dict(type='LoadImageFromFile', backend_args=backend_args), dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict( + type='LoadAnnotations', + with_mask=True, + with_bbox=False, + with_seg=False, + with_label=False), dict( type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', - 'scale_factor', 'text', 'image_id')) + 'scale_factor', 'gt_masks', 'text')) ] -train_dataloader = dict( - batch_size=2, - num_workers=2, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=True), - batch_sampler=dict(type='AspectRatioBatchSampler'), - dataset=dict( - type=dataset_type, - data_root=data_root, - data_prefix=dict(img='train2014/'), - ann_file='refcocog/instances.json', - split_file='refcocog/refs(umd).p', - split='train', - pipeline=train_pipeline, - backend_args=backend_args)) - val_dataloader = dict( batch_size=1, num_workers=2, @@ -48,12 +28,12 @@ dataset=dict( type=dataset_type, data_root=data_root, - data_prefix=dict(img='train2014/'), + data_prefix=dict(img_path='train2014/'), ann_file='refcocog/instances.json', split_file='refcocog/refs(umd).p', split='val', - pipeline=test_pipeline, - backend_args=backend_args)) + text_mode='select_first', + pipeline=test_pipeline)) test_dataloader = dict( batch_size=1, @@ -64,11 +44,12 @@ dataset=dict( type=dataset_type, data_root=data_root, - data_prefix=dict(img='train2014/'), + data_prefix=dict(img_path='train2014/'), ann_file='refcocog/instances.json', split_file='refcocog/refs(umd).p', split='test', - pipeline=test_pipeline, - backend_args=backend_args)) + text_mode='select_first', + pipeline=test_pipeline)) -# TODO: set the metrics +val_evaluator = dict(type='RefSegMetric', metric=['cIoU', 'mIoU']) +test_evaluator = val_evaluator diff --git a/configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py b/configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py index 9be797f8482..34a818caefc 100644 --- a/configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py +++ b/configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py @@ -82,9 +82,9 @@ dict( type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', - 'scale_factor', 'caption', 'custom_entities')) + 'scale_factor', 'text', 'custom_entities')) ] val_dataloader = dict( - dataset=dict(pipeline=test_pipeline, return_caption=True)) + dataset=dict(pipeline=test_pipeline, return_classes=True)) test_dataloader = val_dataloader diff --git a/demo/image_demo.py b/demo/image_demo.py index 4c9163dc8dd..2e2c27adbf2 100644 --- a/demo/image_demo.py +++ b/demo/image_demo.py @@ -14,6 +14,20 @@ configs/rtmdet/rtmdet_s_8xb32-300e_coco.py \ --weights rtmdet_s_8xb32-300e_coco_20220905_161602-387a891e.pth + python demo/image_demo.py demo/demo.jpg \ + glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365 --texts bench + + python demo/image_demo.py demo/demo.jpg \ + glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365 --texts 'bench . car .' + + python demo/image_demo.py demo/demo.jpg \ + glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365 + --texts 'bench . car .' -c + + python demo/image_demo.py demo/demo.jpg \ + glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365 \ + --texts 'There are a lot of cars here.' + Visualize prediction results:: python demo/image_demo.py demo/demo.jpg rtmdet-ins-s --show @@ -46,6 +60,7 @@ def parse_args(): type=str, default='outputs', help='Output directory of images or prediction results.') + parser.add_argument('--texts', help='text prompt') parser.add_argument( '--device', default='cuda:0', help='Device used for inference') parser.add_argument( @@ -76,6 +91,14 @@ def parse_args(): default='none', choices=['coco', 'voc', 'citys', 'random', 'none'], help='Color palette used for visualization') + # only for GLIP + parser.add_argument( + '--custom-entities', + '-c', + action='store_true', + help='Whether to customize entity names? ' + 'If so, the input text should be ' + '"cls_name1 . cls_name2 . cls_name3 ." format') call_args = vars(parser.parse_args()) diff --git a/demo/multimodal_demo.py b/demo/multimodal_demo.py deleted file mode 100644 index 2dec7367135..00000000000 --- a/demo/multimodal_demo.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -"""MultiModal Demo. - -Example: - python demo/multimodal_demo.py demo/demo.jpg bench \ - configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py \ - https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_a_mmdet-b3654169.pth - - python demo/multimodal_demo.py demo/demo.jpg "bench . car . " \ - configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py \ - https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_a_mmdet-b3654169.pth - - python demo/multimodal_demo.py demo/demo.jpg "bench . car . " -c \ - configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py \ - https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_a_mmdet-b3654169.pth - - python demo/multimodal_demo.py demo/demo.jpg \ - "There are a lot of cars here." \ - configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py \ - https://download.openmmlab.com/mmdetection/v3.0/glip/glip_tiny_a_mmdet-b3654169.pth -""" - -import os.path as osp -from argparse import ArgumentParser - -import mmcv -from mmengine.utils import path - -from mmdet.apis import inference_detector, init_detector -from mmdet.registry import VISUALIZERS - - -def parse_args(): - parser = ArgumentParser() - parser.add_argument('img', help='Image path, include image file and URL.') - parser.add_argument('text', help='text prompt') - parser.add_argument('config', help='Config file') - parser.add_argument('checkpoint', help='Checkpoint file') - parser.add_argument( - '--out-dir', default='./output', help='Path to output file') - parser.add_argument( - '--device', default='cuda:0', help='Device used for inference') - parser.add_argument( - '--show', action='store_true', help='Show the detection results') - parser.add_argument( - '--score-thr', type=float, default=0.5, help='Bbox score threshold') - parser.add_argument( - '--custom-entities', - '-c', - action='store_true', - help='Whether to customize entity names? ' - 'If so, the input text should be ' - '"cls_name1 . cls_name2 . cls_name3 ." format') - args = parser.parse_args() - return args - - -def main(): - args = parse_args() - - # build the model from a config file and a checkpoint file - model = init_detector(args.config, args.checkpoint, device=args.device) - - result = inference_detector( - model, - args.img, - text_prompt=args.text, - custom_entities=args.custom_entities) - - visualizer = VISUALIZERS.build(model.cfg.visualizer) - - img = mmcv.imread(args.img) - img = mmcv.imconvert(img, 'bgr', 'rgb') - - out_file = None - if not args.show: - path.mkdir_or_exist(args.out_dir) - out_file = osp.join(args.out_dir, osp.basename(args.img)) - - visualizer.add_datasample( - 'results', - img, - data_sample=result, - draw_gt=False, - show=args.show, - wait_time=0, - out_file=out_file, - pred_score_thr=args.score_thr) - - if out_file: - print(f'\nResults have been saved at {osp.abspath(out_file)}') - - -if __name__ == '__main__': - main() diff --git a/docs/en/user_guides/dataset_prepare.md b/docs/en/user_guides/dataset_prepare.md index 7d960ba18ec..a3a33d11249 100644 --- a/docs/en/user_guides/dataset_prepare.md +++ b/docs/en/user_guides/dataset_prepare.md @@ -1,5 +1,7 @@ # Dataset Prepare +### Basic Detection Dataset Preparation + MMDetection supports multiple public datasets including COCO, Pascal VOC, CityScapes, and [more](../../../configs/_base_/datasets). Public datasets like [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/index.html) or mirror and [COCO](https://cocodataset.org/#download) are available from official websites or mirrors. Note: In the detection task, Pascal VOC 2012 is an extension of Pascal VOC 2007 without overlap, and we usually use them together. @@ -75,18 +77,127 @@ python tools/dataset_converters/cityscapes.py \ --out-dir ./data/cityscapes/annotations ``` +### COCO Caption Dataset Preparation + +COCO Caption uses the COCO2014 dataset image and uses the annotation of karpathy. + +At first, you need to download the COCO2014 dataset. + +```shell +python tools/misc/download_dataset.py --dataset-name coco2014 --unzip +``` + +The dataset will be downloaded to `data/coco` under the current path. Then download the annotation of karpathy. + +```shell +cd data/coco/annotations +wget https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json +wget https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json +wget https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json +wget https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json +wget https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json +``` + +The final directory structure of the dataset folder that can be directly used for training and testing is as follows: + +```text +mmdetection +├── data +│ ├── coco +│ │ ├── annotations +│ │ │ ├── coco_karpathy_train.json +│ │ │ ├── coco_karpathy_test.json +│ │ │ ├── coco_karpathy_val.json +│ │ │ ├── coco_karpathy_val_gt.json +│ │ │ ├── coco_karpathy_test_gt.json +│ │ ├── train2014 +│ │ ├── val2014 +│ │ ├── test2014 +``` + +### COCO Semantic Dataset Preparation + +There are two types of annotations for COCO semantic segmentation, which differ mainly in the definition of category names, so there are two ways to handle them. The first is to directly use the stuffthingmaps dataset, and the second is to use the panoptic dataset. + +**(1) Use stuffthingmaps dataset** + +The download link for this dataset is [stuffthingmaps_trainval2017](http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip). Please download and extract it to the `data/coco` folder. + +```text +mmdetection +├── data +│ ├── coco +│ │ ├── annotations +│ │ ├── train2017 +│ │ ├── val2017 +│ │ ├── test2017 +│ │ ├── stuffthingmaps +``` + +This dataset is different from the standard COCO category annotation in that it includes 172 classes: 80 "thing" classes, 91 "stuff" classes, and 1 "unlabeled" class. The description of each class can be found at https://github.com/nightrome/cocostuff/blob/master/labels.md. + +Although only 172 categories are annotated, the maximum label ID in `stuffthingmaps` is 182, and some categories in the middle are not annotated. In addition, the "unlabeled" category of class 0 is removed. Therefore, the relationship between the value at each position in the final `stuffthingmaps` image can be found at https://github.com/kazuto1011/deeplab-pytorch/blob/master/data/datasets/cocostuff/labels.txt. + +To train efficiently and conveniently for users, we need to remove 12 unannotated classes before starting training or evaluation. The names of these 12 classes are: `street sign, hat, shoe, eye glasses, plate, mirror, window, desk, door, blender, hair brush`. The category information that can be used for training and evaluation can be found in `mmdet/datasets/coco_semantic.py`. + +You can use `tools/dataset_converters/coco_stuff164k.py` to convert the downloaded `stuffthingmaps` to a dataset that can be directly used for training and evaluation. The directory structure of the converted dataset is as follows: + +```text +mmdetection +├── data +│ ├── coco +│ │ ├── annotations +│ │ ├── train2017 +│ │ ├── val2017 +│ │ ├── test2017 +│ │ ├── stuffthingmaps +│ │ ├── stuffthingmaps_semseg +``` + +`stuffthingmaps_semseg` is the newly generated COCO semantic segmentation dataset that can be directly used for training and testing. + +**(2) use panoptic dataset** + +The number of categories in the semantic segmentation dataset generated through panoptic annotation will be less than that generated using the `stuffthingmaps` dataset. First, you need to prepare the panoptic segmentation annotations, and then use the following script to complete the conversion. + +```shell +python tools/dataset_converters/prepare_coco_semantic_annos_from_panoptic_annos.py data/coco +``` + +The directory structure of the converted dataset is as follows: + +```text +mmdetection +├── data +│ ├── coco +│ │ ├── annotations +│ │ │ ├── panoptic_train2017.json +│ │ │ ├── panoptic_train2017 +│ │ │ ├── panoptic_val2017.json +│ │ │ ├── panoptic_val2017 +│ │ │ ├── panoptic_semseg_train2017 +│ │ │ ├── panoptic_semseg_val2017 +│ │ ├── train2017 +│ │ ├── val2017 +│ │ ├── test2017 +``` + +`panoptic_semseg_train2017` and `panoptic_semseg_val2017` are the newly generated COCO semantic segmentation datasets that can be directly used for training and testing. Note that their category information is the same as that of COCO panoptic segmentation, including both "thing" and "stuff" categories. + +### RefCOCO Dataset Preparation + The images and annotations of [RefCOCO](https://github.com/lichengunc/refer) series datasets can be download by running `tools/misc/download_dataset.py`: ```shell -python tools/misc/download_dataset.py --dataset-name refcoco --save-dir data/refcoco --unzip +python tools/misc/download_dataset.py --dataset-name refcoco --save-dir data/coco --unzip ``` -Then the directory should be like this. +Then the directory should be like this: ```text data -├── refcoco -│   ├── refcoco +├── coco +│ ├── refcoco │   │   ├── instances.json │   │   ├── refs(google).p │   │   └── refs(unc).p @@ -99,3 +210,73 @@ data │   │   └── refs(umd).p | |── train2014 ``` + +### ADE20K 2016 Dataset Preparation + +The images and annotations of [ADE20K](https://groups.csail.mit.edu/vision/datasets/ADE20K/) dataset can be download by running `tools/misc/download_dataset.py`: + +```shell +python tools/misc/download_dataset.py --dataset-name ade20k_2016 --save-dir data --unzip +``` + +Then move the annotations to the `data/ADEChallengeData2016` directory and run the preprocess script to produce the coco format annotations: + +```shell +mv data/annotations_instance data/ADEChallengeData2016/ +mv data/categoryMapping.txt data/ADEChallengeData2016/ +mv data/imgCatIds.json data/ADEChallengeData2016/ +python tools/dataset_converters/ade20k2coco.py data/ADEChallengeData2016 --task panoptic +python tools/dataset_converters/ade20k2coco.py data/ADEChallengeData2016 --task instance +``` + +The directory should be like this. + +```text +data +├── ADEChallengeData2016 +│   ├── ade20k_instance_train.json +│   ├── ade20k_instance_val.json +│   ├── ade20k_panoptic_train +| | ├── ADE_train_00000001.png +| | ├── ADE_train_00000002.png +| | ├── ... +│   ├── ade20k_panoptic_train.json +│   ├── ade20k_panoptic_val +| | ├── ADE_val_00000001.png +| | ├── ADE_val_00000002.png +| | ├── ... +│   ├── ade20k_panoptic_val.json +│   ├── annotations +| | ├── training +| | | ├── ADE_train_00000001.png +| | | ├── ADE_train_00000002.png +| | | ├── ... +| | ├── validation +| | | ├── ADE_val_00000001.png +| | | ├── ADE_val_00000002.png +| | | ├── ... +│   ├── annotations_instance +| | ├── training +| | | ├── ADE_train_00000001.png +| | | ├── ADE_train_00000002.png +| | | ├── ... +| | ├── validation +| | | ├── ADE_val_00000001.png +| | | ├── ADE_val_00000002.png +| | | ├── ... +│   ├── categoryMapping.txt +│   ├── images +│   | ├── training +| | | ├── ADE_train_00000001.jpg +| | | ├── ADE_train_00000002.jpg +| | | ├── ... +| | ├── validation +| | | ├── ADE_val_00000001.jpg +| | | ├── ADE_val_00000002.jpg +| | | ├── ... +│   ├── imgCatIds.json +│   ├── objectInfo150.txt +| |── sceneCategories.txt +``` + +The above folders include all data of ADE20K's semantic segmentation, instance segmentation, and panoptic segmentation. diff --git a/docs/zh_cn/user_guides/dataset_prepare.md b/docs/zh_cn/user_guides/dataset_prepare.md index b33ec3bd309..376008bfee2 100644 --- a/docs/zh_cn/user_guides/dataset_prepare.md +++ b/docs/zh_cn/user_guides/dataset_prepare.md @@ -1,5 +1,7 @@ ## 数据集准备 +### 基础检测数据集准备 + MMDetection 支持多个公共数据集,包括 [COCO](https://cocodataset.org/), [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC), [Cityscapes](https://www.cityscapes-dataset.com/) 和 [其他更多数据集](https://github.com/open-mmlab/mmdetection/tree/main/configs/_base_/datasets)。 一些公共数据集,比如 Pascal VOC 及其镜像数据集,或者 COCO 等数据集都可以从官方网站或者镜像网站获取。注意:在检测任务中,Pascal VOC 2012 是 Pascal VOC 2007 的无交集扩展,我们通常将两者一起使用。 我们建议将数据集下载,然后解压到项目外部的某个文件夹内,然后通过符号链接的方式,将数据集根目录链接到 `$MMDETECTION/data` 文件夹下, 如果你的文件夹结构和下方不同的话,你需要在配置文件中改变对应的路径。 @@ -71,3 +73,207 @@ python tools/dataset_converters/cityscapes.py \ --nproc 8 \ --out-dir ./data/cityscapes/annotations ``` + +### COCO Caption 数据集准备 + +COCO Caption 采用的是 COCO2014 数据集作为图片,并且使用了 karpathy 的标注, + +首先你需要下载 COCO2014 数据集 + +```shell +python tools/misc/download_dataset.py --dataset-name coco2014 --unzip +``` + +数据集会下载到当前路径的 `data/coco` 下。然后下载 karpathy 的标注 + +```shell +cd data/coco/annotations +wget https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json +wget https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json +wget https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json +wget https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json +wget https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json +``` + +最终直接可用于训练和测试的数据集文件夹结构如下: + +```text +mmdetection +├── data +│ ├── coco +│ │ ├── annotations +│ │ │ ├── coco_karpathy_train.json +│ │ │ ├── coco_karpathy_test.json +│ │ │ ├── coco_karpathy_val.json +│ │ │ ├── coco_karpathy_val_gt.json +│ │ │ ├── coco_karpathy_test_gt.json +│ │ ├── train2014 +│ │ ├── val2014 +│ │ ├── test2014 +``` + +### COCO semantic 数据集准备 + +COCO 语义分割有两种类型标注,主要差别在于类别名定义不一样,因此处理方式也有两种,第一种是直接使用 stuffthingmaps 数据集,第二种是使用 panoptic 数据集。 + +**(1) 使用 stuffthingmaps 数据集** + +该数据集的下载地址为 [stuffthingmaps_trainval2017](http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip),请下载后解压到 `data/coco` 文件夹下。 + +```text +mmdetection +├── data +│ ├── coco +│ │ ├── annotations +│ │ ├── train2017 +│ │ ├── val2017 +│ │ ├── test2017 +│ │ ├── stuffthingmaps +``` + +该数据集不同于标准的 COCO 类别标注,其包括 172 个类: 80 thing 类、91 stuff 类和 1 个 'unlabeled',其每个类别的说明见 https://github.com/nightrome/cocostuff/blob/master/labels.md + +虽然只标注了 172 个类别,但是 `stuffthingmaps` 中最大标签 id 是 182,中间有些类别是没有标注的,并且第 0 类的 `unlabeled` 类别被移除。因此最终的 `stuffthingmaps` 图片中每个位置的值对应的类别关系见 https://github.com/kazuto1011/deeplab-pytorch/blob/master/data/datasets/cocostuff/labels.txt + +考虑到训练高效和方便用户,在开启训练或者评估前,我们需要将没有标注的 12 个类移除,这 12 个类的名字为: `street sign、hat、shoe、eye glasses、plate、mirror、window、desk、door、blender、hair brush`,最终可用于训练和评估的类别信息见 `mmdet/datasets/coco_semantic.py` + +你可以使用 `tools/dataset_converters/coco_stuff164k.py` 来完成将下载的 `stuffthingmaps` 转换为直接可以训练和评估的数据集,转换后的数据集文件夹结构如下: + +```text +mmdetection +├── data +│ ├── coco +│ │ ├── annotations +│ │ ├── train2017 +│ │ ├── val2017 +│ │ ├── test2017 +│ │ ├── stuffthingmaps +│ │ ├── stuffthingmaps_semseg +``` + +`stuffthingmaps_semseg` 即为新生成的可以直接训练和测试的 COCO 语义分割数据集。 + +**(2) 使用 panoptic 数据集** + +通过 panoptic 标注生成的语义分割数据集类别数相比使用 `stuffthingmaps` 数据集生成的会少一些。首先你需要准备全景分割标注,然后使用如下脚本完成转换 + +```shell +python tools/dataset_converters/prepare_coco_semantic_annos_from_panoptic_annos.py data/coco +``` + +转换后的数据集文件夹结构如下: + +```text +mmdetection +├── data +│ ├── coco +│ │ ├── annotations +│ │ │ ├── panoptic_train2017.json +│ │ │ ├── panoptic_train2017 +│ │ │ ├── panoptic_val2017.json +│ │ │ ├── panoptic_val2017 +│ │ │ ├── panoptic_semseg_train2017 +│ │ │ ├── panoptic_semseg_val2017 +│ │ ├── train2017 +│ │ ├── val2017 +│ │ ├── test2017 +``` + +`panoptic_semseg_train2017` 和 `panoptic_semseg_val2017` 即为新生成的可以直接训练和测试的 COCO 语义分割数据集。注意其类别信息就是 COCO 全景分割的类别信息,包括 thing 和 stuff。 + +### RefCOCO 数据集准备 + +[RefCOCO](https://github.com/lichengunc/refer)系列数据集的图像和注释可以通过运行 `tools/misc/download_dataset.py` 下载: + +```shell +python tools/misc/download_dataset.py --dataset-name refcoco --save-dir data/coco --unzip +``` + +然后,目录应该是这样的: + +```text +data +├── coco +│ ├── refcoco +│   │   ├── instances.json +│   │   ├── refs(google).p +│   │   └── refs(unc).p +│   ├── refcoco+ +│   │   ├── instances.json +│   │   └── refs(unc).p +│   ├── refcocog +│   │   ├── instances.json +│   │   ├── refs(google).p +│   │   └── refs(umd).p +| |── train2014 +``` + +### ADE20K 数据集准备 + +[ADE20K](http://groups.csail.mit.edu/vision/datasets/ADE20K/)数据集的图像和注释可以通过运行 `tools/misc/download_dataset.py` 下载: + +```shell +python tools/misc/download_dataset.py --dataset-name ade20k_2016 --save-dir data --unzip +``` + +然后将注释移至`data/ADEChallengeData2016`目录,并运行预处理脚本以产生coco格式注释: + +```shell +mv data/annotations_instance data/ADEChallengeData2016/ +mv data/categoryMapping.txt data/ADEChallengeData2016/ +mv data/imgCatIds.json data/ADEChallengeData2016/ +python tools/dataset_converters/ade20k2coco.py data/ADEChallengeData2016 --task panoptic +python tools/dataset_converters/ade20k2coco.py data/ADEChallengeData2016 --task instance +``` + +然后,目录应该是这样的: + +```text +data +├── ADEChallengeData2016 +│   ├── ade20k_instance_train.json +│   ├── ade20k_instance_val.json +│   ├── ade20k_panoptic_train +| | ├── ADE_train_00000001.png +| | ├── ADE_train_00000002.png +| | ├── ... +│   ├── ade20k_panoptic_train.json +│   ├── ade20k_panoptic_val +| | ├── ADE_val_00000001.png +| | ├── ADE_val_00000002.png +| | ├── ... +│   ├── ade20k_panoptic_val.json +│   ├── annotations +| | ├── training +| | | ├── ADE_train_00000001.png +| | | ├── ADE_train_00000002.png +| | | ├── ... +| | ├── validation +| | | ├── ADE_val_00000001.png +| | | ├── ADE_val_00000002.png +| | | ├── ... +│   ├── annotations_instance +| | ├── training +| | | ├── ADE_train_00000001.png +| | | ├── ADE_train_00000002.png +| | | ├── ... +| | ├── validation +| | | ├── ADE_val_00000001.png +| | | ├── ADE_val_00000002.png +| | | ├── ... +│   ├── categoryMapping.txt +│   ├── images +│   | ├── training +| | | ├── ADE_train_00000001.jpg +| | | ├── ADE_train_00000002.jpg +| | | ├── ... +| | ├── validation +| | | ├── ADE_val_00000001.jpg +| | | ├── ADE_val_00000002.jpg +| | | ├── ... +│   ├── imgCatIds.json +│   ├── objectInfo150.txt +| |── sceneCategories.txt +``` + +上述文件夹包括ADE20K的语义分割、实例分割和泛在分割的所有数据。 diff --git a/mmdet/apis/det_inferencer.py b/mmdet/apis/det_inferencer.py index da4ad171283..b0af7b753e5 100644 --- a/mmdet/apis/det_inferencer.py +++ b/mmdet/apis/det_inferencer.py @@ -270,7 +270,16 @@ def _get_chunk_data(self, inputs: Iterable, chunk_size: int): chunk_data = [] for _ in range(chunk_size): inputs_ = next(inputs_iter) - chunk_data.append((inputs_, self.pipeline(inputs_))) + if isinstance(inputs_, dict): + if 'img' in inputs_: + ori_inputs_ = inputs_['img'] + else: + ori_inputs_ = inputs_['img_path'] + chunk_data.append( + (ori_inputs_, + self.pipeline(copy.deepcopy(inputs_)))) + else: + chunk_data.append((inputs_, self.pipeline(inputs_))) yield chunk_data except StopIteration: if chunk_data: @@ -280,20 +289,27 @@ def _get_chunk_data(self, inputs: Iterable, chunk_size: int): # TODO: Video and Webcam are currently not supported and # may consume too much memory if your input folder has a lot of images. # We will be optimized later. - def __call__(self, - inputs: InputsType, - batch_size: int = 1, - return_vis: bool = False, - show: bool = False, - wait_time: int = 0, - no_save_vis: bool = False, - draw_pred: bool = True, - pred_score_thr: float = 0.3, - return_datasample: bool = False, - print_result: bool = False, - no_save_pred: bool = True, - out_dir: str = '', - **kwargs) -> dict: + def __call__( + self, + inputs: InputsType, + batch_size: int = 1, + return_vis: bool = False, + show: bool = False, + wait_time: int = 0, + no_save_vis: bool = False, + draw_pred: bool = True, + pred_score_thr: float = 0.3, + return_datasample: bool = False, + print_result: bool = False, + no_save_pred: bool = True, + out_dir: str = '', + # by open image task + texts: Optional[Union[str, list]] = None, + # by open panoptic task + stuff_texts: Optional[Union[str, list]] = None, + # by GLIP + custom_entities: bool = False, + **kwargs) -> dict: """Call the inferencer. Args: @@ -317,7 +333,11 @@ def __call__(self, out_file: Dir to save the inference results or visualization. If left as empty, no file will be saved. Defaults to ''. - + texts (str | list[str]): Text prompts. Defaults to None. + stuff_texts (str | list[str]): Stuff text prompts of open + panoptic task. Defaults to None. + custom_entities (bool): Whether to use custom entities. + Defaults to False. Only used in GLIP. **kwargs: Other keyword arguments passed to :meth:`preprocess`, :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. Each key in kwargs should be in the corresponding set of @@ -335,14 +355,39 @@ def __call__(self, ) = self._dispatch_kwargs(**kwargs) ori_inputs = self._inputs_to_list(inputs) + + if texts is not None and isinstance(texts, str): + texts = [texts] * len(ori_inputs) + if stuff_texts is not None and isinstance(stuff_texts, str): + stuff_texts = [stuff_texts] * len(ori_inputs) + if texts is not None: + assert len(texts) == len(ori_inputs) + for i in range(len(texts)): + if isinstance(ori_inputs[i], str): + ori_inputs[i] = { + 'text': texts[i], + 'img_path': ori_inputs[i], + 'custom_entities': custom_entities + } + else: + ori_inputs[i] = { + 'text': texts[i], + 'img': ori_inputs[i], + 'custom_entities': custom_entities + } + if stuff_texts is not None: + assert len(stuff_texts) == len(ori_inputs) + for i in range(len(stuff_texts)): + ori_inputs[i]['stuff_text'] = stuff_texts[i] + inputs = self.preprocess( ori_inputs, batch_size=batch_size, **preprocess_kwargs) results_dict = {'predictions': [], 'visualization': []} - for ori_inputs, data in track(inputs, description='Inference'): + for ori_imgs, data in track(inputs, description='Inference'): preds = self.forward(data, **forward_kwargs) visualization = self.visualize( - ori_inputs, + ori_imgs, preds, return_vis=return_vis, show=show, @@ -551,12 +596,14 @@ def pred2dict(self, masks = data_sample.pred_instances.get('masks') pred_instances = data_sample.pred_instances.numpy() result = { - 'bboxes': pred_instances.bboxes.tolist(), 'labels': pred_instances.labels.tolist(), 'scores': pred_instances.scores.tolist() } + if 'bboxes' in pred_instances: + result['bboxes'] = pred_instances.bboxes.tolist() if masks is not None: - if pred_instances.bboxes.sum() == 0: + if 'bboxes' not in pred_instances or pred_instances.bboxes.sum( + ) == 0: # Fake bbox, such as the SOLO. bboxes = mask2bbox(masks.cpu()).numpy().tolist() result['bboxes'] = bboxes diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py index 7d347ae4ad9..5f398c08a3a 100644 --- a/mmdet/apis/inference.py +++ b/mmdet/apis/inference.py @@ -172,7 +172,7 @@ def inference_detector( data_ = dict(img_path=img, img_id=0) if text_prompt: - data_['caption'] = text_prompt + data_['text'] = text_prompt data_['custom_entities'] = custom_entities # build the data pipeline diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index 78074823d6f..303ea81a32b 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -1,12 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .ade20k import ADE20KDataset, ADE20KPanopticDataset +from .ade20k import (ADE20KInstanceDataset, ADE20KPanopticDataset, + ADE20KSegDataset) from .base_det_dataset import BaseDetDataset from .base_semseg_dataset import BaseSegDataset from .base_video_dataset import BaseVideoDataset from .cityscapes import CityscapesDataset from .coco import CocoDataset -from .coco_caption import COCOCaptionDataset +from .coco_caption import CocoCaptionDataset from .coco_panoptic import CocoPanopticDataset +from .coco_semantic import CocoSegDataset from .crowdhuman import CrowdHumanDataset from .dataset_wrappers import MultiImageMixDataset from .deepfashion import DeepFashionDataset @@ -15,7 +17,7 @@ from .mot_challenge_dataset import MOTChallengeDataset from .objects365 import Objects365V1Dataset, Objects365V2Dataset from .openimages import OpenImagesChallengeDataset, OpenImagesDataset -from .refcoco import RefCOCODataset +from .refcoco import RefCocoDataset from .reid_dataset import ReIDDataset from .samplers import (AspectRatioBatchSampler, ClassAwareSampler, GroupMultiSourceSampler, MultiSourceSampler, @@ -36,6 +38,7 @@ 'Objects365V1Dataset', 'Objects365V2Dataset', 'DSDLDetDataset', 'BaseVideoDataset', 'MOTChallengeDataset', 'TrackImgSampler', 'ReIDDataset', 'YouTubeVISDataset', 'TrackAspectRatioBatchSampler', - 'ADE20KPanopticDataset', 'COCOCaptionDataset', 'RefCOCODataset', - 'BaseSegDataset', 'ADE20KDataset' + 'ADE20KPanopticDataset', 'CocoCaptionDataset', 'RefCocoDataset', + 'BaseSegDataset', 'ADE20KSegDataset', 'CocoSegDataset', + 'ADE20KInstanceDataset' ] diff --git a/mmdet/datasets/ade20k.py b/mmdet/datasets/ade20k.py index dd49481a55e..573271cb5d0 100644 --- a/mmdet/datasets/ade20k.py +++ b/mmdet/datasets/ade20k.py @@ -6,54 +6,93 @@ from mmdet.registry import DATASETS from .base_semseg_dataset import BaseSegDataset +from .coco import CocoDataset from .coco_panoptic import CocoPanopticDataset +ADE_PALETTE = [(120, 120, 120), (180, 120, 120), (6, 230, 230), (80, 50, 50), + (4, 200, 3), (120, 120, 80), (140, 140, 140), (204, 5, 255), + (230, 230, 230), (4, 250, 7), (224, 5, 255), (235, 255, 7), + (150, 5, 61), (120, 120, 70), (8, 255, 51), (255, 6, 82), + (143, 255, 140), (204, 255, 4), (255, 51, 7), (204, 70, 3), + (0, 102, 200), (61, 230, 250), (255, 6, 51), (11, 102, 255), + (255, 7, 71), (255, 9, 224), (9, 7, 230), (220, 220, 220), + (255, 9, 92), (112, 9, 255), (8, 255, 214), (7, 255, 224), + (255, 184, 6), (10, 255, 71), (255, 41, 10), (7, 255, 255), + (224, 255, 8), (102, 8, 255), (255, 61, 6), (255, 194, 7), + (255, 122, 8), (0, 255, 20), (255, 8, 41), (255, 5, 153), + (6, 51, 255), (235, 12, 255), (160, 150, 20), (0, 163, 255), + (140, 140, 140), (250, 10, 15), (20, 255, 0), (31, 255, 0), + (255, 31, 0), (255, 224, 0), (153, 255, 0), (0, 0, 255), + (255, 71, 0), (0, 235, 255), (0, 173, 255), (31, 0, 255), + (11, 200, 200), (255, 82, 0), (0, 255, 245), (0, 61, 255), + (0, 255, 112), (0, 255, 133), (255, 0, 0), (255, 163, 0), + (255, 102, 0), (194, 255, 0), (0, 143, 255), (51, 255, 0), + (0, 82, 255), (0, 255, 41), (0, 255, 173), (10, 0, 255), + (173, 255, 0), (0, 255, 153), (255, 92, 0), (255, 0, 255), + (255, 0, 245), (255, 0, 102), (255, 173, 0), (255, 0, 20), + (255, 184, 184), (0, 31, 255), (0, 255, 61), (0, 71, 255), + (255, 0, 204), (0, 255, 194), (0, 255, 82), (0, 10, 255), + (0, 112, 255), (51, 0, 255), (0, 194, 255), (0, 122, 255), + (0, 255, 163), (255, 153, 0), (0, 255, 10), (255, 112, 0), + (143, 255, 0), (82, 0, 255), (163, 255, 0), (255, 235, 0), + (8, 184, 170), (133, 0, 255), (0, 255, 92), (184, 0, 255), + (255, 0, 31), (0, 184, 255), (0, 214, 255), (255, 0, 112), + (92, 255, 0), (0, 224, 255), (112, 224, 255), (70, 184, 160), + (163, 0, 255), (153, 0, 255), (71, 255, 0), (255, 0, 163), + (255, 204, 0), (255, 0, 143), (0, 255, 235), (133, 255, 0), + (255, 0, 235), (245, 0, 255), (255, 0, 122), (255, 245, 0), + (10, 190, 212), (214, 255, 0), (0, 204, 255), (20, 0, 255), + (255, 255, 0), (0, 153, 255), (0, 41, 255), (0, 255, 204), + (41, 0, 255), (41, 255, 0), (173, 0, 255), (0, 245, 255), + (71, 0, 255), (122, 0, 255), (0, 255, 184), (0, 92, 255), + (184, 255, 0), (0, 133, 255), (255, 214, 0), (25, 194, 194), + (102, 255, 0), (92, 0, 255)] + @DATASETS.register_module() class ADE20KPanopticDataset(CocoPanopticDataset): METAINFO = { 'classes': - ('wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road, route', - 'bed', 'window ', 'grass', 'cabinet', 'sidewalk, pavement', 'person', - 'earth, ground', 'door', 'table', 'mountain, mount', 'plant', - 'curtain', 'chair', 'car', 'water', 'painting, picture', 'sofa', - 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', 'seat', - 'fence', 'desk', 'rock, stone', 'wardrobe, closet, press', 'lamp', - 'tub', 'rail', 'cushion', 'base, pedestal, stand', 'box', - 'column, pillar', 'signboard, sign', - 'chest of drawers, chest, bureau, dresser', 'counter', 'sand', 'sink', - 'skyscraper', 'fireplace', 'refrigerator, icebox', - 'grandstand, covered stand', 'path', 'stairs', 'runway', + ('bed', 'window', 'cabinet', 'person', 'door', 'table', 'curtain', + 'chair', 'car', 'painting, picture', 'sofa', 'shelf', 'mirror', + 'armchair', 'seat', 'fence', 'desk', 'wardrobe, closet, press', + 'lamp', 'tub', 'rail', 'cushion', 'box', 'column, pillar', + 'signboard, sign', 'chest of drawers, chest, bureau, dresser', + 'counter', 'sink', 'fireplace', 'refrigerator, icebox', 'stairs', 'case, display case, showcase, vitrine', 'pool table, billiard table, snooker table', 'pillow', - 'screen door, screen', 'stairway, staircase', 'river', 'bridge, span', - 'bookcase', 'blind, screen', 'coffee table', + 'screen door, screen', 'bookcase', 'coffee table', 'toilet, can, commode, crapper, pot, potty, stool, throne', 'flower', - 'book', 'hill', 'bench', 'countertop', 'stove', 'palm, palm tree', - 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', - 'arcade machine', 'hovel, hut, hutch, shack, shanty', 'bus', 'towel', - 'light', 'truck', 'tower', 'chandelier', 'awning, sunshade, sunblind', - 'street lamp', 'booth', 'tv', 'plane', 'dirt track', 'clothes', - 'pole', 'land, ground, soil', + 'book', 'bench', 'countertop', 'stove', 'palm, palm tree', + 'kitchen island', 'computer', 'swivel chair', 'boat', + 'arcade machine', 'bus', 'towel', 'light', 'truck', 'chandelier', + 'awning, sunshade, sunblind', 'street lamp', 'booth', 'tv', + 'airplane', 'clothes', 'pole', 'bannister, banister, balustrade, balusters, handrail', + 'ottoman, pouf, pouffe, puff, hassock', 'bottle', 'van', 'ship', + 'fountain', 'washer, automatic washer, washing machine', + 'plaything, toy', 'stool', 'barrel, cask', 'basket, handbasket', + 'bag', 'minibike, motorbike', 'oven', 'ball', 'food, solid food', + 'step, stair', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', + 'dishwasher', 'screen', 'sculpture', 'hood, exhaust hood', 'sconce', + 'vase', 'traffic light', 'tray', 'trash can', 'fan', 'plate', + 'monitor', 'bulletin board', 'radiator', 'glass, drinking glass', + 'clock', 'flag', 'wall', 'building', 'sky', 'floor', 'tree', + 'ceiling', 'road, route', 'grass', 'sidewalk, pavement', + 'earth, ground', 'mountain, mount', 'plant', 'water', 'house', 'sea', + 'rug', 'field', 'rock, stone', 'base, pedestal, stand', 'sand', + 'skyscraper', 'grandstand, covered stand', 'path', 'runway', + 'stairway, staircase', 'river', 'bridge, span', 'blind, screen', + 'hill', 'bar', 'hovel, hut, hutch, shack, shanty', 'tower', + 'dirt track', 'land, ground, soil', 'escalator, moving staircase, moving stairway', - 'ottoman, pouf, pouffe, puff, hassock', 'bottle', 'buffet, counter, sideboard', - 'poster, posting, placard, notice, bill, card', 'stage', 'van', - 'ship', 'fountain', - 'conveyor belt, conveyor belt, conveyor, conveyor, transporter', - 'canopy', 'washer, automatic washer, washing machine', - 'plaything, toy', 'pool', 'stool', 'barrel, cask', - 'basket, handbasket', 'falls', 'tent', 'bag', 'minibike, motorbike', - 'cradle', 'oven', 'ball', 'food, solid food', 'step, stair', - 'tank, storage tank', 'trade name', 'microwave', 'pot', 'animal', - 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket, cover', - 'sculpture', 'hood, exhaust hood', 'sconce', 'vase', 'traffic light', - 'tray', 'trash can', 'fan', 'pier', 'crt screen', 'plate', 'monitor', - 'bulletin board', 'shower', 'radiator', 'glass, drinking glass', - 'clock', 'flag'), + 'poster, posting, placard, notice, bill, card', 'stage', + 'conveyer belt, conveyor belt, conveyer, conveyor, transporter', + 'canopy', 'pool', 'falls', 'tent', 'cradle', 'tank, storage tank', + 'lake', 'blanket, cover', 'pier', 'crt screen', 'shower'), 'thing_classes': - ('bed', 'window ', 'cabinet', 'person', 'door', 'table', 'curtain', + ('bed', 'window', 'cabinet', 'person', 'door', 'table', 'curtain', 'chair', 'car', 'painting, picture', 'sofa', 'shelf', 'mirror', 'armchair', 'seat', 'fence', 'desk', 'wardrobe, closet, press', 'lamp', 'tub', 'rail', 'cushion', 'box', 'column, pillar', @@ -66,8 +105,8 @@ class ADE20KPanopticDataset(CocoPanopticDataset): 'book', 'bench', 'countertop', 'stove', 'palm, palm tree', 'kitchen island', 'computer', 'swivel chair', 'boat', 'arcade machine', 'bus', 'towel', 'light', 'truck', 'chandelier', - 'awning, sunshade, sunblind', 'street lamp', 'booth', 'tv', 'plane', - 'clothes', 'pole', + 'awning, sunshade, sunblind', 'street lamp', 'booth', 'tv', + 'airplane', 'clothes', 'pole', 'bannister, banister, balustrade, balusters, handrail', 'ottoman, pouf, pouffe, puff, hassock', 'bottle', 'van', 'ship', 'fountain', 'washer, automatic washer, washing machine', @@ -89,55 +128,66 @@ class ADE20KPanopticDataset(CocoPanopticDataset): 'land, ground, soil', 'escalator, moving staircase, moving stairway', 'buffet, counter, sideboard', 'poster, posting, placard, notice, bill, card', 'stage', - 'conveyor belt, conveyor belt, conveyor, conveyor, transporter', + 'conveyer belt, conveyor belt, conveyer, conveyor, transporter', 'canopy', 'pool', 'falls', 'tent', 'cradle', 'tank, storage tank', 'lake', 'blanket, cover', 'pier', 'crt screen', 'shower'), - 'palette': [[120, 120, 120], [180, 120, 120], [6, 230, 230], - [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], - [204, 5, 255], [230, 230, 230], [4, 250, 7], [224, 5, 255], - [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], - [255, 6, 82], [143, 255, 140], [204, 255, 4], [255, 51, 7], - [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], - [11, 102, 255], [255, 7, 71], [255, 9, 224], [9, 7, 230], - [220, 220, 220], [255, 9, 92], - [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], - [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], - [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], - [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], - [235, 12, 255], [160, 150, 20], [0, 163, 255], - [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], - [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], - [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], - [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], - [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], - [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], - [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], - [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], - [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], - [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], - [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], - [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], - [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], - [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], - [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], - [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], - [92, 255, 0], [0, 224, 255], [112, 224, - 255], [70, 184, 160], - [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], - [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], - [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], - [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], - [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], - [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], - [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], - [184, 255, 0], [0, 133, 255], [255, 214, - 0], [25, 194, 194], - [102, 255, 0], [92, 0, 255]] + 'palette': + ADE_PALETTE + } + + +@DATASETS.register_module() +class ADE20KInstanceDataset(CocoDataset): + METAINFO = { + 'classes': + ('bed', 'windowpane', 'cabinet', 'person', 'door', 'table', 'curtain', + 'chair', 'car', 'painting', 'sofa', 'shelf', 'mirror', 'armchair', + 'seat', 'fence', 'desk', 'wardrobe', 'lamp', 'bathtub', 'railing', + 'cushion', 'box', 'column', 'signboard', 'chest of drawers', + 'counter', 'sink', 'fireplace', 'refrigerator', 'stairs', 'case', + 'pool table', 'pillow', 'screen door', 'bookcase', 'coffee table', + 'toilet', 'flower', 'book', 'bench', 'countertop', 'stove', 'palm', + 'kitchen island', 'computer', 'swivel chair', 'boat', + 'arcade machine', 'bus', 'towel', 'light', 'truck', 'chandelier', + 'awning', 'streetlight', 'booth', 'television receiver', 'airplane', + 'apparel', 'pole', 'bannister', 'ottoman', 'bottle', 'van', 'ship', + 'fountain', 'washer', 'plaything', 'stool', 'barrel', 'basket', 'bag', + 'minibike', 'oven', 'ball', 'food', 'step', 'trade name', 'microwave', + 'pot', 'animal', 'bicycle', 'dishwasher', 'screen', 'sculpture', + 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'ashcan', 'fan', + 'plate', 'monitor', 'bulletin board', 'radiator', 'glass', 'clock', + 'flag'), + 'palette': [(204, 5, 255), (230, 230, 230), (224, 5, 255), + (150, 5, 61), (8, 255, 51), (255, 6, 82), (255, 51, 7), + (204, 70, 3), (0, 102, 200), (255, 6, 51), (11, 102, 255), + (255, 7, 71), (220, 220, 220), (8, 255, 214), + (7, 255, 224), (255, 184, 6), (10, 255, 71), (7, 255, 255), + (224, 255, 8), (102, 8, 255), (255, 61, 6), (255, 194, 7), + (0, 255, 20), (255, 8, 41), (255, 5, 153), (6, 51, 255), + (235, 12, 255), (0, 163, 255), (250, 10, 15), (20, 255, 0), + (255, 224, 0), (0, 0, 255), (255, 71, 0), (0, 235, 255), + (0, 173, 255), (0, 255, 245), (0, 255, 112), (0, 255, 133), + (255, 0, 0), (255, 163, 0), (194, 255, 0), (0, 143, 255), + (51, 255, 0), (0, 82, 255), (0, 255, 41), (0, 255, 173), + (10, 0, 255), (173, 255, 0), (255, 92, 0), (255, 0, 245), + (255, 0, 102), (255, 173, 0), (255, 0, 20), (0, 31, 255), + (0, 255, 61), (0, 71, 255), (255, 0, 204), (0, 255, 194), + (0, 255, 82), (0, 112, 255), (51, 0, 255), (0, 122, 255), + (255, 153, 0), (0, 255, 10), (163, 255, 0), (255, 235, 0), + (8, 184, 170), (184, 0, 255), (255, 0, 31), (0, 214, 255), + (255, 0, 112), (92, 255, 0), (70, 184, 160), (163, 0, 255), + (71, 255, 0), (255, 0, 163), (255, 204, 0), (255, 0, 143), + (133, 255, 0), (255, 0, 235), (245, 0, 255), (255, 0, 122), + (255, 245, 0), (214, 255, 0), (0, 204, 255), (255, 255, 0), + (0, 153, 255), (0, 41, 255), (0, 255, 204), (41, 0, 255), + (41, 255, 0), (173, 0, 255), (0, 245, 255), (0, 255, 184), + (0, 92, 255), (184, 255, 0), (255, 214, 0), (25, 194, 194), + (102, 255, 0), (92, 0, 255)], } @DATASETS.register_module() -class ADE20KDataset(BaseSegDataset): +class ADE20KSegDataset(BaseSegDataset): """ADE20K dataset. In segmentation map annotation for ADE20K, 0 stands for background, which @@ -173,44 +223,7 @@ class ADE20KDataset(BaseSegDataset): 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag'), - palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], - [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], - [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], - [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], - [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], - [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], - [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], - [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], - [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], - [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], - [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], - [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], - [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], - [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], - [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], - [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], - [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], - [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], - [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], - [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], - [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], - [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], - [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], - [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], - [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], - [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], - [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], - [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], - [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], - [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], - [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], - [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], - [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], - [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], - [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], - [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], - [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], - [102, 255, 0], [92, 0, 255]]) + palette=ADE_PALETTE) def __init__(self, img_suffix='.jpg', @@ -241,7 +254,6 @@ def load_data_list(self) -> List[dict]: seg_map = img.replace(self.img_suffix, self.seg_map_suffix) data_info['seg_map_path'] = osp.join(ann_dir, seg_map) data_info['label_map'] = self.label_map - data_info['seg_fields'] = [] if self.return_classes: data_info['text'] = list(self._metainfo['classes']) data_list.append(data_info) diff --git a/mmdet/datasets/base_det_dataset.py b/mmdet/datasets/base_det_dataset.py index cf110bc7a02..57bc7098387 100644 --- a/mmdet/datasets/base_det_dataset.py +++ b/mmdet/datasets/base_det_dataset.py @@ -19,6 +19,8 @@ class BaseDetDataset(BaseDataset): corresponding backend in mmdet <= 3.0.0rc6. Defaults to None. backend_args (dict, optional): Arguments to instantiate the corresponding backend. Defaults to None. + return_classes (bool): Whether to return class information + for open vocabulary-based algorithms. Defaults to False. """ def __init__(self, @@ -27,12 +29,12 @@ def __init__(self, proposal_file: Optional[str] = None, file_client_args: dict = None, backend_args: dict = None, - return_caption: Optional[bool] = False, + return_classes: bool = False, **kwargs) -> None: self.seg_map_suffix = seg_map_suffix self.proposal_file = proposal_file self.backend_args = backend_args - self.return_caption = return_caption + self.return_classes = return_classes if file_client_args is not None: raise RuntimeError( 'The `file_client_args` is deprecated, ' diff --git a/mmdet/datasets/base_semseg_dataset.py b/mmdet/datasets/base_semseg_dataset.py index e0ef56f043d..d10f762a21a 100644 --- a/mmdet/datasets/base_semseg_dataset.py +++ b/mmdet/datasets/base_semseg_dataset.py @@ -67,13 +67,15 @@ class BaseSegDataset(BaseDataset): information of the dataset is needed, which is not necessary to load annotation file. ``Basedataset`` can skip load annotations to save time by set ``lazy_init=True``. Defaults to False. + use_label_map (bool, optional): Whether to use label map. + Defaults to False. max_refetch (int, optional): If ``Basedataset.prepare_data`` get a None img. The maximum extra number of cycles to get a valid image. Defaults to 1000. backend_args (dict, Optional): Arguments to instantiate a file backend. See https://mmengine.readthedocs.io/en/latest/api/fileio.htm for details. Defaults to None. - Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + Notes: mmcv>=2.0.0rc4 required. """ METAINFO: dict = dict() @@ -90,6 +92,7 @@ def __init__(self, pipeline: List[Union[dict, Callable]] = [], test_mode: bool = False, lazy_init: bool = False, + use_label_map: bool = False, max_refetch: int = 1000, backend_args: Optional[dict] = None) -> None: @@ -113,7 +116,8 @@ def __init__(self, # Get label map for custom classes new_classes = self._metainfo.get('classes', None) - self.label_map = self.get_label_map(new_classes) + self.label_map = self.get_label_map( + new_classes) if use_label_map else None self._metainfo.update(dict(label_map=self.label_map)) # Update palette based on label map or generate palette @@ -213,6 +217,9 @@ def _update_palette(self) -> list: if new_id != 0: new_palette.append(palette[old_id]) new_palette = type(palette)(new_palette) + elif len(palette) >= len(classes): + # Allow palette length is greater than classes. + return palette else: raise ValueError('palette does not match classes ' f'as metainfo is {self._metainfo}.') diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py index 1e6205473b7..277b75988da 100644 --- a/mmdet/datasets/coco.py +++ b/mmdet/datasets/coco.py @@ -127,8 +127,8 @@ def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: data_info['height'] = img_info['height'] data_info['width'] = img_info['width'] - if self.return_caption: - data_info['caption'] = self.metainfo['classes'] + if self.return_classes: + data_info['text'] = self.metainfo['classes'] data_info['custom_entities'] = True instances = [] diff --git a/mmdet/datasets/coco_caption.py b/mmdet/datasets/coco_caption.py index e5af1ec59a6..ee695fe9a76 100644 --- a/mmdet/datasets/coco_caption.py +++ b/mmdet/datasets/coco_caption.py @@ -10,18 +10,8 @@ @DATASETS.register_module() -class COCOCaptionDataset(BaseDataset): - """COCO Caption dataset. - - Args: - data_root (str): The root directory for ``data_prefix`` and - ``ann_file``.. - ann_file (str): Annotation file path. - data_prefix (dict): Prefix for data field. Defaults to - ``dict(img_path='')``. - pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. - **kwargs: Other keyword arguments in :class:`BaseDataset`. - """ +class CocoCaptionDataset(BaseDataset): + """COCO2014 Caption dataset.""" def load_data_list(self) -> List[dict]: """Load data list.""" diff --git a/mmdet/datasets/coco_panoptic.py b/mmdet/datasets/coco_panoptic.py index 33d4189e6c4..d5ca7855509 100644 --- a/mmdet/datasets/coco_panoptic.py +++ b/mmdet/datasets/coco_panoptic.py @@ -217,6 +217,11 @@ def parse_data_info(self, raw_data_info: dict) -> dict: data_info['height'] = img_info['height'] data_info['width'] = img_info['width'] + if self.return_classes: + data_info['text'] = self.metainfo['thing_classes'] + data_info['stuff_text'] = self.metainfo['stuff_classes'] + data_info['custom_entities'] = True # no important + instances = [] segments_info = [] for ann in ann_info: diff --git a/mmdet/datasets/coco_semantic.py b/mmdet/datasets/coco_semantic.py new file mode 100644 index 00000000000..75256845445 --- /dev/null +++ b/mmdet/datasets/coco_semantic.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import DATASETS +from .ade20k import ADE20KSegDataset + + +@DATASETS.register_module() +class CocoSegDataset(ADE20KSegDataset): + """COCO dataset. + + In segmentation map annotation for COCO. The ``img_suffix`` is fixed to + '.jpg', and ``seg_map_suffix`` is fixed to '.png'. + """ + + METAINFO = dict( + classes=( + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', + 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', + 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', + 'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet', + 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile', + 'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain', + 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble', + 'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower', + 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel', + 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal', + 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', + 'paper', 'pavement', 'pillow', 'plant-other', 'plastic', + 'platform', 'playingfield', 'railing', 'railroad', 'river', 'road', + 'rock', 'roof', 'rug', 'salad', 'sand', 'sea', 'shelf', + 'sky-other', 'skyscraper', 'snow', 'solid-other', 'stairs', + 'stone', 'straw', 'structural-other', 'table', 'tent', + 'textile-other', 'towel', 'tree', 'vegetable', 'wall-brick', + 'wall-concrete', 'wall-other', 'wall-panel', 'wall-stone', + 'wall-tile', 'wall-wood', 'water-other', 'waterdrops', + 'window-blind', 'window-other', 'wood'), + palette=[(120, 120, 120), (180, 120, 120), (6, 230, 230), (80, 50, 50), + (4, 200, 3), (120, 120, 80), (140, 140, 140), (204, 5, 255), + (230, 230, 230), (4, 250, 7), (224, 5, 255), (235, 255, 7), + (150, 5, 61), (120, 120, 70), (8, 255, 51), (255, 6, 82), + (143, 255, 140), (204, 255, 4), (255, 51, 7), (204, 70, 3), + (0, 102, 200), (61, 230, 250), (255, 6, 51), (11, 102, 255), + (255, 7, 71), (255, 9, 224), (9, 7, 230), (220, 220, 220), + (255, 9, 92), (112, 9, 255), (8, 255, 214), (7, 255, 224), + (255, 184, 6), (10, 255, 71), (255, 41, 10), (7, 255, 255), + (224, 255, 8), (102, 8, 255), (255, 61, 6), (255, 194, 7), + (255, 122, 8), (0, 255, 20), (255, 8, 41), (255, 5, 153), + (6, 51, 255), (235, 12, 255), (160, 150, 20), (0, 163, 255), + (140, 140, 140), (250, 10, 15), (20, 255, 0), (31, 255, 0), + (255, 31, 0), (255, 224, 0), (153, 255, 0), (0, 0, 255), + (255, 71, 0), (0, 235, 255), (0, 173, 255), (31, 0, 255), + (11, 200, 200), (255, 82, 0), (0, 255, 245), (0, 61, 255), + (0, 255, 112), (0, 255, 133), (255, 0, 0), (255, 163, 0), + (255, 102, 0), (194, 255, 0), (0, 143, 255), (51, 255, 0), + (0, 82, 255), (0, 255, 41), (0, 255, 173), (10, 0, 255), + (173, 255, 0), (0, 255, 153), (255, 92, 0), (255, 0, 255), + (255, 0, 245), (255, 0, 102), (255, 173, 0), (255, 0, 20), + (255, 184, 184), (0, 31, 255), (0, 255, 61), (0, 71, 255), + (255, 0, 204), (0, 255, 194), (0, 255, 82), (0, 10, 255), + (0, 112, 255), (51, 0, 255), (0, 194, 255), (0, 122, 255), + (0, 255, 163), (255, 153, 0), (0, 255, 10), (255, 112, 0), + (143, 255, 0), (82, 0, 255), (163, 255, 0), (255, 235, 0), + (8, 184, 170), (133, 0, 255), (0, 255, 92), (184, 0, 255), + (255, 0, 31), (0, 184, 255), (0, 214, 255), (255, 0, 112), + (92, 255, 0), (0, 224, 255), (112, 224, 255), (70, 184, 160), + (163, 0, 255), (153, 0, 255), (71, 255, 0), (255, 0, 163), + (255, 204, 0), (255, 0, 143), (0, 255, 235), (133, 255, 0), + (255, 0, 235), (245, 0, 255), (255, 0, 122), (255, 245, 0), + (10, 190, 212), (214, 255, 0), (0, 204, 255), (20, 0, 255), + (255, 255, 0), (0, 153, 255), (0, 41, 255), (0, 255, 204), + (41, 0, 255), (41, 255, 0), (173, 0, 255), (0, 245, 255), + (71, 0, 255), (122, 0, 255), (0, 255, 184), (0, 92, 255), + (184, 255, 0), (0, 133, 255), (255, 214, 0), (25, 194, 194), + (102, 255, 0), (92, 0, 255), (107, 255, 200), (58, 41, 149), + (183, 121, 142), (255, 73, 97), (107, 142, 35), + (190, 153, 153), (146, 139, 141), (70, 130, 180), + (134, 199, 156), (209, 226, 140), (96, 36, 108), (96, 96, 96), + (64, 170, 64), (152, 251, 152), (208, 229, 228), + (206, 186, 171), (152, 161, 64), (116, 112, 0), (0, 114, 143), + (102, 102, 156), (250, 141, 255)]) diff --git a/mmdet/datasets/refcoco.py b/mmdet/datasets/refcoco.py index ce95e04e171..0dae75fd547 100644 --- a/mmdet/datasets/refcoco.py +++ b/mmdet/datasets/refcoco.py @@ -1,17 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. +import collections import os.path as osp -from typing import List +import random +from typing import Dict, List import mmengine -import numpy as np from mmengine.dataset import BaseDataset -from pycocotools.coco import COCO from mmdet.registry import DATASETS @DATASETS.register_module() -class RefCOCODataset(BaseDataset): +class RefCocoDataset(BaseDataset): """RefCOCO dataset. The `Refcoco` and `Refcoco+` dataset is based on @@ -29,19 +29,23 @@ class RefCOCODataset(BaseDataset): data_prefix (str): Prefix for training data. split_file (str): Split file path. split (str): Split name. Defaults to 'train'. + text_mode (str): Text mode. Defaults to 'random'. **kwargs: Other keyword arguments in :class:`BaseDataset`. """ def __init__(self, - data_root, - ann_file, - data_prefix, - split_file, - split='train', + data_root: str, + ann_file: str, + split_file: str, + data_prefix: Dict, + split: str = 'train', + text_mode: str = 'random', **kwargs): self.split_file = split_file self.split = split + assert text_mode in ['original', 'random', 'concat', 'select_first'] + self.text_mode = text_mode super().__init__( data_root=data_root, data_prefix=data_prefix, @@ -55,36 +59,103 @@ def _join_prefix(self): return super()._join_prefix() + def _init_refs(self): + """Initialize the refs for RefCOCO.""" + anns, imgs = {}, {} + for ann in self.instances['annotations']: + anns[ann['id']] = ann + for img in self.instances['images']: + imgs[img['id']] = img + + refs, ref_to_ann = {}, {} + for ref in self.splits: + # ids + ref_id = ref['ref_id'] + ann_id = ref['ann_id'] + # add mapping related to ref + refs[ref_id] = ref + ref_to_ann[ref_id] = anns[ann_id] + + self.refs = refs + self.ref_to_ann = ref_to_ann + def load_data_list(self) -> List[dict]: """Load data list.""" - with mmengine.get_local_path(self.ann_file) as ann_file: - coco = COCO(ann_file) - splits = mmengine.load(self.split_file, file_format='pkl') + self.splits = mmengine.load(self.split_file, file_format='pkl') + self.instances = mmengine.load(self.ann_file, file_format='json') + self._init_refs() img_prefix = self.data_prefix['img_path'] + ref_ids = [ + ref['ref_id'] for ref in self.splits if ref['split'] == self.split + ] + full_anno = [] + for ref_id in ref_ids: + ref = self.refs[ref_id] + ann = self.ref_to_ann[ref_id] + ann.update(ref) + full_anno.append(ann) + + image_id_list = [] + final_anno = {} + for anno in full_anno: + image_id_list.append(anno['image_id']) + final_anno[anno['ann_id']] = anno + annotations = [value for key, value in final_anno.items()] + + coco_train_id = [] + image_annot = {} + for i in range(len(self.instances['images'])): + coco_train_id.append(self.instances['images'][i]['id']) + image_annot[self.instances['images'][i] + ['id']] = self.instances['images'][i] + + images = [] + for image_id in list(set(image_id_list)): + images += [image_annot[image_id]] + data_list = [] + + grounding_dict = collections.defaultdict(list) + for anno in annotations: + image_id = int(anno['image_id']) + grounding_dict[image_id].append(anno) + join_path = mmengine.fileio.get_file_backend(img_prefix).join_path - for refer in splits: - if refer['split'] != self.split: - continue - - ann = coco.anns[refer['ann_id']] - img = coco.imgs[ann['image_id']] - sentences = refer['sentences'] - bbox = np.array(ann['bbox'], dtype=np.float32) - bbox[2:4] = bbox[0:2] + bbox[2:4] # XYWH -> XYXY - mask = np.array(ann['segmentation'], dtype=np.float32) - - for sent in sentences: - data_info = { - 'img_path': join_path(img_prefix, img['file_name']), - 'image_id': ann['image_id'], - 'ann_id': ann['id'], - 'text': sent['sent'], - 'gt_bboxes': bbox[None, :], - 'gt_masks': mask[None, :], - } - data_list.append(data_info) + for image in images: + img_id = image['id'] + instances = [] + sentences = [] + for grounding_anno in grounding_dict[img_id]: + texts = [x['raw'].lower() for x in grounding_anno['sentences']] + # random select one text + if self.text_mode == 'random': + idx = random.randint(0, len(texts) - 1) + text = [texts[idx]] + # concat all texts + elif self.text_mode == 'concat': + text = [''.join(texts)] + # select the first text + elif self.text_mode == 'select_first': + text = [texts[0]] + # use all texts + elif self.text_mode == 'original': + text = texts + else: + raise ValueError(f'Invalid text mode "{self.text_mode}".') + ins = [{ + 'mask': grounding_anno['segmentation'], + 'ignore_flag': 0 + }] * len(text) + instances.extend(ins) + sentences.extend(text) + data_info = { + 'img_path': join_path(img_prefix, image['file_name']), + 'img_id': img_id, + 'instances': instances, + 'text': sentences + } + data_list.append(data_info) if len(data_list) == 0: raise ValueError(f'No sample in split "{self.split}".') diff --git a/mmdet/datasets/transforms/__init__.py b/mmdet/datasets/transforms/__init__.py index 9892f61891f..b5ab3758382 100644 --- a/mmdet/datasets/transforms/__init__.py +++ b/mmdet/datasets/transforms/__init__.py @@ -12,15 +12,14 @@ from .loading import (FilterAnnotations, InferencerLoader, LoadAnnotations, LoadEmptyAnnotations, LoadImageFromNDArray, LoadMultiChannelImageFromFiles, LoadPanopticAnnotations, - LoadProposals, LoadSemSegAnnotations, - LoadTrackAnnotations) + LoadProposals, LoadTrackAnnotations) from .transforms import (Albu, CachedMixUp, CachedMosaic, CopyPaste, CutOut, Expand, FixScaleResize, FixShapeResize, MinIoURandomCrop, MixUp, Mosaic, Pad, PhotoMetricDistortion, RandomAffine, RandomCenterCropPad, RandomCrop, RandomErasing, - RandomFlip, RandomShift, Resize, SegRescale, - YOLOXHSVRandomAug) + RandomFlip, RandomShift, Resize, ResizeShortestEdge, + SegRescale, YOLOXHSVRandomAug) from .wrappers import MultiBranch, ProposalBroadcaster, RandomOrder __all__ = [ @@ -38,6 +37,5 @@ 'LoadEmptyAnnotations', 'RandomOrder', 'CachedMosaic', 'CachedMixUp', 'FixShapeResize', 'ProposalBroadcaster', 'InferencerLoader', 'LoadTrackAnnotations', 'BaseFrameSample', 'UniformRefFrameSample', - 'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize', - 'LoadSemSegAnnotations' + 'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize', 'ResizeShortestEdge' ] diff --git a/mmdet/datasets/transforms/formatting.py b/mmdet/datasets/transforms/formatting.py index 58d0b612f92..83fada30b1f 100644 --- a/mmdet/datasets/transforms/formatting.py +++ b/mmdet/datasets/transforms/formatting.py @@ -125,7 +125,11 @@ def transform(self, results: dict) -> dict: if 'gt_seg_map' in results: gt_sem_seg_data = dict( sem_seg=to_tensor(results['gt_seg_map'][None, ...].copy())) - data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data) + gt_sem_seg_data = PixelData(**gt_sem_seg_data) + if 'ignore_index' in results: + metainfo = dict(ignore_index=results['ignore_index']) + gt_sem_seg_data.set_metainfo(metainfo) + data_sample.gt_sem_seg = gt_sem_seg_data img_meta = {} for key in self.meta_keys: diff --git a/mmdet/datasets/transforms/loading.py b/mmdet/datasets/transforms/loading.py index c7db404f1e3..95945a82d88 100644 --- a/mmdet/datasets/transforms/loading.py +++ b/mmdet/datasets/transforms/loading.py @@ -239,6 +239,11 @@ class LoadAnnotations(MMCV_LoadAnnotations): poly2mask (bool): Whether to convert mask to bitmap. Default: True. box_type (str): The box type used to wrap the bboxes. If ``box_type`` is None, gt_bboxes will keep being np.ndarray. Defaults to 'hbox'. + reduce_zero_label (bool): Whether reduce all label value + by 1. Usually used for datasets where 0 is background label. + Defaults to False. + ignore_index (int): The label index to be ignored. + Valid only if reduce_zero_label is true. Defaults is 255. imdecode_backend (str): The image decoding backend type. The backend argument for :func:``mmcv.imfrombytes``. See :fun:``mmcv.imfrombytes`` for details. @@ -247,15 +252,21 @@ class LoadAnnotations(MMCV_LoadAnnotations): corresponding backend. Defaults to None. """ - def __init__(self, - with_mask: bool = False, - poly2mask: bool = True, - box_type: str = 'hbox', - **kwargs) -> None: + def __init__( + self, + with_mask: bool = False, + poly2mask: bool = True, + box_type: str = 'hbox', + # use for semseg + reduce_zero_label: bool = False, + ignore_index: int = 255, + **kwargs) -> None: super(LoadAnnotations, self).__init__(**kwargs) self.with_mask = with_mask self.poly2mask = poly2mask self.box_type = box_type + self.reduce_zero_label = reduce_zero_label + self.ignore_index = ignore_index def _load_bboxes(self, results: dict) -> None: """Private function to load bounding box annotations. @@ -381,6 +392,42 @@ def _load_masks(self, results: dict) -> None: gt_masks = PolygonMasks([mask for mask in gt_masks], h, w) results['gt_masks'] = gt_masks + def _load_seg_map(self, results: dict) -> None: + """Private function to load semantic segmentation annotations. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + if results.get('seg_map_path', None) is None: + return + + img_bytes = get( + results['seg_map_path'], backend_args=self.backend_args) + gt_semantic_seg = mmcv.imfrombytes( + img_bytes, flag='unchanged', + backend=self.imdecode_backend).squeeze() + + if self.reduce_zero_label: + # avoid using underflow conversion + gt_semantic_seg[gt_semantic_seg == 0] = self.ignore_index + gt_semantic_seg = gt_semantic_seg - 1 + gt_semantic_seg[gt_semantic_seg == self.ignore_index - + 1] = self.ignore_index + + # modify if custom classes + if results.get('label_map', None) is not None: + # Add deep copy to solve bug of repeatedly + # replace `gt_semantic_seg`, which is reported in + # https://github.com/open-mmlab/mmsegmentation/pull/1445/ + gt_semantic_seg_copy = gt_semantic_seg.copy() + for old_id, new_id in results['label_map'].items(): + gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id + results['gt_seg_map'] = gt_semantic_seg + results['ignore_index'] = self.ignore_index + def transform(self, results: dict) -> dict: """Function to load multiple types annotations. @@ -600,72 +647,6 @@ def transform(self, results: dict) -> dict: return results -@TRANSFORMS.register_module() -class LoadSemSegAnnotations(LoadAnnotations): - """Load annotations for semantic segmentation provided by dataset. - - The annotation format is as the following: - - .. code-block:: python - - { - # Filename of semantic segmentation ground truth file. - 'seg_map_path': 'a/b/c' - } - - After this module, the annotation has been changed to the format below: - - .. code-block:: python - - { - # In uint8 type. - 'gt_seg_map': np.ndarray (H, W) - } - - Required Keys: - - - seg_map_path (str): Path of semantic segmentation ground truth file. - - Added Keys: - - - gt_seg_map (np.uint8) - """ - - def __init__(self, **kwargs) -> None: - super().__init__( - with_bbox=False, - with_label=False, - with_seg=True, - with_keypoints=False, - **kwargs) - - def _load_seg_map(self, results: dict) -> None: - """Private function to load semantic segmentation annotations. - - Args: - results (dict): Result dict from :obj:``mmcv.BaseDataset``. - - Returns: - dict: The dict contains loaded semantic segmentation annotations. - """ - - img_bytes = get( - results['seg_map_path'], backend_args=self.backend_args) - gt_semantic_seg = mmcv.imfrombytes( - img_bytes, flag='unchanged', - backend=self.imdecode_backend).squeeze().astype(np.uint8) - - # modify if custom classes - if results.get('label_map', None) is not None: - # Add deep copy to solve bug of repeatedly - # replace `gt_semantic_seg`, which is reported in - # https://github.com/open-mmlab/mmsegmentation/pull/1445/ - gt_semantic_seg_copy = gt_semantic_seg.copy() - for old_id, new_id in results['label_map'].items(): - gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id - results['gt_seg_map'] = gt_semantic_seg - - @TRANSFORMS.register_module() class LoadProposals(BaseTransform): """Load proposal pipeline. diff --git a/mmdet/datasets/transforms/transforms.py b/mmdet/datasets/transforms/transforms.py index d85a39561b6..018c15ea585 100644 --- a/mmdet/datasets/transforms/transforms.py +++ b/mmdet/datasets/transforms/transforms.py @@ -6,6 +6,7 @@ import cv2 import mmcv +import numpy import numpy as np from mmcv.image import imresize from mmcv.image.geometric import _scale_size @@ -277,6 +278,83 @@ def _resize_img(self, results): results['keep_ratio'] = self.keep_ratio +@TRANSFORMS.register_module() +class ResizeShortestEdge(BaseTransform): + """Resize the image and mask while keeping the aspect ratio unchanged. + + Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/augmentation_impl.py#L130 # noqa:E501 + + This transform attempts to scale the shorter edge to the given + `scale`, as long as the longer edge does not exceed `max_size`. + If `max_size` is reached, then downscale so that the longer + edge does not exceed `max_size`. + + Required Keys: + - img + - gt_seg_map (optional) + Modified Keys: + - img + - img_shape + - gt_seg_map (optional)) + Added Keys: + - scale + - scale_factor + - keep_ratio + + Args: + scale (Union[int, Tuple[int, int]]): The target short edge length. + If it's tuple, will select the min value as the short edge length. + max_size (int): The maximum allowed longest edge length. + """ + + def __init__(self, + scale: Union[int, Tuple[int, int]], + max_size: Optional[int] = None, + resize_type: str = 'Resize', + **resize_kwargs) -> None: + super().__init__() + self.scale = scale + self.max_size = max_size + + self.resize_cfg = dict(type=resize_type, **resize_kwargs) + self.resize = TRANSFORMS.build({'scale': 0, **self.resize_cfg}) + + def _get_output_shape( + self, img: np.ndarray, + short_edge_length: Union[int, Tuple[int, int]]) -> Tuple[int, int]: + """Compute the target image shape with the given `short_edge_length`. + + Args: + img (np.ndarray): The input image. + short_edge_length (Union[int, Tuple[int, int]]): The target short + edge length. If it's tuple, will select the min value as the + short edge length. + """ + h, w = img.shape[:2] + if isinstance(short_edge_length, int): + size = short_edge_length * 1.0 + elif isinstance(short_edge_length, tuple): + size = min(short_edge_length) * 1.0 + scale = size / min(h, w) + if h < w: + new_h, new_w = size, scale * w + else: + new_h, new_w = scale * h, size + + if self.max_size and max(new_h, new_w) > self.max_size: + scale = self.max_size * 1.0 / max(new_h, new_w) + new_h *= scale + new_w *= scale + + new_h = int(new_h + 0.5) + new_w = int(new_w + 0.5) + return new_w, new_h + + def transform(self, results: dict) -> dict: + self.resize.scale = self._get_output_shape(results['img'], self.scale) + return self.resize(results) + + @TRANSFORMS.register_module() class FixShapeResize(Resize): """Resize images & bbox & seg to the specified size. diff --git a/mmdet/evaluation/metrics/__init__.py b/mmdet/evaluation/metrics/__init__.py index df73bb329dc..e1ec0e46250 100644 --- a/mmdet/evaluation/metrics/__init__.py +++ b/mmdet/evaluation/metrics/__init__.py @@ -12,6 +12,7 @@ from .lvis_metric import LVISMetric from .mot_challenge_metric import MOTChallengeMetric from .openimages_metric import OpenImagesMetric +from .refseg_metric import RefSegMetric from .reid_metric import ReIDMetrics from .semseg_metric import SemSegMetric from .voc_metric import VOCMetric @@ -22,5 +23,5 @@ 'VOCMetric', 'LVISMetric', 'CrowdHumanMetric', 'DumpProposals', 'CocoOccludedSeparatedMetric', 'DumpDetResults', 'BaseVideoMetric', 'MOTChallengeMetric', 'CocoVideoMetric', 'ReIDMetrics', 'YouTubeVISMetric', - 'COCOCaptionMetric', 'SemSegMetric' + 'COCOCaptionMetric', 'SemSegMetric', 'RefSegMetric' ] diff --git a/mmdet/evaluation/metrics/coco_caption_metric.py b/mmdet/evaluation/metrics/coco_caption_metric.py index ab05d91424e..d8c7350150f 100644 --- a/mmdet/evaluation/metrics/coco_caption_metric.py +++ b/mmdet/evaluation/metrics/coco_caption_metric.py @@ -63,7 +63,7 @@ def process(self, data_batch, data_samples): result = dict() result['caption'] = data_sample['pred_caption'] - result['image_id'] = data_sample['img_id'] + result['image_id'] = int(data_sample['img_id']) # Save the result to `self.results`. self.results.append(result) @@ -85,7 +85,7 @@ def compute_metrics(self, results: List): eval_result_file = save_result( result=results, result_dir=temp_dir, - filename='m4-caption_pred', + filename='caption_pred', remove_duplicate='image_id', ) diff --git a/mmdet/evaluation/metrics/coco_panoptic_metric.py b/mmdet/evaluation/metrics/coco_panoptic_metric.py index 475e51dbc19..1554c0908d1 100644 --- a/mmdet/evaluation/metrics/coco_panoptic_metric.py +++ b/mmdet/evaluation/metrics/coco_panoptic_metric.py @@ -268,12 +268,16 @@ def _parse_predictions(self, result['img_id'] = img_id # shape (1, H, W) -> (H, W) pan = pred['pred_panoptic_seg']['sem_seg'].cpu().numpy()[0] + ignore_index = pred['pred_panoptic_seg'].get( + 'ignore_index', len(self.dataset_meta['classes'])) pan_labels = np.unique(pan) segments_info = [] for pan_label in pan_labels: sem_label = pan_label % INSTANCE_OFFSET - # We reserve the length of dataset_meta['classes'] for VOID label - if sem_label == len(self.dataset_meta['classes']): + # We reserve the length of dataset_meta['classes'] + # and ignore_index for VOID label + if sem_label == len( + self.dataset_meta['classes']) or sem_label == ignore_index: continue mask = pan == pan_label area = mask.sum() @@ -290,6 +294,8 @@ def _parse_predictions(self, }) # evaluation script uses 0 for VOID label. pan[pan % INSTANCE_OFFSET == len(self.dataset_meta['classes'])] = VOID + pan[pan % INSTANCE_OFFSET == ignore_index] = VOID + pan = id2rgb(pan).astype(np.uint8) mmcv.imwrite(pan[:, :, ::-1], osp.join(self.seg_out_dir, segm_file)) result = { diff --git a/mmdet/evaluation/metrics/refseg_metric.py b/mmdet/evaluation/metrics/refseg_metric.py new file mode 100644 index 00000000000..0faee07007e --- /dev/null +++ b/mmdet/evaluation/metrics/refseg_metric.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +from mmengine.evaluator import BaseMetric + +from mmdet.registry import METRICS + + +@METRICS.register_module() +class RefSegMetric(BaseMetric): + """Referring Expression Segmentation Metric.""" + + def __init__(self, metric: Sequence = ('cIoU', 'mIoU'), **kwargs): + super().__init__(**kwargs) + assert set(metric).issubset(['cIoU', 'mIoU']), \ + f'Only support cIoU and mIoU, but got {metric}' + assert len(metric) > 0, 'metrics should not be empty' + self.metrics = metric + + def compute_iou(self, pred_seg: torch.Tensor, + gt_seg: torch.Tensor) -> tuple: + overlap = pred_seg & gt_seg + union = pred_seg | gt_seg + return overlap, union + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data and data_samples. + + The processed results should be stored in ``self.results``, which will + be used to compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + pred_label = data_sample['pred_instances']['masks'].bool() + label = data_sample['gt_masks'].to_tensor( + pred_label.dtype, pred_label.device).bool() + # calculate iou + overlap, union = self.compute_iou(pred_label, label) + + bs = len(pred_label) + iou = overlap.reshape(bs, -1).sum(-1) * 1.0 / union.reshape( + bs, -1).sum(-1) + iou = torch.nan_to_num_(iou, nan=0.0) + self.results.append((overlap.sum(), union.sum(), iou.sum(), bs)) + + def compute_metrics(self, results: list) -> dict: + results = tuple(zip(*results)) + assert len(results) == 4 + cum_i = sum(results[0]) + cum_u = sum(results[1]) + iou = sum(results[2]) + seg_total = sum(results[3]) + + metrics = {} + if 'cIoU' in self.metrics: + metrics['cIoU'] = cum_i * 100 / cum_u + if 'mIoU' in self.metrics: + metrics['mIoU'] = iou * 100 / seg_total + return metrics diff --git a/mmdet/evaluation/metrics/semseg_metric.py b/mmdet/evaluation/metrics/semseg_metric.py index 6b12d4a0b0b..3215f6788a6 100644 --- a/mmdet/evaluation/metrics/semseg_metric.py +++ b/mmdet/evaluation/metrics/semseg_metric.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp from collections import OrderedDict -from typing import Dict, List, Optional, Sequence, Union +from typing import Dict, Optional, Sequence, Union import numpy as np import torch @@ -47,20 +47,20 @@ class SemSegMetric(BaseMetric): """ def __init__(self, - iou_metrics: List[str] = ['mIoU'], + iou_metrics: Sequence[str] = ['mIoU'], beta: int = 1, collect_device: str = 'cpu', output_dir: Optional[str] = None, format_only: bool = False, backend_args: dict = None, - prefix: Optional[str] = None, - **kwargs) -> None: + prefix: Optional[str] = None) -> None: super().__init__(collect_device=collect_device, prefix=prefix) if isinstance(iou_metrics, str): iou_metrics = [iou_metrics] if not set(iou_metrics).issubset(set(['mIoU', 'mDice', 'mFscore'])): - raise KeyError(f'metrics {iou_metrics} is not supported') + raise KeyError(f'metrics {iou_metrics} is not supported. ' + f'Only supports mIoU/mDice/mFscore.') self.metrics = iou_metrics self.beta = beta self.output_dir = output_dir @@ -86,8 +86,12 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: if not self.format_only: label = data_sample['gt_sem_seg']['sem_seg'].squeeze().to( pred_label) + ignore_index = data_sample['pred_sem_seg'].get( + 'ignore_index', 255) self.results.append( - self._compute_pred_stats(pred_label, label, num_classes)) + self._compute_pred_stats(pred_label, label, num_classes, + ignore_index)) + # format_result if self.output_dir is not None: basename = osp.splitext(osp.basename( @@ -134,7 +138,8 @@ def compute_metrics(self, results: list) -> Dict[str, float]: return metrics def _compute_pred_stats(self, pred_label: torch.tensor, - label: torch.tensor, num_classes: int): + label: torch.tensor, num_classes: int, + ignore_index: int): """Parse semantic segmentation predictions. Args: @@ -149,20 +154,20 @@ def _compute_pred_stats(self, pred_label: torch.tensor, histogram on all classes. torch.Tensor: The union of prediction and ground truth histogram on all classes. - torch.Tens6or: The prediction histogram on all classes. + torch.Tensor: The prediction histogram on all classes. torch.Tensor: The ground truth histogram on all classes. """ assert pred_label.shape == label.shape - # 0 is background - mask = label != 0 - pred_label = (pred_label + 1) * mask + mask = label != ignore_index + label, pred_label = label[mask], pred_label[mask] + intersect = pred_label[pred_label == label] area_intersect = torch.histc( - intersect.float(), bins=(num_classes), min=1, max=num_classes) + intersect.float(), bins=num_classes, min=0, max=num_classes - 1) area_pred_label = torch.histc( - pred_label.float(), bins=(num_classes), min=1, max=num_classes) + pred_label.float(), bins=num_classes, min=0, max=num_classes - 1) area_label = torch.histc( - label.float(), bins=(num_classes), min=1, max=num_classes) + label.float(), bins=num_classes, min=0, max=num_classes - 1) area_union = area_pred_label + area_label - area_intersect result = dict( area_intersect=area_intersect, diff --git a/mmdet/models/detectors/glip.py b/mmdet/models/detectors/glip.py index f39c8e9fe76..7951e3ecb15 100644 --- a/mmdet/models/detectors/glip.py +++ b/mmdet/models/detectors/glip.py @@ -224,7 +224,7 @@ def predict(self, the last dimension 4 arrange as (x1, y1, x2, y2). """ text_prompts = [ - data_samples.caption for data_samples in batch_data_samples + data_samples.text for data_samples in batch_data_samples ] if 'custom_entities' in batch_data_samples[0]: diff --git a/mmdet/testing/_utils.py b/mmdet/testing/_utils.py index 4f5a761ea28..c4d3a86deab 100644 --- a/mmdet/testing/_utils.py +++ b/mmdet/testing/_utils.py @@ -96,7 +96,7 @@ def demo_mm_inputs(batch_size=2, with_semantic=False, use_box_type=False, device='cpu', - captions=None, + texts=None, custom_entities=False): """Create a superset of inputs needed to run test or train batches. @@ -124,8 +124,8 @@ def demo_mm_inputs(batch_size=2, if isinstance(num_items, list): assert len(num_items) == batch_size - if captions is not None: - assert batch_size == len(captions) + if texts is not None: + assert batch_size == len(texts) packed_inputs = [] for idx in range(batch_size): @@ -148,8 +148,8 @@ def demo_mm_inputs(batch_size=2, 'border': [1, 1, 1, 1] # Only used by CenterNet } - if captions: - img_meta['caption'] = captions[idx] + if texts: + img_meta['text'] = texts[idx] img_meta['custom_entities'] = custom_entities data_sample = DetDataSample() diff --git a/mmdet/visualization/local_visualizer.py b/mmdet/visualization/local_visualizer.py index 30645b7eedc..cc6521c56eb 100644 --- a/mmdet/visualization/local_visualizer.py +++ b/mmdet/visualization/local_visualizer.py @@ -123,7 +123,7 @@ def _draw_instances(self, image: np.ndarray, instances: ['InstanceData'], """ self.set_image(image) - if 'bboxes' in instances: + if 'bboxes' in instances and instances.bboxes.sum() > 0: bboxes = instances.bboxes labels = instances.labels @@ -211,8 +211,11 @@ def _draw_instances(self, image: np.ndarray, instances: ['InstanceData'], scales = _get_adaptive_scales(areas) for i, (pos, label) in enumerate(zip(positions, labels)): - label_text = classes[ - label] if classes is not None else f'class {label}' + if 'label_names' in instances: + label_text = instances.label_names[i] + else: + label_text = classes[ + label] if classes is not None else f'class {label}' if 'scores' in instances: score = round(float(instances.scores[i]) * 100, 1) label_text += f': {score}' @@ -233,7 +236,8 @@ def _draw_instances(self, image: np.ndarray, instances: ['InstanceData'], def _draw_panoptic_seg(self, image: np.ndarray, panoptic_seg: ['PixelData'], - classes: Optional[List[str]]) -> np.ndarray: + classes: Optional[List[str]], + palette: Optional[List]) -> np.ndarray: """Draw panoptic seg of GT or prediction. Args: @@ -248,16 +252,28 @@ def _draw_panoptic_seg(self, image: np.ndarray, # TODO: Is there a way to bypass? num_classes = len(classes) - panoptic_seg = panoptic_seg.sem_seg[0] - ids = np.unique(panoptic_seg)[::-1] - legal_indices = ids != num_classes # for VOID label - ids = ids[legal_indices] + panoptic_seg_data = panoptic_seg.sem_seg[0] + + ids = np.unique(panoptic_seg_data)[::-1] + + if 'label_names' in panoptic_seg: + # open set panoptic segmentation + classes = panoptic_seg.metainfo['label_names'] + ignore_index = panoptic_seg.metainfo.get('ignore_index', + len(classes)) + ids = ids[ids != ignore_index] + else: + # for VOID label + ids = ids[ids != num_classes] labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64) - segms = (panoptic_seg[None] == ids[:, None, None]) + segms = (panoptic_seg_data[None] == ids[:, None, None]) max_label = int(max(labels) if len(labels) > 0 else 0) - mask_palette = get_palette(self.mask_color, max_label + 1) + + mask_color = palette if self.mask_color is None \ + else self.mask_color + mask_palette = get_palette(mask_color, max_label + 1) colors = [mask_palette[label] for label in labels] self.set_image(image) @@ -302,6 +318,77 @@ def _draw_panoptic_seg(self, image: np.ndarray, horizontal_alignments='center') return self.get_image() + def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData, + classes: Optional[List], + palette: Optional[List]) -> np.ndarray: + """Draw semantic seg of GT or prediction. + + Args: + image (np.ndarray): The image to draw. + sem_seg (:obj:`PixelData`): Data structure for pixel-level + annotations or predictions. + classes (list, optional): Input classes for result rendering, as + the prediction of segmentation model is a segment map with + label indices, `classes` is a list which includes items + responding to the label indices. If classes is not defined, + visualizer will take `cityscapes` classes by default. + Defaults to None. + palette (list, optional): Input palette for result rendering, which + is a list of color palette responding to the classes. + Defaults to None. + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + sem_seg_data = sem_seg.sem_seg + if isinstance(sem_seg_data, torch.Tensor): + sem_seg_data = sem_seg_data.numpy() + + # 0 ~ num_class, the value 0 means background + ids = np.unique(sem_seg_data) + ignore_index = sem_seg.metainfo.get('ignore_index', 255) + ids = ids[ids != ignore_index] + + if 'label_names' in sem_seg: + # open set semseg + label_names = sem_seg.metainfo['label_names'] + else: + label_names = classes + + labels = np.array(ids, dtype=np.int64) + colors = [palette[label] for label in labels] + + self.set_image(image) + + # draw semantic masks + for i, (label, color) in enumerate(zip(labels, colors)): + masks = sem_seg_data == label + self.draw_binary_masks(masks, colors=[color], alphas=self.alpha) + label_text = label_names[label] + _, _, stats, centroids = cv2.connectedComponentsWithStats( + masks[0].astype(np.uint8), connectivity=8) + if stats.shape[0] > 1: + largest_id = np.argmax(stats[1:, -1]) + 1 + centroids = centroids[largest_id] + + areas = stats[largest_id, -1] + scales = _get_adaptive_scales(areas) + + self.draw_texts( + label_text, + centroids, + colors=(255, 255, 255), + font_sizes=int(13 * scales), + horizontal_alignments='center', + bboxes=[{ + 'facecolor': 'black', + 'alpha': 0.8, + 'pad': 0.7, + 'edgecolor': 'none' + }]) + + return self.get_image() + @master_only def add_datasample( self, @@ -359,6 +446,10 @@ def add_datasample( gt_img_data = self._draw_instances(image, data_sample.gt_instances, classes, palette) + if 'gt_sem_seg' in data_sample: + gt_img_data = self._draw_sem_seg(gt_img_data, + data_sample.gt_sem_seg, + classes, palette) if 'gt_panoptic_seg' in data_sample: assert classes is not None, 'class information is ' \ @@ -366,7 +457,7 @@ def add_datasample( 'visualizing panoptic ' \ 'segmentation results.' gt_img_data = self._draw_panoptic_seg( - gt_img_data, data_sample.gt_panoptic_seg, classes) + gt_img_data, data_sample.gt_panoptic_seg, classes, palette) if draw_pred and data_sample is not None: pred_img_data = image @@ -376,6 +467,12 @@ def add_datasample( pred_instances.scores > pred_score_thr] pred_img_data = self._draw_instances(image, pred_instances, classes, palette) + + if 'pred_sem_seg' in data_sample: + pred_img_data = self._draw_sem_seg(pred_img_data, + data_sample.pred_sem_seg, + classes, palette) + if 'pred_panoptic_seg' in data_sample: assert classes is not None, 'class information is ' \ 'not provided when ' \ @@ -383,7 +480,7 @@ def add_datasample( 'segmentation results.' pred_img_data = self._draw_panoptic_seg( pred_img_data, data_sample.pred_panoptic_seg.numpy(), - classes) + classes, palette) if gt_img_data is not None and pred_img_data is not None: drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) diff --git a/projects/XDecoder/README.md b/projects/XDecoder/README.md new file mode 100644 index 00000000000..b739fdfa92d --- /dev/null +++ b/projects/XDecoder/README.md @@ -0,0 +1,245 @@ +# X-Decoder + +> [X-Decoder: Generalized Decoding for Pixel, Image, and Language](https://arxiv.org/pdf/2212.11270.pdf) + + + +## Abstract + +We present X-Decoder, a generalized decoding model that can predict pixel-level segmentation and language tokens seamlessly. X-Decodert takes as input two types of queries: (i) generic non-semantic queries and (ii) semantic queries induced from text inputs, to decode different pixel-level and token-level outputs in the same semantic space. With such a novel design, X-Decoder is the first work that provides a unified way to support all types of image segmentation and a variety of vision-language (VL) tasks. Further, our design enables seamless interactions across tasks at different granularities and brings mutual benefits by learning a common and rich pixel-level visual-semantic understanding space, without any pseudo-labeling. After pretraining on a mixed set of a limited amount of segmentation data and millions of image-text pairs, X-Decoder exhibits strong transferability to a wide range of downstream tasks in both zero-shot and finetuning settings. Notably, it achieves (1) state-of-the-art results on open-vocabulary segmentation and referring segmentation on eight datasets; (2) better or competitive finetuned performance to other generalist and specialist models on segmentation and VL tasks; and (3) flexibility for efficient finetuning and novel task composition (e.g., referring captioning and image editing). + +
+ +
+ +## Installation + +```shell +# if source +pip install -r requirements/multimodal.txt + +# if wheel +mim install mmdet[multimodal] +``` + +## How to use it? + +For convenience, you can download the weights to the `mmdetection` root dir + +```shell +wget https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_last_novg.pt +wget https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_best_openseg.pt +``` + +The above two weights are directly copied from the official website without any modification. The specific source is https://github.com/microsoft/X-Decoder + +For convenience of demonstration, please download [the folder](https://github.com/microsoft/X-Decoder/tree/main/images) and place it in the root directory of mmdetection. + +**(1) Open Vocabulary Semantic Segmentation** + +```shell +cd projects/XDecoder +python demo.py ../../images/animals.png configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py --weights ../../xdecoder_focalt_last_novg.pt --texts zebra.giraffe +``` + +
+ +
+ +**(2) Open Vocabulary Instance Segmentation** + +```shell +cd projects/XDecoder +python demo.py ../../images/owls.jpeg configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py --weights ../../xdecoder_focalt_last_novg.pt --texts owl +``` + +
+ +
+ +**(3) Open Vocabulary Panoptic Segmentation** + +```shell +cd projects/XDecoder +python demo.py ../../images/street.jpg configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_coco.py --weights ../../xdecoder_focalt_last_novg.pt --text car.person --stuff-text tree.sky +``` + +
+ +
+ +**(4) Referring Expression Segmentation** + +```shell +cd projects/XDecoder +python demo.py ../../images/fruit.jpg configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcocog.py --weights ../../xdecoder_focalt_last_novg.pt --text "The larger watermelon. The front white flower. White tea pot." +``` + +
+ +
+ +**(5) Image Caption** + +```shell +cd projects/XDecoder +python demo.py ../../images/penguin.jpeg configs/xdecoder-tiny_zeroshot_caption_coco2014.py --weights ../../xdecoder_focalt_last_novg.pt +``` + +
+ +
+ +**(6) Referring Expression Image Caption** + +```shell +cd projects/XDecoder +python demo.py ../../images/fruit.jpg configs/xdecoder-tiny_zeroshot_ref-caption.py --weights ../../xdecoder_focalt_last_novg.pt --text 'White tea pot' +``` + +
+ +
+ +**(7) Text Image Region Retrieval** + +```shell +cd projects/XDecoder +python demo.py ../../images/coco configs/xdecoder-tiny_zeroshot_text-image-retrieval.py --weights ../../xdecoder_focalt_last_novg.pt --text 'pizza on the plate' +``` + +```text +The image that best matches the given text is ../../images/coco/000.jpg and probability is 0.998 +``` + +
+ +
+ +We have also prepared a gradio program in the `projects/gradio_demo` directory, which you can run interactively all the inference supported by mmdetection in your browser. + +## Models and results + +### Semantic segmentation on ADE20K + +Prepare your dataset according to the [docs](../../docs/en/user_guides/dataset_prepare.md#ade20k-2016-dataset-preparation). + +**Test Command** + +Since semantic segmentation is a pixel-level task, we don't need to use a threshold to filter out low-confidence predictions. So we set `model.test_cfg.use_thr_for_mc=False` in the test command. + +```shell +./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_ade20k.py xdecoder_focalt_best_openseg.pt 8 --cfg-options model.test_cfg.use_thr_for_mc=False +``` + +| Model | mIoU | mIOU(official) | Config | +| :-------------------------------- | :---: | :------------: | :------------------------------------------------------------------: | +| `xdecoder_focalt_best_openseg.pt` | 25.24 | 25.13 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-semseg_ade20k.py) | + +### Instance segmentation on ADE20K + +Prepare your dataset according to the [docs](../../docs/en/user_guides/dataset_prepare.md#ade20k-2016-dataset-preparation). + +```shell +./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-instance_ade20k.py xdecoder_focalt_best_openseg.pt 8 +``` + +| Model | mIoU | mIOU(official) | Config | +| :-------------------------------- | :--: | :------------: | :--------------------------------------------------------------------: | +| `xdecoder_focalt_best_openseg.pt` | 10.1 | 10.1 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-instance_ade20k.py) | + +### Panoptic segmentation on ADE20K + +Prepare your dataset according to the [docs](../../docs/en/user_guides/dataset_prepare.md#ade20k-2016-dataset-preparation). + +```shell +./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_ade20k.py xdecoder_focalt_best_openseg.pt 8 +``` + +| Model | mIoU | mIOU(official) | Config | +| :-------------------------------- | :---: | :------------: | :--------------------------------------------------------------------: | +| `xdecoder_focalt_best_openseg.pt` | 19.11 | 18.97 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_ade20k.py) | + +### Semantic segmentation on COCO2017 + +Prepare your dataset according to the [docs](../../docs/en/user_guides/dataset_prepare.md#coco-semantic-dataset-preparation) of `(2) use panoptic dataset` part. + +```shell +./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py xdecoder_focalt_last_novg.pt 8 --cfg-options model.test_cfg.use_thr_for_mc=False +``` + +| Model | mIOU | mIOU(official) | Config | +| :---------------------------------------------- | :--: | :------------: | :----------------------------------------------------------------: | +| `xdecoder-tiny_zeroshot_open-vocab-semseg_coco` | 62.1 | 62.1 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py) | + +### Instance segmentation on COCO2017 + +Prepare your dataset according to the [docs](../../docs/en/user_guides/dataset_prepare.md#basic-detection-dataset-preparation). + +```shell +./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py xdecoder_focalt_last_novg.pt 8 +``` + +| Model | Mask mAP | Mask mAP(official) | Config | +| :------------------------------------------------ | :------: | :----------------: | :------------------------------------------------------------------: | +| `xdecoder-tiny_zeroshot_open-vocab-instance_coco` | 39.8 | 39.7 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py) | + +### Panoptic segmentation on COCO2017 + +Prepare your dataset according to the [docs](../../docs/en/user_guides/dataset_prepare.md#basic-detection-dataset-preparation). + +```shell +./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_coco.py xdecoder_focalt_last_novg.pt 8 +``` + +| Model | PQ | PQ(official) | Config | +| :------------------------------------------------ | :---: | :----------: | :------------------------------------------------------------------: | +| `xdecoder-tiny_zeroshot_open-vocab-panoptic_coco` | 51.42 | 51.16 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_coco.py) | + +### Referring segmentation on RefCOCO + +Prepare your dataset according to the [docs](../../docs/en/user_guides/dataset_prepare.md#refcoco-dataset-preparation). + +```shell +./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcocog.py xdecoder_focalt_last_novg.pt 8 --cfg-options test_dataloader.dataset.split='val' +``` + +| Model | text mode | cIoU | cIOU(official) | Config | +| :----------------------------- | :----------: | :-----: | :------------: | :---------------------------------------------------------------------: | +| `xdecoder_focalt_last_novg.pt` | select first | 58.8415 | 57.85 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcocog.py) | +| `xdecoder_focalt_last_novg.pt` | original | 60.0321 | - | [config](configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcocog.py) | +| `xdecoder_focalt_last_novg.pt` | concat | 60.3551 | - | [config](configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcocog.py) | + +**Note:** + +1. If you set the scale of `Resize` to (1024, 512), the result will be `57.69`. +2. `text mode` is the `RefCoCoDataset` parameter in MMDetection, it determines the texts loaded to the data list. It can be set to `select_first`, `original`, `concat` and `random`. + - `select_first`: select the first text in the text list as the description to an instance. + - `original`: use all texts in the text list as the description to an instance. + - `concat`: concatenate all texts in the text list as the description to an instance. + - `random`: randomly select one text in the text list as the description to an instance, usually used for training. + +### Image Caption on COCO2014 + +Prepare your dataset according to the [docs](../../docs/en/user_guides/dataset_prepare.md#coco-caption-dataset-preparation). + +Before testing, you need to install jdk 1.8, otherwise it will prompt that java does not exist during the evaluation process + +``` +./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_caption_coco2014.py xdecoder_focalt_last_novg.pt 8 +``` + +| Model | BLEU-4 | CIDER | Config | +| :---------------------------------------- | :----: | :----: | :----------------------------------------------------------: | +| `xdecoder-tiny_zeroshot_caption_coco2014` | 35.26 | 116.81 | [config](configs/xdecoder-tiny_zeroshot_caption_coco2014.py) | + +## Citation + +```latex +@article{zou2022xdecoder, + author = {Zou*, Xueyan and Dou*, Zi-Yi and Yang*, Jianwei and Gan, Zhe and Li, Linjie and Li, Chunyuan and Dai, Xiyang and Wang, Jianfeng and Yuan, Lu and Peng, Nanyun and Wang, Lijuan and Lee*, Yong Jae and Gao*, Jianfeng}, + title = {Generalized Decoding for Pixel, Image and Language}, + publisher = {arXiv}, + year = {2022}, +} +``` diff --git a/projects/XDecoder/configs/_base_/xdecoder-tiny_caption.py b/projects/XDecoder/configs/_base_/xdecoder-tiny_caption.py new file mode 100644 index 00000000000..16b16465939 --- /dev/null +++ b/projects/XDecoder/configs/_base_/xdecoder-tiny_caption.py @@ -0,0 +1,3 @@ +_base_ = 'xdecoder-tiny_open-vocab-semseg.py' + +model = dict(head=dict(task='caption')) diff --git a/projects/XDecoder/configs/_base_/xdecoder-tiny_open-vocab-instance.py b/projects/XDecoder/configs/_base_/xdecoder-tiny_open-vocab-instance.py new file mode 100644 index 00000000000..ca2cb3e3ac1 --- /dev/null +++ b/projects/XDecoder/configs/_base_/xdecoder-tiny_open-vocab-instance.py @@ -0,0 +1,3 @@ +_base_ = 'xdecoder-tiny_open-vocab-semseg.py' + +model = dict(head=dict(task='instance'), test_cfg=dict(max_per_img=100)) diff --git a/projects/XDecoder/configs/_base_/xdecoder-tiny_open-vocab-panoptic.py b/projects/XDecoder/configs/_base_/xdecoder-tiny_open-vocab-panoptic.py new file mode 100644 index 00000000000..0eaac442289 --- /dev/null +++ b/projects/XDecoder/configs/_base_/xdecoder-tiny_open-vocab-panoptic.py @@ -0,0 +1,4 @@ +_base_ = 'xdecoder-tiny_open-vocab-semseg.py' + +model = dict( + head=dict(task='panoptic'), test_cfg=dict(mask_thr=0.8, overlap_thr=0.8)) diff --git a/projects/XDecoder/configs/_base_/xdecoder-tiny_open-vocab-semseg.py b/projects/XDecoder/configs/_base_/xdecoder-tiny_open-vocab-semseg.py new file mode 100644 index 00000000000..0ffef0f8d99 --- /dev/null +++ b/projects/XDecoder/configs/_base_/xdecoder-tiny_open-vocab-semseg.py @@ -0,0 +1,29 @@ +_base_ = 'mmdet::_base_/default_runtime.py' + +custom_imports = dict( + imports=['projects.XDecoder.xdecoder'], allow_failed_imports=False) + +model = dict( + type='XDecoder', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + backbone=dict(type='FocalNet'), + head=dict( + type='XDecoderUnifiedhead', + in_channels=(96, 192, 384, 768), + pixel_decoder=dict(type='XTransformerEncoderPixelDecoder'), + transformer_decoder=dict(type='XDecoderTransformerDecoder'), + task='semseg', + ), + # use_thr_for_mc=True means use threshold for multi-class + # This parameter is only used in semantic segmentation task and + # referring semantic segmentation task. + test_cfg=dict(mask_thr=0.5, use_thr_for_mc=True, ignore_index=255), +) + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') diff --git a/projects/XDecoder/configs/_base_/xdecoder-tiny_ref-seg.py b/projects/XDecoder/configs/_base_/xdecoder-tiny_ref-seg.py new file mode 100644 index 00000000000..6101474b8e1 --- /dev/null +++ b/projects/XDecoder/configs/_base_/xdecoder-tiny_ref-seg.py @@ -0,0 +1,3 @@ +_base_ = 'xdecoder-tiny_open-vocab-semseg.py' + +model = dict(head=dict(task='ref-seg')) diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_caption_coco2014.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_caption_coco2014.py new file mode 100644 index 00000000000..963c7c61e09 --- /dev/null +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_caption_coco2014.py @@ -0,0 +1,18 @@ +_base_ = [ + '_base_/xdecoder-tiny_caption.py', 'mmdet::_base_/datasets/coco_caption.py' +] + +test_pipeline = [ + dict( + type='LoadImageFromFile', + imdecode_backend='pillow', + backend_args=_base_.backend_args), + dict(type='ResizeShortestEdge', scale=224, backend='pillow'), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-instance_ade20k.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-instance_ade20k.py new file mode 100644 index 00000000000..4f61ae6e337 --- /dev/null +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-instance_ade20k.py @@ -0,0 +1,20 @@ +_base_ = [ + '_base_/xdecoder-tiny_open-vocab-instance.py', + 'mmdet::_base_/datasets/ade20k_instance.py' +] + +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args=_base_.backend_args), + dict(type='Resize', scale=(2560, 640), keep_ratio=True), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'text')) +] + +val_dataloader = dict( + dataset=dict(return_classes=True, pipeline=test_pipeline)) +test_dataloader = val_dataloader + +test_evaluator = dict(metric=['segm']) diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py new file mode 100644 index 00000000000..d978cf2fa8e --- /dev/null +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py @@ -0,0 +1,27 @@ +_base_ = [ + '_base_/xdecoder-tiny_open-vocab-instance.py', + 'mmdet::_base_/datasets/coco_instance.py' +] + +test_pipeline = [ + dict( + type='LoadImageFromFile', + imdecode_backend='pillow', + backend_args=_base_.backend_args), + dict( + type='ResizeShortestEdge', scale=800, max_size=1333, backend='pillow'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'text')) +] + +val_dataloader = dict( + dataset=dict(pipeline=test_pipeline, return_classes=True)) +test_dataloader = val_dataloader + +val_evaluator = dict(metric='segm') +test_evaluator = val_evaluator + +train_dataloader = None diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_ade20k.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_ade20k.py new file mode 100644 index 00000000000..7c97045a989 --- /dev/null +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_ade20k.py @@ -0,0 +1,51 @@ +_base_ = [ + '_base_/xdecoder-tiny_open-vocab-panoptic.py', + 'mmdet::_base_/datasets/ade20k_panoptic.py' +] + +model = dict(test_cfg=dict(mask_thr=0.4)) + +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args=_base_.backend_args), + dict(type='Resize', scale=(2560, 640), keep_ratio=True), + dict(type='LoadPanopticAnnotations', backend_args=_base_.backend_args), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'text', 'stuff_text')) +] + +x_decoder_ade20k_thing_classes = ( + 'bed', 'window', 'cabinet', 'person', 'door', 'table', 'curtain', 'chair', + 'car', 'painting', 'sofa', 'shelf', 'mirror', 'armchair', 'seat', 'fence', + 'desk', 'wardrobe', 'lamp', 'tub', 'rail', 'cushion', 'box', 'column', + 'signboard', 'chest of drawers', 'counter', 'sink', 'fireplace', + 'refrigerator', 'stairs', 'case', 'pool table', 'pillow', 'screen door', + 'bookcase', 'coffee table', 'toilet', 'flower', 'book', 'bench', + 'countertop', 'stove', 'palm', 'kitchen island', 'computer', + 'swivel chair', 'boat', 'arcade machine', 'bus', 'towel', 'light', 'truck', + 'chandelier', 'awning', 'street lamp', 'booth', 'tv', 'airplane', + 'clothes', 'pole', 'bannister', 'ottoman', 'bottle', 'van', 'ship', + 'fountain', 'washer', 'plaything', 'stool', 'barrel', 'basket', 'bag', + 'minibike', 'oven', 'ball', 'food', 'step', 'trade name', 'microwave', + 'pot', 'animal', 'bicycle', 'dishwasher', 'screen', 'sculpture', 'hood', + 'sconce', 'vase', 'traffic light', 'tray', 'trash can', 'fan', 'plate', + 'monitor', 'bulletin board', 'radiator', 'glass', 'clock', 'flag') + +x_decoder_ade20k_stuff_classes = ( + 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'grass', + 'sidewalk', 'earth', 'mountain', 'plant', 'water', 'house', 'sea', 'rug', + 'field', 'rock', 'base', 'sand', 'skyscraper', 'grandstand', 'path', + 'runway', 'stairway', 'river', 'bridge', 'blind', 'hill', 'bar', 'hovel', + 'tower', 'dirt track', 'land', 'escalator', 'buffet', 'poster', 'stage', + 'conveyer belt', 'canopy', 'pool', 'falls', 'tent', 'cradle', 'tank', + 'lake', 'blanket', 'pier', 'crt screen', 'shower') + +val_dataloader = dict( + dataset=dict( + metainfo=dict( + thing_classes=x_decoder_ade20k_thing_classes, + stuff_classes=x_decoder_ade20k_stuff_classes), + return_classes=True, + pipeline=test_pipeline)) +test_dataloader = val_dataloader diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_coco.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_coco.py new file mode 100644 index 00000000000..025e54beb14 --- /dev/null +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_coco.py @@ -0,0 +1,27 @@ +_base_ = [ + '_base_/xdecoder-tiny_open-vocab-panoptic.py', + 'mmdet::_base_/datasets/coco_panoptic.py' +] + +model = dict(test_cfg=dict(mask_thr=0.4)) + +test_pipeline = [ + dict( + type='LoadImageFromFile', + imdecode_backend='pillow', + backend_args=_base_.backend_args), + dict( + type='ResizeShortestEdge', scale=800, max_size=1333, backend='pillow'), + dict(type='LoadPanopticAnnotations', backend_args=_base_.backend_args), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'text', 'stuff_text')) +] + +val_dataloader = dict( + dataset=dict(pipeline=test_pipeline, return_classes=True)) + +test_dataloader = val_dataloader + +train_dataloader = None diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcoco+.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcoco+.py new file mode 100644 index 00000000000..948c9d72c9a --- /dev/null +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcoco+.py @@ -0,0 +1,3 @@ +_base_ = [ + '_base_/xdecoder-tiny_ref-seg.py', 'mmdet::_base_/datasets/refcoco+.py' +] diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcoco.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcoco.py new file mode 100644 index 00000000000..e6215758a15 --- /dev/null +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcoco.py @@ -0,0 +1,3 @@ +_base_ = [ + '_base_/xdecoder-tiny_ref-seg.py', 'mmdet::_base_/datasets/refcoco.py' +] diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcocog.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcocog.py new file mode 100644 index 00000000000..eb7474efa52 --- /dev/null +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcocog.py @@ -0,0 +1,3 @@ +_base_ = [ + '_base_/xdecoder-tiny_ref-seg.py', 'mmdet::_base_/datasets/refcocog.py' +] diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_ade20k.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_ade20k.py new file mode 100644 index 00000000000..1fe990b42d4 --- /dev/null +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_ade20k.py @@ -0,0 +1,50 @@ +_base_ = [ + '_base_/xdecoder-tiny_open-vocab-semseg.py', + 'mmdet::_base_/datasets/ade20k_semantic.py' +] + +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args=_base_.backend_args), + dict(type='Resize', scale=(2560, 640), keep_ratio=True), + dict( + type='LoadAnnotations', + with_bbox=False, + with_mask=False, + with_seg=True, + reduce_zero_label=True), + dict( + type='PackDetInputs', + meta_keys=('img_path', 'ori_shape', 'img_shape', 'text')) +] + +x_decoder_ade20k_classes = ( + 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed', + 'window', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', 'door', + 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water', + 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', + 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', 'tub', + 'rail', 'cushion', 'base', 'box', 'column', 'signboard', + 'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace', + 'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case', + 'pool table', 'pillow', 'screen door', 'stairway', 'river', 'bridge', + 'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', + 'bench', 'countertop', 'stove', 'palm', 'kitchen island', 'computer', + 'swivel chair', 'boat', 'bar', 'arcade machine', 'hovel', 'bus', 'towel', + 'light', 'truck', 'tower', 'chandelier', 'awning', 'street lamp', 'booth', + 'tv', 'airplane', 'dirt track', 'clothes', 'pole', 'land', 'bannister', + 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', + 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', + 'pool', 'stool', 'barrel', 'basket', 'falls', 'tent', 'bag', 'minibike', + 'cradle', 'oven', 'ball', 'food', 'step', 'tank', 'trade name', + 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher', 'screen', + 'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray', + 'trash can', 'fan', 'pier', 'crt screen', 'plate', 'monitor', + 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag') + +val_dataloader = dict( + dataset=dict( + metainfo=dict(classes=x_decoder_ade20k_classes), + return_classes=True, + use_label_map=False, + pipeline=test_pipeline)) +test_dataloader = val_dataloader diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py new file mode 100644 index 00000000000..cd9a7eccfe6 --- /dev/null +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py @@ -0,0 +1,68 @@ +_base_ = '_base_/xdecoder-tiny_open-vocab-semseg.py' + +dataset_type = 'CocoSegDataset' +data_root = 'data/coco/' + +test_pipeline = [ + dict( + type='LoadImageFromFile', imdecode_backend='pillow', + backend_args=None), + dict( + type='ResizeShortestEdge', scale=800, max_size=1333, backend='pillow'), + dict( + type='LoadAnnotations', + with_bbox=False, + with_label=False, + with_seg=True), + dict( + type='PackDetInputs', + meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor', + 'text')) +] + +x_decoder_coco2017_semseg_classes = ( + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', + 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', + 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', + 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', + 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', + 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', + 'hair drier', 'toothbrush', 'banner', 'blanket', 'bridge', 'cardboard', + 'counter', 'curtain', 'door-stuff', 'floor-wood', 'flower', 'fruit', + 'gravel', 'house', 'light', 'mirror-stuff', 'net', 'pillow', 'platform', + 'playingfield', 'railroad', 'river', 'road', 'roof', 'sand', 'sea', + 'shelf', 'snow', 'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', + 'wall-tile', 'wall-wood', 'water-other', 'window-blind', 'window-other', + 'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged', + 'cabinet-merged', 'table-merged', 'floor-other-merged', 'pavement-merged', + 'mountain-merged', 'grass-merged', 'dirt-merged', 'paper-merged', + 'food-other-merged', 'building-other-merged', 'rock-merged', + 'wall-other-merged', 'rug-merged') + +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + metainfo=dict(classes=x_decoder_coco2017_semseg_classes), + use_label_map=False, + data_prefix=dict( + img_path='val2017/', + seg_map_path='annotations/panoptic_semseg_val2017/'), + pipeline=test_pipeline, + return_classes=True)) + +test_dataloader = val_dataloader + +val_evaluator = dict(type='SemSegMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_ref-caption.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_ref-caption.py new file mode 100644 index 00000000000..fc81af198f9 --- /dev/null +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_ref-caption.py @@ -0,0 +1,17 @@ +_base_ = 'xdecoder-tiny_zeroshot_caption_coco2014.py' + +model = dict(head=dict(task='ref-caption')) + +grounding_scale = 512 + +test_pipeline = [ + dict(type='LoadImageFromFile', imdecode_backend='pillow'), + dict(type='ResizeShortestEdge', scale=224, backend='pillow'), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'text')) +] + +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_text-image-retrieval.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_text-image-retrieval.py new file mode 100644 index 00000000000..7523e045273 --- /dev/null +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_text-image-retrieval.py @@ -0,0 +1,24 @@ +_base_ = 'xdecoder-tiny_zeroshot_caption_coco2014.py' + +model = dict(head=dict(task='retrieval')) + +grounding_scale = 512 + +test_pipeline = [ + dict( + type='LoadImageFromFile', + imdecode_backend='pillow', + backend_args=_base_.backend_args), + dict( + type='ResizeShortestEdge', + scale=224, + backend='pillow', + interpolation='bicubic'), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'text')) +] + +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader diff --git a/projects/XDecoder/demo.py b/projects/XDecoder/demo.py new file mode 100644 index 00000000000..fb281c85f1e --- /dev/null +++ b/projects/XDecoder/demo.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from argparse import ArgumentParser + +from mmengine.config import Config +from mmengine.logging import print_log + +from mmdet.apis import DetInferencer +from projects.XDecoder.xdecoder.inference import ( + ImageCaptionInferencer, RefImageCaptionInferencer, + TextToImageRegionRetrievalInferencer) + +TASKINFOS = { + 'semseg': DetInferencer, + 'ref-seg': DetInferencer, + 'instance': DetInferencer, + 'panoptic': DetInferencer, + 'caption': ImageCaptionInferencer, + 'ref-caption': RefImageCaptionInferencer, + 'retrieval': TextToImageRegionRetrievalInferencer, +} + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument( + 'inputs', type=str, help='Input image file or folder path.') + parser.add_argument('model', type=str, help='Config file name') + parser.add_argument('--weights', help='Checkpoint file') + parser.add_argument('--texts', help='text prompt') + parser.add_argument( + '--out-dir', + type=str, + default='outputs', + help='Output directory of images or prediction results.') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + parser.add_argument( + '--show', + action='store_true', + help='Display the image in a popup window.') + parser.add_argument( + '--no-save-vis', + action='store_true', + help='Do not save detection vis results') + parser.add_argument( + '--palette', + default='none', + choices=['ade20k', 'coco', 'voc', 'citys', 'random', 'none'], + help='Color palette used for visualization') + + # only for instance segmentation + parser.add_argument( + '--pred-score-thr', + type=float, + default=0.5, + help='bbox score threshold') + # only for panoptic segmentation + parser.add_argument( + '--stuff-texts', + help='text prompt for stuff name in panoptic segmentation') + + call_args = vars(parser.parse_args()) + if call_args['no_save_vis']: + call_args['out_dir'] = '' + + init_kws = ['model', 'weights', 'device', 'palette'] + init_args = {} + for init_kw in init_kws: + init_args[init_kw] = call_args.pop(init_kw) + + return init_args, call_args + + +def main(): + init_args, call_args = parse_args() + + cfg = Config.fromfile(init_args['model']) + task = cfg.model.head.task + assert task in TASKINFOS + + inferencer = TASKINFOS[task](**init_args) + + if task != 'caption': + assert call_args[ + 'texts'] is not None, f'text prompts is required for {task}' + if task != 'panoptic': + call_args.pop('stuff_texts') + else: + call_args.pop('texts') + call_args.pop('stuff_texts') + + inferencer(**call_args) + + if call_args['out_dir'] != '' and not call_args['no_save_vis']: + print_log(f'results have been saved at {call_args["out_dir"]}') + + +if __name__ == '__main__': + main() diff --git a/projects/XDecoder/xdecoder/__init__.py b/projects/XDecoder/xdecoder/__init__.py new file mode 100644 index 00000000000..d343c8f8ddb --- /dev/null +++ b/projects/XDecoder/xdecoder/__init__.py @@ -0,0 +1,10 @@ +from .focalnet import FocalNet +from .pixel_decoder import XTransformerEncoderPixelDecoder +from .transformer_decoder import XDecoderTransformerDecoder +from .unified_head import XDecoderUnifiedhead +from .xdecoder import XDecoder + +__all__ = [ + 'XDecoder', 'FocalNet', 'XDecoderUnifiedhead', + 'XTransformerEncoderPixelDecoder', 'XDecoderTransformerDecoder' +] diff --git a/projects/XDecoder/xdecoder/focalnet.py b/projects/XDecoder/xdecoder/focalnet.py new file mode 100644 index 00000000000..b85178f45ca --- /dev/null +++ b/projects/XDecoder/xdecoder/focalnet.py @@ -0,0 +1,522 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from mmcv.cnn.bricks import DropPath + +from mmdet.registry import MODELS + +# modified from https://github.com/microsoft/X-Decoder/blob/main/xdecoder/backbone/focal_dw.py # noqa + + +@MODELS.register_module() +class FocalNet(nn.Module): + + def __init__( + self, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.3, + norm_layer=nn.LayerNorm, + patch_norm=True, + out_indices=[0, 1, 2, 3], + frozen_stages=-1, + focal_levels=[3, 3, 3, 3], + focal_windows=[3, 3, 3, 3], + use_pre_norms=[False, False, False, False], + use_conv_embed=True, + use_postln=True, + use_postln_in_modulation=False, + scaling_modulator=True, + use_layerscale=True, + use_checkpoint=False, + ): + super().__init__() + + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + use_conv_embed=use_conv_embed, + is_stem=True, + use_pre_norm=False) + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] + + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchEmbed if + (i_layer < self.num_layers - 1) else None, + focal_window=focal_windows[i_layer], + focal_level=focal_levels[i_layer], + use_pre_norm=use_pre_norms[i_layer], + use_conv_embed=use_conv_embed, + use_postln=use_postln, + use_postln_in_modulation=use_postln_in_modulation, + scaling_modulator=scaling_modulator, + use_layerscale=use_layerscale, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in self.out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + def forward(self, x): + x = self.patch_embed(x) + Wh, Ww = x.size(2), x.size(3) + + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = {} + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() + outs['res{}'.format(i + 2)] = out + return outs + + +class Mlp(nn.Module): + """Multilayer perceptron.""" + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class FocalModulation(nn.Module): + """Focal Modulation. + + Args: + dim (int): Number of input channels. + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + focal_level (int): Number of focal levels + focal_window (int): Focal window size at focal level 1 + focal_factor (int, default=2): Step to increase the focal window + """ + + def __init__(self, + dim, + proj_drop=0., + focal_level=2, + focal_window=7, + focal_factor=2, + use_postln_in_modulation=False, + scaling_modulator=False): + + super().__init__() + self.dim = dim + + self.focal_level = focal_level + self.focal_window = focal_window + self.focal_factor = focal_factor + self.use_postln_in_modulation = use_postln_in_modulation + self.scaling_modulator = scaling_modulator + + self.f = nn.Linear(dim, 2 * dim + (self.focal_level + 1), bias=True) + self.h = nn.Conv2d( + dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True) + + self.act = nn.GELU() + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.focal_layers = nn.ModuleList() + + if self.use_postln_in_modulation: + self.ln = nn.LayerNorm(dim) + + for k in range(self.focal_level): + kernel_size = self.focal_factor * k + self.focal_window + self.focal_layers.append( + nn.Sequential( + nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + stride=1, + groups=dim, + padding=kernel_size // 2, + bias=False), + nn.GELU(), + )) + + def forward(self, x): + """Forward function. + + Args: + x: input features with shape of (B, H, W, C) + """ + B, nH, nW, C = x.shape + x = self.f(x) + x = x.permute(0, 3, 1, 2).contiguous() + q, ctx, gates = torch.split(x, (C, C, self.focal_level + 1), 1) + + ctx_all = 0 + for level in range(self.focal_level): + ctx = self.focal_layers[level](ctx) + ctx_all = ctx_all + ctx * gates[:, level:level + 1] + ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) + ctx_all = ctx_all + ctx_global * gates[:, self.focal_level:] + + if self.scaling_modulator: + ctx_all = ctx_all / (self.focal_level + 1) + + x_out = q * self.h(ctx_all) + x_out = x_out.permute(0, 2, 3, 1).contiguous() + if self.use_postln_in_modulation: + x_out = self.ln(x_out) + x_out = self.proj(x_out) + x_out = self.proj_drop(x_out) + return x_out + + +class FocalModulationBlock(nn.Module): + """Focal Modulation Block. + + Args: + dim (int): Number of input channels. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. + Default: nn.LayerNorm + focal_level (int): number of focal levels + focal_window (int): focal kernel size at level 1 + """ + + def __init__(self, + dim, + mlp_ratio=4., + drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + focal_level=2, + focal_window=9, + use_postln=False, + use_postln_in_modulation=False, + scaling_modulator=False, + use_layerscale=False, + layerscale_value=1e-4): + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.focal_window = focal_window + self.focal_level = focal_level + self.use_postln = use_postln + self.use_layerscale = use_layerscale + + self.dw1 = nn.Conv2d( + dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) + self.norm1 = norm_layer(dim) + self.modulation = FocalModulation( + dim, + focal_window=self.focal_window, + focal_level=self.focal_level, + proj_drop=drop, + use_postln_in_modulation=use_postln_in_modulation, + scaling_modulator=scaling_modulator) + + self.dw2 = nn.Conv2d( + dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + self.H = None + self.W = None + + self.gamma_1 = 1.0 + self.gamma_2 = 1.0 + if self.use_layerscale: + self.gamma_1 = nn.Parameter( + layerscale_value * torch.ones(dim), requires_grad=True) + self.gamma_2 = nn.Parameter( + layerscale_value * torch.ones(dim), requires_grad=True) + + def forward(self, x): + """Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, 'input feature has wrong size' + + x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() + x = x + self.dw1(x) + x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) + + shortcut = x + if not self.use_postln: + x = self.norm1(x) + x = x.view(B, H, W, C) + + # FM + x = self.modulation(x).view(B, H * W, C) + x = shortcut + self.drop_path(self.gamma_1 * x) + if self.use_postln: + x = self.norm1(x) + + x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() + x = x + self.dw2(x) + x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) + + if not self.use_postln: + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_2 * self.mlp(x)) + x = self.norm2(x) + + return x + + +class BasicLayer(nn.Module): + """A basic focal modulation layer for one stage. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + Default: 4. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. + Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. + Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the + end of the layer. Default: None + focal_level (int): Number of focal levels + focal_window (int): Focal window size at focal level 1 + use_conv_embed (bool): Use overlapped convolution for patch + embedding or now. Default: False + use_checkpoint (bool): Whether to use checkpointing to save memory. + Default: False + """ + + def __init__( + self, + dim, + depth, + mlp_ratio=4., + drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + focal_window=9, + focal_level=2, + use_conv_embed=False, + use_postln=False, + use_postln_in_modulation=False, + scaling_modulator=False, + use_layerscale=False, + use_checkpoint=False, + use_pre_norm=False, + ): + super().__init__() + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + FocalModulationBlock( + dim=dim, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + focal_window=focal_window, + focal_level=focal_level, + use_postln=use_postln, + use_postln_in_modulation=use_postln_in_modulation, + scaling_modulator=scaling_modulator, + use_layerscale=use_layerscale, + norm_layer=norm_layer) for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + patch_size=2, + in_chans=dim, + embed_dim=2 * dim, + use_conv_embed=use_conv_embed, + norm_layer=norm_layer, + is_stem=False, + use_pre_norm=use_pre_norm) + + else: + self.downsample = None + + def forward(self, x, H, W): + """Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W) + x_down = self.downsample(x_reshaped) + x_down = x_down.flatten(2).transpose(1, 2) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding. + + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. + Default: 96. + norm_layer (nn.Module, optional): Normalization layer. + Default: None + use_conv_embed (bool): Whether use overlapped convolution for + patch embedding. Default: False + is_stem (bool): Is the stem block or not. + """ + + def __init__(self, + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None, + use_conv_embed=False, + is_stem=False, + use_pre_norm=False): + super().__init__() + patch_size = (patch_size, patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + self.use_pre_norm = use_pre_norm + + if use_conv_embed: + # if we choose to use conv embedding, + # then we treat the stem and non-stem differently + if is_stem: + kernel_size = 7 + padding = 3 + stride = 4 + else: + kernel_size = 3 + padding = 1 + stride = 2 + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=kernel_size, + stride=stride, + padding=padding) + else: + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + if self.use_pre_norm: + if norm_layer is not None: + self.norm = norm_layer(in_chans) + else: + self.norm = None + else: + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + B, C, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, + (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + if self.use_pre_norm: + if self.norm is not None: + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + x = self.norm(x).transpose(1, 2).view(B, C, H, W) + x = self.proj(x) + else: + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x diff --git a/projects/XDecoder/xdecoder/inference/__init__.py b/projects/XDecoder/xdecoder/inference/__init__.py new file mode 100644 index 00000000000..5ebf6f04bf4 --- /dev/null +++ b/projects/XDecoder/xdecoder/inference/__init__.py @@ -0,0 +1,8 @@ +from .image_caption import ImageCaptionInferencer, RefImageCaptionInferencer +from .texttoimage_regionretrieval_inferencer import \ + TextToImageRegionRetrievalInferencer + +__all__ = [ + 'ImageCaptionInferencer', 'RefImageCaptionInferencer', + 'TextToImageRegionRetrievalInferencer' +] diff --git a/projects/XDecoder/xdecoder/inference/image_caption.py b/projects/XDecoder/xdecoder/inference/image_caption.py new file mode 100644 index 00000000000..f22551efdf3 --- /dev/null +++ b/projects/XDecoder/xdecoder/inference/image_caption.py @@ -0,0 +1,308 @@ +import copy +import os.path as osp +from typing import Iterable, List, Optional, Tuple, Union + +import mmcv +import mmengine +import numpy as np +import torch +from mmengine.dataset import Compose +from rich.progress import track + +from mmdet.apis.det_inferencer import DetInferencer, InputsType, PredType +from mmdet.utils import ConfigType + + +def get_adaptive_scale(img_shape: Tuple[int, int], + min_scale: float = 0.3, + max_scale: float = 3.0) -> float: + """Get adaptive scale according to image shape. + + The target scale depends on the the short edge length of the image. If the + short edge length equals 224, the output is 1.0. And output linear scales + according the short edge length. + + You can also specify the minimum scale and the maximum scale to limit the + linear scale. + + Args: + img_shape (Tuple[int, int]): The shape of the canvas image. + min_scale (float): The minimum scale. Defaults to 0.3. + max_scale (float): The maximum scale. Defaults to 3.0. + + Returns: + int: The adaptive scale. + """ + short_edge_length = min(img_shape) + scale = short_edge_length / 224. + return min(max(scale, min_scale), max_scale) + + +class ImageCaptionInferencer(DetInferencer): + DEFAULT_TEXT_CFG = { + 'font_families': 'monospace', + 'colors': 'white', + 'bboxes': dict(facecolor='black', alpha=0.5, boxstyle='Round'), + 'vertical_alignments': 'top', + 'horizontal_alignments': 'left', + } + + def visualize(self, + inputs: InputsType, + preds: PredType, + return_vis: bool = False, + show: bool = False, + wait_time: int = 0, + draw_pred: bool = True, + pred_score_thr: float = 0.3, + no_save_vis: bool = False, + img_out_dir: str = '', + **kwargs) -> Union[List[np.ndarray], None]: + + if no_save_vis is True: + img_out_dir = '' + + if not show and img_out_dir == '' and not return_vis: + return None + + if self.visualizer is None: + raise ValueError('Visualization needs the "visualizer" term' + 'defined in the config, but got None.') + + results = [] + + text_cfg = self.DEFAULT_TEXT_CFG + + for single_input, pred in zip(inputs, preds): + if isinstance(single_input, str): + img_bytes = mmengine.fileio.get(single_input) + img = mmcv.imfrombytes(img_bytes) + img = img[:, :, ::-1] + img_name = osp.basename(single_input) + elif isinstance(single_input, np.ndarray): + img = single_input.copy() + img_num = str(self.num_visualized_imgs).zfill(8) + img_name = f'{img_num}.jpg' + else: + raise ValueError('Unsupported input type: ' + f'{type(single_input)}') + + out_file = osp.join(img_out_dir, 'vis', + img_name) if img_out_dir != '' else None + + self.visualizer.set_image(img) + + img_scale = get_adaptive_scale(img.shape[:2]) + text_cfg['font_sizes'] = int(img_scale * 7) + + self.visualizer.draw_texts( + pred.pred_caption, torch.tensor([img_scale * 5, + img_scale * 5]), **text_cfg) + drawn_img = self.visualizer.get_image() + + self.visualizer.add_datasample( + img_name, + drawn_img, + pred, + show=show, + wait_time=wait_time, + draw_gt=False, + draw_pred=draw_pred, + pred_score_thr=pred_score_thr, + out_file=out_file, + ) + results.append(self.visualizer.get_image()) + self.num_visualized_imgs += 1 + + return results + + +class RefImageCaptionInferencer(ImageCaptionInferencer): + + def _init_pipeline(self, cfg: ConfigType) -> Compose: + """Initialize the test pipeline.""" + pipeline_cfg = cfg.test_dataloader.dataset.pipeline + + # For inference, the key of ``img_id`` is not used. + if 'meta_keys' in pipeline_cfg[-1]: + pipeline_cfg[-1]['meta_keys'] = tuple( + meta_key for meta_key in pipeline_cfg[-1]['meta_keys'] + if meta_key != 'img_id') + + load_img_idx = self._get_transform_idx(pipeline_cfg, + 'LoadImageFromFile') + if load_img_idx == -1: + raise ValueError( + 'LoadImageFromFile is not found in the test pipeline') + pipeline_cfg[load_img_idx]['type'] = 'mmdet.InferencerLoader' + + caption_pipeline = Compose(pipeline_cfg) + + grounding_pipeline_cp = copy.deepcopy(pipeline_cfg) + grounding_pipeline_cp[1].scale = cfg.grounding_scale + grounding_pipeline = Compose(grounding_pipeline_cp) + + return { + 'grounding_pipeline': grounding_pipeline, + 'caption_pipeline': caption_pipeline + } + + def _get_chunk_data(self, inputs: Iterable, chunk_size: int): + """Get batch data from inputs. + + Args: + inputs (Iterable): An iterable dataset. + chunk_size (int): Equivalent to batch size. + + Yields: + list: batch data. + """ + inputs_iter = iter(inputs) + while True: + try: + chunk_data = [] + for _ in range(chunk_size): + inputs_ = next(inputs_iter) + if 'img' in inputs_: + ori_inputs_ = inputs_['img'] + else: + ori_inputs_ = inputs_['img_path'] + chunk_data.append( + (ori_inputs_, self.pipeline['grounding_pipeline']( + copy.deepcopy(inputs_)), + self.pipeline['caption_pipeline']( + copy.deepcopy(inputs_)))) + yield chunk_data + except StopIteration: + if chunk_data: + yield chunk_data + break + + def __call__( + self, + inputs: InputsType, + batch_size: int = 1, + return_vis: bool = False, + show: bool = False, + wait_time: int = 0, + no_save_vis: bool = False, + draw_pred: bool = True, + pred_score_thr: float = 0.3, + return_datasample: bool = False, + print_result: bool = False, + no_save_pred: bool = True, + out_dir: str = '', + texts: Optional[Union[str, list]] = None, + # by open panoptic task + stuff_texts: Optional[Union[str, list]] = None, + custom_entities: bool = False, # by GLIP + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (InputsType): Inputs for the inferencer. + batch_size (int): Inference batch size. Defaults to 1. + show (bool): Whether to display the visualization results in a + popup window. Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + no_save_vis (bool): Whether to force not to save prediction + vis results. Defaults to False. + draw_pred (bool): Whether to draw predicted bounding boxes. + Defaults to True. + pred_score_thr (float): Minimum score of bboxes to draw. + Defaults to 0.3. + return_datasample (bool): Whether to return results as + :obj:`DetDataSample`. Defaults to False. + print_result (bool): Whether to print the inference result w/o + visualization to the console. Defaults to False. + no_save_pred (bool): Whether to force not to save prediction + results. Defaults to True. + out_file: Dir to save the inference results or + visualization. If left as empty, no file will be saved. + Defaults to ''. + + **kwargs: Other keyword arguments passed to :meth:`preprocess`, + :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. + Each key in kwargs should be in the corresponding set of + ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` + and ``postprocess_kwargs``. + + Returns: + dict: Inference and visualization results. + """ + assert batch_size == 1 + ( + preprocess_kwargs, + forward_kwargs, + visualize_kwargs, + postprocess_kwargs, + ) = self._dispatch_kwargs(**kwargs) + + ori_inputs = self._inputs_to_list(inputs) + + if isinstance(texts, str): + texts = [texts] * len(ori_inputs) + + for i in range(len(texts)): + if isinstance(ori_inputs[i], str): + ori_inputs[i] = { + 'text': texts[i], + 'img_path': ori_inputs[i], + 'custom_entities': custom_entities + } + else: + ori_inputs[i] = { + 'text': texts[i], + 'img': ori_inputs[i], + 'custom_entities': custom_entities + } + inputs = self.preprocess( + ori_inputs, batch_size=batch_size, **preprocess_kwargs) + + results_dict = {'predictions': [], 'visualization': []} + for ori_inputs, grounding_data, caption_data in track( + inputs, description='Inference'): + + self.model.sem_seg_head.task = 'ref-seg' + self.model.sem_seg_head.predictor.task = 'ref-seg' + preds = self.forward(grounding_data, **forward_kwargs) + + for data_sample, pred_datasmaple in zip( + caption_data['data_samples'], preds): + data_sample.pred_instances = pred_datasmaple.pred_instances + data_sample.set_metainfo({ + 'grounding_img_shape': + pred_datasmaple.metainfo['img_shape'] + }) + + self.model.sem_seg_head.task = 'caption' + self.model.sem_seg_head.predictor.task = 'caption' + + preds = self.forward(caption_data, **forward_kwargs) + + if isinstance(ori_inputs, dict): + ori_inputs = ori_inputs['img_path'] + + visualization = self.visualize( + ori_inputs, + preds, + return_vis=return_vis, + show=show, + wait_time=wait_time, + draw_pred=draw_pred, + pred_score_thr=pred_score_thr, + no_save_vis=no_save_vis, + img_out_dir=out_dir, + **visualize_kwargs) + results = self.postprocess( + preds, + visualization, + return_datasample=return_datasample, + print_result=print_result, + no_save_pred=no_save_pred, + pred_out_dir=out_dir, + **postprocess_kwargs) + results_dict['predictions'].extend(results['predictions']) + if results['visualization'] is not None: + results_dict['visualization'].extend(results['visualization']) + return results_dict diff --git a/projects/XDecoder/xdecoder/inference/texttoimage_regionretrieval_inferencer.py b/projects/XDecoder/xdecoder/inference/texttoimage_regionretrieval_inferencer.py new file mode 100644 index 00000000000..0aa091bbb24 --- /dev/null +++ b/projects/XDecoder/xdecoder/inference/texttoimage_regionretrieval_inferencer.py @@ -0,0 +1,226 @@ +import copy +from typing import Iterable, Optional, Union + +import torch +from mmengine.dataset import Compose +from rich.progress import track + +from mmdet.apis.det_inferencer import DetInferencer, InputsType +from mmdet.utils import ConfigType + + +class TextToImageRegionRetrievalInferencer(DetInferencer): + + def _init_pipeline(self, cfg: ConfigType) -> Compose: + """Initialize the test pipeline.""" + pipeline_cfg = cfg.test_dataloader.dataset.pipeline + + # For inference, the key of ``img_id`` is not used. + if 'meta_keys' in pipeline_cfg[-1]: + pipeline_cfg[-1]['meta_keys'] = tuple( + meta_key for meta_key in pipeline_cfg[-1]['meta_keys'] + if meta_key != 'img_id') + + load_img_idx = self._get_transform_idx(pipeline_cfg, + 'LoadImageFromFile') + if load_img_idx == -1: + raise ValueError( + 'LoadImageFromFile is not found in the test pipeline') + pipeline_cfg[load_img_idx]['type'] = 'mmdet.InferencerLoader' + + retrieval_pipeline = Compose(pipeline_cfg) + + grounding_pipeline_cp = copy.deepcopy(pipeline_cfg) + grounding_pipeline_cp[1].scale = cfg.grounding_scale + grounding_pipeline = Compose(grounding_pipeline_cp) + + return { + 'grounding_pipeline': grounding_pipeline, + 'retrieval_pipeline': retrieval_pipeline + } + + def _get_chunk_data(self, inputs: Iterable, pipeline, chunk_size: int): + """Get batch data from inputs. + + Args: + inputs (Iterable): An iterable dataset. + chunk_size (int): Equivalent to batch size. + + Yields: + list: batch data. + """ + inputs_iter = iter(inputs) + while True: + try: + chunk_data = [] + for _ in range(chunk_size): + inputs_ = next(inputs_iter) + chunk_data.append( + (inputs_, pipeline(copy.deepcopy(inputs_)))) + yield chunk_data + except StopIteration: + if chunk_data: + yield chunk_data + break + + def preprocess(self, + inputs: InputsType, + pipeline, + batch_size: int = 1, + **kwargs): + """Process the inputs into a model-feedable format. + + Customize your preprocess by overriding this method. Preprocess should + return an iterable object, of which each item will be used as the + input of ``model.test_step``. + + ``BaseInferencer.preprocess`` will return an iterable chunked data, + which will be used in __call__ like this: + + .. code-block:: python + + def __call__(self, inputs, batch_size=1, **kwargs): + chunked_data = self.preprocess(inputs, batch_size, **kwargs) + for batch in chunked_data: + preds = self.forward(batch, **kwargs) + + Args: + inputs (InputsType): Inputs given by user. + batch_size (int): batch size. Defaults to 1. + + Yields: + Any: Data processed by the ``pipeline`` and ``collate_fn``. + """ + chunked_data = self._get_chunk_data(inputs, pipeline, batch_size) + yield from map(self.collate_fn, chunked_data) + + def __call__( + self, + inputs: InputsType, + batch_size: int = 1, + return_vis: bool = False, + show: bool = False, + wait_time: int = 0, + no_save_vis: bool = False, + draw_pred: bool = True, + pred_score_thr: float = 0.3, + return_datasample: bool = False, + print_result: bool = False, + no_save_pred: bool = True, + out_dir: str = '', + texts: Optional[Union[str, list]] = None, + # by open panoptic task + stuff_texts: Optional[Union[str, list]] = None, + custom_entities: bool = False, # by GLIP + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (InputsType): Inputs for the inferencer. + batch_size (int): Inference batch size. Defaults to 1. + show (bool): Whether to display the visualization results in a + popup window. Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + no_save_vis (bool): Whether to force not to save prediction + vis results. Defaults to False. + draw_pred (bool): Whether to draw predicted bounding boxes. + Defaults to True. + pred_score_thr (float): Minimum score of bboxes to draw. + Defaults to 0.3. + return_datasample (bool): Whether to return results as + :obj:`DetDataSample`. Defaults to False. + print_result (bool): Whether to print the inference result w/o + visualization to the console. Defaults to False. + no_save_pred (bool): Whether to force not to save prediction + results. Defaults to True. + out_file: Dir to save the inference results or + visualization. If left as empty, no file will be saved. + Defaults to ''. + + **kwargs: Other keyword arguments passed to :meth:`preprocess`, + :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. + Each key in kwargs should be in the corresponding set of + ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` + and ``postprocess_kwargs``. + + Returns: + dict: Inference and visualization results. + """ + ( + preprocess_kwargs, + forward_kwargs, + visualize_kwargs, + postprocess_kwargs, + ) = self._dispatch_kwargs(**kwargs) + + ori_inputs = self._inputs_to_list(inputs) + + if isinstance(texts, str): + texts = [texts] * len(ori_inputs) + + for i in range(len(texts)): + ori_inputs[i] = { + 'img_path': ori_inputs[i], + 'text': texts[i], + 'custom_entities': False + } + inputs = self.preprocess( + ori_inputs, + pipeline=self.pipeline['retrieval_pipeline'], + batch_size=batch_size, + **preprocess_kwargs) + + self.model.sem_seg_head._force_not_use_cache = True + + pred_scores = [] + for _, retrieval_data in track(inputs, description='Inference'): + preds = self.forward(retrieval_data, **forward_kwargs) + pred_scores.append(preds[0].pred_score) + + pred_score = torch.cat(pred_scores) + pred_score = torch.softmax(pred_score, dim=0) + max_id = torch.argmax(pred_score) + retrieval_ori_input = ori_inputs[max_id.item()] + max_prob = round(pred_score[max_id].item(), 3) + print( + 'The image that best matches the given text is ' + f"{retrieval_ori_input['img_path']} and probability is {max_prob}") + + inputs = self.preprocess([retrieval_ori_input], + pipeline=self.pipeline['grounding_pipeline'], + batch_size=1, + **preprocess_kwargs) + + self.model.task = 'ref-seg' + self.model.sem_seg_head.task = 'ref-seg' + self.model.sem_seg_head.predictor.task = 'ref-seg' + + ori_inputs, grounding_data = next(inputs) + + if isinstance(ori_inputs, dict): + ori_inputs = ori_inputs['img_path'] + + preds = self.forward(grounding_data, **forward_kwargs) + + visualization = self.visualize( + ori_inputs, + preds, + return_vis=return_vis, + show=show, + wait_time=wait_time, + draw_pred=draw_pred, + pred_score_thr=pred_score_thr, + no_save_vis=no_save_vis, + img_out_dir=out_dir, + **visualize_kwargs) + results = self.postprocess( + preds, + visualization, + return_datasample=return_datasample, + print_result=print_result, + no_save_pred=no_save_pred, + pred_out_dir=out_dir, + **postprocess_kwargs) + if results['visualization'] is not None: + results['visualization'] = results['visualization'] + return results diff --git a/projects/XDecoder/xdecoder/language_model.py b/projects/XDecoder/xdecoder/language_model.py new file mode 100644 index 00000000000..effe321825a --- /dev/null +++ b/projects/XDecoder/xdecoder/language_model.py @@ -0,0 +1,251 @@ +import os +from collections import OrderedDict + +import torch +from mmcv.cnn.bricks import DropPath +from torch import nn +from transformers import CLIPTokenizer + +from .utils import get_prompt_templates + +# modified from https://github.com/microsoft/X-Decoder/blob/main/xdecoder/language/vlpencoder.py # noqa + + +class LanguageEncoder(nn.Module): + + def __init__( + self, + tokenizer='openai/clip-vit-base-patch32', + dim_lang=512, + dim_projection=512, + ): + super().__init__() + + os.environ['TOKENIZERS_PARALLELISM'] = 'true' + self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer) + self.tokenizer.add_special_tokens( + {'cls_token': self.tokenizer.eos_token}) + + max_token_num = self.tokenizer.model_max_length + self.lang_encoder = Transformer(max_token_num, + self.tokenizer.vocab_size, dim_lang) + + self.lang_proj = nn.Parameter(torch.empty(dim_lang, dim_projection)) + self.max_token_num = max_token_num + self.logit_scale = nn.Parameter(torch.ones([])) + + @torch.no_grad() + def get_mean_embeds(self, class_names, name='default'): + + def extract_mean_emb(txts): + tokens = self.tokenizer( + txts, + padding='max_length', + truncation=True, + max_length=self.max_token_num, + return_tensors='pt') + clss_embedding, _ = self.forward_language( + (tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), + norm=True, + with_token_embed=False) + clss_embedding = clss_embedding.mean(dim=0) + clss_embedding /= clss_embedding.norm() + return clss_embedding + + templates = get_prompt_templates() + + clss_embeddings = [] + for clss in class_names: + txts = [ + template.format( + clss.replace('-other', + '').replace('-merged', + '').replace('-stuff', '')) + for template in templates + ] + clss_embeddings.append(extract_mean_emb(txts)) + + text_emb = torch.stack(clss_embeddings, dim=0) + setattr(self, '{}_text_embeddings'.format(name), text_emb) + + def get_text_embeds(self, txts, name='grounding', norm=False): + tokens = self.tokenizer( + txts, + padding='max_length', + truncation=True, + max_length=self.max_token_num, + return_tensors='pt') + tokens = {key: value.cuda() for key, value in tokens.items()} + class_emb, token_emb = self.forward_language( + (tokens['input_ids'], tokens['attention_mask']), norm=norm) + ret = { + 'tokens': tokens, + 'token_emb': token_emb, + 'class_emb': class_emb, + } + setattr(self, '{}_token_embeddings'.format(name), ret) + return ret + + def get_sot_token(self, device): + # 49406: CLIP SOT token <|startoftext|> + # 77: CLIP context_length + return torch.tensor([[49406] * 77], device=device) + + def compute_similarity(self, v_emb, name='default'): + v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) + t_emb = getattr(self, '{}_text_embeddings'.format(name)) + output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose( + 1, 2) + return output + + def forward_language(self, + texts, + norm=False, + with_token_embed=True, + with_cls_embed=True): + x = self.lang_encoder(*texts) + hidden_x = x['last_hidden_state'] + + class_embed = None + if with_cls_embed: + class_embed = hidden_x[torch.arange(hidden_x.size(0)), + texts[0].argmax(dim=-1)] + + class_embed = class_embed @ self.lang_proj + if norm: + class_embed = class_embed / ( + class_embed.norm(dim=-1, keepdim=True) + 1e-7) + + hidden_embed = None + if with_token_embed: + hidden_embed = hidden_x @ self.lang_proj + if norm: + hidden_embed = hidden_embed / ( + hidden_embed.norm(dim=-1, keepdim=True) + 1e-7) + + return class_embed, hidden_embed + + +class Transformer(nn.Module): + + def __init__(self, + context_length, + vocab_size, + width, + layers: int = 12, + heads: int = 8, + drop_path: float = 0.0, + autogressive: bool = True): + super().__init__() + + self.token_embedding = nn.Embedding(vocab_size, width) + + self.context_length = context_length + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, width)) + + self.width = width + self.layers = layers + self.autogressive = autogressive + attn_mask = self.build_attention_mask() if autogressive else None + dpr = [x.item() for x in torch.linspace(0, drop_path, layers) + ] # stochastic depth decay rule + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) + for i in range(layers) + ]) + + self.ln_final = LayerNorm(width) + + @property + def dim_out(self): + return self.width + + def build_attention_mask(self): + # lazily create causal attention mask, + # with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, input_ids, attention_mask=None): + key_padding_mask = (attention_mask == 0) if ( + not self.autogressive and attention_mask is not None) else None + x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model] + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + for block in self.resblocks: + x = block(x, key_padding_mask) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_final(x) + + return {'last_hidden_state': x} + + +class LayerNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the + square root).""" + super(LayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + pdtype = x.dtype + x = x.float() + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x.to(pdtype) + self.bias + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None, + drop_path: float = 0.0): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def attention(self, + x: torch.Tensor, + key_padding_mask: torch.Tensor = None): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ + if self.attn_mask is not None else None + + return self.attn( + x, + x, + x, + key_padding_mask=key_padding_mask, + need_weights=False, + attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): + x = x + self.drop_path( + self.attention(self.ln_1(x), key_padding_mask=key_padding_mask)) + x = x + self.drop_path(self.mlp(self.ln_2(x))) + return x diff --git a/projects/XDecoder/xdecoder/pixel_decoder.py b/projects/XDecoder/xdecoder/pixel_decoder.py new file mode 100644 index 00000000000..79312ed7fce --- /dev/null +++ b/projects/XDecoder/xdecoder/pixel_decoder.py @@ -0,0 +1,214 @@ +from typing import Callable, Optional, Union + +from torch import nn +from torch.nn import functional as F + +from mmdet.registry import MODELS +from .transformer_blocks import (Conv2d, PositionEmbeddingSine, + TransformerEncoder, TransformerEncoderLayer, + get_norm) + +# modified from https://github.com/microsoft/X-Decoder/blob/main/xdecoder/body/encoder/transformer_encoder_fpn.py # noqa + + +class TransformerEncoderOnly(nn.Module): + + def __init__(self, + d_model=512, + nhead=8, + num_encoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation='relu', + normalize_before=False): + super().__init__() + + encoder_layer = TransformerEncoderLayer(d_model, nhead, + dim_feedforward, dropout, + activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, + encoder_norm) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, pos_embed): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + if mask is not None: + mask = mask.flatten(1) + + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + return memory.permute(1, 2, 0).view(bs, c, h, w) + + +class BasePixelDecoder(nn.Module): + + def __init__( + self, + in_channels, + conv_dim: int, + mask_dim: int, + mask_on: bool, + norm: Optional[Union[str, Callable]] = None, + ): + super().__init__() + + lateral_convs = [] + output_convs = [] + + use_bias = norm == '' + for idx, in_channel in enumerate(in_channels): + if idx == len(in_channels) - 1: + output_norm = get_norm(norm, conv_dim) + output_conv = Conv2d( + in_channel, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + self.add_module('layer_{}'.format(idx + 1), output_conv) + + lateral_convs.append(None) + output_convs.append(output_conv) + else: + lateral_norm = get_norm(norm, conv_dim) + output_norm = get_norm(norm, conv_dim) + + lateral_conv = Conv2d( + in_channel, + conv_dim, + kernel_size=1, + bias=use_bias, + norm=lateral_norm) + output_conv = Conv2d( + conv_dim, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + self.add_module('adapter_{}'.format(idx + 1), lateral_conv) + self.add_module('layer_{}'.format(idx + 1), output_conv) + + lateral_convs.append(lateral_conv) + output_convs.append(output_conv) + # Place convs into top-down order (from low to high resolution) + # to make the top-down computation in forward clearer. + self.lateral_convs = lateral_convs[::-1] + self.output_convs = output_convs[::-1] + + self.mask_on = mask_on + if self.mask_on: + self.mask_dim = mask_dim + self.mask_features = Conv2d( + conv_dim, + mask_dim, + kernel_size=3, + stride=1, + padding=1, + ) + self.maskformer_num_feature_levels = 3 + + +# To prevent conflicts with TransformerEncoderPixelDecoder in mask2former, +# we change the name to XTransformerEncoderPixelDecoder +@MODELS.register_module() +class XTransformerEncoderPixelDecoder(BasePixelDecoder): + + def __init__( + self, + in_channels, + transformer_dropout: float = 0.0, + transformer_nheads: int = 8, + transformer_dim_feedforward: int = 2048, + transformer_enc_layers: int = 6, + transformer_pre_norm: bool = False, + conv_dim: int = 512, + mask_dim: int = 512, + norm: Optional[Union[str, Callable]] = 'GN', + ): + + super().__init__( + in_channels, + conv_dim=conv_dim, + mask_dim=mask_dim, + norm=norm, + mask_on=True) + + self.in_features = ['res2', 'res3', 'res4', 'res5'] + feature_channels = in_channels + + in_channels = feature_channels[len(in_channels) - 1] + self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1) + self.transformer = TransformerEncoderOnly( + d_model=conv_dim, + dropout=transformer_dropout, + nhead=transformer_nheads, + dim_feedforward=transformer_dim_feedforward, + num_encoder_layers=transformer_enc_layers, + normalize_before=transformer_pre_norm, + ) + self.pe_layer = PositionEmbeddingSine(conv_dim // 2, normalize=True) + + # update layer + use_bias = norm == '' + output_norm = get_norm(norm, conv_dim) + output_conv = Conv2d( + conv_dim, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + delattr(self, 'layer_{}'.format(len(self.in_features))) + self.add_module('layer_{}'.format(len(self.in_features)), output_conv) + self.output_convs[0] = output_conv + + def forward(self, features): + multi_scale_features = [] + num_cur_levels = 0 + + # Reverse feature maps into top-down order + # (from low to high resolution) + for idx, f in enumerate(self.in_features[::-1]): + x = features[f] + lateral_conv = self.lateral_convs[idx] + output_conv = self.output_convs[idx] + if lateral_conv is None: + transformer = self.input_proj(x) + pos = self.pe_layer(x) + transformer = self.transformer(transformer, None, pos) + y = output_conv(transformer) + else: + cur_fpn = lateral_conv(x) + # Following FPN implementation, we use nearest upsampling here + y = cur_fpn + F.interpolate( + y, size=cur_fpn.shape[-2:], mode='nearest') + y = output_conv(y) + if num_cur_levels < self.maskformer_num_feature_levels: + multi_scale_features.append(y) + num_cur_levels += 1 + + mask_features = self.mask_features(y) + return mask_features, multi_scale_features diff --git a/projects/XDecoder/xdecoder/transformer_blocks.py b/projects/XDecoder/xdecoder/transformer_blocks.py new file mode 100755 index 00000000000..4e6861d643a --- /dev/null +++ b/projects/XDecoder/xdecoder/transformer_blocks.py @@ -0,0 +1,473 @@ +import copy +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +# modified from https://github.com/microsoft/X-Decoder/blob/main/xdecoder/body/transformer_blocks.py # noqa +"""Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" + + +class Conv2d(torch.nn.Conv2d): + """A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and + more features.""" + + def __init__(self, *args, **kwargs): + """Extra keyword arguments supported in addition to those in + `torch.nn.Conv2d`: + + Args: + norm (nn.Module, optional): a normalization layer + activation (callable(Tensor) -> Tensor): a callable + activation function + + It assumes that norm layer is used before activation. + """ + norm = kwargs.pop('norm', None) + activation = kwargs.pop('activation', None) + super().__init__(*args, **kwargs) + + self.norm = norm + self.activation = activation + + def forward(self, x): + x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups) + if self.norm is not None: + x = self.norm(x) + if self.activation is not None: + x = self.activation(x) + return x + + +class PositionEmbeddingSine(nn.Module): + """This is a more standard version of the position embedding, very similar + to the one used by the Attention is all you need paper, generalized to work + on images.""" + + def __init__(self, + num_pos_feats=64, + temperature=10000, + normalize=False, + scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError('normalize should be True if scale is passed') + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask=None): + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), + device=x.device, + dtype=torch.bool) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=x.dtype) + x_embed = not_mask.cumsum(2, dtype=x.dtype) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange( + self.num_pos_feats, dtype=x.dtype, device=x.device) + dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + def __repr__(self, _repr_indent=4): + head = 'Positional encoding ' + self.__class__.__name__ + body = [ + 'num_pos_feats: {}'.format(self.num_pos_feats), + 'temperature: {}'.format(self.temperature), + 'normalize: {}'.format(self.normalize), + 'scale: {}'.format(self.scale), + ] + # _repr_indent = 4 + lines = [head] + [' ' * _repr_indent + line for line in body] + return '\n'.join(lines) + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + output = src + + for layer in self.layers: + output = layer( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerEncoderLayer(nn.Module): + + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation='relu', + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(src, pos) + + src2 = self.self_attn( + q, + k, + value=src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn( + q, + k, + value=src2, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class SelfAttentionLayer(nn.Module): + + def __init__(self, + d_model, + nhead, + dropout=0.0, + activation='relu', + normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn( + q, + k, + value=tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + + return tgt + + def forward_pre(self, + tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn( + q, + k, + value=tgt2, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + + return tgt + + def forward(self, + tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, tgt_mask, tgt_key_padding_mask, + query_pos) + return self.forward_post(tgt, tgt_mask, tgt_key_padding_mask, + query_pos) + + +class CrossAttentionLayer(nn.Module): + + def __init__(self, + d_model, + nhead, + dropout=0.0, + activation='relu', + normalize_before=False): + super().__init__() + self.multihead_attn = nn.MultiheadAttention( + d_model, nhead, dropout=dropout) + + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + tgt, + memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2, avg_attn = self.multihead_attn( + query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt, avg_attn + + def forward_pre(self, + tgt, + memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm(tgt) + tgt2, avg_attn = self.multihead_attn( + query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + tgt = tgt + self.dropout(tgt2) + + return tgt, avg_attn + + def forward(self, + tgt, + memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, memory_mask, + memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, memory_mask, + memory_key_padding_mask, pos, query_pos) + + +class FFNLayer(nn.Module): + + def __init__(self, + d_model, + dim_feedforward=2048, + dropout=0.0, + activation='relu', + normalize_before=False): + super().__init__() + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm = nn.LayerNorm(d_model) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt): + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt + + def forward_pre(self, tgt): + tgt2 = self.norm(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout(tgt2) + return tgt + + def forward(self, tgt): + if self.normalize_before: + return self.forward_pre(tgt) + return self.forward_post(tgt) + + +class MLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +def get_norm(norm, out_channels): + """ + Args: + norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; + or a callable that takes a channel number and returns + the normalization layer as a nn.Module. + + Returns: + nn.Module or None: the normalization layer + """ + if norm is None: + return None + if isinstance(norm, str): + if len(norm) == 0: + return None + norm = { + 'BN': nn.BatchNorm2d, + 'GN': lambda channels: nn.GroupNorm(32, channels), + }[norm] + return norm(out_channels) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string.""" + if activation == 'relu': + return F.relu + if activation == 'gelu': + return F.gelu + if activation == 'glu': + return F.glu + raise RuntimeError(f'activation should be relu/gelu, not {activation}.') diff --git a/projects/XDecoder/xdecoder/transformer_decoder.py b/projects/XDecoder/xdecoder/transformer_decoder.py new file mode 100644 index 00000000000..4c1165b0e6e --- /dev/null +++ b/projects/XDecoder/xdecoder/transformer_decoder.py @@ -0,0 +1,439 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from mmdet.registry import MODELS +from .language_model import LanguageEncoder +from .transformer_blocks import (MLP, Conv2d, CrossAttentionLayer, FFNLayer, + PositionEmbeddingSine, SelfAttentionLayer) +from .utils import is_lower_torch_version + + +def vl_similarity(image_feat, text_feat, temperature=1): + logits = torch.matmul(image_feat, text_feat.t()) + logits = temperature.exp().clamp(max=100) * logits + return logits + + +@MODELS.register_module() +class XDecoderTransformerDecoder(nn.Module): + + def __init__( + self, + in_channels=512, + hidden_dim: int = 512, + dim_proj: int = 512, + num_queries: int = 101, + max_token_num: int = 77, + nheads: int = 8, + dim_feedforward: int = 2048, + decoder_layers: int = 9, + pre_norm: bool = False, + mask_dim: int = 512, + task: str = 'semseg', + captioning_step: int = 50, + ): + super().__init__() + + # positional encoding + self.pe_layer = PositionEmbeddingSine(hidden_dim // 2, normalize=True) + + # define transformer decoder here + self.num_heads = nheads + self.num_layers = decoder_layers + self.max_token_num = max_token_num + self.transformer_self_attention_layers = nn.ModuleList() + self.transformer_cross_attention_layers = nn.ModuleList() + self.transformer_ffn_layers = nn.ModuleList() + + for _ in range(self.num_layers): + self.transformer_self_attention_layers.append( + SelfAttentionLayer( + d_model=hidden_dim, + nhead=nheads, + dropout=0.0, + normalize_before=pre_norm, + )) + + self.transformer_cross_attention_layers.append( + CrossAttentionLayer( + d_model=hidden_dim, + nhead=nheads, + dropout=0.0, + normalize_before=pre_norm, + )) + + self.transformer_ffn_layers.append( + FFNLayer( + d_model=hidden_dim, + dim_feedforward=dim_feedforward, + dropout=0.0, + normalize_before=pre_norm, + )) + + self.decoder_norm = nn.LayerNorm(hidden_dim) + + self.num_queries = num_queries + # learnable query features + self.query_feat = nn.Embedding(num_queries, hidden_dim) + # learnable query p.e. + self.query_embed = nn.Embedding(num_queries, hidden_dim) + + # level embedding (always use 3 scales) + self.num_feature_levels = 3 + self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) + self.input_proj = nn.ModuleList() + + for _ in range(self.num_feature_levels): + if in_channels != hidden_dim: + self.input_proj.append( + Conv2d(in_channels, hidden_dim, kernel_size=1)) + else: + self.input_proj.append(nn.Sequential()) + + self.task = task + + # output FFNs + self.lang_encoder = LanguageEncoder() + + self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) + self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) + + # for caption and ref-caption + self.caping_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) + self.pos_embed_caping = nn.Embedding(max_token_num, hidden_dim) + self.captioning_step = captioning_step + + # register self_attn_mask to avoid information leakage, + # it includes interaction between object query, class query and + # caption query + self_attn_mask = torch.zeros((1, num_queries + max_token_num, + num_queries + max_token_num)).bool() + # object+class query does not attend with caption query. + self_attn_mask[:, :num_queries, num_queries:] = True + # caption query only attend with previous token. + self_attn_mask[:, num_queries:, num_queries:] = torch.triu( + torch.ones((1, max_token_num, max_token_num)), diagonal=1).bool() + # object query does not attend with class query. + self_attn_mask[:, :num_queries - 1, num_queries - 1:num_queries] = True + # class query does not attend with object query. + self_attn_mask[:, num_queries - 1:num_queries, :num_queries - 1] = True + self.register_buffer('self_attn_mask', self_attn_mask) + + def forward(self, x, mask_features, extra=None): + if self.task == 'caption': + return self.forward_caption(x, mask_features, extra) + + assert len(x) == self.num_feature_levels + src = [] + pos = [] + size_list = [] + + for i in range(self.num_feature_levels): + size_list.append(x[i].shape[-2:]) + pos.append(self.pe_layer(x[i], None).flatten(2)) + src.append(self.input_proj[i](x[i]).flatten(2) + + self.level_embed.weight[i][None, :, None]) + + # flatten NxCxHxW to HWxNxC + pos[-1] = pos[-1].permute(2, 0, 1) + src[-1] = src[-1].permute(2, 0, 1) + + _, bs, _ = src[0].shape + + query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) + output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) + + predictions_mask = [] + predictions_class_embed = [] + + if self.task == 'ref-seg': + self_tgt_mask = self.self_attn_mask[:, :self.num_queries, :self. + num_queries].repeat( + output.shape[1] * + self.num_heads, 1, 1) + grounding_tokens = extra['grounding_tokens'] + _grounding_tokens = grounding_tokens.detach().clone() + # initialize with negative attention at the beginning. + pad_tgt_mask = torch.ones( + (1, self.num_queries + (self.num_queries - 1) + + len(grounding_tokens), self.num_queries + + (self.num_queries - 1) + len(grounding_tokens)), + device=self_tgt_mask.device).bool().repeat( + output.shape[1] * self.num_heads, 1, 1) + pad_tgt_mask[:, :self.num_queries, :self. + num_queries] = self_tgt_mask + # grounding tokens could attend with eatch other + pad_tgt_mask[:, self.num_queries:, self.num_queries:] = False + self_tgt_mask = pad_tgt_mask + output = torch.cat((output, output[:-1]), dim=0) + # also pad language embdding to fix embedding + query_embed = torch.cat((query_embed, query_embed[:-1]), dim=0) + else: + self_tgt_mask = self.self_attn_mask[:, :self.num_queries, :self. + num_queries].repeat( + output.shape[1] * + self.num_heads, 1, 1) + + results = self.forward_prediction_heads( + output, mask_features, attn_mask_target_size=size_list[0]) + attn_mask = results['attn_mask'] + predictions_class_embed.append(results['class_embed']) + predictions_mask.append(results['outputs_mask']) + + for i in range(self.num_layers): + level_index = i % self.num_feature_levels + attn_mask[torch.where( + attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + # attention: cross-attention first + output, avg_attn = self.transformer_cross_attention_layers[i]( + output, + src[level_index], + memory_mask=attn_mask, + # here we do not apply masking on padded region + memory_key_padding_mask=None, + pos=pos[level_index], + query_pos=query_embed) + + if self.task == 'ref-seg': + output = torch.cat((output, _grounding_tokens), dim=0) + query_embed = torch.cat((query_embed, grounding_tokens), dim=0) + + output = self.transformer_self_attention_layers[i]( + output, + tgt_mask=self_tgt_mask, + tgt_key_padding_mask=None, + query_pos=query_embed) + + output = self.transformer_ffn_layers[i](output) + + if self.task == 'ref-seg': + _grounding_tokens = output[-len(_grounding_tokens):] + output = output[:-len(_grounding_tokens)] + query_embed = query_embed[:-len(_grounding_tokens)] + + results = self.forward_prediction_heads( + output, + mask_features, + attn_mask_target_size=size_list[(i + 1) % + self.num_feature_levels]) + attn_mask = results['attn_mask'] + predictions_mask.append(results['outputs_mask']) + predictions_class_embed.append(results['class_embed']) + + out = { + 'pred_masks': predictions_mask[-1], + 'pred_class_embed': predictions_class_embed[-1], + } + + if self.task == 'ref-seg': + mask_pred_results = [] + outputs_class = [] + for idx in range(mask_features.shape[0]): # batch size + pred_gmasks = out['pred_masks'][idx, self.num_queries:2 * + self.num_queries - 1] + v_emb = predictions_class_embed[-1][idx, self.num_queries:2 * + self.num_queries - 1] + t_emb = extra['class_emb'] + + t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7) + v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) + + temperature = self.lang_encoder.logit_scale + out_prob = vl_similarity(v_emb, t_emb, temperature=temperature) + + matched_id = out_prob.max(0)[1] + mask_pred_results += [pred_gmasks[matched_id, :, :]] + outputs_class += [out_prob[matched_id, :]] + out['pred_masks'] = mask_pred_results + out['pred_logits'] = outputs_class + elif self.task == 'retrieval': + t_emb = extra['class_emb'] + temperature = self.lang_encoder.logit_scale + v_emb = out['pred_class_embed'][:, -1, :] + v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) + logits = vl_similarity(v_emb, t_emb, temperature) + out['pred_logits'] = logits + elif self.task in ['semseg', 'instance', 'panoptic']: + outputs_class = self.lang_encoder.compute_similarity( + out['pred_class_embed']) + out['pred_logits'] = outputs_class + return out + + def forward_caption(self, x, mask_features, extra=None): + assert len(x) == self.num_feature_levels + src = [] + pos = [] + size_list = [] + + for i in range(self.num_feature_levels): + size_list.append(x[i].shape[-2:]) + pos.append(self.pe_layer(x[i], None).flatten(2)) + src.append(self.input_proj[i](x[i]).flatten(2) + + self.level_embed.weight[i][None, :, None]) + + # flatten NxCxHxW to HWxNxC + pos[-1] = pos[-1].permute(2, 0, 1) + src[-1] = src[-1].permute(2, 0, 1) + + _, bs, _ = src[0].shape + + # QxNxC + query_embed_ = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) + query_feat = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) + lang_token = extra['start_token'].repeat(bs, 1) + pos_embed = self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1) + + # prepare token embedding for evaluation + token_embs = self.lang_encoder.lang_encoder.token_embedding.weight + + for cap_idx in range(0, self.captioning_step): + lang_embed = self.lang_encoder.forward_language( + (lang_token, ), with_cls_embed=False)[1].transpose(0, 1) + # concat object query, class token and caption token. + output = torch.cat((query_feat, lang_embed), dim=0) + lang_embed += pos_embed + query_embed = torch.cat((query_embed_, lang_embed), dim=0) + + # prediction heads on learnable query features + results = self.forward_prediction_heads( + output, mask_features, attn_mask_target_size=size_list[0]) + attn_mask = results['attn_mask'] + + for i in range(self.num_layers): + level_index = i % self.num_feature_levels + attn_mask[torch.where( + attn_mask.sum(-1) == attn_mask.shape[-1])] = False + attn_mask = torch.cat( + (attn_mask, + torch.zeros_like(attn_mask[:, :self.max_token_num, :])), + dim=1) + self_tgt_mask = self.self_attn_mask.repeat( + output.shape[1] * self.num_heads, 1, 1) + + if 'grounding_mask' in extra: + bs, nq, wh = attn_mask.shape + assert bs == self.num_heads, 'Only support single ' \ + 'image referring captioning.' + grounding_mask = extra['grounding_mask'] + attn_mask = attn_mask.reshape(bs, nq, size_list[i % 3][0], + size_list[i % 3][1]) + grounding_mask = F.interpolate( + grounding_mask.float(), + size_list[i % 3], + mode='nearest').bool()[0, 0] + attn_mask[:, self.num_queries:, grounding_mask] = True + attn_mask = attn_mask.reshape(bs, nq, wh) + + # attention: cross-attention first + output, avg_attn = self.transformer_cross_attention_layers[i]( + output, + src[level_index], + memory_mask=attn_mask, + # here we do not apply masking on padded region + memory_key_padding_mask=None, + pos=pos[level_index], + query_pos=query_embed) + + output = self.transformer_self_attention_layers[i]( + output, + tgt_mask=self_tgt_mask, + tgt_key_padding_mask=None, + query_pos=query_embed) + + output = self.transformer_ffn_layers[i](output) + + results = self.forward_prediction_heads( + output, + mask_features, + attn_mask_target_size=size_list[(i + 1) % + self.num_feature_levels]) + attn_mask = results['attn_mask'] + + pred_captions = results['outputs_caption'] + pred_captions = pred_captions @ token_embs.t() + lang_token[:, cap_idx + 1] = pred_captions[:, cap_idx].max(-1)[1] + + texts = self.lang_encoder.tokenizer.batch_decode( + lang_token, skip_special_tokens=False) + texts_new = [] + + for x in texts: + x = x.split('<|endoftext|>')[0] + x = x.replace('<|endoftext|>', '') + x = x.replace('<|startoftext|>', '') + x = x.strip() + texts_new.append(x) + + out = {'pred_caption': texts_new} + return out + + def forward_prediction_heads(self, output, mask_features, + attn_mask_target_size): + decoder_output = self.decoder_norm(output) + decoder_output = decoder_output.transpose(0, 1) + + if self.task == 'caption': + outputs_caption = decoder_output[:, self. + num_queries:] @ self.caping_embed + + # recompute class token output. + norm_decoder_output = decoder_output / ( + decoder_output.norm(dim=-1, keepdim=True) + 1e-7) + obj_token = norm_decoder_output[:, :self.num_queries - 1] + cls_token = norm_decoder_output[:, + self.num_queries - 1:self.num_queries] + + sim = (cls_token @ obj_token.transpose(1, 2)).softmax(-1)[:, 0, :, + None] + cls_token = (sim * decoder_output[:, :self.num_queries - 1]).sum( + dim=1, keepdim=True) + + if self.task == 'ref-seg': + decoder_output = torch.cat( + (decoder_output[:, :self.num_queries - 1], cls_token, + decoder_output[:, self.num_queries:2 * self.num_queries - 1]), + dim=1) + else: + decoder_output = torch.cat( + (decoder_output[:, :self.num_queries - 1], cls_token), dim=1) + + mask_embed = self.mask_embed(decoder_output) + outputs_mask = torch.einsum('bqc,bchw->bqhw', mask_embed, + mask_features) + + if is_lower_torch_version(): + attn_mask = F.interpolate( + outputs_mask, + size=attn_mask_target_size, + mode='bicubic', + align_corners=False) + else: + attn_mask = F.interpolate( + outputs_mask, + size=attn_mask_target_size, + mode='bicubic', + align_corners=False, + antialias=True) + + attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat( + 1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool() + attn_mask = attn_mask.detach() + + attn_mask[:, self.num_queries:self.num_queries + 1].fill_(False) + + if self.task == 'caption': + results = { + 'attn_mask': attn_mask, + 'outputs_caption': outputs_caption, + } + return results + else: + class_embed = decoder_output @ self.class_embed + results = { + 'outputs_mask': outputs_mask, + 'attn_mask': attn_mask, + 'class_embed': class_embed, + } + return results diff --git a/projects/XDecoder/xdecoder/unified_head.py b/projects/XDecoder/xdecoder/unified_head.py new file mode 100644 index 00000000000..ec852b1d0df --- /dev/null +++ b/projects/XDecoder/xdecoder/unified_head.py @@ -0,0 +1,363 @@ +import copy +from typing import Sequence + +import torch +from mmengine.structures import InstanceData, PixelData +from torch import nn +from torch.nn import functional as F + +from mmdet.evaluation.functional import INSTANCE_OFFSET +from mmdet.registry import MODELS +from .utils import (is_lower_torch_version, retry_if_cuda_oom, + sem_seg_postprocess) + + +@MODELS.register_module() +class XDecoderUnifiedhead(nn.Module): + + def __init__(self, + in_channels: int, + pixel_decoder: nn.Module, + transformer_decoder: nn.Module, + task: str = 'semseg', + test_cfg=None): + super().__init__() + self.task = task + self.test_cfg = test_cfg + + pixel_decoder_ = copy.deepcopy(pixel_decoder) + pixel_decoder_.update(in_channels=in_channels) + self.pixel_decoder = MODELS.build(pixel_decoder_) + + transformer_decoder_ = copy.deepcopy(transformer_decoder) + transformer_decoder_.update(task=task) + self.predictor = MODELS.build(transformer_decoder_) + + self.return_inter_mask = False + if self.task == 'ref-caption': + # ref-caption = ref-seg + caption, + # so we need to return the intermediate mask + self.return_inter_mask = True + + self._all_text_prompts = None + self._extra = None + # TODO: Very trick, for retrieval task + self._force_not_use_cache = False + + def pre_process(self, batch_data_samples, device): + extra = {} + if self.task != 'caption': + # have text + all_text_prompts = [] + num_thing_class = 0 + for data_samples in batch_data_samples: + if isinstance(data_samples.text, str): + text = data_samples.text.split('.') + elif isinstance(data_samples.text, Sequence): + text = data_samples.text + else: + raise TypeError( + 'Type pf data_sample.text must be sequence or str') + text = list(filter(lambda x: len(x) > 0, text)) + all_text_prompts.append(text) + num_thing_class = len(text) + # for panoptic + if 'stuff_text' in data_samples: + if isinstance(data_samples.stuff_text, str): + text = data_samples.stuff_text.split('.') + elif isinstance(data_samples.stuff_text, Sequence): + text = data_samples.stuff_text + else: + raise TypeError('Type pf data_sample.stuff_text ' + 'must be sequence or str') + text = list(filter(lambda x: len(x) > 0, text)) + all_text_prompts[-1].extend(text) + + # TODO: support batch + all_text_prompts = all_text_prompts[0] + + if all_text_prompts != self._all_text_prompts \ + or self._force_not_use_cache: + # avoid redundant computation + self._all_text_prompts = all_text_prompts + if self.task in ['semseg', 'instance', 'panoptic']: + self.predictor.lang_encoder.get_mean_embeds( + all_text_prompts + ['background']) + elif self.task == 'ref-seg': + token_info = self.predictor.lang_encoder.get_text_embeds( + all_text_prompts, norm=False) + token_emb = token_info['token_emb'] + tokens = token_info['tokens'] + query_emb = token_emb[tokens['attention_mask'].bool()] + extra['grounding_tokens'] = query_emb[:, None] + extra['class_emb'] = token_info['class_emb'] + elif self.task == 'retrieval': + token_info = self.predictor.lang_encoder.get_text_embeds( + all_text_prompts, norm=True) + extra['class_emb'] = token_info['class_emb'] + self._extra = extra + return extra, all_text_prompts, num_thing_class + else: + return self._extra, all_text_prompts, num_thing_class + else: + if not hasattr(self, 'start_token'): + self.start_token = self.predictor.lang_encoder. \ + get_sot_token(device=device) + extra['start_token'] = self.start_token + return extra, None, None + + def predict(self, features, batch_data_samples): + # multi scale feature + mask_features, multi_scale_features = self.pixel_decoder(features) + + # pre process + extra, all_text_prompts, num_thing_class = self.pre_process( + batch_data_samples, mask_features.device) + + # transformer decoder forward + predictions = self.predictor( + multi_scale_features, mask_features, extra=extra) + + # post process + return self.post_process(predictions, batch_data_samples, + all_text_prompts, num_thing_class) + + def post_process(self, predictions, batch_data_samples, all_text_prompts, + num_thing_class): + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + batch_input_shape = batch_data_samples[0].metainfo['batch_input_shape'] + + if self.task == 'caption': + for text, data_samples in zip(predictions['pred_caption'], + batch_data_samples): + data_samples.pred_caption = text + + if 'pred_instances' in batch_data_samples[0]: + for img_metas, data_samples in zip(batch_img_metas, + batch_data_samples): + original_caption = data_samples.text.split('.') + text_prompts = list( + filter(lambda x: len(x) > 0, original_caption)) + + height = img_metas['ori_shape'][0] + width = img_metas['ori_shape'][1] + image_size = img_metas['grounding_img_shape'][:2] + + mask_pred_result = data_samples.pred_instances.masks.float( + ) + mask_cls_result = data_samples.pred_instances.scores.float( + ) + + mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( + mask_pred_result, image_size, height, width) + + pred_instances = retry_if_cuda_oom( + self._instance_inference)(mask_cls_result, + mask_pred_result, + text_prompts) + data_samples.pred_instances = pred_instances + + elif self.task in ['semseg', 'instance', 'panoptic']: + mask_pred_results = predictions['pred_masks'] + mask_cls_results = predictions['pred_logits'] + if is_lower_torch_version(): + mask_pred_results = F.interpolate( + mask_pred_results, + size=(batch_input_shape[-2], batch_input_shape[-1]), + mode='bicubic', + align_corners=False) + else: + mask_pred_results = F.interpolate( + mask_pred_results, + size=(batch_input_shape[-2], batch_input_shape[-1]), + mode='bicubic', + align_corners=False, + antialias=True) + + # for batch + for mask_cls_result, \ + mask_pred_result, \ + img_metas, \ + data_samples in zip( + mask_cls_results, + mask_pred_results, + batch_img_metas, + batch_data_samples): + height = img_metas['ori_shape'][0] + width = img_metas['ori_shape'][1] + image_size = img_metas['img_shape'][:2] + mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( + mask_pred_result, image_size, height, width) + mask_cls_result = mask_cls_result.to(mask_pred_result) + + if self.task == 'semseg': + pred_sem_seg = retry_if_cuda_oom(self._semantic_inference)( + mask_cls_result, mask_pred_result, all_text_prompts) + data_samples.pred_sem_seg = pred_sem_seg + elif self.task == 'instance': + pred_instances = retry_if_cuda_oom( + self._instance_inference)(mask_cls_result, + mask_pred_result, + all_text_prompts) + data_samples.pred_instances = pred_instances + elif self.task == 'panoptic': + pred_panoptic_seg = retry_if_cuda_oom( + self._panoptic_inference)(mask_cls_result, + mask_pred_result, + all_text_prompts, + num_thing_class) + data_samples.pred_panoptic_seg = pred_panoptic_seg + elif self.task == 'ref-seg': + mask_pred_results = predictions['pred_masks'] + mask_cls_results = predictions['pred_logits'] + results_ = zip(mask_pred_results, mask_cls_results, + batch_img_metas, batch_data_samples) + for mask_pred_result, mask_cls_result, \ + img_metas, data_samples in results_: + if is_lower_torch_version(): + mask_pred_result = F.interpolate( + mask_pred_result[None], + size=(batch_input_shape[-2], batch_input_shape[-1]), + mode='bicubic', + align_corners=False)[0] + else: + mask_pred_result = F.interpolate( + mask_pred_result[None], + size=(batch_input_shape[-2], batch_input_shape[-1]), + mode='bicubic', + align_corners=False, + antialias=True)[0] + + if self.return_inter_mask: + mask = mask_pred_result > 0 + pred_instances = InstanceData() + pred_instances.masks = mask + pred_instances.scores = mask_cls_result + data_samples.pred_instances = pred_instances + continue + + height = img_metas['ori_shape'][0] + width = img_metas['ori_shape'][1] + image_size = img_metas['img_shape'][:2] + mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( + mask_pred_result, image_size, height, width) + + pred_instances = retry_if_cuda_oom(self._instance_inference)( + mask_cls_result, mask_pred_result, all_text_prompts) + data_samples.pred_instances = pred_instances + elif self.task == 'retrieval': + batch_data_samples[0].pred_score = predictions['pred_logits'] + return batch_data_samples + + def _instance_inference(self, mask_cls, mask_pred, text_prompts): + num_class = len(text_prompts) + + if self.task in ['ref-seg', 'caption']: + scores = F.softmax(mask_cls, dim=-1) + scores_per_image = scores.max(dim=-1)[0] + labels_per_image = torch.arange(num_class) + else: + scores = F.softmax(mask_cls, dim=-1)[:, :-1] + + labels = torch.arange( + num_class, + device=scores.device).unsqueeze(0).repeat(scores.shape[0], + 1).flatten(0, 1) + scores_per_image, topk_indices = scores.flatten(0, 1).topk( + self.test_cfg.get('max_per_img', 100), sorted=False) + + labels_per_image = labels[topk_indices] + topk_indices = (topk_indices // num_class) + mask_pred = mask_pred[topk_indices] + + result = InstanceData() + mask_pred = mask_pred.sigmoid() + result.masks = (mask_pred > self.test_cfg.mask_thr).float() + + # calculate average mask prob + mask_scores_per_image = (mask_pred.flatten(1) * + result.masks.flatten(1)).sum(1) / ( + result.masks.flatten(1).sum(1) + 1e-6) + result.scores = scores_per_image * mask_scores_per_image + result.labels = labels_per_image + result.label_names = [ + text_prompts[label] for label in labels_per_image + ] + result.bboxes = result.scores.new_zeros(len(result.scores), 4) + return result + + def _semantic_inference(self, mask_cls, mask_pred, text_prompts): + mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + sem_seg = torch.einsum('qc,qhw->chw', mask_cls, mask_pred) + + if sem_seg.shape[0] == 1: + # 0 is foreground, ignore_index is background + sem_seg = (sem_seg.squeeze(0) <= self.test_cfg.mask_thr).int() + sem_seg[sem_seg == 1] = self.test_cfg.get('ignore_index', 255) + else: + # 0 is foreground, ignore_index is background + if self.test_cfg.use_thr_for_mc: + foreground_flag = sem_seg > self.test_cfg.mask_thr + sem_seg = sem_seg.max(0)[1] + sem_seg[foreground_flag.sum(0) == 0] = self.test_cfg.get( + 'ignore_index', 255) + else: + sem_seg = sem_seg.max(0)[1] + pred_sem_seg = PixelData( + sem_seg=sem_seg[None], + metainfo={ + 'label_names': text_prompts, + 'ignore_index': self.test_cfg.get('ignore_index', 255) + }) + return pred_sem_seg + + def _panoptic_inference(self, mask_cls, mask_pred, all_text_prompts, + num_thing_class): + scores, labels = F.softmax(mask_cls, dim=-1).max(-1) + mask_pred = mask_pred.sigmoid() + + keep = labels.ne(len(all_text_prompts)) & ( + scores > self.test_cfg.mask_thr) + cur_scores = scores[keep] + cur_classes = labels[keep] + cur_masks = mask_pred[keep] + cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks + + h, w = cur_masks.shape[-2:] + panoptic_seg = torch.full((h, w), + self.test_cfg.get('ignore_index', 255), + dtype=torch.int32, + device=cur_masks.device) + instance_id = 1 + + if cur_masks.shape[0] > 0: + cur_mask_ids = cur_prob_masks.argmax(0) + for k in range(cur_classes.shape[0]): + pred_class = cur_classes[k].item() + isthing = int(pred_class) < num_thing_class + mask_area = (cur_mask_ids == k).sum().item() + original_area = (cur_masks[k] >= 0.5).sum().item() + mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5) + + if mask_area > 0 and original_area > 0 and mask.sum().item( + ) > 0: + if mask_area / original_area < self.test_cfg.overlap_thr: + continue + # merge stuff regions + if not isthing: + panoptic_seg[mask] = int(pred_class) + else: + panoptic_seg[mask] = int( + pred_class) + instance_id * INSTANCE_OFFSET + instance_id += 1 + + panoptic_seg = PixelData( + sem_seg=panoptic_seg[None], + metainfo={ + 'label_names': all_text_prompts, + 'ignore_index': self.test_cfg.get('ignore_index', 255) + }) + return panoptic_seg diff --git a/projects/XDecoder/xdecoder/utils.py b/projects/XDecoder/xdecoder/utils.py new file mode 100644 index 00000000000..5cbf1760d6a --- /dev/null +++ b/projects/XDecoder/xdecoder/utils.py @@ -0,0 +1,215 @@ +import logging +from contextlib import contextmanager +from functools import wraps + +import torch +from mmcv.cnn.bricks.wrappers import obsolete_torch_version +from torch.nn import functional as F + +TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2]) + + +def is_lower_torch_version(version=(1, 10)): + """Check if the pytorch version is lower than "version.""" + return obsolete_torch_version(TORCH_VERSION, version) + + +@contextmanager +def _ignore_torch_cuda_oom(): + """A context which ignores CUDA OOM exception from pytorch.""" + try: + yield + except RuntimeError as e: + if 'CUDA out of memory. ' in str(e): + pass + else: + raise + + +def retry_if_cuda_oom(func): + """Makes a function retry itself after encountering pytorch's CUDA OOM + error. It will first retry after calling `torch.cuda.empty_cache()`. + + If that still fails, it will then retry by trying to convert inputs + to CPUs. In this case, it expects the function to dispatch to CPU + implementation. The return values may become CPU tensors as well + and it's user's responsibility to convert it back to CUDA tensor + if needed. + + Args: + func: a stateless callable that takes tensor-like objects as arguments + + Returns: + a callable which retries `func` if OOM is encountered. + + Examples: + :: + output = retry_if_cuda_oom(some_torch_function)(input1, input2) + # output may be on CPU even if inputs are on GPU + + Note: + 1. When converting inputs to CPU, it will only + look at each argument and check if it has `.device` + and `.to` for conversion. Nested structures of tensors + are not supported. + + 2. Since the function might be called more than once, it has to be + stateless. + """ + + def maybe_to_cpu(x): + try: + like_gpu_tensor = x.device.type == 'cuda' and hasattr(x, 'to') + except AttributeError: + like_gpu_tensor = False + if like_gpu_tensor: + return x.to(device='cpu') + else: + return x + + @wraps(func) + def wrapped(*args, **kwargs): + with _ignore_torch_cuda_oom(): + return func(*args, **kwargs) + + # Clear cache and retry + torch.cuda.empty_cache() + with _ignore_torch_cuda_oom(): + return func(*args, **kwargs) + + # Try on CPU. This slows down the code significantly, + # therefore print a notice. + logger = logging.getLogger(__name__) + logger.info( + 'Attempting to copy inputs of {} to CPU due to CUDA OOM'.format( + str(func)[0:5])) + new_args = (maybe_to_cpu(x) for x in args) + new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()} + return func(*new_args, **new_kwargs) + + return wrapped + + +def sem_seg_postprocess(result, img_size, output_height, output_width): + """Return semantic segmentation predictions in the original resolution. + + The input images are often resized when entering semantic segmentor. + Moreover, in same cases, they also padded inside segmentor to be + divisible by maximum network stride. As a result, we often need + the predictions of the segmentor in a different resolution from + its inputs. + + Args: + result (Tensor): semantic segmentation prediction logits. + A tensor of shape (C, H, W), where C is the number of classes, + and H, W are the height and width of the prediction. + img_size (tuple): image size that segmentor is taking as input. + output_height, output_width: the desired output resolution. + + Returns: + semantic segmentation prediction (Tensor): A tensor of the shape + (C, output_height, output_width) that contains per-pixel + soft predictions. + """ + result = result[:, :img_size[0], :img_size[1]].expand(1, -1, -1, -1) + if is_lower_torch_version(): + result = F.interpolate( + result, + size=(output_height, output_width), + mode='bicubic', + align_corners=False)[0] + else: + result = F.interpolate( + result, + size=(output_height, output_width), + mode='bicubic', + align_corners=False, + antialias=True)[0] + return result + + +def get_prompt_templates(): + prompt_templates = [ + '{}.', + 'a photo of a {}.', + 'a bad photo of a {}.', + 'a photo of many {}.', + 'a sculpture of a {}.', + 'a photo of the hard to see {}.', + 'a low resolution photo of the {}.', + 'a rendering of a {}.', + 'graffiti of a {}.', + 'a bad photo of the {}.', + 'a cropped photo of the {}.', + 'a tattoo of a {}.', + 'the embroidered {}.', + 'a photo of a hard to see {}.', + 'a bright photo of a {}.', + 'a photo of a clean {}.', + 'a photo of a dirty {}.', + 'a dark photo of the {}.', + 'a drawing of a {}.', + 'a photo of my {}.', + 'the plastic {}.', + 'a photo of the cool {}.', + 'a close-up photo of a {}.', + 'a black and white photo of the {}.', + 'a painting of the {}.', + 'a painting of a {}.', + 'a pixelated photo of the {}.', + 'a sculpture of the {}.', + 'a bright photo of the {}.', + 'a cropped photo of a {}.', + 'a plastic {}.', + 'a photo of the dirty {}.', + 'a jpeg corrupted photo of a {}.', + 'a blurry photo of the {}.', + 'a photo of the {}.', + 'a good photo of the {}.', + 'a rendering of the {}.', + 'a {} in a video game.', + 'a photo of one {}.', + 'a doodle of a {}.', + 'a close-up photo of the {}.', + 'the origami {}.', + 'the {} in a video game.', + 'a sketch of a {}.', + 'a doodle of the {}.', + 'a origami {}.', + 'a low resolution photo of a {}.', + 'the toy {}.', + 'a rendition of the {}.', + 'a photo of the clean {}.', + 'a photo of a large {}.', + 'a rendition of a {}.', + 'a photo of a nice {}.', + 'a photo of a weird {}.', + 'a blurry photo of a {}.', + 'a cartoon {}.', + 'art of a {}.', + 'a sketch of the {}.', + 'a embroidered {}.', + 'a pixelated photo of a {}.', + 'itap of the {}.', + 'a jpeg corrupted photo of the {}.', + 'a good photo of a {}.', + 'a plushie {}.', + 'a photo of the nice {}.', + 'a photo of the small {}.', + 'a photo of the weird {}.', + 'the cartoon {}.', + 'art of the {}.', + 'a drawing of the {}.', + 'a photo of the large {}.', + 'a black and white photo of a {}.', + 'the plushie {}.', + 'a dark photo of a {}.', + 'itap of a {}.', + 'graffiti of the {}.', + 'a toy {}.', + 'itap of my {}.', + 'a photo of a cool {}.', + 'a photo of a small {}.', + 'a tattoo of the {}.', + ] + return prompt_templates diff --git a/projects/XDecoder/xdecoder/xdecoder.py b/projects/XDecoder/xdecoder/xdecoder.py new file mode 100644 index 00000000000..893a07dcfe4 --- /dev/null +++ b/projects/XDecoder/xdecoder/xdecoder.py @@ -0,0 +1,36 @@ +from torch import Tensor + +from mmdet.models.detectors.single_stage import SingleStageDetector +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig + + +@MODELS.register_module() +class XDecoder(SingleStageDetector): + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + head: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super(SingleStageDetector, self).__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + + head_ = head.deepcopy() + head_.update(test_cfg=test_cfg) + self.sem_seg_head = MODELS.build(head_) # TODO: sem_seg_head -> head + + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + visual_features = self.extract_feat(batch_inputs) + outputs = self.sem_seg_head.predict(visual_features, + batch_data_samples) + return outputs diff --git a/projects/gradio_demo/README.md b/projects/gradio_demo/README.md new file mode 100644 index 00000000000..e2e1a965863 --- /dev/null +++ b/projects/gradio_demo/README.md @@ -0,0 +1,49 @@ +# MMDetection Gradio Demo + +Here is a gradio demo for MMDetection supported inference tasks. + +Currently supported tasks: + +- Object Detection +- Instance Segmentation +- Panoptic Segmentation +- Grounding Object Detection +- Open Vocabulary Object Detection +- Open Vocabulary Instance Segmentation +- Open Vocabulary Semantic Segmentation +- Open Vocabulary Panoptic Segmentation +- Referring Expression Segmentation +- Image Caption +- Referring Expression Image Caption +- Text-To-Image Retrieval + +## Preview + + + +## Requirements + +To run the demo, you need to install MMDetection at first. And please install with the extra multi-modality +dependencies to enable multi-modality tasks. + +```shell +# At the MMDetection root folder +pip install -e ".[multimodal]" +``` + +And then install the latest gradio package. + +```shell +pip install "gradio>=3.31.0" +``` + +## Start + +Then, you can start the gradio server on the local machine by: + +```shell +cd mmdetection +python projects/gradio_demo/launch.py +``` + +The demo will start a local server `http://127.0.0.1:7860` and you can browse it by your browser. diff --git a/projects/gradio_demo/launch.py b/projects/gradio_demo/launch.py new file mode 100644 index 00000000000..5d9694237b5 --- /dev/null +++ b/projects/gradio_demo/launch.py @@ -0,0 +1,623 @@ +# Modified from MMPretrain +import gradio as gr +import torch +from mmengine.logging import MMLogger + +from mmdet.apis import DetInferencer +from projects.XDecoder.xdecoder.inference import ( + ImageCaptionInferencer, RefImageCaptionInferencer, + TextToImageRegionRetrievalInferencer) + +logger = MMLogger('mmdetection', logger_name='mmdet') +if torch.cuda.is_available(): + gpus = [ + torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count()) + ] + logger.info(f'Available GPUs: {len(gpus)}') +else: + gpus = None + logger.info('No available GPU.') + + +def get_free_device(): + if gpus is None: + return torch.device('cpu') + if hasattr(torch.cuda, 'mem_get_info'): + free = [torch.cuda.mem_get_info(gpu)[0] for gpu in gpus] + select = max(zip(free, range(len(free))))[1] + else: + import random + select = random.randint(0, len(gpus) - 1) + return gpus[select] + + +class ObjectDetectionTab: + model_list = [ + 'retinanet_r50-caffe_fpn_1x_coco', + 'faster-rcnn_r50-caffe_fpn_1x_coco', + 'dino-5scale_swin-l_8xb2-12e_coco.py', + ] + + def __init__(self) -> None: + self.create_ui() + + def create_ui(self): + with gr.Row(): + with gr.Column(): + select_model = gr.Dropdown( + label='Choose a model', + elem_id='od_models', + elem_classes='select_model', + choices=self.model_list, + value=self.model_list[0], + ) + with gr.Column(): + image_input = gr.Image( + label='Image', + source='upload', + elem_classes='input_image', + type='filepath', + interactive=True, + tool='editor', + ) + output = gr.Image( + label='Result', + source='upload', + interactive=False, + elem_classes='result', + ) + run_button = gr.Button( + 'Run', + elem_classes='run_button', + ) + run_button.click( + self.inference, + inputs=[select_model, image_input], + outputs=output, + ) + + with gr.Row(): + example_images = gr.Dataset( + components=[image_input], samples=[['demo/demo.jpg']]) + example_images.click( + fn=lambda x: gr.Image.update(value=x[0]), + inputs=example_images, + outputs=image_input) + + def inference(self, model, image): + det_inferencer = DetInferencer( + model, scope='mmdet', device=get_free_device()) + results_dict = det_inferencer(image, return_vis=True, no_save_vis=True) + vis = results_dict['visualization'][0] + return vis + + +class InstanceSegTab(ObjectDetectionTab): + model_list = ['mask-rcnn_r50-caffe_fpn_1x_coco', 'solov2_r50_fpn_1x_coco'] + + +class PanopticSegTab(ObjectDetectionTab): + model_list = [ + 'panoptic_fpn_r50_fpn_1x_coco', + 'mask2former_swin-s-p4-w7-224_8xb2-lsj-50e_coco-panoptic' + ] + + +class OpenVocabObjectDetectionTab: + model_list = ['glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365'] + + def __init__(self) -> None: + self.create_ui() + + def create_ui(self): + with gr.Row(): + with gr.Column(): + select_model = gr.Dropdown( + label='Choose a model', + elem_id='od_models', + elem_classes='select_model', + choices=self.model_list, + value=self.model_list[0], + ) + with gr.Column(): + image_input = gr.Image( + label='Image', + source='upload', + elem_classes='input_image', + type='filepath', + interactive=True, + tool='editor', + ) + text_input = gr.Textbox( + label='text prompt', + elem_classes='input_text', + interactive=True, + ) + output = gr.Image( + label='Result', + source='upload', + interactive=False, + elem_classes='result', + ) + run_button = gr.Button( + 'Run', + elem_classes='run_button', + ) + run_button.click( + self.inference, + inputs=[select_model, image_input, text_input], + outputs=output, + ) + + with gr.Row(): + example_images = gr.Dataset( + components=[image_input, text_input], + samples=[['demo/demo.jpg', 'bench . car .']]) + example_images.click( + fn=self.update, + inputs=example_images, + outputs=[image_input, text_input]) + + def update(self, example): + return gr.Image.update(value=example[0]), gr.Textbox.update( + value=example[1]) + + def inference(self, model, image, text): + det_inferencer = DetInferencer( + model, scope='mmdet', device=get_free_device()) + results_dict = det_inferencer( + image, + texts=text, + custom_entities=True, + pred_score_thr=0.5, + return_vis=True, + no_save_vis=True) + vis = results_dict['visualization'][0] + return vis + + +class GroundingDetectionTab(OpenVocabObjectDetectionTab): + model_list = ['glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365'] + + def create_ui(self): + with gr.Row(): + with gr.Column(): + select_model = gr.Dropdown( + label='Choose a model', + elem_id='od_models', + elem_classes='select_model', + choices=self.model_list, + value=self.model_list[0], + ) + with gr.Column(): + image_input = gr.Image( + label='Image', + source='upload', + elem_classes='input_image', + type='filepath', + interactive=True, + tool='editor', + ) + text_input = gr.Textbox( + label='text prompt', + elem_classes='input_text', + interactive=True, + ) + output = gr.Image( + label='Result', + source='upload', + interactive=False, + elem_classes='result', + ) + run_button = gr.Button( + 'Run', + elem_classes='run_button', + ) + run_button.click( + self.inference, + inputs=[select_model, image_input, text_input], + outputs=output, + ) + + with gr.Row(): + example_images = gr.Dataset( + components=[image_input, text_input], + samples=[['demo/demo.jpg', 'There are a lot of cars here.']]) + example_images.click( + fn=self.update, + inputs=example_images, + outputs=[image_input, text_input]) + + def inference(self, model, image, text): + det_inferencer = DetInferencer( + model, scope='mmdet', device=get_free_device()) + results_dict = det_inferencer( + image, + texts=text, + custom_entities=False, + pred_score_thr=0.5, + return_vis=True, + no_save_vis=True) + vis = results_dict['visualization'][0] + return vis + + +class OpenVocabInstanceSegTab(OpenVocabObjectDetectionTab): + model_list = ['xdecoder-tiny'] + + model_info = { + 'xdecoder-tiny': { + 'model': + 'projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py', # noqa + 'weights': + 'https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_last_novg.pt' # noqa + } + } + + def inference(self, model, image, text): + det_inferencer = DetInferencer( + **self.model_info[model], scope='mmdet', device=get_free_device()) + results_dict = det_inferencer( + image, texts=text, return_vis=True, no_save_vis=True) + vis = results_dict['visualization'][0] + return vis + + +class OpenVocabPanopticSegTab(OpenVocabObjectDetectionTab): + model_list = ['xdecoder-tiny'] + + model_info = { + 'xdecoder-tiny': { + 'model': + 'projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_coco.py', # noqa + 'weights': + 'https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_last_novg.pt' # noqa + } + } + + def create_ui(self): + with gr.Row(): + with gr.Column(): + select_model = gr.Dropdown( + label='Choose a model', + elem_id='od_models', + elem_classes='select_model', + choices=self.model_list, + value=self.model_list[0], + ) + with gr.Column(): + image_input = gr.Image( + label='Image', + source='upload', + elem_classes='input_image', + type='filepath', + interactive=True, + tool='editor', + ) + text_input = gr.Textbox( + label='thing text prompt', + elem_classes='input_text_thing', + interactive=True, + ) + stuff_text_input = gr.Textbox( + label='stuff text prompt', + elem_classes='input_text_stuff', + interactive=True, + ) + output = gr.Image( + label='Result', + source='upload', + interactive=False, + elem_classes='result', + ) + run_button = gr.Button( + 'Run', + elem_classes='run_button', + ) + run_button.click( + self.inference, + inputs=[ + select_model, image_input, text_input, stuff_text_input + ], + outputs=output, + ) + with gr.Row(): + example_images = gr.Dataset( + components=[image_input, text_input, stuff_text_input], + samples=[['demo/demo.jpg', 'bench.car', 'tree']]) + example_images.click( + fn=self.update, + inputs=example_images, + outputs=[image_input, text_input, stuff_text_input]) + + def update(self, example): + return gr.Image.update(value=example[0]), \ + gr.Textbox.update(label='thing text prompt', value=example[1]), \ + gr.Textbox.update(label='stuff text prompt', value=example[2]) + + def inference(self, model, image, text, stuff_text): + det_inferencer = DetInferencer( + **self.model_info[model], scope='mmdet', device=get_free_device()) + results_dict = det_inferencer( + image, + texts=text, + stuff_texts=stuff_text, + return_vis=True, + no_save_vis=True) + vis = results_dict['visualization'][0] + return vis + + +class OpenVocabSemSegTab(OpenVocabInstanceSegTab): + model_list = ['xdecoder-tiny'] + + model_info = { + 'xdecoder-tiny': { + 'model': + 'projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py', # noqa + 'weights': + 'https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_last_novg.pt' # noqa + } + } + + +class ReferSegTab(OpenVocabInstanceSegTab): + model_list = ['xdecoder-tiny'] + + model_info = { + 'xdecoder-tiny': { + 'model': + 'projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcocog.py', # noqa + 'weights': + 'https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_last_novg.pt' # noqa + } + } + + +class ImageCaptionTab: + model_list = ['xdecoder-tiny'] + + model_info = { + 'xdecoder-tiny': { + 'model': + 'projects/XDecoder/configs/xdecoder-tiny_zeroshot_caption_coco2014.py', # noqa + 'weights': + 'https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_last_novg.pt' # noqa + } + } + + def __init__(self) -> None: + self.create_ui() + + def create_ui(self): + with gr.Row(): + with gr.Column(): + select_model = gr.Dropdown( + label='Choose a model', + elem_id='image_caption_models', + elem_classes='select_model', + choices=self.model_list, + value=self.model_list[0], + ) + with gr.Column(): + image_input = gr.Image( + label='Input', + source='upload', + elem_classes='input_image', + interactive=True, + tool='editor', + ) + caption_output = gr.Textbox( + label='Result', + lines=2, + elem_classes='caption_result', + interactive=False, + ) + run_button = gr.Button( + 'Run', + elem_classes='run_button', + ) + run_button.click( + self.inference, + inputs=[select_model, image_input], + outputs=caption_output, + ) + + with gr.Row(): + example_images = gr.Dataset( + components=[image_input], samples=[['demo/demo.jpg']]) + example_images.click( + fn=lambda x: gr.Image.update(value=x[0]), + inputs=example_images, + outputs=image_input) + + def inference(self, model, image): + ic_inferencer = ImageCaptionInferencer( + **self.model_info[model], scope='mmdet', device=get_free_device()) + results_dict = ic_inferencer( + image, return_vis=False, no_save_vis=True, return_datasample=True) + return results_dict['predictions'][0].pred_caption + + +class ReferImageCaptionTab(OpenVocabInstanceSegTab): + model_list = ['xdecoder-tiny'] + + model_info = { + 'xdecoder-tiny': { + 'model': + 'projects/XDecoder/configs/xdecoder-tiny_zeroshot_ref-caption.py', # noqa + 'weights': + 'https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_last_novg.pt' # noqa + } + } + + def create_ui(self): + with gr.Row(): + with gr.Column(): + select_model = gr.Dropdown( + label='Choose a model', + elem_id='image_caption_models', + elem_classes='select_model', + choices=self.model_list, + value=self.model_list[0], + ) + with gr.Column(): + image_input = gr.Image( + label='Input', + source='upload', + elem_classes='input_image', + type='filepath', + interactive=True, + tool='editor', + ) + text_input = gr.Textbox( + label='text prompt', + elem_classes='input_text', + interactive=True, + ) + output = gr.Image( + label='Result', + source='upload', + interactive=False, + elem_classes='result', + ) + run_button = gr.Button( + 'Run', + elem_classes='run_button', + ) + run_button.click( + self.inference, + inputs=[select_model, image_input, text_input], + outputs=output, + ) + + with gr.Row(): + example_images = gr.Dataset( + components=[image_input, text_input], + samples=[['demo/demo.jpg', 'tree']]) + example_images.click( + fn=self.update, + inputs=example_images, + outputs=[image_input, text_input]) + + def update(self, example): + return gr.Image.update(value=example[0]), gr.Textbox.update( + value=example[1]) + + def inference(self, model, image, text): + ric_inferencer = RefImageCaptionInferencer( + **self.model_info[model], scope='mmdet', device=get_free_device()) + results_dict = ric_inferencer( + image, texts=text, return_vis=True, no_save_vis=True) + vis = results_dict['visualization'][0] + return vis + + +class TextToImageRetrievalTab: + model_list = ['xdecoder-tiny'] + + model_info = { + 'xdecoder-tiny': { + 'model': + 'projects/XDecoder/configs/xdecoder-tiny_zeroshot_text-image-retrieval.py', # noqa + 'weights': + 'https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_last_novg.pt' # noqa + } + } + + def __init__(self) -> None: + self.create_ui() + + def create_ui(self): + with gr.Row(): + with gr.Column(): + select_model = gr.Dropdown( + label='Choose a model', + elem_id='t2i_retri_models', + elem_classes='select_model', + choices=self.model_list, + value=self.model_list[0], + ) + with gr.Column(): + prototype = gr.File( + file_count='multiple', file_types=['image']) + text_input = gr.Textbox( + label='Query', + elem_classes='input_text', + interactive=True, + ) + retri_output = gr.Image( + label='Result', + source='upload', + interactive=False, + elem_classes='result', + ) + + run_button = gr.Button( + 'Run', + elem_classes='run_button', + ) + run_button.click( + self.inference, + inputs=[select_model, prototype, text_input], + outputs=retri_output, + ) + + def inference(self, model, prototype, text): + inputs = [file.name for file in prototype] + retri_inferencer = TextToImageRegionRetrievalInferencer( + **self.model_info[model], scope='mmdet', device=get_free_device()) + results_dict = retri_inferencer( + inputs, texts=text, return_vis=True, no_save_vis=True) + vis = results_dict['visualization'][0] + return vis + + +if __name__ == '__main__': + title = 'MMDetection Inference Demo' + + DESCRIPTION = '''#
MMDetection Inference Demo
+
+ +
+ + #### This is an official demo for MMDet. \n + + - The first time running requires downloading the weights, + please wait a moment. \n + - OV is mean Open Vocabulary \n + - Refer Seg is mean Referring Expression Segmentation \n + - In Text-Image Region Retrieval, you need to provide n images and + a query text, and the model will predict the most matching image and + its corresponding grounding mask. + ''' + + with gr.Blocks(analytics_enabled=False, title=title) as demo: + gr.Markdown(DESCRIPTION) + with gr.Tabs(): + with gr.TabItem('Detection'): + ObjectDetectionTab() + with gr.TabItem('Instance'): + InstanceSegTab() + with gr.TabItem('Panoptic'): + PanopticSegTab() + with gr.TabItem('Grounding Detection'): + GroundingDetectionTab() + with gr.TabItem('OV Detection'): + OpenVocabObjectDetectionTab() + with gr.TabItem('OV Instance'): + OpenVocabInstanceSegTab() + with gr.TabItem('OV Panoptic'): + OpenVocabPanopticSegTab() + with gr.TabItem('OV SemSeg'): + OpenVocabSemSegTab() + with gr.TabItem('Refer Seg'): + ReferSegTab() + with gr.TabItem('Image Caption'): + ImageCaptionTab() + with gr.TabItem('Refer Caption'): + ReferImageCaptionTab() + with gr.TabItem('Text-Image Region Retrieval'): + TextToImageRetrievalTab() + demo.queue().launch(share=True) diff --git a/requirements/multimodal.txt b/requirements/multimodal.txt index 579f70fcfb4..5abdb4fdbff 100644 --- a/requirements/multimodal.txt +++ b/requirements/multimodal.txt @@ -1,2 +1,3 @@ nltk +pycocoevalcap transformers diff --git a/setup.cfg b/setup.cfg index 70dd621c8f5..a3878cf1071 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,4 +18,4 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true [codespell] skip = *.ipynb quiet-level = 3 -ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam,DOTA,dota +ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam,DOTA,dota,conveyer diff --git a/tests/test_apis/test_inference.py b/tests/test_apis/test_inference.py index c68e4459896..e42f86c64e8 100644 --- a/tests/test_apis/test_inference.py +++ b/tests/test_apis/test_inference.py @@ -62,8 +62,8 @@ def test_inference_detector(config, devices): # test init_detector with config_file: str and cfg_options rng = np.random.RandomState(0) - img1 = rng.randint(0, 255, (100, 100, 3), dtype=np.uint8) - img2 = rng.randint(0, 255, (100, 100, 3), dtype=np.uint8) + img1 = rng.randint(0, 255, (32, 32, 3), dtype=np.uint8) + img2 = rng.randint(0, 255, (32, 32, 3), dtype=np.uint8) for device in devices: if device == 'cuda' and not torch.cuda.is_available(): diff --git a/tests/test_datasets/test_transforms/test_loading.py b/tests/test_datasets/test_transforms/test_loading.py index 1993fae43da..840ad51c4ed 100644 --- a/tests/test_datasets/test_transforms/test_loading.py +++ b/tests/test_datasets/test_transforms/test_loading.py @@ -110,6 +110,28 @@ def test_load_mask_poly2mask(self): self.assertEqual(len(results['gt_masks']), 3) self.assertIsInstance(results['gt_masks'], BitmapMasks) + def test_load_semseg(self): + transform = LoadAnnotations( + with_bbox=False, with_label=False, with_seg=True, with_mask=False) + results = transform(copy.deepcopy(self.results)) + self.assertIn('gt_seg_map', results) + self.assertIn('ignore_index', results) + self.assertEqual(results['gt_seg_map'].shape, (288, 512)) + + # test reduce_zero_label and ignore_index + transform = LoadAnnotations( + with_bbox=False, + with_label=False, + with_seg=True, + with_mask=False, + reduce_zero_label=True, + ignore_index=10) + results = transform(copy.deepcopy(self.results)) + self.assertIn('gt_seg_map', results) + self.assertIn('ignore_index', results) + self.assertEqual(results['ignore_index'], 10) + self.assertEqual(results['gt_seg_map'].shape, (288, 512)) + def test_repr(self): transform = LoadAnnotations( with_bbox=True, diff --git a/tests/test_datasets/test_transforms/test_transforms.py b/tests/test_datasets/test_transforms/test_transforms.py index e064e299518..e36f518aa8b 100644 --- a/tests/test_datasets/test_transforms/test_transforms.py +++ b/tests/test_datasets/test_transforms/test_transforms.py @@ -15,7 +15,8 @@ PhotoMetricDistortion, RandomAffine, RandomCenterCropPad, RandomCrop, RandomErasing, RandomFlip, RandomShift, - Resize, SegRescale, YOLOXHSVRandomAug) + Resize, ResizeShortestEdge, SegRescale, + YOLOXHSVRandomAug) # yapf:enable from mmdet.evaluation import bbox_overlaps from mmdet.registry import TRANSFORMS @@ -42,38 +43,38 @@ def setUp(self): """ rng = np.random.RandomState(0) self.data_info1 = dict( - img=np.random.random((1333, 800, 3)), - gt_seg_map=np.random.random((1333, 800, 3)), + img=np.random.random((400, 500, 3)), + gt_seg_map=np.random.random((400, 500, 3)), gt_bboxes=np.array([[0, 0, 112, 112]], dtype=np.float32), - gt_masks=BitmapMasks( - rng.rand(1, 1333, 800), height=1333, width=800)) + gt_masks=BitmapMasks(rng.rand(1, 400, 500), height=400, width=500)) self.data_info2 = dict( - img=np.random.random((300, 400, 3)), - gt_bboxes=np.array([[200, 150, 600, 450]], dtype=np.float32), + img=np.random.random((200, 100, 3)), + gt_bboxes=np.array([[20, 15, 60, 45]], dtype=np.float32), dtype=np.float32) - self.data_info3 = dict(img=np.random.random((300, 400, 3))) + self.data_info3 = dict(img=np.random.random((200, 100, 3))) def test_resize(self): # test keep_ratio is True - transform = Resize(scale=(2000, 2000), keep_ratio=True) + transform = Resize(scale=(100, 100), keep_ratio=True) results = transform(copy.deepcopy(self.data_info1)) - self.assertEqual(results['img_shape'], (2000, 1200)) - self.assertEqual(results['scale_factor'], (1200 / 800, 2000 / 1333)) + self.assertEqual(results['img_shape'], (80, 100)) + self.assertEqual(results['scale_factor'], (80 / 400, 100 / 500)) # test resize_bboxes/seg/masks transform = Resize(scale_factor=(1.5, 2)) results = transform(copy.deepcopy(self.data_info1)) - self.assertTrue((results['gt_bboxes'] == np.array([[0, 0, 168, - 224]])).all()) - self.assertEqual(results['gt_masks'].height, 2666) - self.assertEqual(results['gt_masks'].width, 1200) - self.assertEqual(results['gt_seg_map'].shape[:2], (2666, 1200)) + self.assertTrue( + (results['gt_bboxes'] == np.array([[0., 0., 168., 224.]])).all()) + self.assertEqual(results['gt_masks'].height, 800) + self.assertEqual(results['gt_masks'].width, 750) + self.assertEqual(results['gt_seg_map'].shape[:2], (800, 750)) # test clip_object_border = False transform = Resize(scale=(200, 150), clip_object_border=False) results = transform(self.data_info2) - self.assertTrue((results['gt_bboxes'] == np.array([100, 75, 300, - 225])).all()) + self.assertTrue( + (results['gt_bboxes'] == np.array([40., 11.25, 120., + 33.75])).all()) # test only with image transform = Resize(scale=(200, 150), clip_object_border=False) @@ -93,10 +94,10 @@ def test_resize_use_box_type(self): data_info2 = copy.deepcopy(self.data_info2) data_info2['gt_bboxes'] = HorizontalBoxes(data_info2['gt_bboxes']) # test keep_ratio is True - transform = Resize(scale=(2000, 2000), keep_ratio=True) + transform = Resize(scale=(100, 150), keep_ratio=True) results = transform(copy.deepcopy(data_info1)) - self.assertEqual(results['img_shape'], (2000, 1200)) - self.assertEqual(results['scale_factor'], (1200 / 800, 2000 / 1333)) + self.assertEqual(results['img_shape'], (100, 125)) + self.assertEqual(results['scale_factor'], (100 / 400, 125 / 500)) # test resize_bboxes/seg/masks transform = Resize(scale_factor=(1.5, 2)) @@ -104,16 +105,15 @@ def test_resize_use_box_type(self): self.assertTrue( (results['gt_bboxes'].numpy() == np.array([[0, 0, 168, 224]])).all()) - self.assertEqual(results['gt_masks'].height, 2666) - self.assertEqual(results['gt_masks'].width, 1200) - self.assertEqual(results['gt_seg_map'].shape[:2], (2666, 1200)) + self.assertEqual(results['gt_masks'].height, 800) + self.assertEqual(results['gt_masks'].width, 750) + self.assertEqual(results['gt_seg_map'].shape[:2], (800, 750)) # test clip_object_border = False transform = Resize(scale=(200, 150), clip_object_border=False) results = transform(data_info2) - self.assertTrue( - (results['gt_bboxes'].numpy() == np.array([100, 75, 300, - 225])).all()) + self.assertTrue((results['gt_bboxes'].numpy() == np.array( + [40., 11.25, 120., 33.75])).all()) # test geometric transformation with homography matrix transform = Resize(scale_factor=(1.5, 2)) @@ -124,9 +124,9 @@ def test_resize_use_box_type(self): ).all()) def test_repr(self): - transform = Resize(scale=(2000, 2000), keep_ratio=True) + transform = Resize(scale=(100, 100), keep_ratio=True) self.assertEqual( - repr(transform), ('Resize(scale=(2000, 2000), ' + repr(transform), ('Resize(scale=(100, 100), ' 'scale_factor=None, keep_ratio=True, ' 'clip_object_border=True), backend=cv2), ' 'interpolation=bilinear)')) @@ -142,23 +142,17 @@ def setUp(self): """ rng = np.random.RandomState(0) self.data_info1 = dict( - img=np.random.random((1333, 800, 3)), - gt_seg_map=np.random.random((1333, 800, 3)), + img=np.random.random((200, 300, 3)), + gt_seg_map=np.random.random((200, 300, 3)), gt_bboxes=np.array([[0, 0, 112, 112]], dtype=np.float32), - gt_masks=BitmapMasks( - rng.rand(1, 1333, 800), height=1333, width=800)) - self.data_info2 = dict( - img=np.random.random((300, 400, 3)), - gt_bboxes=np.array([[200, 150, 600, 450]], dtype=np.float32), - dtype=np.float32) - self.data_info3 = dict(img=np.random.random((300, 400, 3))) + gt_masks=BitmapMasks(rng.rand(1, 200, 300), height=200, width=300)) def test_resize(self): # test keep_ratio is True - transform = FixScaleResize(scale=(2001, 2002), keep_ratio=True) + transform = FixScaleResize(scale=(101, 201), keep_ratio=True) results = transform(copy.deepcopy(self.data_info1)) - self.assertEqual(results['img_shape'], (2002, 1201)) - self.assertEqual(results['scale_factor'], (1201 / 800, 2002 / 1333)) + self.assertEqual(results['img_shape'], (101, 151)) + self.assertEqual(results['scale_factor'], (151 / 300, 101 / 200)) class TestFixShapeResize(unittest.TestCase): @@ -171,35 +165,32 @@ def setUp(self): """ rng = np.random.RandomState(0) self.data_info1 = dict( - img=np.random.random((1333, 800, 3)), - gt_seg_map=np.random.random((1333, 800, 3)), - gt_bboxes=np.array([[0, 0, 112, 1333]], dtype=np.float32), - gt_masks=BitmapMasks( - rng.rand(1, 1333, 800), height=1333, width=800)) + img=np.random.random((200, 300, 3)), + gt_seg_map=np.random.random((200, 300, 3)), + gt_bboxes=np.array([[0, 0, 112, 133]], dtype=np.float32), + gt_masks=BitmapMasks(rng.rand(1, 200, 300), height=200, width=300)) self.data_info2 = dict( img=np.random.random((300, 400, 3)), gt_bboxes=np.array([[200, 150, 600, 450]], dtype=np.float32), dtype=np.float32) self.data_info3 = dict(img=np.random.random((300, 400, 3))) self.data_info4 = dict( - img=np.random.random((600, 800, 3)), + img=np.random.random((400, 450, 3)), gt_bboxes=np.array([[200, 150, 300, 400]], dtype=np.float32), dtype=np.float32) def test_resize(self): # test keep_ratio is True - transform = FixShapeResize(width=2000, height=800, keep_ratio=True) + transform = FixShapeResize(width=100, height=50, keep_ratio=True) results = transform(copy.deepcopy(self.data_info1)) - self.assertEqual(results['img_shape'], (800, 2000)) - self.assertEqual(results['scale_factor'], (800 / 1333, 800 / 1333)) + self.assertEqual(results['img_shape'], (50, 100)) + self.assertEqual(results['scale_factor'], (50 / 200, 50 / 200)) # test resize_bboxes/seg/masks - transform = FixShapeResize(width=2000, height=800, keep_ratio=False) + transform = FixShapeResize(width=120, height=100, keep_ratio=False) results = transform(copy.deepcopy(self.data_info1)) - self.assertTrue((results['gt_bboxes'] == np.array([[0, 0, 280, - 800]])).all()) - self.assertEqual(results['gt_masks'].height, 800) - self.assertEqual(results['gt_masks'].width, 2000) - self.assertEqual(results['gt_seg_map'].shape[:2], (800, 2000)) + self.assertEqual(results['gt_masks'].height, 100) + self.assertEqual(results['gt_masks'].width, 120) + self.assertEqual(results['gt_seg_map'].shape[:2], (100, 120)) # test clip_object_border = False transform = FixShapeResize( @@ -229,20 +220,20 @@ def test_resize_with_boxlist(self): data_info4 = copy.deepcopy(self.data_info4) data_info4['gt_bboxes'] = HorizontalBoxes(data_info4['gt_bboxes']) # test keep_ratio is True - transform = FixShapeResize(width=2000, height=800, keep_ratio=True) + transform = FixShapeResize(width=100, height=200, keep_ratio=True) results = transform(copy.deepcopy(data_info1)) - self.assertEqual(results['img_shape'], (800, 2000)) - self.assertEqual(results['scale_factor'], (800 / 1333, 800 / 1333)) + self.assertEqual(results['img_shape'], (200, 100)) + self.assertEqual(results['scale_factor'], (100 / 300, 100 / 300)) # test resize_bboxes/seg/masks - transform = FixShapeResize(width=2000, height=800, keep_ratio=False) + transform = FixShapeResize(width=150, height=200, keep_ratio=False) results = transform(copy.deepcopy(data_info1)) self.assertTrue( - (results['gt_bboxes'].numpy() == np.array([[0, 0, 280, - 800]])).all()) - self.assertEqual(results['gt_masks'].height, 800) - self.assertEqual(results['gt_masks'].width, 2000) - self.assertEqual(results['gt_seg_map'].shape[:2], (800, 2000)) + (results['gt_bboxes'].numpy() == np.array([[0, 0, 56, + 133]])).all()) + self.assertEqual(results['gt_masks'].height, 200) + self.assertEqual(results['gt_masks'].width, 150) + self.assertEqual(results['gt_seg_map'].shape[:2], (200, 150)) # test clip_object_border = False transform = FixShapeResize( @@ -267,9 +258,9 @@ def test_resize_with_boxlist(self): ).all()) def test_repr(self): - transform = FixShapeResize(width=2000, height=2000, keep_ratio=True) + transform = FixShapeResize(width=100, height=50, keep_ratio=True) self.assertEqual( - repr(transform), ('FixShapeResize(width=2000, height=2000, ' + repr(transform), ('FixShapeResize(width=100, height=50, ' 'keep_ratio=True, ' 'clip_object_border=True), backend=cv2), ' 'interpolation=bilinear)')) @@ -381,41 +372,41 @@ def setUp(self): """ rng = np.random.RandomState(0) self.results = { - 'img': np.random.random((1333, 800, 3)), + 'img': np.random.random((100, 80, 3)), 'gt_masks': - BitmapMasks(rng.rand(4, 1333, 800), height=1333, width=800) + BitmapMasks(rng.rand(4, 100, 80), height=100, width=80) } def test_transform(self): # test pad img/gt_masks with size - transform = Pad(size=(1200, 2000)) + transform = Pad(size=(120, 110)) results = transform(copy.deepcopy(self.results)) - self.assertEqual(results['img'].shape[:2], (2000, 1200)) - self.assertEqual(results['gt_masks'].masks.shape[1:], (2000, 1200)) + self.assertEqual(results['img'].shape[:2], (110, 120)) + self.assertEqual(results['gt_masks'].masks.shape[1:], (110, 120)) # test pad img/gt_masks with size_divisor transform = Pad(size_divisor=11) results = transform(copy.deepcopy(self.results)) - self.assertEqual(results['img'].shape[:2], (1342, 803)) - self.assertEqual(results['gt_masks'].masks.shape[1:], (1342, 803)) + self.assertEqual(results['img'].shape[:2], (110, 88)) + self.assertEqual(results['gt_masks'].masks.shape[1:], (110, 88)) # test pad img/gt_masks with pad_to_square transform = Pad(pad_to_square=True) results = transform(copy.deepcopy(self.results)) - self.assertEqual(results['img'].shape[:2], (1333, 1333)) - self.assertEqual(results['gt_masks'].masks.shape[1:], (1333, 1333)) + self.assertEqual(results['img'].shape[:2], (100, 100)) + self.assertEqual(results['gt_masks'].masks.shape[1:], (100, 100)) # test pad img/gt_masks with pad_to_square and size_divisor transform = Pad(pad_to_square=True, size_divisor=11) results = transform(copy.deepcopy(self.results)) - self.assertEqual(results['img'].shape[:2], (1342, 1342)) - self.assertEqual(results['gt_masks'].masks.shape[1:], (1342, 1342)) + self.assertEqual(results['img'].shape[:2], (110, 110)) + self.assertEqual(results['gt_masks'].masks.shape[1:], (110, 110)) # test pad img/gt_masks with pad_to_square and size_divisor transform = Pad(pad_to_square=True, size_divisor=11) results = transform(copy.deepcopy(self.results)) - self.assertEqual(results['img'].shape[:2], (1342, 1342)) - self.assertEqual(results['gt_masks'].masks.shape[1:], (1342, 1342)) + self.assertEqual(results['img'].shape[:2], (110, 110)) + self.assertEqual(results['gt_masks'].masks.shape[1:], (110, 110)) def test_repr(self): transform = Pad( @@ -1744,3 +1735,35 @@ def test_repr(self): 'img_border_value=128, ' 'mask_border_value=0, ' 'seg_ignore_label=255)')) + + +class TestResizeShortestEdge(unittest.TestCase): + + def setUp(self): + """Setup the model and optimizer which are used in every test method. + + TestCase calls functions in this order: setUp() -> testMethod() + -> tearDown() -> cleanUp() + """ + rng = np.random.RandomState(0) + self.data_info = dict( + img=np.random.random((220, 100, 3)), + gt_seg_map=np.random.random((220, 100, 3)), + gt_bboxes=np.array([[0, 0, 112, 12]], dtype=np.float32), + gt_masks=BitmapMasks(rng.rand(1, 220, 100), height=220, width=100)) + + def test_resize(self): + transform = ResizeShortestEdge(scale=200) + results = transform(copy.deepcopy(self.data_info)) + self.assertEqual(results['img_shape'], (440, 200)) + self.assertEqual(results['scale_factor'], (200 / 100, 440 / 220)) + + transform = ResizeShortestEdge(scale=200, max_size=301) + results = transform(copy.deepcopy(self.data_info)) + self.assertEqual(results['img_shape'], (301, 137)) + self.assertEqual(results['scale_factor'], (137 / 100, 301 / 220)) + + transform = ResizeShortestEdge(scale=201, keep_ratio=True) + results = transform(copy.deepcopy(self.data_info)) + self.assertEqual(results['img_shape'], (442, 201)) + self.assertEqual(results['scale_factor'], (201 / 100, 442 / 220)) diff --git a/tests/test_models/test_detectors/test_glip.py b/tests/test_models/test_detectors/test_glip.py index fca05ac2648..8be3d8d719f 100644 --- a/tests/test_models/test_detectors/test_glip.py +++ b/tests/test_models/test_detectors/test_glip.py @@ -50,7 +50,7 @@ def test_glip_forward_predict_mode(self, cfg_file, devices): # test custom_entities is True packed_inputs = demo_mm_inputs( 2, [[3, 128, 128], [3, 125, 130]], - captions=['a', 'b'], + texts=['a', 'b'], custom_entities=True) data = detector.data_preprocessor(packed_inputs, False) # Test forward test @@ -63,7 +63,7 @@ def test_glip_forward_predict_mode(self, cfg_file, devices): # test custom_entities is False packed_inputs = demo_mm_inputs( 2, [[3, 128, 128], [3, 125, 130]], - captions=['a', 'b'], + texts=['a', 'b'], custom_entities=False) data = detector.data_preprocessor(packed_inputs, False) # Test forward test diff --git a/tests/test_models/test_detectors/test_single_stage.py b/tests/test_models/test_detectors/test_single_stage.py index 1ed3c7c0f7c..22dbd1a98cb 100644 --- a/tests/test_models/test_detectors/test_single_stage.py +++ b/tests/test_models/test_detectors/test_single_stage.py @@ -39,11 +39,8 @@ def test_init(self, cfg_file): ('retinanet/retinanet_r18_fpn_1x_coco.py', ('cpu', 'cuda')), ('centernet/centernet_r18_8xb16-crop512-140e_coco.py', ('cpu', 'cuda')), - ('fsaf/fsaf_r50_fpn_1x_coco.py', ('cpu', 'cuda')), ('yolox/yolox_tiny_8xb8-300e_coco.py', ('cpu', 'cuda')), ('yolo/yolov3_mobilenetv2_8xb24-320-300e_coco.py', ('cpu', 'cuda')), - ('reppoints/reppoints-minmax_r50_fpn-gn_head-gn_1x_coco.py', ('cpu', - 'cuda')), ]) def test_single_stage_forward_loss_mode(self, cfg_file, devices): message_hub = MessageHub.get_instance( @@ -74,11 +71,8 @@ def test_single_stage_forward_loss_mode(self, cfg_file, devices): ('retinanet/retinanet_r18_fpn_1x_coco.py', ('cpu', 'cuda')), ('centernet/centernet_r18_8xb16-crop512-140e_coco.py', ('cpu', 'cuda')), - ('fsaf/fsaf_r50_fpn_1x_coco.py', ('cpu', 'cuda')), ('yolox/yolox_tiny_8xb8-300e_coco.py', ('cpu', 'cuda')), ('yolo/yolov3_mobilenetv2_8xb24-320-300e_coco.py', ('cpu', 'cuda')), - ('reppoints/reppoints-minmax_r50_fpn-gn_head-gn_1x_coco.py', ('cpu', - 'cuda')), ]) def test_single_stage_forward_predict_mode(self, cfg_file, devices): model = get_detector_cfg(cfg_file) @@ -108,11 +102,8 @@ def test_single_stage_forward_predict_mode(self, cfg_file, devices): ('retinanet/retinanet_r18_fpn_1x_coco.py', ('cpu', 'cuda')), ('centernet/centernet_r18_8xb16-crop512-140e_coco.py', ('cpu', 'cuda')), - ('fsaf/fsaf_r50_fpn_1x_coco.py', ('cpu', 'cuda')), ('yolox/yolox_tiny_8xb8-300e_coco.py', ('cpu', 'cuda')), ('yolo/yolov3_mobilenetv2_8xb24-320-300e_coco.py', ('cpu', 'cuda')), - ('reppoints/reppoints-minmax_r50_fpn-gn_head-gn_1x_coco.py', ('cpu', - 'cuda')), ]) def test_single_stage_forward_tensor_mode(self, cfg_file, devices): model = get_detector_cfg(cfg_file) diff --git a/tests/test_models/test_detectors/test_single_stage_instance_seg.py b/tests/test_models/test_detectors/test_single_stage_instance_seg.py index 3b761c9b0bd..51530341241 100644 --- a/tests/test_models/test_detectors/test_single_stage_instance_seg.py +++ b/tests/test_models/test_detectors/test_single_stage_instance_seg.py @@ -17,10 +17,7 @@ def setUp(self): @parameterized.expand([ 'solo/solo_r50_fpn_1x_coco.py', - 'solo/decoupled-solo_r50_fpn_1x_coco.py', - 'solo/decoupled-solo-light_r50_fpn_3x_coco.py', 'solov2/solov2_r50_fpn_1x_coco.py', - 'solov2/solov2-light_r18_fpn_ms-3x_coco.py', 'yolact/yolact_r50_1xb8-55e_coco.py', ]) def test_init(self, cfg_file): @@ -37,9 +34,6 @@ def test_init(self, cfg_file): @parameterized.expand([ ('solo/solo_r50_fpn_1x_coco.py', ('cpu', 'cuda')), - ('solo/decoupled-solo_r50_fpn_1x_coco.py', ('cpu', 'cuda')), - ('solo/decoupled-solo-light_r50_fpn_3x_coco.py', ('cpu', 'cuda')), - ('solov2/solov2_r50_fpn_1x_coco.py', ('cpu', 'cuda')), ('solov2/solov2-light_r18_fpn_ms-3x_coco.py', ('cpu', 'cuda')), ('yolact/yolact_r50_1xb8-55e_coco.py', ('cpu', 'cuda')), ]) @@ -69,11 +63,7 @@ def test_single_stage_forward_loss_mode(self, cfg_file, devices): self.assertIsInstance(losses, dict) @parameterized.expand([ - ('solo/solo_r50_fpn_1x_coco.py', ('cpu', 'cuda')), - ('solo/decoupled-solo_r50_fpn_1x_coco.py', ('cpu', 'cuda')), ('solo/decoupled-solo-light_r50_fpn_3x_coco.py', ('cpu', 'cuda')), - ('solov2/solov2_r50_fpn_1x_coco.py', ('cpu', 'cuda')), - ('solov2/solov2-light_r18_fpn_ms-3x_coco.py', ('cpu', 'cuda')), ('yolact/yolact_r50_1xb8-55e_coco.py', ('cpu', 'cuda')), ]) def test_single_stage_forward_predict_mode(self, cfg_file, devices): @@ -106,10 +96,7 @@ def test_single_stage_forward_predict_mode(self, cfg_file, devices): @parameterized.expand([ ('solo/solo_r50_fpn_1x_coco.py', ('cpu', 'cuda')), - ('solo/decoupled-solo_r50_fpn_1x_coco.py', ('cpu', 'cuda')), - ('solo/decoupled-solo-light_r50_fpn_3x_coco.py', ('cpu', 'cuda')), ('solov2/solov2_r50_fpn_1x_coco.py', ('cpu', 'cuda')), - ('solov2/solov2-light_r18_fpn_ms-3x_coco.py', ('cpu', 'cuda')), ('yolact/yolact_r50_1xb8-55e_coco.py', ('cpu', 'cuda')), ]) def test_single_stage_forward_tensor_mode(self, cfg_file, devices): diff --git a/tools/dataset_converters/ade20k2coco.py b/tools/dataset_converters/ade20k2coco.py index 3ae92325c28..e0b5ce86da8 100644 --- a/tools/dataset_converters/ade20k2coco.py +++ b/tools/dataset_converters/ade20k2coco.py @@ -1,23 +1,161 @@ import argparse +import json import os from pathlib import Path import numpy as np +import pycocotools.mask as mask_util from mmengine.utils import ProgressBar, mkdir_or_exist from panopticapi.utils import IdGenerator, save_json from PIL import Image from mmdet.datasets.ade20k import ADE20KPanopticDataset +ORIGINAL_CATEGORIES = [ + 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road, route', + 'bed', 'window', 'grass', 'cabinet', 'sidewalk, pavement', 'person', + 'earth, ground', 'door', 'table', 'mountain, mount', 'plant', 'curtain', + 'chair', 'car', 'water', 'painting, picture', 'sofa', 'shelf', 'house', + 'sea', 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk', + 'rock, stone', 'wardrobe, closet, press', 'lamp', 'tub', 'rail', 'cushion', + 'base, pedestal, stand', 'box', 'column, pillar', 'signboard, sign', + 'chest of drawers, chest, bureau, dresser', 'counter', 'sand', 'sink', + 'skyscraper', 'fireplace', 'refrigerator, icebox', + 'grandstand, covered stand', 'path', 'stairs', 'runway', + 'case, display case, showcase, vitrine', + 'pool table, billiard table, snooker table', 'pillow', + 'screen door, screen', 'stairway, staircase', 'river', 'bridge, span', + 'bookcase', 'blind, screen', 'coffee table', + 'toilet, can, commode, crapper, pot, potty, stool, throne', 'flower', + 'book', 'hill', 'bench', 'countertop', 'stove', 'palm, palm tree', + 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', + 'arcade machine', 'hovel, hut, hutch, shack, shanty', 'bus', 'towel', + 'light', 'truck', 'tower', 'chandelier', 'awning, sunshade, sunblind', + 'street lamp', 'booth', 'tv', 'airplane', 'dirt track', 'clothes', 'pole', + 'land, ground, soil', + 'bannister, banister, balustrade, balusters, handrail', + 'escalator, moving staircase, moving stairway', + 'ottoman, pouf, pouffe, puff, hassock', 'bottle', + 'buffet, counter, sideboard', + 'poster, posting, placard, notice, bill, card', 'stage', 'van', 'ship', + 'fountain', + 'conveyer belt, conveyor belt, conveyer, conveyor, transporter', 'canopy', + 'washer, automatic washer, washing machine', 'plaything, toy', 'pool', + 'stool', 'barrel, cask', 'basket, handbasket', 'falls', 'tent', 'bag', + 'minibike, motorbike', 'cradle', 'oven', 'ball', 'food, solid food', + 'step, stair', 'tank, storage tank', 'trade name', 'microwave', 'pot', + 'animal', 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket, cover', + 'sculpture', 'hood, exhaust hood', 'sconce', 'vase', 'traffic light', + 'tray', 'trash can', 'fan', 'pier', 'crt screen', 'plate', 'monitor', + 'bulletin board', 'shower', 'radiator', 'glass, drinking glass', 'clock', + 'flag' +] + def parse_args(): parser = argparse.ArgumentParser( description='Convert ADE20K annotations to COCO format') parser.add_argument('src', help='ade20k data path') + parser.add_argument('--task', help='task name', default='panoptic') args = parser.parse_args() return args +def prepare_instance_annotations(dataset_dir: str): + dataset_dir = Path(dataset_dir) + for name, dirname in [('train', 'training'), ('val', 'validation')]: + image_dir = dataset_dir / 'images' / dirname + instance_dir = dataset_dir / 'annotations_instance' / dirname + + ann_id = 0 + + # json + out_file = dataset_dir / f'ade20k_instance_{name}.json' + + # json config + instance_config_file = dataset_dir / 'imgCatIds.json' + with open(instance_config_file, 'r') as f: + category_dict = json.load(f)['categories'] + + # catid mapping + mapping_file = dataset_dir / 'categoryMapping.txt' + with open(mapping_file, 'r') as f: + map_id = {} + for i, line in enumerate(f.readlines()): + if i == 0: + continue + ins_id, sem_id, _ = line.strip().split() + map_id[int(ins_id)] = int(sem_id) - 1 + + for cat in category_dict: + cat['id'] = map_id[cat['id']] + + filenames = sorted(list(image_dir.iterdir())) + + ann_dict = {} + images = [] + annotations = [] + + progressbar = ProgressBar(len(filenames)) + for filename in filenames: + image = {} + image_id = filename.stem + + image['id'] = image_id + image['file_name'] = filename.name + + original_format = np.array(Image.open(filename)) + image['height'] = original_format.shape[0] + image['width'] = original_format.shape[1] + + images.append(image) + + instance_file = instance_dir / f'{image_id}.png' + ins_seg = np.array(Image.open(instance_file)) + assert ins_seg.dtype == np.uint8 + + instance_cat_ids = ins_seg[..., 0] + instance_ins_ids = ins_seg[..., 1] + + for thing_id in np.unique(instance_ins_ids): + if thing_id == 0: + continue + mask = instance_ins_ids == thing_id + instance_cat_id = np.unique(instance_cat_ids[mask]) + assert len(instance_cat_id) == 1 + + anno = {} + anno['id'] = ann_id + ann_id += 1 + anno['image_id'] = image['id'] + anno['iscrowd'] = int(0) + anno['category_id'] = int(map_id[instance_cat_id[0]]) + + inds = np.nonzero(mask) + ymin, ymax = inds[0].min(), inds[0].max() + xmin, xmax = inds[1].min(), inds[1].max() + anno['bbox'] = [ + int(xmin), + int(ymin), + int(xmax - xmin + 1), + int(ymax - ymin + 1) + ] + + rle = mask_util.encode( + np.array(mask[:, :, np.newaxis], order='F', + dtype='uint8'))[0] + rle['counts'] = rle['counts'].decode('utf-8') + anno['segmentation'] = rle + anno['area'] = int(mask_util.area(rle)) + annotations.append(anno) + progressbar.update() + + ann_dict['images'] = images + ann_dict['categories'] = category_dict + ann_dict['annotations'] = annotations + save_json(ann_dict, out_file) + + def prepare_panoptic_annotations(dataset_dir: str): dataset_dir = Path(dataset_dir) @@ -34,32 +172,34 @@ def prepare_panoptic_annotations(dataset_dir: str): mkdir_or_exist(out_folder) # catid mapping - mapping_file = dataset_dir / 'categoryMapping.txt' - with open(mapping_file, 'r') as f: - map_id = {} - for i, line in enumerate(f.readlines()): - if i == 0: - continue - ins_id, sem_id, _ = line.strip().split() - map_id[int(ins_id) - 1] = int(sem_id) - 1 - - ADE20K_150_CATEGORIES = [] - ADE20K_SEM_SEG_CATEGORIES = ADE20KPanopticDataset.METAINFO['classes'] - PALETTE = ADE20KPanopticDataset.METAINFO['palette'] - for cat_id, cat_name in enumerate(ADE20K_SEM_SEG_CATEGORIES): - ADE20K_150_CATEGORIES.append({ - 'id': - cat_id, - 'name': - cat_name, - 'isthing': - int(cat_id in map_id.values()), - 'color': - PALETTE[cat_id] + neworder_categories = [] + all_classes = ORIGINAL_CATEGORIES + thing_classes = ADE20KPanopticDataset.METAINFO['thing_classes'] + stuff_classes = ADE20KPanopticDataset.METAINFO['stuff_classes'] + palette = ADE20KPanopticDataset.METAINFO['palette'] + + old_2_new_mapping = {} + new_2_old_mapping = {} + for i, t in enumerate(thing_classes): + j = list(all_classes).index(t) + old_2_new_mapping[j] = i + new_2_old_mapping[i] = j + + for i, t in enumerate(stuff_classes): + j = list(all_classes).index(t) + old_2_new_mapping[j] = i + len(thing_classes) + new_2_old_mapping[i + len(thing_classes)] = j + + for old, new in old_2_new_mapping.items(): + neworder_categories.append({ + 'id': new, + 'name': all_classes[old], + 'isthing': int(new < len(thing_classes)), + 'color': palette[new] }) - categories_dict = {cat['id']: cat for cat in ADE20K_150_CATEGORIES} + categories_dict = {cat['id']: cat for cat in neworder_categories} - panoptic_json_categories = ADE20K_150_CATEGORIES[:] + panoptic_json_categories = neworder_categories[:] panoptic_json_images = [] panoptic_json_annotations = [] @@ -103,14 +243,15 @@ def prepare_panoptic_annotations(dataset_dir: str): for semantic_cat_id in np.unique(semantic_cat_ids): if semantic_cat_id == 255: continue - if categories_dict[semantic_cat_id]['isthing'] == 1: + if categories_dict[old_2_new_mapping[int( + semantic_cat_id)]]['isthing'] == 1: continue mask = semantic_cat_ids == semantic_cat_id # should not have any overlap assert pan_seg[mask].sum() == 0 segment_id, color = id_generator.get_id_and_color( - semantic_cat_id) + old_2_new_mapping[int(semantic_cat_id)]) pan_seg[mask] = color area = np.sum(mask) @@ -126,11 +267,16 @@ def prepare_panoptic_annotations(dataset_dir: str): bbox = [int(x), int(y), int(width), int(height)] segm_info.append({ - 'id': int(segment_id), - 'category_id': int(semantic_cat_id), - 'area': int(area), - 'bbox': bbox, - 'iscrowd': 0 + 'id': + int(segment_id), + 'category_id': + old_2_new_mapping[int(semantic_cat_id)], + 'area': + int(area), + 'bbox': + bbox, + 'iscrowd': + 0 }) # process things @@ -138,13 +284,12 @@ def prepare_panoptic_annotations(dataset_dir: str): if thing_id == 0: continue mask = instance_ins_ids == thing_id + instance_cat_id = np.unique(instance_cat_ids[mask]) assert len(instance_cat_id) == 1 - id_ = instance_cat_id[0] - semantic_cat_id = map_id[id_] segment_id, color = id_generator.get_id_and_color( - semantic_cat_id) + instance_cat_id[0]) pan_seg[mask] = color area = np.sum(mask) @@ -161,7 +306,7 @@ def prepare_panoptic_annotations(dataset_dir: str): segm_info.append({ 'id': int(segment_id), - 'category_id': int(semantic_cat_id), + 'category_id': int(instance_cat_id[0]), 'area': int(area), 'bbox': bbox, 'iscrowd': 0 @@ -190,17 +335,32 @@ def prepare_panoptic_annotations(dataset_dir: str): def main(): args = parse_args() + assert args.task in ['panoptic', 'instance'] src = args.src - annotation_train_path = f'{src}/ade20k_panoptic_train' - annotation_val_path = f'{src}/ade20k_panoptic_val' - print('Preparing ADE20K panoptic annotations ...') - print( - f'Creating panoptic annotations to {annotation_train_path} and {annotation_val_path} ...' # noqa - ) - if os.path.exists(annotation_train_path) or os.path.exists( - annotation_val_path): - raise RuntimeError('Panoptic annotations already exist.') - prepare_panoptic_annotations(src) + if args.task == 'panoptic': + annotation_train_path = f'{src}/ade20k_panoptic_train' + annotation_val_path = f'{src}/ade20k_panoptic_val' + print('Preparing ADE20K panoptic annotations ...') + print( + f'Creating panoptic annotations to {annotation_train_path} and {annotation_val_path} ...' # noqa + ) + if os.path.exists(annotation_train_path) or os.path.exists( + annotation_val_path): + raise RuntimeError('Panoptic annotations already exist.') + prepare_panoptic_annotations(src) + print('Done.') + else: + annotation_train_path = f'{src}/ade20k_instance_train' + annotation_val_path = f'{src}/ade20k_instance_val' + print('Preparing ADE20K instance annotations ...') + print( + f'Creating instance annotations to {annotation_train_path} and {annotation_val_path} ...' # noqa + ) + if os.path.exists(annotation_train_path) or os.path.exists( + annotation_val_path): + raise RuntimeError('Instance annotations already exist.') + prepare_instance_annotations(src) + print('Done.') if __name__ == '__main__': diff --git a/tools/dataset_converters/coco_stuff164k.py b/tools/dataset_converters/coco_stuff164k.py new file mode 100644 index 00000000000..fe1ff9f6b43 --- /dev/null +++ b/tools/dataset_converters/coco_stuff164k.py @@ -0,0 +1,254 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from functools import partial +from glob import glob + +import numpy as np +from mmengine.utils import (mkdir_or_exist, track_parallel_progress, + track_progress) +from PIL import Image + +COCO_LEN = 123287 + +clsID_to_trID = { + 0: 0, + 1: 1, + 2: 2, + 3: 3, + 4: 4, + 5: 5, + 6: 6, + 7: 7, + 8: 8, + 9: 9, + 10: 10, + 12: 11, + 13: 12, + 14: 13, + 15: 14, + 16: 15, + 17: 16, + 18: 17, + 19: 18, + 20: 19, + 21: 20, + 22: 21, + 23: 22, + 24: 23, + 26: 24, + 27: 25, + 30: 26, + 31: 27, + 32: 28, + 33: 29, + 34: 30, + 35: 31, + 36: 32, + 37: 33, + 38: 34, + 39: 35, + 40: 36, + 41: 37, + 42: 38, + 43: 39, + 45: 40, + 46: 41, + 47: 42, + 48: 43, + 49: 44, + 50: 45, + 51: 46, + 52: 47, + 53: 48, + 54: 49, + 55: 50, + 56: 51, + 57: 52, + 58: 53, + 59: 54, + 60: 55, + 61: 56, + 62: 57, + 63: 58, + 64: 59, + 66: 60, + 69: 61, + 71: 62, + 72: 63, + 73: 64, + 74: 65, + 75: 66, + 76: 67, + 77: 68, + 78: 69, + 79: 70, + 80: 71, + 81: 72, + 83: 73, + 84: 74, + 85: 75, + 86: 76, + 87: 77, + 88: 78, + 89: 79, + 91: 80, + 92: 81, + 93: 82, + 94: 83, + 95: 84, + 96: 85, + 97: 86, + 98: 87, + 99: 88, + 100: 89, + 101: 90, + 102: 91, + 103: 92, + 104: 93, + 105: 94, + 106: 95, + 107: 96, + 108: 97, + 109: 98, + 110: 99, + 111: 100, + 112: 101, + 113: 102, + 114: 103, + 115: 104, + 116: 105, + 117: 106, + 118: 107, + 119: 108, + 120: 109, + 121: 110, + 122: 111, + 123: 112, + 124: 113, + 125: 114, + 126: 115, + 127: 116, + 128: 117, + 129: 118, + 130: 119, + 131: 120, + 132: 121, + 133: 122, + 134: 123, + 135: 124, + 136: 125, + 137: 126, + 138: 127, + 139: 128, + 140: 129, + 141: 130, + 142: 131, + 143: 132, + 144: 133, + 145: 134, + 146: 135, + 147: 136, + 148: 137, + 149: 138, + 150: 139, + 151: 140, + 152: 141, + 153: 142, + 154: 143, + 155: 144, + 156: 145, + 157: 146, + 158: 147, + 159: 148, + 160: 149, + 161: 150, + 162: 151, + 163: 152, + 164: 153, + 165: 154, + 166: 155, + 167: 156, + 168: 157, + 169: 158, + 170: 159, + 171: 160, + 172: 161, + 173: 162, + 174: 163, + 175: 164, + 176: 165, + 177: 166, + 178: 167, + 179: 168, + 180: 169, + 181: 170, + 255: 255 +} + + +def convert_to_trainID(maskpath, out_mask_dir, is_train): + mask = np.array(Image.open(maskpath)) + mask_copy = mask.copy() + for clsID, trID in clsID_to_trID.items(): + mask_copy[mask == clsID] = trID + seg_filename = osp.join(out_mask_dir, 'train2017', + osp.basename(maskpath)) if is_train else osp.join( + out_mask_dir, 'val2017', + osp.basename(maskpath)) + Image.fromarray(mask_copy).save(seg_filename, 'PNG') + + +def parse_args(): + parser = argparse.ArgumentParser( + description=\ + 'Convert COCO Stuff 164k annotations to mmdet format') # noqa + parser.add_argument('coco_path', help='coco stuff path') + parser.add_argument( + '--out-dir-name', + '-o', + default='stuffthingmaps_semseg', + help='output path') + parser.add_argument( + '--nproc', default=16, type=int, help='number of process') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + coco_path = args.coco_path + out_dir = osp.join(coco_path, args.out_dir_name) + nproc = args.nproc + + mkdir_or_exist(osp.join(out_dir, 'train2017')) + mkdir_or_exist(osp.join(out_dir, 'val2017')) + + train_list = glob(osp.join(coco_path, 'stuffthingmaps/train2017', '*.png')) + val_list = glob(osp.join(coco_path, 'stuffthingmaps/val2017', '*.png')) + assert (len(train_list) + + len(val_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format( + len(train_list), len(val_list)) + + if args.nproc > 1: + track_parallel_progress( + partial(convert_to_trainID, out_mask_dir=out_dir, is_train=True), + train_list, + nproc=nproc) + track_parallel_progress( + partial(convert_to_trainID, out_mask_dir=out_dir, is_train=False), + val_list, + nproc=nproc) + else: + track_progress( + partial(convert_to_trainID, out_mask_dir=out_dir, is_train=True), + train_list) + track_progress( + partial(convert_to_trainID, out_mask_dir=out_dir, is_train=False), + val_list) + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/tools/dataset_converters/prepare_coco_semantic_annos_from_panoptic_annos.py b/tools/dataset_converters/prepare_coco_semantic_annos_from_panoptic_annos.py new file mode 100644 index 00000000000..2b9ee592cb3 --- /dev/null +++ b/tools/dataset_converters/prepare_coco_semantic_annos_from_panoptic_annos.py @@ -0,0 +1,899 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/facebookresearch/Mask2Former/blob/main/datasets/prepare_coco_semantic_annos_from_panoptic_annos.py # noqa + +import argparse +import functools +import json +import multiprocessing as mp +import os +import time + +import numpy as np +from panopticapi.utils import rgb2id +from PIL import Image + +COCO_CATEGORIES = [ + { + 'color': [220, 20, 60], + 'isthing': 1, + 'id': 1, + 'name': 'person' + }, + { + 'color': [119, 11, 32], + 'isthing': 1, + 'id': 2, + 'name': 'bicycle' + }, + { + 'color': [0, 0, 142], + 'isthing': 1, + 'id': 3, + 'name': 'car' + }, + { + 'color': [0, 0, 230], + 'isthing': 1, + 'id': 4, + 'name': 'motorcycle' + }, + { + 'color': [106, 0, 228], + 'isthing': 1, + 'id': 5, + 'name': 'airplane' + }, + { + 'color': [0, 60, 100], + 'isthing': 1, + 'id': 6, + 'name': 'bus' + }, + { + 'color': [0, 80, 100], + 'isthing': 1, + 'id': 7, + 'name': 'train' + }, + { + 'color': [0, 0, 70], + 'isthing': 1, + 'id': 8, + 'name': 'truck' + }, + { + 'color': [0, 0, 192], + 'isthing': 1, + 'id': 9, + 'name': 'boat' + }, + { + 'color': [250, 170, 30], + 'isthing': 1, + 'id': 10, + 'name': 'traffic light' + }, + { + 'color': [100, 170, 30], + 'isthing': 1, + 'id': 11, + 'name': 'fire hydrant' + }, + { + 'color': [220, 220, 0], + 'isthing': 1, + 'id': 13, + 'name': 'stop sign' + }, + { + 'color': [175, 116, 175], + 'isthing': 1, + 'id': 14, + 'name': 'parking meter' + }, + { + 'color': [250, 0, 30], + 'isthing': 1, + 'id': 15, + 'name': 'bench' + }, + { + 'color': [165, 42, 42], + 'isthing': 1, + 'id': 16, + 'name': 'bird' + }, + { + 'color': [255, 77, 255], + 'isthing': 1, + 'id': 17, + 'name': 'cat' + }, + { + 'color': [0, 226, 252], + 'isthing': 1, + 'id': 18, + 'name': 'dog' + }, + { + 'color': [182, 182, 255], + 'isthing': 1, + 'id': 19, + 'name': 'horse' + }, + { + 'color': [0, 82, 0], + 'isthing': 1, + 'id': 20, + 'name': 'sheep' + }, + { + 'color': [120, 166, 157], + 'isthing': 1, + 'id': 21, + 'name': 'cow' + }, + { + 'color': [110, 76, 0], + 'isthing': 1, + 'id': 22, + 'name': 'elephant' + }, + { + 'color': [174, 57, 255], + 'isthing': 1, + 'id': 23, + 'name': 'bear' + }, + { + 'color': [199, 100, 0], + 'isthing': 1, + 'id': 24, + 'name': 'zebra' + }, + { + 'color': [72, 0, 118], + 'isthing': 1, + 'id': 25, + 'name': 'giraffe' + }, + { + 'color': [255, 179, 240], + 'isthing': 1, + 'id': 27, + 'name': 'backpack' + }, + { + 'color': [0, 125, 92], + 'isthing': 1, + 'id': 28, + 'name': 'umbrella' + }, + { + 'color': [209, 0, 151], + 'isthing': 1, + 'id': 31, + 'name': 'handbag' + }, + { + 'color': [188, 208, 182], + 'isthing': 1, + 'id': 32, + 'name': 'tie' + }, + { + 'color': [0, 220, 176], + 'isthing': 1, + 'id': 33, + 'name': 'suitcase' + }, + { + 'color': [255, 99, 164], + 'isthing': 1, + 'id': 34, + 'name': 'frisbee' + }, + { + 'color': [92, 0, 73], + 'isthing': 1, + 'id': 35, + 'name': 'skis' + }, + { + 'color': [133, 129, 255], + 'isthing': 1, + 'id': 36, + 'name': 'snowboard' + }, + { + 'color': [78, 180, 255], + 'isthing': 1, + 'id': 37, + 'name': 'sports ball' + }, + { + 'color': [0, 228, 0], + 'isthing': 1, + 'id': 38, + 'name': 'kite' + }, + { + 'color': [174, 255, 243], + 'isthing': 1, + 'id': 39, + 'name': 'baseball bat' + }, + { + 'color': [45, 89, 255], + 'isthing': 1, + 'id': 40, + 'name': 'baseball glove' + }, + { + 'color': [134, 134, 103], + 'isthing': 1, + 'id': 41, + 'name': 'skateboard' + }, + { + 'color': [145, 148, 174], + 'isthing': 1, + 'id': 42, + 'name': 'surfboard' + }, + { + 'color': [255, 208, 186], + 'isthing': 1, + 'id': 43, + 'name': 'tennis racket' + }, + { + 'color': [197, 226, 255], + 'isthing': 1, + 'id': 44, + 'name': 'bottle' + }, + { + 'color': [171, 134, 1], + 'isthing': 1, + 'id': 46, + 'name': 'wine glass' + }, + { + 'color': [109, 63, 54], + 'isthing': 1, + 'id': 47, + 'name': 'cup' + }, + { + 'color': [207, 138, 255], + 'isthing': 1, + 'id': 48, + 'name': 'fork' + }, + { + 'color': [151, 0, 95], + 'isthing': 1, + 'id': 49, + 'name': 'knife' + }, + { + 'color': [9, 80, 61], + 'isthing': 1, + 'id': 50, + 'name': 'spoon' + }, + { + 'color': [84, 105, 51], + 'isthing': 1, + 'id': 51, + 'name': 'bowl' + }, + { + 'color': [74, 65, 105], + 'isthing': 1, + 'id': 52, + 'name': 'banana' + }, + { + 'color': [166, 196, 102], + 'isthing': 1, + 'id': 53, + 'name': 'apple' + }, + { + 'color': [208, 195, 210], + 'isthing': 1, + 'id': 54, + 'name': 'sandwich' + }, + { + 'color': [255, 109, 65], + 'isthing': 1, + 'id': 55, + 'name': 'orange' + }, + { + 'color': [0, 143, 149], + 'isthing': 1, + 'id': 56, + 'name': 'broccoli' + }, + { + 'color': [179, 0, 194], + 'isthing': 1, + 'id': 57, + 'name': 'carrot' + }, + { + 'color': [209, 99, 106], + 'isthing': 1, + 'id': 58, + 'name': 'hot dog' + }, + { + 'color': [5, 121, 0], + 'isthing': 1, + 'id': 59, + 'name': 'pizza' + }, + { + 'color': [227, 255, 205], + 'isthing': 1, + 'id': 60, + 'name': 'donut' + }, + { + 'color': [147, 186, 208], + 'isthing': 1, + 'id': 61, + 'name': 'cake' + }, + { + 'color': [153, 69, 1], + 'isthing': 1, + 'id': 62, + 'name': 'chair' + }, + { + 'color': [3, 95, 161], + 'isthing': 1, + 'id': 63, + 'name': 'couch' + }, + { + 'color': [163, 255, 0], + 'isthing': 1, + 'id': 64, + 'name': 'potted plant' + }, + { + 'color': [119, 0, 170], + 'isthing': 1, + 'id': 65, + 'name': 'bed' + }, + { + 'color': [0, 182, 199], + 'isthing': 1, + 'id': 67, + 'name': 'dining table' + }, + { + 'color': [0, 165, 120], + 'isthing': 1, + 'id': 70, + 'name': 'toilet' + }, + { + 'color': [183, 130, 88], + 'isthing': 1, + 'id': 72, + 'name': 'tv' + }, + { + 'color': [95, 32, 0], + 'isthing': 1, + 'id': 73, + 'name': 'laptop' + }, + { + 'color': [130, 114, 135], + 'isthing': 1, + 'id': 74, + 'name': 'mouse' + }, + { + 'color': [110, 129, 133], + 'isthing': 1, + 'id': 75, + 'name': 'remote' + }, + { + 'color': [166, 74, 118], + 'isthing': 1, + 'id': 76, + 'name': 'keyboard' + }, + { + 'color': [219, 142, 185], + 'isthing': 1, + 'id': 77, + 'name': 'cell phone' + }, + { + 'color': [79, 210, 114], + 'isthing': 1, + 'id': 78, + 'name': 'microwave' + }, + { + 'color': [178, 90, 62], + 'isthing': 1, + 'id': 79, + 'name': 'oven' + }, + { + 'color': [65, 70, 15], + 'isthing': 1, + 'id': 80, + 'name': 'toaster' + }, + { + 'color': [127, 167, 115], + 'isthing': 1, + 'id': 81, + 'name': 'sink' + }, + { + 'color': [59, 105, 106], + 'isthing': 1, + 'id': 82, + 'name': 'refrigerator' + }, + { + 'color': [142, 108, 45], + 'isthing': 1, + 'id': 84, + 'name': 'book' + }, + { + 'color': [196, 172, 0], + 'isthing': 1, + 'id': 85, + 'name': 'clock' + }, + { + 'color': [95, 54, 80], + 'isthing': 1, + 'id': 86, + 'name': 'vase' + }, + { + 'color': [128, 76, 255], + 'isthing': 1, + 'id': 87, + 'name': 'scissors' + }, + { + 'color': [201, 57, 1], + 'isthing': 1, + 'id': 88, + 'name': 'teddy bear' + }, + { + 'color': [246, 0, 122], + 'isthing': 1, + 'id': 89, + 'name': 'hair drier' + }, + { + 'color': [191, 162, 208], + 'isthing': 1, + 'id': 90, + 'name': 'toothbrush' + }, + { + 'color': [255, 255, 128], + 'isthing': 0, + 'id': 92, + 'name': 'banner' + }, + { + 'color': [147, 211, 203], + 'isthing': 0, + 'id': 93, + 'name': 'blanket' + }, + { + 'color': [150, 100, 100], + 'isthing': 0, + 'id': 95, + 'name': 'bridge' + }, + { + 'color': [168, 171, 172], + 'isthing': 0, + 'id': 100, + 'name': 'cardboard' + }, + { + 'color': [146, 112, 198], + 'isthing': 0, + 'id': 107, + 'name': 'counter' + }, + { + 'color': [210, 170, 100], + 'isthing': 0, + 'id': 109, + 'name': 'curtain' + }, + { + 'color': [92, 136, 89], + 'isthing': 0, + 'id': 112, + 'name': 'door-stuff' + }, + { + 'color': [218, 88, 184], + 'isthing': 0, + 'id': 118, + 'name': 'floor-wood' + }, + { + 'color': [241, 129, 0], + 'isthing': 0, + 'id': 119, + 'name': 'flower' + }, + { + 'color': [217, 17, 255], + 'isthing': 0, + 'id': 122, + 'name': 'fruit' + }, + { + 'color': [124, 74, 181], + 'isthing': 0, + 'id': 125, + 'name': 'gravel' + }, + { + 'color': [70, 70, 70], + 'isthing': 0, + 'id': 128, + 'name': 'house' + }, + { + 'color': [255, 228, 255], + 'isthing': 0, + 'id': 130, + 'name': 'light' + }, + { + 'color': [154, 208, 0], + 'isthing': 0, + 'id': 133, + 'name': 'mirror-stuff' + }, + { + 'color': [193, 0, 92], + 'isthing': 0, + 'id': 138, + 'name': 'net' + }, + { + 'color': [76, 91, 113], + 'isthing': 0, + 'id': 141, + 'name': 'pillow' + }, + { + 'color': [255, 180, 195], + 'isthing': 0, + 'id': 144, + 'name': 'platform' + }, + { + 'color': [106, 154, 176], + 'isthing': 0, + 'id': 145, + 'name': 'playingfield' + }, + { + 'color': [230, 150, 140], + 'isthing': 0, + 'id': 147, + 'name': 'railroad' + }, + { + 'color': [60, 143, 255], + 'isthing': 0, + 'id': 148, + 'name': 'river' + }, + { + 'color': [128, 64, 128], + 'isthing': 0, + 'id': 149, + 'name': 'road' + }, + { + 'color': [92, 82, 55], + 'isthing': 0, + 'id': 151, + 'name': 'roof' + }, + { + 'color': [254, 212, 124], + 'isthing': 0, + 'id': 154, + 'name': 'sand' + }, + { + 'color': [73, 77, 174], + 'isthing': 0, + 'id': 155, + 'name': 'sea' + }, + { + 'color': [255, 160, 98], + 'isthing': 0, + 'id': 156, + 'name': 'shelf' + }, + { + 'color': [255, 255, 255], + 'isthing': 0, + 'id': 159, + 'name': 'snow' + }, + { + 'color': [104, 84, 109], + 'isthing': 0, + 'id': 161, + 'name': 'stairs' + }, + { + 'color': [169, 164, 131], + 'isthing': 0, + 'id': 166, + 'name': 'tent' + }, + { + 'color': [225, 199, 255], + 'isthing': 0, + 'id': 168, + 'name': 'towel' + }, + { + 'color': [137, 54, 74], + 'isthing': 0, + 'id': 171, + 'name': 'wall-brick' + }, + { + 'color': [135, 158, 223], + 'isthing': 0, + 'id': 175, + 'name': 'wall-stone' + }, + { + 'color': [7, 246, 231], + 'isthing': 0, + 'id': 176, + 'name': 'wall-tile' + }, + { + 'color': [107, 255, 200], + 'isthing': 0, + 'id': 177, + 'name': 'wall-wood' + }, + { + 'color': [58, 41, 149], + 'isthing': 0, + 'id': 178, + 'name': 'water-other' + }, + { + 'color': [183, 121, 142], + 'isthing': 0, + 'id': 180, + 'name': 'window-blind' + }, + { + 'color': [255, 73, 97], + 'isthing': 0, + 'id': 181, + 'name': 'window-other' + }, + { + 'color': [107, 142, 35], + 'isthing': 0, + 'id': 184, + 'name': 'tree-merged' + }, + { + 'color': [190, 153, 153], + 'isthing': 0, + 'id': 185, + 'name': 'fence-merged' + }, + { + 'color': [146, 139, 141], + 'isthing': 0, + 'id': 186, + 'name': 'ceiling-merged' + }, + { + 'color': [70, 130, 180], + 'isthing': 0, + 'id': 187, + 'name': 'sky-other-merged' + }, + { + 'color': [134, 199, 156], + 'isthing': 0, + 'id': 188, + 'name': 'cabinet-merged' + }, + { + 'color': [209, 226, 140], + 'isthing': 0, + 'id': 189, + 'name': 'table-merged' + }, + { + 'color': [96, 36, 108], + 'isthing': 0, + 'id': 190, + 'name': 'floor-other-merged' + }, + { + 'color': [96, 96, 96], + 'isthing': 0, + 'id': 191, + 'name': 'pavement-merged' + }, + { + 'color': [64, 170, 64], + 'isthing': 0, + 'id': 192, + 'name': 'mountain-merged' + }, + { + 'color': [152, 251, 152], + 'isthing': 0, + 'id': 193, + 'name': 'grass-merged' + }, + { + 'color': [208, 229, 228], + 'isthing': 0, + 'id': 194, + 'name': 'dirt-merged' + }, + { + 'color': [206, 186, 171], + 'isthing': 0, + 'id': 195, + 'name': 'paper-merged' + }, + { + 'color': [152, 161, 64], + 'isthing': 0, + 'id': 196, + 'name': 'food-other-merged' + }, + { + 'color': [116, 112, 0], + 'isthing': 0, + 'id': 197, + 'name': 'building-other-merged' + }, + { + 'color': [0, 114, 143], + 'isthing': 0, + 'id': 198, + 'name': 'rock-merged' + }, + { + 'color': [102, 102, 156], + 'isthing': 0, + 'id': 199, + 'name': 'wall-other-merged' + }, + { + 'color': [250, 141, 255], + 'isthing': 0, + 'id': 200, + 'name': 'rug-merged' + }, +] + + +def _process_panoptic_to_semantic(input_panoptic, output_semantic, segments, + id_map): + panoptic = np.asarray(Image.open(input_panoptic), dtype=np.uint32) + panoptic = rgb2id(panoptic) + output = np.zeros_like(panoptic, dtype=np.uint8) + 255 + for seg in segments: + cat_id = seg['category_id'] + new_cat_id = id_map[cat_id] + output[panoptic == seg['id']] = new_cat_id + Image.fromarray(output).save(output_semantic) + + +def separate_coco_semantic_from_panoptic(panoptic_json, panoptic_root, + sem_seg_root, categories): + """Create semantic segmentation annotations from panoptic segmentation + annotations, to be used by PanopticFPN. + + It maps all thing categories to class 0, and maps all + unlabeled pixels to class 255. + It maps all stuff categories to contiguous ids starting from 1. + Args: + panoptic_json (str): path to the panoptic json file, in COCO's format. + panoptic_root (str): a directory with panoptic annotation files, in + COCO's format. + sem_seg_root (str): a directory to output semantic annotation files + categories (list[dict]): category metadata. Each dict needs to have: + "id": corresponds to the "category_id" in the json annotations + "isthing": 0 or 1 + """ + os.makedirs(sem_seg_root, exist_ok=True) + + id_map = {} # map from category id to id in the output semantic annotation + assert len(categories) <= 254 + for i, k in enumerate(categories): + id_map[k['id']] = i + # what is id = 0? + # id_map[0] = 255 + print(id_map) + + with open(panoptic_json) as f: + obj = json.load(f) + + pool = mp.Pool(processes=max(mp.cpu_count() // 2, 4)) + + def iter_annotations(): + for anno in obj['annotations']: + file_name = anno['file_name'] + segments = anno['segments_info'] + input = os.path.join(panoptic_root, file_name) + output = os.path.join(sem_seg_root, file_name) + yield input, output, segments + + print('Start writing to {} ...'.format(sem_seg_root)) + start = time.time() + pool.starmap( + functools.partial(_process_panoptic_to_semantic, id_map=id_map), + iter_annotations(), + chunksize=100, + ) + print('Finished. time: {:.2f}s'.format(time.time() - start)) + + +def parse_args(): + parser = argparse.ArgumentParser( + description=\ + 'Convert COCO Stuff 164k annotations to mmdet format') # noqa + parser.add_argument('coco_path', help='coco stuff path') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + dataset_dir = args.coco_path + for s in ['val2017', 'train2017']: + separate_coco_semantic_from_panoptic( + os.path.join(dataset_dir, + 'annotations/panoptic_{}.json'.format(s)), + os.path.join(dataset_dir, 'annotations/panoptic_{}'.format(s)), + os.path.join(dataset_dir, + 'annotations/panoptic_semseg_{}'.format(s)), + COCO_CATEGORIES, + ) diff --git a/tools/misc/download_dataset.py b/tools/misc/download_dataset.py index 3d57fb728df..5d801d208c4 100644 --- a/tools/misc/download_dataset.py +++ b/tools/misc/download_dataset.py @@ -188,7 +188,7 @@ def main(): # training images and semantic segmentation annotations 'http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip', # noqa # instance segmentation annotations - 'http://sceneparsing.csail.mit.edu/data/ChallengeData2017/annotations_instance.tar' # noqa + 'http://sceneparsing.csail.mit.edu/data/ChallengeData2017/annotations_instance.tar', # noqa # img categories ids 'https://raw.githubusercontent.com/CSAILVision/placeschallenge/master/instancesegmentation/imgCatIds.json', # noqa # category mapping @@ -206,7 +206,8 @@ def main(): ]) url = data2url.get(args.dataset_name, None) if url is None: - print('Only support COCO, VOC, LVIS, balloon, and Objects365v2 now!') + print('Only support ADE20K, COCO, RefCOCO, VOC, LVIS, ' + 'balloon, and Objects365v2 now!') return if args.dataset_name == 'objects365v2': download_objects365v2(