Skip to content
Permalink
Browse files

Match input and model device

  • Loading branch information
sytelus committed Feb 5, 2020
1 parent 65577dd commit 51a198d4cd3c8d6792c9bab48f58d520d3d46afc
Showing with 4 additions and 3 deletions.
  1. +1 −0 tensorwatch/model_graph/torchstat/model_hook.py
  2. +3 −3 test/pre_train/model_stats.py
@@ -21,6 +21,7 @@ def __init__(self, model, input_size):

self._hook_model()
x = torch.rand(*self._input_size) # add module duration time
x = x.to(next(model.parameters()).device)
self._model.eval()
self._model(x)

@@ -1,9 +1,9 @@
import tensorwatch as tw
import torchvision.models

model_names = ['alexnet'] #, 'resnet18', 'resnet34','densenet121']
model_names = ['alexnet', 'resnet18', 'resnet34', 'resnet101', 'densenet121']

for model_name in model_names:
model = getattr(torchvision.models, model_name)()
df = tw.model_stats(model, [1, 3, 224, 224])
print(df)
model_stats = tw.ModelStats(model, [1, 3, 224, 224], clone_model=False)
print(f'{model_name}: flops={model_stats.Flops}, parameters={model_stats.parameters}, memory={model_stats.inference_memory}')

0 comments on commit 51a198d

Please sign in to comment.
You can’t perform that action at this time.