diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py index 55da69d7a4af..d5a02ac4d579 100644 --- a/deepspeed/profiling/flops_profiler/profiler.py +++ b/deepspeed/profiling/flops_profiler/profiler.py @@ -1180,20 +1180,19 @@ def get_module_duration(module): return duration -def get_model_profile( - model, - input_shape=None, - args=[], - kwargs={}, - print_profile=True, - detailed=True, - module_depth=-1, - top_modules=1, - warm_up=1, - as_string=True, - output_file=None, - ignore_modules=None, -): +def get_model_profile(model, + input_shape=None, + args=[], + kwargs={}, + print_profile=True, + detailed=True, + module_depth=-1, + top_modules=1, + warm_up=1, + as_string=True, + output_file=None, + ignore_modules=None, + mode='forward'): """Returns the total floating-point operations, MACs, and parameters of a model. Example: @@ -1239,18 +1238,29 @@ def get_model_profile( args = [input] assert (len(args) > 0) or (len(kwargs) > 0), "args and/or kwargs must be specified if input_shape is None" - for _ in range(warm_up): if kwargs: - _ = model(*args, **kwargs) + if mode == 'forward': + _ = model(*args, **kwargs) + if mode == 'generate': + _ = model.generate(*args, **kwargs) else: - _ = model(*args) + if mode == 'forward': + _ = model(*args) + if mode == 'generate': + _ = model.generate(*args) prof.start_profile(ignore_list=ignore_modules) if kwargs: - _ = model(*args, **kwargs) + if mode == 'forward': + _ = model(*args, **kwargs) + if mode == 'generate': + _ = model.generate(*args, **kwargs) else: - _ = model(*args) + if mode == 'forward': + _ = model(*args) + if mode == 'generate': + _ = model.generate(*args) flops = prof.get_total_flops() macs = prof.get_total_macs()