Skip to content

Commit

Permalink
[Unity] Add support for AXIS_SEPARATOR in AlterOpImpl Pass (apache#15315
Browse files Browse the repository at this point in the history
)

* [Unity] Add support for AXIS_SEPARATOR in AlterOpImpl Pass

Enable support of AXIS_SEPARATOR to handle non-flat buffers
Modified  pass to handle AXIS_SEPARATOR

* Fix LINT errors.
  • Loading branch information
abhikran-quic authored and junrushao committed Jul 27, 2023
1 parent 0488cd3 commit 39d8805
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 22 deletions.
9 changes: 9 additions & 0 deletions include/tvm/relax/attrs/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,21 @@ struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
// pad_value is chosen to be of PrimValue type, as it represents constant TIR POD expression. This
// needs to be revisited in case PrimValue is evolved to represent symbolic expression in future.
Optional<PrimValue> pad_value;
/*!
* axis_separators between input axes when generating flattened output axes. For buffers
* representing flat 1-d memory (e.g. any buffer in RAM), this should be an empty array.
* For buffers representing non-flat memory, each entry in axis_separators should be the
* first input axis that is part of a new flattened axis.
*/
Optional<Array<IntImm>> axis_separators;

TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") {
TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply.");
TVM_ATTR_FIELD(pad_value).describe(
"The specific value to be used to pad if the layout transform would result in implicit "
"padding. If not specified, the compiler is free to choose any value.");
TVM_ATTR_FIELD(axis_separators)
.describe("The separators between input axes when generating flat output axes");
}
}; // struct LayoutTransformAttrs

Expand Down
4 changes: 3 additions & 1 deletion include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,10 +464,12 @@ TVM_DLL Pass DecomposeOpsForTraining(Optional<String> func_name);
* \param op_impl_map Map from from kOperatorName attr (e.g., relax.conv2d) to replacement PrimFunc
* \param op_buffer_transforms Map from kOperatorName attr to layout transformations on each of the
* PrimFunc i/o buffers.
* \param axis_separators Map from kOperatorName attr to axis_separators of each buffer_transforms
* \return The Pass.
*/
TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
const Map<String, Array<tir::IndexMap>>& op_buffer_transforms);
const Map<String, Array<tir::IndexMap>>& op_buffer_transforms,
const Map<String, Array<Array<IntImm>>>& axis_separators);

/*!
* \brief Layout conversion pass.
Expand Down
10 changes: 9 additions & 1 deletion python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def layout_transform(
x: Expr,
index_map: Union[Callable, IndexMap],
pad_value: Optional[Union[int, float, PrimValue]] = None,
axis_separators: Optional[Union[int, IndexMap.AXIS_SEPARATOR]] = None,
):
"""Modifies the layout of a tensor.
Expand All @@ -129,6 +130,9 @@ def layout_transform(
The value used for padding if the transformation results in implicit padding.
If not specified, any value can be used.
axis_separators : Optional[Union[int, IndexMap.AXIS_SEPARATOR]]
The axis_separators for index_map to create non flat buffers.
Returns
-------
result : relax.Expr
Expand All @@ -150,7 +154,11 @@ def layout_transform(
elif "float" in x_dtype and (isinstance(pad_value, (int, float))):
pad_value = FloatImm(x_dtype, float(pad_value))
pad_value = PrimValue(pad_value)
return _ffi_api.layout_transform(x, index_map, pad_value) # type: ignore

if axis_separators is None:
axis_separators = []

return _ffi_api.layout_transform(x, index_map, pad_value, axis_separators) # type: ignore


def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr:
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,7 @@ def DecomposeOpsForTraining(func_name: Optional[str] = None) -> tvm.ir.transform
def AlterOpImpl(
op_impl_map: Dict[str, PrimFunc],
op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]],
op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]],
):
"""Replace all PrimFunc's which have matching 'operator_name' attribute, with replacement
PrimFunc that could possibly have different layouts on i/o buffers. The layout
Expand All @@ -916,6 +917,9 @@ def AlterOpImpl(
op_kind to PrimFunc map
op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]
op_kind to layout transformation map for each of the buffers
op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]]
op_kind to axis_separator for each index_map
Returns
-------
ret: tvm.ir.transform.Pass
Expand All @@ -928,7 +932,9 @@ def AlterOpImpl(
l.append(transform)
op_buffer_transforms[operator_name] = l

return _ffi_api.AlterOpImpl(op_impl_map, op_buffer_transforms) # type: ignore
return _ffi_api.AlterOpImpl(
op_impl_map, op_buffer_transforms, op_buffer_axis_separators
) # type: ignore


def ConvertLayout(desired_layouts: Dict[str, List[str]]) -> tvm.ir.transform.Pass:
Expand Down
4 changes: 3 additions & 1 deletion src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,12 @@ TVM_REGISTER_OP("relax.flatten")
/* relax.layout_transform */
TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);

Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value) {
Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value,
Optional<Array<IntImm>> axis_separators) {
ObjectPtr<LayoutTransformAttrs> attrs = make_object<LayoutTransformAttrs>();
attrs->index_map = std::move(index_map);
attrs->pad_value = std::move(pad_value);
attrs->axis_separators = std::move(axis_separators);

static const Op& op = Op::Get("relax.layout_transform");
return Call(op, {std::move(x)}, Attrs{attrs}, {});
Expand Down
5 changes: 4 additions & 1 deletion src/relax/op/tensor/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,12 @@ Expr flatten(Expr x);
* \param index_map The transformation to apply.
* \param pad_value The value used for padding if the transformation results in implicit padding. If
* not specified, any value can be used.
* \param axis_separators Array of values to differentiate between input axes
* when generating flattened output axes.
* \return The transformed result.
*/
Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value);
Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value,
Optional<Array<IntImm>> axis_separators);

