-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fx/profiler] assigned UUID to each unrecorded tensor/ improved performance on GPT-2 #1679
Changes from all commits
8a42a8c
905eece
95fcbcc
4b326df
6056be1
58cd60b
bdec890
2f24d4b
754f64a
4d13492
ff93220
71b93e4
2b1bd26
e57fbc4
a6210d2
66e8546
0e37fc5
3e25419
7993aff
d358df5
3ec8061
ff8e832
db0c211
9f206d7
ab96506
fab23e3
49defef
771c02e
016a94b
abe66ef
283d121
47c8381
a405f60
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
from dataclasses import dataclass | ||
from dataclasses import dataclass, field | ||
from enum import Enum | ||
from functools import partial | ||
from typing import Dict | ||
from typing import Dict, List | ||
from torch.fx import Graph, Node | ||
from .memory import activation_size, is_inplace | ||
|
||
|
@@ -39,16 +39,25 @@ class GraphInfo: | |
bwd_flop (int): The backward FLOPs of a certain node. | ||
bwd_time (float): The real backward time (s) of a certain node. | ||
save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes. | ||
fwd_in (List): See the above illustration. | ||
fwd_tmp (List): See the above illustration. | ||
fwd_out (List): See the above illustration. | ||
fwd_mem_tmp (int): See the above illustration. | ||
fwd_mem_out (int): See the above illustration. | ||
bwd_mem_tmp (int): See the above illustration. | ||
bwd_mem_out (int): See the above illustration. | ||
""" | ||
|
||
# TODO(super-dainiu): removed redundant items, currently all of them are necessary for development | ||
|
||
fwd_flop: int = 0 | ||
fwd_time: float = 0.0 | ||
bwd_flop: int = 0 | ||
bwd_time: float = 0.0 | ||
save_fwd_in: bool = False | ||
fwd_in: List = field(default_factory=list) | ||
fwd_tmp: List = field(default_factory=list) | ||
fwd_out: List = field(default_factory=list) | ||
fwd_mem_tmp: int = 0 | ||
fwd_mem_out: int = 0 | ||
bwd_mem_tmp: int = 0 | ||
Comment on lines
+51
to
63
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just for compatibility of different There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now all |
||
|
@@ -60,10 +69,6 @@ def is_phase(n: Node, phase: Phase) -> bool: | |
return n.meta['phase'] == phase | ||
|
||
|
||
def is_saved(n: Node): | ||
return len(n.meta['saved_tensor']) | ||
|
||
|
||
def autograd_graph_analysis(graph: Graph) -> GraphInfo: | ||
"""Analyze the autograd node dependencies and find out the memory usage. | ||
Basically the input graph should have all nodes marked for keyword `phase`. | ||
|
@@ -113,9 +118,9 @@ def _peak_memory(deps: Dict[Node, int]): | |
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint | ||
# the node, `fwd_mem_tmp` can be freed. | ||
if is_phase(n, Phase.PLACEHOLDER): | ||
graph_info.save_fwd_in |= activation_size(n.meta['saved_tensor']) > 0 | ||
graph_info.fwd_in += n.meta['saved_tensor'] | ||
if is_phase(n, Phase.FORWARD): | ||
graph_info.fwd_mem_tmp += activation_size(n.meta['saved_tensor']) | ||
graph_info.fwd_tmp += n.meta['saved_tensor'] | ||
elif is_phase(n, Phase.BACKWARD): | ||
if len(n.users): | ||
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .registry import meta_profiler_function, meta_profiler_module | ||
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out | ||
from .profiler_function import * | ||
from .profiler_module import * | ||
from .profiler import profile_function, profile_method, profile_module |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# for PyTorch 1.11 compatibility uses | ||
import torch | ||
from torch.fx import Node, GraphModule | ||
from typing import Union, Dict, List, Tuple | ||
|
||
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"] | ||
|
||
|
||
def calculate_fwd_in(n: Node) -> bool: | ||
"""A helper function to calculate `fwd_in` | ||
|
||
Args: | ||
n (Node): a node from the graph | ||
|
||
Returns: | ||
save_fwd_in (bool): the result of `save_fwd_in` | ||
""" | ||
return n.meta['save_fwd_in'] | ||
|
||
|
||
def calculate_fwd_tmp(n: Node) -> int: | ||
"""A helper function to calculate `fwd_tmp` | ||
|
||
Args: | ||
n (Node): a node from the graph | ||
|
||
Returns: | ||
fwd_tmp (int): the result of `fwd_tmp` | ||
""" | ||
return n.meta["fwd_mem_tmp"] | ||
|
||
|
||
def calculate_fwd_out(n: Node) -> int: | ||
"""A helper function to calculate `fwd_out` | ||
|
||
Args: | ||
n (Node): a node from the graph | ||
|
||
Returns: | ||
fwd_out (int): the result of `fwd_out` | ||
""" | ||
return n.meta['fwd_mem_out'] | ||
Comment on lines
+1
to
+42
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is for compatibility uses because metainfoprop for torch 1.11 don't save tensors with uuid. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now fwd_out that are not saved for next node may be regarded as fwd_tmp