Skip to content

Commit

Permalink
[autoparallel] fix C version rotor inconsistency (#1691)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cypher30 committed Oct 12, 2022
1 parent 363fc28 commit 31d2f03
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 22 deletions.
30 changes: 25 additions & 5 deletions colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.logging import get_dist_logger

# global vairable to indicate whether the solver is failed
SOLVER_FAILED = False


# this is the python compute table code from rotor
# https://gitlab.inria.fr/hiepacs/rotor
Expand Down Expand Up @@ -87,9 +90,17 @@ def _rec(chain: Chain, lmin, lmax, cmem, opt_table):
opt, what = opt_table
sequence = Sequence(Function("Persistent", lmax - lmin, cmem))
if opt[cmem][lmin][lmax] == float("inf"):
raise ValueError("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin,
lmax=lmax,
cmem=cmem))
# using logger to annonce that the solver is failed
logger = get_dist_logger()
logger.info("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin,
lmax=lmax,
cmem=cmem))

# set global indicater SOLVER_FAILED to True
global SOLVER_FAILED
SOLVER_FAILED = True
return sequence

if lmin == lmax:
if lmin == chain.length:
sequence.insert(Loss())
Expand Down Expand Up @@ -406,9 +417,18 @@ def solver_rotor(gm: ColoGraphModule,

# found sequence
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
_annotate_from_sequence(sequence, node_list)

# if solver failed, we don't need to annotate the graph
if not SOLVER_FAILED:
_annotate_from_sequence(sequence, node_list)

# set __sequence__ attribute to GraphModule
setattr(gm, "__sequence__", sequence)
if SOLVER_FAILED:
setattr(gm, "__sequence__", None)
else:
setattr(gm, "__sequence__", sequence)

# set __opttable__ attribute to GraphModule
setattr(gm, "__opttable__", opt_table[0])
gm.recompile()
return gm
33 changes: 18 additions & 15 deletions colossalai/fx/passes/algorithms/dynamic_programs.c
Original file line number Diff line number Diff line change
Expand Up @@ -94,24 +94,27 @@ static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
OPT(m, i, i) = INFINITY;

for (long m = 0; m <= mmax; ++m)
for (long i = 0; i <= chain_length; ++i) {
long maxCostFWD = 0;
for (long l = i + 1; l <= chain_length; ++l) {
long mmin = cw[l + 1] + cw[i + 1] + fwd_tmp[i];
if (l > i + 1) {
maxCostFWD = fmaxl(maxCostFWD, cw[l - 1] + cw[l] + fwd_tmp[l - 1]);
mmin = fmaxl(mmin, cw[l + 1] + maxCostFWD);
for (long d = 1; d <= chain_length; ++d) {
for (long i = 0; i <= chain_length - d; ++i) {
long idx = i + d;
long mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i];
if (idx > i + 1) {
long maxCostFWD = 0;
for (long j = i + 1; j < idx; j++) {
maxCostFWD = fmaxl(maxCostFWD, cw[j] + cw[j + 1] + fwd_tmp[j]);
}
mmin = fmaxl(mmin, cw[idx + 1] + maxCostFWD);
}
if ((m >= mmin)) {
long bestLeaf = -1;
double sumFw = 0;
double bestLeafCost = INFINITY;
/// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j =
/// i+1
for (long j = i + 1; j <= l; ++j) {
for (long j = i + 1; j <= idx; ++j) {
sumFw += fw[j - 1];
if (m >= cw[j]) {
double cost = sumFw + OPT(m - cw[j], j, l) + OPT(m, i, j - 1);
double cost = sumFw + OPT(m - cw[j], j, idx) + OPT(m, i, j - 1);
if (cost < bestLeafCost) {
bestLeafCost = cost;
bestLeaf = j;
Expand All @@ -120,16 +123,16 @@ static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
}
double chainCost = INFINITY;
if (m >= cbw[i + 1])
chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, l);
chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, idx);
if (bestLeafCost <= chainCost) {
OPT(m, i, l) = bestLeafCost;
WHAT(m, i, l) = bestLeaf;
OPT(m, i, idx) = bestLeafCost;
WHAT(m, i, idx) = bestLeaf;
} else {
OPT(m, i, l) = chainCost;
WHAT(m, i, l) = -1;
OPT(m, i, idx) = chainCost;
WHAT(m, i, idx) = -1;
}
} else
OPT(m, i, l) = INFINITY;
OPT(m, i, idx) = INFINITY;
}
}

Expand Down
13 changes: 11 additions & 2 deletions tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
def _run_C_solver_consistency_test(rank=0):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')

for M, mem_budget in [(tm.resnet18, 2000), (tm.resnet50, 8000)]:
for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]:
model = M()
data = torch.rand(128, 3, 224, 224, device='meta')

Expand All @@ -41,15 +41,24 @@ def _run_C_solver_consistency_test(rank=0):
# python solver
gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024, force_python=True)
sequence_python: Sequence = copy.deepcopy(gm.__sequence__)
opt_python = copy.deepcopy(gm.__opttable__)

# C solver
gm = solver_rotor(gm, data_meta, mem_budget * 1024 * 1024)
sequence_C: Sequence = copy.deepcopy(gm.__sequence__)
opt_C = copy.deepcopy(gm.__opttable__)

# make sure the opt_tables are the same
for m in range(len(opt_python)):
for d in range(1, len(opt_python[0])):
for i in range(len(opt_python[0]) - d):
assert opt_python[m][i][i + d] == opt_C[m][i][i + d], \
f"item ({m}, {i}, {i + d}) is not consistent with python version!\npython version: {opt_python[m][i][i + d]}\nC version: {opt_C[m][i][i + d]}"

sequence_python = sequence_python.list_operations()
sequence_C = sequence_C.list_operations()

# make sure the solutions are the same
# make sure the sequences are the same
assert len(sequence_python) == len(sequence_C) and \
all(python_op.__repr__() == C_op.__repr__() for (python_op, C_op) in zip(sequence_python, sequence_C))

Expand Down

0 comments on commit 31d2f03

Please sign in to comment.