Skip to content

Commit

Permalink
ViTDet README and COCO configs
Browse files Browse the repository at this point in the history
Reviewed By: rbgirshick, wat3rBro, HannaMao

Differential Revision: D36117941

fbshipit-source-id: 9608b390b958f2471fbdedfb5f97ae0a3c23e006
  • Loading branch information
lyttonhao authored and facebook-github-bot committed Jun 9, 2022
1 parent b01e0e9 commit 333efcb
Show file tree
Hide file tree
Showing 28 changed files with 776 additions and 26 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Explain Like I’m 5: Detectron2 | Using Machine Learning with Detec

## What's New
* Includes new capabilities such as panoptic segmentation, Densepose, Cascade R-CNN, rotated bounding boxes, PointRend,
DeepLab, etc.
DeepLab, ViTDet, etc.
* Used as a library to support building [research projects](projects/) on top of it.
* Models can be exported to TorchScript format or Caffe2 format for deployment.
* It [trains much faster](https://detectron2.readthedocs.io/notes/benchmarks.html).
Expand Down
5 changes: 3 additions & 2 deletions configs/Misc/mmdet_mask_rcnn_R_50_FPN_1x.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ..common.coco_schedule import lr_multiplier_1x as lr_multiplier
from ..common.optim import SGD as optimizer
from ..common.train import train
from ..common.data.constants import constants

from detectron2.modeling.mmdet_wrapper import MMDetDetector
from detectron2.config import LazyCall as L
Expand Down Expand Up @@ -143,8 +144,8 @@
),
),
),
pixel_mean=[123.675, 116.280, 103.530],
pixel_std=[58.395, 57.120, 57.375],
pixel_mean=constants.imagenet_rgb256_mean,
pixel_std=constants.imagenet_rgb256_std,
)

dataloader.train.mapper.image_format = "RGB" # torchvision pretrained model
Expand Down
9 changes: 9 additions & 0 deletions configs/common/data/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
constants = dict(
imagenet_rgb256_mean=[123.675, 116.28, 103.53],
imagenet_rgb256_std=[58.395, 57.12, 57.375],
imagenet_bgr256_mean=[103.530, 116.280, 123.675],
# When using pre-trained models in Detectron1 or any MSRA models,
# std has been absorbed into its conv1 weights, so the std needs to be set 1.
# Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std)
imagenet_bgr256_std=[1.0, 1.0, 1.0],
)
6 changes: 4 additions & 2 deletions configs/common/models/mask_rcnn_c4.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
Res5ROIHeads,
)

from ..data.constants import constants

model = L(GeneralizedRCNN)(
backbone=L(ResNet)(
stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"),
Expand Down Expand Up @@ -82,7 +84,7 @@
conv_dims=[256],
),
),
pixel_mean=[103.530, 116.280, 123.675],
pixel_std=[1.0, 1.0, 1.0],
pixel_mean=constants.imagenet_bgr256_mean,
pixel_std=constants.imagenet_bgr256_std,
input_format="BGR",
)
6 changes: 4 additions & 2 deletions configs/common/models/mask_rcnn_fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
FastRCNNConvFCHead,
)

from ..data.constants import constants

model = L(GeneralizedRCNN)(
backbone=L(FPN)(
bottom_up=L(ResNet)(
Expand Down Expand Up @@ -87,7 +89,7 @@
conv_dims=[256, 256, 256, 256, 256],
),
),
pixel_mean=[103.530, 116.280, 123.675],
pixel_std=[1.0, 1.0, 1.0],
pixel_mean=constants.imagenet_bgr256_mean,
pixel_std=constants.imagenet_bgr256_std,
input_format="BGR",
)
59 changes: 59 additions & 0 deletions configs/common/models/mask_rcnn_vitdet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from functools import partial
import torch.nn as nn
from detectron2.config import LazyCall as L
from detectron2.modeling import ViT, SimpleFeaturePyramid
from detectron2.modeling.backbone.fpn import LastLevelMaxPool

from .mask_rcnn_fpn import model
from ..data.constants import constants

model.pixel_mean = constants.imagenet_rgb256_mean
model.pixel_std = constants.imagenet_rgb256_std
model.input_format = "RGB"

