@@ -205,8 +205,10 @@ def __init__(
205205 self .num_classes = self .model .num_classes
206206 self .param_count = count_params (self .model )
207207 _logger .info ('Model %s created, param count: %d' % (model_name , self .param_count ))
208+ self .scripted = False
208209 if torchscript :
209210 self .model = torch .jit .script (self .model )
211+ self .scripted = True
210212
211213 data_config = resolve_data_config (kwargs , model = self .model , use_test_size = not use_train_size )
212214 self .input_size = data_config ['input_size' ]
@@ -275,14 +277,14 @@ def _step():
275277 img_size = self .input_size [- 1 ],
276278 param_count = round (self .param_count / 1e6 , 2 ),
277279 )
278-
279- if has_deepspeed_profiling :
280- macs , _ = profile_deepspeed (self .model , self .input_size )
281- results ['gmacs' ] = round (macs / 1e9 , 2 )
282- elif has_fvcore_profiling :
283- macs , activations = profile_fvcore (self .model , self .input_size )
284- results ['gmacs' ] = round (macs / 1e9 , 2 )
285- results ['macts' ] = round (activations / 1e6 , 2 )
280+ if not self . scripted :
281+ if has_deepspeed_profiling :
282+ macs , _ = profile_deepspeed (self .model , self .input_size )
283+ results ['gmacs' ] = round (macs / 1e9 , 2 )
284+ elif has_fvcore_profiling :
285+ macs , activations = profile_fvcore (self .model , self .input_size )
286+ results ['gmacs' ] = round (macs / 1e9 , 2 )
287+ results ['macts' ] = round (activations / 1e6 , 2 )
286288
287289 _logger .info (
288290 f"Inference benchmark of { self .model_name } done. "
0 commit comments