diff --git a/README.md b/README.md
index 6fd329e7..cfcbbd50 100644
--- a/README.md
+++ b/README.md
@@ -1,344 +1,69 @@
-**Use this instead: https://github.com/facebookresearch/maskrcnn-benchmark**
+# Path Aggregation Network for Instance Segmentation
-# A Pytorch Implementation of Detectron
+by [Shu Liu](http://shuliu.me), Lu Qi, Haifang Qin, [Jianping Shi](https://shijianping.me/), [Jiaya Jia](http://jiaya.me/).
-[](https://travis-ci.com/roytseng-tw/Detectron.pytorch)
+### Introduction
-
+This repository is for the CVPR 2018 Spotlight paper, '[Path Aggregation Network for Instance Segmentation](https://arxiv.org/abs/1803.01534)', which ranked 1st place of [COCO Instance Segmentation Challenge 2017](http://cocodataset.org/#detections-leaderboard) , 2nd place of [COCO Detection Challenge 2017](http://cocodataset.org/#detections-leaderboard) (Team Name: [UCenter](https://places-coco2017.github.io/#winners)) and 1st place of 2018 [Scene Understanding Challenge for Autonomous Navigation in Unstructured Environments](http://cvit.iiit.ac.in/scene-understanding-challenge-2018/benchmarks.php#instance) (Team Name: TUTU).
-

+### Citation
-
Example output of e2e_mask_rcnn-R-101-FPN_2x using Detectron pretrained weight.
+If PANet is useful for your research, please consider citing:
-

+ @inproceedings{liu2018path,
+ author = {Shu Liu and
+ Lu Qi and
+ Haifang Qin and
+ Jianping Shi and
+ Jiaya Jia},
+ title = {Path Aggregation Network for Instance Segmentation},
+ booktitle = {Proceedings of IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
+ year = {2018}
+ }
-
Corresponding example output from Detectron.
-

+### Disclaimer
-
Example output of e2e_keypoint_rcnn-R-50-FPN_s1x using Detectron pretrained weight.
+- The origin code was implemented based on the modified version of Caffe maintained by Sensetime Research. Due to several reasons, we could not release our origin code.
+- In this repository, we provide our re-implementation of PANet based on Pytorch. Note that our code is heavily based on [Detectron.pytorch](https://github.com/roytseng-tw/Detectron.pytorch). Thanks [Roy](https://github.com/roytseng-tw) for his great work!
+- Several details, e.g., weight initialization and RPN joint training, in [Detectron](https://github.com/facebookresearch/Detectron) is fairly different from our origin implementation. In this repository, we simply follow Detectron because it achieves a better baseline than the codebase used in our paper.
+- In this repository, we test our code with BN layers in the backbone fixed and use GN in other part. We expect to achieve a better performance with Synchronized Batch Normalization Layer and train all parameter layers as what we have done in our paper. With those differences and a much better baseline, the improvement is **not** same as the one we reported. But we achieve a **better** performance than our origin implementation.
+- We trained with image batch size 16 using 8*P40. The performance should be similar with batch size 8.
-
+### Installation
-**This code follows the implementation architecture of Detectron.** Only part of the functionality is supported. Check [this section](#supported-network-modules) for more information.
+For environment requirements, data preparation and compilation, please refer to [Detectron.pytorch](https://github.com/roytseng-tw/Detectron.pytorch).
-With this code, you can...
+WARNING: pytorch 0.4.1 is broken, see https://github.com/pytorch/pytorch/issues/8483. Use pytorch 0.4.0
-1. **Train your model from scratch.**
-2. **Inference using the pretrained weight file (*.pkl) from Detectron.**
+### Usage
-This repository is originally built on [jwyang/faster-rcnn.pytorch](https://github.com/jwyang/faster-rcnn.pytorch). However, after many modifications, the structure changes a lot and it's now more similar to [Detectron](https://github.com/facebookresearch/Detectron). I deliberately make everything similar or identical to Detectron's implementation, so as to reproduce the result directly from official pretrained weight files.
+For training and testing, we keep the same as the one in [Detectron.pytorch](https://github.com/roytseng-tw/Detectron.pytorch). To train and test PANet, simply use corresponding config files. For example, to train PANet on COCO:
-This implementation has the following features:
-
-- **It is pure Pytorch code**. Of course, there are some CUDA code.
-
-- **It supports multi-image batch training**.
-
-- **It supports multiple GPUs training**.
-
-- **It supports three pooling methods**. Notice that only **roi align** is revised to match the implementation in Caffe2. So, use it.
-
-- **It is memory efficient**. For data batching, there are two techiniques available to reduce memory usage: 1) *Aspect grouping*: group images with similar aspect ratio in a batch 2) *Aspect cropping*: crop images that are too long. Aspect grouping is implemented in Detectron, so it's used for default. Aspect cropping is the idea from [jwyang/faster-rcnn.pytorch](https://github.com/jwyang/faster-rcnn.pytorch), and it's not used for default.
-
- Besides of that, I implement a customized `nn.DataParallel ` module which enables different batch blob size on different gpus. Check [My nn.DataParallel](#my-nndataparallel) section for more details about this.
-
-## News
-
-- (2018/05/25) Support ResNeXt backbones.
-- (2018/05/22) Add group normalization baselines.
-- (2018/05/15) PyTorch0.4 is supported now !
-
-## Getting Started
-Clone the repo:
-
-```
-git clone https://github.com/roytseng-tw/mask-rcnn.pytorch.git
-```
-
-### Requirements
-
-Tested under python3.
-
-- python packages
- - pytorch>=0.3.1
- - torchvision>=0.2.0
- - cython
- - matplotlib
- - numpy
- - scipy
- - opencv
- - pyyaml
- - packaging
- - [pycocotools](https://github.com/cocodataset/cocoapi) — for COCO dataset, also available from pip.
- - tensorboardX — for logging the losses in Tensorboard
-- An NVIDAI GPU and CUDA 8.0 or higher. Some operations only have gpu implementation.
-- **NOTICE**: different versions of Pytorch package have different memory usages.
-
-### Compilation
-
-Compile the CUDA code:
-
-```
-cd lib # please change to this directory
-sh make.sh
-```
-
-If your are using Volta GPUs, uncomment this [line](https://github.com/roytseng-tw/mask-rcnn.pytorch/tree/master/lib/make.sh#L15) in `lib/mask.sh` and remember to postpend a backslash at the line above. `CUDA_PATH` defaults to `/usr/loca/cuda`. If you want to use a CUDA library on different path, change this [line](https://github.com/roytseng-tw/mask-rcnn.pytorch/tree/master/lib/make.sh#L3) accordingly.
-
-It will compile all the modules you need, including NMS, ROI_Pooing, ROI_Crop and ROI_Align. (Actually gpu nms is never used ...)
-
-Note that, If you use `CUDA_VISIBLE_DEVICES` to set gpus, **make sure at least one gpu is visible when compile the code.**
-
-### Data Preparation
-
-Create a data folder under the repo,
-
-```
-cd {repo_root}
-mkdir data
+```shell
+python tools/train_net_step.py --dataset coco2017 --cfg configs/panet/e2e_panet_R-50-FPN_2x_mask.yaml
```
-- **COCO**:
- Download the coco images and annotations from [coco website](http://cocodataset.org/#download).
+To evaluate the model, simply use:
- And make sure to put the files as the following structure:
- ```
- coco
- ├── annotations
- | ├── instances_minival2014.json
- │ ├── instances_train2014.json
- │ ├── instances_train2017.json
- │ ├── instances_val2014.json
- │ ├── instances_val2017.json
- │ ├── instances_valminusminival2014.json
- │ ├── ...
- |
- └── images
- ├── train2014
- ├── train2017
- ├── val2014
- ├──val2017
- ├── ...
- ```
- Download coco mini annotations from [here](https://s3-us-west-2.amazonaws.com/detectron/coco/coco_annotations_minival.tgz).
- Please note that minival is exactly equivalent to the recently defined 2017 val set. Similarly, the union of valminusminival and the 2014 train is exactly equivalent to the 2017 train set.
-
- Feel free to put the dataset at any place you want, and then soft link the dataset under the `data/` folder:
-
- ```
- ln -s path/to/coco data/coco
- ```
-
- Recommend to put the images on a SSD for possible better training performance
-
-### Pretrained Model
-
-I use ImageNet pretrained weights from Caffe for the backbone networks.
-
-- [ResNet50](https://drive.google.com/open?id=1wHSvusQ1CiEMc5Nx5R8adqoHQjIDWXl1), [ResNet101](https://drive.google.com/open?id=1x2fTMqLrn63EMW0VuK4GEa2eQKzvJ_7l), [ResNet152](https://drive.google.com/open?id=1NSCycOb7pU0KzluH326zmyMFUU55JslF)
-- [VGG16](https://drive.google.com/open?id=19UphT53C0Ua9JAtICnw84PPTa3sZZ_9k) (vgg backbone is not implemented yet)
-
-Download them and put them into the `{repo_root}/data/pretrained_model`.
-
-You can the following command to download them all:
-
-- extra required packages: `argparse_color_formater`, `colorama`, `requests`
-
-```
-python tools/download_imagenet_weights.py
+```shell
+python tools/test_net.py --dataset coco2017 --cfg configs/panet/e2e_panet_R-50-FPN_2x_mask.yaml --load_ckpt {path/to/your/checkpoint}
```
-**NOTE**: Caffe pretrained weights have slightly better performance than Pytorch pretrained. Suggest to use Caffe pretrained models from the above link to reproduce the results. By the way, Detectron also use pretrained weights from Caffe.
-
-**If you want to use pytorch pre-trained models, please remember to transpose images from BGR to RGB, and also use the same data preprocessing (minus mean and normalize) as used in Pytorch pretrained model.**
-
-#### ImageNet Pretrained Model provided by Detectron
-
-- [R-50.pkl](https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/MSRA/R-50.pkl)
-- [R-101.pkl](https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/MSRA/R-101.pkl)
-- [R-50-GN.pkl](https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/47261647/R-50-GN.pkl)
-- [R-101-GN.pkl](https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/47592356/R-101-GN.pkl)
-- [X-101-32x8d.pkl](https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/20171220/X-101-32x8d.pkl)
-- [X-101-64x4d.pkl](https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/FBResNeXt/X-101-64x4d.pkl)
-- [X-152-32x8d-IN5k.pkl](https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/25093814/X-152-32x8d-IN5k.pkl)
-
-Besides of using the pretrained weights for ResNet above, you can also use the weights from Detectron by changing the corresponding line in model config file as follows:
-```
-RESNETS:
- IMAGENET_PRETRAINED_WEIGHTS: 'data/pretrained_model/R-50.pkl'
-```
-
-R-50-GN.pkl and R-101-GN.pkl are required for gn_baselines.
-
-X-101-32x8d.pkl, X-101-64x4d.pkl and X-152-32x8d-IN5k.pkl are required for ResNeXt backbones.
-
-## Training
-
-**DO NOT CHANGE anything in the provided config files(configs/\*\*/xxxx.yml) unless you know what you are doing**
-
-Use the environment variable `CUDA_VISIBLE_DEVICES` to control which GPUs to use.
-
-### Adapative config adjustment
-
-#### Let's define some terms first
-
- batch_size: `NUM_GPUS` x `TRAIN.IMS_PER_BATCH`
- effective_batch_size: batch_size x `iter_size`
- change of somethining: `new value of something / old value of something`
-
-Following config options will be adjusted **automatically** according to actual training setups: 1) number of GPUs `NUM_GPUS`, 2) batch size per GPU `TRAIN.IMS_PER_BATCH`, 3) update period `iter_size`
-
-- `SOLVER.BASE_LR`: adjust directly propotional to the change of batch_size.
-- `SOLVER.STEPS`, `SOLVER.MAX_ITER`: adjust inversely propotional to the change of effective_batch_size.
-
-### Train from scratch
-Take mask-rcnn with res50 backbone for example.
-```
-python tools/train_net_step.py --dataset coco2017 --cfg configs/baselines/e2e_mask_rcnn_R-50-C4.yml --use_tfboard --bs {batch_size} --nw {num_workers}
-```
-
-Use `--bs` to overwrite the default batch size to a proper value that fits into your GPUs. Simliar for `--nw`, number of data loader threads defaults to 4 in config.py.
-
-Specify `—-use_tfboard` to log the losses on Tensorboard.
-
-**NOTE**: use `--dataset keypoints_coco2017` when training for keypoint-rcnn.
-
-### The use of `--iter_size`
-As in Caffe, update network once (`optimizer.step()`) every `iter_size` iterations (forward + backward). This way to have a larger effective batch size for training. Notice that, step count is only increased after network update.
-
-```
-python tools/train_net_step.py --dataset coco2017 --cfg configs/baselines/e2e_mask_rcnn_R-50-C4.yml --bs 4 --iter_size 4
-```
-`iter_size` defaults to 1.
-
-### Finetune from a pretrained checkpoint
-```
-python tools/train_net_step.py ... --load_ckpt {path/to/the/checkpoint}
-```
-or using Detectron's checkpoint file
-```
-python tools/train_net_step.py ... --load_detectron {path/to/the/checkpoint}
-```
-
-### Resume training with the same dataset and batch size
-```
-python tools/train_net_step.py ... --load_ckpt {path/to/the/checkpoint} --resume
-```
-When resume the training, **step count** and **optimizer state** will also be restored from the checkpoint. For SGD optimizer, optimizer state contains the momentum for each trainable parameter.
-
-**NOTE**: `--resume` is not yet supported for `--load_detectron`
-
-### Set config options in command line
-```
- python tools/train_net_step.py ... --no_save --set {config.name1} {value1} {config.name2} {value2} ...
-```
-- For Example, run for debugging.
- ```
- python tools/train_net_step.py ... --no_save --set DEBUG True
- ```
- Load less annotations to accelarate training progress. Add `--no_save` to avoid saving any checkpoint or logging.
-
-### Show command line help messages
-```
-python train_net_step.py --help
-```
-
-### Two Training Scripts
-
-In short, use `train_net_step.py`.
-
-In `train_net_step.py`:
-- `SOLVER.LR_POLICY: steps_with_decay` is supported.
-- Training warm up in [Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour](https://arxiv.org/abs/1706.02677) is supported.
-
-(Deprecated) In `train_net.py` some config options have no effects and worth noticing:
-
- - `SOLVER.LR_POLICY`, `SOLVER.MAX_ITER`, `SOLVER.STEPS`,`SOLVER.LRS`:
- For now, the training policy is controlled by these command line arguments:
-
- - **`--epochs`**: How many epochs to train. One epoch means one travel through the whole training sets. Defaults to 6.
- - **`--lr_decay_epochs `**: Epochs to decay the learning rate on. Decay happens on the beginning of a epoch. Epoch is 0-indexed. Defaults to [4, 5].
-
- For more command line arguments, please refer to `python train_net.py --help`
-
-- `SOLVER.WARM_UP_ITERS`, `SOLVER.WARM_UP_FACTOR`, `SOLVER.WARM_UP_METHOD`:
- Training warm up is not supported.
-
-## Inference
-
-### Evaluate the training results
-For example, test mask-rcnn on coco2017 val set
-```
-python tools/test_net.py --dataset coco2017 --cfg config/baselines/e2e_mask_rcnn_R-50-FPN_1x.yaml --load_ckpt {path/to/your/checkpoint}
-```
-Use `--load_detectron` to load Detectron's checkpoint. If multiple gpus are available, add `--multi-gpu-testing`.
-
-Specify a different output directry, use `--output_dir {...}`. Defaults to `{the/parent/dir/of/checkpoint}/test`
-
-### Visualize the training results on images
-```
-python tools/infer_simple.py --dataset coco --cfg cfgs/baselines/e2e_mask_rcnn_R-50-C4.yml --load_ckpt {path/to/your/checkpoint} --image_dir {dir/of/input/images} --output_dir {dir/to/save/visualizations}
-```
-`--output_dir` defaults to `infer_outputs`.
-
-## Supported Network modules
-
-- Backbone:
- - ResNet:
- `ResNet50_conv4_body`,`ResNet50_conv5_body`,
- `ResNet101_Conv4_Body`,`ResNet101_Conv5_Body`,
- `ResNet152_Conv5_Body`
- - ResNeXt:
- `[fpn_]ResNet101_Conv4_Body`,`[fpn_]ResNet101_Conv5_Body`, `[fpn_]ResNet152_Conv5_Body`
- - FPN:
- `fpn_ResNet50_conv5_body`,`fpn_ResNet50_conv5_P2only_body`,
- `fpn_ResNet101_conv5_body`,`fpn_ResNet101_conv5_P2only_body`,`fpn_ResNet152_conv5_body`,`fpn_ResNet152_conv5_P2only_body`
-
-- Box head:
- `ResNet_roi_conv5_head`,`roi_2mlp_head`, `roi_Xconv1fc_head`, `roi_Xconv1fc_gn_head`
-
-- Mask head:
- `mask_rcnn_fcn_head_v0upshare`,`mask_rcnn_fcn_head_v0up`, `mask_rcnn_fcn_head_v1up`, `mask_rcnn_fcn_head_v1up4convs`, `mask_rcnn_fcn_head_v1up4convs_gn`
-
-- Keypoints head:
- `roi_pose_head_v1convX`
-
-**NOTE**: the naming is similar to the one used in Detectron. Just remove any prepending `add_`.
-
-## Supported Datasets
-
-Only COCO is supported for now. However, the whole dataset library implementation is almost identical to Detectron's, so it should be easy to add more datasets supported by Detectron.
-
-## Configuration Options
-
-Architecture specific configuration files are put under [configs](configs/). The general configuration file [lib/core/config.py](lib/core/config.py) **has almost all the options with same default values as in Detectron's**, so it's effortless to transform the architecture specific configs from Detectron.
-
-**Some options from Detectron are not used** because the corresponding functionalities are not implemented yet. For example, data augmentation on testing.
-
-### Extra options
-- `MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS = True`: Whether to load ImageNet pretrained weights.
- - `RESNETS.IMAGENET_PRETRAINED_WEIGHTS = ''`: Path to pretrained residual network weights. If start with `'/'`, then it is treated as a absolute path. Otherwise, treat as a relative path to `ROOT_DIR`.
-- `TRAIN.ASPECT_CROPPING = False`, `TRAIN.ASPECT_HI = 2`, `TRAIN.ASPECT_LO = 0.5`: Options for aspect cropping to restrict image aspect ratio range.
-- `RPN.OUT_DIM_AS_IN_DIM = True`, `RPN.OUT_DIM = 512`, `RPN.CLS_ACTIVATION = 'sigmoid'`: Official implement of RPN has same input and output feature channels and use sigmoid as the activation function for fg/bg class prediction. In [jwyang's implementation](https://github.com/jwyang/faster-rcnn.pytorch/blob/master/lib/model/rpn/rpn.py#L28), it fix output channel number to 512 and use softmax as activation function.
+### Main Results
-### How to transform configuration files from Detectron
-1. Remove `MODEL.NUM_CLASSES`. It will be set according to the dataset specified by `--dataset`.
-2. Remove `TRAIN.WEIGHTS`, `TRAIN.DATASETS` and `TEST.DATASETS`
-3. For module type options (e.g `MODEL.CONV_BODY`, `FAST_RCNN.ROI_BOX_HEAD` ...), remove `add_` in the string if exists.
-4. If want to load ImageNet pretrained weights for the model, add `RESNETS.IMAGENET_PRETRAINED_WEIGHTS` pointing to the pretrained weight file. If not, set `MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS` to `False`.
-5. [Optional] Delete `OUTPUT_DIR: .` at the last line
-6. Do **NOT** change the option `NUM_GPUS` in the config file. It's used to infer the original batch size for training, and learning rate will be linearly scaled according to batch size change. Proper learning rate adjustment is important for training with different batch size.
-7. For group normalization baselines, add `RESNETS.USE_GN: True`.
+ Backbone | Type | Batch Size | LR Schedules | Box AP | Mask AP | Download Links
+ :------------: |:------------: |:------------: |:------: | :-------: | :--------------:| :--------------:
+ R-50-PANet (paper) | Faster | 16 | 1x | 39.2 | - | -
+ R-50-PANet | Faster | 16 | 1x | **39.8** | - | [model](https://drive.google.com/file/d/1_ahNQHY3D4mbsMWHR2FwmItBkLwYOrS4/view?usp=sharing)
+ R-50-PANet-2fc (paper) | Faster | 16 | 1x | 39.0 | - | -
+ R-50-PANet-2fc | Faster | 16 | 1x | **39.6** | - | [model](https://drive.google.com/file/d/1s-xm8GxHbmnt5M3gOMacXIRMvCGaDeRR/view?usp=sharing)
+ R-50-PANet (paper) | Mask| 16 | 2x | 42.1 | 37.8 | -
+ R-50-PANet | Mask | 16| 2x | **43.1** | **38.3** | [model](https://drive.google.com/file/d/1-pVZQ3GR6Aj7KJzH9nWoRQ-Lts8IcdMS/view?usp=sharing)
-## My nn.DataParallel
+Results on COCO 20017 *val* subset produced by this repository. In our paper, we used Synchronized Batch Normalization following all parameter layers. While in this repository, we fix BN layers in the backbone and use GN layers in other part. With the same set of hyper-parameters, e.g., multi-scales, this repository can produce better performance than that in our origin paper. We expect a better performance with Synchronized Batch Normalization Layer.
-- **Keep certain keyword inputs on cpu**
- Official DataParallel will broadcast all the input Variables to GPUs. However, many rpn related computations are done in CPU, and it's unnecessary to put those related inputs on GPUs.
-- **Allow Different blob size for different GPU**
- To save gpu memory, images are padded seperately for each gpu.
-- **Work with returned value of dictionary type**
+### Questions
-## Benchmark
-[BENCHMARK.md](BENCHMARK.md)
+Please contact 'liushuhust@gmail.com'
diff --git a/configs/panet/e2e_panet_R-50-FPN_1x_det.yaml b/configs/panet/e2e_panet_R-50-FPN_1x_det.yaml
new file mode 100644
index 00000000..8fcf6334
--- /dev/null
+++ b/configs/panet/e2e_panet_R-50-FPN_1x_det.yaml
@@ -0,0 +1,35 @@
+MODEL:
+ TYPE: generalized_rcnn
+ CONV_BODY: FPN.fpn_ResNet50_conv5_body_bup
+ FASTER_RCNN: True
+NUM_GPUS: 8
+RESNETS:
+ IMAGENET_PRETRAINED_WEIGHTS: 'data/pretrained_model/resnet50_caffe.pth'
+SOLVER:
+ WEIGHT_DECAY: 0.0001
+ LR_POLICY: steps_with_decay
+ BASE_LR: 0.02
+ GAMMA: 0.1
+ MAX_ITER: 90000
+ STEPS: [0, 60000, 80000]
+FPN:
+ FPN_ON: True
+ MULTILEVEL_ROIS: True
+ MULTILEVEL_RPN: True
+ USE_GN: True # Note: use GN on the FPN-specific layers
+FAST_RCNN:
+ ROI_BOX_HEAD: fast_rcnn_heads.roi_Xconv1fc_gn_head_panet # Note: this is a Conv GN head
+ ROI_XFORM_METHOD: RoIAlign
+ ROI_XFORM_RESOLUTION: 7
+ ROI_XFORM_SAMPLING_RATIO: 2
+TRAIN:
+ SCALES: (1200,1200,1000,800,600,400)
+ MAX_SIZE: 1400
+ BATCH_SIZE_PER_IM: 512
+ RPN_PRE_NMS_TOP_N: 2000 # Per FPN level
+TEST:
+ SCALE: 1000
+ MAX_SIZE: 1400
+ NMS: 0.5
+ RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
+ RPN_POST_NMS_TOP_N: 1000
diff --git a/configs/panet/e2e_panet_R-50-FPN_1x_det_2fc.yaml b/configs/panet/e2e_panet_R-50-FPN_1x_det_2fc.yaml
new file mode 100644
index 00000000..eedf5422
--- /dev/null
+++ b/configs/panet/e2e_panet_R-50-FPN_1x_det_2fc.yaml
@@ -0,0 +1,35 @@
+MODEL:
+ TYPE: generalized_rcnn
+ CONV_BODY: FPN.fpn_ResNet50_conv5_body_bup
+ FASTER_RCNN: True
+NUM_GPUS: 8
+RESNETS:
+ IMAGENET_PRETRAINED_WEIGHTS: 'data/pretrained_model/resnet50_caffe.pth'
+SOLVER:
+ WEIGHT_DECAY: 0.0001
+ LR_POLICY: steps_with_decay
+ BASE_LR: 0.02
+ GAMMA: 0.1
+ MAX_ITER: 90000
+ STEPS: [0, 60000, 80000]
+FPN:
+ FPN_ON: True
+ MULTILEVEL_ROIS: True
+ MULTILEVEL_RPN: True
+ USE_GN: True # Note: use GN on the FPN-specific layers
+FAST_RCNN:
+ ROI_BOX_HEAD: fast_rcnn_heads.roi_2mlp_head_gn_panet # Note: this is a Conv GN head
+ ROI_XFORM_METHOD: RoIAlign
+ ROI_XFORM_RESOLUTION: 7
+ ROI_XFORM_SAMPLING_RATIO: 2
+TRAIN:
+ SCALES: (1200,1200,1000,800,600,400)
+ MAX_SIZE: 1400
+ BATCH_SIZE_PER_IM: 512
+ RPN_PRE_NMS_TOP_N: 2000 # Per FPN level
+TEST:
+ SCALE: 1000
+ MAX_SIZE: 1400
+ NMS: 0.5
+ RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
+ RPN_POST_NMS_TOP_N: 1000
diff --git a/configs/panet/e2e_panet_R-50-FPN_2x_mask.yaml b/configs/panet/e2e_panet_R-50-FPN_2x_mask.yaml
new file mode 100644
index 00000000..7f8700ba
--- /dev/null
+++ b/configs/panet/e2e_panet_R-50-FPN_2x_mask.yaml
@@ -0,0 +1,45 @@
+MODEL:
+ TYPE: generalized_rcnn
+ CONV_BODY: FPN.fpn_ResNet50_conv5_body_bup
+ FASTER_RCNN: True
+ MASK_ON: True
+NUM_GPUS: 8
+SOLVER:
+ WEIGHT_DECAY: 0.0001
+ LR_POLICY: steps_with_decay
+ BASE_LR: 0.02
+ GAMMA: 0.1
+ MAX_ITER: 180000
+ STEPS: [0, 120000, 160000]
+FPN:
+ FPN_ON: True
+ MULTILEVEL_ROIS: True
+ MULTILEVEL_RPN: True
+ USE_GN: True # Note: use GN on the FPN-specific layers
+RESNETS:
+ IMAGENET_PRETRAINED_WEIGHTS: 'data/pretrained_model/resnet50_caffe.pth'
+FAST_RCNN:
+ ROI_BOX_HEAD: fast_rcnn_heads.roi_Xconv1fc_gn_head_panet # Note: this is a Conv GN head
+ ROI_XFORM_METHOD: RoIAlign
+ ROI_XFORM_RESOLUTION: 7
+ ROI_XFORM_SAMPLING_RATIO: 2
+MRCNN:
+ ROI_MASK_HEAD: mask_rcnn_heads.mask_rcnn_fcn_head_v1up4convs_gn_adp_ff # Note: this is a GN mask head
+ RESOLUTION: 28 # (output mask resolution) default 14
+ ROI_XFORM_METHOD: RoIAlign
+ ROI_XFORM_RESOLUTION: 14 # default 7
+ ROI_XFORM_SAMPLING_RATIO: 2 # default 0
+ DILATION: 1 # default 2
+ CONV_INIT: MSRAFill # default GaussianFill
+TRAIN:
+ SCALES: (1200, 1200, 1000, 800, 600, 400)
+ MAX_SIZE: 1400
+ BATCH_SIZE_PER_IM: 512
+ RPN_PRE_NMS_TOP_N: 2000 # Per FPN level
+TEST:
+ SCALE: 1000
+ MAX_SIZE: 1400
+ NMS: 0.5
+ RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
+ RPN_POST_NMS_TOP_N: 1000
+
diff --git a/lib/modeling/FPN.py b/lib/modeling/FPN.py
index 03cd60bc..6c3ca8d4 100644
--- a/lib/modeling/FPN.py
+++ b/lib/modeling/FPN.py
@@ -30,6 +30,12 @@ def fpn_ResNet50_conv5_body():
ResNet.ResNet50_conv5_body, fpn_level_info_ResNet50_conv5()
)
+def fpn_ResNet50_conv5_body_bup():
+ return fpn(
+ ResNet.ResNet50_conv5_body, fpn_level_info_ResNet50_conv5(),
+ panet_buttomup=True
+ )
+
def fpn_ResNet50_conv5_P2only_body():
return fpn(
@@ -77,10 +83,11 @@ class fpn(nn.Module):
similarly for fpn_level_info.dims: e.g [2048, 1024, 512, 256]
similarly for spatial_scale: e.g [1/32, 1/16, 1/8, 1/4]
"""
- def __init__(self, conv_body_func, fpn_level_info, P2only=False):
+ def __init__(self, conv_body_func, fpn_level_info, P2only=False, panet_buttomup=False):
super().__init__()
self.fpn_level_info = fpn_level_info
self.P2only = P2only
+ self.panet_buttomup = panet_buttomup
self.dim_out = fpn_dim = cfg.FPN.DIM
min_level, max_level = get_min_max_levels()
@@ -125,6 +132,35 @@ def __init__(self, conv_body_func, fpn_level_info, P2only=False):
self.spatial_scale.append(fpn_level_info.spatial_scales[i])
+ # add for panet buttom-up path
+ if self.panet_buttomup:
+ self.panet_buttomup_conv1_modules = nn.ModuleList()
+ self.panet_buttomup_conv2_modules = nn.ModuleList()
+ for i in range(self.num_backbone_stages - 1):
+ if cfg.FPN.USE_GN:
+ self.panet_buttomup_conv1_modules.append(nn.Sequential(
+ nn.Conv2d(fpn_dim, fpn_dim, 3, 2, 1, bias=True),
+ nn.GroupNorm(net_utils.get_group_gn(fpn_dim), fpn_dim,
+ eps=cfg.GROUP_NORM.EPSILON),
+ nn.ReLU(inplace=True)
+ ))
+ self.panet_buttomup_conv2_modules.append(nn.Sequential(
+ nn.Conv2d(fpn_dim, fpn_dim, 3, 1, 1, bias=True),
+ nn.GroupNorm(net_utils.get_group_gn(fpn_dim), fpn_dim,
+ eps=cfg.GROUP_NORM.EPSILON),
+ nn.ReLU(inplace=True)
+ ))
+ else:
+ self.panet_buttomup_conv1_modules.append(
+ nn.Conv2d(fpn_dim, fpn_dim, 3, 2, 1)
+ )
+ self.panet_buttomup_conv2_modules.append(
+ nn.Conv2d(fpn_dim, fpn_dim, 3, 1, 1)
+ )
+
+ #self.spatial_scale.append(fpn_level_info.spatial_scales[i])
+
+
#
# Step 2: build up starting from the coarsest backbone level
#
@@ -160,6 +196,7 @@ def _init_weights(self):
def init_func(m):
if isinstance(m, nn.Conv2d):
mynn.init.XavierFill(m.weight)
+ #mynn.init.MSRAFill(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
@@ -236,10 +273,25 @@ def forward(self, x):
self.topdown_lateral_modules[i](fpn_inner_blobs[-1], conv_body_blobs[-(i+2)])
)
fpn_output_blobs = []
+ if self.panet_buttomup:
+ fpn_middle_blobs = []
for i in range(self.num_backbone_stages):
- fpn_output_blobs.append(
- self.posthoc_modules[i](fpn_inner_blobs[i])
- )
+ if not self.panet_buttomup:
+ fpn_output_blobs.append(
+ self.posthoc_modules[i](fpn_inner_blobs[i])
+ )
+ else:
+ fpn_middle_blobs.append(
+ self.posthoc_modules[i](fpn_inner_blobs[i])
+ )
+ if self.panet_buttomup:
+ fpn_output_blobs.append(fpn_middle_blobs[-1])
+ for i in range(2, self.num_backbone_stages + 1):
+ fpn_tmp = self.panet_buttomup_conv1_modules[i - 2](fpn_output_blobs[0])
+ #print(fpn_middle_blobs[self.num_backbone_stages - i].size())
+ fpn_tmp = fpn_tmp + fpn_middle_blobs[self.num_backbone_stages - i]
+ fpn_tmp = self.panet_buttomup_conv2_modules[i - 2](fpn_tmp)
+ fpn_output_blobs.insert(0, fpn_tmp)
if hasattr(self, 'maxpool_p6'):
fpn_output_blobs.insert(0, self.maxpool_p6(fpn_output_blobs[0]))
diff --git a/lib/modeling/fast_rcnn_heads.py b/lib/modeling/fast_rcnn_heads.py
index 3b386ed3..81c273c4 100644
--- a/lib/modeling/fast_rcnn_heads.py
+++ b/lib/modeling/fast_rcnn_heads.py
@@ -57,7 +57,7 @@ def fast_rcnn_losses(cls_score, bbox_pred, label_int32, bbox_targets,
bbox_inside_weights = Variable(torch.from_numpy(bbox_inside_weights)).cuda(device_id)
bbox_outside_weights = Variable(torch.from_numpy(bbox_outside_weights)).cuda(device_id)
loss_bbox = net_utils.smooth_l1_loss(
- bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_weights)
+ bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_weights, beta=1/3)
# class accuracy
cls_preds = cls_score.max(dim=1)[1].type_as(rois_label)
@@ -240,3 +240,202 @@ def forward(self, x, rpn_ret):
x = self.convs(x)
x = F.relu(self.fc(x.view(batch_size, -1)), inplace=True)
return x
+
+
+class roi_Xconv1fc_gn_head_panet(nn.Module):
+ """Add a X conv + 1fc head, with GroupNorm"""
+ def __init__(self, dim_in, roi_xform_func, spatial_scale):
+ super().__init__()
+ self.dim_in = dim_in
+ self.roi_xform = roi_xform_func
+ self.spatial_scale = spatial_scale
+
+ hidden_dim = cfg.FAST_RCNN.CONV_HEAD_DIM
+ module_list = []
+ for i in range(cfg.FAST_RCNN.NUM_STACKED_CONVS - 1):
+ module_list.extend([
+ nn.Conv2d(dim_in, hidden_dim, 3, 1, 1, bias=False),
+ nn.GroupNorm(net_utils.get_group_gn(hidden_dim), hidden_dim,
+ eps=cfg.GROUP_NORM.EPSILON),
+ nn.ReLU(inplace=True)
+ ])
+ dim_in = hidden_dim
+ self.convs = nn.Sequential(*module_list)
+
+ self.dim_out = fc_dim = cfg.FAST_RCNN.MLP_HEAD_DIM
+ roi_size = cfg.FAST_RCNN.ROI_XFORM_RESOLUTION
+ self.fc = nn.Linear(dim_in * roi_size * roi_size, fc_dim)
+ num_levels = cfg.FPN.ROI_MAX_LEVEL - cfg.FPN.ROI_MIN_LEVEL + 1
+ self.conv1_head = nn.ModuleList()
+ for i in range(num_levels):
+ self.conv1_head.append(nn.Sequential(
+ nn.Conv2d(dim_in, hidden_dim, 3, 1, 1, bias=False),
+ nn.GroupNorm(net_utils.get_group_gn(hidden_dim), hidden_dim,
+ eps=cfg.GROUP_NORM.EPSILON),
+ nn.ReLU(inplace=True)
+ ))
+
+ self._init_weights()
+
+ def _init_weights(self):
+ def _init(m):
+ if isinstance(m, nn.Conv2d):
+ mynn.init.MSRAFill(m.weight)
+ elif isinstance(m, nn.Linear):
+ mynn.init.XavierFill(m.weight)
+ init.constant_(m.bias, 0)
+ self.apply(_init)
+
+ def detectron_weight_mapping(self):
+ mapping = {}
+ for i in range(cfg.FAST_RCNN.NUM_STACKED_CONVS):
+ mapping.update({
+ 'convs.%d.weight' % (i*3): 'head_conv%d_w' % (i+1),
+ 'convs.%d.weight' % (i*3+1): 'head_conv%d_gn_s' % (i+1),
+ 'convs.%d.bias' % (i*3+1): 'head_conv%d_gn_b' % (i+1)
+ })
+ mapping.update({
+ 'fc.weight': 'fc6_w',
+ 'fc.bias': 'fc6_b'
+ })
+ return mapping, []
+
+ def forward(self, x, rpn_ret):
+ x = self.roi_xform(
+ x, rpn_ret,
+ blob_rois='rois',
+ method=cfg.FAST_RCNN.ROI_XFORM_METHOD,
+ resolution=cfg.FAST_RCNN.ROI_XFORM_RESOLUTION,
+ spatial_scale=self.spatial_scale,
+ sampling_ratio=cfg.FAST_RCNN.ROI_XFORM_SAMPLING_RATIO,
+ panet=True
+ )
+ for i in range(len(x)):
+ x[i] = self.conv1_head[i](x[i])
+ for i in range(1, len(x)):
+ x[0] = torch.max(x[0], x[i])
+ x = x[0]
+ batch_size = x.size(0)
+
+ x = self.convs(x)
+ x = F.relu(self.fc(x.view(batch_size, -1)), inplace=True)
+ return x
+
+
+class roi_2mlp_head_gn(nn.Module):
+ """Add a ReLU MLP with two hidden layers."""
+ def __init__(self, dim_in, roi_xform_func, spatial_scale):
+ super().__init__()
+ self.dim_in = dim_in
+ self.roi_xform = roi_xform_func
+ self.spatial_scale = spatial_scale
+ self.dim_out = hidden_dim = cfg.FAST_RCNN.MLP_HEAD_DIM
+
+ roi_size = cfg.FAST_RCNN.ROI_XFORM_RESOLUTION
+ self.fc1 = nn.Sequential(nn.Linear(dim_in * roi_size**2, hidden_dim), nn.GroupNorm(net_utils.get_group_gn(hidden_dim), hidden_dim,
+ eps=cfg.GROUP_NORM.EPSILON))
+ self.fc2 = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.GroupNorm(net_utils.get_group_gn(hidden_dim), hidden_dim,
+ eps=cfg.GROUP_NORM.EPSILON))
+
+ self._init_weights()
+
+ def _init_weights(self):
+ def _init(m):
+ if isinstance(m, nn.Conv2d):
+ mynn.init.MSRAFill(m.weight)
+ elif isinstance(m, nn.Linear):
+ mynn.init.XavierFill(m.weight)
+ init.constant_(m.bias, 0)
+ self.apply(_init)
+
+
+ def detectron_weight_mapping(self):
+ detectron_weight_mapping = {
+ 'fc1.weight': 'fc6_w',
+ 'fc1.bias': 'fc6_b',
+ 'fc2.weight': 'fc7_w',
+ 'fc2.bias': 'fc7_b'
+ }
+ return detectron_weight_mapping, []
+
+ def forward(self, x, rpn_ret):
+ x = self.roi_xform(
+ x, rpn_ret,
+ blob_rois='rois',
+ method=cfg.FAST_RCNN.ROI_XFORM_METHOD,
+ resolution=cfg.FAST_RCNN.ROI_XFORM_RESOLUTION,
+ spatial_scale=self.spatial_scale,
+ sampling_ratio=cfg.FAST_RCNN.ROI_XFORM_SAMPLING_RATIO
+ )
+ batch_size = x.size(0)
+ x = F.relu(self.fc1(x.view(batch_size, -1)), inplace=True)
+ x = F.relu(self.fc2(x), inplace=True)
+
+ return x
+
+class roi_2mlp_head_gn_panet(nn.Module):
+ """Add a ReLU MLP with two hidden layers."""
+ def __init__(self, dim_in, roi_xform_func, spatial_scale):
+ super().__init__()
+ self.dim_in = dim_in
+ self.roi_xform = roi_xform_func
+ self.spatial_scale = spatial_scale
+ self.dim_out = hidden_dim = cfg.FAST_RCNN.MLP_HEAD_DIM
+
+ roi_size = cfg.FAST_RCNN.ROI_XFORM_RESOLUTION
+ num_levels = cfg.FPN.ROI_MAX_LEVEL - cfg.FPN.ROI_MIN_LEVEL + 1
+ self.fc1 = nn.ModuleList()
+ for i in range(num_levels):
+ self.fc1.append(nn.Sequential(
+ nn.Linear(dim_in * roi_size**2, hidden_dim),
+ nn.GroupNorm(net_utils.get_group_gn(hidden_dim), hidden_dim,
+ eps=cfg.GROUP_NORM.EPSILON),
+ nn.ReLU(inplace=True)
+ ))
+ #self.fc1 = nn.Sequential(nn.Linear(dim_in * roi_size**2, hidden_dim), nn.GroupNorm(net_utils.get_group_gn(hidden_dim), hidden_dim,
+ # eps=cfg.GROUP_NORM.EPSILON))
+ self.fc2 = nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
+ nn.GroupNorm(net_utils.get_group_gn(hidden_dim), hidden_dim,
+ eps=cfg.GROUP_NORM.EPSILON),
+ nn.ReLU(inplace=True))
+
+ self._init_weights()
+
+ def _init_weights(self):
+ def _init(m):
+ if isinstance(m, nn.Conv2d):
+ mynn.init.MSRAFill(m.weight)
+ elif isinstance(m, nn.Linear):
+ mynn.init.XavierFill(m.weight)
+ init.constant_(m.bias, 0)
+ self.apply(_init)
+
+
+ def detectron_weight_mapping(self):
+ detectron_weight_mapping = {
+ 'fc1.weight': 'fc6_w',
+ 'fc1.bias': 'fc6_b',
+ 'fc2.weight': 'fc7_w',
+ 'fc2.bias': 'fc7_b'
+ }
+ return detectron_weight_mapping, []
+
+ def forward(self, x, rpn_ret):
+ x = self.roi_xform(
+ x, rpn_ret,
+ blob_rois='rois',
+ method=cfg.FAST_RCNN.ROI_XFORM_METHOD,
+ resolution=cfg.FAST_RCNN.ROI_XFORM_RESOLUTION,
+ spatial_scale=self.spatial_scale,
+ sampling_ratio=cfg.FAST_RCNN.ROI_XFORM_SAMPLING_RATIO,
+ panet=True
+ )
+ batch_size = x[0].size(0)
+ for i in range(len(x)):
+ x[i] = self.fc1[i](x[i].view(batch_size, -1))
+ for i in range(1, len(x)):
+ x[0] = torch.max(x[0], x[i])
+ x = x[0]
+ x = self.fc2(x)
+
+ return x
diff --git a/lib/modeling/mask_rcnn_heads.py b/lib/modeling/mask_rcnn_heads.py
index e71be164..b058a6e4 100644
--- a/lib/modeling/mask_rcnn_heads.py
+++ b/lib/modeling/mask_rcnn_heads.py
@@ -60,14 +60,18 @@ def detectron_weight_mapping(self):
return mapping, orphan_in_detectron
def forward(self, x):
- x = self.classify(x)
+ if not isinstance(x, list):
+ x = self.classify(x)
+ else:
+ x[0] = self.classify(x[0])
+ x[1] = x[1].view(-1, 1, cfg.MRCNN.RESOLUTION, cfg.MRCNN.RESOLUTION)
+ x[1] = x[1].repeat(1, cfg.MODEL.NUM_CLASSES, 1, 1)
+ x = x[0] + x[1]
if cfg.MRCNN.UPSAMPLE_RATIO > 1:
x = self.upsample(x)
if not self.training:
x = F.sigmoid(x)
return x
-
-
# def mask_rcnn_losses(mask_pred, rois_mask, rois_label, weight):
# n_rois, n_classes, _, _ = mask_pred.size()
# rois_mask_label = rois_label[weight.data.nonzero().view(-1)]
@@ -115,7 +119,18 @@ def mask_rcnn_fcn_head_v1up4convs_gn(dim_in, roi_xform_func, spatial_scale):
return mask_rcnn_fcn_head_v1upXconvs_gn(
dim_in, roi_xform_func, spatial_scale, 4
)
+
+def mask_rcnn_fcn_head_v1up4convs_gn_adp(dim_in, roi_xform_func, spatial_scale):
+ """v1up design: 4 * (conv 3x3), convT 2x2, with GroupNorm"""
+ return mask_rcnn_fcn_head_v1upXconvs_gn_adp(
+ dim_in, roi_xform_func, spatial_scale, 4
+ )
+def mask_rcnn_fcn_head_v1up4convs_gn_adp_ff(dim_in, roi_xform_func, spatial_scale):
+ """v1up design: 4 * (conv 3x3), convT 2x2, with GroupNorm"""
+ return mask_rcnn_fcn_head_v1upXconvs_gn_adp_ff(
+ dim_in, roi_xform_func, spatial_scale, 4
+ )
def mask_rcnn_fcn_head_v1up(dim_in, roi_xform_func, spatial_scale):
"""v1up design: 2 * (conv 3x3), convT 2x2."""
@@ -383,3 +398,200 @@ def ResNet_roi_conv5_head_for_masks(dim_in):
stride_init = cfg.MRCNN.ROI_XFORM_RESOLUTION // 7 # by default: 2
module, dim_out = ResNet.add_stage(dim_in, 2048, 512, 3, dilation, stride_init)
return module, dim_out
+
+
+
+
+
+class mask_rcnn_fcn_head_v1upXconvs_gn_adp(nn.Module):
+ """v1upXconvs design: X * (conv 3x3), convT 2x2, with GroupNorm"""
+ def __init__(self, dim_in, roi_xform_func, spatial_scale, num_convs):
+ super().__init__()
+ self.dim_in = dim_in
+ self.roi_xform = roi_xform_func
+ self.spatial_scale = spatial_scale
+ self.num_convs = num_convs
+
+ dilation = cfg.MRCNN.DILATION
+ dim_inner = cfg.MRCNN.DIM_REDUCED
+ self.dim_out = dim_inner
+
+ module_list = []
+ for i in range(num_convs - 1):
+ module_list.extend([
+ nn.Conv2d(dim_in, dim_inner, 3, 1, padding=1*dilation, dilation=dilation, bias=False),
+ nn.GroupNorm(net_utils.get_group_gn(dim_inner), dim_inner, eps=cfg.GROUP_NORM.EPSILON),
+ nn.ReLU(inplace=True)
+ ])
+ dim_in = dim_inner
+ self.conv_fcn = nn.Sequential(*module_list)
+
+ self.mask_conv1 = nn.ModuleList()
+ num_levels = cfg.FPN.ROI_MAX_LEVEL - cfg.FPN.ROI_MIN_LEVEL + 1
+ for i in range(num_levels):
+ self.mask_conv1.append(nn.Sequential(
+ nn.Conv2d(dim_in, dim_inner, 3, 1, padding=1*dilation, dilation=dilation, bias=False),
+ nn.GroupNorm(net_utils.get_group_gn(dim_inner), dim_inner, eps=cfg.GROUP_NORM.EPSILON),
+ nn.ReLU(inplace=True)
+ ))
+
+
+ # upsample layer
+ self.upconv = nn.ConvTranspose2d(dim_inner, dim_inner, 2, 2, 0)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
+ if cfg.MRCNN.CONV_INIT == 'GaussianFill':
+ init.normal_(m.weight, std=0.001)
+ elif cfg.MRCNN.CONV_INIT == 'MSRAFill':
+ mynn.init.MSRAFill(m.weight)
+ else:
+ raise ValueError
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+ def detectron_weight_mapping(self):
+ mapping_to_detectron = {}
+ for i in range(self.num_convs):
+ mapping_to_detectron.update({
+ 'conv_fcn.%d.weight' % (3*i): '_mask_fcn%d_w' % (i+1),
+ 'conv_fcn.%d.weight' % (3*i+1): '_mask_fcn%d_gn_s' % (i+1),
+ 'conv_fcn.%d.bias' % (3*i+1): '_mask_fcn%d_gn_b' % (i+1)
+ })
+ mapping_to_detectron.update({
+ 'upconv.weight': 'conv5_mask_w',
+ 'upconv.bias': 'conv5_mask_b'
+ })
+
+ return mapping_to_detectron, []
+
+ def forward(self, x, rpn_ret):
+ x = self.roi_xform(
+ x, rpn_ret,
+ blob_rois='mask_rois',
+ method=cfg.MRCNN.ROI_XFORM_METHOD,
+ resolution=cfg.MRCNN.ROI_XFORM_RESOLUTION,
+ spatial_scale=self.spatial_scale,
+ sampling_ratio=cfg.MRCNN.ROI_XFORM_SAMPLING_RATIO,
+ panet=True
+ )
+ for i in range(len(x)):
+ x[i] = self.mask_conv1[i](x[i])
+ for i in range(1, len(x)):
+ x[0] = torch.max(x[0], x[i])
+ x = x[0]
+ x = self.conv_fcn(x)
+ return F.relu(self.upconv(x), inplace=True)
+
+
+
+class mask_rcnn_fcn_head_v1upXconvs_gn_adp_ff(nn.Module):
+ """v1upXconvs design: X * (conv 3x3), convT 2x2, with GroupNorm"""
+ def __init__(self, dim_in, roi_xform_func, spatial_scale, num_convs):
+ super().__init__()
+ self.dim_in = dim_in
+ self.roi_xform = roi_xform_func
+ self.spatial_scale = spatial_scale
+ self.num_convs = num_convs
+
+ dilation = cfg.MRCNN.DILATION
+ dim_inner = cfg.MRCNN.DIM_REDUCED
+ self.dim_out = dim_inner
+
+ module_list = []
+ for i in range(2):
+ module_list.extend([
+ nn.Conv2d(dim_in, dim_inner, 3, 1, padding=1*dilation, dilation=dilation, bias=False),
+ nn.GroupNorm(net_utils.get_group_gn(dim_inner), dim_inner, eps=cfg.GROUP_NORM.EPSILON),
+ nn.ReLU(inplace=True)
+ ])
+ dim_in = dim_inner
+ self.conv_fcn = nn.Sequential(*module_list)
+
+ self.mask_conv1 = nn.ModuleList()
+ num_levels = cfg.FPN.ROI_MAX_LEVEL - cfg.FPN.ROI_MIN_LEVEL + 1
+ for i in range(num_levels):
+ self.mask_conv1.append(nn.Sequential(
+ nn.Conv2d(dim_in, dim_inner, 3, 1, padding=1*dilation, dilation=dilation, bias=False),
+ nn.GroupNorm(net_utils.get_group_gn(dim_inner), dim_inner, eps=cfg.GROUP_NORM.EPSILON),
+ nn.ReLU(inplace=True)
+ ))
+
+ self.mask_conv4 = nn.Sequential(
+ nn.Conv2d(dim_in, dim_inner, 3, 1, padding=1*dilation, dilation=dilation, bias=False),
+ nn.GroupNorm(net_utils.get_group_gn(dim_inner), dim_inner, eps=cfg.GROUP_NORM.EPSILON),
+ nn.ReLU(inplace=True))
+
+ self.mask_conv4_fc = nn.Sequential(
+ nn.Conv2d(dim_in, dim_inner, 3, 1, padding=1*dilation, dilation=dilation, bias=False),
+ nn.GroupNorm(net_utils.get_group_gn(dim_inner), dim_inner, eps=cfg.GROUP_NORM.EPSILON),
+ nn.ReLU(inplace=True))
+
+ self.mask_conv5_fc = nn.Sequential(
+ nn.Conv2d(dim_in, int(dim_inner / 2), 3, 1, padding=1*dilation, dilation=dilation, bias=False),
+ nn.GroupNorm(net_utils.get_group_gn(dim_inner), int(dim_inner / 2), eps=cfg.GROUP_NORM.EPSILON),
+ nn.ReLU(inplace=True))
+
+ self.mask_fc = nn.Sequential(
+ nn.Linear(int(dim_inner / 2) * (cfg.MRCNN.ROI_XFORM_RESOLUTION) ** 2, cfg.MRCNN.RESOLUTION ** 2, bias=True),
+ nn.ReLU(inplace=True))
+
+
+
+ # upsample layer
+ self.upconv = nn.ConvTranspose2d(dim_inner, dim_inner, 2, 2, 0)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
+ if cfg.MRCNN.CONV_INIT == 'GaussianFill':
+ init.normal_(m.weight, std=0.001)
+ elif cfg.MRCNN.CONV_INIT == 'MSRAFill':
+ mynn.init.MSRAFill(m.weight)
+ else:
+ raise ValueError
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ init.constant_(m.bias, 0)
+
+ def detectron_weight_mapping(self):
+ mapping_to_detectron = {}
+ for i in range(self.num_convs):
+ mapping_to_detectron.update({
+ 'conv_fcn.%d.weight' % (3*i): '_mask_fcn%d_w' % (i+1),
+ 'conv_fcn.%d.weight' % (3*i+1): '_mask_fcn%d_gn_s' % (i+1),
+ 'conv_fcn.%d.bias' % (3*i+1): '_mask_fcn%d_gn_b' % (i+1)
+ })
+ mapping_to_detectron.update({
+ 'upconv.weight': 'conv5_mask_w',
+ 'upconv.bias': 'conv5_mask_b'
+ })
+
+ return mapping_to_detectron, []
+
+ def forward(self, x, rpn_ret):
+ x = self.roi_xform(
+ x, rpn_ret,
+ blob_rois='mask_rois',
+ method=cfg.MRCNN.ROI_XFORM_METHOD,
+ resolution=cfg.MRCNN.ROI_XFORM_RESOLUTION,
+ spatial_scale=self.spatial_scale,
+ sampling_ratio=cfg.MRCNN.ROI_XFORM_SAMPLING_RATIO,
+ panet=True
+ )
+ for i in range(len(x)):
+ x[i] = self.mask_conv1[i](x[i])
+ for i in range(1, len(x)):
+ x[0] = torch.max(x[0], x[i])
+ x = x[0]
+ x = self.conv_fcn(x)
+ batch_size = x.size(0)
+ x_fcn = F.relu(self.upconv(self.mask_conv4(x)), inplace=True)
+ x_ff = self.mask_fc(self.mask_conv5_fc(self.mask_conv4_fc(x)).view(batch_size, -1))
+
+ return [x_fcn, x_ff]
diff --git a/lib/modeling/model_builder.py b/lib/modeling/model_builder.py
index 0c7f1f49..7f1d03db 100644
--- a/lib/modeling/model_builder.py
+++ b/lib/modeling/model_builder.py
@@ -250,7 +250,7 @@ def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
return return_dict
def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoIPoolF',
- resolution=7, spatial_scale=1. / 16., sampling_ratio=0):
+ resolution=7, spatial_scale=1. / 16., sampling_ratio=0, panet=False):
"""Add the specified RoI pooling method. The sampling_ratio argument
is supported for some, but not all, RoI transform methods.
@@ -271,7 +271,10 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI
for lvl in range(k_min, k_max + 1):
bl_in = blobs_in[k_max - lvl] # blobs_in is in reversed order
sc = spatial_scale[k_max - lvl] # in reversed order
- bl_rois = blob_rois + '_fpn' + str(lvl)
+ if not panet:
+ bl_rois = blob_rois + '_fpn' + str(lvl)
+ else:
+ bl_rois = blob_rois
if len(rpn_ret[bl_rois]):
rois = Variable(torch.from_numpy(rpn_ret[bl_rois])).cuda(device_id)
if method == 'RoIPoolF':
@@ -290,17 +293,19 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI
xform_out = RoIAlignFunction(
resolution, resolution, sc, sampling_ratio)(bl_in, rois)
bl_out_list.append(xform_out)
-
- # The pooled features from all levels are concatenated along the
- # batch dimension into a single 4D tensor.
- xform_shuffled = torch.cat(bl_out_list, dim=0)
-
- # Unshuffle to match rois from dataloader
- device_id = xform_shuffled.get_device()
- restore_bl = rpn_ret[blob_rois + '_idx_restore_int32']
- restore_bl = Variable(
- torch.from_numpy(restore_bl.astype('int64', copy=False))).cuda(device_id)
- xform_out = xform_shuffled[restore_bl]
+ if not panet:
+ # The pooled features from all levels are concatenated along the
+ # batch dimension into a single 4D tensor.
+ xform_shuffled = torch.cat(bl_out_list, dim=0)
+
+ # Unshuffle to match rois from dataloader
+ device_id = xform_shuffled.get_device()
+ restore_bl = rpn_ret[blob_rois + '_idx_restore_int32']
+ restore_bl = Variable(
+ torch.from_numpy(restore_bl.astype('int64', copy=False))).cuda(device_id)
+ xform_out = xform_shuffled[restore_bl]
+ else:
+ return bl_out_list
else:
# Single feature level
# rois: holds R regions of interest, each is a 5-tuple