# Base
embed_dim, depth, num_heads, dp = 768, 12, 12, 0.1
# Creates Simple Feature Pyramid from ViT backbone
model.backbone = L(SimpleFeaturePyramid)(
net=L(ViT)( # Single-scale ViT backbone
img_size=1024,
patch_size=16,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
drop_path_rate=dp,
window_size=14,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
window_block_indexes=[
# 2, 5, 8 11 for global attention
0,
1,
3,
4,
6,
7,
9,
10,
],
residual_block_indexes=[],
use_rel_pos=True,
out_feature="last_feat",
),
in_feature="${.net.out_feature}",
out_channels=256,
scale_factors=(4.0, 2.0, 1.0, 0.5),
top_block=L(LastLevelMaxPool)(),
norm="LN",
square_pad=1024,
)

model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = "LN"

# 2conv in RPN:
model.proposal_generator.head.conv_dims = [-1, -1]

# 4conv1fc box head
model.roi_heads.box_head.conv_dims = [256, 256, 256, 256]
model.roi_heads.box_head.fc_dims = [1024]
6 changes: 4 additions & 2 deletions configs/common/models/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from detectron2.modeling.matcher import Matcher
from detectron2.modeling.meta_arch.retinanet import RetinaNetHead

from ..data.constants import constants

model = L(RetinaNet)(
backbone=L(FPN)(
bottom_up=L(ResNet)(
Expand Down Expand Up @@ -47,7 +49,7 @@
head_in_features=["p3", "p4", "p5", "p6", "p7"],
focal_loss_alpha=0.25,
focal_loss_gamma=2.0,
pixel_mean=[103.530, 116.280, 123.675],
pixel_std=[1.0, 1.0, 1.0],
pixel_mean=constants.imagenet_bgr256_mean,
pixel_std=constants.imagenet_bgr256_std,
input_format="BGR",
)
13 changes: 13 additions & 0 deletions configs/common/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,16 @@
momentum=0.9,
weight_decay=1e-4,
)


AdamW = L(torch.optim.AdamW)(
params=L(get_default_optimizer_params)(
# params.model is meant to be set to the model object, before instantiating
# the optimizer.
base_lr="${..lr}",
weight_decay_norm=0.0,
),
lr=1e-4,
betas=(0.9, 0.999),
weight_decay=0.1,
)
1 change: 1 addition & 0 deletions detectron2/modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
make_stage,
ViT,
SimpleFeaturePyramid,
get_vit_lr_decay_rate,
MViT,
SwinTransformer,
)
Expand Down
2 changes: 1 addition & 1 deletion detectron2/modeling/backbone/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
make_stage,
BottleneckBlock,
)
from .vit import ViT, SimpleFeaturePyramid
from .vit import ViT, SimpleFeaturePyramid, get_vit_lr_decay_rate
from .mvit import MViT
from .swin import SwinTransformer

Expand Down
6 changes: 3 additions & 3 deletions detectron2/modeling/backbone/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def padding_constraints(self) -> Dict[str, int]:
in :paper:vitdet). `padding_constraints` contains these optional items like:
{
"size_divisibility": int,
"square": int,
"square_size": int,
# Future options are possible
}
`size_divisibility` will read from here if presented and `square` indicates if requiring
inputs to be padded to square. Set to None if no specific padding constraints.
`size_divisibility` will read from here if presented and `square_size` indicates the
square padding size if `square_size` > 0.
TODO: use type of Dict[str, int] to avoid torchscipt issues. The type of padding_constraints
could be generalized as TypedDict (Python 3.8+) to support more types in the future.
Expand Down
5 changes: 3 additions & 2 deletions detectron2/modeling/backbone/fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
norm="",
top_block=None,
fuse_type="sum",
square_pad=False,
square_pad=0,
):
"""
Args:
Expand All @@ -54,6 +54,7 @@ def __init__(
fuse_type (str): types for fusing the top down features and the lateral
ones. It can be "sum" (default), which sums up element-wise; or "avg",
which takes the element-wise mean of the two.
square_pad (int): If > 0, require input images to be padded to specific square size.
"""
super(FPN, self).__init__()
assert isinstance(bottom_up, Backbone)
Expand Down Expand Up @@ -120,7 +121,7 @@ def size_divisibility(self):

