diff --git a/demo/models/__init__.py b/demo/models/__init__.py index e843697407850..c6c0b76075bc1 100644 --- a/demo/models/__init__.py +++ b/demo/models/__init__.py @@ -1,5 +1,6 @@ from .mobilenet import MobileNet from .resnet import ResNet34, ResNet50 from .mobilenet_v2 import MobileNetV2 +from .pvanet import PVANet -__all__ = ['MobileNet', 'ResNet34', 'ResNet50', 'MobileNetV2'] +__all__ = ['MobileNet', 'ResNet34', 'ResNet50', 'MobileNetV2', 'PVANet'] diff --git a/demo/models/pvanet.py b/demo/models/pvanet.py new file mode 100644 index 0000000000000..6f5024c94f334 --- /dev/null +++ b/demo/models/pvanet.py @@ -0,0 +1,505 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import MSRA +from paddle.fluid.param_attr import ParamAttr +import os, sys, time, math +import numpy as np +from collections import namedtuple + +BLOCK_TYPE_MCRELU = 'BLOCK_TYPE_MCRELU' +BLOCK_TYPE_INCEP = 'BLOCK_TYPE_INCEP' +BlockConfig = namedtuple('BlockConfig', + 'stride, num_outputs, preact_bn, block_type') + +__all__ = ['PVANet'] + + +class PVANet(): + def __init__(self): + pass + + def net(self, input, include_last_bn_relu=True, class_dim=1000): + conv1 = self._conv_bn_crelu(input, 16, 7, stride=2, name="conv1_1") + pool1 = fluid.layers.pool2d( + input=conv1, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max', + name='pool1') + + end_points = {} + conv2 = self._conv_stage( + pool1, + block_configs=[ + BlockConfig(1, (24, 24, 48), False, BLOCK_TYPE_MCRELU), + BlockConfig(1, (24, 24, 48), True, BLOCK_TYPE_MCRELU), + BlockConfig(1, (24, 24, 48), True, BLOCK_TYPE_MCRELU) + ], + name='conv2', + end_points=end_points) + + conv3 = self._conv_stage( + conv2, + block_configs=[ + BlockConfig(2, (48, 48, 96), True, BLOCK_TYPE_MCRELU), + BlockConfig(1, (48, 48, 96), True, BLOCK_TYPE_MCRELU), + BlockConfig(1, (48, 48, 96), True, BLOCK_TYPE_MCRELU), + BlockConfig(1, (48, 48, 96), True, BLOCK_TYPE_MCRELU) + ], + name='conv3', + end_points=end_points) + + conv4 = self._conv_stage( + conv3, + block_configs=[ + BlockConfig(2, '64 48-96 24-48-48 96 128', True, + BLOCK_TYPE_INCEP), + BlockConfig(1, '64 64-96 24-48-48 128', True, + BLOCK_TYPE_INCEP), + BlockConfig(1, '64 64-96 24-48-48 128', True, + BLOCK_TYPE_INCEP), + BlockConfig(1, '64 64-96 24-48-48 128', True, BLOCK_TYPE_INCEP) + ], + name='conv4', + end_points=end_points) + + conv5 = self._conv_stage( + conv4, + block_configs=[ + BlockConfig(2, '64 96-128 32-64-64 128 196', True, + BLOCK_TYPE_INCEP), + BlockConfig(1, '64 96-128 32-64-64 196', True, + BLOCK_TYPE_INCEP), + BlockConfig(1, '64 96-128 32-64-64 196', True, + BLOCK_TYPE_INCEP), BlockConfig( + 1, '64 96-128 32-64-64 196', True, + BLOCK_TYPE_INCEP) + ], + name='conv5', + end_points=end_points) + + if include_last_bn_relu: + conv5 = self._bn(conv5, 'relu', 'conv5_4_last_bn') + end_points['conv5'] = conv5 + + output = fluid.layers.fc(input=input, + size=class_dim, + act='softmax', + param_attr=ParamAttr( + initializer=MSRA(), name="fc_weights"), + bias_attr=ParamAttr(name="fc_offset")) + + return output + + def _conv_stage(self, input, block_configs, name, end_points): + net = input + for idx, bc in enumerate(block_configs): + if bc.block_type == BLOCK_TYPE_MCRELU: + block_scope = '{}_{}'.format(name, idx + 1) + fn = self._mCReLU + elif bc.block_type == BLOCK_TYPE_INCEP: + block_scope = '{}_{}_incep'.format(name, idx + 1) + fn = self._inception_block + net = fn(net, bc, block_scope) + end_points[block_scope] = net + end_points[name] = net + return net + + def _mCReLU(self, input, mc_config, name): + """ + every cReLU has at least three conv steps: + conv_bn_relu, conv_bn_crelu, conv_bn_relu + if the inputs has a different number of channels as crelu output, + an extra 1x1 conv is added before sum. + """ + if mc_config.preact_bn: + conv1_fn = self._bn_relu_conv + conv1_scope = name + '_1' + else: + conv1_fn = self._conv + conv1_scope = name + '_1_conv' + + sub_conv1 = conv1_fn(input, mc_config.num_outputs[0], 1, conv1_scope, + mc_config.stride) + + sub_conv2 = self._bn_relu_conv(sub_conv1, mc_config.num_outputs[1], 3, + name + '_2') + + sub_conv3 = self._bn_crelu_conv(sub_conv2, mc_config.num_outputs[2], 1, + name + '_3') + + if int(input.shape[1]) == mc_config.num_outputs[2]: + conv_proj = input + else: + conv_proj = self._conv(input, mc_config.num_outputs[2], 1, + name + '_proj', mc_config.stride) + + conv = sub_conv3 + conv_proj + return conv + + def _inception_block(self, input, block_config, name): + num_outputs = block_config.num_outputs.split() # e.g. 64 24-48-48 128 + num_outputs = [map(int, s.split('-')) for s in num_outputs] + inception_outputs = num_outputs[-1][0] + num_outputs = num_outputs[:-1] + stride = block_config.stride + pool_path_outputs = None + if stride > 1: + pool_path_outputs = num_outputs[-1][0] + num_outputs = num_outputs[:-1] + + scopes = [['_0']] # follow the name style of caffe pva + kernel_sizes = [[1]] + for path_idx, path_outputs in enumerate(num_outputs[1:]): + path_idx += 1 + path_scopes = ['_{}_reduce'.format(path_idx)] + path_scopes.extend([ + '_{}_{}'.format(path_idx, i - 1) + for i in range(1, len(path_outputs)) + ]) + scopes.append(path_scopes) + + path_kernel_sizes = [1, 3, 3][:len(path_outputs)] + kernel_sizes.append(path_kernel_sizes) + + paths = [] + if block_config.preact_bn: + preact = self._bn(input, 'relu', name + '_bn') + else: + preact = input + + path_params = zip(num_outputs, scopes, kernel_sizes) + for path_idx, path_param in enumerate(path_params): + path_net = preact + for conv_idx, (num_output, scope, + kernel_size) in enumerate(zip(*path_param)): + if conv_idx == 0: + conv_stride = stride + else: + conv_stride = 1 + path_net = self._conv_bn_relu(path_net, num_output, + kernel_size, name + scope, + conv_stride) + paths.append(path_net) + + if stride > 1: + path_net = fluid.layers.pool2d( + input, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max', + name=name + '_pool') + path_net = self._conv_bn_relu(path_net, pool_path_outputs, 1, + name + '_poolproj') + paths.append(path_net) + block_net = fluid.layers.concat(paths, axis=1) + block_net = self._conv(block_net, inception_outputs, 1, + name + '_out_conv') + + if int(input.shape[1]) == inception_outputs: + proj = input + else: + proj = self._conv(input, inception_outputs, 1, name + '_proj', + stride) + return block_net + proj + + def _scale(self, input, name, axis=1, num_axes=1): + assert num_axes == 1, "layer scale not support this num_axes[%d] now" % ( + num_axes) + + prefix = name + '_' + scale_shape = input.shape[axis:axis + num_axes] + param_attr = fluid.ParamAttr(name=prefix + 'gamma') + scale_param = fluid.layers.create_parameter( + shape=scale_shape, + dtype=input.dtype, + name=name, + attr=param_attr, + is_bias=True, + default_initializer=fluid.initializer.Constant(value=1.0)) + + offset_attr = fluid.ParamAttr(name=prefix + 'beta') + offset_param = fluid.layers.create_parameter( + shape=scale_shape, + dtype=input.dtype, + name=name, + attr=offset_attr, + is_bias=True, + default_initializer=fluid.initializer.Constant(value=0.0)) + + output = fluid.layers.elementwise_mul( + input, scale_param, axis=axis, name=prefix + 'mul') + output = fluid.layers.elementwise_add( + output, offset_param, axis=axis, name=prefix + 'add') + return output + + def _conv(self, + input, + num_filters, + filter_size, + name, + stride=1, + groups=1, + act=None): + net = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=act, + use_cudnn=True, + param_attr=ParamAttr(name=name + '_weights'), + bias_attr=ParamAttr(name=name + '_bias'), + name=name) + return net + + def _bn(self, input, act, name): + net = fluid.layers.batch_norm( + input=input, + act=act, + name=name, + moving_mean_name=name + '_mean', + moving_variance_name=name + '_variance', + param_attr=ParamAttr(name=name + '_scale'), + bias_attr=ParamAttr(name=name + '_offset')) + return net + + def _bn_relu_conv(self, + input, + num_filters, + filter_size, + name, + stride=1, + groups=1): + + net = self._bn(input, 'relu', name + '_bn') + net = self._conv(net, num_filters, filter_size, name + '_conv', stride, + groups) + return net + + def _conv_bn_relu(self, + input, + num_filters, + filter_size, + name, + stride=1, + groups=1): + net = self._conv(input, num_filters, filter_size, name + '_conv', + stride, groups) + net = self._bn(net, 'relu', name + '_bn') + return net + + def _bn_crelu(self, input, name): + net = self._bn(input, None, name + '_bn_1') + neg_net = fluid.layers.scale(net, scale=-1.0, name=name + '_neg') + net = fluid.layers.concat([net, neg_net], axis=1) + net = self._scale(net, name + '_scale') + net = fluid.layers.relu(net, name=name + '_relu') + return net + + def _conv_bn_crelu(self, + input, + num_filters, + filter_size, + name, + stride=1, + groups=1, + act=None): + net = self._conv(input, num_filters, filter_size, name + '_conv', + stride, groups) + net = self._bn_crelu(net, name) + return net + + def _bn_crelu_conv(self, + input, + num_filters, + filter_size, + name, + stride=1, + groups=1, + act=None): + net = self._bn_crelu(input, name) + net = self._conv(net, num_filters, filter_size, name + '_conv', stride, + groups) + return net + + def deconv_bn_layer(self, + input, + num_filters, + filter_size=4, + stride=2, + padding=1, + act='relu', + name=None): + """Deconv bn layer.""" + deconv = fluid.layers.conv2d_transpose( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + act=None, + param_attr=ParamAttr(name=name + '_weights'), + bias_attr=ParamAttr(name=name + '_bias'), + name=name + 'deconv') + return self._bn(deconv, act, name + '_bn') + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + name, + stride=1, + groups=1): + return self._conv_bn_relu(input, num_filters, filter_size, name, + stride, groups) + + +def Fpn_Fusion(blocks, net): + f = [blocks['conv5'], blocks['conv4'], blocks['conv3'], blocks['conv2']] + num_outputs = [64] * len(f) + g = [None] * len(f) + h = [None] * len(f) + for i in range(len(f)): + h[i] = net.conv_bn_layer(f[i], num_outputs[i], 1, 'fpn_pre_' + str(i)) + + for i in range(len(f) - 1): + if i == 0: + g[i] = net.deconv_bn_layer(h[i], num_outputs[i], name='fpn_0') + else: + out = fluid.layers.elementwise_add(x=g[i - 1], y=h[i]) + out = net.conv_bn_layer(out, num_outputs[i], 1, + 'fpn_trans_' + str(i)) + g[i] = net.deconv_bn_layer( + out, num_outputs[i], name='fpn_' + str(i)) + + out = fluid.layers.elementwise_add(x=g[-2], y=h[-1]) + out = net.conv_bn_layer(out, num_outputs[-1], 1, 'fpn_post_0') + out = net.conv_bn_layer(out, num_outputs[-1], 3, 'fpn_post_1') + + return out + + +def Detector_Header(f_common, net, class_num): + """Detector header.""" + f_geo = net.conv_bn_layer(f_common, 64, 1, name='geo_1') + f_geo = net.conv_bn_layer(f_geo, 64, 3, name='geo_2') + f_geo = net.conv_bn_layer(f_geo, 64, 1, name='geo_3') + f_geo = fluid.layers.conv2d( + f_geo, + 8, + 1, + use_cudnn=True, + param_attr=ParamAttr(name='geo_4_conv_weights'), + bias_attr=ParamAttr(name='geo_4_conv_bias'), + name='geo_4_conv') + + name = 'score_class_num' + str(class_num + 1) + f_score = net.conv_bn_layer(f_common, 64, 1, 'score_1') + f_score = net.conv_bn_layer(f_score, 64, 3, 'score_2') + f_score = net.conv_bn_layer(f_score, 64, 1, 'score_3') + f_score = fluid.layers.conv2d( + f_score, + class_num + 1, + 1, + use_cudnn=True, + param_attr=ParamAttr(name=name + '_conv_weights'), + bias_attr=ParamAttr(name=name + '_conv_bias'), + name=name + '_conv') + + f_score = fluid.layers.transpose(f_score, perm=[0, 2, 3, 1]) + f_score = fluid.layers.reshape(f_score, shape=[-1, class_num + 1]) + f_score = fluid.layers.softmax(input=f_score) + + return f_score, f_geo + + +def east(input, class_num=31): + net = PVANet() + out = net.net(input) + blocks = [] + for i, j, k in zip(['conv2', 'conv3', 'conv4', 'conv5'], [1, 2, 4, 8], + [64, 64, 64, 64]): + if j == 1: + conv = net.conv_bn_layer( + out[i], k, 1, name='fusion_' + str(len(blocks))) + elif j <= 4: + conv = net.deconv_bn_layer( + out[i], k, 2 * j, j, j // 2, + name='fusion_' + str(len(blocks))) + else: + conv = net.deconv_bn_layer( + out[i], 32, 8, 4, 2, name='fusion_' + str(len(blocks)) + '_1') + conv = net.deconv_bn_layer( + conv, + k, + j // 2, + j // 4, + j // 8, + name='fusion_' + str(len(blocks)) + '_2') + blocks.append(conv) + conv = fluid.layers.concat(blocks, axis=1) + f_score, f_geo = Detector_Header(conv, net, class_num) + return f_score, f_geo + + +def inference(input, class_num=1, nms_thresh=0.2, score_thresh=0.5): + f_score, f_geo = east(input, class_num) + print("f_geo shape={}".format(f_geo.shape)) + print("f_score shape={}".format(f_score.shape)) + f_score = fluid.layers.transpose(f_score, perm=[1, 0]) + return f_score, f_geo + + +def loss(f_score, f_geo, l_score, l_geo, l_mask, class_num=1): + ''' + predictions: f_score: -1 x 1 x H x W; f_geo: -1 x 8 x H x W + targets: l_score: -1 x 1 x H x W; l_geo: -1 x 1 x H x W; l_mask: -1 x 1 x H x W + return: dice_loss + smooth_l1_loss + ''' + #smooth_l1_loss + channels = 8 + l_geo_split, l_short_edge = fluid.layers.split( + l_geo, num_or_sections=[channels, 1], + dim=1) #last channel is short_edge_norm + f_geo_split = fluid.layers.split(f_geo, num_or_sections=[channels], dim=1) + f_geo_split = f_geo_split[0] + + geo_diff = l_geo_split - f_geo_split + abs_geo_diff = fluid.layers.abs(geo_diff) + l_flag = l_score >= 1 + l_flag = fluid.layers.cast(x=l_flag, dtype="float32") + l_flag = fluid.layers.expand(x=l_flag, expand_times=[1, channels, 1, 1]) + + smooth_l1_sign = abs_geo_diff < l_flag + smooth_l1_sign = fluid.layers.cast(x=smooth_l1_sign, dtype="float32") + + in_loss = abs_geo_diff * abs_geo_diff * smooth_l1_sign + ( + abs_geo_diff - 0.5) * (1.0 - smooth_l1_sign) + l_short_edge = fluid.layers.expand( + x=l_short_edge, expand_times=[1, channels, 1, 1]) + out_loss = l_short_edge * in_loss * l_flag + out_loss = out_loss * l_flag + smooth_l1_loss = fluid.layers.reduce_mean(out_loss) + + ##softmax_loss + l_score.stop_gradient = True + l_score = fluid.layers.transpose(l_score, perm=[0, 2, 3, 1]) + l_score.stop_gradient = True + l_score = fluid.layers.reshape(l_score, shape=[-1, 1]) + l_score.stop_gradient = True + l_score = fluid.layers.cast(x=l_score, dtype="int64") + l_score.stop_gradient = True + + softmax_loss = fluid.layers.cross_entropy(input=f_score, label=l_score) + softmax_loss = fluid.layers.reduce_mean(softmax_loss) + + return softmax_loss, smooth_l1_loss diff --git a/demo/prune/train.py b/demo/prune/train.py index a8d923b3b9cda..b9899921b4396 100644 --- a/demo/prune/train.py +++ b/demo/prune/train.py @@ -40,11 +40,33 @@ model_list = [m for m in dir(models) if "__" not in m] +def get_pruned_params(args, program): + params = [] + if args.model == "MobileNet": + for param in program.global_block().all_parameters(): + if "_sep_weights" in param.name: + params.append(param.name) + elif args.model == "MobileNetV2": + for param in program.global_block().all_parameters(): + if "linear_weights" in param.name or "expand_weights" in param.name: + params.append(param.name) + elif args.model == "ResNet34": + for param in program.global_block().all_parameters(): + if "weights" in param.name and "branch" in param.name: + params.append(param.name) + elif args.model == "PVANet": + for param in program.global_block().all_parameters(): + if "conv_weights" in param.name: + params.append(param.name) + return params + + def piecewise_decay(args): step = int(math.ceil(float(args.total_images) / args.batch_size)) bd = [step * e for e in args.step_epochs] lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)] learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr) + optimizer = fluid.optimizer.Momentum( learning_rate=learning_rate, momentum=args.momentum_rate, @@ -176,14 +198,11 @@ def train(epoch, program): end_time - start_time)) batch_id += 1 - params = [] - for param in fluid.default_main_program().global_block().all_parameters(): - if "_sep_weights" in param.name: - params.append(param.name) - _logger.info("fops before pruning: {}".format( + params = get_pruned_params(args, fluid.default_main_program()) + _logger.info("FLOPs before pruning: {}".format( flops(fluid.default_main_program()))) pruner = Pruner() - pruned_val_program = pruner.prune( + pruned_val_program, _, _ = pruner.prune( val_program, fluid.global_scope(), params=params, @@ -191,19 +210,13 @@ def train(epoch, program): place=place, only_graph=True) - pruned_program = pruner.prune( + pruned_program, _, _ = pruner.prune( fluid.default_main_program(), fluid.global_scope(), params=params, ratios=[0.33] * len(params), place=place) - - for param in pruned_program[0].global_block().all_parameters(): - if "weights" in param.name: - print param.name, param.shape - return - _logger.info("fops after pruning: {}".format(flops(pruned_program))) - + _logger.info("FLOPs after pruning: {}".format(flops(pruned_program))) for i in range(args.num_epochs): train(i, pruned_program) if i % args.test_period == 0: diff --git a/docs/docs/tutorials/pruning_demo.md b/docs/docs/tutorials/pruning_demo.md new file mode 100755 index 0000000000000..1c97bff599d1b --- /dev/null +++ b/docs/docs/tutorials/pruning_demo.md @@ -0,0 +1,42 @@ +# 卷积通道剪裁示例 + +本示例将演示如何按指定的剪裁率对每个卷积层的通道数进行剪裁。该示例默认会自动下载并使用mnist数据。 + +当前示例支持以下分类模型: + +- MobileNetV1 +- MobileNetV2 +- ResNet50 +- PVANet + +## 接口介绍 + +该示例使用了`paddleslim.Pruner`工具类,用户接口使用介绍请参考:[API文档](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/) + +## 确定待裁参数 + +不同模型的参数命名不同,在剪裁前需要确定待裁卷积层的参数名称。可通过以下方法列出所有参数名: + +``` +for param in program.global_block().all_parameters(): + print("param name: {}; shape: {}".format(param.name, param.shape)) +``` + +在`train.py`脚本中,提供了`get_pruned_params`方法,根据用户设置的选项`--model`确定要裁剪的参数。 + +## 启动裁剪任务 + +通过以下命令启动裁剪任务: + +``` +export CUDA_VISIBLE_DEVICES=0 +python train.py +``` + +执行`python train.py --help`查看更多选项。 + +## 注意 + +1. 在接口`paddle.Pruner.prune`的参数中,`params`和`ratios`的长度需要一样。 + + diff --git a/paddleslim/analysis/flops.py b/paddleslim/analysis/flops.py index fed377db89e69..4e710fdc584d6 100644 --- a/paddleslim/analysis/flops.py +++ b/paddleslim/analysis/flops.py @@ -36,7 +36,7 @@ def flops(program, only_conv=True, detail=False): return _graph_flops(graph, only_conv=only_conv, detail=detail) -def _graph_flops(graph, only_conv=False, detail=False): +def _graph_flops(graph, only_conv=True, detail=False): assert isinstance(graph, GraphWrapper) flops = 0 params2flops = {} @@ -66,12 +66,14 @@ def _graph_flops(graph, only_conv=False, detail=False): y_shape = op.inputs("Y")[0].shape() if x_shape[0] == -1: x_shape[0] = 1 + flops += x_shape[0] * x_shape[1] * y_shape[1] op_flops = x_shape[0] * x_shape[1] * y_shape[1] flops += op_flops params2flops[op.inputs("Y")[0].name()] = op_flops - elif op.type() in ['relu', 'sigmoid', 'batch_norm', 'relu6'] and not only_conv: + elif op.type() in ['relu', 'sigmoid', 'batch_norm', 'relu6' + ] and not only_conv: input_shape = list(op.inputs("X")[0].shape()) if input_shape[0] == -1: input_shape[0] = 1 diff --git a/paddleslim/core/graph_wrapper.py b/paddleslim/core/graph_wrapper.py index dc01846a10feb..7ed99d2069676 100644 --- a/paddleslim/core/graph_wrapper.py +++ b/paddleslim/core/graph_wrapper.py @@ -93,6 +93,8 @@ def outputs(self): ops.append(op) return ops + def is_parameter(self): + return isinstance(self._var, Parameter) class OpWrapper(object): def __init__(self, op, graph): diff --git a/paddleslim/prune/__init__.py b/paddleslim/prune/__init__.py index d8c439be403ff..2ace7f600c143 100644 --- a/paddleslim/prune/__init__.py +++ b/paddleslim/prune/__init__.py @@ -23,6 +23,8 @@ import sensitive_pruner from .sensitive import * import sensitive +from prune_walker import * +import prune_walker __all__ = [] @@ -32,3 +34,4 @@ __all__ += controller_client.__all__ __all__ += sensitive_pruner.__all__ __all__ += sensitive.__all__ +__all__ += prune_walker.__all__ diff --git a/paddleslim/prune/prune_walker.py b/paddleslim/prune/prune_walker.py new file mode 100644 index 0000000000000..16edbd992b022 --- /dev/null +++ b/paddleslim/prune/prune_walker.py @@ -0,0 +1,525 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import numpy as np +from ..core import Registry +from ..common import get_logger + +__all__ = ["PRUNE_WORKER", "conv2d"] + +_logger = get_logger(__name__, level=logging.INFO) + +PRUNE_WORKER = Registry('prune_worker') + + +class PruneWorker(object): + def __init__(self, op, pruned_params=[], visited={}): + """ + A wrapper of operator used to infer the information of all the related variables. + + Args: + op(Operator): The operator to be pruned. + pruned_params(list): The list to store the information of pruning that infered by walker. + visited(dict): The auxiliary dict to record the visited operators and variables. The key is a encoded string of operator id and variable name. + + Return: A instance of PruneWalker. + """ + self.op = op + self.pruned_params = pruned_params + self.visited = visited + + def prune(self, var, pruned_axis, pruned_idx): + """ + Infer the shape of variables related with current operator, predecessor and successor. + It will search the graph to find all varibles related with `var` and record the information of pruning. + Args: + var(Variable): The root variable of searching. It can be the input or output of current operator. + pruned_axis(int): The axis to be pruned of root variable. + pruned_idx(int): The indexes to be pruned in `pruned_axis` of root variable. + """ + key = "_".join([str(self.op.idx()), var.name()]) + if pruned_axis not in self.visited: + self.visited[pruned_axis] = {} + if key in self.visited[pruned_axis]: + return + else: + self.visited[pruned_axis][key] = True + self._prune(var, pruned_axis, pruned_idx) + + def _prune(self, var, pruned_axis, pruned_idx): + raise NotImplementedError('Abstract method.') + + def _prune_op(self, op, var, pruned_axis, pruned_idx, visited=None): + if op.type().endswith("_grad"): + return + if visited is not None: + self.visited = visited + cls = PRUNE_WORKER.get(op.type()) + assert cls is not None, "The walker of {} is not registered.".format( + op.type()) + _logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format( + self.op, op, pruned_axis, var.name())) + walker = cls(op, + pruned_params=self.pruned_params, + visited=self.visited) + walker.prune(var, pruned_axis, pruned_idx) + + +@PRUNE_WORKER.register +class conv2d(PruneWorker): + def __init__(self, op, pruned_params, visited={}): + super(conv2d, self).__init__(op, pruned_params, visited) + + def _prune(self, var, pruned_axis, pruned_idx): + data_format = sef.op.attr("data_format") + channel_axis = 1 + if data_format == "NHWC": + channel_axis = 3 + if var in self.op.inputs("Input"): + assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format( + pruned_axis, var.name()) + filter_var = self.op.inputs("Filter")[0] + key = "_".join([str(self.op.idx()), filter_var.name()]) + self.visited[1][key] = True + self.pruned_params.append((filter_var, 1, pruned_idx)) + for op in filter_var.outputs(): + self._prune_op(op, filter_var, 1, pruned_idx) + + elif var in self.op.inputs("Filter"): + assert pruned_axis in [0, 1] + + self.pruned_params.append((var, pruned_axis, pruned_idx)) + + for op in var.outputs(): + self._prune_op(op, var, pruned_axis, pruned_idx) + + if pruned_axis == 0: + if len(self.op.inputs("Bias")) > 0: + self.pruned_params.append( + (self.op.inputs("Bias"), channel_axis, pruned_idx)) + output_var = self.op.outputs("Output")[0] + key = "_".join([str(self.op.idx()), output_var.name()]) + self.visited[channel_axis][key] = True + next_ops = output_var.outputs() + for op in next_ops: + self._prune_op(op, output_var, channel_axis, pruned_idx) + + elif pruned_axis == 1: + input_var = self.op.inputs("Input")[0] + key = "_".join([str(self.op.idx()), input_var.name()]) + self.visited[channel_axis][key] = True + pre_ops = input_var.inputs() + for op in pre_ops: + self._prune_op(op, input_var, channel_axis, pruned_idx) + elif var in self.op.outputs("Output"): + assert pruned_axis == channel_axis, "pruned_axis: {}; var: {}".format( + pruned_axis, var.name()) + + filter_var = self.op.inputs("Filter")[0] + key = "_".join([str(self.op.idx()), filter_var.name()]) + self.visited[0][key] = True + + self.pruned_params.append((filter_var, 0, pruned_idx)) + + for op in filter_var.outputs(): + self._prune_op(op, filter_var, 0, pruned_idx) + + if len(self.op.inputs("Bias")) > 0: + self.pruned_params.append( + (self.op.inputs("Bias")[0], channel_axis, pruned_idx)) + + output_var = self.op.outputs("Output")[0] + next_ops = output_var.outputs() + for op in next_ops: + self._prune_op(op, output_var, channel_axis, pruned_idx) + + +@PRUNE_WORKER.register +class batch_norm(PruneWorker): + def __init__(self, op, pruned_params, visited): + super(batch_norm, self).__init__(op, pruned_params, visited) + + def _prune(self, var, pruned_axis, pruned_idx): + if (var not in self.op.outputs("Y")) and ( + var not in self.op.inputs("X")): + return + + if var in self.op.outputs("Y"): + in_var = self.op.inputs("X")[0] + key = "_".join([str(self.op.idx()), in_var.name()]) + self.visited[pruned_axis][key] = True + pre_ops = in_var.inputs() + for op in pre_ops: + self._prune_op(op, in_var, pruned_axis, pruned_idx) + + for param in ["Scale", "Bias", "Mean", "Variance"]: + param_var = self.op.inputs(param)[0] + for op in param_var.outputs(): + self._prune_op(op, param_var, 0, pruned_idx) + self.pruned_params.append((param_var, 0, pruned_idx)) + + out_var = self.op.outputs("Y")[0] + key = "_".join([str(self.op.idx()), out_var.name()]) + self.visited[pruned_axis][key] = True + next_ops = out_var.outputs() + for op in next_ops: + self._prune_op(op, out_var, pruned_axis, pruned_idx) + + +class elementwise_op(PruneWorker): + def __init__(self, op, pruned_params, visited): + super(elementwise_op, self).__init__(op, pruned_params, visited) + + def _prune(self, var, pruned_axis, pruned_idx): + axis = self.op.attr("axis") + if axis == -1: # TODO + axis = 0 + if var in self.op.outputs("Out"): + for name in ["X", "Y"]: + actual_axis = pruned_axis + if name == "Y": + actual_axis = pruned_axis - axis + in_var = self.op.inputs(name)[0] + pre_ops = in_var.inputs() + for op in pre_ops: + self._prune_op(op, in_var, actual_axis, pruned_idx) + + else: + if var in self.op.inputs("X"): + in_var = self.op.inputs("Y")[0] + + if in_var.is_parameter(): + self.pruned_params.append( + (in_var, pruned_axis - axis, pruned_idx)) + pre_ops = in_var.inputs() + for op in pre_ops: + self._prune_op(op, in_var, pruned_axis - axis, pruned_idx) + elif var in self.op.inputs("Y"): + in_var = self.op.inputs("X")[0] + pre_ops = in_var.inputs() + pruned_axis = pruned_axis + axis + for op in pre_ops: + self._prune_op(op, in_var, pruned_axis, pruned_idx) + + out_var = self.op.outputs("Out")[0] + key = "_".join([str(self.op.idx()), out_var.name()]) + self.visited[pruned_axis][key] = True + next_ops = out_var.outputs() + for op in next_ops: + self._prune_op(op, out_var, pruned_axis, pruned_idx) + + +@PRUNE_WORKER.register +class elementwise_add(elementwise_op): + def __init__(self, op, pruned_params, visited): + super(elementwise_add, self).__init__(op, pruned_params, visited) + + +@PRUNE_WORKER.register +class elementwise_sub(elementwise_op): + def __init__(self, op, pruned_params, visited): + super(elementwise_sub, self).__init__(op, pruned_params, visited) + + +@PRUNE_WORKER.register +class elementwise_mul(elementwise_op): + def __init__(self, op, pruned_params, visited): + super(elementwise_mul, self).__init__(op, pruned_params, visited) + + +class activation(PruneWorker): + def __init__(self, op, pruned_params, visited): + super(activation, self).__init__(op, pruned_params, visited) + self.input_name = "X" + self.output_name = "Out" + + def _prune(self, var, pruned_axis, pruned_idx): + if var in self.op.outputs(self.output_name): + in_var = self.op.inputs(self.input_name)[0] + pre_ops = in_var.inputs() + for op in pre_ops: + self._prune_op(op, in_var, pruned_axis, pruned_idx) + + out_var = self.op.outputs(self.output_name)[0] + key = "_".join([str(self.op.idx()), out_var.name()]) + self.visited[pruned_axis][key] = True + next_ops = out_var.outputs() + for op in next_ops: + self._prune_op(op, out_var, pruned_axis, pruned_idx) + + +@PRUNE_WORKER.register +class uniform_random_batch_size_like(activation): + def __init__(self, op, pruned_params, visited): + super(uniform_random_batch_size_like, self).__init__(op, pruned_params, + visited) + self.input_name = "Input" + self.output_name = "Out" + + +@PRUNE_WORKER.register +class bilinear_interp(activation): + def __init__(self, op, pruned_params, visited): + super(bilinear_interp, self).__init__(op, pruned_params, visited) + + +@PRUNE_WORKER.register +class relu(activation): + def __init__(self, op, pruned_params, visited): + super(relu, self).__init__(op, pruned_params, visited) + + +@PRUNE_WORKER.register +class floor(activation): + def __init__(self, op, pruned_params, visited): + super(floor, self).__init__(op, pruned_params, visited) + + +@PRUNE_WORKER.register +class relu6(activation): + def __init__(self, op, pruned_params, visited): + super(relu6, self).__init__(op, pruned_params, visited) + + +@PRUNE_WORKER.register +class pool2d(activation): + def __init__(self, op, pruned_params, visited): + super(pool2d, self).__init__(op, pruned_params, visited) + + +@PRUNE_WORKER.register +class sum(PruneWorker): + def __init__(self, op, pruned_params, visited): + super(sum, self).__init__(op, pruned_params, visited) + + def _prune(self, var, pruned_axis, pruned_idx): + if var in self.op.outputs("Out"): + for in_var in self.op.inputs("X"): + pre_ops = in_var.inputs() + for op in pre_ops: + self._prune_op(op, in_var, pruned_axis, pruned_idx) + elif var in self.op.inputs("X"): + for in_var in self.op.inputs("X"): + if in_var != var: + pre_ops = in_var.inputs() + for op in pre_ops: + self._prune_op(op, in_var, pruned_axis, pruned_idx) + out_var = self.op.outputs("Out")[0] + key = "_".join([str(self.op.idx()), out_var.name()]) + self.visited[pruned_axis][key] = True + next_ops = out_var.outputs() + for op in next_ops: + self._prune_op(op, out_var, pruned_axis, pruned_idx) + + +@PRUNE_WORKER.register +class concat(PruneWorker): + def __init__(self, op, pruned_params, visited): + super(concat, self).__init__(op, pruned_params, visited) + + def _prune(self, var, pruned_axis, pruned_idx): + idx = [] + axis = self.op.attr("axis") + if var in self.op.outputs("Out"): + start = 0 + if axis == pruned_axis: + for _, in_var in enumerate(self.op.inputs("X")): + idx = [] + for i in pruned_idx: + r_idx = i - start + if r_idx < in_var.shape()[pruned_axis] and r_idx >= 0: + idx.append(r_idx) + start += in_var.shape()[pruned_axis] + + pre_ops = in_var.inputs() + for op in pre_ops: + self._prune_op(op, in_var, pruned_axis, idx) + idx = pruned_idx[:] + else: + for _, in_var in enumerate(self.op.inputs("X")): + pre_ops = in_var.inputs() + for op in pre_ops: + self._prune_op(op, in_var, pruned_axis, pruned_idx) + elif var in self.op.inputs("X"): + if axis == pruned_axis: + idx = [] + start = 0 + for v in self.op.inputs("X"): + if v.name() == var.name(): + idx = [i + start for i in pruned_idx] + else: + start += v.shape()[pruned_axis] + + out_var = self.op.outputs("Out")[0] + key = "_".join([str(self.op.idx()), out_var.name()]) + self.visited[pruned_axis][key] = True + next_ops = out_var.outputs() + for op in next_ops: + self._prune_op(op, out_var, pruned_axis, idx, visited={}) + else: + for v in self.op.inputs("X"): + for op in v.inputs(): + self._prune_op(op, v, pruned_axis, pruned_idx) + out_var = self.op.outputs("Out")[0] + key = "_".join([str(self.op.idx()), out_var.name()]) + self.visited[pruned_axis][key] = True + next_ops = out_var.outputs() + for op in next_ops: + self._prune_op(op, out_var, pruned_axis, pruned_idx) + + +@PRUNE_WORKER.register +class depthwise_conv2d(PruneWorker): + def __init__(self, op, pruned_params, visited={}): + super(depthwise_conv2d, self).__init__(op, pruned_params, visited) + + def _prune(self, var, pruned_axis, pruned_idx): + data_format = sef.op.attr("data_format") + channel_axis = 1 + if data_format == "NHWC": + channel_axis = 3 + if var in self.op.inputs("Input"): + assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format( + pruned_axis) + + filter_var = self.op.inputs("Filter")[0] + self.pruned_params.append((filter_var, 0, pruned_idx)) + key = "_".join([str(self.op.idx()), filter_var.name()]) + self.visited[0][key] = True + + new_groups = filter_var.shape()[0] - len(pruned_idx) + self.op.set_attr("groups", new_groups) + + for op in filter_var.outputs(): + self._prune_op(op, filter_var, 0, pruned_idx) + + output_var = self.op.outputs("Output")[0] + next_ops = output_var.outputs() + for op in next_ops: + self._prune_op(op, output_var, channel_axis, pruned_idx) + + elif var in self.op.inputs("Filter"): + assert pruned_axis in [0] + if pruned_axis == 0: + if len(self.op.inputs("Bias")) > 0: + self.pruned_params.append( + (self.op.inputs("Bias"), channel_axis, pruned_idx)) + + self.pruned_params.append((var, 0, pruned_idx)) + new_groups = var.shape()[0] - len(pruned_idx) + self.op.set_attr("groups", new_groups) + + for op in var.outputs(): + self._prune_op(op, var, 0, pruned_idx) + + output_var = self.op.outputs("Output")[0] + key = "_".join([str(self.op.idx()), output_var.name()]) + self.visited[channel_axis][key] = True + next_ops = output_var.outputs() + for op in next_ops: + self._prune_op(op, output_var, channel_axis, pruned_idx) + for op in var.outputs(): + self._prune_op(op, var, pruned_axis, pruned_idx) + elif var in self.op.outputs("Output"): + assert pruned_axis == channel_axis + filter_var = self.op.inputs("Filter")[0] + self.pruned_params.append((filter_var, 0, pruned_idx)) + key = "_".join([str(self.op.idx()), filter_var.name()]) + self.visited[0][key] = True + + new_groups = filter_var.shape()[0] - len(pruned_idx) + op.set_attr("groups", new_groups) + + for op in filter_var.outputs(): + self._prune_op(op, filter_var, 0, pruned_idx) + + if len(self.op.inputs("Bias")) > 0: + self.pruned_params.append( + (self.op.inputs("Bias")[0], channel_axis, pruned_idx)) + + in_var = self.op.inputs("Input")[0] + key = "_".join([str(self.op.idx()), in_var.name()]) + self.visited[channel_axis][key] = True + pre_ops = in_var.inputs() + for op in pre_ops: + self._prune_op(op, in_var, channel_axis, pruned_idx) + + output_var = self.op.outputs("Output")[0] + next_ops = output_var.outputs() + for op in next_ops: + self._prune_op(op, output_var, channel_axis, pruned_idx) + + +@PRUNE_WORKER.register +class mul(PruneWorker): + def __init__(self, op, pruned_params, visited={}): + super(mul, self).__init__(op, pruned_params, visited) + + def _prune(self, var, pruned_axis, pruned_idx): + if var in self.op.inputs("X"): + assert pruned_axis == 1, "The Input of conv2d can only be pruned at axis 1, but got {}".format( + pruned_axis) + idx = [] + feature_map_size = var.shape()[2] * var.shape()[3] + range_idx = np.array(range(feature_map_size)) + for i in pruned_idx: + idx += list(range_idx + i * feature_map_size) + param_var = self.op.inputs("Y")[0] + self.pruned_params.append((param_var, 0, idx)) + + for op in param_var.outputs(): + self._prune_op(op, param_var, 0, pruned_idx) + + +@PRUNE_WORKER.register +class scale(PruneWorker): + def __init__(self, op, pruned_params, visited={}): + super(scale, self).__init__(op, pruned_params, visited) + + def _prune(self, var, pruned_axis, pruned_idx): + if var in self.op.inputs("X"): + out_var = self.op.outputs("Out")[0] + for op in out_var.outputs(): + self._prune_op(op, out_var, pruned_axis, pruned_idx) + elif var in self.op.outputs("Out"): + in_var = self.op.inputs("X")[0] + for op in in_var.inputs(): + self._prune_op(op, in_var, pruned_axis, pruned_idx) + + +@PRUNE_WORKER.register +class momentum(PruneWorker): + def __init__(self, op, pruned_params, visited={}): + super(momentum, self).__init__(op, pruned_params, visited) + + def _prune(self, var, pruned_axis, pruned_idx): + if var in self.op.inputs("Param"): + _logger.debug("pruning momentum, var:{}".format(var.name())) + velocity_var = self.op.inputs("Velocity")[0] + self.pruned_params.append((velocity_var, pruned_axis, pruned_idx)) + + +@PRUNE_WORKER.register +class adam(PruneWorker): + def __init__(self, op, pruned_params, visited={}): + super(adam, self).__init__(op, pruned_params, visited) + + def _prune(self, var, pruned_axis, pruned_idx): + if var in self.op.inputs("Param"): + _logger.debug("pruning momentum, var:{}".format(var.name())) + moment1_var = self.op.inputs("Moment1")[0] + self.pruned_params.append((moment1_var, pruned_axis, pruned_idx)) + moment2_var = self.op.inputs("Moment2")[0] + self.pruned_params.append((moment2_var, pruned_axis, pruned_idx)) diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index 95f6774ce5a36..4f442eebf97a2 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -17,6 +17,7 @@ import paddle.fluid as fluid import copy from ..core import VarWrapper, OpWrapper, GraphWrapper +from .prune_walker import conv2d as conv2d_walker from ..common import get_logger __all__ = ["Pruner"] @@ -67,561 +68,60 @@ def prune(self, graph = GraphWrapper(program.clone()) param_backup = {} if param_backup else None param_shape_backup = {} if param_shape_backup else None - self._prune_parameters( - graph, - scope, - params, - ratios, - place, - lazy=lazy, - only_graph=only_graph, - param_backup=param_backup, - param_shape_backup=param_shape_backup) - for op in graph.ops(): - if op.type() == 'depthwise_conv2d' or op.type( - ) == 'depthwise_conv2d_grad': - op.set_attr('groups', op.inputs('Filter')[0].shape()[0]) - return graph.program, param_backup, param_shape_backup - - def _prune_filters_by_ratio(self, - scope, - params, - ratio, - place, - lazy=False, - only_graph=False, - param_shape_backup=None, - param_backup=None): - """ - Pruning filters by given ratio. - Args: - scope(fluid.core.Scope): The scope used to pruning filters. - params(list): A list of filter parameters. - ratio(float): The ratio to be pruned. - place(fluid.Place): The device place of filter parameters. - lazy(bool): True means setting the pruned elements to zero. - False means cutting down the pruned elements. - only_graph(bool): True means only modifying the graph. - False means modifying graph and variables in scope. - """ - if params[0].name() in self.pruned_list[0]: - return - - if only_graph: - pruned_num = int(round(params[0].shape()[0] * ratio)) - for param in params: - ori_shape = param.shape() - if param_backup is not None and ( - param.name() not in param_backup): - param_backup[param.name()] = copy.deepcopy(ori_shape) - new_shape = list(ori_shape) - new_shape[0] -= pruned_num - param.set_shape(new_shape) - _logger.debug("prune [{}] from {} to {}".format(param.name( - ), ori_shape, new_shape)) - self.pruned_list[0].append(param.name()) - return range(pruned_num) - - else: - - param_t = scope.find_var(params[0].name()).get_tensor() - pruned_idx = self._cal_pruned_idx( - params[0].name(), np.array(param_t), ratio, axis=0) - for param in params: - assert isinstance(param, VarWrapper) - param_t = scope.find_var(param.name()).get_tensor() - if param_backup is not None and ( - param.name() not in param_backup): - param_backup[param.name()] = copy.deepcopy( - np.array(param_t)) - try: - pruned_param = self._prune_tensor( - np.array(param_t), - pruned_idx, - pruned_axis=0, - lazy=lazy) - except IndexError as e: - _logger.error("Pruning {}, but get [{}]".format(param.name( - ), e)) - - param_t.set(pruned_param, place) - ori_shape = param.shape() - if param_shape_backup is not None and ( - param.name() not in param_shape_backup): - param_shape_backup[param.name()] = copy.deepcopy( - param.shape()) - new_shape = list(param.shape()) - new_shape[0] = pruned_param.shape[0] - param.set_shape(new_shape) - _logger.debug("prune [{}] from {} to {}".format(param.name( - ), ori_shape, new_shape)) - self.pruned_list[0].append(param.name()) - return pruned_idx - - def _prune_parameter_by_idx(self, - scope, - params, - pruned_idx, - pruned_axis, - place, - lazy=False, - only_graph=False, - param_shape_backup=None, - param_backup=None): - """ - Pruning parameters in given axis. - Args: - scope(fluid.core.Scope): The scope storing paramaters to be pruned. - params(VarWrapper): The parameter to be pruned. - pruned_idx(list): The index of elements to be pruned. - pruned_axis(int): The pruning axis. - place(fluid.Place): The device place of filter parameters. - lazy(bool): True means setting the pruned elements to zero. - False means cutting down the pruned elements. - only_graph(bool): True means only modifying the graph. - False means modifying graph and variables in scope. - """ - if params[0].name() in self.pruned_list[pruned_axis]: - return - if only_graph: - pruned_num = len(pruned_idx) - for param in params: - ori_shape = param.shape() - if param_backup is not None and ( - param.name() not in param_backup): - param_backup[param.name()] = copy.deepcopy(ori_shape) - new_shape = list(ori_shape) - new_shape[pruned_axis] -= pruned_num - param.set_shape(new_shape) - _logger.debug("prune [{}] from {} to {}".format(param.name( - ), ori_shape, new_shape)) - self.pruned_list[pruned_axis].append(param.name()) - - else: - for param in params: - assert isinstance(param, VarWrapper) - param_t = scope.find_var(param.name()).get_tensor() - if param_backup is not None and ( - param.name() not in param_backup): - param_backup[param.name()] = copy.deepcopy( - np.array(param_t)) - pruned_param = self._prune_tensor( - np.array(param_t), pruned_idx, pruned_axis, lazy=lazy) - param_t.set(pruned_param, place) - ori_shape = param.shape() - if param_shape_backup is not None and ( - param.name() not in param_shape_backup): - param_shape_backup[param.name()] = copy.deepcopy( - param.shape()) - new_shape = list(param.shape()) - new_shape[pruned_axis] = pruned_param.shape[pruned_axis] - param.set_shape(new_shape) - _logger.debug("prune [{}] from {} to {}".format(param.name( - ), ori_shape, new_shape)) - self.pruned_list[pruned_axis].append(param.name()) - - def _forward_search_related_op(self, graph, node): - """ - Forward search operators that will be affected by pruning of param. - Args: - graph(GraphWrapper): The graph to be searched. - node(VarWrapper|OpWrapper): The current pruned parameter or operator. - Returns: - list: A list of operators. - """ visited = {} - for op in graph.ops(): - visited[op.idx()] = False - stack = [] - visit_path = [] - if isinstance(node, VarWrapper): - for op in graph.ops(): - if (not op.is_bwd_op()) and (node in op.all_inputs()): - next_ops = self._get_next_unvisited_op(graph, visited, op) - # visit_path.append(op) - visited[op.idx()] = True - for next_op in next_ops: - if visited[next_op.idx()] == False: - stack.append(next_op) - visit_path.append(next_op) - visited[next_op.idx()] = True - elif isinstance(node, OpWrapper): - next_ops = self._get_next_unvisited_op(graph, visited, node) - for next_op in next_ops: - if visited[next_op.idx()] == False: - stack.append(next_op) - visit_path.append(next_op) - visited[next_op.idx()] = True - while len(stack) > 0: - #top_op = stack[len(stack) - 1] - top_op = stack.pop(0) - next_ops = None - if top_op.type() in ["conv2d", "deformable_conv"]: - next_ops = None - elif top_op.type() in ["mul", "concat"]: - next_ops = None - else: - next_ops = self._get_next_unvisited_op(graph, visited, top_op) - if next_ops != None: - for op in next_ops: - if visited[op.idx()] == False: - stack.append(op) - visit_path.append(op) - visited[op.idx()] = True - - return visit_path - - def _get_next_unvisited_op(self, graph, visited, top_op): - """ - Get next unvisited adjacent operators of given operators. - Args: - graph(GraphWrapper): The graph used to search. - visited(list): The ids of operators that has been visited. - top_op: The given operator. - Returns: - list: A list of operators. - """ - assert isinstance(top_op, OpWrapper) - next_ops = [] - for op in graph.next_ops(top_op): - if (visited[op.idx()] == False) and (not op.is_bwd_op()): - next_ops.append(op) - return next_ops - - def _get_accumulator(self, graph, param): - """ - Get accumulators of given parameter. The accumulator was created by optimizer. - Args: - graph(GraphWrapper): The graph used to search. - param(VarWrapper): The given parameter. - Returns: - list: A list of accumulators which are variables. - """ - assert isinstance(param, VarWrapper) - params = [] - for op in param.outputs(): - if op.is_opt_op(): - for out_var in op.all_outputs(): - if graph.is_persistable(out_var) and out_var.name( - ) != param.name(): - params.append(out_var) - return params - - def _forward_pruning_ralated_params(self, - graph, - scope, - param, - place, - ratio=None, - pruned_idxs=None, - lazy=False, - only_graph=False, - param_backup=None, - param_shape_backup=None): - """ - Pruning all the parameters affected by the pruning of given parameter. - Args: - graph(GraphWrapper): The graph to be searched. - scope(fluid.core.Scope): The scope storing paramaters to be pruned. - param(VarWrapper): The given parameter. - place(fluid.Place): The device place of filter parameters. - ratio(float): The target ratio to be pruned. - pruned_idx(list): The index of elements to be pruned. - lazy(bool): True means setting the pruned elements to zero. - False means cutting down the pruned elements. - only_graph(bool): True means only modifying the graph. - False means modifying graph and variables in scope. - """ - assert isinstance( - graph, - GraphWrapper), "graph must be instance of slim.core.GraphWrapper" - assert isinstance( - param, - VarWrapper), "param must be instance of slim.core.VarWrapper" - - if param.name() in self.pruned_list[0]: - return - related_ops = self._forward_search_related_op(graph, param) - for op in related_ops: - _logger.debug("relate op: {};".format(op)) - if ratio is None: - assert pruned_idxs is not None - self._prune_parameter_by_idx( - scope, [param] + self._get_accumulator(graph, param), - pruned_idxs, - pruned_axis=0, - place=place, - lazy=lazy, - only_graph=only_graph, - param_backup=param_backup, - param_shape_backup=param_shape_backup) - - else: - pruned_idxs = self._prune_filters_by_ratio( - scope, [param] + self._get_accumulator(graph, param), - ratio, - place, - lazy=lazy, - only_graph=only_graph, - param_backup=param_backup, - param_shape_backup=param_shape_backup) - self._prune_ops(related_ops, pruned_idxs, graph, scope, place, lazy, - only_graph, param_backup, param_shape_backup) - - def _prune_ops(self, ops, pruned_idxs, graph, scope, place, lazy, - only_graph, param_backup, param_shape_backup): - for idx, op in enumerate(ops): - if op.type() in ["conv2d", "deformable_conv"]: - for in_var in op.all_inputs(): - if graph.is_parameter(in_var): - conv_param = in_var - self._prune_parameter_by_idx( - scope, [conv_param] + self._get_accumulator( - graph, conv_param), - pruned_idxs, - pruned_axis=1, - place=place, - lazy=lazy, - only_graph=only_graph, - param_backup=param_backup, - param_shape_backup=param_shape_backup) - if op.type() == "depthwise_conv2d": - for in_var in op.all_inputs(): - if graph.is_parameter(in_var): - conv_param = in_var - self._prune_parameter_by_idx( - scope, [conv_param] + self._get_accumulator( - graph, conv_param), - pruned_idxs, - pruned_axis=0, - place=place, - lazy=lazy, - only_graph=only_graph, - param_backup=param_backup, - param_shape_backup=param_shape_backup) - elif op.type() == "elementwise_add": - # pruning bias - for in_var in op.all_inputs(): - if graph.is_parameter(in_var): - bias_param = in_var - self._prune_parameter_by_idx( - scope, [bias_param] + self._get_accumulator( - graph, bias_param), - pruned_idxs, - pruned_axis=0, - place=place, - lazy=lazy, - only_graph=only_graph, - param_backup=param_backup, - param_shape_backup=param_shape_backup) - elif op.type() == "mul": # pruning fc layer - fc_input = None - fc_param = None - for in_var in op.all_inputs(): - if graph.is_parameter(in_var): - fc_param = in_var - else: - fc_input = in_var - - idx = [] - feature_map_size = fc_input.shape()[2] * fc_input.shape()[3] - range_idx = np.array(range(feature_map_size)) - for i in pruned_idxs: - idx += list(range_idx + i * feature_map_size) - corrected_idxs = idx - self._prune_parameter_by_idx( - scope, [fc_param] + self._get_accumulator(graph, fc_param), - corrected_idxs, - pruned_axis=0, - place=place, - lazy=lazy, - only_graph=only_graph, - param_backup=param_backup, - param_shape_backup=param_shape_backup) - - elif op.type() == "concat": - concat_inputs = op.all_inputs() - last_op = ops[idx - 1] - concat_idx = None - for last_op in reversed(ops): - for out_var in last_op.all_outputs(): - if out_var in concat_inputs: - concat_idx = concat_inputs.index(out_var) - break - if concat_idx is not None: - break - offset = 0 - for ci in range(concat_idx): - offset += concat_inputs[ci].shape()[1] - corrected_idxs = [x + offset for x in pruned_idxs] - related_ops = self._forward_search_related_op(graph, op) - - for op in related_ops: - _logger.debug("concat relate op: {};".format(op)) - - self._prune_ops(related_ops, corrected_idxs, graph, scope, - place, lazy, only_graph, param_backup, - param_shape_backup) - elif op.type() == "batch_norm": - bn_inputs = op.all_inputs() - in_num = len(bn_inputs) - beta = bn_inputs[0] - mean = bn_inputs[1] - alpha = bn_inputs[2] - variance = bn_inputs[3] - self._prune_parameter_by_idx( - scope, [mean] + self._get_accumulator(graph, mean), - pruned_idxs, - pruned_axis=0, - place=place, - lazy=lazy, - only_graph=only_graph, - param_backup=param_backup, - param_shape_backup=param_shape_backup) - self._prune_parameter_by_idx( - scope, [variance] + self._get_accumulator(graph, variance), - pruned_idxs, - pruned_axis=0, - place=place, - lazy=lazy, - only_graph=only_graph, - param_backup=param_backup, - param_shape_backup=param_shape_backup) - self._prune_parameter_by_idx( - scope, [alpha] + self._get_accumulator(graph, alpha), - pruned_idxs, - pruned_axis=0, - place=place, - lazy=lazy, - only_graph=only_graph, - param_backup=param_backup, - param_shape_backup=param_shape_backup) - self._prune_parameter_by_idx( - scope, [beta] + self._get_accumulator(graph, beta), - pruned_idxs, - pruned_axis=0, - place=place, - lazy=lazy, - only_graph=only_graph, - param_backup=param_backup, - param_shape_backup=param_shape_backup) - - def _prune_parameters(self, - graph, - scope, - params, - ratios, - place, - lazy=False, - only_graph=False, - param_backup=None, - param_shape_backup=None): - """ - Pruning the given parameters. - Args: - graph(GraphWrapper): The graph to be searched. - scope(fluid.core.Scope): The scope storing paramaters to be pruned. - params(list): A list of parameter names to be pruned. - ratios(list): A list of ratios to be used to pruning parameters. - place(fluid.Place): The device place of filter parameters. - pruned_idx(list): The index of elements to be pruned. - lazy(bool): True means setting the pruned elements to zero. - False means cutting down the pruned elements. - only_graph(bool): True means only modifying the graph. - False means modifying graph and variables in scope. - """ - assert len(params) == len(ratios) - self.pruned_list = [[], []] + pruned_params = [] for param, ratio in zip(params, ratios): - assert isinstance(param, str) or isinstance(param, unicode) - if param in self.pruned_list[0]: - _logger.info("Skip {}".format(param)) - continue - _logger.info("pruning param: {}".format(param)) + if only_graph: + param_v = graph.var(param) + pruned_num = int(round(param_v.shape()[0] * ratio)) + pruned_idx = [0] * pruned_num + else: + param_t = np.array(scope.find_var(param).get_tensor()) + pruned_idx = self._cal_pruned_idx(param_t, ratio, axis=0) param = graph.var(param) - self._forward_pruning_ralated_params( - graph, - scope, - param, - place, - ratio=ratio, - lazy=lazy, - only_graph=only_graph, - param_backup=param_backup, - param_shape_backup=param_shape_backup) - ops = param.outputs() - for op in ops: - if op.type() in ['conv2d', 'deformable_conv']: - brother_ops = self._search_brother_ops(graph, op) - for broher in brother_ops: - _logger.debug("pruning brother: {}".format(broher)) - for p in graph.get_param_by_op(broher): - self._forward_pruning_ralated_params( - graph, - scope, - p, - place, - ratio=ratio, - lazy=lazy, - only_graph=only_graph, - param_backup=param_backup, - param_shape_backup=param_shape_backup) - - def _search_brother_ops(self, graph, op_node): - """ - Search brother operators that was affected by pruning of given operator. - Args: - graph(GraphWrapper): The graph to be searched. - op_node(OpWrapper): The start node for searching. - Returns: - list: A list of operators. - """ - _logger.debug("######################search: {}######################". - format(op_node)) - visited = [op_node.idx()] - stack = [] - brothers = [] - for op in graph.next_ops(op_node): - if ("conv2d" not in op.type()) and ( - "concat" not in op.type()) and ( - "deformable_conv" not in op.type()) and ( - op.type() != 'fc') and ( - not op.is_bwd_op()) and (not op.is_opt_op()): - stack.append(op) - visited.append(op.idx()) - while len(stack) > 0: - top_op = stack.pop() - for parent in graph.pre_ops(top_op): - if parent.idx() not in visited and ( - not parent.is_bwd_op()) and (not parent.is_opt_op()): - _logger.debug("----------go back from {} to {}----------". - format(top_op, parent)) - if (('conv2d' in parent.type()) or - ("deformable_conv" in parent.type()) or - (parent.type() == 'fc')): - brothers.append(parent) - else: - stack.append(parent) - visited.append(parent.idx()) - - for child in graph.next_ops(top_op): - if ('conv2d' not in child.type()) and ( - "concat" not in child.type()) and ( - 'deformable_conv' not in child.type()) and ( - child.type() != 'fc') and ( - child.idx() not in visited) and ( - not child.is_bwd_op()) and ( - not child.is_opt_op()): - stack.append(child) - visited.append(child.idx()) - _logger.debug("brothers: {}".format(brothers)) - _logger.debug( - "######################Finish search######################".format( - op_node)) - return brothers + conv_op = param.outputs()[0] + walker = conv2d_walker(conv_op,pruned_params=pruned_params, visited=visited) + walker.prune(param, pruned_axis=0, pruned_idx=pruned_idx) + + merge_pruned_params = {} + for param, pruned_axis, pruned_idx in pruned_params: + if param.name() not in merge_pruned_params: + merge_pruned_params[param.name()] = {} + if pruned_axis not in merge_pruned_params[param.name()]: + merge_pruned_params[param.name()][pruned_axis] = [] + merge_pruned_params[param.name()][pruned_axis].append(pruned_idx) + + for param_name in merge_pruned_params: + for pruned_axis in merge_pruned_params[param_name]: + pruned_idx = np.concatenate(merge_pruned_params[param_name][pruned_axis]) + param = graph.var(param_name) + _logger.debug("{}\t{}\t{}".format(param.name(), pruned_axis, len(pruned_idx))) + if param_shape_backup is not None: + origin_shape = copy.deepcopy(param.shape()) + param_shape_backup[param.name()] = origin_shape + new_shape = list(param.shape()) + new_shape[pruned_axis] -= len(pruned_idx) + param.set_shape(new_shape) + if not only_graph: + param_t = scope.find_var(param.name()).get_tensor() + if param_backup is not None and (param.name() not in param_backup): + param_backup[param.name()] = copy.deepcopy(np.array(param_t)) + try: + pruned_param = self._prune_tensor( + np.array(param_t), + pruned_idx, + pruned_axis=pruned_axis, + lazy=lazy) + except IndexError as e: + _logger.error("Pruning {}, but get [{}]".format(param.name( + ), e)) + + param_t.set(pruned_param, place) + + return graph.program, param_backup, param_shape_backup - def _cal_pruned_idx(self, name, param, ratio, axis): + def _cal_pruned_idx(self, param, ratio, axis): """ Calculate the index to be pruned on axis by given pruning ratio. Args: diff --git a/tests/test_prune.py b/tests/test_prune.py index 931cf9cf35429..60fe603ccd04e 100644 --- a/tests/test_prune.py +++ b/tests/test_prune.py @@ -15,7 +15,7 @@ sys.path.append("../") import unittest import paddle.fluid as fluid -from paddleslim.prune import Pruner +from paddleslim.prune.walk_pruner import Pruner from layers import conv_bn_layer @@ -72,6 +72,7 @@ def test_prune(self): for param in main_program.global_block().all_parameters(): if "weights" in param.name: + print("param: {}; param shape: {}".format(param.name, param.shape)) self.assertTrue(param.shape == shapes[param.name]) diff --git a/tests/test_prune_walker.py b/tests/test_prune_walker.py new file mode 100644 index 0000000000000..b80f6903904dd --- /dev/null +++ b/tests/test_prune_walker.py @@ -0,0 +1,64 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle.fluid as fluid +from paddleslim.prune import Pruner +from paddleslim.core import GraphWrapper +from paddleslim.prune import conv2d as conv2d_walker +from layers import conv_bn_layer + + +class TestPrune(unittest.TestCase): + def test_prune(self): + main_program = fluid.Program() + startup_program = fluid.Program() + # X X O X O + # conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6 + # | ^ | ^ + # |____________| |____________________| + # + # X: prune output channels + # O: prune input channels + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[None, 3, 16, 16]) + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + sum1 = conv1 + conv2 + conv3 = conv_bn_layer(sum1, 8, 3, "conv3") + conv4 = conv_bn_layer(conv3, 8, 3, "conv4") + sum2 = conv4 + sum1 + conv5 = conv_bn_layer(sum2, 8, 3, "conv5") + conv6 = conv_bn_layer(conv5, 8, 3, "conv6") + + shapes = {} + for param in main_program.global_block().all_parameters(): + shapes[param.name] = param.shape + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.Scope() + exe.run(startup_program, scope=scope) + + graph = GraphWrapper(main_program) + + conv_op = graph.var("conv4_weights").outputs()[0] + walker = conv2d_walker(conv_op, []) + walker.prune(graph.var("conv4_weights"), pruned_axis=0, pruned_idx=[]) + print walker.pruned_params + + +if __name__ == '__main__': + unittest.main()