Skip to content

Commit

Permalink
Add Detectron2 wrapper (facebookresearch#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
alcinos committed Jun 28, 2020
1 parent f864b0b commit 0099bd1
Show file tree
Hide file tree
Showing 9 changed files with 692 additions and 0 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Training code follows this idea - it is not a library,
but simply a [main.py](main.py) importing model and criterion
definitions with standard training loops.

Additionnally, we provide a Detectron2 wrapper in the d2/ folder. See the readme there for more information.

For details see [End-to-End Object Detection with Transformers](https://ai.facebook.com/research/publications/end-to-end-object-detection-with-transformers) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, and Sergey Zagoruyko.

# Model Zoo
Expand Down
34 changes: 34 additions & 0 deletions d2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
Detectron2 wrapper for DETR
=======

We provide a Detectron2 wrapper for DETR, thus providing a way to better integrate it in the existing detection ecosystem. It can be used for example to easily leverage datasets or backbones provided in Detectron2.

This wrapper currently supports only box detection, and is intended to be as close as possible to the original implementation, and we checked that it indeed match the results. Some notable facts and caveats:
- The data augmentation matches DETR's original data augmentation. This required patching the RandomCrop augmentation from Detectron2, so you'll need a version from the master branch from June 24th 2020 or more recent.
- To match DETR's original backbone initialization, we use the weights of a ResNet50 trained on imagenet using torchvision. This network uses a different pixel mean and std than most of the backbones available in Detectron2 by default, so extra care must be taken when switching to another one. Note that no other torchvision models are available in Detectron2 as of now, though it may change in the future.
- The gradient clipping mode is "full_model", which is not the default in Detectron2.

# Usage

To install Detectron2, please follow the [official installation instructions](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md).

## Evaluating a model

For convenience, we provide a conversion script to convert models trained by the main DETR training loop into the format of this wrapper. To download and convert the main Resnet50 model, simply do:

```
python converter.py --source_model https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth --output_model converted_model.pth
```

You can then evaluate it using:
```
python train_net.py --eval-only --config configs/detr_256_6_6_torchvision.yaml MODEL.WEIGHTS "converted_model.pth"
```


## Training

To train DETR on a single node with 8 gpus, simply use:
```
python train_net.py --config configs/detr_256_6_6_torchvision.yaml --num-gpus 8
```
45 changes: 45 additions & 0 deletions d2/configs/detr_256_6_6_torchvision.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
MODEL:
META_ARCHITECTURE: "Detr"
WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
PIXEL_MEAN: [123.675, 116.280, 103.530]
PIXEL_STD: [58.395, 57.120, 57.375]
MASK_ON: False
RESNETS:
DEPTH: 50
STRIDE_IN_1X1: False
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
DETR:
GIOU_WEIGHT: 2.0
L1_WEIGHT: 5.0
NUM_OBJECT_QUERIES: 100
DATASETS:
TRAIN: ("coco_2017_train",)
TEST: ("coco_2017_val",)
SOLVER:
IMS_PER_BATCH: 64
BASE_LR: 0.0001
STEPS: (369600,)
MAX_ITER: 554400
WARMUP_FACTOR: 1.0
WARMUP_ITERS: 10
WEIGHT_DECAY: 0.0001
OPTIMIZER: "ADAMW"
BACKBONE_MULTIPLIER: 0.1
CLIP_GRADIENTS:
ENABLED: True
CLIP_TYPE: "full_model"
CLIP_VALUE: 0.01
NORM_TYPE: 2.0
INPUT:
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
CROP:
ENABLED: True
TYPE: "absolute_range"
SIZE: (384, 600)
FORMAT: "RGB"
TEST:
EVAL_PERIOD: 4000
DATALOADER:
FILTER_EMPTY_ANNOTATIONS: False
NUM_WORKERS: 4
VERSION: 2
69 changes: 69 additions & 0 deletions d2/converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Helper script to convert models trained with the main version of DETR to be used with the Detectron2 version.
"""
import json
import argparse

import numpy as np
import torch


def parse_args():
parser = argparse.ArgumentParser("D2 model converter")

parser.add_argument("--source_model", default="", type=str, help="Path or url to the DETR model to convert")
parser.add_argument("--output_model", default="", type=str, help="Path where to save the converted model")
return parser.parse_args()


def main():
args = parse_args()

# D2 expects contiguous classes, so we need to remap the 92 classes from DETR
# fmt: off
coco_idx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77,
78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91]
# fmt: on

coco_idx = np.array(coco_idx)

if args.source_model.startswith("https"):
checkpoint = torch.hub.load_state_dict_from_url(args.source_model, map_location="cpu", check_hash=True)
else:
checkpoint = torch.load(args.source_model, map_location="cpu")
model_to_convert = checkpoint["model"]

model_converted = {}
for k in model_to_convert.keys():
old_k = k
if "backbone" in k:
k = k.replace("backbone.0.body.", "")
if "layer" not in k:
k = "stem." + k
for t in [1, 2, 3, 4]:
k = k.replace(f"layer{t}", f"res{t + 1}")
for t in [1, 2, 3]:
k = k.replace(f"bn{t}", f"conv{t}.norm")
k = k.replace("downsample.0", "shortcut")
k = k.replace("downsample.1", "shortcut.norm")
k = "backbone.0.backbone." + k
k = "detr." + k
print(old_k, "->", k)
if "class_embed" in old_k:
v = model_to_convert[old_k].detach()
if v.shape[0] == 92:
shape_old = v.shape
model_converted[k] = v[coco_idx]
print("Head conversion: changing shape from {} to {}".format(shape_old, model_converted[k].shape))
continue
model_converted[k] = model_to_convert[old_k].detach()

model_to_save = {"model": model_converted}
torch.save(model_to_save, args.output_model)


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions d2/detr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .config import add_detr_config
from .detr import Detr
from .dataset_mapper import DetrDatasetMapper
32 changes: 32 additions & 0 deletions d2/detr/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from detectron2.config import CfgNode as CN


def add_detr_config(cfg):
"""
Add config for DETR.
"""
cfg.MODEL.DETR = CN()
cfg.MODEL.DETR.NUM_CLASSES = 80

# LOSS
cfg.MODEL.DETR.GIOU_WEIGHT = 2.0
cfg.MODEL.DETR.L1_WEIGHT = 5.0
cfg.MODEL.DETR.DEEP_SUPERVISION = True
cfg.MODEL.DETR.NO_OBJECT_WEIGHT = 0.1

# TRANSFORMER
cfg.MODEL.DETR.NHEADS = 8
cfg.MODEL.DETR.DROPOUT = 0.1
cfg.MODEL.DETR.DIM_FEEDFORWARD = 2048
cfg.MODEL.DETR.ENC_LAYERS = 6
cfg.MODEL.DETR.DEC_LAYERS = 6
cfg.MODEL.DETR.PRE_NORM = False
cfg.MODEL.DETR.PASS_POS_AND_QUERY = True

cfg.MODEL.DETR.HIDDEN_DIM = 256
cfg.MODEL.DETR.NUM_OBJECT_QUERIES = 100

cfg.SOLVER.OPTIMIZER = "ADAMW"
cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
122 changes: 122 additions & 0 deletions d2/detr/dataset_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import logging

import numpy as np
import torch

from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
from detectron2.data.transforms import TransformGen

__all__ = ["DetrDatasetMapper"]


def build_transform_gen(cfg, is_train):
"""
Create a list of :class:`TransformGen` from config.
Returns:
list[TransformGen]
"""
if is_train:
min_size = cfg.INPUT.MIN_SIZE_TRAIN
max_size = cfg.INPUT.MAX_SIZE_TRAIN
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
else:
min_size = cfg.INPUT.MIN_SIZE_TEST
max_size = cfg.INPUT.MAX_SIZE_TEST
sample_style = "choice"
if sample_style == "range":
assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size))

logger = logging.getLogger(__name__)
tfm_gens = []
if is_train:
tfm_gens.append(T.RandomFlip())
tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
if is_train:
logger.info("TransformGens used in training: " + str(tfm_gens))
return tfm_gens


class DetrDatasetMapper:
"""
A callable which takes a dataset dict in Detectron2 Dataset format,
and map it into a format used by DETR.
The callable currently does the following:
1. Read the image from "file_name"
2. Applies geometric transforms to the image and annotation
3. Find and applies suitable cropping to the image and annotation
4. Prepare image and annotation to Tensors
"""

def __init__(self, cfg, is_train=True):
if cfg.INPUT.CROP.ENABLED and is_train:
self.crop_gen = [
T.ResizeShortestEdge([400, 500, 600], sample_style="choice"),
T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE),
]
else:
self.crop_gen = None

assert not cfg.MODEL.MASK_ON, "Mask is not supported"

self.tfm_gens = build_transform_gen(cfg, is_train)
logging.getLogger(__name__).info(
"Full TransformGens used in training: {}, crop: {}".format(str(self.tfm_gens), str(self.crop_gen))
)

self.img_format = cfg.INPUT.FORMAT
self.is_train = is_train

def __call__(self, dataset_dict):
"""
Args:
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
Returns:
dict: a format that builtin models in detectron2 accept
"""
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
utils.check_image_size(dataset_dict, image)

if self.crop_gen is None:
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
else:
if np.random.rand() > 0.5:
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
else:
image, transforms = T.apply_transform_gens(
self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image
)

image_shape = image.shape[:2] # h, w

# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor.
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))

if not self.is_train:
# USER: Modify this if you want to keep them for some reason.
dataset_dict.pop("annotations", None)
return dataset_dict

if "annotations" in dataset_dict:
# USER: Modify this if you want to keep them for some reason.
for anno in dataset_dict["annotations"]:
anno.pop("segmentation", None)
anno.pop("keypoints", None)

# USER: Implement additional transformations if you have other types of data
annos = [
utils.transform_instance_annotations(obj, transforms, image_shape)
for obj in dataset_dict.pop("annotations")
if obj.get("iscrowd", 0) == 0
]
instances = utils.annotations_to_instances(annos, image_shape)
dataset_dict["instances"] = utils.filter_empty_instances(instances)
return dataset_dict
Loading

0 comments on commit 0099bd1

Please sign in to comment.