From 3b94aed08f0d5c6fa1dcdb7332509cba61ce9b1f Mon Sep 17 00:00:00 2001 From: yongfeng-nv <49211903+yongfeng-nv@users.noreply.github.com> Date: Thu, 9 Apr 2020 21:49:37 -0400 Subject: [PATCH] Create loops according to storage scope and thread hierarchies (#5190) * Set IterVar index to 0 for local thread bound IterVars. * Lint fix * Use rank instead of scope name to predicate. Add tests. * Handle cases other than local/threadIdx. * Turn warp to the old behavior. * Modify test to cover global/blockIdx. * Fix a typo. * Update test_te_schedule_ops.py with more testing coverage in test_local_stage_predicate; remove test_schedule_schedule_ops.py which was added by mistake. --- src/te/operation/op_util.cc | 9 +- tests/python/unittest/test_te_schedule_ops.py | 88 +++++++++++++++++++ 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index 3714f439bd2b6..4ecfe9472901c 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -29,6 +29,7 @@ #include "op_util.h" #include "../schedule/message_passing.h" #include "../../arith/compute_expr.h" +#include "../../runtime/thread_storage_scope.h" namespace tvm { namespace te { @@ -162,7 +163,13 @@ MakeLoopNest(const Stage& stage, if (!debug_keep_trivial_loop && is_one(dom->extent)) { value_map[iv] = dom->min; } else { - value_map[iv] = var; + runtime::ThreadScope ts = runtime::ThreadScope::make(bind_iv->thread_tag); + if (stage->scope == "" || stage->scope == "warp" || + static_cast(runtime::StorageScope::make(stage->scope).rank) <= ts.rank) { + value_map[iv] = var; + } else { + value_map[iv] = dom->min; + } } } // annotate the extent of the IterVar diff --git a/tests/python/unittest/test_te_schedule_ops.py b/tests/python/unittest/test_te_schedule_ops.py index 8d10ceea0b48a..4e27ad3f2a58f 100644 --- a/tests/python/unittest/test_te_schedule_ops.py +++ b/tests/python/unittest/test_te_schedule_ops.py @@ -482,6 +482,92 @@ def _compute(*index) : bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) + +def test_local_stage_predicate(): + m = 1 + n = 3 + p = 2 + A = tvm.te.placeholder((m, n, p), name='A') + B = tvm.te.compute((m, n, p), lambda bi, bj, bk: A[bi, bj, bk], name="B") + C = tvm.te.compute((m, n, p), lambda ci, cj, ck: B[ci, cj, ck], name="C") + by = tvm.te.thread_axis("blockIdx.y") + tx = tvm.te.thread_axis("threadIdx.x") + vx = tvm.te.thread_axis("vthread") + + def schedule(thread_tag, mem_scope) : + s = tvm.te.create_schedule(C.op) + s[B].compute_at(s[C], s[C].op.axis[0]) + s[B].set_scope(mem_scope) + bno, bni = s[B].split(s[B].op.axis[1], n) + bx = tvm.te.thread_axis("blockIdx.x") + s[C].bind(s[C].op.axis[0], bx) + s[C].bind(s[C].op.axis[1], thread_tag) + s[B].bind(bni, thread_tag) + return s + + def collect_visit(stmt, f): + ret = [] + tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x))) + return ret + # local vs. threadIdx + s = schedule(tx, "local") + lowered_body = tvm.lower(s, [A, C], simple_mode=True).body + assert (not any( + collect_visit(lowered_body, + lambda x: isinstance(x, tvm.tir.IfThenElse)))) + # local vs. vthread + s = schedule(vx, "local") + lowered_body = tvm.lower(s, [A, C], simple_mode=True).body + assert (not any( + collect_visit(lowered_body, + lambda x: isinstance(x, tvm.tir.IfThenElse)))) + # shared vs. blockIdx + s = schedule(by, "shared") + lowered_body = tvm.lower(s, [A, C], simple_mode=True).body + assert (not any( + collect_visit(lowered_body, + lambda x: isinstance(x, tvm.tir.IfThenElse)))) + +def test_local_stage_predicate2(): + A = tvm.te.placeholder((128, ), name="A") + B = tvm.te.compute((128, ), lambda bi: A[bi] + 1, name="B") + C = tvm.te.compute((128, ), lambda ci: B[ci] + 2, name="C") + s = tvm.te.create_schedule(C.op) + AA = s.cache_read(A, "local", [B]) + s[B].set_scope("shared") + block_x = tvm.te.thread_axis("blockIdx.x") + thread_x = tvm.te.thread_axis((0, 32), "threadIdx.x") + oc, ic = s[C].split(s[C].op.axis[0], factor=64) + ooc, ioc = s[C].split(oc, factor=2) + oic, iic = s[C].split(ic, factor=32) + s[C].bind(ooc, block_x) + s[C].bind(iic, thread_x) + s[B].compute_at(s[C], ioc) + ob, ib = s[B].split(s[B].op.axis[0], factor=32) + s[B].bind(ib, thread_x) + s[AA].compute_root() + s[AA].compute_at(s[C], ooc) + oaa, iaa = s[AA].split(s[AA].op.axis[0], factor=32) + s[AA].bind(iaa, thread_x) + lowered_body = tvm.lower(s, [A, C], simple_mode=True).body + + def collect_visit(stmt, f): + ret = [] + tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x))) + return ret + + def visit_stmt(op): + print(op) + if (isinstance(op, tvm.tir.Allocate)): + return op.extents[0].value == 97 + return False + + assert (not any( + collect_visit(lowered_body, + lambda x: isinstance(x, tvm.tir.IfThenElse)))) + assert (any(collect_visit(lowered_body, visit_stmt))) + + if __name__ == "__main__": test_loop_dep_reduce() test_loop_dep_reduce_cache_write() @@ -506,3 +592,5 @@ def _compute(*index) : test_schedule_tensor_compute3() test_reduction_and_dummy_fuse_split() test_schedule_compute_inline() + test_local_stage_predicate() + test_local_stage_predicate2() \ No newline at end of file