Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hotfix/rotor] fix variable names #1597

Merged
merged 27 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
044adf8
[fx] add some comment and docstrings.
super-dainiu Sep 8, 2022
a42ab22
[fx] add dataflow analysis for an autograd graph.
super-dainiu Sep 8, 2022
f8e1c1c
Merge branch 'main' of https://github.com/super-dainiu/ColossalAI int…
super-dainiu Sep 8, 2022
0d55f26
add intepretation for graph analysis.
super-dainiu Sep 9, 2022
42c6e8c
[fx] before doing save_tensor_hooks.
super-dainiu Sep 12, 2022
dafbfcf
Merge branch 'hpcaitech:main' into feature/better_flop_tensor
super-dainiu Sep 12, 2022
5f25d6e
[fx] provide an accurate estimation of memory except for GPT-2.
super-dainiu Sep 12, 2022
9739876
Merge branch 'hpcaitech:main' into feature/better_flop_tensor
super-dainiu Sep 12, 2022
3745c5f
[fx] provide an accurate estimation of memory except for GPT-2.
super-dainiu Sep 12, 2022
504c607
[fx] provide an accurate estimation of memory except for GPT-2.
super-dainiu Sep 12, 2022
5d72a52
[fx] a very accurate version on GPT-2.
super-dainiu Sep 13, 2022
6bdeb29
[fx] refactor code.
super-dainiu Sep 13, 2022
fafb7d0
[fx] remove redundant inplace=True.
super-dainiu Sep 13, 2022
d3c3690
[fx] refactor code.
super-dainiu Sep 13, 2022
de2be8f
[fx] refactor code.
super-dainiu Sep 13, 2022
bc727c2
[fx] refactor code.
super-dainiu Sep 13, 2022
3d284af
Merge branch 'main' of https://github.com/super-dainiu/ColossalAI int…
super-dainiu Sep 13, 2022
13b1d58
[fx] dive into backward memory.
super-dainiu Sep 13, 2022
7df14a2
[fx] fix variable names in ckpt_solvers and unskip tests.
super-dainiu Sep 13, 2022
8025fb8
Merge branch 'main' of https://github.com/super-dainiu/ColossalAI int…
super-dainiu Sep 13, 2022
ce7b1a9
[fx] commit my changes.
super-dainiu Sep 14, 2022
4dd1b27
[fx] commit my changes.
super-dainiu Sep 14, 2022
5277b72
[fx] restore skips.
super-dainiu Sep 14, 2022
594ec0b
[fx] restore skips.
super-dainiu Sep 14, 2022
3144f8a
[fx] chaange stage into phase.
super-dainiu Sep 14, 2022
31d6bbf
[fx] chaange stage into phase.
super-dainiu Sep 14, 2022
aead7e1
[fx] chaange stage into phase.
super-dainiu Sep 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions colossalai/fx/passes/algorithms/ckpt_solver_chen.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,11 @@ def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
y = 0
prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes):
temp += getattr(n, 'fwd_out')
n: Node
temp += n.meta['fwd_mem_out'] + n.meta['fwd_mem_tmp']
y = max(y, temp)
if temp > b and n in ckpt_nodes:
x += getattr(n, 'fwd_out')
x += n.meta['fwd_mem_out']
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
temp = 0
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1
Expand Down
48 changes: 26 additions & 22 deletions colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import List, Set, Tuple, Dict
from typing import List, Tuple
import torch
from torch.fx import GraphModule, Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import parameter_size
import math
from .linearize import linearize
from .utils import *
from colossalai.fx.profiler import profile_function, profile_module
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions

Expand All @@ -25,8 +25,8 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
bw = chain.bweight ## backward time, not used
cw = chain.cweight + [0] ## size of x (and of y)
cbw = chain.cbweight + [0] ## size of xbar
fwd_tmp = chain.fwd_tmp + [0]
bwd_tmp = chain.bwd_tmp + [0]
fwd_mem_tmp = chain.fwd_mem_tmp + [0]
bwd_mem_tmp = chain.bwd_mem_tmp + [0]

# Build table
opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
Expand All @@ -37,7 +37,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
for m in range(mmax + 1):
for i in range(chain.length + 1):
#lmax-lmin = 0
limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_tmp[i])
limit = max(cw[i + 1] + cbw[i + 1] + fwd_mem_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_mem_tmp[i])
if m >= limit: ## Equation (1)
opt[m][i][i] = fw[i] + bw[i]
else:
Expand All @@ -49,9 +49,9 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
for i in range(chain.length + 1 - d):
# for idx in range(i+1, chain.length + 1):
idx = i + d
mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i]
mmin = cw[idx + 1] + cw[i + 1] + fwd_mem_tmp[i]
if idx > i + 1:
mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_tmp[j] for j in range(i + 1, idx)))
mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_mem_tmp[j] for j in range(i + 1, idx)))
if m < mmin:
opt[m][i][idx] = float("inf")
else:
Expand Down Expand Up @@ -165,7 +165,7 @@ def _fwd_xbar(node: List[Node]) -> int:

xbar = 0
for n in node:
xbar += n.fwd_tmp + n.fwd_out
xbar += n.meta['fwd_mem_tmp'] + n.meta['fwd_mem_out']
return xbar


Expand All @@ -183,7 +183,7 @@ def _fwd_time(node: List[Node]) -> int:
fwd_time = 0
for n in node:
# minimum flop count is needed
fwd_time += max(n.fwd_flop, 1)
fwd_time += max(n.meta['fwd_flop'], 1)
return fwd_time


Expand All @@ -201,11 +201,11 @@ def _bwd_time(node: List[Node]) -> int:
bwd_time = 0
for n in node:
# minimum flop count is needed
bwd_time += max(n.bwd_flop, 1)
bwd_time += max(n.meta['bwd_flop'], 1)
return bwd_time


def _get_bwd_tmp(node: List[Node]) -> int:
def _get_bwd_mem_tmp(node: List[Node]) -> int:
"""Get the backward temp memory of a node

Args:
Expand All @@ -218,29 +218,32 @@ def _get_bwd_tmp(node: List[Node]) -> int:

def _get_deps_size():
deps_size = 0
for key in deps.keys():
deps_size += key.bwd_out
for k, v in deps.items():
if v > 0:
deps_size += k.meta['bwd_mem_out']

return deps_size

bwd_tmp = 0
bwd_mem_tmp = 0
deps = {}

# add all the users for last node into deps,
# as those nodes' gradient out will be stored in memory
for son in node[-1].users:
deps[son] = 1
for child in node[-1].users:
deps[child] = 1
for n in reversed(node):
bwd_tmp = max(bwd_tmp, _get_deps_size() + n.bwd_tmp)
deps[n] = len(n._input_nodes)
for son in n.users:
deps[son] -= 1
bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp'])

deps[n] = len(n.all_input_nodes)
for child in n.users:
if child in deps:
deps[child] -= 1

for key in list(deps.keys()):
if deps[key] == 0:
del deps[key]

return bwd_tmp
return bwd_mem_tmp


def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
Expand All @@ -267,7 +270,7 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
bwd_time.append(_bwd_time(node))
x_sizes.append(_compute_output_size(node))
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
tmp_bwd.append(_get_bwd_tmp(node))
tmp_bwd.append(_get_bwd_mem_tmp(node))

# if a node with only one inplace op, we need to let x_bar = 0
if len(node) == 1 and _get_inplace(node[0]):
Expand Down Expand Up @@ -394,6 +397,7 @@ def solver_rotor(gm: ColoGraphModule,
"""

node_list = linearize(gm, cnode)
mem_limit -= parameter_size(gm)
mem_unit = mem_limit * (1.0 - eps) // mem_slots
MetaInfoProp(gm).run(data)
chain: Chain = _construct_chain(node_list, data, mem_unit)
Expand Down
14 changes: 7 additions & 7 deletions colossalai/fx/passes/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,24 @@ def __init__(self, fw, bw, cw, cbw, ftmp, btmp, check=True):
self.bweight = bw
self.cweight = cw
self.cbweight = cbw
self.fwd_tmp = ftmp
self.bwd_tmp = btmp
self.fwd_mem_tmp = ftmp
self.bwd_mem_tmp = btmp
self.length = len(fw)
if check and not self.check_lengths():
raise AttributeError("In Chain, input lists do not have consistent lengths")

