New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fx/profiler] assigned UUID to each unrecorded tensor/ improved performance on GPT-2 #1679
Conversation
# TODO(super-dainiu): removed redundant items, currently all of them are necessary for development | ||
|
||
fwd_flop: int = 0 | ||
fwd_time: float = 0.0 | ||
bwd_flop: int = 0 | ||
bwd_time: float = 0.0 | ||
save_fwd_in: bool = False | ||
fwd_in: List = field(default_factory=list) | ||
fwd_tmp: List = field(default_factory=list) | ||
fwd_out: List = field(default_factory=list) | ||
fwd_mem_tmp: int = 0 | ||
fwd_mem_out: int = 0 | ||
bwd_mem_tmp: int = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for compatibility of different InfoProp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now all fwd
of MetaInfoProp
are saved in List
# If there is an argument that this `call_function` is inplace, we should | ||
# still run the profiling but discard some results regarding `target` | ||
global do_not_cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This happens if we replace inplace=True
with inplace=False
, because some tensors might be saved for backward during inplace=False
, and it will be cached.
def set_uuid(x): | ||
if isinstance(x, torch.Tensor): | ||
if not hasattr(x, 'uuid'): | ||
setattr(x, 'uuid', uuid.uuid4()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uuid is set if and only if the tensor has no uuid
@@ -15,7 +19,7 @@ | |||
with_codegen = False | |||
|
|||
|
|||
@pytest.mark.skip(reason='TODO: modify calculations in rotor') | |||
@pytest.mark.skip(reason='TODO: modify the logger') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See error 3
codegen = ActivationCheckpointCodeGen() | ||
gm.graph.set_codegen(codegen) | ||
if solver == solver_rotor: | ||
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500) | ||
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500, force_python=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See error 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will check the logger problem, and the build warning in code factor might because of shell=True, I will dive into this problem.
codegen = ActivationCheckpointCodeGen() | ||
gm.graph.set_codegen(codegen) | ||
if solver == solver_rotor: | ||
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500) | ||
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500, force_python=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently I cannot reproduce this problem.
# for PyTorch 1.11 compatibility uses | ||
import torch | ||
from torch.fx import Node, GraphModule | ||
from typing import Union, Dict, List, Tuple | ||
|
||
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"] | ||
|
||
|
||
def calculate_fwd_in(n: Node) -> bool: | ||
"""A helper function to calculate `fwd_in` | ||
|
||
Args: | ||
n (Node): a node from the graph | ||
|
||
Returns: | ||
save_fwd_in (bool): the result of `save_fwd_in` | ||
""" | ||
return n.meta['save_fwd_in'] | ||
|
||
|
||
def calculate_fwd_tmp(n: Node) -> int: | ||
"""A helper function to calculate `fwd_tmp` | ||
|
||
Args: | ||
n (Node): a node from the graph | ||
|
||
Returns: | ||
fwd_tmp (int): the result of `fwd_tmp` | ||
""" | ||
return n.meta["fwd_mem_tmp"] | ||
|
||
|
||
def calculate_fwd_out(n: Node) -> int: | ||
"""A helper function to calculate `fwd_out` | ||
|
||
Args: | ||
n (Node): a node from the graph | ||
|
||
Returns: | ||
fwd_out (int): the result of `fwd_out` | ||
""" | ||
return n.meta['fwd_mem_out'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is for compatibility uses because metainfoprop for torch 1.11 don't save tensors with uuid.
def _get_fwd_mem_tmp(node: List[Node]) -> int: | ||
"""Get the forward temp memory of a node | ||
This could be done by subtracting the saved activation from all output of a node | ||
|
||
Args: | ||
node (List[Node]): List of torch.fx Node, | ||
indicates a node in linearized graph | ||
|
||
Returns: | ||
int: forward temp memory, unit Byte | ||
""" | ||
n = node[-1] | ||
return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now fwd_out that are not saved for next node may be regarded as fwd_tmp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great Work! I will approve this PR first and check the C version issue.
What's new?
fwd_in
,fwd_tmp
andfwd_out
. Hopefully, this will avoid a lot of misunderstandings.Bugs in existing code
C version results
Force Python results
What is wrong with this? ✔
Logger always requires a launch, some modifications should be made either on the logger or on
test_linearize.py
. cc @Cypher30Tests
Passed all the tests/test_fx tests locally.