Skip to content

Commit

Permalink
[fx/profiler] tuned the calculation of memory estimation (#1619)
Browse files Browse the repository at this point in the history
* [fx] tuned the meta info and rotor solver.

* [fx] remove import.

* [fx] remove import.

* [fx] remove import.

* [fx] tune the meta calculations.

* [fx] polish comments.

* [fx] remove assertions.

* [fx] modify test cases.

* [fx] modify test cases.

* [fx] optimize import.

* [fx
  • Loading branch information
super-dainiu committed Sep 23, 2022
1 parent f7f2248 commit d967779
Show file tree
Hide file tree
Showing 16 changed files with 413 additions and 207 deletions.
56 changes: 52 additions & 4 deletions colossalai/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ def meta_hardswish(input: torch.Tensor):
return torch.empty_like(input)


@register_meta(aten.hardtanh.default)
def meta_hardtanh(input: torch.Tensor, min, max):
return torch.empty_like(input)


@register_meta(aten.hardswish_backward.default)
def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor):
grad_in = torch.empty_like(input)
Expand All @@ -189,7 +194,7 @@ def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val:

@register_meta(aten.roll.default)
def meta_roll(input: torch.Tensor, shifts, dims):
return torch.empty_like(input)
return input


@register_meta(aten.native_batch_norm.default)
Expand All @@ -211,13 +216,39 @@ def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
return dX, dgamma, dbeta


@register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.cudnn_batch_norm.default)
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1)

output = torch.empty_like(input)
running_mean = torch.empty((n_input), device='meta')
running_var = torch.empty((n_input), device='meta')
reserve = torch.empty((0), dtype=torch.uint8, device='meta')
return output, running_mean, running_var, reserve


# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
# NB: CuDNN only implements the backward algorithm for batchnorm
# in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default)
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
save_mean, save_invstd, eps, reserve):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight)
return dX, dgamma, dbeta


@register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
bs = input.size(0)
n_input = input.size(1)

output = torch.empty_like(input)
running_mean = torch.empty((bs, n_input, 1), device='meta')
running_var = torch.empty((bs, n_input, 1), device='meta')
return output, running_mean, running_var


Expand Down Expand Up @@ -338,6 +369,23 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens
layout=grad_output.layout)


# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
return torch.empty_like(condition)
result_type = torch.result_type(self, other)
return torch.empty_like(self, dtype=result_type)


# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout.default)
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
# notice that mask is bool
output = torch.empty_like(input)
mask = torch.empty_like(input, dtype=torch.bool)
return output, mask


# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout_backward.default)
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
return torch.empty_like(grad)
12 changes: 9 additions & 3 deletions colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch.fx import Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, parameter_size
from colossalai.fx.profiler.tensor import MetaTensor
import math
from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
Expand Down Expand Up @@ -123,7 +124,9 @@ def _fwd_xbar(node: List[Node]) -> int:

xbar = 0
for n in node:
xbar += n.meta['fwd_mem_tmp'] + n.meta['fwd_mem_out']
xbar += n.meta['fwd_mem_tmp']
if any(map(lambda x: x.meta['save_fwd_in'], n.users)):
xbar += n.meta['fwd_mem_out']
return xbar


Expand Down Expand Up @@ -177,10 +180,13 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
def _get_deps_size():
deps_size = 0
for k, v in deps.items():
k: Node
if v > 0:
deps_size += k.meta['bwd_mem_out']
if v == float('-inf'):
deps_size -= k.meta['fwd_mem_tmp'] + k.meta['fwd_mem_out']
deps_size -= k.meta['fwd_mem_tmp']
if any(map(lambda x: x.meta['save_fwd_in'], k.users)):
deps_size -= k.meta['fwd_mem_out']

return deps_size

Expand Down Expand Up @@ -333,8 +339,8 @@ def solver_rotor(gm: ColoGraphModule,
"""

node_list = linearize(gm, cnode)
mem_limit -= parameter_size(gm)
mem_unit = mem_limit * (1.0 - eps) // mem_slots
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data)

chain: Chain = _construct_chain(node_list, data)
Expand Down
6 changes: 2 additions & 4 deletions colossalai/fx/passes/meta_info_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,9 @@ def extract_tensor_meta(obj):

tensor_meta = tree_map(extract_tensor_meta, result)
n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info), 'fwd_mem_out': 0} # extend MetaInfo to `n.meta`
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
for par in n.all_input_nodes:
par.meta['fwd_mem_out'] = max(par.meta.get('fwd_mem_out', 0), n.meta.get('fwd_mem_in', 0))
n.meta['type'] = type(result)

# retain the autograd graph
Expand Down Expand Up @@ -224,7 +222,7 @@ def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str,
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return args[0], GraphInfo(fwd_mem_in=activation_size(args[0]))
return args[0], GraphInfo(save_fwd_in=True)

def propagate(self, *args):
"""
Expand Down
78 changes: 78 additions & 0 deletions colossalai/fx/profiler/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
from . import META_COMPATIBILITY

__all__ = []

if META_COMPATIBILITY:
aten = torch.ops.aten

ALIAS_ATEN = [
# inplace reshaping
aten.detach.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
]

INPLACE_NEW = [
aten.empty_like.default,
aten.new_empty_strided.default,
]

INPLACE_MATH_ATEN = [
aten.add_.Tensor,
aten.sub_.Tensor,
aten.div_.Tensor,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.bernoulli_.float,
]

CLONE_ATEN = [
aten.clone.default,
]

__all__ += ['INPLACE_ATEN', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']

else:
# TODO fill out the inplace ops
INPLACE_OPS = [
add,
sub,
mul,
floordiv,
neg,
pos,
getitem,
setitem,
getattr,
torch.Tensor.cpu,
]

# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
'size',
'view',
'unsqueeze',
'to',
'type',
'flatten',
]

# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'chunk',
'contiguous',
'expand',
'mean',
'split',
]
__all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
54 changes: 27 additions & 27 deletions colossalai/fx/profiler/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
from typing import Dict
from torch.fx import Graph, Node
from .memory import activation_size, is_inplace
from . import META_COMPATIBILITY
if META_COMPATIBILITY:
from .memory import NORMALIZATION_ATEN, CLONE_ATEN


