Skip to content

Commit

Permalink
[mir][Python][linalg] Support OpDSL extensions in C++.
Browse files Browse the repository at this point in the history
The patch extends the yaml code generation to support the following new OpDSL constructs:
- captures
- constants
- iteration index accesses
- predefined types
These changes have been introduced by revision
https://reviews.llvm.org/D101364.

Differential Revision: https://reviews.llvm.org/D102075
  • Loading branch information
Tobias Gysi committed May 19, 2021
1 parent 0bab7b2 commit 9a2769d
Show file tree
Hide file tree
Showing 16 changed files with 881 additions and 246 deletions.
166 changes: 161 additions & 5 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1,7 +1,7 @@
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul
cpp_op_name: MatmulOp
cpp_class_name: MatmulOp
doc: |-
Performs a matrix multiplication of two 2D inputs.
Expand Down Expand Up @@ -63,7 +63,7 @@ structured_op: !LinalgStructuredOpConfig
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_matmul
cpp_op_name: BatchMatmulOp
cpp_class_name: BatchMatmulOp
doc: |-
Performs a batched matrix multiplication of two 3D inputs.
Expand Down Expand Up @@ -126,7 +126,7 @@ structured_op: !LinalgStructuredOpConfig
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matvec
cpp_op_name: MatvecOp
cpp_class_name: MatvecOp
doc: |-
Performs a matrix-vector multiplication.
Expand Down Expand Up @@ -187,7 +187,7 @@ structured_op: !LinalgStructuredOpConfig
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: vecmat
cpp_op_name: VecmatOp
cpp_class_name: VecmatOp
doc: |-
Performs a vector-matrix multiplication.
Expand Down Expand Up @@ -248,7 +248,7 @@ structured_op: !LinalgStructuredOpConfig
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: dot
cpp_op_name: DotOp
cpp_class_name: DotOp
doc: |-
Performs a dot product of two vectors to a scalar result.
Expand Down Expand Up @@ -305,4 +305,160 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: B
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: fill_rng_2d
cpp_class_name: FillRng2DOp
doc: |-
Fills the output tensor with pseudo random numbers.
The operation generations pseudo random numbers using a linear congruential
generator. It provides no guarantees regarding the distribution of the
generated random numbers. Instead of generating the random numbers
sequentially, it instantiates one random number generator per data element
and runs them in parallel. The seed operand and the indices of the data
element seed the random number generation. The min and max operands limit
the range of the generated random numbers.
Note: The captures are hard-coded till there is capture support on the C++
side.
structured_op: !LinalgStructuredOpConfig
args:
- !<LinalgTensorDef>
name: O
usage: output
shape: affine_map<()[s0, s1] -> (s0, s1)>
element_type_var: T
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
iterator_types:
- parallel
- parallel
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
symbolic_cast:
type_var: T
operands:
- !ScalarExpression
scalar_apply:
fn_name: add
operands:
- !ScalarExpression
scalar_apply:
fn_name: mul
operands:
- !ScalarExpression
scalar_apply:
fn_name: add
operands:
- !ScalarExpression
symbolic_cast:
type_var: F64
operands:
- !ScalarExpression
scalar_const: '2147483647 : i64'
- !ScalarExpression
symbolic_cast:
type_var: F64
operands:
- !ScalarExpression
scalar_apply:
fn_name: add
operands:
- !ScalarExpression
scalar_apply:
fn_name: mul
operands:
- !ScalarExpression
scalar_apply:
fn_name: add
operands:
- !ScalarExpression
symbolic_cast:
type_var: I32
operands:
- !ScalarExpression
scalar_index: 1
- !ScalarExpression
scalar_apply:
fn_name: add
operands:
- !ScalarExpression
scalar_apply:
fn_name: mul
operands:
- !ScalarExpression
scalar_apply:
fn_name: add
operands:
- !ScalarExpression
symbolic_cast:
type_var: I32
operands:
- !ScalarExpression
scalar_index: 0
- !ScalarExpression
symbolic_cast:
type_var: I32
operands:
- !ScalarExpression
scalar_const: '42 : i64'
- !ScalarExpression
symbolic_cast:
type_var: I32
operands:
- !ScalarExpression
scalar_const: '1103515245 : i64'
- !ScalarExpression
symbolic_cast:
type_var: I32
operands:
- !ScalarExpression
scalar_const: '12345 : i64'
- !ScalarExpression
symbolic_cast:
type_var: I32
operands:
- !ScalarExpression
scalar_const: '1103515245 : i64'
- !ScalarExpression
symbolic_cast:
type_var: I32
operands:
- !ScalarExpression
scalar_const: '12345 : i64'
- !ScalarExpression
scalar_apply:
fn_name: mul
operands:
- !ScalarExpression
scalar_apply:
fn_name: sub
operands:
- !ScalarExpression
symbolic_cast:
type_var: F64
operands:
- !ScalarExpression
scalar_const: '1000 : i64'
- !ScalarExpression
symbolic_cast:
type_var: F64
operands:
- !ScalarExpression
scalar_const: '-1000 : i64'
- !ScalarExpression
symbolic_cast:
type_var: F64
operands:
- !ScalarExpression
scalar_const: '2.3283063999999999E-10 : f64'
- !ScalarExpression
symbolic_cast:
type_var: F64
operands:
- !ScalarExpression
scalar_const: '-1000 : i64'
57 changes: 45 additions & 12 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Expand Up @@ -220,14 +220,15 @@ namespace {

class RegionBuilderHelper {
public:
RegionBuilderHelper(Block &block) : block(block) {}
RegionBuilderHelper(MLIRContext *context, Block &block)
: context(context), block(block) {}

// Generates operations to cast the given operand to a specified type.
// If the cast cannot be performed, a warning will be issued and the
// operand returned as-is (which will presumably yield a verification
// issue downstream).
Value cast(Type toType, Value operand) {
OpBuilder builder = getBuilder(operand);
OpBuilder builder = getBuilder();
auto loc = operand.getLoc();

if (operand.getType() == toType)
Expand All @@ -236,11 +237,14 @@ class RegionBuilderHelper {
// If operand is floating point, cast directly to the int type.
if (operand.getType().isa<FloatType>())
return builder.create<FPToSIOp>(loc, toType, operand);
// Cast index operands directly to the int type.
if (operand.getType().isIndex())
return builder.create<IndexCastOp>(loc, toType, operand);
if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
// Either sign extend or truncate.
if (toIntType.getWidth() > fromIntType.getWidth())
return builder.create<SignExtendIOp>(loc, toType, operand);
else if (toIntType.getWidth() < fromIntType.getWidth())
if (toIntType.getWidth() < fromIntType.getWidth())
return builder.create<TruncateIOp>(loc, toType, operand);
}
} else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
Expand All @@ -251,7 +255,7 @@ class RegionBuilderHelper {
if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
if (toFloatType.getWidth() > fromFloatType.getWidth())
return builder.create<FPExtOp>(loc, toFloatType, operand);
else if (toFloatType.getWidth() < fromFloatType.getWidth())
if (toFloatType.getWidth() < fromFloatType.getWidth())
return builder.create<FPTruncOp>(loc, toFloatType, operand);
}
}
Expand All @@ -262,19 +266,28 @@ class RegionBuilderHelper {
}

Value applyfn__add(Value lhs, Value rhs) {
OpBuilder builder = getBuilder(lhs);
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<AddFOp>(lhs.getLoc(), lhs, rhs);
else if (isInteger(lhs))
if (isInteger(lhs))
return builder.create<AddIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}

Value applyfn__sub(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<SubFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<SubIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}

Value applyfn__mul(Value lhs, Value rhs) {
OpBuilder builder = getBuilder(lhs);
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<MulFOp>(lhs.getLoc(), lhs, rhs);
else if (isInteger(lhs))
if (isInteger(lhs))
return builder.create<MulIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
Expand All @@ -284,18 +297,39 @@ class RegionBuilderHelper {
if (values.empty())
return;
Value first = values.front();
OpBuilder builder = getBuilder(first);
OpBuilder builder = getBuilder();
builder.create<YieldOp>(first.getLoc(), values);
}

Value constant(std::string value) {
OpBuilder builder = getBuilder();
Location loc = builder.getUnknownLoc();
Attribute valueAttr = parseAttribute(value, builder.getContext());
return builder.create<ConstantOp>(loc, valueAttr.getType(), valueAttr);
}

Value index(int64_t dim) {
OpBuilder builder = getBuilder();
return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
}

Type getIntegerType(unsigned width) {
return IntegerType::get(context, width);
}

Type getFloat32Type() { return Float32Type::get(context); }

Type getFloat64Type() { return Float64Type::get(context); }

private:
MLIRContext *context;
Block &block;

bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }

OpBuilder getBuilder(Value value) {
OpBuilder builder(value.getContext());
OpBuilder getBuilder() {
OpBuilder builder(context);
builder.setInsertionPointToEnd(&block);
return builder;
}
Expand Down Expand Up @@ -1476,7 +1510,6 @@ computeReshapeCollapsedType(MemRefType type,
MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout}));
}


template <typename AffineExprTy>
unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
unsigned pos = 0;
Expand Down

0 comments on commit 9a2769d

Please sign in to comment.