Skip to content

Commit

Permalink
Merge branch 'master' into cvs_ts
Browse files Browse the repository at this point in the history
  • Loading branch information
fantes committed Feb 5, 2019
2 parents 9c39873 + 8af1f17 commit 8fcad37
Show file tree
Hide file tree
Showing 49 changed files with 192,191 additions and 270 deletions.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Expand Up @@ -218,6 +218,9 @@ if (USE_CAFFE2)
${CAFFE2_PATCHES}/detectron/0002-compiled.patch
${CAFFE2_PATCHES}/detectron/0003-ops.patch
${CAFFE2_PATCHES}/detectron/0004-import.patch
${CAFFE2_PATCHES}/detectron/0005-visual_genome.patch
${CAFFE2_PATCHES}/detectron/0006-pkl_cache.patch
${CAFFE2_PATCHES}/detectron/0007-weight_transfer.patch
)

list(APPEND CAFFE2_OPS
Expand Down
11 changes: 9 additions & 2 deletions examples/caffe2/detectron/README.md
Expand Up @@ -29,6 +29,8 @@ Download the weights (model_final.pkl) from the [model zoo](https://github.com/f

And find the corresponding training [configuration file](https://github.com/facebookresearch/Detectron/tree/master/configs) (*.yaml)

(You can also get some models on the ["Learning to Segment Every Thing"](https://github.com/ronghanghu/seg_every_thing#inference) repository)

Then convert them using the [tool](https://github.com/facebookresearch/Detectron/blob/master/tools/convert_pkl_to_pb.py) provided in the Detectron repository

And finally, place the .pb files into the model repository:
Expand Down Expand Up @@ -109,9 +111,14 @@ rmdir $WORKSPACE

- #### With its class labels

If you do not have access to the training dataset or the classes that were used during the training, it may imply that the default 80-classes COCO dataset was used. Its classes are listed [here](https://github.com/facebookresearch/Detectron/blob/master/detectron/datasets/dummy_datasets.py)
If you do not have access to the training dataset or the classes that were used during the training, it may imply that one of the following was used:

- the 80-classes COCO dataset
- the 3000-classes Visual Genome dataset

Their classes are listed [here](https://github.com/ronghanghu/seg_every_thing/blob/master/lib/datasets/dummy_datasets.py).

You can also generate a 'corresp.txt' file when converting the model (using the ```--coco``` flag of [this](convert_pkl_to_pb.py) script).
You can also generate a 'corresp.txt' file when converting the model (using the ```--corresp=XXX``` flag of [this](convert_pkl_to_pb.py) script).

### Create a service

Expand Down
63 changes: 50 additions & 13 deletions examples/caffe2/detectron/convert_pkl_to_pb.py
Expand Up @@ -6,7 +6,6 @@
import sys
import copy
import numpy
import cPickle
import argparse
from caffe2.python import core
from caffe2.proto import caffe2_pb2
Expand All @@ -16,10 +15,32 @@
from detectron.core.config import assert_and_infer_cfg
import detectron.core.test_engine as infer_engine
import detectron.utils.c2 as c2_utils
import detectron.utils.io as io_utils
import detectron.utils.model_convert_utils as mutils
import detectron.datasets.dummy_datasets as dummy_datasets
import tools.convert_pkl_to_pb as convert_tools

#############################################################
# Supposed to be used for "seg-every-thing models" training #
# Set for compatibility but currenlty ignored #
from detectron.utils.collections import AttrDict
cfg.MRCNN.BBOX2MASK = AttrDict()
cfg.MRCNN.BBOX2MASK.BBOX2MASK_ON = False
cfg.MRCNN.BBOX2MASK.TYPE = b''
cfg.MRCNN.BBOX2MASK.USE_PRETRAINED_EMBED = False
cfg.MRCNN.BBOX2MASK.PRETRAINED_EMBED_NAME = b''
cfg.MRCNN.BBOX2MASK.PRETRAINED_EMBED_DIM = -1
cfg.MRCNN.BBOX2MASK.STOP_DET_W_GRAD = True
cfg.MRCNN.BBOX2MASK.INCLUDE_CLS_SCORE = True
cfg.MRCNN.BBOX2MASK.INCLUDE_BBOX_PRED = False
cfg.MRCNN.BBOX2MASK.USE_LEAKYRELU = True
cfg.MRCNN.JOINT_FCN_MLP_HEAD = False
cfg.MRCNN.MLP_MASK_BRANCH_TYPE = b''
cfg.TRAIN.TRAIN_MASK_HEAD_ONLY = False
cfg.TRAIN.MRCNN_FILTER_LABELS = False
cfg.TRAIN.MRCNN_LABELS_TO_KEEP = ()
#############################################################

# Hardcoded values
class Constants:

Expand All @@ -36,6 +57,14 @@ class Constants:
@staticmethod
def fpn_level_suffix(level): return '_fpn' + str(level)

### Defined by Learning to Segment Every Thing

# In seg_every_thing/lib/modeling/mask_rcnn_heads.py (see bbox2mask_weight_transfer)
mask_w = 'mask_fcn_logits_w'
mask_w_flat = 'mask_fcn_logits_w_flat'
mask_w_flat_inputs = (mask_w_flat + '_w', mask_w_flat + '_b')
mask_w_size = 3002

### Defined by Deepdetect

# In deepdetect/src/backends/caffe2/nettools/internal.h
Expand Down Expand Up @@ -78,8 +107,9 @@ def parse_args():
parser.add_argument('--cfg', required=True, help='cfg model file')
parser.add_argument('--out_dir', required=True, help='output directory')
parser.add_argument('--mask_dir', type=str, help='mask extension directory')
parser.add_argument('--coco', action='store_true',
help='generate a corresp.txt file containing the 81 coco classes')
parser.add_argument('--corresp', default=None, choices=['coco', 'vg3k'],
help='generate a corresp.txt file containing the classes '
'(81 for the COCO dataset, 3002 for Visual Genome 3K)')
parser.add_argument('--net_name', default='detectron',
type=str, help='optional name for the net')
parser.add_argument('--fuse_af', default=1, type=int, help='1 to fuse_af')
Expand Down Expand Up @@ -107,9 +137,9 @@ def save_model(net, init_net, path):

def convert_main_net(args, main_net, blobs):
net = core.Net('')
net.Proto().op.extend(copy.deepcopy(main_net.Proto().op))
net.Proto().external_input.extend(copy.deepcopy(main_net.Proto().external_input))
net.Proto().external_output.extend(copy.deepcopy(main_net.Proto().external_output))
net.Proto().op.extend(copy.deepcopy(main_net.op))
net.Proto().external_input.extend(copy.deepcopy(main_net.external_input))
net.Proto().external_output.extend(copy.deepcopy(main_net.external_output))
net.Proto().type = args.net_execution_type
net.Proto().num_workers = 1 if args.net_execution_type == 'simple' else 4
convert_tools.convert_net(args, net.Proto(), blobs)
Expand All @@ -132,8 +162,8 @@ def convert_mask_net(args, mask_net):
# Initialization net
init_net = caffe2_pb2.NetDef()
net = caffe2_pb2.NetDef()
blobs = cPickle.load(open(args.wts))['blobs']
externals = set(c2_utils.UnscopeName(inp) for inp in mask_net.Proto().external_input)
blobs = io_utils.load_object(args.wts)['blobs']
externals = set(c2_utils.UnscopeName(inp) for inp in mask_net.external_input)
for name in set(blobs.keys()).intersection(externals):
blob = blobs[name]
add_custom_op(init_net, 'GivenTensorFill', [], [name],
Expand All @@ -155,7 +185,7 @@ def convert_mask_net(args, mask_net):
canon_level = cfg.FPN.ROI_CANONICAL_LEVEL)

# Generate the masks
net.op.extend(mask_net.Proto().op)
net.op.extend(mask_net.op)

# Post-process the masks
add_custom_op(net, 'SegmentMask',
Expand All @@ -167,17 +197,24 @@ def convert_mask_net(args, mask_net):
init_net.name = args.net_name + '_mask_init'
save_model(net, init_net, args.mask_dir)

def create_corresp_file(args, dataset):
classes = dataset.classes
corresp = '\n'.join('{} {}'.format(i, classes[i]) for i, _ in enumerate(classes))
with open(args.out_dir + '/corresp.txt', 'w') as f:
f.write(corresp)

def main():
args = parse_args()
merge_cfg_from_file(args.cfg)
merge_cfg_from_list(args.opts)
assert_and_infer_cfg()
model, blobs = convert_tools.load_model(args)
convert_main_net(args, model.net, blobs)

convert_main_net(args, model.net.Proto(), blobs)
if args.mask_dir:
convert_mask_net(args, model.mask_net)
if args.coco:
classes = dummy_datasets.get_coco_dataset().classes
convert_mask_net(args, model.mask_net.Proto())
if args.corresp:
classes = getattr(dummy_datasets, 'get_{}_dataset'.format(args.corresp))().classes
corresp = '\n'.join('{} {}'.format(i, classes[i]) for i, _ in enumerate(classes))
with open(args.out_dir + '/corresp.txt', 'w') as f:
f.write(corresp)
Expand Down
2 changes: 1 addition & 1 deletion examples/caffe2/detectron/masks.md
Expand Up @@ -47,7 +47,7 @@ python $DD_REPO/examples/caffe2/detectron/convert_pkl_to_pb.py \
--mask_dir deepdetect_model/mask \
--cfg detectron_model/config.yaml \
--wts detectron_model/weights.pkl \
--coco
--corresp=coco
# Register the service
Expand Down
6 changes: 3 additions & 3 deletions patches/caffe2/detectron/0001-dependencies.patch
@@ -1,7 +1,7 @@
From 90161fb5990829dd3e1315cd8393c46832166bc7 Mon Sep 17 00:00:00 2001
From 61ccadc4ea7d5b2d6513b4c7a56aaee25ce94d70 Mon Sep 17 00:00:00 2001
From: Julien CHICHA <julien.chicha@epitech.eu>
Date: Sat, 17 Nov 2018 17:38:25 +0100
Subject: [PATCH 1/4] dependencies
Subject: [PATCH 1/7] dependencies

---
detectron/core/test.py | 1 -
Expand Down Expand Up @@ -99,5 +99,5 @@ index d3c8833..ff63a01 100644
from detectron.utils.colormap import colormap
import detectron.utils.env as envu
--
2.19.1
2.20.1

6 changes: 3 additions & 3 deletions patches/caffe2/detectron/0002-compiled.patch
@@ -1,7 +1,7 @@
From 9b7e79c07ad4a1d93a65524c3e1c9f566bef0572 Mon Sep 17 00:00:00 2001
From 10cbbf3224eeaef4dcaf61216ed82cb239046b1a Mon Sep 17 00:00:00 2001
From: Julien CHICHA <julien.chicha@epitech.eu>
Date: Sat, 17 Nov 2018 17:39:11 +0100
Subject: [PATCH 2/4] compiled
Subject: [PATCH 2/7] compiled

---
detectron/core/test.py | 1 -
Expand Down Expand Up @@ -189,5 +189,5 @@ index e8e4637..aac99d8 100644

def parse_args():
--
2.19.1
2.20.1

6 changes: 3 additions & 3 deletions patches/caffe2/detectron/0003-ops.patch
@@ -1,7 +1,7 @@
From 6a3dc72684694d491cfc02fabe77e0d0baa62e4e Mon Sep 17 00:00:00 2001
From f750757337e12f2af7e4b3c7d242eaf431c03d90 Mon Sep 17 00:00:00 2001
From: Julien CHICHA <julien.chicha@epitech.eu>
Date: Sat, 17 Nov 2018 17:40:02 +0100
Subject: [PATCH 3/4] ops
Subject: [PATCH 3/7] ops

---
detectron/tests/test_batch_permutation_op.py | 1 -
Expand Down Expand Up @@ -124,5 +124,5 @@ index 9e757b5..afc6e28 100755
# OpenCL may be enabled by default in OpenCV3; disable it because it's not
# thread safe and causes unwanted GPU memory allocations.
--
2.19.1
2.20.1

6 changes: 3 additions & 3 deletions patches/caffe2/detectron/0004-import.patch
@@ -1,7 +1,7 @@
From 73c447c6ebbd06a09f1eb566075d18c3a1f5f104 Mon Sep 17 00:00:00 2001
From 5528725330a7a0b27a374ce8db47361bf60b0c18 Mon Sep 17 00:00:00 2001
From: Julien CHICHA <julien.chicha@epitech.eu>
Date: Sat, 17 Nov 2018 17:51:40 +0100
Subject: [PATCH 4/4] import
Subject: [PATCH 4/7] import

---
detectron/__init__.py | 3 +++
Expand All @@ -21,5 +21,5 @@ diff --git a/tools/__init__.py b/tools/__init__.py
new file mode 100644
index 0000000..e69de29
--
2.19.1
2.20.1

28 changes: 28 additions & 0 deletions patches/caffe2/detectron/0005-visual_genome.patch

Large diffs are not rendered by default.

60 changes: 60 additions & 0 deletions patches/caffe2/detectron/0006-pkl_cache.patch
@@ -0,0 +1,60 @@
From e3c030adf444c98c45e072b991de34ab9ef73ed7 Mon Sep 17 00:00:00 2001
From: Julien CHICHA <julien.chicha@epitech.eu>
Date: Wed, 23 Jan 2019 19:15:46 +0100
Subject: [PATCH 6/7] pkl_cache

---
detectron/utils/io.py | 20 ++++++++++++++++++--
1 file changed, 18 insertions(+), 2 deletions(-)

diff --git a/detectron/utils/io.py b/detectron/utils/io.py
index 2dbc8b1..0fe5c30 100644
--- a/detectron/utils/io.py
+++ b/detectron/utils/io.py
@@ -33,6 +33,9 @@ logger = logging.getLogger(__name__)

_DETECTRON_S3_BASE_URL = 'https://s3-us-west-2.amazonaws.com/detectron'

+# Keep a reference to pickled objects as a save/load can take several minutes in some cases
+# (e.g. weights from https://github.com/ronghanghu/seg_every_thing)
+_PKL_CACHE = {}

def save_object(obj, file_name, pickle_format=2):
"""Save a Python object by pickling it.
@@ -47,8 +50,17 @@ file is manifested or used, external to the system.
with open(file_name, 'wb') as f:
pickle.dump(obj, f, pickle_format)

+ # Save in cache
+ _PKL_CACHE[file_name] = obj

def load_object(file_name):
+
+ # Fetch from cache
+ file_name = os.path.abspath(file_name)
+ obj = _PKL_CACHE.get(file_name)
+ if obj is not None:
+ return obj
+
with open(file_name, 'rb') as f:
# The default encoding used while unpickling is 7-bit (ASCII.) However,
# the blobs are arbitrary 8-bit bytes which don't agree. The absolute
@@ -57,9 +69,13 @@ def load_object(file_name):
# reasonable fix, however, is to treat it the encoding as 8-bit latin1
# (which agrees with the first 256 characters of Unicode anyway.)
if six.PY2:
- return pickle.load(f)
+ obj = pickle.load(f)
else:
- return pickle.load(f, encoding='latin1')
+ obj = pickle.load(f, encoding='latin1')
+
+ # Save in cache
+ _PKL_CACHE[file_name] = obj
+ return obj


def cache_url(url_or_file, cache_dir):
--
2.20.1

0 comments on commit 8fcad37

Please sign in to comment.