From 3ce67e75399e827dfd47f9d531dc95adfa8186d9 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 28 Mar 2018 14:38:20 +0800 Subject: [PATCH 1/7] add SE-ResNeXt-152_parallel_exe --- fluid/SE-ResNeXt-152/run.sh | 2 +- .../{train.py => train_parallel_do.py} | 0 .../SE-ResNeXt-152/train_parallel_executor.py | 237 ++++++++++++++++++ 3 files changed, 238 insertions(+), 1 deletion(-) rename fluid/SE-ResNeXt-152/{train.py => train_parallel_do.py} (100%) create mode 100644 fluid/SE-ResNeXt-152/train_parallel_executor.py diff --git a/fluid/SE-ResNeXt-152/run.sh b/fluid/SE-ResNeXt-152/run.sh index 4c2a7e8..053076c 100644 --- a/fluid/SE-ResNeXt-152/run.sh +++ b/fluid/SE-ResNeXt-152/run.sh @@ -1 +1 @@ -env CUDA_VISIBLE_DEVICES=4 python train.py --use_nccl=False --parallel=False +env CUDA_VISIBLE_DEVICES=4 python train_parallel_do.py --use_nccl=False --parallel=False diff --git a/fluid/SE-ResNeXt-152/train.py b/fluid/SE-ResNeXt-152/train_parallel_do.py similarity index 100% rename from fluid/SE-ResNeXt-152/train.py rename to fluid/SE-ResNeXt-152/train_parallel_do.py diff --git a/fluid/SE-ResNeXt-152/train_parallel_executor.py b/fluid/SE-ResNeXt-152/train_parallel_executor.py new file mode 100644 index 0000000..3536b4a --- /dev/null +++ b/fluid/SE-ResNeXt-152/train_parallel_executor.py @@ -0,0 +1,237 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import argparse +import distutils.util +import numpy as np + +import paddle.v2 as paddle +import paddle.fluid as fluid +import paddle.v2.dataset.flowers as flowers +import paddle.fluid.profiler as profiler + + +def parse_args(): + parser = argparse.ArgumentParser( + 'SE-ResNeXt-152 parallel-executor profile.') + parser.add_argument( + '--use_mem_opt', + type=distutils.util.strtobool, + default=True, + help='use memory optimize') + parser.add_argument('--per_gpu_batch_size', type=int, default=12, help='') + parser.add_argument( + '--total_batch_num', + type=int, + default=40, + help='total batch num for per_gpu_batch_size') + parser.add_argument( + '--number_iteration', + type=int, + default=10, + help='total batch num for per_gpu_batch_size') + + args = parser.parse_args() + return args + + +def print_arguments(args): + print('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).iteritems()): + print('%s=%s' % (arg, value)) + + +def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1, + act=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) / 2, + groups=groups, + act=None, + bias_attr=False) + return fluid.layers.batch_norm(input=conv, act=act, momentum=0.1) + + +def squeeze_excitation(input, num_channels, reduction_ratio): + #pool = fluid.layers.pool2d( + # input=input, pool_size=0, pool_type='avg', global_pooling=True) + conv = input + shape = conv.shape + reshape = fluid.layers.reshape( + x=conv, shape=[-1, shape[1], shape[2] * shape[3]]) + pool = fluid.layers.reduce_mean(input=reshape, dim=2) + + squeeze = fluid.layers.fc(input=pool, + size=num_channels / reduction_ratio, + act='relu') + excitation = fluid.layers.fc(input=squeeze, + size=num_channels, + act='sigmoid') + scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) + return scale + + +def shortcut(input, ch_out, stride): + ch_in = input.shape[1] + if ch_in != ch_out: + if stride == 1: + filter_size = 1 + else: + filter_size = 3 + return conv_bn_layer(input, ch_out, filter_size, stride) + else: + return input + + +def bottleneck_block(input, num_filters, stride, cardinality, reduction_ratio): + # The number of first 1x1 convolutional channels for each bottleneck build block + # was halved to reduce the compution cost. + conv0 = conv_bn_layer( + input=input, num_filters=num_filters, filter_size=1, act='relu') + conv1 = conv_bn_layer( + input=conv0, + num_filters=num_filters * 2, + filter_size=3, + stride=stride, + groups=cardinality, + act='relu') + conv2 = conv_bn_layer( + input=conv1, num_filters=num_filters * 2, filter_size=1, act=None) + scale = squeeze_excitation( + input=conv2, + num_channels=num_filters * 2, + reduction_ratio=reduction_ratio) + + short = shortcut(input, num_filters * 2, stride) + + return fluid.layers.elementwise_add(x=short, y=scale, act='relu') + + +def SE_ResNeXt152(input, class_dim): + cardinality = 64 + reduction_ratio = 16 + depth = [3, 8, 36, 3] + num_filters = [128, 256, 512, 1024] + + conv = conv_bn_layer( + input=input, num_filters=64, filter_size=3, stride=2, act='relu') + conv = conv_bn_layer( + input=conv, num_filters=64, filter_size=3, stride=1, act='relu') + conv = conv_bn_layer( + input=conv, num_filters=128, filter_size=3, stride=1, act='relu') + conv = fluid.layers.pool2d( + input=conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') + + for block in range(len(depth)): + for i in range(depth[block]): + conv = bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + cardinality=cardinality, + reduction_ratio=reduction_ratio) + + #pool = fluid.layers.pool2d( + # input=conv, pool_size=0, pool_type='avg', global_pooling=True) + shape = conv.shape + reshape = fluid.layers.reshape( + x=conv, shape=[-1, shape[1], shape[2] * shape[3]]) + pool = fluid.layers.reduce_mean(input=reshape, dim=2) + #yancanxiang: A drop out layer(with a drop ratio of 0.2) was inserted before the classifier layer. + dropout = fluid.layers.dropout(x=pool, dropout_prob=0.2) + # Classifier layer: + out = fluid.layers.fc(input=dropout, size=class_dim, act='softmax') + return out + + +def net_conf(image, label, class_dim): + out = SE_ResNeXt152(input=image, class_dim=class_dim) + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + #accuracy = fluid.evaluator.Accuracy(input=out, label=label) + #accuracy5 = fluid.evaluator.Accuracy(input=out, label=label, k=5) + accuracy = fluid.layers.accuracy(input=out, label=label) + accuracy5 = fluid.layers.accuracy(input=out, label=label, k=5) + return out, avg_cost, accuracy, accuracy5 + + +def train(): + args = parse_args() + + cards = os.getenv("CUDA_VISIBLE_DEVICES") or "" + cards_num = len(cards.split(",")) + batch_size = args.per_gpu_batch_size * cards_num + + print_arguments(args) + print("cards_num=" + str(cards_num)) + print("batch_size=" + str(batch_size)) + + class_dim = 1000 + image_shape = [3, 224, 224] + + main = fluid.Program() + startup = fluid.Program() + + with fluid.program_guard(main, startup): + data_file = fluid.layers.open_recordio_file( + filename='./resnet_152.recordio_batch_size_12_3_224_224', # ./resnet_152.recordio_batch_size_2 + shapes=[[-1, 3, 224, 224], [-1, 1]], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + image, label = fluid.layers.read_file(data_file) + + prediction, avg_cost, accuracy, accuracy5 = net_conf(image, label, + class_dim) + + optimizer = fluid.optimizer.Momentum( + learning_rate=fluid.layers.piecewise_decay( + boundaries=[100], values=[0.1, 0.2]), + momentum=0.9, + regularization=fluid.regularizer.L2Decay(1e-4)) + opts = optimizer.minimize(avg_cost) + + if args.use_mem_opt: + fluid.memory_optimize(fluid.default_main_program()) + + exe = fluid.ParallelExecutor(loss_name=avg_cost.name, use_cuda=True) + + batch_id = 0 + time_record = [] + # with profiler.profiler('All', 'total', '/tmp/profile') as prof: + for i in xrange(args.number_iteration): + t1 = time.time() + exe.run([avg_cost.name] if batch_id % 10 == 0 else []) + t2 = time.time() + period = t2 - t1 + time_record.append(period) + + if batch_id % 10 == 0: + print("trainbatch {0}, time{1}".format(batch_id, + "%2.2f sec" % period)) + batch_id += 1 + + del time_record[0] + for ele in time_record: + print ele + + print("average time:{0}".format(np.mean(time_record))) + + +if __name__ == '__main__': + train() From daea2998533011223d3d80dc12aae5db6872e6ad Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 28 Mar 2018 16:01:00 +0800 Subject: [PATCH 2/7] fix SE-ResNeXt-152_parallel_exe --- .../SE-ResNeXt-152/train_parallel_executor.py | 94 ++++++++++--------- 1 file changed, 50 insertions(+), 44 deletions(-) diff --git a/fluid/SE-ResNeXt-152/train_parallel_executor.py b/fluid/SE-ResNeXt-152/train_parallel_executor.py index 3536b4a..0b5696f 100644 --- a/fluid/SE-ResNeXt-152/train_parallel_executor.py +++ b/fluid/SE-ResNeXt-152/train_parallel_executor.py @@ -25,19 +25,13 @@ def parse_args(): - parser = argparse.ArgumentParser( - 'SE-ResNeXt-152 parallel-executor profile.') + parser = argparse.ArgumentParser('SE-ResNeXt-152 parallel-executor model.') parser.add_argument( '--use_mem_opt', type=distutils.util.strtobool, default=True, help='use memory optimize') parser.add_argument('--per_gpu_batch_size', type=int, default=12, help='') - parser.add_argument( - '--total_batch_num', - type=int, - default=40, - help='total batch num for per_gpu_batch_size') parser.add_argument( '--number_iteration', type=int, @@ -65,18 +59,12 @@ def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1, groups=groups, act=None, bias_attr=False) - return fluid.layers.batch_norm(input=conv, act=act, momentum=0.1) + return fluid.layers.batch_norm(input=conv, act=act) def squeeze_excitation(input, num_channels, reduction_ratio): - #pool = fluid.layers.pool2d( - # input=input, pool_size=0, pool_type='avg', global_pooling=True) - conv = input - shape = conv.shape - reshape = fluid.layers.reshape( - x=conv, shape=[-1, shape[1], shape[2] * shape[3]]) - pool = fluid.layers.reduce_mean(input=reshape, dim=2) - + pool = fluid.layers.pool2d( + input=input, pool_size=0, pool_type='avg', global_pooling=True) squeeze = fluid.layers.fc(input=pool, size=num_channels / reduction_ratio, act='relu') @@ -100,13 +88,11 @@ def shortcut(input, ch_out, stride): def bottleneck_block(input, num_filters, stride, cardinality, reduction_ratio): - # The number of first 1x1 convolutional channels for each bottleneck build block - # was halved to reduce the compution cost. conv0 = conv_bn_layer( input=input, num_filters=num_filters, filter_size=1, act='relu') conv1 = conv_bn_layer( input=conv0, - num_filters=num_filters * 2, + num_filters=num_filters, filter_size=3, stride=stride, groups=cardinality, @@ -119,24 +105,47 @@ def bottleneck_block(input, num_filters, stride, cardinality, reduction_ratio): reduction_ratio=reduction_ratio) short = shortcut(input, num_filters * 2, stride) - return fluid.layers.elementwise_add(x=short, y=scale, act='relu') -def SE_ResNeXt152(input, class_dim): - cardinality = 64 - reduction_ratio = 16 - depth = [3, 8, 36, 3] - num_filters = [128, 256, 512, 1024] - - conv = conv_bn_layer( - input=input, num_filters=64, filter_size=3, stride=2, act='relu') - conv = conv_bn_layer( - input=conv, num_filters=64, filter_size=3, stride=1, act='relu') - conv = conv_bn_layer( - input=conv, num_filters=128, filter_size=3, stride=1, act='relu') - conv = fluid.layers.pool2d( - input=conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') +def SE_ResNeXt(input, class_dim, infer=False, layers=152): + supported_layers = [50, 152] + if layers not in supported_layers: + print("supported layers are", supported_layers, "but input layer is ", + layers) + exit() + if layers == 50: + cardinality = 32 + reduction_ratio = 16 + depth = [3, 4, 6, 3] + num_filters = [128, 256, 512, 1024] + + conv = conv_bn_layer( + input=input, num_filters=64, filter_size=7, stride=2, act='relu') + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + elif layers == 152: + cardinality = 64 + reduction_ratio = 16 + depth = [3, 8, 36, 3] + num_filters = [128, 256, 512, 1024] + + conv = conv_bn_layer( + input=input, num_filters=64, filter_size=3, stride=2, act='relu') + conv = conv_bn_layer( + input=conv, num_filters=64, filter_size=3, stride=1, act='relu') + conv = conv_bn_layer( + input=conv, num_filters=128, filter_size=3, stride=1, act='relu') + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') for block in range(len(depth)): for i in range(depth[block]): @@ -147,16 +156,13 @@ def SE_ResNeXt152(input, class_dim): cardinality=cardinality, reduction_ratio=reduction_ratio) - #pool = fluid.layers.pool2d( - # input=conv, pool_size=0, pool_type='avg', global_pooling=True) - shape = conv.shape - reshape = fluid.layers.reshape( - x=conv, shape=[-1, shape[1], shape[2] * shape[3]]) - pool = fluid.layers.reduce_mean(input=reshape, dim=2) - #yancanxiang: A drop out layer(with a drop ratio of 0.2) was inserted before the classifier layer. - dropout = fluid.layers.dropout(x=pool, dropout_prob=0.2) - # Classifier layer: - out = fluid.layers.fc(input=dropout, size=class_dim, act='softmax') + pool = fluid.layers.pool2d( + input=conv, pool_size=0, pool_type='avg', global_pooling=True) + if not infer: + drop = fluid.layers.dropout(x=pool, dropout_prob=0.2) + else: + drop = pool + out = fluid.layers.fc(input=drop, size=class_dim, act='softmax') return out From b9dd7762a96070808ce5a93ba06d16280cd973b8 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 28 Mar 2018 16:13:06 +0800 Subject: [PATCH 3/7] add mem_opt argument --- fluid/SE-ResNeXt-152/train_parallel_do.py | 49 +++++++++++++---------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/fluid/SE-ResNeXt-152/train_parallel_do.py b/fluid/SE-ResNeXt-152/train_parallel_do.py index c503204..6e32e55 100644 --- a/fluid/SE-ResNeXt-152/train_parallel_do.py +++ b/fluid/SE-ResNeXt-152/train_parallel_do.py @@ -26,6 +26,11 @@ def parse_args(): parser = argparse.ArgumentParser('SE-ResNeXt-152 parallel profile.') parser.add_argument('--per_gpu_batch_size', type=int, default=12, help='') + parser.add_argument( + '--use_mem_opt', + type=distutils.util.strtobool, + default=True, + help='use memory optimize') parser.add_argument( '--skip_first_steps', type=int, @@ -212,7 +217,8 @@ def train(): regularization=fluid.regularizer.L2Decay(1e-4)) opts = optimizer.minimize(avg_cost) - fluid.memory_optimize(fluid.default_main_program()) + if args.use_mem_opt: + fluid.memory_optimize(fluid.default_main_program()) place = fluid.CUDAPlace(0) # place = fluid.CPUPlace() @@ -227,27 +233,26 @@ def train(): feed_dict = feeder.feed(data) for pass_id in range(1): - with profiler.profiler('All', 'total', '/tmp/profile') as prof: - train_time = 0.0 - - for step_id in range(step_num): - train_start = time.time() - exe.run(fluid.default_main_program(), - feed=feeder.feed(train_reader_iter.next()) - if args.use_python_reader else feed_dict, - fetch_list=[], - use_program_cache=True) - train_stop = time.time() - step_time = train_stop - train_start - if step_id >= args.skip_first_steps: - train_time += step_time - print("step_id=" + str(step_id) + " step_time=" + str( - step_time)) - print("\n\n\n") - calc_step_num = step_num - args.skip_first_steps - print("calc_step_num=" + str(calc_step_num) + " total_train_time=" + - str(train_time) + " ave_step_time=" + str( - float(train_time) / calc_step_num)) + #with profiler.profiler('All', 'total', '/tmp/profile') as prof: + train_time = 0.0 + + for step_id in range(step_num): + train_start = time.time() + exe.run(fluid.default_main_program(), + feed=feeder.feed(train_reader_iter.next()) + if args.use_python_reader else feed_dict, + fetch_list=[], + use_program_cache=True) + train_stop = time.time() + step_time = train_stop - train_start + if step_id >= args.skip_first_steps: + train_time += step_time + print("step_id=" + str(step_id) + " step_time=" + str(step_time)) + print("\n\n\n") + calc_step_num = step_num - args.skip_first_steps + print("calc_step_num=" + str(calc_step_num) + " total_train_time=" + + str(train_time) + " ave_step_time=" + str( + float(train_time) / calc_step_num)) if __name__ == '__main__': From e8f4ff5cde739df0c8025fb3eadf12a6777708d9 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 2 Apr 2018 11:06:26 +0800 Subject: [PATCH 4/7] code refine --- .../SE-ResNeXt-152/train_parallel_executor.py | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/fluid/SE-ResNeXt-152/train_parallel_executor.py b/fluid/SE-ResNeXt-152/train_parallel_executor.py index 0b5696f..9c3e36d 100644 --- a/fluid/SE-ResNeXt-152/train_parallel_executor.py +++ b/fluid/SE-ResNeXt-152/train_parallel_executor.py @@ -35,8 +35,9 @@ def parse_args(): parser.add_argument( '--number_iteration', type=int, - default=10, + default=100, help='total batch num for per_gpu_batch_size') + parser.add_argument('--display_step', type=int, default=1, help='') args = parser.parse_args() return args @@ -167,7 +168,7 @@ def SE_ResNeXt(input, class_dim, infer=False, layers=152): def net_conf(image, label, class_dim): - out = SE_ResNeXt152(input=image, class_dim=class_dim) + out = SE_ResNeXt(input=image, class_dim=class_dim) cost = fluid.layers.cross_entropy(input=out, label=label) avg_cost = fluid.layers.mean(x=cost) #accuracy = fluid.evaluator.Accuracy(input=out, label=label) @@ -195,16 +196,18 @@ def train(): startup = fluid.Program() with fluid.program_guard(main, startup): - data_file = fluid.layers.open_recordio_file( - filename='./resnet_152.recordio_batch_size_12_3_224_224', # ./resnet_152.recordio_batch_size_2 + reader = fluid.layers.open_recordio_file( + filename='./flowers.recordio', shapes=[[-1, 3, 224, 224], [-1, 1]], lod_levels=[0, 0], dtypes=['float32', 'int64']) - image, label = fluid.layers.read_file(data_file) + # currently, double buffer only supports one device. + #data_file = fluid.layers.create_double_buffer_reader(reader=data_file, place='CUDA:0') + image, label = fluid.layers.read_file(reader) prediction, avg_cost, accuracy, accuracy5 = net_conf(image, label, class_dim) - + #optimizer = fluid.optimizer.SGD(learning_rate=0.002) optimizer = fluid.optimizer.Momentum( learning_rate=fluid.layers.piecewise_decay( boundaries=[100], values=[0.1, 0.2]), @@ -217,20 +220,26 @@ def train(): exe = fluid.ParallelExecutor(loss_name=avg_cost.name, use_cuda=True) - batch_id = 0 + batch_id = -1 time_record = [] - # with profiler.profiler('All', 'total', '/tmp/profile') as prof: + for i in xrange(args.number_iteration): + batch_id += 1 + if batch_id >= 5 and batch_id < 7: + with profiler.profiler('All', 'total', '/tmp/profile') as prof: + exe.run([]) + continue + t1 = time.time() - exe.run([avg_cost.name] if batch_id % 10 == 0 else []) + cost_val = exe.run([avg_cost.name] + if batch_id % args.display_step == 0 else []) t2 = time.time() period = t2 - t1 time_record.append(period) - if batch_id % 10 == 0: - print("trainbatch {0}, time{1}".format(batch_id, - "%2.2f sec" % period)) - batch_id += 1 + if batch_id % args.display_step == 0: + print("iter=%d, elapse=%f, cost=%s" % + (batch_id, period, np.array(cost_val[0]))) del time_record[0] for ele in time_record: From 4f4c4c9ed49269ef72f9143bff016935bd1fc3b3 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 2 Apr 2018 14:44:09 +0800 Subject: [PATCH 5/7] follow comment --- fluid/SE-ResNeXt-152/readme.md | 16 ++ fluid/SE-ResNeXt-152/run.sh | 1 - .../{train_parallel_executor.py => train.py} | 191 +++++++++++-- fluid/SE-ResNeXt-152/train_parallel_do.py | 259 ------------------ 4 files changed, 183 insertions(+), 284 deletions(-) create mode 100644 fluid/SE-ResNeXt-152/readme.md delete mode 100644 fluid/SE-ResNeXt-152/run.sh rename fluid/SE-ResNeXt-152/{train_parallel_executor.py => train.py} (56%) delete mode 100644 fluid/SE-ResNeXt-152/train_parallel_do.py diff --git a/fluid/SE-ResNeXt-152/readme.md b/fluid/SE-ResNeXt-152/readme.md new file mode 100644 index 0000000..8c11c89 --- /dev/null +++ b/fluid/SE-ResNeXt-152/readme.md @@ -0,0 +1,16 @@ +# Benchmark SE-ResNeXt-152 + +## For single card: +``` +env CUDA_VISIBLE_DEVICES=4 python train.py --use_parallel_mode=parallel_do --use_nccl=False --parallel=False --display_step=1 +``` + +## For multi-card: +### use parallel_do +``` +env CUDA_VISIBLE_DEVICES=4,5,6,7 python train.py --use_parallel_mode=parallel_do --use_nccl=True --parallel=True --display_step=1 +``` +### use parallel_exe +``` +env CUDA_VISIBLE_DEVICES=4,5,6,7 python train.py --use_parallel_mode=parallel_exe --use_nccl=True --parallel=True --display_step=1 +``` diff --git a/fluid/SE-ResNeXt-152/run.sh b/fluid/SE-ResNeXt-152/run.sh deleted file mode 100644 index 053076c..0000000 --- a/fluid/SE-ResNeXt-152/run.sh +++ /dev/null @@ -1 +0,0 @@ -env CUDA_VISIBLE_DEVICES=4 python train_parallel_do.py --use_nccl=False --parallel=False diff --git a/fluid/SE-ResNeXt-152/train_parallel_executor.py b/fluid/SE-ResNeXt-152/train.py similarity index 56% rename from fluid/SE-ResNeXt-152/train_parallel_executor.py rename to fluid/SE-ResNeXt-152/train.py index 9c3e36d..13fc951 100644 --- a/fluid/SE-ResNeXt-152/train_parallel_executor.py +++ b/fluid/SE-ResNeXt-152/train.py @@ -18,26 +18,65 @@ import distutils.util import numpy as np -import paddle.v2 as paddle +import paddle import paddle.fluid as fluid -import paddle.v2.dataset.flowers as flowers +import paddle.dataset.flowers as flowers import paddle.fluid.profiler as profiler +fluid.default_startup_program().random_seed = 111 + def parse_args(): - parser = argparse.ArgumentParser('SE-ResNeXt-152 parallel-executor model.') + parser = argparse.ArgumentParser('SE-ResNeXt-152 parallel profile.') + parser.add_argument( + '--class_number', type=int, default=1000, help='the class number') + parser.add_argument( + '--use_parallel_mode', + type=str, + default='parallel_exe', + choices=['parallel_do', 'parallel_exe'], + help='The parallel mode("parallel_do" or "parallel_exe").') + parser.add_argument('--batch_size', type=int, default=12, help='batch size') + parser.add_argument('--per_gpu_batch_size', type=int, default=12, help='') parser.add_argument( '--use_mem_opt', type=distutils.util.strtobool, default=True, - help='use memory optimize') - parser.add_argument('--per_gpu_batch_size', type=int, default=12, help='') + help='use memory optimize or not.') + parser.add_argument( + '--do_profile', + type=distutils.util.strtobool, + default=True, + help='do profile or not.') parser.add_argument( '--number_iteration', type=int, - default=100, - help='total batch num for per_gpu_batch_size') + default=50, + help='total batch num for per_gpu_batch_size.') parser.add_argument('--display_step', type=int, default=1, help='') + parser.add_argument( + '--skip_first_steps', + type=int, + default=2, + help='The first num of steps to skip, for better performance profile.') + parser.add_argument( + '--parallel', + type=distutils.util.strtobool, + default=True, + help='It is valid only when parallel_mode is parallel_do.') + parser.add_argument( + '--use_nccl', + type=distutils.util.strtobool, + default=True, + help='It is valid only when parallel_mode is parallel_do.') + parser.add_argument( + '--use_python_reader', + type=distutils.util.strtobool, + default=True, + help='It is valid only when parallel_mode is parallel_do.' + 'If use_python_reader is True, python reader is used to feeding data,' + 'the process includes data transfer from CPU to GPU. Otherwise, ' + 'the data which will be needed for training is in GPU side constantly.') args = parser.parse_args() return args @@ -45,8 +84,13 @@ def parse_args(): def print_arguments(args): print('----------- Configuration Arguments -----------') - for arg, value in sorted(vars(args).iteritems()): - print('%s=%s' % (arg, value)) + if args.use_parallel_mode == "parallel_do": + for arg, value in sorted(vars(args).iteritems()): + print('%s=%s' % (arg, value)) + else: + args.use_nccl = True + for arg, value in sorted(vars(args).iteritems()): + print('%s=%s' % (arg, value)) def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1, @@ -178,16 +222,99 @@ def net_conf(image, label, class_dim): return out, avg_cost, accuracy, accuracy5 -def train(): - args = parse_args() +def train_parallel_do(args): - cards = os.getenv("CUDA_VISIBLE_DEVICES") or "" - cards_num = len(cards.split(",")) - batch_size = args.per_gpu_batch_size * cards_num + class_dim = 1000 + image_shape = [3, 224, 224] - print_arguments(args) - print("cards_num=" + str(cards_num)) - print("batch_size=" + str(batch_size)) + image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + if args.parallel: + places = fluid.layers.get_places() + pd = fluid.layers.ParallelDo(places, use_nccl=args.use_nccl) + + with pd.do(): + image_ = pd.read_input(image) + label_ = pd.read_input(label) + out = SE_ResNeXt(input=image_, class_dim=class_dim) + cost = fluid.layers.cross_entropy(input=out, label=label_) + avg_cost = fluid.layers.mean(x=cost) + accuracy = fluid.layers.accuracy(input=out, label=label_) + pd.write_output(avg_cost) + pd.write_output(accuracy) + + avg_cost, accuracy = pd() + avg_cost = fluid.layers.mean(x=avg_cost) + accuracy = fluid.layers.mean(x=accuracy) + else: + out = SE_ResNeXt(input=image, class_dim=class_dim) + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + accuracy = fluid.layers.accuracy(input=out, label=label) + + #optimizer = fluid.optimizer.SGD(learning_rate=0.002) + optimizer = fluid.optimizer.Momentum( + learning_rate=fluid.layers.piecewise_decay( + boundaries=[100], values=[0.1, 0.2]), + momentum=0.9, + regularization=fluid.regularizer.L2Decay(1e-4)) + opts = optimizer.minimize(avg_cost) + + if args.use_mem_opt: + fluid.memory_optimize(fluid.default_main_program()) + + place = fluid.CUDAPlace(0) + # place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + train_reader = paddle.batch(flowers.train(), batch_size=args.batch_size) + + feeder = fluid.DataFeeder(place=place, feed_list=[image, label]) + train_reader_iter = train_reader() + if not args.use_python_reader: + data = train_reader_iter.next() + feed_dict = feeder.feed(data) + + time_record = [] + + for batch_id in range(args.number_iteration): + if args.do_profile and batch_id >= 5 and batch_id < 8: + with profiler.profiler('All', 'total', + '/tmp/profile_parallel_do') as prof: + exe.run(fluid.default_main_program(), + feed=feeder.feed(train_reader_iter.next()) + if args.use_python_reader else feed_dict, + fetch_list=[], + use_program_cache=True) + continue + + train_start = time.time() + cost_val = exe.run(fluid.default_main_program(), + feed=feeder.feed(train_reader_iter.next()) + if args.use_python_reader else feed_dict, + fetch_list=[avg_cost.name] + if batch_id % args.display_step == 0 else [], + use_program_cache=True) + train_stop = time.time() + step_time = train_stop - train_start + time_record.append(step_time) + + if batch_id % args.display_step == 0: + print("iter=%d, elapse=%f, cost=%s" % + (batch_id, step_time, np.array(cost_val[0]))) + + for _ in range(args.skip_first_steps): + del time_record[0] + + for ele in time_record: + print ele + + print("average time:{0}".format(np.mean(time_record))) + + +def train_parallel_exe(args): class_dim = 1000 image_shape = [3, 224, 224] @@ -201,12 +328,14 @@ def train(): shapes=[[-1, 3, 224, 224], [-1, 1]], lod_levels=[0, 0], dtypes=['float32', 'int64']) + # currently, double buffer only supports one device. #data_file = fluid.layers.create_double_buffer_reader(reader=data_file, place='CUDA:0') image, label = fluid.layers.read_file(reader) prediction, avg_cost, accuracy, accuracy5 = net_conf(image, label, class_dim) + #optimizer = fluid.optimizer.SGD(learning_rate=0.002) optimizer = fluid.optimizer.Momentum( learning_rate=fluid.layers.piecewise_decay( @@ -220,13 +349,13 @@ def train(): exe = fluid.ParallelExecutor(loss_name=avg_cost.name, use_cuda=True) - batch_id = -1 time_record = [] - for i in xrange(args.number_iteration): - batch_id += 1 - if batch_id >= 5 and batch_id < 7: - with profiler.profiler('All', 'total', '/tmp/profile') as prof: + for batch_id in xrange(args.number_iteration): + + if args.do_profile and batch_id >= 5 and batch_id < 8: + with profiler.profiler('All', 'total', + '/tmp/profile_parallel_exe') as prof: exe.run([]) continue @@ -241,7 +370,9 @@ def train(): print("iter=%d, elapse=%f, cost=%s" % (batch_id, period, np.array(cost_val[0]))) - del time_record[0] + for _ in range(args.skip_first_steps): + del time_record[0] + for ele in time_record: print ele @@ -249,4 +380,16 @@ def train(): if __name__ == '__main__': - train() + args = parse_args() + + cards = os.getenv("CUDA_VISIBLE_DEVICES") or "" + cards_num = len(cards.split(",")) + args.batch_size = args.per_gpu_batch_size * cards_num + + print_arguments(args) + print("cards_num=" + str(cards_num)) + + if args.use_parallel_mode == "parallel_do": + train_parallel_do(args) + else: + train_parallel_exe(args) diff --git a/fluid/SE-ResNeXt-152/train_parallel_do.py b/fluid/SE-ResNeXt-152/train_parallel_do.py deleted file mode 100644 index 6e32e55..0000000 --- a/fluid/SE-ResNeXt-152/train_parallel_do.py +++ /dev/null @@ -1,259 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import time -import argparse -import distutils.util - -import paddle.v2 as paddle -import paddle.fluid as fluid -import paddle.v2.dataset.flowers as flowers -import paddle.fluid.profiler as profiler - - -def parse_args(): - parser = argparse.ArgumentParser('SE-ResNeXt-152 parallel profile.') - parser.add_argument('--per_gpu_batch_size', type=int, default=12, help='') - parser.add_argument( - '--use_mem_opt', - type=distutils.util.strtobool, - default=True, - help='use memory optimize') - parser.add_argument( - '--skip_first_steps', - type=int, - default=2, - help='The first num of steps to skip, for better performance profile') - parser.add_argument( - '--total_batch_num', - type=int, - default=40, - help='total batch num for per_gpu_batch_size') - parser.add_argument( - '--parallel', - type=distutils.util.strtobool, - default=True, - help='use parallel_do') - parser.add_argument( - '--use_nccl', - type=distutils.util.strtobool, - default=False, - help='use_nccl') - parser.add_argument( - '--use_python_reader', - type=distutils.util.strtobool, - default=True, - help='use python reader to feed data') - - args = parser.parse_args() - return args - - -def print_arguments(args): - print('----------- Configuration Arguments -----------') - for arg, value in sorted(vars(args).iteritems()): - print('%s=%s' % (arg, value)) - - -def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1, - act=None): - conv = fluid.layers.conv2d( - input=input, - num_filters=num_filters, - filter_size=filter_size, - stride=stride, - padding=(filter_size - 1) / 2, - groups=groups, - act=None, - bias_attr=False) - return fluid.layers.batch_norm(input=conv, act=act) - - -def squeeze_excitation(input, num_channels, reduction_ratio): - pool = fluid.layers.pool2d( - input=input, pool_size=0, pool_type='avg', global_pooling=True) - squeeze = fluid.layers.fc(input=pool, - size=num_channels / reduction_ratio, - act='relu') - excitation = fluid.layers.fc(input=squeeze, - size=num_channels, - act='sigmoid') - scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) - return scale - - -def shortcut(input, ch_out, stride): - ch_in = input.shape[1] - if ch_in != ch_out: - if stride == 1: - filter_size = 1 - else: - filter_size = 3 - return conv_bn_layer(input, ch_out, filter_size, stride) - else: - return input - - -def bottleneck_block(input, num_filters, stride, cardinality, reduction_ratio): - conv0 = conv_bn_layer( - input=input, num_filters=num_filters, filter_size=1, act='relu') - conv1 = conv_bn_layer( - input=conv0, - num_filters=num_filters, - filter_size=3, - stride=stride, - groups=cardinality, - act='relu') - conv2 = conv_bn_layer( - input=conv1, num_filters=num_filters * 2, filter_size=1, act=None) - scale = squeeze_excitation( - input=conv2, - num_channels=num_filters * 2, - reduction_ratio=reduction_ratio) - - short = shortcut(input, num_filters * 2, stride) - - return fluid.layers.elementwise_add(x=short, y=scale, act='relu') - - -def SE_ResNeXt(input, class_dim, infer=False): - cardinality = 64 - reduction_ratio = 16 - depth = [3, 8, 36, 3] - num_filters = [128, 256, 512, 1024] - - conv = conv_bn_layer( - input=input, num_filters=64, filter_size=3, stride=2, act='relu') - conv = conv_bn_layer( - input=conv, num_filters=64, filter_size=3, stride=1, act='relu') - conv = conv_bn_layer( - input=conv, num_filters=128, filter_size=3, stride=1, act='relu') - conv = fluid.layers.pool2d( - input=conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') - - for block in range(len(depth)): - for i in range(depth[block]): - conv = bottleneck_block( - input=conv, - num_filters=num_filters[block], - stride=2 if i == 0 and block != 0 else 1, - cardinality=cardinality, - reduction_ratio=reduction_ratio) - - pool = fluid.layers.pool2d( - input=conv, pool_size=0, pool_type='avg', global_pooling=True) - if not infer: - drop = fluid.layers.dropout(x=pool, dropout_prob=0.2) - else: - drop = pool - out = fluid.layers.fc(input=drop, size=class_dim, act='softmax') - return out - - -def time_stamp(): - return int(round(time.time() * 1000)) - - -def train(): - args = parse_args() - - cards = os.getenv("CUDA_VISIBLE_DEVICES") or "" - cards_num = len(cards.split(",")) - step_num = args.total_batch_num / cards_num - batch_size = args.per_gpu_batch_size * cards_num - - print_arguments(args) - print("cards_num=" + str(cards_num)) - print("batch_size=" + str(batch_size)) - print("total_batch_num=" + str(args.total_batch_num)) - print("step_num=" + str(step_num)) - - class_dim = 1000 - image_shape = [3, 224, 224] - - image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') - label = fluid.layers.data(name='label', shape=[1], dtype='int64') - - if args.parallel: - places = fluid.layers.get_places() - pd = fluid.layers.ParallelDo(places, use_nccl=args.use_nccl) - - with pd.do(): - image_ = pd.read_input(image) - label_ = pd.read_input(label) - out = SE_ResNeXt(input=image_, class_dim=class_dim) - cost = fluid.layers.cross_entropy(input=out, label=label_) - avg_cost = fluid.layers.mean(x=cost) - accuracy = fluid.layers.accuracy(input=out, label=label_) - pd.write_output(avg_cost) - pd.write_output(accuracy) - - avg_cost, accuracy = pd() - avg_cost = fluid.layers.mean(x=avg_cost) - accuracy = fluid.layers.mean(x=accuracy) - else: - out = SE_ResNeXt(input=image, class_dim=class_dim) - cost = fluid.layers.cross_entropy(input=out, label=label) - avg_cost = fluid.layers.mean(x=cost) - accuracy = fluid.layers.accuracy(input=out, label=label) - - #optimizer = fluid.optimizer.SGD(learning_rate=0.002) - optimizer = fluid.optimizer.Momentum( - learning_rate=fluid.layers.piecewise_decay( - boundaries=[100], values=[0.1, 0.2]), - momentum=0.9, - regularization=fluid.regularizer.L2Decay(1e-4)) - opts = optimizer.minimize(avg_cost) - - if args.use_mem_opt: - fluid.memory_optimize(fluid.default_main_program()) - - place = fluid.CUDAPlace(0) - # place = fluid.CPUPlace() - exe = fluid.Executor(place) - exe.run(fluid.default_startup_program()) - - train_reader = paddle.batch(flowers.train(), batch_size=batch_size) - test_reader = paddle.batch(flowers.test(), batch_size=batch_size) - feeder = fluid.DataFeeder(place=place, feed_list=[image, label]) - train_reader_iter = train_reader() - data = train_reader_iter.next() - feed_dict = feeder.feed(data) - - for pass_id in range(1): - #with profiler.profiler('All', 'total', '/tmp/profile') as prof: - train_time = 0.0 - - for step_id in range(step_num): - train_start = time.time() - exe.run(fluid.default_main_program(), - feed=feeder.feed(train_reader_iter.next()) - if args.use_python_reader else feed_dict, - fetch_list=[], - use_program_cache=True) - train_stop = time.time() - step_time = train_stop - train_start - if step_id >= args.skip_first_steps: - train_time += step_time - print("step_id=" + str(step_id) + " step_time=" + str(step_time)) - print("\n\n\n") - calc_step_num = step_num - args.skip_first_steps - print("calc_step_num=" + str(calc_step_num) + " total_train_time=" + - str(train_time) + " ave_step_time=" + str( - float(train_time) / calc_step_num)) - - -if __name__ == '__main__': - train() From b105ee0fe25dc16cddf39910470da7742a24c2e0 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 2 Apr 2018 16:01:24 +0800 Subject: [PATCH 6/7] add generate flowers recordio --- .../generate_flowers_recordio.py | 34 +++++++++++++++++++ fluid/SE-ResNeXt-152/readme.md | 2 +- fluid/SE-ResNeXt-152/train.py | 34 +++++++++---------- 3 files changed, 51 insertions(+), 19 deletions(-) create mode 100644 fluid/SE-ResNeXt-152/generate_flowers_recordio.py diff --git a/fluid/SE-ResNeXt-152/generate_flowers_recordio.py b/fluid/SE-ResNeXt-152/generate_flowers_recordio.py new file mode 100644 index 0000000..08f7e8b --- /dev/null +++ b/fluid/SE-ResNeXt-152/generate_flowers_recordio.py @@ -0,0 +1,34 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.dataset.flowers as flowers + +batch_size = 12 +data_shape = [3, 224, 224] + +with fluid.program_guard(fluid.Program(), fluid.Program()): + reader = paddle.batch(flowers.train(), batch_size=batch_size) + feeder = fluid.DataFeeder( + feed_list=[ # order is image and label + fluid.layers.data( + name='image', shape=data_shape, dtype='float32'), + fluid.layers.data( + name='label', shape=[1], dtype='int64'), + ], + place=fluid.CPUPlace()) + fluid.recordio_writer.convert_reader_to_recordio_file( + './flowers_bs_12_3_224_224.recordio', reader, feeder) diff --git a/fluid/SE-ResNeXt-152/readme.md b/fluid/SE-ResNeXt-152/readme.md index 8c11c89..de00156 100644 --- a/fluid/SE-ResNeXt-152/readme.md +++ b/fluid/SE-ResNeXt-152/readme.md @@ -2,7 +2,7 @@ ## For single card: ``` -env CUDA_VISIBLE_DEVICES=4 python train.py --use_parallel_mode=parallel_do --use_nccl=False --parallel=False --display_step=1 +env CUDA_VISIBLE_DEVICES=0 python train.py --use_parallel_mode=parallel_do --use_nccl=False --parallel=False --display_step=1 ``` ## For multi-card: diff --git a/fluid/SE-ResNeXt-152/train.py b/fluid/SE-ResNeXt-152/train.py index 13fc951..c4908d9 100644 --- a/fluid/SE-ResNeXt-152/train.py +++ b/fluid/SE-ResNeXt-152/train.py @@ -222,6 +222,19 @@ def net_conf(image, label, class_dim): return out, avg_cost, accuracy, accuracy5 +def add_optimizer(args, avg_cost): + #optimizer = fluid.optimizer.SGD(learning_rate=0.002) + optimizer = fluid.optimizer.Momentum( + learning_rate=fluid.layers.piecewise_decay( + boundaries=[100], values=[0.1, 0.2]), + momentum=0.9, + regularization=fluid.regularizer.L2Decay(1e-4)) + optimizer.minimize(avg_cost) + + if args.use_mem_opt: + fluid.memory_optimize(fluid.default_main_program()) + + def train_parallel_do(args): class_dim = 1000 @@ -253,16 +266,7 @@ def train_parallel_do(args): avg_cost = fluid.layers.mean(x=cost) accuracy = fluid.layers.accuracy(input=out, label=label) - #optimizer = fluid.optimizer.SGD(learning_rate=0.002) - optimizer = fluid.optimizer.Momentum( - learning_rate=fluid.layers.piecewise_decay( - boundaries=[100], values=[0.1, 0.2]), - momentum=0.9, - regularization=fluid.regularizer.L2Decay(1e-4)) - opts = optimizer.minimize(avg_cost) - - if args.use_mem_opt: - fluid.memory_optimize(fluid.default_main_program()) + add_optimizer(args, avg_cost) place = fluid.CUDAPlace(0) # place = fluid.CPUPlace() @@ -324,7 +328,7 @@ def train_parallel_exe(args): with fluid.program_guard(main, startup): reader = fluid.layers.open_recordio_file( - filename='./flowers.recordio', + filename='./flowers_bs_12_3_224_224.recordio', shapes=[[-1, 3, 224, 224], [-1, 1]], lod_levels=[0, 0], dtypes=['float32', 'int64']) @@ -336,13 +340,7 @@ def train_parallel_exe(args): prediction, avg_cost, accuracy, accuracy5 = net_conf(image, label, class_dim) - #optimizer = fluid.optimizer.SGD(learning_rate=0.002) - optimizer = fluid.optimizer.Momentum( - learning_rate=fluid.layers.piecewise_decay( - boundaries=[100], values=[0.1, 0.2]), - momentum=0.9, - regularization=fluid.regularizer.L2Decay(1e-4)) - opts = optimizer.minimize(avg_cost) + add_optimizer(args, avg_cost) if args.use_mem_opt: fluid.memory_optimize(fluid.default_main_program()) From 1d05993713f3d08dc56b284ff8f45e73fcb7c9f3 Mon Sep 17 00:00:00 2001 From: chengduo Date: Mon, 2 Apr 2018 19:22:19 +0800 Subject: [PATCH 7/7] Rename readme.md to README.md --- fluid/SE-ResNeXt-152/{readme.md => README.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename fluid/SE-ResNeXt-152/{readme.md => README.md} (100%) diff --git a/fluid/SE-ResNeXt-152/readme.md b/fluid/SE-ResNeXt-152/README.md similarity index 100% rename from fluid/SE-ResNeXt-152/readme.md rename to fluid/SE-ResNeXt-152/README.md