Skip to content

Commit

Permalink
Add support for unpooling import from caffe and export to pytorch (#776)
Browse files Browse the repository at this point in the history
Signed-off-by: Sebastien ESKENAZI <s.eskenaz@pixelz.com>
  • Loading branch information
SebastienEske authored and rainLiuplus committed Jan 6, 2020
1 parent 55139e6 commit 9f75477
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 6 deletions.
28 changes: 27 additions & 1 deletion mmdnn/conversion/caffe/caffe_emitter.py
Expand Up @@ -172,6 +172,9 @@ def _get_symmetric_padding(self, IR_node):
pad_w = pads[2] + (0 if pads[2] == pads[6] else stride_w)
else:
pad_w = 0
elif IR_node.type == "Unpool":
pad_h = 0
pad_w = 0
else:
pad_h = pads[1] + (0 if pads[1] == pads[5] else stride_h)
pad_w = pads[2] + (0 if pads[2] == pads[6] else stride_w)
Expand Down Expand Up @@ -245,6 +248,9 @@ def compute_output_shape(self, IR_node, kernel_h, kernel_w):
if IR_node.type == 'Pool':
h_o = (h_i + 2 * pad_h - kernel_h + stride_h - 1) // stride_h + 1
w_o = (w_i + 2 * pad_w - kernel_w + stride_w - 1) // stride_w + 1
elif IR_node.type == 'Unpool':
h_o = (h_i - 2 * pad_h - kernel_h + stride_h) * stride_h
w_o = (w_i - 2 * pad_w - kernel_w + stride_w) * stride_w
else:
h_o = (h_i + 2 * pad_h - kernel_h) // stride_h + 1
w_o = (w_i + 2 * pad_w - kernel_w) // stride_w + 1
Expand All @@ -260,7 +266,7 @@ def check_if_need_crop(self, IR_node):
ir_wo = shape[2]
if ir_ho <0 or ir_wo<0:
return
if IR_node.type == 'Pool':
if IR_node.type == 'Pool' or IR_node.type == 'Unpool':
k_h = IR_node.get_attr('kernel_shape')[1]
k_w = IR_node.get_attr('kernel_shape')[2]
else:
Expand Down Expand Up @@ -331,6 +337,26 @@ def emit_Pool(self, IR_node):
# check if need crop output shape
self.check_if_need_crop(IR_node)

def emit_Unpool(self, IR_node):
pad_h, pad_w = self._get_symmetric_padding(IR_node)
pool_size = IR_node.get_attr('kernel_shape')[1:3]
if pool_size[0] != pool_size[1]:
self.add_body(1, "n.{:<15} = L.Unpooling(n.{}, kernel_h={}, kernel_w={}, stride={}, ntop=1)".format(
IR_node.variable_name,
self.parent_variable_name(IR_node),
pool_size[0],
pool_size[1],
IR_node.get_attr('strides')[1]))
else:
self.add_body(1, "n.{:<15} = L.Unpooling(n.{}, kernel_size={}, stride={}, ntop=1)".format(
IR_node.variable_name,
self.parent_variable_name(IR_node),
pool_size[0],
IR_node.get_attr('strides')[1]))

# check if need crop output shape
self.check_if_need_crop(IR_node)

def emit_ResizeBilinear(self, IR_node):
shape = IR_node.get_attr("_output_shapes")[0]
shape = shape_to_list(shape)
Expand Down
3 changes: 2 additions & 1 deletion mmdnn/conversion/caffe/graph.py
Expand Up @@ -89,6 +89,7 @@
'MultinomialLogisticLoss': shape_scalar,
'MVN': shape_not_implemented,
'Pooling': shape_pool,
'Unpooling': shape_unpool,
'Power': shape_identity,
'ReLU': shape_identity,
'Scale': shape_identity,
Expand Down Expand Up @@ -191,7 +192,7 @@ def get_kernel_value(scalar, repeated, idx, default=None):

@property
def kernel_parameters(self):
assert self.kind in (NodeKind.Convolution, NodeKind.Pooling, NodeKind.Deconvolution)
assert self.kind in (NodeKind.Convolution, NodeKind.Pooling, NodeKind.Unpooling, NodeKind.Deconvolution)
params = self.parameters
global_pooling = hasattr(params, 'global_pooling') and params.global_pooling
if not global_pooling:
Expand Down
12 changes: 11 additions & 1 deletion mmdnn/conversion/caffe/mapper.py
Expand Up @@ -67,7 +67,7 @@ def get_kernel_params(cls, node, input_shape):
else:
o_h_tf = (input_shape.height + node.kernel_parameters.p_h * 2 - ko_h + 1) // node.kernel_parameters.s_h
o_w_tf = (input_shape.width + node.kernel_parameters.p_w * 2 - ko_w + 1) // node.kernel_parameters.s_w

kwargs['pads'] = [0, node.kernel_parameters.p_h, node.kernel_parameters.p_w, 0] + \
[0, node.kernel_parameters.p_h + o_h_caffe - o_h_tf, node.kernel_parameters.p_w + o_w_caffe - o_w_tf, 0]

Expand Down Expand Up @@ -183,6 +183,16 @@ def map_pooling(cls, node):
return Node.create('Pool', **kwargs)


@classmethod
def map_unpooling(cls, node):
kwargs = {}
kwargs['kernel_shape'] = [1, node.kernel_parameters.k_h, node.kernel_parameters.k_w, 1]
kwargs['pads'] = [0, node.kernel_parameters.p_h, node.kernel_parameters.p_w, 0]
kwargs['strides'] = [1, node.kernel_parameters.s_h, node.kernel_parameters.s_w, 1]
cls._convert_output_shape(kwargs, node)
return Node.create('Unpool', **kwargs)


@classmethod
def _add_flatten_layer(cls, node):
shape = TensorShape()
Expand Down
4 changes: 4 additions & 0 deletions mmdnn/conversion/caffe/network.py
Expand Up @@ -69,6 +69,10 @@ def sigmoid(self, input, name):
def max_pool(self, input, k_h, k_w, s_h, s_w, p_h, p_w, name):
raise NotImplementedError('Must be implemented by the subclass')

@layer
def max_unpool(self, input, k_h, k_w, s_h, s_w, p_h, p_w, name):
raise NotImplementedError('Must be implemented by the subclass')

@layer
def avg_pool(self, input, k_h, k_w, s_h, s_w, p_h, p_w, name):
raise NotImplementedError('Must be implemented by the subclass')
Expand Down
2 changes: 2 additions & 0 deletions mmdnn/conversion/caffe/shape.py
Expand Up @@ -115,6 +115,8 @@ def shape_pool(node):
return shape_global_pooling(node)
return get_strided_kernel_output_shape(node, math.ceil)

def shape_unpool(node):
return get_strided_kernel_output_shape(node, math.ceil)

def shape_inner_product(node):
input_shape = node.get_only_parent()[0].output_shape
Expand Down
2 changes: 1 addition & 1 deletion mmdnn/conversion/caffe/transformer.py
Expand Up @@ -123,7 +123,7 @@ def map(self, node_kind):
raise ConversionError('Ordering not found for node kind: {}'.format(node_kind))

def _is_image_data(self, node):
return len([child for child in node.children if child.kind in (NodeKind.Convolution, NodeKind.Pooling)])
return len([child for child in node.children if child.kind in (NodeKind.Convolution, NodeKind.Pooling, NodeKind.Unpooling)])

def __call__(self, graph):
for node in graph.nodes:
Expand Down
57 changes: 57 additions & 0 deletions mmdnn/conversion/common/IR/ops.pbtxt
Expand Up @@ -594,6 +594,63 @@ op {
description: "tf.nn.pool defined"
}

op {
name: "Unpool"
attr {
name: "kernel_shape"
description: "Shape `[1, depth, height, wide, 1]`."
type: "list(int)"
}
attr {
name: "strides"
type: "list(int)"
description: "1-D tensor of length N. [1, stride_deep, stride_height, stride_width, 1]"
}
attr {
name: "audo_pad"
type: "string"
description: "The type of padding algorithm to use."
allowed_values {
list {
s: "SAME_UPPER"
s: "SAME_LOWER"
s: "VALID"
}
}
}
attr {
name: "pads"
type: "list(int)"
description: "1-D tensor of length N*2. [x1_begin, x2_begin...x1_end, x2_end,...]"
allowed_values {
list {
i: 0
f: 0.0
}
}
}
attr {
name: "data_format"
type: "string"
default_value {
s: "NHWC"
}
allowed_values {
list {
s: "NC"
s: "NWC"
s: "NCW"
s: "NHWC"
s: "NCHW"
s: "NDHWC"
s: "NCDHW"
}
}
}
summary: "Performs an N-D unpooling operation."
description: "tf.nn.unpool defined"
}

op {
name: "Mul"
summary: "Returns x * y element-wise."
Expand Down
27 changes: 25 additions & 2 deletions mmdnn/conversion/pytorch/pytorch_emitter.py
Expand Up @@ -231,14 +231,16 @@ def emit_Pool(self, IR_node):
pool_size = IR_node.get_attr('kernel_shape')[1:-1]
strides = IR_node.get_attr('strides')[1:-1]

code = "{:<15} = F.{}({}, kernel_size={}, stride={}, padding={}, ceil_mode={})".format(
code = "{}, {}_idx = F.{}({}, kernel_size={}, stride={}, padding={}, ceil_mode={}, return_indices={})".format(
IR_node.variable_name,
IR_node.variable_name,
pool_name,
input_node,
tuple(pool_size),
tuple(strides),
0,
False
False,
True
)
return code

Expand Down Expand Up @@ -267,6 +269,27 @@ def emit_Pool(self, IR_node):
else:
raise ValueError()

def emit_Unpool(self, IR_node):
dim = len(IR_node.get_attr('strides')) - 2

# Change to padding defuse
input_node = self.parent_variable_name(IR_node)
index_node = self.parent_variable_name(IR_node,[1])
pool_name = "max_unpool{}d".format(dim)
pool_size = IR_node.get_attr('kernel_shape')[1:-1]
strides = IR_node.get_attr('strides')[1:-1]

code = "{:<15} = F.{}({},{}_idx, kernel_size={}, stride={}, padding={})".format(
IR_node.variable_name,
pool_name,
input_node,
index_node,
tuple(pool_size),
tuple(strides),
0
)
return code


def emit_UNKNOWN(self, IR_node):
print(IR_node.name)
Expand Down

0 comments on commit 9f75477

Please sign in to comment.