Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx/profiler] assigned UUID to each unrecorded tensor/ improved performance on GPT-2 #1679

Merged
merged 33 commits into from Oct 11, 2022

Conversation

super-dainiu
Copy link
Contributor

@super-dainiu super-dainiu commented Oct 5, 2022

What's new?

  1. Provide a function to calculate fwd_in, fwd_tmp and fwd_out. Hopefully, this will avoid a lot of misunderstandings.
  2. Assigned UUID to each unrecorded tensor.
  3. Improved performance on GPT-2

Bugs in existing code

  1. C version of the rotor solver is not correct. cc @Cypher30

C version results

Model mem_limit real_consumption train step time solver time
<function densenet121 at 0x7feff4622430> mem_limit: None real memory consumption: 15996.476 MB train step time: 171.204 MS  
<function densenet121 at 0x7feff4622430> mem_limit: 5300.0 MB real memory consumption: 5300.000 MB train step time: inf MS solver time: 0.000 MS
<function densenet121 at 0x7feff4622430> mem_limit: 6680.0 MB real memory consumption: 6680.000 MB train step time: inf MS solver time: 0.000 MS
<function densenet121 at 0x7feff4622430> mem_limit: 8060.0 MB real memory consumption: 8060.000 MB train step time: inf MS solver time: 0.000 MS
<function densenet121 at 0x7feff4622430> mem_limit: 9440.0 MB real memory consumption: 9440.000 MB train step time: inf MS solver time: 0.000 MS
<function densenet121 at 0x7feff4622430> mem_limit: 10820.0 MB real memory consumption: 10820.000 MB train step time: inf MS solver time: 0.000 MS
<function densenet121 at 0x7feff4622430> mem_limit: 12200.0 MB real memory consumption: 12200.000 MB train step time: inf MS solver time: 0.000 MS
<function densenet121 at 0x7feff4622430> mem_limit: 13580.0 MB real memory consumption: 13580.000 MB train step time: inf MS solver time: 0.000 MS
<function densenet121 at 0x7feff4622430> mem_limit: 14960.0 MB real memory consumption: 14960.000 MB train step time: inf MS solver time: 0.000 MS
<function densenet121 at 0x7feff4622430> mem_limit: 16340.0 MB real memory consumption: 15800.507 MB train step time: 171.508 MS solver time: 2038.141 MS
<function densenet121 at 0x7feff4622430> mem_limit: 17720.0 MB real memory consumption: 15800.476 MB train step time: 171.513 MS solver time: 2205.862 MS
<function densenet121 at 0x7feff4622430> mem_limit: 19100.0 MB real memory consumption: 15800.476 MB train step time: 171.499 MS solver time: 2239.843 MS

Force Python results

Model mem_limit real_consumption train step time solver time
<function densenet121 at 0x7f2bc32e79d0> mem_limit: None real memory consumption: 16027.334 MB train step time: 171.072 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 5300.0 MB real memory consumption: 5300.000 MB train step time: inf MS solver time: 0.000 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 6690.0 MB real memory consumption: 6313.646 MB train step time: 208.779 MS solver time: 1582.991 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 8080.0 MB real memory consumption: 7292.647 MB train step time: 204.590 MS solver time: 1605.730 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 9470.0 MB real memory consumption: 9500.286 MB train step time: 195.120 MS solver time: 1625.514 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 10860.0 MB real memory consumption: 10728.410 MB train step time: 190.821 MS solver time: 1571.043 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 12250.0 MB real memory consumption: 11897.869 MB train step time: 186.108 MS solver time: 1640.221 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 13640.0 MB real memory consumption: 12093.869 MB train step time: 185.596 MS solver time: 1590.731 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 15030.0 MB real memory consumption: 14851.113 MB train step time: 175.736 MS solver time: 1633.711 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 16420.0 MB real memory consumption: 16027.334 MB train step time: 171.635 MS solver time: 1585.685 MS
<function densenet121 at 0x7f2bc32e79d0> mem_limit: 17810.0 MB real memory consumption: 16027.334 MB train step time: 171.443 MS solver time: 1645.066 MS
  1. What is wrong with this? ✔
    image

  2. Logger always requires a launch, some modifications should be made either on the logger or on test_linearize.py. cc @Cypher30

