Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] Add concrete info prop #1677

Merged
merged 44 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
04e5272
Merge pull request #1 from hpcaitech/main
Cypher30 Jul 14, 2022
75618b3
Merge pull request #2 from hpcaitech/main
Cypher30 Jul 15, 2022
3e4620c
Merge pull request #3 from hpcaitech/main
Cypher30 Jul 20, 2022
cf24049
Merge remote-tracking branch 'upstream/main' into main
Jul 20, 2022
3d223b6
Merge remote-tracking branch 'upstream/main' into main
Jul 21, 2022
644115c
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 22, 2022
d995ade
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 25, 2022
bba2dbe
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 26, 2022
05ca628
Merge branch 'hpcaitech:main' into main
Cypher30 Jul 26, 2022
0a967da
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 6, 2022
0637c0d
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 8, 2022
74a6227
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 10, 2022
e550490
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 10, 2022
2d7f5d9
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 11, 2022
b62e870
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 12, 2022
b4b0974
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 15, 2022
65c20de
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 16, 2022
1660bfc
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 17, 2022
6eb0ad0
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 20, 2022
56df059
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 26, 2022
480e932
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 30, 2022
0fa66ee
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 30, 2022
1d013b0
Merge branch 'hpcaitech:main' into main
Cypher30 Aug 31, 2022
5774db2
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 5, 2022
e8ff699
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 6, 2022
855c728
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 7, 2022
2c113ea
Merge branch 'main' of github.com:Cypher30/ColossalAI into main
Sep 8, 2022
838ba70
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 13, 2022
cacec2b
Merge branch 'main' of github.com:Cypher30/ColossalAI into main
Sep 13, 2022
5ed6ef0
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 14, 2022
668af30
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 14, 2022
df79772
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 15, 2022
7b6a0fc
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 20, 2022
c30022e
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 23, 2022
df20f4d
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 26, 2022
2d5a6a0
Merge branch 'hpcaitech:main' into main
Cypher30 Sep 29, 2022
9612ed6
[fx] concreteinfoprop
Sep 29, 2022
5b6aec8
[fx] add concreteinfoprop
Oct 3, 2022
943b59a
Merge branch 'hpcaitech:main' into feature/add_concrete_info_prop
Cypher30 Oct 3, 2022
aac7858
[fx] modify docstring of ConcreteInfoProp
Oct 3, 2022
2f89249
Merge branch 'feature/add_concrete_info_prop' of github.com:Cypher30/…
Oct 3, 2022
95e1354
[fx] fix device error
Oct 3, 2022
23e6f29
[fx] modify parameter calculation
Oct 3, 2022
a7c26d0
[fx] modify parameters calculation
Oct 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions colossalai/fx/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass
from .meta_info_prop import MetaInfoProp
from .concrete_info_prop import ConcreteInfoProp
290 changes: 290 additions & 0 deletions colossalai/fx/passes/concrete_info_prop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
from dataclasses import asdict
from colossalai.fx.profiler import GraphInfo
import torch
import torch.fx
from torch.fx.node import Node, Argument, Target
from torch.utils._pytree import tree_flatten
from typing import Any, List, Tuple, NamedTuple, Dict, Optional
from torch.fx._compatibility import compatibility
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
from torch.fx.graph_module import GraphModule


@compatibility(is_backward_compatible=True)
class ConcreteInfoProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node with concrete tensor and record the memory
usage, execution time of forward and backward, and type of the result into
the corresponding node.

Usage:
BATCH_SIZE = 2
DIM_IN = 4
DIM_HIDDEN = 16
DIM_OUT = 16
model = torch.nn.Sequential(
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
).cuda()
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="cuda")
gm = symbolic_trace(model)
interp = ConcreteInfoProp(gm)
interp.run(input_sample)
print(interp.summary(unit='kb'))


output of above code is
Op type Op Forward time Backward time SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- ----------------------- ------------------------ ------------- --------- --------- --------- ---------
placeholder input_1 0.0 s 0.0 s False 0.00 KB 0.00 KB 0.00 KB 0.00 KB
call_module _0 0.0003993511199951172 s 0.00706791877746582 s False 0.50 KB 0.00 KB 0.03 KB 0.66 KB
call_module _1 6.29425048828125e-05 s 0.00018286705017089844 s False 0.50 KB 0.00 KB 0.12 KB 0.81 KB
output output 0.0 s 0.0 s True 0.00 KB 0.00 KB 0.00 KB 0.00 KB
Args:
module (GraphModule): The module to be executed

"""

_is_proped: bool = False

def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any:
"""Customized run for ConcreteInfoProp
We need to store the device in self.device