class Phase(Enum):
Expand All @@ -23,29 +20,32 @@ class GraphInfo:
============================================================================
-------------------------------
| Node |
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`.
placeholders saved for | | \__________ | |
backward. | | \ | |
| [fwd_tmp] ------> [bwd_tmp] | <-----
| | \_________ | | [bwd_tmp] marks the peak memory
| / \ \ | | in backward pass.
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
in [fwd_tmp] because | | | \_____ | |
it is not saved for | | | \ | |
backward. -------------------------------
in [fwd_tmp] because | | \_____ | |
it is not saved for | | \ | |
backward. | [fwd_out] \ | | <----- [fwd_out] is [fwd_in] for the next node.
-------------------------------
============================================================================
Attributes:
fwd_flop (int): The forward FLOPs of a certain node
bwd_flop (int): The backward FLOPs of a certain node.
fwd_mem_in (int): See the above illustration.
save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
fwd_mem_tmp (int): See the above illustration.
fwd_mem_out (int): See the above illustration.
bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration.
"""
fwd_flop: int = 0
bwd_flop: int = 0
fwd_mem_in: int = 0
save_fwd_in: bool = False
fwd_mem_tmp: int = 0
fwd_mem_out: int = 0
bwd_mem_tmp: int = 0
bwd_mem_out: int = 0

Expand All @@ -56,7 +56,7 @@ def is_phase(n: Node, phase: Phase) -> bool:


def is_saved(n: Node):
return n.meta.get('saved', False)
return len(n.meta['saved_tensor'])


def autograd_graph_analysis(graph: Graph) -> GraphInfo:
Expand Down Expand Up @@ -87,10 +87,10 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
def _peak_memory(deps: Dict[Node, int]):
peak_mem = 0
for k, v in deps.items():
if v > 0 and is_phase(k, Phase.BACKWARD) and not any(map(is_inplace, k.users)):
peak_mem += activation_size(k.meta['out'])
if v <= float('-inf') and is_saved(k) and (k.target not in NORMALIZATION_ATEN):
peak_mem -= activation_size(k.meta['out'])
if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k):
peak_mem += activation_size(k.meta['saved_tensor'])
if v <= float('-inf') and is_phase(k, Phase.FORWARD):
peak_mem -= activation_size(k.meta['saved_tensor'])
return peak_mem

# deps is used to track all the memory dependencies of the graph.
Expand All @@ -99,25 +99,25 @@ def _peak_memory(deps: Dict[Node, int]):

for n in graph.nodes:
n: Node
if is_saved(n) and (n.target not in NORMALIZATION_ATEN) or any(map(lambda x: x.target in CLONE_ATEN, n.users)):
# A forward tensor who is marked `save` but is not
# an input to `loss` should be saved during forward.
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
# Any `fwd_mem_in` should be kept in memory even this function
# is checkpointed.
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_mem_tmp` can be freed.
if is_phase(n, Phase.PLACEHOLDER):
graph_info.fwd_mem_in += activation_size(n.meta['out'])
if is_phase(n, Phase.FORWARD):
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
deps[n] = len(n.users)
# A forward tensor who is marked `save` but is also
# an input to `Phase.FORWARD` should be saved during forward.
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
# Any `fwd_mem_in` should be kept in memory even this function
# is checkpointed.
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_mem_tmp` can be freed.
if is_phase(n, Phase.PLACEHOLDER):
graph_info.save_fwd_in |= activation_size(n.meta['saved_tensor']) > 0
if is_phase(n, Phase.FORWARD):
graph_info.fwd_mem_tmp += activation_size(n.meta['saved_tensor'])
elif is_phase(n, Phase.BACKWARD):
if len(n.users):
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
else:
# TODO: some of the bwd_mem_out might be model parameters.
# basically a backward node without user is a `grad_out` node
graph_info.bwd_mem_out += activation_size(n.meta['out'])
graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor'])
for input_n in n.all_input_nodes:
if input_n in deps:
deps[input_n] -= 1
Expand Down
3 changes: 2 additions & 1 deletion colossalai/fx/profiler/experimental/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import torch
from torch.fx.node import Argument, Target
from . import meta_profiler_function, meta_profiler_module
from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS
from ..memory import activation_size
from ..constant import INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS

__all__ = ['profile_function', 'profile_module', 'profile_method']

Expand Down

0 comments on commit d967779

Please sign in to comment.