Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx/profiler] provide a table of summary. #1634

Merged
merged 7 commits into from Sep 23, 2022

Conversation

super-dainiu
Copy link
Contributor

@super-dainiu super-dainiu commented Sep 23, 2022

What's new?

I create a summary method for MetaInfoProp. Hopefully, this will help @Cypher30 debug backward memory estimations.

BATCH_SIZE = 2
DIM_IN = 4
DIM_HIDDEN = 16
DIM_OUT = 16
model = torch.nn.Sequential(
          torch.nn.Linear(DIM_IN, DIM_HIDDEN), 
          torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
          )
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
gm = symbolic_trace(model)
interp = MetaInfoProp(gm)
interp.run(input_sample)
print(interp.summary(format='kb'))    # don't panic if some statistics are 0.00 MB

======================================================
# output of above code is 
    Op type       Op    Forward FLOPs    Backward FLOPs    SAVE_FWD_IN    FWD_OUT    FWD_TMP    BWD_OUT    BWD_TMP
-----------  -------  ---------------  ----------------  -------------  ---------  ---------  ---------  ---------
placeholder  input_1          0 FLOPs           0 FLOPs          False    0.00 KB    0.00 KB    0.00 KB    0.00 KB
call_module       _0        128 FLOPs         288 FLOPs           True    0.12 KB    0.00 KB    0.34 KB    0.00 KB
call_module       _1        512 FLOPs       1,056 FLOPs           True    0.12 KB    0.00 KB    1.19 KB    0.00 KB
     output   output          0 FLOPs           0 FLOPs           True    0.00 KB    0.00 KB    0.00 KB    0.00 KB

Warnings

@Cypher30
Obviously, nn.Linear's BWD_OUT should be the same as the previous node's FWD_OUT.
image
It is fairly likely that BWD_OUT also includes a gradient for weight and bias.

"which could not be found on this machine. Run `pip "
"install tabulate` to install the library.")

assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need at least run meta once before getting this summary

@@ -236,3 +242,64 @@ def propagate(self, *args):
Any: The value returned from executing the Module
"""
return super().run(*args)

def summary(self, format: str = 'MB') -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be unit instead of format?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

@super-dainiu super-dainiu merged commit 04bbabe into hpcaitech:main Sep 23, 2022
@super-dainiu super-dainiu deleted the feature/meta_table branch September 24, 2022 13:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants