Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx/profiler] assigned UUID to each unrecorded tensor/ improved performance on GPT-2 #1679

Merged
merged 33 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
8a42a8c
[fx/profiling] provide summary for MetaInfoProp.
super-dainiu Sep 23, 2022
905eece
[fx/profiler] provide a table of summary.
super-dainiu Sep 23, 2022
95fcbcc
[fx/profiler] provide a table of summary.
super-dainiu Sep 23, 2022
4b326df
[fx/profiler] provide a table of summary.
super-dainiu Sep 23, 2022
6056be1
[fx/profiler] provide a table of summary.
super-dainiu Sep 23, 2022
58cd60b
[fx] optimize table repr.
super-dainiu Sep 23, 2022
bdec890
[fx] optimize table repr.
super-dainiu Sep 23, 2022
2f24d4b
[fx] refactor code for profiler.
super-dainiu Sep 26, 2022
754f64a
Merge branch 'hpcaitech:main' into speedup
super-dainiu Sep 26, 2022
4d13492
[fx] add docstring.
super-dainiu Sep 26, 2022
ff93220
Merge branch 'speedup' of https://github.com/super-dainiu/ColossalAI …
super-dainiu Sep 26, 2022
71b93e4
[fx] add docstring.
super-dainiu Sep 26, 2022
2b1bd26
[fx] skip test.
super-dainiu Sep 26, 2022
e57fbc4
[fx] redo.
super-dainiu Sep 26, 2022
a6210d2
[fx] redo.
super-dainiu Sep 26, 2022
66e8546
[fx] fix import error for torch11.
super-dainiu Sep 26, 2022
0e37fc5
[fx] fix import error for torch11.
super-dainiu Sep 26, 2022
3e25419
[hotfix] fix singledispatchmethod incompatibility.
super-dainiu Sep 27, 2022
7993aff
Merge branch 'hpcaitech:main' into main
super-dainiu Sep 27, 2022
d358df5
[hotfix] fix singledispatchmethod incompatibility.
super-dainiu Sep 27, 2022
3ec8061
[fx/profiler] modify data_ptr into uuid for all tensors.
super-dainiu Oct 3, 2022
ff8e832
[fx] modify uuid.
super-dainiu Oct 4, 2022
db0c211
Merge branch 'hpcaitech:main' into speedup
super-dainiu Oct 4, 2022
9f206d7
Merge branch 'speedup' of https://github.com/super-dainiu/ColossalAI …
super-dainiu Oct 4, 2022
ab96506
Merge branch 'hpcaitech:main' into speedup
super-dainiu Oct 5, 2022
fab23e3
[fx/profiler] tune performance on GPT-2.
super-dainiu Oct 5, 2022
49defef
[fx] merge upstreams.
super-dainiu Oct 5, 2022
771c02e
[fx] updates.
super-dainiu Oct 5, 2022
016a94b
[fx] debug.
super-dainiu Oct 6, 2022
abe66ef
[fx] debug.
super-dainiu Oct 6, 2022
283d121
Merge branch 'hpcaitech:main' into speedup
super-dainiu Oct 6, 2022
47c8381
Merge branch 'speedup' of https://github.com/super-dainiu/ColossalAI …
super-dainiu Oct 6, 2022
a405f60
[fx] cuda.
super-dainiu Oct 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions colossalai/fx/passes/algorithms/ckpt_solver_chen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from torch.fx import GraphModule, Node
import math
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp

__all__ = ['chen_greedy']
CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr']
Expand Down Expand Up @@ -74,10 +75,10 @@ def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes):
n: Node
temp += n.meta['fwd_mem_out'] + n.meta['fwd_mem_tmp']
temp += calculate_fwd_in(n) + calculate_fwd_tmp(n)
y = max(y, temp)
if temp > b and n in ckpt_nodes:
x += n.meta['fwd_mem_out']
x += calculate_fwd_in(n)
temp = 0
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1
Expand Down
32 changes: 22 additions & 10 deletions colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import sys
from typing import List, Tuple
from colossalai.fx.profiler.memory import calculate_fwd_in
from torch.fx import Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, parameter_size
from colossalai.fx.profiler import activation_size, parameter_size, calculate_fwd_out, calculate_fwd_tmp
import math
from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
Expand Down Expand Up @@ -124,9 +125,7 @@ def _fwd_xbar(node: List[Node]) -> int:

xbar = 0
for n in node:
xbar += n.meta['fwd_mem_tmp']
if any(map(lambda x: x.meta['save_fwd_in'], n.users)):
xbar += n.meta['fwd_mem_out']
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
return xbar


Expand Down Expand Up @@ -166,6 +165,21 @@ def _bwd_time(node: List[Node]) -> int:
return bwd_time


def _get_fwd_mem_tmp(node: List[Node]) -> int:
"""Get the forward temp memory of a node
This could be done by subtracting the saved activation from all output of a node

Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph

Returns:
int: forward temp memory, unit Byte
"""
n = node[-1]
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
Comment on lines +168 to +180
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now fwd_out that are not saved for next node may be regarded as fwd_tmp



def _get_bwd_mem_tmp(node: List[Node]) -> int:
"""Get the backward temp memory of a node

Expand All @@ -184,9 +198,7 @@ def _get_deps_size():
if v > 0:
deps_size += k.meta['bwd_mem_out']
if v == float('-inf'):
deps_size -= k.meta['fwd_mem_tmp']
if any(map(lambda x: x.meta['save_fwd_in'], k.users)):
deps_size -= k.meta['fwd_mem_out']
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)

return deps_size

Expand All @@ -212,15 +224,15 @@ def _construct_chain(node_list: List[List[Node]], input) -> Chain:
bwd_time = []
xbar_sizes = [activation_size(input)]
x_sizes = [activation_size(input)]
# currently we can't get the temp memory needed in fwd
tmp_fwd = [0] * len(node_list)
tmp_fwd = []
tmp_bwd = []

for idx, node in enumerate(node_list):
fwd_time.append(_fwd_time(node))
bwd_time.append(_bwd_time(node))
x_sizes.append(node[-1].meta['fwd_mem_out'])
x_sizes.append(calculate_fwd_out(node[-1]))
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
tmp_fwd.append(_get_fwd_mem_tmp(node))
tmp_bwd.append(_get_bwd_mem_tmp(node))

bwd_time.append(0)
Expand Down
27 changes: 14 additions & 13 deletions colossalai/fx/passes/meta_info_prop.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from dataclasses import asdict
from colossalai.fx.profiler import GraphInfo
import torch
import torch.fx
from torch.fx.node import Node, Argument, Target
from torch.utils._pytree import tree_map
from typing import Any, List, Tuple, NamedTuple, Dict
from torch.fx._compatibility import compatibility
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
from colossalai.fx.profiler import GraphInfo, profile_function, profile_module, profile_method, activation_size, calculate_fwd_out, calculate_fwd_tmp, calculate_fwd_in


@compatibility(is_backward_compatible=True)
Expand Down Expand Up @@ -62,12 +61,12 @@ class MetaInfoProp(torch.fx.Interpreter):


# output of above code is
Op type Op Forward FLOPs Backward FLOPs SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- --------------- ---------------- ------------- --------- --------- --------- ---------
placeholder input_1 0 FLOPs 0 FLOPs False 0.00 KB 0.00 KB 0.00 KB 0.00 KB
call_module _0 128 FLOPs 288 FLOPs True 0.12 KB 0.00 KB 0.34 KB 0.00 KB
call_module _1 512 FLOPs 1,056 FLOPs True 0.12 KB 0.00 KB 1.19 KB 0.00 KB
output output 0 FLOPs 0 FLOPs True 0.00 KB 0.00 KB 0.00 KB 0.00 KB
Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- --------------- ---------------- --------- --------- --------- ---------
placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
call_module _0 128 FLOPs 288 FLOPs 0.12 KB 0.00 KB 0.34 KB 0.00 KB
call_module _1 512 FLOPs 1,056 FLOPs 0.12 KB 0.00 KB 1.19 KB 0.00 KB
output output 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
Args:
module (GraphModule): The module to be executed

Expand Down Expand Up @@ -102,7 +101,7 @@ def extract_tensor_meta(obj):
n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
setattr(n, 'node_size', activation_size(n.meta.get('fwd_in', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
n.meta['type'] = type(result)

# retain the autograd graph
Expand Down Expand Up @@ -228,6 +227,8 @@ def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str,
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
if hasattr(args[0], '_tensor'):
return args[0], GraphInfo(fwd_in=[args[0]._tensor])
return args[0], GraphInfo(save_fwd_in=True)

def propagate(self, *args):
Expand Down Expand Up @@ -281,9 +282,9 @@ def flops_repr(flop: int) -> str:
str(node),
flops_repr(node.meta['fwd_flop']),
flops_repr(node.meta['bwd_flop']),
node.meta['save_fwd_in'],
mem_repr(node.meta['fwd_mem_out']),
mem_repr(node.meta['fwd_mem_tmp']),
mem_repr(calculate_fwd_in(node)),
mem_repr(calculate_fwd_out(node)),
mem_repr(calculate_fwd_tmp(node)),
mem_repr(node.meta['bwd_mem_out']),
mem_repr(node.meta['bwd_mem_tmp']),
])
Expand All @@ -295,7 +296,7 @@ def flops_repr(flop: int) -> str:
'Op',
'Forward FLOPs',
'Backward FLOPs',
'SAVE_FWD_IN',
'FWD_IN',
'FWD_OUT',
'FWD_TMP',
'BWD_OUT',
Expand Down
3 changes: 2 additions & 1 deletion colossalai/fx/profiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from .opcount import flop_mapping
from .tensor import MetaTensor
from .profiler import profile_function, profile_method, profile_module
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out

from .dataflow import GraphInfo
from .memory import parameter_size, activation_size, is_inplace
21 changes: 13 additions & 8 deletions colossalai/fx/profiler/dataflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from typing import Dict
from typing import Dict, List
from torch.fx import Graph, Node
from .memory import activation_size, is_inplace

Expand Down Expand Up @@ -39,16 +39,25 @@ class GraphInfo:
bwd_flop (int): The backward FLOPs of a certain node.
bwd_time (float): The real backward time (s) of a certain node.
save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
fwd_in (List): See the above illustration.
fwd_tmp (List): See the above illustration.
fwd_out (List): See the above illustration.
fwd_mem_tmp (int): See the above illustration.
fwd_mem_out (int): See the above illustration.
bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration.
"""

