@@ -147,19 +147,19 @@ def resolve_precision(precision: str):
147147 return use_amp , model_dtype , data_dtype
148148
149149
150- def profile (model , input_size = (3 , 224 , 224 )):
150+ def profile (model , input_size = (3 , 224 , 224 ), detailed = False ):
151151 batch_size = 1
152152 macs , params = get_model_profile (
153153 model = model ,
154154 input_res = (batch_size ,) + input_size , # input shape or input to the input_constructor
155155 input_constructor = None , # if specified, a constructor taking input_res is used as input to the model
156- print_profile = False , # prints the model graph with the measured profile attached to each module
157- detailed = False , # print the detailed profile
156+ print_profile = detailed , # prints the model graph with the measured profile attached to each module
157+ detailed = detailed , # print the detailed profile
158158 warm_up = 10 , # the number of warm-ups before measuring the time of each module
159159 as_string = False , # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)
160160 output_file = None , # path to the output file. If None, the profiler prints to stdout.
161161 ignore_modules = None ) # the list of modules to ignore in the profiling
162- return macs
162+ return macs , params
163163
164164
165165class BenchmarkRunner :
@@ -258,8 +258,8 @@ def _step():
258258 )
259259
260260 if get_model_profile is not None :
261- macs = profile (self .model , self .input_size )
262- results ['GMACs ' ] = round (macs / 1e9 , 2 )
261+ macs , _ = profile (self .model , self .input_size )
262+ results ['gmacs ' ] = round (macs / 1e9 , 2 )
263263
264264 _logger .info (
265265 f"Inference benchmark of { self .model_name } done. "
@@ -388,6 +388,32 @@ def _step(detail=False):
388388 return results
389389
390390
391+ class ProfileRunner (BenchmarkRunner ):
392+
393+ def __init__ (self , model_name , device = 'cuda' , ** kwargs ):
394+ super ().__init__ (model_name = model_name , device = device , ** kwargs )
395+ self .model .eval ()
396+
397+ def run (self ):
398+ _logger .info (
399+ f'Running profiler on { self .model_name } w/ '
400+ f'input size { self .input_size } and batch size 1.' )
401+
402+ macs , params = profile (self .model , self .input_size , detailed = True )
403+
404+ results = dict (
405+ gmacs = round (macs / 1e9 , 2 ),
406+ img_size = self .input_size [- 1 ],
407+ param_count = round (params / 1e6 , 2 ),
408+ )
409+
410+ _logger .info (
411+ f"Profile of { self .model_name } done. "
412+ f"{ results ['gmacs' ]:.2f} GMACs, { results ['param_count' ]:.2f} M params." )
413+
414+ return results
415+
416+
391417def decay_batch_exp (batch_size , factor = 0.5 , divisor = 16 ):
392418 out_batch_size = batch_size * factor
393419 if out_batch_size > divisor :
@@ -436,6 +462,9 @@ def benchmark(args):
436462 elif args .bench == 'train' :
437463 bench_fns = TrainBenchmarkRunner ,
438464 prefixes = 'train' ,
465+ elif args .bench == 'profile' :
466+ assert get_model_profile is not None , "deepspeed needs to be installed for profile"
467+ bench_fns = ProfileRunner ,
439468
440469 model_results = OrderedDict (model = model )
441470 for prefix , bench_fn in zip (prefixes , bench_fns ):
@@ -483,7 +512,11 @@ def main():
483512 results .append (r )
484513 except KeyboardInterrupt as e :
485514 pass
486- sort_key = 'train_samples_per_sec' if 'train' in args .bench else 'infer_samples_per_sec'
515+ sort_key = 'infer_samples_per_sec'
516+ if 'train' in args .bench :
517+ sort_key = 'train_samples_per_sec'
518+ elif 'profile' in args .bench :
519+ sort_key = 'infer_gmacs'
487520 results = sorted (results , key = lambda x : x [sort_key ], reverse = True )
488521 if len (results ):
489522 write_results (results_file , results )
0 commit comments