Skip to content

Commit

Permalink
block methods (apache#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyx-6 authored and junrushao committed Jul 4, 2022
1 parent ec4ad60 commit b627697
Show file tree
Hide file tree
Showing 15 changed files with 566 additions and 106 deletions.
2 changes: 1 addition & 1 deletion python/tvm/script/builder/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from . import axis
from .base import TIRFrame
from .block_frame import block
from .block_frame import block, where, reads, writes, alloc_buffer, block_attr, init
from .for_frame import (
ForFrame,
grid,
Expand Down
16 changes: 15 additions & 1 deletion python/tvm/script/builder/tir/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from . import _ffi_api

from typing import List


def spatial(dom, binding, dtype="int32") -> IterVar:
if not isinstance(dom, Range):
Expand All @@ -34,7 +36,19 @@ def reduce(dom, binding, dtype="int32") -> IterVar:
return _ffi_api.AxisReduce(dom, binding, dtype) # pylint: disable=no-member # type: ignore


def remap(kinds, bindings, dtype="int32") -> IterVar:
def scan(dom, binding, dtype="int32") -> IterVar:
if not isinstance(dom, Range):
dom = Range(0, dom)
return _ffi_api.AxisScan(dom, binding, dtype) # pylint: disable=no-member # type: ignore


def opaque(dom, binding, dtype="int32") -> IterVar:
if not isinstance(dom, Range):
dom = Range(0, dom)
return _ffi_api.AxisOpaque(dom, binding, dtype) # pylint: disable=no-member # type: ignore


def remap(kinds, bindings, dtype="int32") -> List[IterVar]:
return _ffi_api.AxisRemap(kinds, bindings, dtype) # pylint: disable=no-member # type: ignore


Expand Down
64 changes: 62 additions & 2 deletions python/tvm/script/builder/tir/block_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,71 @@
from . import _ffi_api
from .base import TIRFrame

from typing import List, Dict, Any, Union
from tvm.tir import Buffer, BufferLoad, BufferRegion


@_register_object("script.builder.tir.BlockFrame")
class BlockFrame(TIRFrame):
...


def block(name) -> BlockFrame:
return _ffi_api.BlockFrame(name) # pylint: disable=no-member # type: ignore
@_register_object("script.builder.tir.BlockInitFrame")
class BlockInitFrame(TIRFrame):
...


def block(name: str, no_realize: bool = False) -> BlockFrame:
return _ffi_api.BlockFrame(name, no_realize) # pylint: disable=no-member # type: ignore


def init() -> BlockInitFrame:
return _ffi_api.BlockInitFrame() # pylint: disable=no-member # type: ignore


def where(predicate) -> None:
_ffi_api.Where(predicate) # pylint: disable=no-member # type: ignore


def reads(buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
if not isinstance(buffer_slices, List):
buffer_slices = [buffer_slices]
_ffi_api.Reads(buffer_slices)


def writes(buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
if not isinstance(buffer_slices, List):
buffer_slices = [buffer_slices]
_ffi_api.Writes(buffer_slices)


def block_attr(attrs: Dict[str, Any]) -> None:
return _ffi_api.BlockAttrs(attrs) # pylint: disable=no-member # type: ignore


def alloc_buffer(
shape,
dtype="float32",
data=None,
strides=[],
elem_offset=None,
storage_scope="",
align=-1,
offset_factor=0,
buffer_type="default",
axis_separators=None,
span=None,
) -> Buffer:
return _ffi_api.AllocBuffer(
shape,
dtype,
data,
strides,
elem_offset,
storage_scope,
align,
offset_factor,
buffer_type,
axis_separators,
span,
)
10 changes: 5 additions & 5 deletions python/tvm/script/builder/tir/for_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,23 @@ def __enter__(self) -> List[Var]:
return self.vars


def serial(start, stop, annotations) -> ForFrame:
def serial(start, stop, annotations=None) -> ForFrame:
return _ffi_api.Serial(start, stop, annotations) # pylint: disable=no-member # type: ignore


def parallel(start, stop, annotations) -> ForFrame:
def parallel(start, stop, annotations=None) -> ForFrame:
return _ffi_api.Parallel(start, stop, annotations) # pylint: disable=no-member # type: ignore


def vectorized(start, stop, annotations) -> ForFrame:
def vectorized(start, stop, annotations=None) -> ForFrame:
return _ffi_api.Vectorized(start, stop, annotations) # pylint: disable=no-member # type: ignore


def unroll(start, stop, annotations) -> ForFrame:
def unroll(start, stop, annotations=None) -> ForFrame:
return _ffi_api.Unroll(start, stop, annotations) # pylint: disable=no-member # type: ignore


def thread_binding(start, stop, thread, annotations) -> ForFrame:
def thread_binding(start, stop, thread, annotations=None) -> ForFrame:
return _ffi_api.ThreadBinding( # pylint: disable=no-member # type: ignore
start, stop, thread, annotations
)
Expand Down
11 changes: 11 additions & 0 deletions src/script/builder/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class BuilderNode : public runtime::Object {
public:
template <typename TFrame>
inline Optional<TFrame> FindFrame() const;
template <typename TFrame>
inline Optional<TFrame> GetLastFrame() const;

template <typename TObjectRef>
inline TObjectRef Get() const;
Expand Down Expand Up @@ -90,6 +92,15 @@ inline Optional<TFrame> BuilderNode::FindFrame() const {
return NullOpt;
}

template <typename TFrame>
inline Optional<TFrame> BuilderNode::GetLastFrame() const {
using TFrameNode = typename TFrame::ContainerType;
if (!frames.empty() && frames.back()->IsInstance<TFrameNode>()) {
return Downcast<TFrame>(frames.back());
}
return NullOpt;
}

template <typename TObjectRef>
inline TObjectRef BuilderNode::Get() const {
using TObject = typename TObjectRef::ContainerType;
Expand Down
161 changes: 129 additions & 32 deletions src/script/builder/tir/block_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@
* specific language governing permissions and limitations
* under the License.
*/

#include "./block_frame.h"

#include <tvm/runtime/registry.h>

#include "./for_frame.h"
#include "./utils.h"
#include "./var.h"

namespace tvm {
namespace script {
namespace builder {
namespace tir {

BlockFrame Block_(String name) {
BlockFrame Block_(String name, bool no_realize) {
ObjectPtr<BlockFrameNode> n = make_object<BlockFrameNode>();
n->name = name;
n->iter_vars.clear();
Expand All @@ -39,24 +42,116 @@ BlockFrame Block_(String name) {
n->annotations.clear();
n->iter_values.clear();
n->predicate = NullOpt;
n->no_realize = no_realize;
return BlockFrame(n);
}

void BlockFrameNode::ExitWithScope() {
using namespace tvm::tir;
TIRFrameNode::ExitWithScope();
AddToParent(BlockRealize(iter_values, //
predicate.value_or(Bool(true)),
Block(iter_vars, //
reads, writes, //
name, //
AsStmt(stmts), //
init, //
alloc_buffers, //
match_buffers, //
annotations)));
Block block = Block(iter_vars, reads, writes, name, AsStmt(stmts), init, alloc_buffers,
match_buffers, annotations);
if (no_realize) {
CHECK(iter_values.empty()) << "ValueError: Block bindings are not allowed when `no_realize=True`";
CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`";
AddToParent(block);
} else {
AddToParent(BlockRealize(iter_values, predicate.value_or(Bool(true)), block));
}
}

BlockInitFrame Init() {
ObjectPtr<BlockInitFrameNode> n = make_object<BlockInitFrameNode>();
return BlockInitFrame(n);
}

void BlockInitFrameNode::EnterWithScope() {
BlockFrame frame = FindBlockFrame("T.init");
if (frame->init.defined()) {
LOG(FATAL) << "Duplicate block init declaration";
}
TIRFrameNode::EnterWithScope();
}

void BlockInitFrameNode::ExitWithScope() {
TIRFrameNode::ExitWithScope();
BlockFrame frame = FindBlockFrame("T.init");
frame->init = AsStmt(stmts);
}

BlockFrame FindBlockFrame(const String& method) {
if (Optional<BlockFrame> block_frame = Builder::Current()->GetLastFrame<BlockFrame>()) {
return block_frame.value();
} else {
LOG(FATAL) << "ValueError: Block frame not find. Please ensure '" << method
<< "' is called under T.block()";
}
throw;
}

void Where(PrimExpr predicate) {
BlockFrame frame = FindBlockFrame("T.where");
if (frame->predicate.defined()) {
LOG(FATAL) << "Duplicate block predicate declaration, previous one is "
<< frame->predicate.value();
}
frame->predicate = predicate;
}

void Reads(Array<ObjectRef> buffer_slices) {
using namespace tvm::tir;
BlockFrame frame = FindBlockFrame("T.reads");
if (!frame->reads.empty()) {
LOG(FATAL) << "Duplicate read region declaration, previous one is " << frame->reads;
}
for (const ObjectRef& obj : buffer_slices) {
if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
frame->reads.push_back(GetRef<BufferRegion>(buffer_region));
} else if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
frame->reads.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load)));
} else {
LOG(FATAL) << "Invalid type for buffer reads.";
}
}
}

void Writes(Array<ObjectRef> buffer_slices) {
using namespace tvm::tir;
BlockFrame frame = FindBlockFrame("T.writes");
if (!frame->writes.empty()) {
LOG(FATAL) << "Duplicate write region declaration, previous one is " << frame->writes;
}
for (const ObjectRef& obj : buffer_slices) {
if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
frame->writes.push_back(GetRef<BufferRegion>(buffer_region));
} else if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
frame->writes.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load)));
} else {
LOG(FATAL) << "Invalid type for buffer writes.";
}
}
}

void BlockAttrs(Map<String, ObjectRef> attrs) {
BlockFrame frame = FindBlockFrame("T.block_attr");
if (!frame->annotations.empty()) {
LOG(FATAL) << "Duplicate block annotations, previous one is " << frame->annotations;
}
frame->annotations = attrs;
}

tvm::tir::Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype, Optional<tvm::tir::Var> data,
Array<PrimExpr> strides, PrimExpr elem_offset, String storage_scope,
int align, int offset_factor, String buffer_type_str,
Array<IntImm> axis_separators, Span span) {
using namespace tvm::tir;
Buffer buffer = DeclBuffer(shape, dtype, "", data, strides, elem_offset, storage_scope, align,
offset_factor, buffer_type_str, axis_separators, span);
BlockFrame frame = FindBlockFrame("T.alloc_buffer");
frame->alloc_buffers.push_back(buffer);
return buffer;
};

namespace axis {

// TODO(@junrushao1994): figure out the Block syntax without BlockRealize
Expand All @@ -72,27 +167,20 @@ tvm::tir::IterVar PushBlockVar(tvm::tir::IterVar iter_var, PrimExpr binding) {
return iter_var;
}

tvm::tir::IterVar Spatial(Range dom, PrimExpr binding, DataType dtype) {
using namespace tvm::tir;
ICHECK(dom.defined()) << "Spatial axis must have a domain";
int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()});
return PushBlockVar(IterVar(/*dom=*/dom, //
/*var=*/Var("_", dtype.with_bits(bits)), //
/*iter_type=*/IterVarType::kDataPar, //
/*thread_tag=*/""),
binding);
}

tvm::tir::IterVar Reduce(Range dom, PrimExpr binding, DataType dtype) {
using namespace tvm::tir;
ICHECK(dom.defined()) << "Reduction axis must have a domain";
int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()});
return PushBlockVar(IterVar(/*dom=*/dom, //
/*var=*/Var("_", dtype.with_bits(bits)), //
/*iter_type=*/IterVarType::kCommReduce, //
/*thread_tag=*/""),
binding);
}
#define TVM_SCRIPT_BUILDER_TIR_AXIS_CREATE(Method, Kind, Name) \
tvm::tir::IterVar Method(Range dom, PrimExpr binding, DataType dtype) { \
using namespace tvm::tir; \
ICHECK(dom.defined()) << Name << " axis must have a domain"; \
int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); \
return PushBlockVar(IterVar(/*dom=*/dom, /*var=*/Var("_", dtype.with_bits(bits)), \
/*iter_type=*/Kind, /*thread_tag=*/""), \
binding); \
}
TVM_SCRIPT_BUILDER_TIR_AXIS_CREATE(Spatial, IterVarType::kDataPar, "Spatial");
TVM_SCRIPT_BUILDER_TIR_AXIS_CREATE(Reduce, IterVarType::kCommReduce, "Reduction");
TVM_SCRIPT_BUILDER_TIR_AXIS_CREATE(Scan, IterVarType::kOrdered, "Scan");
TVM_SCRIPT_BUILDER_TIR_AXIS_CREATE(Opaque, IterVarType::kOpaque, "Opaque");
#undef TVM_SCRIPT_BUILDER_TIR_AXIS_CREATE

Array<tvm::tir::IterVar> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype) {
using namespace tvm::tir;
Expand Down Expand Up @@ -145,9 +233,18 @@ Array<tvm::tir::IterVar> Remap(String kinds, Array<PrimExpr> bindings, DataType
} // namespace axis

TVM_REGISTER_NODE_TYPE(BlockFrameNode);
TVM_REGISTER_NODE_TYPE(BlockInitFrameNode);
TVM_REGISTER_GLOBAL("script.builder.tir.BlockFrame").set_body_typed(Block_);
TVM_REGISTER_GLOBAL("script.builder.tir.BlockInitFrame").set_body_typed(Init);
TVM_REGISTER_GLOBAL("script.builder.tir.Where").set_body_typed(Where);
TVM_REGISTER_GLOBAL("script.builder.tir.Reads").set_body_typed(Reads);
TVM_REGISTER_GLOBAL("script.builder.tir.Writes").set_body_typed(Writes);
TVM_REGISTER_GLOBAL("script.builder.tir.BlockAttrs").set_body_typed(BlockAttrs);
TVM_REGISTER_GLOBAL("script.builder.tir.AllocBuffer").set_body_typed(AllocBuffer);
TVM_REGISTER_GLOBAL("script.builder.tir.AxisSpatial").set_body_typed(axis::Spatial);
TVM_REGISTER_GLOBAL("script.builder.tir.AxisReduce").set_body_typed(axis::Reduce);
TVM_REGISTER_GLOBAL("script.builder.tir.AxisScan").set_body_typed(axis::Scan);
TVM_REGISTER_GLOBAL("script.builder.tir.AxisOpaque").set_body_typed(axis::Opaque);
TVM_REGISTER_GLOBAL("script.builder.tir.AxisRemap").set_body_typed(axis::Remap);

} // namespace tir
Expand Down
Loading

0 comments on commit b627697

Please sign in to comment.