# TODO(super-dainiu): removed redundant items, currently all of them are necessary for development

fwd_flop: int = 0
fwd_time: float = 0.0
bwd_flop: int = 0
bwd_time: float = 0.0
save_fwd_in: bool = False
fwd_in: List = field(default_factory=list)
fwd_tmp: List = field(default_factory=list)
fwd_out: List = field(default_factory=list)
fwd_mem_tmp: int = 0
fwd_mem_out: int = 0
bwd_mem_tmp: int = 0
Comment on lines +51 to 63
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for compatibility of different InfoProp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now all fwd of MetaInfoProp are saved in List

Expand All @@ -60,10 +69,6 @@ def is_phase(n: Node, phase: Phase) -> bool:
return n.meta['phase'] == phase


def is_saved(n: Node):
return len(n.meta['saved_tensor'])


def autograd_graph_analysis(graph: Graph) -> GraphInfo:
"""Analyze the autograd node dependencies and find out the memory usage.
Basically the input graph should have all nodes marked for keyword `phase`.
Expand Down Expand Up @@ -113,9 +118,9 @@ def _peak_memory(deps: Dict[Node, int]):
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_mem_tmp` can be freed.
if is_phase(n, Phase.PLACEHOLDER):
graph_info.save_fwd_in |= activation_size(n.meta['saved_tensor']) > 0
graph_info.fwd_in += n.meta['saved_tensor']
if is_phase(n, Phase.FORWARD):
graph_info.fwd_mem_tmp += activation_size(n.meta['saved_tensor'])
graph_info.fwd_tmp += n.meta['saved_tensor']
elif is_phase(n, Phase.BACKWARD):
if len(n.users):
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
Expand Down
1 change: 1 addition & 0 deletions colossalai/fx/profiler/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .registry import meta_profiler_function, meta_profiler_module
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
from .profiler_function import *
from .profiler_module import *
from .profiler import profile_function, profile_method, profile_module
42 changes: 42 additions & 0 deletions colossalai/fx/profiler/experimental/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# for PyTorch 1.11 compatibility uses
import torch
from torch.fx import Node, GraphModule
from typing import Union, Dict, List, Tuple

__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]


def calculate_fwd_in(n: Node) -> bool:
"""A helper function to calculate `fwd_in`

