Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx/tuning] tune performance on rotor with meta info. #1599

Merged
merged 4 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
83 changes: 14 additions & 69 deletions colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import List, Tuple
import torch
from torch.fx import GraphModule, Node
from torch.fx import Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import parameter_size
from colossalai.fx.profiler import activation_size, parameter_size
import math
from .linearize import linearize
from .utils import *
Expand Down Expand Up @@ -31,7 +30,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
# Build table
opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
## Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation

# Initialize borders of the tables for lmax-lmin = 0
for m in range(mmax + 1):
Expand Down Expand Up @@ -115,43 +114,6 @@ def _discretize(mem_unit, values):
return [math.ceil(value / mem_unit) for value in values]


def _compute_size(obj: torch.Tensor) -> int:
return obj.numel() * obj.element_size()


def _compute_output_size(node: List[Node]) -> int:
"""Compute the output size of a node

Args:
node (List[Node]): node, list of torch.fx.Node

Returns:
int: output size
"""

return node[-1].meta['tensor_meta'].numel * torch.tensor([],
dtype=node[-1].meta['tensor_meta'].dtype).element_size()


def _get_inplace(node: Node) -> bool:
"""Get the inplace argument from torch.fx.Node

Args:
node (Node): torch.fx.Node

Returns:
bool: indicates whether this op is inplace
"""

is_inplace = False
if node.op == "call_function":
is_inplace = node.kwargs.get("inplace", False)
elif node.op == "call_module":
is_inplace = getattr(node.graph.owning_module.get_submodule(node.target), "inplace", False)

return is_inplace


def _fwd_xbar(node: List[Node]) -> int:
"""Get the forward xbar of a node

Expand Down Expand Up @@ -221,61 +183,44 @@ def _get_deps_size():
for k, v in deps.items():
if v > 0:
deps_size += k.meta['bwd_mem_out']
if v == float('-inf'):
deps_size -= k.meta['fwd_mem_tmp'] + k.meta['fwd_mem_out']

return deps_size

bwd_mem_tmp = 0
deps = {}

# add all the users for last node into deps,
# as those nodes' gradient out will be stored in memory
for child in node[-1].users:
deps[child] = 1
for n in reversed(node):
deps[n] = len(n.all_input_nodes)
bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp'])

deps[n] = len(n.all_input_nodes)
for child in n.users:
if child in deps:
deps[child] -= 1

for key in list(deps.keys()):
if deps[key] == 0:
del deps[key]
if deps[child] <= 0:
deps[child] = float('-inf') # free

return bwd_mem_tmp


def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
def _construct_chain(node_list: List[List[Node]], input, mem_unit: int) -> Chain:

fwd_time = []
bwd_time = []

if isinstance(data, torch.Tensor):
xbar_sizes = [_compute_size(data)]
x_sizes = [_compute_size(data)]
elif isinstance(data, list) or isinstance(data, tuple):
xbar_sizes = [sum([_compute_size(obj) for obj in data])]
x_sizes = [sum([_compute_size(obj) for obj in data])]
elif isinstance(data, dict):
xbar_sizes = [sum([_compute_size(obj) for obj in data.values()])]
x_sizes = [sum([_compute_size(obj) for obj in data.values()])]

xbar_sizes = [activation_size(input)]
x_sizes = [activation_size(input)]
# currently we can't get the temp memory needed in fwd
tmp_fwd = [0] * len(node_list)
tmp_bwd = []

for idx, node in enumerate(node_list):
fwd_time.append(_fwd_time(node))
bwd_time.append(_bwd_time(node))
x_sizes.append(_compute_output_size(node))
x_sizes.append(node[-1].meta['fwd_mem_out'])
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
tmp_bwd.append(_get_bwd_mem_tmp(node))

# if a node with only one inplace op, we need to let x_bar = 0
if len(node) == 1 and _get_inplace(node[0]):
xbar_sizes[-1] = 0

bwd_time.append(0)

# currently we view loss backward temp as zero
Expand Down Expand Up @@ -381,7 +326,7 @@ def solver_rotor(gm: ColoGraphModule,
mem_limit: int,
mem_slots: int = 500,
cnode: List[str] = None,
eps: float = 0.02) -> ColoGraphModule:
eps: float = 0.0) -> ColoGraphModule:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should eps be a very small but non-zero value? e.g. 1e-6

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so 0.0 means no memory decay?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the default setting is 0.0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the memory decay is calculated by $M(1 - \epsilon)$, maybe the variable name is not that appropriate?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the eps will be something around 0.05 or less, 1e-6 is too small as the memory will be discretized.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So actually decay is unnecessary if i can estimate the memory accurately.
This can be removed in future if I have tested performance of all models

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think eps is ok, you can just provide the equation for memory decay in line 338 to explain how eps affect memory decay.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this option could be provided for user as we might not be able to catch up with all the models in reality, so there might be some cases our meta info provides bad estimations. With this option the user might be able to tune the solver if necessary.

"""solver that automatically find activation checkpoint in rotor's manner

Args:
Expand All @@ -390,7 +335,7 @@ def solver_rotor(gm: ColoGraphModule,
mem_limit (int): memory budget in Byte.
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
cnode (List[Node], optional): common node list for linearize. Defaults to None.
eps (float): epsilon for memory decay. Defaults to 0.02
eps (float): epsilon for memory decay. Defaults to 0.0

Returns:
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
Expand Down
6 changes: 5 additions & 1 deletion colossalai/fx/passes/algorithms/linearize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Any
from torch.fx import GraphModule, Node
from colossalai.fx.profiler import is_inplace

# Common nodes are type of nodes that could be seen as attributes and remain
# unchanged throughout the whole model, it will be used several times by
Expand Down Expand Up @@ -41,6 +42,9 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
Returns:
List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner.

Remarks:
We merge the inplace ops into the previous node.
"""

def _is_sink() -> bool:
Expand All @@ -50,7 +54,7 @@ def _is_sink() -> bool:
bool
"""

return not sum([v for _, v in deps.items()])
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a simple example here to show the different between new linearize and older version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[input] 15 15 15
[conv1] 78 78 78
[bn1, relu] 78 78 78
[maxpool] 20 78 59
[layer1_0_conv1, layer1_0_bn1, layer1_0_relu, layer1_0_conv2, layer1_0_bn2, add, layer1_0_relu_1] 20 78 59
[layer1_1_conv1, layer1_1_bn1, layer1_1_relu, layer1_1_conv2, layer1_1_bn2, add_1, layer1_1_relu_1] 20 78 39
[layer2_0_conv1, layer2_0_bn1, layer2_0_relu, layer2_0_conv2, layer2_0_bn2, layer2_0_downsample_0, layer2_0_downsample_1, add_2, layer2_0_relu_1] 10 49 30
[layer2_1_conv1, layer2_1_bn1, layer2_1_relu, layer2_1_conv2, layer2_1_bn2, add_3, layer2_1_relu_1] 10 39 20
[layer3_0_conv1, layer3_0_bn1, layer3_0_relu, layer3_0_conv2, layer3_0_bn2, layer3_0_downsample_0, layer3_0_downsample_1, add_4, layer3_0_relu_1] 5 25 15
[layer3_1_conv1, layer3_1_bn1, layer3_1_relu, layer3_1_conv2, layer3_1_bn2, add_5, layer3_1_relu_1] 5 20 10
[layer4_0_conv1, layer4_0_bn1, layer4_0_relu, layer4_0_conv2, layer4_0_bn2, layer4_0_downsample_0, layer4_0_downsample_1, add_6, layer4_0_relu_1] 3 13 12
[layer4_1_conv1, layer4_1_bn1, layer4_1_relu, layer4_1_conv2, layer4_1_bn2, add_7, layer4_1_relu_1] 0 8 3
[avgpool] 0 0 1
[flatten] 1 1 1
[fc] 1 1 0
[input] 15 15 15
[conv1] 78 78 78
[bn1] 0 0 0
[relu] 78 78 78
[maxpool] 20 78 78
[layer1_0_conv1, layer1_0_bn1, layer1_0_relu, layer1_0_conv2, layer1_0_bn2, add] 0 58 0
[layer1_0_relu_1] 20 20 78
[layer1_1_conv1, layer1_1_bn1, layer1_1_relu, layer1_1_conv2, layer1_1_bn2, add_1] 0 58 0
[layer1_1_relu_1] 20 20 49
[layer2_0_conv1, layer2_0_bn1, layer2_0_relu, layer2_0_conv2, layer2_0_bn2, layer2_0_downsample_0, layer2_0_downsample_1, add_2] 0 39 0
[layer2_0_relu_1] 10 10 39
[layer2_1_conv1, layer2_1_bn1, layer2_1_relu, layer2_1_conv2, layer2_1_bn2, add_3] 0 29 0
[layer2_1_relu_1] 10 10 25
[layer3_0_conv1, layer3_0_bn1, layer3_0_relu, layer3_0_conv2, layer3_0_bn2, layer3_0_downsample_0, layer3_0_downsample_1, add_4] 0 20 0
[layer3_0_relu_1] 5 5 20
[layer3_1_conv1, layer3_1_bn1, layer3_1_relu, layer3_1_conv2, layer3_1_bn2, add_5] 0 15 0
[layer3_1_relu_1] 5 5 13
[layer4_0_conv1, layer4_0_bn1, layer4_0_relu, layer4_0_conv2, layer4_0_bn2, layer4_0_downsample_0, layer4_0_downsample_1, add_6] 0 10 0
[layer4_0_relu_1] 3 3 12
[layer4_1_conv1, layer4_1_bn1, layer4_1_relu, layer4_1_conv2, layer4_1_bn2, add_7] 0 8 0
[layer4_1_relu_1] 0 0 3
[avgpool] 0 0 1
[flatten] 1 1 1
[fc] 1 1 0


# make sure that item in cnode is valid
if cnode:
Expand Down
2 changes: 1 addition & 1 deletion colossalai/fx/profiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module

from .dataflow import GraphInfo
from .memory import parameter_size, activation_size
from .memory import parameter_size, activation_size, is_inplace
28 changes: 16 additions & 12 deletions colossalai/fx/profiler/dataflow.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import Dict
from torch.fx import Graph, Node
from .memory import activation_size
from .memory import activation_size, is_inplace
from . import META_COMPATIBILITY
if META_COMPATIBILITY:
from .memory import NORMALIZATION_ATEN, CLONE_ATEN


class Phase(Enum):
FORWARD = 0
LOSS = 1
BACKWARD = 2
PLACEHOLDER = 3
BACKWARD = 1
PLACEHOLDER = 2


@dataclass
Expand Down Expand Up @@ -86,8 +87,10 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
def _peak_memory(deps: Dict[Node, int]):
peak_mem = 0
for k, v in deps.items():
if v > 0:
if v > 0 and is_phase(k, Phase.BACKWARD) and not any(map(is_inplace, k.users)):
peak_mem += activation_size(k.meta['out'])
if v <= float('-inf') and is_saved(k) and (k.target not in NORMALIZATION_ATEN):
peak_mem -= activation_size(k.meta['out'])
return peak_mem

# deps is used to track all the memory dependencies of the graph.
Expand All @@ -96,7 +99,7 @@ def _peak_memory(deps: Dict[Node, int]):

for n in graph.nodes:
n: Node
if is_saved(n) and not any(map(partial(is_phase, phase=Phase.LOSS), n.users)):
if is_saved(n) and (n.target not in NORMALIZATION_ATEN) or any(map(lambda x: x.target in CLONE_ATEN, n.users)):
# A forward tensor who is marked `save` but is not
# an input to `loss` should be saved during forward.
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
Expand All @@ -110,13 +113,14 @@ def _peak_memory(deps: Dict[Node, int]):
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
elif is_phase(n, Phase.BACKWARD):
if len(n.users):
# liveness analysis is only used in backward
deps[n] = len(n.users)
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
for input_n in n.all_input_nodes:
if input_n in deps:
deps[input_n] -= 1
else:
# TODO: some of the bwd_mem_out might be model parameters.
# basically a backward node without user is a `grad_out` node
graph_info.bwd_mem_out += activation_size(n.meta['out'])
for input_n in n.all_input_nodes:
if input_n in deps:
deps[input_n] -= 1
if deps[input_n] <= 0:
deps[input_n] = float('-inf')
return graph_info
36 changes: 34 additions & 2 deletions colossalai/fx/profiler/memory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
from torch.fx import Node
from typing import Union, Dict, List, Tuple
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
from . import META_COMPATIBILITY

__all__ = ['activation_size', 'parameter_size']
__all__ = ['activation_size', 'parameter_size', 'is_inplace']

if META_COMPATIBILITY:
aten = torch.ops.aten
Expand All @@ -21,14 +22,25 @@
aten.bernoulli_.float,

# inplace reshaping
aten.copy_.default,
aten.detach.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
]

__all__ += ['INPLACE_ATEN', 'WEIRD_OPS']
NORMALIZATION_ATEN = [
aten.native_batch_norm.default,
aten.native_layer_norm.default,
# aten.max_pool2d_with_indices.default,
]

CLONE_ATEN = [
aten.clone.default,
]

__all__ += ['INPLACE_ATEN', 'WEIRD_OPS', 'NORMALIZATION_ATEN', 'CLONE_ATEN']

else:
# TODO fill out the inplace ops
Expand Down Expand Up @@ -106,3 +118,23 @@ def parameter_size(mod: torch.nn.Module) -> int:
for param in mod.parameters():
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
return param_size


def is_inplace(n: Node):
"""Get the inplace argument from torch.fx.Node

Args:
node (Node): torch.fx.Node

Returns:
bool: indicates whether this op is inplace
"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
if META_COMPATIBILITY and n.target in INPLACE_ATEN:
inplace = True
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)

return inplace
1 change: 1 addition & 0 deletions colossalai/fx/profiler/opcount.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def zero_flop_jit(*args):
aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
}

elementwise_flop_aten = [
Expand Down