def check_lengths(self):
return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1)
and (len(self.cweight) == self.length + 1) and (len(self.fwd_tmp) == self.length)
and (len(self.bwd_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1))
and (len(self.cweight) == self.length + 1) and (len(self.fwd_mem_tmp) == self.length)
and (len(self.bwd_mem_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1))

def __repr__(self):
chain_list = []
for i in range(self.length):
chain_list.append(
(self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_tmp[i], self.bwd_tmp[i]))
chain_list.append((self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_mem_tmp[i],
self.bwd_mem_tmp[i]))
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
i = self.length
chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_tmp[i]))
chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_mem_tmp[i]))
return chain_list.__repr__()


Expand Down
5 changes: 2 additions & 3 deletions colossalai/fx/passes/meta_info_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,11 @@ def extract_tensor_meta(obj):

tensor_meta = tree_map(extract_tensor_meta, result)
n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`

n.meta = {**n.meta, **asdict(meta_info), 'fwd_mem_out': 0} # extend MetaInfo to `n.meta`
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid doubled MetaInfoProp that introduces doubled fwd_mem_out.

# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
for par in n.all_input_nodes:
par.meta['fwd_mem_out'] = par.meta.get('fwd_mem_out', 0) + n.meta.get('fwd_mem_in', 0)
par.meta['fwd_mem_out'] = max(par.meta.get('fwd_mem_out', 0), n.meta.get('fwd_mem_in', 0))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max is more plausible for this calculation.

n.meta['type'] = type(result)

# retain the autograd graph
Expand Down
50 changes: 18 additions & 32 deletions colossalai/fx/profiler/dataflow.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import Dict
from torch.fx import Graph, Node
from .memory import activation_size


class Stage(Enum):
class Phase(Enum):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stage should be Phase with respect to RPC phase.

FORWARD = 0
LOSS = 1
BACKWARD = 2
Expand Down Expand Up @@ -48,24 +49,9 @@ class GraphInfo:
bwd_mem_out: int = 0


def is_forward(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.FORWARD


def is_loss(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.LOSS


def is_placeholder(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.PLACEHOLDER


def is_backward(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.BACKWARD
def is_phase(n: Node, phase: Phase) -> bool:
assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!'
return n.meta['phase'] == phase


def is_saved(n: Node):
Expand All @@ -74,7 +60,7 @@ def is_saved(n: Node):

def autograd_graph_analysis(graph: Graph) -> GraphInfo:
"""Analyze the autograd node dependencies and find out the memory usage.
Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`.
Basically the input graph should have all nodes marked for keyword `phase`.
Nodes should have attribute `out` indicating the output of each node.
============================================================================
Placeholder ----> p o <---- We need to keep track of grad out
Expand All @@ -91,38 +77,38 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
l
=============================================================================
Args:
graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`.
graph (Graph): The autograd graph with nodes marked for keyword `phase`.

Returns:
graph_info (GraphInfo): Meta information for the dataflow.
"""

def _peak_memory(deps: Dict[Node, int]):
bwd_tmp = 0
peak_mem = 0
for k, v in deps.items():
if v > 0:
bwd_tmp += activation_size(k.meta['out'])
return bwd_tmp
peak_mem += activation_size(k.meta['out'])
return peak_mem

# deps is used to track all the memory dependencies of the graph.
deps = {}
graph_info = GraphInfo()

for n in graph.nodes:
n: Node
if is_saved(n) and not any(map(is_loss, n.users)):
if is_saved(n) and not any(map(partial(is_phase, phase=Phase.LOSS), n.users)):
# A forward tensor who is marked `save` but is not
# an input to `loss` should be saved during forward.
# If the tensor is a placeholder, then it belongs to `fwd_in`.
# Any `fwd_in` should be kept in memory even this function
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
# Any `fwd_mem_in` should be kept in memory even this function
# is checkpointed.
# Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint
# the node, `fwd_tmp` can be freed.
if is_placeholder(n):
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_mem_tmp` can be freed.
if is_phase(n, Phase.PLACEHOLDER):
graph_info.fwd_mem_in += activation_size(n.meta['out'])
if is_forward(n):
if is_phase(n, Phase.FORWARD):
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
elif is_backward(n):
elif is_phase(n, Phase.BACKWARD):
if len(n.users):
# liveness analysis is only used in backward
deps[n] = len(n.users)
Expand Down