Args:
n (Node): a node from the graph

Returns:
save_fwd_in (bool): the result of `save_fwd_in`
"""
return n.meta['save_fwd_in']


def calculate_fwd_tmp(n: Node) -> int:
"""A helper function to calculate `fwd_tmp`

Args:
n (Node): a node from the graph

Returns:
fwd_tmp (int): the result of `fwd_tmp`
"""
return n.meta["fwd_mem_tmp"]


def calculate_fwd_out(n: Node) -> int:
"""A helper function to calculate `fwd_out`

Args:
n (Node): a node from the graph

Returns:
fwd_out (int): the result of `fwd_out`
"""
return n.meta['fwd_mem_out']
Comment on lines +1 to +42
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is for compatibility uses because metainfoprop for torch 1.11 don't save tensors with uuid.

63 changes: 60 additions & 3 deletions colossalai/fx/profiler/memory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch
from torch.fx import Node
from torch.fx import Node, GraphModule
from typing import Union, Dict, List, Tuple
from . import META_COMPATIBILITY

__all__ = ['activation_size', 'parameter_size', 'is_inplace']
__all__ = [
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
]


def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
Expand All @@ -21,7 +23,7 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
elif isinstance(out, dict):
value_list = [v for _, v in out.items()]
act_size += activation_size(value_list)
elif isinstance(out, tuple) or isinstance(out, list):
elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
for element in out:
act_size += activation_size(element)
return act_size
Expand All @@ -42,6 +44,61 @@ def parameter_size(mod: torch.nn.Module) -> int:
return param_size


def calculate_fwd_in(n: Node) -> int:
"""A helper function to calculate `fwd_in`

Args:
n (Node): a node from the graph

Returns:
fwd_in (int): the result of `fwd_in`
"""
return activation_size(n.meta["fwd_in"])


def calculate_fwd_tmp(n: Node) -> int:
"""A helper function to calculate `fwd_tmp`
Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy.

Args:
n (Node): a node from the graph

Returns:
fwd_tmp (int): the result of `fwd_tmp`
"""

def is_relu_node(n: Node) -> bool:
if n.op == 'call_function':
return n.target in [torch.nn.functional.relu]
elif n.op == 'call_module':
return type(n.graph.owning_module.get_submodule(n.target)) in [torch.nn.ReLU]
return False

if not is_relu_node(n):
return activation_size(n.meta["fwd_tmp"])
return 0


def calculate_fwd_out(n: Node) -> int:
"""A helper function to calculate `fwd_out`

Args:
n (Node): a node from the graph

Returns:
fwd_out (int): the result of `fwd_out`
"""

def intersect(a, b):
return {k: a[k] for k in a if k in b}

fwd_in = dict()
for u in n.users:
fwd_in.update({x.uuid: x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')})
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
return activation_size(intersect(fwd_in, fwd_out))


def is_inplace(n: Node):
"""Get the inplace argument from torch.fx.Node

Expand Down
3 changes: 3 additions & 0 deletions colossalai/fx/profiler/opcount.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def zero_flop_jit(*args):
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
aten.embedding.default: elementwise_flop_counter(1, 0),
}

elementwise_flop_aten = [
Expand Down Expand Up @@ -304,10 +305,12 @@ def zero_flop_jit(*args):
aten.transpose.int,
aten._to_copy.default,
aten.unsqueeze.default,
aten.unbind.int,
aten._unsafe_view.default,
aten.view.default,
aten.where.self,
aten.zero_.default,
aten.zeros_like.default,
]

for op in zero_flop_aten:
Expand Down