Skip to content
48 changes: 29 additions & 19 deletions deepspeed/profiling/flops_profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,20 +1180,19 @@ def get_module_duration(module):
return duration


def get_model_profile(
model,
input_shape=None,
args=[],
kwargs={},
print_profile=True,
detailed=True,
module_depth=-1,
top_modules=1,
warm_up=1,
as_string=True,
output_file=None,
ignore_modules=None,
):
def get_model_profile(model,
input_shape=None,
args=[],
kwargs={},
print_profile=True,
detailed=True,
module_depth=-1,
top_modules=1,
warm_up=1,
as_string=True,
output_file=None,
ignore_modules=None,
mode='forward'):
"""Returns the total floating-point operations, MACs, and parameters of a model.

Example:
Expand Down Expand Up @@ -1239,18 +1238,29 @@ def get_model_profile(

args = [input]
assert (len(args) > 0) or (len(kwargs) > 0), "args and/or kwargs must be specified if input_shape is None"

for _ in range(warm_up):
if kwargs:
_ = model(*args, **kwargs)
if mode == 'forward':
_ = model(*args, **kwargs)
if mode == 'generate':
_ = model.generate(*args, **kwargs)
else:
_ = model(*args)
if mode == 'forward':
_ = model(*args)
if mode == 'generate':
_ = model.generate(*args)
prof.start_profile(ignore_list=ignore_modules)

if kwargs:
_ = model(*args, **kwargs)
if mode == 'forward':
_ = model(*args, **kwargs)
if mode == 'generate':
_ = model.generate(*args, **kwargs)
else:
_ = model(*args)
if mode == 'forward':
_ = model(*args)
if mode == 'generate':
_ = model.generate(*args)

flops = prof.get_total_flops()
macs = prof.get_total_macs()
Expand Down