@property
def padding_constraints(self):
return {"square": int(self._square_pad)}
return {"square_size": self._square_pad}

def forward(self, x):
"""
Expand Down
29 changes: 25 additions & 4 deletions detectron2/modeling/backbone/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
logger = logging.getLogger(__name__)


__all__ = ["ViT", "SimpleFeaturePyramid"]
__all__ = ["ViT", "SimpleFeaturePyramid", "get_vit_lr_decay_rate"]


class Attention(nn.Module):
Expand Down Expand Up @@ -372,7 +372,7 @@ def __init__(
scale_factors,
top_block=None,
norm="LN",
square_pad=False,
square_pad=0,
):
"""
Args:
Expand All @@ -391,7 +391,7 @@ def __init__(
this block, and "in_feature", which is a string representing
its input feature (e.g., p5).
norm (str): the normalization to use.
square_pad (bool): If true, require input images to be padded to square.
square_pad (int): If > 0, require input images to be padded to specific square size.
"""
super(SimpleFeaturePyramid, self).__init__()
assert isinstance(net, Backbone)
Expand Down Expand Up @@ -469,7 +469,7 @@ def __init__(
def padding_constraints(self):
return {
"size_divisiblity": self._size_divisibility,
"square": int(self._square_pad),
"square_size": self._square_pad,
}

def forward(self, x):
Expand Down Expand Up @@ -499,3 +499,24 @@ def forward(self, x):
results.extend(self.top_block(top_block_in_feature))
assert len(self._out_features) == len(results)
return {f: res for f, res in zip(self._out_features, results)}


def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
"""
Calculate lr decay rate for different ViT blocks.
Args:
name (string): parameter name.
lr_decay_rate (float): base lr decay rate.
num_layers (int): number of ViT blocks.
Returns:
lr decay rate for the given parameter.
"""
layer_id = num_layers + 1
if name.startswith("backbone"):
if ".pos_embed" in name or ".patch_embed" in name:
layer_id = 0
elif ".blocks." in name and ".residual." not in name:
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1

return lr_decay_rate ** (num_layers + 1 - layer_id)
12 changes: 6 additions & 6 deletions detectron2/structures/image_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,9 @@ def from_tensors(
This depends on the model and many models need a divisibility of 32.
pad_value (float): value to pad.
padding_constraints (optional[Dict]): If given, it would follow the format as
{"size_divisibility": int, "square": int}, where `size_divisibility` will overwrite
the above one if presented and `square` indicates if require inputs to be padded to
square.
{"size_divisibility": int, "square_size": int}, where `size_divisibility` will
overwrite the above one if presented and `square_size` indicates the
square padding size if `square_size` > 0.
Returns:
an `ImageList`.
"""
Expand All @@ -90,9 +89,10 @@ def from_tensors(
max_size = torch.stack(image_sizes_tensor).max(0).values

if padding_constraints is not None:
if padding_constraints.get("square", 0) > 0:
square_size = padding_constraints.get("square_size", 0)
if square_size > 0:
# pad to square.
max_size[0] = max_size[1] = max_size.max()
max_size[0] = max_size[1] = square_size
if "size_divisibility" in padding_constraints:
size_divisibility = padding_constraints["size_divisibility"]
if size_divisibility > 1:
Expand Down
2 changes: 1 addition & 1 deletion projects/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ of support or stability as detectron2.
+ [Unbiased Teacher for Semi-Supervised Object Detection](https://github.com/facebookresearch/unbiased-teacher)
+ [Rethinking "Batch" in BatchNorm](Rethinking-BatchNorm/)
+ [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://github.com/facebookresearch/MaskFormer)
+ [Exploring Plain Vision Transformer Backbones for Object Detection](ViTDet/)


## External Projects
Expand All @@ -45,4 +46,3 @@ External projects in the community that use detectron2:
+ [Sparse R-CNN](https://github.com/PeizeSun/SparseR-CNN)
+ [BCNet](https://github.com/lkeab/BCNet), a bilayer decoupling instance segmentation method.
+ [DD3D](https://github.com/TRI-ML/dd3d), A fully convolutional 3D detector.

0 comments on commit 333efcb

Please sign in to comment.