/*!
* \brief Permutes the dimensions of an array.
Expand Down
55 changes: 41 additions & 14 deletions src/relax/transform/alter_op_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,13 @@ bool IsTransformBijective(const Expr& expr, const IndexMap& transform) {
class AlterOpImplMutator : public ExprMutator {
public:
AlterOpImplMutator(const IRModule& mod, const Map<String, tir::PrimFunc>& op_impl_map,
const Map<String, Array<IndexMap>>& op_buffer_transforms_)
const Map<String, Array<IndexMap>>& op_buffer_transforms_,
const Map<String, Array<Array<IntImm>>>& axis_separators_)
: ExprMutator(mod),
mod_(mod),
op_impl_map_(op_impl_map),
op_buffer_transforms__(op_buffer_transforms_) {}
op_buffer_transforms__(op_buffer_transforms_),
op_buffer_axis_separators__(axis_separators_) {}

IRModule Run() {
for (const auto& [gv, func] : mod_->functions) {
Expand Down Expand Up @@ -119,7 +121,10 @@ class AlterOpImplMutator : public ExprMutator {
const auto& replacement_func = op_impl_map_[op_kind];

Array<IndexMap> buffer_transforms;
Optional<Array<Array<IntImm>>> axis_separators;
if (op_buffer_transforms__.count(op_kind)) buffer_transforms = op_buffer_transforms__[op_kind];
if (op_buffer_axis_separators__.count(op_kind))
axis_separators = op_buffer_axis_separators__[op_kind];

ICHECK(buffer_transforms.empty() || buffer_transforms.size() == replacement_func->params.size())
<< "Either the i/o buffers do not require any transformations or transformations for each "
Expand All @@ -130,15 +135,15 @@ class AlterOpImplMutator : public ExprMutator {
GlobalVar replacement_gv = GetOrCreateGlobalVarForFunc(replacement_func, op_kind);

auto call_tir_inputs_tuple = GetRef<Tuple>(call->args[1].as<TupleNode>());
Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, buffer_transforms);
Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, buffer_transforms, axis_separators);

ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir sinfo_args.size() is expected to be 1";
StructInfo updated_ret_sinfo = UpdateStructInfo(call->sinfo_args[0], buffer_transforms);
auto updated_call = builder_->Normalize(
Call(call_tir_op_, {replacement_gv, updated_inputs}, call->attrs, {updated_ret_sinfo}));

// Now transform each of the outputs to previous layout.
return TransformOutputs(updated_call, buffer_transforms, call->sinfo_args[0]);
return TransformOutputs(updated_call, buffer_transforms, call->sinfo_args[0], axis_separators);
}

Array<TensorStructInfo> GetTensorStructInfoPerOutput(const StructInfo& output_sinfo) {
Expand All @@ -157,17 +162,20 @@ class AlterOpImplMutator : public ExprMutator {
return arr_tensor_sinfo;
}

Expr TransformLayout(const Expr& expr, const IndexMap& index_map) {
Expr TransformLayout(const Expr& expr, const IndexMap& index_map,
const Array<IntImm> axis_separators) {
ObjectPtr<LayoutTransformAttrs> attrs = make_object<LayoutTransformAttrs>();
// We want to avoid two layout_transform ops to share the same index map even if they are
// identical. The scope of vars used in index map initial indices is local to the op. Not doing
// so would confuse the structural equality check.
attrs->index_map = std::move(DeepCopyIndexMap(index_map));
attrs->axis_separators = std::move(axis_separators);
return Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {});
}

Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map,
const TensorStructInfo& old_tensor_sinfo) {
const TensorStructInfo& old_tensor_sinfo,
const Array<IntImm>& axis_separator) {
Array<PrimExpr> old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo);
Array<Range> initial_ranges = ConstructRangeFromShape(old_shape);
arith::Analyzer analyzer;
Expand All @@ -177,7 +185,7 @@ class AlterOpImplMutator : public ExprMutator {
<< "Only bijective transformations on input/output buffers are supported, but found "
"padding predicate "
<< padding_predicate << " on initial range " << initial_ranges;
return TransformLayout(expr, inverse_index_map);
return TransformLayout(expr, inverse_index_map, axis_separator);
}

/*!
Expand All @@ -202,16 +210,22 @@ class AlterOpImplMutator : public ExprMutator {
/*!
* \brief Updates call inputs with layout transformed inputs
*/
Tuple UpdateInputs(const Tuple& inputs, const Array<IndexMap>& transforms) {
Tuple UpdateInputs(const Tuple& inputs, const Array<IndexMap>& transforms,
const Optional<Array<Array<IntImm>>>& axis_separators) {
if (transforms.empty()) return inputs;

Array<Expr> updated_inputs;
int index = 0;
for (const auto& input : inputs->fields) {
Array<IntImm> axis_separator;
if (axis_separators.defined()) {
Array<Array<IntImm>> axis_separators_value = axis_separators.value();
axis_separator = axis_separators_value[index];
}
auto transform = transforms[index++];
ICHECK(IsTransformBijective(input, transform))
<< "Non bijective transforms on input and output buffers are not supported.";
updated_inputs.push_back(TransformLayout(input, transform));
updated_inputs.push_back(TransformLayout(input, transform, axis_separator));
}
return Tuple(updated_inputs);
}
Expand Down Expand Up @@ -254,29 +268,39 @@ class AlterOpImplMutator : public ExprMutator {
}

Expr TransformOutputs(const Expr& expr, const Array<IndexMap>& buffer_transforms,
const StructInfo& old_struct_info) {
const StructInfo& old_struct_info,
const Optional<Array<Array<IntImm>>>& axis_separators) {
if (buffer_transforms.empty()) return expr;

Array<TensorStructInfo> old_output_sinfo = GetTensorStructInfoPerOutput(old_struct_info);

Array<IntImm> axis_sep;
size_t num_outputs = old_output_sinfo.size();
if (num_outputs == 0) return expr;

size_t first_output_index = buffer_transforms.size() - num_outputs;
// If there is a single output, return the transformed output.
if (num_outputs == 1) {
IndexMap output_map = buffer_transforms[first_output_index];
return TransformLayoutInverse(expr, output_map, old_output_sinfo[0]);
if (axis_separators.defined()) {
Array<Array<IntImm>> axis_separators_value = axis_separators.value();
axis_sep = axis_separators_value[first_output_index];
}
return TransformLayoutInverse(expr, output_map, old_output_sinfo[0], axis_sep);
}

// In case of more than one output, we would have to get each item of the output tuple,
// transform it and return a tuple of all transformed outputs.
Array<Expr> transformed_outputs;
for (size_t i = 0; i + first_output_index < buffer_transforms.size(); ++i) {
const auto& output_map = buffer_transforms[i + first_output_index];
if (axis_separators.defined()) {
Array<Array<IntImm>> axis_separators_value = axis_separators.value();
axis_sep = axis_separators_value[i + first_output_index];
}
auto output = builder_->Normalize(TupleGetItem(expr, static_cast<int>(i)));
transformed_outputs.push_back(
TransformLayoutInverse(output, output_map, old_output_sinfo[i]));
TransformLayoutInverse(output, output_map, old_output_sinfo[i], axis_sep));
}
return Tuple(transformed_outputs);
}
Expand All @@ -290,6 +314,8 @@ class AlterOpImplMutator : public ExprMutator {
const Map<String, PrimFunc>& op_impl_map_;
/*! \brief Map from kOperatorName attribute to the layout transforms on i/o buffers */
const Map<String, Array<IndexMap>>& op_buffer_transforms__;
/*! \brief Map from kOperatorName attribute to the axis separatos on i/o buffers */
const Map<String, Array<Array<IntImm>>>& op_buffer_axis_separators__;

const Op& call_tir_op_ = Op::Get("relax.call_tir");
const Op& layout_transform_op_ = Op::Get("relax.layout_transform");
Expand All @@ -298,10 +324,11 @@ class AlterOpImplMutator : public ExprMutator {
namespace transform {

Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
const Map<String, Array<IndexMap>>& op_buffer_transforms_) {
const Map<String, Array<IndexMap>>& op_buffer_transforms_,
const Map<String, Array<Array<IntImm>>>& axis_separators_) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
PassContext pc) {
return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_).Run();
return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_, axis_separators_).Run();
};
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
Expand Down

0 comments on commit 39d8805

Please sign in to comment.