Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx/profiler] tuned the calculation of memory estimation #1619

Merged
merged 12 commits into from Sep 23, 2022

Conversation

super-dainiu
Copy link
Contributor

What's modified?

To do MetaInfoProp on arbitrary physical devices, we need to wrap a torch.Tensor with MetaTensor, which has a property of fake_device. fake_device is used by __torch_dispatch__ to find an torch.ops.aten for execution, while in real execution, the tensor is moved back to device='meta'. So now you should do MetaInfoProp on a tm.resnet18 on CPU with the following script.

model = tm.resnet18()
input = MetaTensor(torch.rand(10000, 3, 224, 224, device='meta'), fake_device='cpu')
MetaInfoProp(gm).run(input)

And on GPU with the following script.

model = tm.resnet18().cuda()    # don't forget to move your model to cuda as well!
input = MetaTensor(torch.rand(10000, 3, 224, 224, device='meta'), fake_device='cuda')
MetaInfoProp(gm).run(input)

Now you might observe some different patterns in estimated memory, because torch.autograd behaves differently on different devices.

Improvements

  1. The computation graph is completely different when executing on 'cpu' and 'cuda'.

  2. Tensor dtype may change during execution.

  3. Not every tensor that passes through torch.autograd.graph.saved_tensors_hooks is saved to device. tensor.data_ptr() marks its physical address. No duplicated tensor.data_ptr() should be saved.

    1. One stupid thing is that we always have tensor.data_ptr()=0 on device='meta'
    2. Luckily in Python frontend, we can still track the same tensor with the unique identifier of tensor.data_ptr as a torch._C function.
  4. See the graph below for illustrations.
    image
    image

    1. Whether or not fwd_out should be saved is determined by the node's users. If and only if any of the node's users have saved fwd_in can this node have fwd_out.
  5. nn.MultiheadAttention has (x, x, x) as input, but only one x should be saved for backward.

Concerns

Backward memory is indeed too hard to estimate.

Comment on lines +219 to 252
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.cudnn_batch_norm.default)
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1)

output = torch.empty_like(input)
running_mean = torch.empty((n_input), device='meta')
running_var = torch.empty((n_input), device='meta')
reserve = torch.empty((0), dtype=torch.uint8, device='meta')
return output, running_mean, running_var, reserve


# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
# NB: CuDNN only implements the backward algorithm for batchnorm
# in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default)
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
save_mean, save_invstd, eps, reserve):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight)
return dX, dgamma, dbeta


@register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
bs = input.size(0)
n_input = input.size(1)

output = torch.empty_like(input)
running_mean = torch.empty((bs, n_input, 1), device='meta')
running_var = torch.empty((bs, n_input, 1), device='meta')
return output, running_mean, running_var
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A sad truth is autograd on cuda use cudnn_batch_norm

Comment on lines +381 to +391
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
# notice that mask is bool
output = torch.empty_like(input)
mask = torch.empty_like(input, dtype=torch.bool)
return output, mask


# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout_backward.default)
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
return torch.empty_like(grad)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another sad truth is autograd on cuda use native_dropout

Comment on lines +127 to +129
xbar += n.meta['fwd_mem_tmp']
if any(map(lambda x: x.meta['save_fwd_in'], n.users)):
xbar += n.meta['fwd_mem_out']
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not every node has fwd_mem_out.

@@ -224,7 +222,7 @@ def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str,
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return args[0], GraphInfo(fwd_mem_in=activation_size(args[0]))
return args[0], GraphInfo(save_fwd_in=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output node saves the last fwd_mem_out for sure

elif is_phase(n, Phase.BACKWARD):
if len(n.users):
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
else:
# TODO: some of the bwd_mem_out might be model parameters.
# basically a backward node without user is a `grad_out` node
graph_info.bwd_mem_out += activation_size(n.meta['out'])
graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor'])
for input_n in n.all_input_nodes:
if input_n in deps:
deps[input_n] -= 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these analysis are still naive, so sad

Comment on lines -9 to -84
if META_COMPATIBILITY:
aten = torch.ops.aten

WEIRD_OPS = [
torch.where,
]

INPLACE_ATEN = [
aten.add_.Tensor,
aten.sub_.Tensor,
aten.div_.Tensor,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.bernoulli_.float,

# inplace reshaping
aten.copy_.default,
aten.detach.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
]

NORMALIZATION_ATEN = [
aten.native_batch_norm.default,
aten.native_layer_norm.default,
# aten.max_pool2d_with_indices.default,
]

CLONE_ATEN = [
aten.clone.default,
]

__all__ += ['INPLACE_ATEN', 'WEIRD_OPS', 'NORMALIZATION_ATEN', 'CLONE_ATEN']

else:
# TODO fill out the inplace ops
INPLACE_OPS = [
add,
sub,
mul,
floordiv,
neg,
pos,
getitem,
setitem,
getattr,
torch.Tensor.cpu,
]

# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
'size',
'view',
'unsqueeze',
'to',
'type',
'flatten',
]

# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'chunk',
'contiguous',
'expand',
'mean',
'split',
]
__all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move them to constant.py

@@ -201,6 +202,8 @@ def zero_flop_jit(*args):
# normalization
aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

notice that cudnn_batch_norm_backward is different because it is only used when training=True

Comment on lines +30 to +32
# super-dainiu:
# x.detach() will change the unique identifier of data_ptr
# we need to handle this in a stupid way
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x.detach() will change the unique identifier of data_ptr. we need to handle this in a stupid way

from .tensor import MetaTensor
from .opcount import flop_mapping

__all__ = ['profile_function', 'profile_module', 'profile_method']

# super-dainiu: this cache should be global, otherwise it cannot
# track duplicated tensors between nodes
cache = set()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this cache should be global, otherwise it cannot track duplicated tensors between nodes

Comment on lines +280 to +283
# still run the profiling but discard some results regarding `module`.
inplace = getattr(module, 'inplace', False)
if inplace:
module.inplace = False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no longer skip profiling for inplace=True

@super-dainiu
Copy link
Contributor Author

image

@FrankLeeeee FrankLeeeee merged commit d967779 into hpcaitech:main Sep 23, 2022
@super-dainiu super-dainiu deleted the tuning/meta_backward branch September 23, 2022 06:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants