-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fx/profiler] provide a table of summary. #1634
Merged
Merged
Changes from 3 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
8a42a8c
[fx/profiling] provide summary for MetaInfoProp.
super-dainiu 905eece
[fx/profiler] provide a table of summary.
super-dainiu 95fcbcc
[fx/profiler] provide a table of summary.
super-dainiu 4b326df
[fx/profiler] provide a table of summary.
super-dainiu 6056be1
[fx/profiler] provide a table of summary.
super-dainiu 58cd60b
[fx] optimize table repr.
super-dainiu bdec890
[fx] optimize table repr.
super-dainiu File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass | ||
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass | ||
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass | ||
from .meta_info_prop import MetaInfoProp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
import torch.fx | ||
from torch.fx.node import Node, Argument, Target | ||
from torch.utils._pytree import tree_map | ||
from typing import Any, Tuple, NamedTuple, Dict | ||
from typing import Any, List, Tuple, NamedTuple, Dict | ||
from torch.fx._compatibility import compatibility | ||
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size | ||
|
||
|
@@ -48,28 +48,33 @@ class MetaInfoProp(torch.fx.Interpreter): | |
Usage: | ||
BATCH_SIZE = 2 | ||
DIM_IN = 4 | ||
DIM_HIDDEN = 16 | ||
DIM_OUT = 16 | ||
model = torch.nn.Linear(DIM_IN, DIM_OUT) | ||
model = torch.nn.Sequential( | ||
torch.nn.Linear(DIM_IN, DIM_HIDDEN), | ||
torch.nn.Linear(DIM_HIDDEN, DIM_OUT), | ||
) | ||
input_sample = torch.rand(BATCH_SIZE, DIM_IN) | ||
orig_output = model(input_sample) | ||
gm = symbolic_trace(model) | ||
MetaInfoProp(gm).run(input_sample) | ||
|
||
for node in gm.graph.nodes: | ||
print(node.name, node.meta['tensor_meta'].dtype, | ||
node.meta['tensor_meta'].shape, node.meta['tensor_meta'].numel) | ||
interp = MetaInfoProp(gm) | ||
interp.run(input_sample) | ||
print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB | ||
|
||
|
||
# output of above code is | ||
# input_1 torch.float32 torch.Size([2, 4]) 8 | ||
# weight torch.float32 torch.Size([16, 4]) 64 | ||
# bias torch.float32 torch.Size([16]) 16 | ||
# linear torch.float32 torch.Size([2, 16]) 32 | ||
# output torch.float32 torch.Size([2, 16]) 32 | ||
Op type Op Forward FLOPs Backward FLOPs SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP | ||
----------- ------- --------------- ---------------- ------------- --------- --------- --------- --------- | ||
placeholder input_1 0.00e+00 FLOPs 0.00e+00 FLOPs False 0.00 KB 0.00 KB 0.00 KB 0.00 KB | ||
call_module _0 1.28e+02 FLOPs 2.88e+02 FLOPs True 0.12 KB 0.00 KB 0.34 KB 0.00 KB | ||
call_module _1 5.12e+02 FLOPs 1.06e+03 FLOPs True 0.12 KB 0.00 KB 1.19 KB 0.00 KB | ||
output output 0.00e+00 FLOPs 0.00e+00 FLOPs True 0.00 KB 0.00 KB 0.00 KB 0.00 KB | ||
Args: | ||
module (GraphModule): The module to be executed | ||
|
||
""" | ||
|
||
_is_proped: bool = False | ||
|
||
@compatibility(is_backward_compatible=True) | ||
def run_node(self, n: Node) -> Any: | ||
""" | ||
|
@@ -84,6 +89,7 @@ def run_node(self, n: Node) -> Any: | |
Returns: | ||
Any: The result of executing ``n`` | ||
""" | ||
self._is_proped = True | ||
result, meta_info = super().run_node(n) | ||
|
||
def extract_tensor_meta(obj): | ||
|
@@ -236,3 +242,64 @@ def propagate(self, *args): | |
Any: The value returned from executing the Module | ||
""" | ||
return super().run(*args) | ||
|
||
def summary(self, format: str = 'MB') -> str: | ||
""" | ||
Summarizes the memory and FLOPs statistics of the `GraphModule` in | ||
tabular format. Note that this API requires the ``tabulate`` module | ||
to be installed. | ||
""" | ||
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py | ||
try: | ||
from tabulate import tabulate | ||
except ImportError: | ||
print("`print_tabular` relies on the library `tabulate`, " | ||
"which could not be found on this machine. Run `pip " | ||
"install tabulate` to install the library.") | ||
|
||
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need at least run meta once before getting this summary |
||
|
||
# Build up a list of summary information for each node | ||
node_summaries: List[List[Any]] = [] | ||
|
||
def mem_repr(mem: int) -> str: | ||
unit_divisor_map = { | ||
'kb': 1024, | ||
'mb': 1024**2, | ||
'gb': 1024**3, | ||
'tb': 1024**4, | ||
} | ||
return f"{mem / unit_divisor_map[format.lower()]: .2f} {format.upper()}" | ||
|
||
def flops_repr(flop: int) -> str: | ||
return f"{flop:.2e} FLOPs" | ||
|
||
for node in self.module.graph.nodes: | ||
node: Node | ||
node_summaries.append([ | ||
node.op, | ||
str(node), | ||
flops_repr(node.meta['fwd_flop']), | ||
flops_repr(node.meta['bwd_flop']), | ||
node.meta['save_fwd_in'], | ||
mem_repr(node.meta['fwd_mem_out']), | ||
mem_repr(node.meta['fwd_mem_tmp']), | ||
mem_repr(node.meta['bwd_mem_out']), | ||
mem_repr(node.meta['bwd_mem_tmp']), | ||
]) | ||
|
||
# Use the ``tabulate`` library to create a well-formatted table | ||
# presenting our summary information | ||
headers: List[str] = [ | ||
'Op type', | ||
'Op', | ||
'Forward FLOPs', | ||
'Backward FLOPs', | ||
'SAVE_FWD_IN', | ||
'FWD_OUT', | ||
'FWD_TMP', | ||
'BWD_OUT', | ||
'BWD_TMP', | ||
] | ||
|
||
return tabulate(node_summaries, headers=headers) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should it be
unit
instead offormat
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes