Skip to content

Commit 71f00bf

Browse files
committed
Don't run profile if model is torchscripted
1 parent 7da1b0b commit 71f00bf

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

benchmark.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,10 @@ def __init__(
205205
self.num_classes = self.model.num_classes
206206
self.param_count = count_params(self.model)
207207
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count))
208+
self.scripted = False
208209
if torchscript:
209210
self.model = torch.jit.script(self.model)
211+
self.scripted = True
210212

211213
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
212214
self.input_size = data_config['input_size']
@@ -275,14 +277,14 @@ def _step():
275277
img_size=self.input_size[-1],
276278
param_count=round(self.param_count / 1e6, 2),
277279
)
278-
279-
if has_deepspeed_profiling:
280-
macs, _ = profile_deepspeed(self.model, self.input_size)
281-
results['gmacs'] = round(macs / 1e9, 2)
282-
elif has_fvcore_profiling:
283-
macs, activations = profile_fvcore(self.model, self.input_size)
284-
results['gmacs'] = round(macs / 1e9, 2)
285-
results['macts'] = round(activations / 1e6, 2)
280+
if not self.scripted:
281+
if has_deepspeed_profiling:
282+
macs, _ = profile_deepspeed(self.model, self.input_size)
283+
results['gmacs'] = round(macs / 1e9, 2)
284+
elif has_fvcore_profiling:
285+
macs, activations = profile_fvcore(self.model, self.input_size)
286+
results['gmacs'] = round(macs / 1e9, 2)
287+
results['macts'] = round(activations / 1e6, 2)
286288

287289
_logger.info(
288290
f"Inference benchmark of {self.model_name} done. "

0 commit comments

Comments
 (0)