diff --git a/python/tvm/script/builder/tir/__init__.py b/python/tvm/script/builder/tir/__init__.py index d74b73341758..48046f0e9175 100644 --- a/python/tvm/script/builder/tir/__init__.py +++ b/python/tvm/script/builder/tir/__init__.py @@ -19,7 +19,7 @@ from . import axis from .base import TIRFrame -from .block_frame import block +from .block_frame import block, where, reads, writes, alloc_buffer, block_attr, init from .for_frame import ( ForFrame, grid, diff --git a/python/tvm/script/builder/tir/axis.py b/python/tvm/script/builder/tir/axis.py index 7be7cd42aba2..acbcda108dde 100644 --- a/python/tvm/script/builder/tir/axis.py +++ b/python/tvm/script/builder/tir/axis.py @@ -21,6 +21,8 @@ from . import _ffi_api +from typing import List + def spatial(dom, binding, dtype="int32") -> IterVar: if not isinstance(dom, Range): @@ -34,7 +36,19 @@ def reduce(dom, binding, dtype="int32") -> IterVar: return _ffi_api.AxisReduce(dom, binding, dtype) # pylint: disable=no-member # type: ignore -def remap(kinds, bindings, dtype="int32") -> IterVar: +def scan(dom, binding, dtype="int32") -> IterVar: + if not isinstance(dom, Range): + dom = Range(0, dom) + return _ffi_api.AxisScan(dom, binding, dtype) # pylint: disable=no-member # type: ignore + + +def opaque(dom, binding, dtype="int32") -> IterVar: + if not isinstance(dom, Range): + dom = Range(0, dom) + return _ffi_api.AxisOpaque(dom, binding, dtype) # pylint: disable=no-member # type: ignore + + +def remap(kinds, bindings, dtype="int32") -> List[IterVar]: return _ffi_api.AxisRemap(kinds, bindings, dtype) # pylint: disable=no-member # type: ignore diff --git a/python/tvm/script/builder/tir/block_frame.py b/python/tvm/script/builder/tir/block_frame.py index aa447a409e89..890b3c44eaa1 100644 --- a/python/tvm/script/builder/tir/block_frame.py +++ b/python/tvm/script/builder/tir/block_frame.py @@ -20,11 +20,71 @@ from . import _ffi_api from .base import TIRFrame +from typing import List, Dict, Any, Union +from tvm.tir import Buffer, BufferLoad, BufferRegion + @_register_object("script.builder.tir.BlockFrame") class BlockFrame(TIRFrame): ... -def block(name) -> BlockFrame: - return _ffi_api.BlockFrame(name) # pylint: disable=no-member # type: ignore +@_register_object("script.builder.tir.BlockInitFrame") +class BlockInitFrame(TIRFrame): + ... + + +def block(name: str, no_realize: bool = False) -> BlockFrame: + return _ffi_api.BlockFrame(name, no_realize) # pylint: disable=no-member # type: ignore + + +def init() -> BlockInitFrame: + return _ffi_api.BlockInitFrame() # pylint: disable=no-member # type: ignore + + +def where(predicate) -> None: + _ffi_api.Where(predicate) # pylint: disable=no-member # type: ignore + + +def reads(buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None: + if not isinstance(buffer_slices, List): + buffer_slices = [buffer_slices] + _ffi_api.Reads(buffer_slices) + + +def writes(buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None: + if not isinstance(buffer_slices, List): + buffer_slices = [buffer_slices] + _ffi_api.Writes(buffer_slices) + + +def block_attr(attrs: Dict[str, Any]) -> None: + return _ffi_api.BlockAttrs(attrs) # pylint: disable=no-member # type: ignore + + +def alloc_buffer( + shape, + dtype="float32", + data=None, + strides=[], + elem_offset=None, + storage_scope="", + align=-1, + offset_factor=0, + buffer_type="default", + axis_separators=None, + span=None, +) -> Buffer: + return _ffi_api.AllocBuffer( + shape, + dtype, + data, + strides, + elem_offset, + storage_scope, + align, + offset_factor, + buffer_type, + axis_separators, + span, + ) diff --git a/python/tvm/script/builder/tir/for_frame.py b/python/tvm/script/builder/tir/for_frame.py index 051565492882..9f81ecccff5a 100644 --- a/python/tvm/script/builder/tir/for_frame.py +++ b/python/tvm/script/builder/tir/for_frame.py @@ -32,23 +32,23 @@ def __enter__(self) -> List[Var]: return self.vars -def serial(start, stop, annotations) -> ForFrame: +def serial(start, stop, annotations=None) -> ForFrame: return _ffi_api.Serial(start, stop, annotations) # pylint: disable=no-member # type: ignore -def parallel(start, stop, annotations) -> ForFrame: +def parallel(start, stop, annotations=None) -> ForFrame: return _ffi_api.Parallel(start, stop, annotations) # pylint: disable=no-member # type: ignore -def vectorized(start, stop, annotations) -> ForFrame: +def vectorized(start, stop, annotations=None) -> ForFrame: return _ffi_api.Vectorized(start, stop, annotations) # pylint: disable=no-member # type: ignore -def unroll(start, stop, annotations) -> ForFrame: +def unroll(start, stop, annotations=None) -> ForFrame: return _ffi_api.Unroll(start, stop, annotations) # pylint: disable=no-member # type: ignore -def thread_binding(start, stop, thread, annotations) -> ForFrame: +def thread_binding(start, stop, thread, annotations=None) -> ForFrame: return _ffi_api.ThreadBinding( # pylint: disable=no-member # type: ignore start, stop, thread, annotations ) diff --git a/src/script/builder/builder.h b/src/script/builder/builder.h index 0bbfee9688e5..996848c5a1ab 100644 --- a/src/script/builder/builder.h +++ b/src/script/builder/builder.h @@ -43,6 +43,8 @@ class BuilderNode : public runtime::Object { public: template inline Optional FindFrame() const; + template + inline Optional GetLastFrame() const; template inline TObjectRef Get() const; @@ -90,6 +92,15 @@ inline Optional BuilderNode::FindFrame() const { return NullOpt; } +template +inline Optional BuilderNode::GetLastFrame() const { + using TFrameNode = typename TFrame::ContainerType; + if (!frames.empty() && frames.back()->IsInstance()) { + return Downcast(frames.back()); + } + return NullOpt; +} + template inline TObjectRef BuilderNode::Get() const { using TObject = typename TObjectRef::ContainerType; diff --git a/src/script/builder/tir/block_frame.cc b/src/script/builder/tir/block_frame.cc index f2167f557589..663017624e09 100644 --- a/src/script/builder/tir/block_frame.cc +++ b/src/script/builder/tir/block_frame.cc @@ -16,18 +16,21 @@ * specific language governing permissions and limitations * under the License. */ + #include "./block_frame.h" #include #include "./for_frame.h" +#include "./utils.h" +#include "./var.h" namespace tvm { namespace script { namespace builder { namespace tir { -BlockFrame Block_(String name) { +BlockFrame Block_(String name, bool no_realize) { ObjectPtr n = make_object(); n->name = name; n->iter_vars.clear(); @@ -39,24 +42,116 @@ BlockFrame Block_(String name) { n->annotations.clear(); n->iter_values.clear(); n->predicate = NullOpt; + n->no_realize = no_realize; return BlockFrame(n); } void BlockFrameNode::ExitWithScope() { using namespace tvm::tir; TIRFrameNode::ExitWithScope(); - AddToParent(BlockRealize(iter_values, // - predicate.value_or(Bool(true)), - Block(iter_vars, // - reads, writes, // - name, // - AsStmt(stmts), // - init, // - alloc_buffers, // - match_buffers, // - annotations))); + Block block = Block(iter_vars, reads, writes, name, AsStmt(stmts), init, alloc_buffers, + match_buffers, annotations); + if (no_realize) { + CHECK(iter_values.empty()) << "ValueError: Block bindings are not allowed when `no_realize=True`"; + CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`"; + AddToParent(block); + } else { + AddToParent(BlockRealize(iter_values, predicate.value_or(Bool(true)), block)); + } +} + +BlockInitFrame Init() { + ObjectPtr n = make_object(); + return BlockInitFrame(n); +} + +void BlockInitFrameNode::EnterWithScope() { + BlockFrame frame = FindBlockFrame("T.init"); + if (frame->init.defined()) { + LOG(FATAL) << "Duplicate block init declaration"; + } + TIRFrameNode::EnterWithScope(); +} + +void BlockInitFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + BlockFrame frame = FindBlockFrame("T.init"); + frame->init = AsStmt(stmts); } +BlockFrame FindBlockFrame(const String& method) { + if (Optional block_frame = Builder::Current()->GetLastFrame()) { + return block_frame.value(); + } else { + LOG(FATAL) << "ValueError: Block frame not find. Please ensure '" << method + << "' is called under T.block()"; + } + throw; +} + +void Where(PrimExpr predicate) { + BlockFrame frame = FindBlockFrame("T.where"); + if (frame->predicate.defined()) { + LOG(FATAL) << "Duplicate block predicate declaration, previous one is " + << frame->predicate.value(); + } + frame->predicate = predicate; +} + +void Reads(Array buffer_slices) { + using namespace tvm::tir; + BlockFrame frame = FindBlockFrame("T.reads"); + if (!frame->reads.empty()) { + LOG(FATAL) << "Duplicate read region declaration, previous one is " << frame->reads; + } + for (const ObjectRef& obj : buffer_slices) { + if (const auto* buffer_region = obj.as()) { + frame->reads.push_back(GetRef(buffer_region)); + } else if (const auto* buffer_load = obj.as()) { + frame->reads.push_back(BufferRegionFromLoad(GetRef(buffer_load))); + } else { + LOG(FATAL) << "Invalid type for buffer reads."; + } + } +} + +void Writes(Array buffer_slices) { + using namespace tvm::tir; + BlockFrame frame = FindBlockFrame("T.writes"); + if (!frame->writes.empty()) { + LOG(FATAL) << "Duplicate write region declaration, previous one is " << frame->writes; + } + for (const ObjectRef& obj : buffer_slices) { + if (const auto* buffer_region = obj.as()) { + frame->writes.push_back(GetRef(buffer_region)); + } else if (const auto* buffer_load = obj.as()) { + frame->writes.push_back(BufferRegionFromLoad(GetRef(buffer_load))); + } else { + LOG(FATAL) << "Invalid type for buffer writes."; + } + } +} + +void BlockAttrs(Map attrs) { + BlockFrame frame = FindBlockFrame("T.block_attr"); + if (!frame->annotations.empty()) { + LOG(FATAL) << "Duplicate block annotations, previous one is " << frame->annotations; + } + frame->annotations = attrs; +} + +tvm::tir::Buffer AllocBuffer(Array shape, DataType dtype, Optional data, + Array strides, PrimExpr elem_offset, String storage_scope, + int align, int offset_factor, String buffer_type_str, + Array axis_separators, Span span) { + using namespace tvm::tir; + Buffer buffer = DeclBuffer(shape, dtype, "", data, strides, elem_offset, storage_scope, align, + offset_factor, buffer_type_str, axis_separators, span); + BlockFrame frame = FindBlockFrame("T.alloc_buffer"); + frame->alloc_buffers.push_back(buffer); + return buffer; +}; + namespace axis { // TODO(@junrushao1994): figure out the Block syntax without BlockRealize @@ -72,27 +167,20 @@ tvm::tir::IterVar PushBlockVar(tvm::tir::IterVar iter_var, PrimExpr binding) { return iter_var; } -tvm::tir::IterVar Spatial(Range dom, PrimExpr binding, DataType dtype) { - using namespace tvm::tir; - ICHECK(dom.defined()) << "Spatial axis must have a domain"; - int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); - return PushBlockVar(IterVar(/*dom=*/dom, // - /*var=*/Var("_", dtype.with_bits(bits)), // - /*iter_type=*/IterVarType::kDataPar, // - /*thread_tag=*/""), - binding); -} - -tvm::tir::IterVar Reduce(Range dom, PrimExpr binding, DataType dtype) { - using namespace tvm::tir; - ICHECK(dom.defined()) << "Reduction axis must have a domain"; - int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); - return PushBlockVar(IterVar(/*dom=*/dom, // - /*var=*/Var("_", dtype.with_bits(bits)), // - /*iter_type=*/IterVarType::kCommReduce, // - /*thread_tag=*/""), - binding); -} +#define TVM_SCRIPT_BUILDER_TIR_AXIS_CREATE(Method, Kind, Name) \ + tvm::tir::IterVar Method(Range dom, PrimExpr binding, DataType dtype) { \ + using namespace tvm::tir; \ + ICHECK(dom.defined()) << Name << " axis must have a domain"; \ + int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); \ + return PushBlockVar(IterVar(/*dom=*/dom, /*var=*/Var("_", dtype.with_bits(bits)), \ + /*iter_type=*/Kind, /*thread_tag=*/""), \ + binding); \ + } +TVM_SCRIPT_BUILDER_TIR_AXIS_CREATE(Spatial, IterVarType::kDataPar, "Spatial"); +TVM_SCRIPT_BUILDER_TIR_AXIS_CREATE(Reduce, IterVarType::kCommReduce, "Reduction"); +TVM_SCRIPT_BUILDER_TIR_AXIS_CREATE(Scan, IterVarType::kOrdered, "Scan"); +TVM_SCRIPT_BUILDER_TIR_AXIS_CREATE(Opaque, IterVarType::kOpaque, "Opaque"); +#undef TVM_SCRIPT_BUILDER_TIR_AXIS_CREATE Array Remap(String kinds, Array bindings, DataType dtype) { using namespace tvm::tir; @@ -145,9 +233,18 @@ Array Remap(String kinds, Array bindings, DataType } // namespace axis TVM_REGISTER_NODE_TYPE(BlockFrameNode); +TVM_REGISTER_NODE_TYPE(BlockInitFrameNode); TVM_REGISTER_GLOBAL("script.builder.tir.BlockFrame").set_body_typed(Block_); +TVM_REGISTER_GLOBAL("script.builder.tir.BlockInitFrame").set_body_typed(Init); +TVM_REGISTER_GLOBAL("script.builder.tir.Where").set_body_typed(Where); +TVM_REGISTER_GLOBAL("script.builder.tir.Reads").set_body_typed(Reads); +TVM_REGISTER_GLOBAL("script.builder.tir.Writes").set_body_typed(Writes); +TVM_REGISTER_GLOBAL("script.builder.tir.BlockAttrs").set_body_typed(BlockAttrs); +TVM_REGISTER_GLOBAL("script.builder.tir.AllocBuffer").set_body_typed(AllocBuffer); TVM_REGISTER_GLOBAL("script.builder.tir.AxisSpatial").set_body_typed(axis::Spatial); TVM_REGISTER_GLOBAL("script.builder.tir.AxisReduce").set_body_typed(axis::Reduce); +TVM_REGISTER_GLOBAL("script.builder.tir.AxisScan").set_body_typed(axis::Scan); +TVM_REGISTER_GLOBAL("script.builder.tir.AxisOpaque").set_body_typed(axis::Opaque); TVM_REGISTER_GLOBAL("script.builder.tir.AxisRemap").set_body_typed(axis::Remap); } // namespace tir diff --git a/src/script/builder/tir/block_frame.h b/src/script/builder/tir/block_frame.h index 05e1969e5a54..7137a9a2bfca 100644 --- a/src/script/builder/tir/block_frame.h +++ b/src/script/builder/tir/block_frame.h @@ -39,6 +39,7 @@ class BlockFrameNode : public TIRFrameNode { Array iter_values; Optional predicate; + bool no_realize; void VisitAttrs(tvm::AttrVisitor* v) { TIRFrameNode::VisitAttrs(v); @@ -52,6 +53,7 @@ class BlockFrameNode : public TIRFrameNode { v->Visit("annotations", &annotations); v->Visit("iter_values", &iter_values); v->Visit("predicate", &predicate); + v->Visit("no_realize", &no_realize); } static constexpr const char* _type_key = "script.builder.tir.BlockFrame"; @@ -66,11 +68,43 @@ class BlockFrame : public TIRFrame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode); }; -BlockFrame Block_(String name); +BlockFrame Block_(String name, bool no_realize = false); + +class BlockInitFrameNode : public TIRFrameNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { TIRFrameNode::VisitAttrs(v); } + + static constexpr const char* _type_key = "script.builder.tir.BlockInitFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(BlockInitFrameNode, TIRFrameNode); + + public: + void EnterWithScope() final; + void ExitWithScope() final; +}; + +class BlockInitFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockInitFrame, TIRFrame, BlockInitFrameNode); +}; + +BlockInitFrame Init(); +BlockFrame FindBlockFrame(const String& method); +void Where(PrimExpr predicate); +void Reads(Array buffer_slices); +void Writes(Array buffer_slices); +void BlockAttrs(Map attrs); +tvm::tir::Buffer AllocBuffer(Array shape, DataType dtype = DataType::Float(32), + Optional data = NullOpt, Array strides = {}, + PrimExpr elem_offset = PrimExpr(), String storage_scope = "", + int align = -1, int offset_factor = 0, + String buffer_type_str = "default", Array axis_separators = {}, + Span span = Span()); namespace axis { tvm::tir::IterVar Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); tvm::tir::IterVar Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); +tvm::tir::IterVar Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); +tvm::tir::IterVar Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); Array Remap(String kinds, Array bindings, DataType dtype = DataType::Int(32)); } // namespace axis diff --git a/src/script/builder/tir/for_frame.cc b/src/script/builder/tir/for_frame.cc index 0b6d289e9fc9..9d1e5a63eec1 100644 --- a/src/script/builder/tir/for_frame.cc +++ b/src/script/builder/tir/for_frame.cc @@ -31,21 +31,22 @@ void ForFrameNode::ExitWithScope() { AddToParent(f_make_for_loop(vars, doms, AsStmt(stmts))); } -#define TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Method, Kind) \ - ForFrame Method(PrimExpr start, PrimExpr stop, Map annotations) { \ - using namespace tvm::tir; \ - PrimExpr min = start; \ - PrimExpr extent = arith::Analyzer().Simplify(stop - start); \ - ObjectPtr n = make_object(); \ - int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ - n->vars = {Var("v", DataType::Int(bits))}; \ - n->doms = {Range::FromMinExtent(min, extent)}; \ - n->f_make_for_loop = [annotations](Array vars, Array doms, Stmt body) { \ - ICHECK_EQ(vars.size(), 1); \ - ICHECK_EQ(doms.size(), 1); \ - return For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt, annotations); \ - }; \ - return ForFrame(n); \ +#define TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Method, Kind) \ + ForFrame Method(PrimExpr start, PrimExpr stop, Optional> annotations) { \ + using namespace tvm::tir; \ + PrimExpr min = start; \ + PrimExpr extent = arith::Analyzer().Simplify(stop - start); \ + ObjectPtr n = make_object(); \ + int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ + n->vars = {Var("v", DataType::Int(bits))}; \ + n->doms = {Range::FromMinExtent(min, extent)}; \ + n->f_make_for_loop = [annotations](Array vars, Array doms, Stmt body) { \ + ICHECK_EQ(vars.size(), 1); \ + ICHECK_EQ(doms.size(), 1); \ + return For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt, \ + annotations.value_or(Map())); \ + }; \ + return ForFrame(n); \ } TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Serial, tvm::tir::ForKind::kSerial); @@ -56,7 +57,7 @@ TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Unroll, tvm::tir::ForKind::kUnrolled); #undef TVM_SCRIPT_BUILDER_TIR_FOR_CREATE ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, - Map annotations) { + Optional> annotations) { using namespace tvm::tir; PrimExpr min = start; PrimExpr extent = arith::Analyzer().Simplify(stop - start); @@ -69,7 +70,7 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, ICHECK_EQ(doms.size(), 1); IterVar iter_var(Range(nullptr), NullValue(), IterVarType::kThreadIndex, thread); return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var, - annotations); + annotations.value_or(Map())); }; return ForFrame(n); } diff --git a/src/script/builder/tir/for_frame.h b/src/script/builder/tir/for_frame.h index e4d87cd7572a..7bb8c6ec3c5f 100644 --- a/src/script/builder/tir/for_frame.h +++ b/src/script/builder/tir/for_frame.h @@ -59,12 +59,16 @@ class ForFrame : public TIRFrame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode); }; -ForFrame Serial(PrimExpr start, PrimExpr stop, Map annotations); -ForFrame Parallel(PrimExpr start, PrimExpr stop, Map annotations); -ForFrame Vectorized(PrimExpr start, PrimExpr stop, Map annotations); -ForFrame Unroll(PrimExpr start, PrimExpr stop, Map annotations); +ForFrame Serial(PrimExpr start, PrimExpr stop, + Optional> annotations = NullOpt); +ForFrame Parallel(PrimExpr start, PrimExpr stop, + Optional> annotations = NullOpt); +ForFrame Vectorized(PrimExpr start, PrimExpr stop, + Optional> annotations = NullOpt); +ForFrame Unroll(PrimExpr start, PrimExpr stop, + Optional> annotations = NullOpt); ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, - Map annotations); + Optional> annotations = NullOpt); ForFrame Grid(Array extents); } // namespace tir diff --git a/src/script/builder/tir/prim_func_frame.cc b/src/script/builder/tir/prim_func_frame.cc index ecc6f97d663e..5d1d2ae9defb 100644 --- a/src/script/builder/tir/prim_func_frame.cc +++ b/src/script/builder/tir/prim_func_frame.cc @@ -22,19 +22,37 @@ #include #include "./block_frame.h" +#include "./var.h" namespace tvm { namespace script { namespace builder { namespace tir { +void PrimFuncFrameNode::EnterWithScope() { + TIRFrameNode::EnterWithScope(); + // add implicit root block + root_block_frame->EnterWithScope(); +} + void PrimFuncFrameNode::ExitWithScope() { using namespace tvm::tir; + root_block_frame->ExitWithScope(); TIRFrameNode::ExitWithScope(); Builder builder = Builder::Current(); + if (!(stmts.size() == 1 && stmts[0]->IsInstance())) { + LOG(FATAL) << "ValueError: PrimFuncFrame shoulde have one and only one root block."; + } + BlockRealize root_block_realize = Downcast(stmts[0]); + Block root_block = root_block_realize->block; + // remove redundant implicit root block + if (root_block->alloc_buffers.empty() && root_block->body->IsInstance() && + root_block->annotations.empty() && root_block->reads.empty() && root_block->writes.empty()) { + stmts.Set(0, root_block->body); + } PrimFunc func(/*params=*/args, /*body=*/AsStmt(stmts), - /*ret_type=*/ret_type, + /*ret_type=*/ret_type.value_or(TupleType::Empty()), /*buffer_map=*/buffer_map, /*preflattened_buffer_map=*/preflattened_buffer_map, /*attrs=*/DictAttrs(attrs)); @@ -43,7 +61,7 @@ void PrimFuncFrameNode::ExitWithScope() { builder->result = func; } else if (Optional opt_frame = builder->FindFrame()) { IRModuleFrame frame = opt_frame.value(); - frame->global_vars.push_back(GlobalVar(name)); + frame->global_vars.push_back(GlobalVar(name.value_or(""))); frame->functions.push_back(func); } else { LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc"; @@ -52,26 +70,43 @@ void PrimFuncFrameNode::ExitWithScope() { PrimFuncFrame PrimFunc_() { ObjectPtr n = make_object(); - n->name = ""; + n->name = NullOpt; n->args.clear(); - n->ret_type = TupleType::Empty(); + n->ret_type = NullOpt; n->buffer_map.clear(); n->preflattened_buffer_map.clear(); n->attrs.clear(); + n->root_block_frame = Block_("root"); return PrimFuncFrame(n); } +PrimFuncFrame FindPrimFuncFrame(const String& method) { + Builder builder = Builder::Current(); + if (Optional prim_func_frame = builder->FindFrame()) { + if (Optional block_frame = builder->GetLastFrame()) { + if (prim_func_frame.value()->root_block_frame.get() == block_frame.get()) { + return prim_func_frame.value(); + } + } + } else { + LOG(FATAL) << "ValueError: PrimFunc frame not find. Please ensure '" << method + << "' is called under T.prim_func()"; + } + LOG(FATAL) << "ValueError: '" << method << "' must be called immediately under T.prim_func()"; + throw; +} + tvm::tir::Var Arg(String name, tvm::tir::Var var) { + PrimFuncFrame frame = FindPrimFuncFrame("T.Arg"); Namer::Name(var, name); - PrimFuncFrame frame = Builder::Current()->FindFrame().value(); frame->args.push_back(var); return var; } tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer) { using namespace tvm::tir; + PrimFuncFrame frame = FindPrimFuncFrame("T.Arg"); Namer::Name(buffer, name); - PrimFuncFrame frame = Builder::Current()->FindFrame().value(); Var handle(buffer->name + "_handle", DataType::Handle()); frame->args.push_back(handle); frame->buffer_map.Set(handle, buffer); @@ -79,18 +114,27 @@ tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer) { } void FuncName(String name) { - PrimFuncFrame frame = Builder::Current()->FindFrame().value(); + PrimFuncFrame frame = FindPrimFuncFrame("T.func_name"); + if (frame->name.defined()) { + LOG(FATAL) << "Duplicate prim func name, previous one is " << frame->name.value(); + } frame->name = name; } void FuncAttrs(Map attrs) { using namespace tvm::tir; - PrimFuncFrame frame = Builder::Current()->FindFrame().value(); + PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr"); + if (!frame->attrs.empty()) { + LOG(FATAL) << "Duplicate prim func annotations, previous one is " << frame->attrs; + } frame->attrs = attrs; } tvm::Type FuncRet(tvm::Type ret_type) { - PrimFuncFrame frame = Builder::Current()->FindFrame().value(); + PrimFuncFrame frame = FindPrimFuncFrame("T.ret_type"); + if (frame->ret_type.defined()) { + LOG(FATAL) << "Duplicate prim func return type, previous one is " << frame->ret_type.value(); + } frame->ret_type = ret_type; return ret_type; } @@ -101,21 +145,10 @@ tvm::tir::Buffer MatchBuffer(ObjectRef param, Array shape, DataType dt int offset_factor, String buffer_type_str, Array axis_separators, Span span) { using namespace tvm::tir; - Var buffer_data; - if (!data.defined()) { - DataType storage_dtype = dtype; - if (storage_dtype == DataType::Bool()) { - storage_dtype = DataType::Int(8); - } - buffer_data = Var("", PointerType(PrimType(storage_dtype), storage_scope), span); - } else { - buffer_data = data.value(); - } - BufferType buffer_type = (buffer_type_str == "auto_broadcast") ? kAutoBroadcast : kDefault; - Buffer buffer(buffer_data, dtype, shape, strides, elem_offset, "", align, offset_factor, - buffer_type, axis_separators, span); - PrimFuncFrame frame = Builder::Current()->FindFrame().value(); + Buffer buffer = DeclBuffer(shape, dtype, "", data, strides, elem_offset, storage_scope, align, + offset_factor, buffer_type_str, axis_separators, span); if (const auto* var = param.as()) { + PrimFuncFrame frame = FindPrimFuncFrame("T.match_buffer"); Var v = GetRef(var); for (auto const& arg : frame->args) { if (arg.same_as(v)) { @@ -125,9 +158,8 @@ tvm::tir::Buffer MatchBuffer(ObjectRef param, Array shape, DataType dt } LOG(FATAL) << "ValueError: Can not bind non-input param to buffer."; } else if (const auto* buffer_region = param.as()) { - BlockFrame block_frame = Builder::Current()->FindFrame().value(); - block_frame->match_buffers.push_back( - MatchBufferRegion(buffer, GetRef(buffer_region))); + BlockFrame frame = FindBlockFrame("T.match_buffer"); + frame->match_buffers.push_back(MatchBufferRegion(buffer, GetRef(buffer_region))); } else { LOG(FATAL) << "ValueError: Unexpected type for TIR MatchBuffer."; } @@ -139,14 +171,13 @@ void PreflattenedBuffer(tvm::tir::Buffer postflattened_buffer, Array s PrimExpr elem_offset, String storage_scope, int align, int offset_factor, String buffer_type_str, Array axis_separators, Span span) { using namespace tvm::tir; - PrimFuncFrame frame = Builder::Current()->FindFrame().value(); + PrimFuncFrame frame = FindPrimFuncFrame("T.preflattened_buffer"); for (auto const& p : frame->buffer_map) { if (p.second.same_as(postflattened_buffer)) { - Var buffer_data = (data.defined()) ? data.value() : frame->buffer_map.at(p.first)->data; String buffer_name(postflattened_buffer->name + "_preflatten"); - BufferType buffer_type = (buffer_type_str == "auto_broadcast") ? kAutoBroadcast : kDefault; - Buffer buffer(buffer_data, dtype, shape, strides, elem_offset, buffer_name, align, - offset_factor, buffer_type, axis_separators, span); + Buffer buffer = + DeclBuffer(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope, align, + offset_factor, buffer_type_str, axis_separators, span); Namer::Name(buffer, buffer_name); frame->preflattened_buffer_map.Set(p.first, buffer); return; diff --git a/src/script/builder/tir/prim_func_frame.h b/src/script/builder/tir/prim_func_frame.h index 519603343259..a696b4ef84d1 100644 --- a/src/script/builder/tir/prim_func_frame.h +++ b/src/script/builder/tir/prim_func_frame.h @@ -20,6 +20,7 @@ #define TVM_SCRIPT_BUILDER_TIR_PRIM_FUNC_FRAME_H_ #include "./base.h" +#include "./block_frame.h" namespace tvm { namespace script { @@ -28,12 +29,13 @@ namespace tir { class PrimFuncFrameNode : public TIRFrameNode { public: - String name; + Optional name; Array args; - Type ret_type; + Optional ret_type; Map buffer_map; Map preflattened_buffer_map; Map attrs; + BlockFrame root_block_frame{nullptr}; void VisitAttrs(tvm::AttrVisitor* v) { TIRFrameNode::VisitAttrs(v); @@ -43,12 +45,14 @@ class PrimFuncFrameNode : public TIRFrameNode { v->Visit("buffer_map", &buffer_map); v->Visit("preflattened_buffer_map", &preflattened_buffer_map); v->Visit("attrs", &attrs); + v->Visit("root_block_frame", &root_block_frame); } static constexpr const char* _type_key = "script.builder.tir.PrimFuncFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncFrameNode, TIRFrameNode); public: + void EnterWithScope() final; void ExitWithScope() final; }; @@ -58,6 +62,7 @@ class PrimFuncFrame : public TIRFrame { }; PrimFuncFrame PrimFunc_(); +PrimFuncFrame FindPrimFuncFrame(const String& method); tvm::tir::Var Arg(String name, tvm::tir::Var var); tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer); void FuncName(String name); diff --git a/src/script/builder/tir/utils.h b/src/script/builder/tir/utils.h new file mode 100644 index 000000000000..0027ee49d9e8 --- /dev/null +++ b/src/script/builder/tir/utils.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_BUILDER_TIR_UTILS_H_ +#define TVM_SCRIPT_BUILDER_TIR_UTILS_H_ + +#include +#include + +namespace tvm { +namespace script { +namespace builder { +namespace tir { + +tvm::tir::BufferRegion BufferRegionFromLoad(tvm::tir::BufferLoad buffer_load) { + using namespace tvm::tir; + Array ranges; + for (const PrimExpr& index : buffer_load->indices) { + ranges.push_back(Range::FromMinExtent(index, 1)); + } + return BufferRegion(buffer_load->buffer, ranges); +} + +} // namespace tir +} // namespace builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_BUILDER_TIR_UTILS_H_ diff --git a/src/script/builder/tir/var.cc b/src/script/builder/tir/var.cc index e3d9c2367b66..e91bc20d37fe 100644 --- a/src/script/builder/tir/var.cc +++ b/src/script/builder/tir/var.cc @@ -27,6 +27,27 @@ tvm::tir::Buffer Buffer_(Array shape, DataType dtype, String name, Str return tvm::tir::decl_buffer(shape, dtype, name, storage_scope); } +tvm::tir::Buffer DeclBuffer(Array shape, DataType dtype, String buffer_name, + Optional data, Array strides, + PrimExpr elem_offset, String storage_scope, int align, + int offset_factor, String buffer_type_str, + Array axis_separators, Span span) { + using namespace tvm::tir; + Var buffer_data; + if (!data.defined()) { + DataType storage_dtype = dtype; + if (storage_dtype == DataType::Bool()) { + storage_dtype = DataType::Int(8); + } + buffer_data = Var(buffer_name, PointerType(PrimType(storage_dtype), storage_scope), span); + } else { + buffer_data = data.value(); + } + BufferType buffer_type = (buffer_type_str == "auto_broadcast") ? kAutoBroadcast : kDefault; + return Buffer(buffer_data, dtype, shape, strides, elem_offset, buffer_name, align, offset_factor, + buffer_type, axis_separators, span); +} + TVM_STATIC_IR_FUNCTOR(Namer, vtable) .set_dispatch([](const ObjectRef& node, String name) -> void { using namespace tvm::tir; diff --git a/src/script/builder/tir/var.h b/src/script/builder/tir/var.h index 81120cadb892..433018c0037d 100644 --- a/src/script/builder/tir/var.h +++ b/src/script/builder/tir/var.h @@ -31,7 +31,13 @@ tvm::tir::Buffer Buffer_(Array shape, // String name = "buffer", // String storage_scope = ""); -} +tvm::tir::Buffer DeclBuffer(Array shape, DataType dtype, String buffer_name, + Optional data, Array strides, + PrimExpr elem_offset, String storage_scope, int align, + int offset_factor, String buffer_type_str, + Array axis_separators, Span span); + +} // namespace tir } // namespace builder } // namespace script } // namespace tvm diff --git a/tests/python/tvmscript/test_builder_basic.py b/tests/python/tvmscript/test_builder_basic.py index c22265119ed9..035c5034b0ca 100644 --- a/tests/python/tvmscript/test_builder_basic.py +++ b/tests/python/tvmscript/test_builder_basic.py @@ -18,32 +18,164 @@ import tvm from tvm.script.builder import Builder, def_, def_many from tvm.script.builder import tir as T +from tvm.tir import BufferLoad, BufferRegion +from tvm.ir import Range -def test_builder_basic(): +def test_builder_root_block(): + print("test_builder_root_block") + # impilict root block + with Builder() as b0: + with T.prim_func(): + T.func_name("main") + T.func_attr({"key": "value"}) + with T.block(name="block"): + pass + print(b0.get().script()) + with Builder() as b1: + with T.prim_func(): + T.func_name("main") + T.func_attr({"key": "value"}) + A = def_("A", T.alloc_buffer((128,))) + with T.block(name="block"): + pass + print(b1.get().script()) + with Builder() as b2: + with T.prim_func(): + T.func_name("main") + T.func_attr({"key": "value"}) + A = def_("A", T.alloc_buffer((128,))) + with T.block(name="block0"): + pass + with T.block(name="block1"): + pass + print(b2.get().script()) + # expilict root block + with Builder() as b0_r: + with T.prim_func(): + T.func_name("main") + T.func_attr({"key": "value"}) + with T.block(name="root"): + with T.block(name="block"): + pass + print(b0_r.get().script()) + with Builder() as b1_r: + with T.prim_func(): + T.func_name("main") + T.func_attr({"key": "value"}) + with T.block(name="root"): + A = def_("A", T.alloc_buffer((128,))) + with T.block(name="block"): + pass + print(b1_r.get().script()) + with Builder() as b2_r: + with T.prim_func(): + T.func_name("main") + T.func_attr({"key": "value"}) + with T.block(name="root"): + A = def_("A", T.alloc_buffer((128,))) + with T.block(name="block0"): + pass + with T.block(name="block1"): + pass + print(b2_r.get().script()) + + +def test_builder_axis(): + print("test_builder_axis") + with Builder() as b: + with T.prim_func(): + T.func_name("main") + with T.grid(128, 128, 128, 128, 128) as (i, j, k, m, n): + def_many(["i", "j", "k", "m", "n"], [i, j, k, m, n]) + with T.block(name="block"): + vi = def_("vi", T.axis.spatial(128, i)) + vj = def_("vj", T.axis.spatial(128, j)) + vk = def_("vk", T.axis.reduce(128, k)) + vm = def_("vm", T.axis.scan(128, m)) + vn = def_("vn", T.axis.opaque(128, n)) + x, y, z = def_many(["x", "y", "z"], T.axis.remap("SSR", [i, j, k])) + print(b.get().script()) + + +def test_builder_prim_func(): + print("test_builder_prim_func") with Builder() as b: with T.prim_func(): T.func_name("main") T.func_attr({"global_symbol": "main"}) - T.func_ret(tvm.ir.PrimType("int8")) arg_a = T.arg("a", T.handle()) arg_b = T.arg("b", T.handle()) buffer_c = T.Buffer((128,), "float32") buffer_d = T.Buffer((128,), "float32") arg_c = T.arg("c", buffer_c) arg_d = T.arg("d", buffer_d) - A = def_("A", T.match_buffer(arg_a, (128, 128, 128))) - B = def_("B", T.match_buffer(arg_b, (128, 128, 128))) + T.func_ret(tvm.ir.PrimType("int8")) + A = def_("A", T.match_buffer(arg_a, (128, 128, 128), "int32")) + B = def_("B", T.match_buffer(arg_b, (128, 128, 128), "int32")) T.preflattened_buffer(buffer_c, (128,), data=buffer_c.data) T.preflattened_buffer(buffer_d, (128,), data=buffer_d.data) + print(b.get().script()) + + +def test_builder_block(): + print("test_builder_block") + with Builder() as b: + with T.prim_func(): + arg_a = T.arg("a", T.handle()) + arg_b = T.arg("b", T.handle()) + A = def_("A", T.match_buffer(arg_a, (128, 128, 128), "int32")) + B = def_("B", T.match_buffer(arg_b, (128, 128, 128), "int32")) with T.grid(128, 128, 128) as (i, j, k): def_many(["i", "j", "k"], [i, j, k]) with T.block(name="block"): - vi = def_("vi", T.axis.spatial(128, i)) - vj = def_("vj", T.axis.spatial(128, j)) - vk = def_("vk", T.axis.reduce(128, k)) + T.block_attr({"axis": 1}) + T.where(i > 1) + with T.init(): + pass + vi, vj, vk = def_many(["vi", "vj", "vk"], T.axis.remap("SSR", [i, j, k])) + T.reads( + BufferRegion( + A, + [ + Range(vi, vi + 1), + Range.from_min_extent(vj, 2), + Range(vk, vk + BufferLoad(B, [1, 2, BufferLoad(A, [3, 4, 5])])), + ], + ) + ) + T.writes([BufferLoad(A, [100, BufferLoad(A, [50, 51, 52]), 102])]) + E = def_("E", T.alloc_buffer((128, 128))) + F = def_("F", T.alloc_buffer((128, 128))) + print(b.get().script()) + + +def test_builder_for(): + print("test_builder_for") + with Builder() as b: + with T.prim_func(): + with T.grid(128, 128, 128) as (i, j, k): + def_many(["i", "j", "k"], [i, j, k]) + with T.serial(0, 128) as w: + w = def_("w", w) + with T.parallel(0, 128) as x: + x = def_("x", x) + with T.vectorized(0, 128) as y: + y = def_("y", y) + with T.unroll(0, 128) as z: + z = def_("z", z) + with T.thread_binding(0, 32, thread="blockIdx.x") as bx: + bx = def_("bx", bx) + with T.thread_binding(0, 2, thread="vthread.y") as vy: + vy = def_("vy", vy) + with T.thread_binding(0, 8, thread="threadIdx.z") as tz: + tz = def_("tz", tz) print(b.get().script()) if __name__ == "__main__": - test_builder_basic() + test_builder_root_block() + test_builder_axis() + test_builder_prim_func() + test_builder_block() + test_builder_for()