Tests

Passed all the tests/test_fx tests locally.
image

@super-dainiu super-dainiu changed the title Speedup [fx/profiler] assigned UUID to each unrecorded tensor/ improved performance on GPT-2 Oct 5, 2022
Comment on lines +51 to 63
# TODO(super-dainiu): removed redundant items, currently all of them are necessary for development

fwd_flop: int = 0
fwd_time: float = 0.0
bwd_flop: int = 0
bwd_time: float = 0.0
save_fwd_in: bool = False
fwd_in: List = field(default_factory=list)
fwd_tmp: List = field(default_factory=list)
fwd_out: List = field(default_factory=list)
fwd_mem_tmp: int = 0
fwd_mem_out: int = 0
bwd_mem_tmp: int = 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for compatibility of different InfoProp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now all fwd of MetaInfoProp are saved in List

# If there is an argument that this `call_function` is inplace, we should
# still run the profiling but discard some results regarding `target`
global do_not_cache
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens if we replace inplace=True with inplace=False, because some tensors might be saved for backward during inplace=False, and it will be cached.

Comment on lines +12 to +15
def set_uuid(x):
if isinstance(x, torch.Tensor):
if not hasattr(x, 'uuid'):
setattr(x, 'uuid', uuid.uuid4())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uuid is set if and only if the tensor has no uuid

@@ -15,7 +19,7 @@
with_codegen = False


@pytest.mark.skip(reason='TODO: modify calculations in rotor')
@pytest.mark.skip(reason='TODO: modify the logger')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See error 3

codegen = ActivationCheckpointCodeGen()
gm.graph.set_codegen(codegen)
if solver == solver_rotor:
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500)
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500, force_python=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See error 1

Copy link
Contributor

@Cypher30 Cypher30 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check the logger problem, and the build warning in code factor might because of shell=True, I will dive into this problem.

codegen = ActivationCheckpointCodeGen()
gm.graph.set_codegen(codegen)
if solver == solver_rotor:
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500)
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500, force_python=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently I cannot reproduce this problem.

Comment on lines +1 to +42
# for PyTorch 1.11 compatibility uses
import torch
from torch.fx import Node, GraphModule
from typing import Union, Dict, List, Tuple

__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]


def calculate_fwd_in(n: Node) -> bool:
"""A helper function to calculate `fwd_in`

Args:
n (Node): a node from the graph

Returns:
save_fwd_in (bool): the result of `save_fwd_in`
"""
return n.meta['save_fwd_in']


def calculate_fwd_tmp(n: Node) -> int:
"""A helper function to calculate `fwd_tmp`

Args:
n (Node): a node from the graph

Returns:
fwd_tmp (int): the result of `fwd_tmp`
"""
return n.meta["fwd_mem_tmp"]


def calculate_fwd_out(n: Node) -> int:
"""A helper function to calculate `fwd_out`

Args:
n (Node): a node from the graph

Returns:
fwd_out (int): the result of `fwd_out`
"""
return n.meta['fwd_mem_out']
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is for compatibility uses because metainfoprop for torch 1.11 don't save tensors with uuid.

Comment on lines +168 to +180
def _get_fwd_mem_tmp(node: List[Node]) -> int:
"""Get the forward temp memory of a node
This could be done by subtracting the saved activation from all output of a node

Args:
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph

Returns:
int: forward temp memory, unit Byte
"""
n = node[-1]
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now fwd_out that are not saved for next node may be regarded as fwd_tmp

Copy link
Contributor

@Cypher30 Cypher30 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great Work! I will approve this PR first and check the C version issue.

@super-dainiu super-dainiu merged commit 3dd6994 into hpcaitech:main Oct 11, 2022
@super-dainiu super-dainiu deleted the speedup branch October 11, 2022 06:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants