diff --git a/fluid/SE-ResNeXt-152/README.md b/fluid/SE-ResNeXt-152/README.md new file mode 100644 index 0000000..de00156 --- /dev/null +++ b/fluid/SE-ResNeXt-152/README.md @@ -0,0 +1,16 @@ +# Benchmark SE-ResNeXt-152 + +## For single card: +``` +env CUDA_VISIBLE_DEVICES=0 python train.py --use_parallel_mode=parallel_do --use_nccl=False --parallel=False --display_step=1 +``` + +## For multi-card: +### use parallel_do +``` +env CUDA_VISIBLE_DEVICES=4,5,6,7 python train.py --use_parallel_mode=parallel_do --use_nccl=True --parallel=True --display_step=1 +``` +### use parallel_exe +``` +env CUDA_VISIBLE_DEVICES=4,5,6,7 python train.py --use_parallel_mode=parallel_exe --use_nccl=True --parallel=True --display_step=1 +``` diff --git a/fluid/SE-ResNeXt-152/generate_flowers_recordio.py b/fluid/SE-ResNeXt-152/generate_flowers_recordio.py new file mode 100644 index 0000000..08f7e8b --- /dev/null +++ b/fluid/SE-ResNeXt-152/generate_flowers_recordio.py @@ -0,0 +1,34 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.dataset.flowers as flowers + +batch_size = 12 +data_shape = [3, 224, 224] + +with fluid.program_guard(fluid.Program(), fluid.Program()): + reader = paddle.batch(flowers.train(), batch_size=batch_size) + feeder = fluid.DataFeeder( + feed_list=[ # order is image and label + fluid.layers.data( + name='image', shape=data_shape, dtype='float32'), + fluid.layers.data( + name='label', shape=[1], dtype='int64'), + ], + place=fluid.CPUPlace()) + fluid.recordio_writer.convert_reader_to_recordio_file( + './flowers_bs_12_3_224_224.recordio', reader, feeder) diff --git a/fluid/SE-ResNeXt-152/run.sh b/fluid/SE-ResNeXt-152/run.sh deleted file mode 100644 index 4c2a7e8..0000000 --- a/fluid/SE-ResNeXt-152/run.sh +++ /dev/null @@ -1 +0,0 @@ -env CUDA_VISIBLE_DEVICES=4 python train.py --use_nccl=False --parallel=False diff --git a/fluid/SE-ResNeXt-152/train.py b/fluid/SE-ResNeXt-152/train.py index c503204..c4908d9 100644 --- a/fluid/SE-ResNeXt-152/train.py +++ b/fluid/SE-ResNeXt-152/train.py @@ -16,41 +16,67 @@ import time import argparse import distutils.util +import numpy as np -import paddle.v2 as paddle +import paddle import paddle.fluid as fluid -import paddle.v2.dataset.flowers as flowers +import paddle.dataset.flowers as flowers import paddle.fluid.profiler as profiler +fluid.default_startup_program().random_seed = 111 + def parse_args(): parser = argparse.ArgumentParser('SE-ResNeXt-152 parallel profile.') + parser.add_argument( + '--class_number', type=int, default=1000, help='the class number') + parser.add_argument( + '--use_parallel_mode', + type=str, + default='parallel_exe', + choices=['parallel_do', 'parallel_exe'], + help='The parallel mode("parallel_do" or "parallel_exe").') + parser.add_argument('--batch_size', type=int, default=12, help='batch size') parser.add_argument('--per_gpu_batch_size', type=int, default=12, help='') parser.add_argument( - '--skip_first_steps', + '--use_mem_opt', + type=distutils.util.strtobool, + default=True, + help='use memory optimize or not.') + parser.add_argument( + '--do_profile', + type=distutils.util.strtobool, + default=True, + help='do profile or not.') + parser.add_argument( + '--number_iteration', type=int, - default=2, - help='The first num of steps to skip, for better performance profile') + default=50, + help='total batch num for per_gpu_batch_size.') + parser.add_argument('--display_step', type=int, default=1, help='') parser.add_argument( - '--total_batch_num', + '--skip_first_steps', type=int, - default=40, - help='total batch num for per_gpu_batch_size') + default=2, + help='The first num of steps to skip, for better performance profile.') parser.add_argument( '--parallel', type=distutils.util.strtobool, default=True, - help='use parallel_do') + help='It is valid only when parallel_mode is parallel_do.') parser.add_argument( '--use_nccl', type=distutils.util.strtobool, - default=False, - help='use_nccl') + default=True, + help='It is valid only when parallel_mode is parallel_do.') parser.add_argument( '--use_python_reader', type=distutils.util.strtobool, default=True, - help='use python reader to feed data') + help='It is valid only when parallel_mode is parallel_do.' + 'If use_python_reader is True, python reader is used to feeding data,' + 'the process includes data transfer from CPU to GPU. Otherwise, ' + 'the data which will be needed for training is in GPU side constantly.') args = parser.parse_args() return args @@ -58,8 +84,13 @@ def parse_args(): def print_arguments(args): print('----------- Configuration Arguments -----------') - for arg, value in sorted(vars(args).iteritems()): - print('%s=%s' % (arg, value)) + if args.use_parallel_mode == "parallel_do": + for arg, value in sorted(vars(args).iteritems()): + print('%s=%s' % (arg, value)) + else: + args.use_nccl = True + for arg, value in sorted(vars(args).iteritems()): + print('%s=%s' % (arg, value)) def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1, @@ -119,24 +150,47 @@ def bottleneck_block(input, num_filters, stride, cardinality, reduction_ratio): reduction_ratio=reduction_ratio) short = shortcut(input, num_filters * 2, stride) - return fluid.layers.elementwise_add(x=short, y=scale, act='relu') -def SE_ResNeXt(input, class_dim, infer=False): - cardinality = 64 - reduction_ratio = 16 - depth = [3, 8, 36, 3] - num_filters = [128, 256, 512, 1024] - - conv = conv_bn_layer( - input=input, num_filters=64, filter_size=3, stride=2, act='relu') - conv = conv_bn_layer( - input=conv, num_filters=64, filter_size=3, stride=1, act='relu') - conv = conv_bn_layer( - input=conv, num_filters=128, filter_size=3, stride=1, act='relu') - conv = fluid.layers.pool2d( - input=conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max') +def SE_ResNeXt(input, class_dim, infer=False, layers=152): + supported_layers = [50, 152] + if layers not in supported_layers: + print("supported layers are", supported_layers, "but input layer is ", + layers) + exit() + if layers == 50: + cardinality = 32 + reduction_ratio = 16 + depth = [3, 4, 6, 3] + num_filters = [128, 256, 512, 1024] + + conv = conv_bn_layer( + input=input, num_filters=64, filter_size=7, stride=2, act='relu') + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + elif layers == 152: + cardinality = 64 + reduction_ratio = 16 + depth = [3, 8, 36, 3] + num_filters = [128, 256, 512, 1024] + + conv = conv_bn_layer( + input=input, num_filters=64, filter_size=3, stride=2, act='relu') + conv = conv_bn_layer( + input=conv, num_filters=64, filter_size=3, stride=1, act='relu') + conv = conv_bn_layer( + input=conv, num_filters=128, filter_size=3, stride=1, act='relu') + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') for block in range(len(depth)): for i in range(depth[block]): @@ -157,23 +211,31 @@ def SE_ResNeXt(input, class_dim, infer=False): return out -def time_stamp(): - return int(round(time.time() * 1000)) +def net_conf(image, label, class_dim): + out = SE_ResNeXt(input=image, class_dim=class_dim) + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + #accuracy = fluid.evaluator.Accuracy(input=out, label=label) + #accuracy5 = fluid.evaluator.Accuracy(input=out, label=label, k=5) + accuracy = fluid.layers.accuracy(input=out, label=label) + accuracy5 = fluid.layers.accuracy(input=out, label=label, k=5) + return out, avg_cost, accuracy, accuracy5 -def train(): - args = parse_args() +def add_optimizer(args, avg_cost): + #optimizer = fluid.optimizer.SGD(learning_rate=0.002) + optimizer = fluid.optimizer.Momentum( + learning_rate=fluid.layers.piecewise_decay( + boundaries=[100], values=[0.1, 0.2]), + momentum=0.9, + regularization=fluid.regularizer.L2Decay(1e-4)) + optimizer.minimize(avg_cost) - cards = os.getenv("CUDA_VISIBLE_DEVICES") or "" - cards_num = len(cards.split(",")) - step_num = args.total_batch_num / cards_num - batch_size = args.per_gpu_batch_size * cards_num + if args.use_mem_opt: + fluid.memory_optimize(fluid.default_main_program()) - print_arguments(args) - print("cards_num=" + str(cards_num)) - print("batch_size=" + str(batch_size)) - print("total_batch_num=" + str(args.total_batch_num)) - print("step_num=" + str(step_num)) + +def train_parallel_do(args): class_dim = 1000 image_shape = [3, 224, 224] @@ -204,51 +266,128 @@ def train(): avg_cost = fluid.layers.mean(x=cost) accuracy = fluid.layers.accuracy(input=out, label=label) - #optimizer = fluid.optimizer.SGD(learning_rate=0.002) - optimizer = fluid.optimizer.Momentum( - learning_rate=fluid.layers.piecewise_decay( - boundaries=[100], values=[0.1, 0.2]), - momentum=0.9, - regularization=fluid.regularizer.L2Decay(1e-4)) - opts = optimizer.minimize(avg_cost) - - fluid.memory_optimize(fluid.default_main_program()) + add_optimizer(args, avg_cost) place = fluid.CUDAPlace(0) # place = fluid.CPUPlace() exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) - train_reader = paddle.batch(flowers.train(), batch_size=batch_size) - test_reader = paddle.batch(flowers.test(), batch_size=batch_size) + train_reader = paddle.batch(flowers.train(), batch_size=args.batch_size) + feeder = fluid.DataFeeder(place=place, feed_list=[image, label]) train_reader_iter = train_reader() - data = train_reader_iter.next() - feed_dict = feeder.feed(data) + if not args.use_python_reader: + data = train_reader_iter.next() + feed_dict = feeder.feed(data) - for pass_id in range(1): - with profiler.profiler('All', 'total', '/tmp/profile') as prof: - train_time = 0.0 + time_record = [] - for step_id in range(step_num): - train_start = time.time() + for batch_id in range(args.number_iteration): + if args.do_profile and batch_id >= 5 and batch_id < 8: + with profiler.profiler('All', 'total', + '/tmp/profile_parallel_do') as prof: exe.run(fluid.default_main_program(), feed=feeder.feed(train_reader_iter.next()) if args.use_python_reader else feed_dict, fetch_list=[], use_program_cache=True) - train_stop = time.time() - step_time = train_stop - train_start - if step_id >= args.skip_first_steps: - train_time += step_time - print("step_id=" + str(step_id) + " step_time=" + str( - step_time)) - print("\n\n\n") - calc_step_num = step_num - args.skip_first_steps - print("calc_step_num=" + str(calc_step_num) + " total_train_time=" + - str(train_time) + " ave_step_time=" + str( - float(train_time) / calc_step_num)) + continue + + train_start = time.time() + cost_val = exe.run(fluid.default_main_program(), + feed=feeder.feed(train_reader_iter.next()) + if args.use_python_reader else feed_dict, + fetch_list=[avg_cost.name] + if batch_id % args.display_step == 0 else [], + use_program_cache=True) + train_stop = time.time() + step_time = train_stop - train_start + time_record.append(step_time) + + if batch_id % args.display_step == 0: + print("iter=%d, elapse=%f, cost=%s" % + (batch_id, step_time, np.array(cost_val[0]))) + + for _ in range(args.skip_first_steps): + del time_record[0] + + for ele in time_record: + print ele + + print("average time:{0}".format(np.mean(time_record))) + + +def train_parallel_exe(args): + + class_dim = 1000 + image_shape = [3, 224, 224] + + main = fluid.Program() + startup = fluid.Program() + + with fluid.program_guard(main, startup): + reader = fluid.layers.open_recordio_file( + filename='./flowers_bs_12_3_224_224.recordio', + shapes=[[-1, 3, 224, 224], [-1, 1]], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + + # currently, double buffer only supports one device. + #data_file = fluid.layers.create_double_buffer_reader(reader=data_file, place='CUDA:0') + image, label = fluid.layers.read_file(reader) + + prediction, avg_cost, accuracy, accuracy5 = net_conf(image, label, + class_dim) + + add_optimizer(args, avg_cost) + + if args.use_mem_opt: + fluid.memory_optimize(fluid.default_main_program()) + + exe = fluid.ParallelExecutor(loss_name=avg_cost.name, use_cuda=True) + + time_record = [] + + for batch_id in xrange(args.number_iteration): + + if args.do_profile and batch_id >= 5 and batch_id < 8: + with profiler.profiler('All', 'total', + '/tmp/profile_parallel_exe') as prof: + exe.run([]) + continue + + t1 = time.time() + cost_val = exe.run([avg_cost.name] + if batch_id % args.display_step == 0 else []) + t2 = time.time() + period = t2 - t1 + time_record.append(period) + + if batch_id % args.display_step == 0: + print("iter=%d, elapse=%f, cost=%s" % + (batch_id, period, np.array(cost_val[0]))) + + for _ in range(args.skip_first_steps): + del time_record[0] + + for ele in time_record: + print ele + + print("average time:{0}".format(np.mean(time_record))) if __name__ == '__main__': - train() + args = parse_args() + + cards = os.getenv("CUDA_VISIBLE_DEVICES") or "" + cards_num = len(cards.split(",")) + args.batch_size = args.per_gpu_batch_size * cards_num + + print_arguments(args) + print("cards_num=" + str(cards_num)) + + if args.use_parallel_mode == "parallel_do": + train_parallel_do(args) + else: + train_parallel_exe(args)