Args:
*args: The arguments to the Module to run, in positional order
initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
This is a dict mapping `Node` to any value. This can be used, for example, to
pre-populate results for certain `Nodes` so as to do only partial evaluation within
the interpreter.
enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and
process_outputs function first before using them.

Returns:
Any: The value returned from executing the Module
"""

flatten_args, _ = tree_flatten(args)
self.device = next(item for item in flatten_args if hasattr(item, "device")).device
return super().run(*args, initial_env, enable_io_processing)

@compatibility(is_backward_compatible=True)
def run_node(self, n: Node) -> Any:
"""
Run a specific node ``n`` and return the result.
Calls into placeholder, get_attr, call_function,
call_method, call_module, or output depending
on ``node.op``

Args:
n (Node): The Node to execute

Returns:
Any: The result of executing ``n``
"""
self._is_proped = True
result, meta_info = super().run_node(n)

n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
n.meta['type'] = type(result)

# retain the autograd graph
for param in self.module.parameters():
param.grad = None

return result

# Main Node running APIs
@compatibility(is_backward_compatible=True)
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
arguments passed to ``run`` and this method returns
next() on that iterator.

Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation

Returns:
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and forward & backward time.
"""
return super().placeholder(target, args, kwargs), GraphInfo()

@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.

Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation

Return:
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return super().get_attr(target, args, kwargs), GraphInfo()

@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.

Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation

Return
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and forward & backward time.
"""
assert not isinstance(target, str)
return profile_function(target, self.device)(*args, **kwargs)

@compatibility(is_backward_compatible=True)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.

Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation

Return
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and forward & backward time.
"""
return profile_method(target, self.device)(*args, **kwargs)

@compatibility(is_backward_compatible=True)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.

Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation

Return
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and forward & backward time.
"""
# Retrieve executed args and kwargs values from the environment
# Execute the method and return the result
assert isinstance(target, str)
submod = self.fetch_attr(target)
return profile_module(submod, self.device)(*args, **kwargs)

@compatibility(is_backward_compatible=True)
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.

Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation

Return:
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and forward & backward time.
"""
return args[0], GraphInfo(save_fwd_in=True)

def propagate(self, *args):
"""
Run `module` via interpretation and return the result and
record the shape and type of each node.

Args:
*args (Tensor): the sample input.

Returns:
Any: The value returned from executing the Module
"""
return super().run(*args)

def summary(self, unit: str = 'MB') -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
to be installed.
"""
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
try:
from tabulate import tabulate
except ImportError:
print("`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library.")

assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."

# Build up a list of summary information for each node
node_summaries: List[List[Any]] = []

def mem_repr(mem: int) -> str:
unit_divisor_map = {
'kb': 1024,
'mb': 1024**2,
'gb': 1024**3,
'tb': 1024**4,
}
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"

def time_repr(time: float):
return f"{time:,} s"

for node in self.module.graph.nodes:
node: Node
node_summaries.append([
node.op,
str(node),
time_repr(node.meta['fwd_time']),
time_repr(node.meta['bwd_time']),
node.meta['save_fwd_in'],
mem_repr(node.meta['fwd_mem_out']),
mem_repr(node.meta['fwd_mem_tmp']),
mem_repr(node.meta['bwd_mem_out']),
mem_repr(node.meta['bwd_mem_tmp']),
])

# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
'Op type',
'Op',
'Forward time',
'Backward time',
'SAVE_FWD_IN',
'FWD_OUT',
'FWD_TMP',
'BWD_OUT',
'BWD_TMP',
]

return tabulate(node_summaries, headers=headers, stralign='right')
7 changes: 6 additions & 1 deletion colossalai/fx/profiler/dataflow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import Dict
from torch.fx import Graph, Node
from .memory import activation_size, is_inplace
Expand Down Expand Up @@ -33,16 +34,20 @@ class GraphInfo:
-------------------------------
============================================================================
Attributes:
fwd_flop (int): The forward FLOPs of a certain node
fwd_flop (int): The forward FLOPs of a certain node.
fwd_time (float): The real forward time (s) of a certain node.
bwd_flop (int): The backward FLOPs of a certain node.
bwd_time (float): The real backward time (s) of a certain node.
save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
fwd_mem_tmp (int): See the above illustration.
fwd_mem_out (int): See the above illustration.
bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration.
"""
fwd_flop: int = 0
fwd_time: float = 0.0
bwd_flop: int = 0
bwd_time: float = 0.0
save_fwd_in: bool = False
fwd_mem_tmp: int = 0
fwd_mem_out: int = 0
Expand Down