Skip to content

Commit

Permalink
Create loops according to storage scope and thread hierarchies (apach…
Browse files Browse the repository at this point in the history
…e#5190)

* Set IterVar index to 0 for local thread bound IterVars.

* Lint fix

* Use rank instead of scope name to predicate.  Add tests.

* Handle cases other than local/threadIdx.

* Turn warp to the old behavior.

* Modify test to cover global/blockIdx.

* Fix a typo.

* Update test_te_schedule_ops.py with more testing coverage in test_local_stage_predicate; remove test_schedule_schedule_ops.py which was added by mistake.
  • Loading branch information
yongfeng-nv authored and dpankratz committed Apr 24, 2020
1 parent 5cae0f6 commit 3b94aed
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/te/operation/op_util.cc
Expand Up @@ -29,6 +29,7 @@
#include "op_util.h"
#include "../schedule/message_passing.h"
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"

namespace tvm {
namespace te {
Expand Down Expand Up @@ -162,7 +163,13 @@ MakeLoopNest(const Stage& stage,
if (!debug_keep_trivial_loop && is_one(dom->extent)) {
value_map[iv] = dom->min;
} else {
value_map[iv] = var;
runtime::ThreadScope ts = runtime::ThreadScope::make(bind_iv->thread_tag);
if (stage->scope == "" || stage->scope == "warp" ||
static_cast<int>(runtime::StorageScope::make(stage->scope).rank) <= ts.rank) {
value_map[iv] = var;
} else {
value_map[iv] = dom->min;
}
}
}
// annotate the extent of the IterVar
Expand Down
88 changes: 88 additions & 0 deletions tests/python/unittest/test_te_schedule_ops.py
Expand Up @@ -482,6 +482,92 @@ def _compute(*index) :
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)


def test_local_stage_predicate():
m = 1
n = 3
p = 2
A = tvm.te.placeholder((m, n, p), name='A')
B = tvm.te.compute((m, n, p), lambda bi, bj, bk: A[bi, bj, bk], name="B")
C = tvm.te.compute((m, n, p), lambda ci, cj, ck: B[ci, cj, ck], name="C")
by = tvm.te.thread_axis("blockIdx.y")
tx = tvm.te.thread_axis("threadIdx.x")
vx = tvm.te.thread_axis("vthread")

def schedule(thread_tag, mem_scope) :
s = tvm.te.create_schedule(C.op)
s[B].compute_at(s[C], s[C].op.axis[0])
s[B].set_scope(mem_scope)
bno, bni = s[B].split(s[B].op.axis[1], n)
bx = tvm.te.thread_axis("blockIdx.x")
s[C].bind(s[C].op.axis[0], bx)
s[C].bind(s[C].op.axis[1], thread_tag)
s[B].bind(bni, thread_tag)
return s

def collect_visit(stmt, f):
ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x)))
return ret
# local vs. threadIdx
s = schedule(tx, "local")
lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
assert (not any(
collect_visit(lowered_body,
lambda x: isinstance(x, tvm.tir.IfThenElse))))
# local vs. vthread
s = schedule(vx, "local")
lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
assert (not any(
collect_visit(lowered_body,
lambda x: isinstance(x, tvm.tir.IfThenElse))))
# shared vs. blockIdx
s = schedule(by, "shared")
lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
assert (not any(
collect_visit(lowered_body,
lambda x: isinstance(x, tvm.tir.IfThenElse))))

def test_local_stage_predicate2():
A = tvm.te.placeholder((128, ), name="A")
B = tvm.te.compute((128, ), lambda bi: A[bi] + 1, name="B")
C = tvm.te.compute((128, ), lambda ci: B[ci] + 2, name="C")
s = tvm.te.create_schedule(C.op)
AA = s.cache_read(A, "local", [B])
s[B].set_scope("shared")
block_x = tvm.te.thread_axis("blockIdx.x")
thread_x = tvm.te.thread_axis((0, 32), "threadIdx.x")
oc, ic = s[C].split(s[C].op.axis[0], factor=64)
ooc, ioc = s[C].split(oc, factor=2)
oic, iic = s[C].split(ic, factor=32)
s[C].bind(ooc, block_x)
s[C].bind(iic, thread_x)
s[B].compute_at(s[C], ioc)
ob, ib = s[B].split(s[B].op.axis[0], factor=32)
s[B].bind(ib, thread_x)
s[AA].compute_root()
s[AA].compute_at(s[C], ooc)
oaa, iaa = s[AA].split(s[AA].op.axis[0], factor=32)
s[AA].bind(iaa, thread_x)
lowered_body = tvm.lower(s, [A, C], simple_mode=True).body

def collect_visit(stmt, f):
ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x)))
return ret

def visit_stmt(op):
print(op)
if (isinstance(op, tvm.tir.Allocate)):
return op.extents[0].value == 97
return False

assert (not any(
collect_visit(lowered_body,
lambda x: isinstance(x, tvm.tir.IfThenElse))))
assert (any(collect_visit(lowered_body, visit_stmt)))


if __name__ == "__main__":
test_loop_dep_reduce()
test_loop_dep_reduce_cache_write()
Expand All @@ -506,3 +592,5 @@ def _compute(*index) :
test_schedule_tensor_compute3()
test_reduction_and_dummy_fuse_split()
test_schedule_compute_inline()
test_local_stage_predicate()
test_local_stage_predicate2()

0 comments on commit 3b94aed

Please sign in to comment.