Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,8 @@ extensionComplianceMap = {
{{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T},
SpecificationVersion::V_1_1_DRAFT},
{{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T},
SpecificationVersion::V_1_1_DRAFT},
{{mxint8T, fp8ue8m0T, mxint8T, fp8ue8m0T, fp32T},
SpecificationVersion::V_1_1_DRAFT}}}}},
{"tosa.max_pool2d",
{{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
Expand Down Expand Up @@ -870,27 +872,31 @@ extensionComplianceMap = {
{{fp6e2m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
{{fp6e3m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
{{fp8e4m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
{{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
{{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
{{mxint8T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
allOf},
{{Extension::mxfp},
{{{fp4e2m1T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
{{fp6e2m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
{{fp6e3m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
{{fp8e4m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
{{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
{{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
{{mxint8T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
{"tosa.cast_to_block_scaled",
{{{Extension::mxfp},
{{{bf16T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{fp32T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{fp32T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{fp32T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{fp32T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}},
{{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{fp32T, mxint8T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::bf16, Extension::mxfp},
{{{bf16T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{bf16T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{bf16T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}},
{{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{bf16T, mxint8T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}},
allOf}}},
{"tosa.rescale",
{{{Extension::int16},
Expand All @@ -908,7 +914,8 @@ extensionComplianceMap = {
{{{fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{fp6e3m2T}, SpecificationVersion::V_1_1_DRAFT},
{{fp6e2m3T}, SpecificationVersion::V_1_1_DRAFT},
{{fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT}}}}},
{{fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT},
{{mxint8T}, SpecificationVersion::V_1_1_DRAFT}}}}},
{"tosa.identity",
{{{Extension::int4}, {{{i4T, i4T}, SpecificationVersion::V_1_0}}},
{{Extension::int16}, {{{i48T, i48T}, SpecificationVersion::V_1_0}}},
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ Value createPadConstTensor(OpBuilder &builder, Location loc, Value src,
// returns type of variable op
RankedTensorType getVariableType(VariableOp variableOp);

// Returns the bitwidth of a TOSA tensor element type
unsigned getBitWidth(Type type);

} // namespace tosa
} // namespace mlir

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ProfileInfoDepot {

private:
TypeInfo convertTypeToInfo(Type type) {
return {type.getTypeID(), type.getIntOrFloatBitWidth()};
return {type.getTypeID(), tosa::getBitWidth(type)};
}

TypeInfo convertValueToInfo(Value value) {
Expand Down
33 changes: 21 additions & 12 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
// Tosa Type Definitions.
//===----------------------------------------------------------------------===//

// The base class for Tosa dialect types.
class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<Tosa_Dialect, name, traits> {
let mnemonic = typeMnemonic;
}

// The base class of a quantized type.
// Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end].
// Where low and high ends are 0,255 when unsigned, -128,127 when signed, for
Expand Down Expand Up @@ -78,13 +84,26 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
Tosa_QuantizedType<"int16", [16, 0], 1>,
Tosa_QuantizedType<"int32", [32, 0], 1>]>;

//===----------------------------------------------------------------------===//
// Custom TOSA element types.
//===----------------------------------------------------------------------===//

// MLIR doesn't have a builtin type for mxint8 yet. For now declared it as a
// custom TOSA type. This may be changed in the future.
def Tosa_MXInt8 : Tosa_Type<"mxint8", "mxint8"> {
let summary = "INT8 type as defined by OCP-MX";
let description = [{
8-bit integer format with an implicit 1/64 scale defined by OCP-MX.
}];
}

//===----------------------------------------------------------------------===//
// Multi-category types.
//===----------------------------------------------------------------------===//
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat, Tosa_MXInt8],
"number">;

def Tosa_MXFPNumber : AnyTypeOf<[F8E4M3FN, F8E5M2, F4E2M1FN, F6E2M3FN, F6E3M2FN],
def Tosa_MXFPNumber : AnyTypeOf<[F8E4M3FN, F8E5M2, F4E2M1FN, F6E2M3FN, F6E3M2FN, Tosa_MXInt8],
"micro-scaling format number">;
def Tosa_MXFPScaleNumber : AnyTypeOf<[F8E8M0FNU], "micro-scaling format scale number">;

Expand Down Expand Up @@ -265,16 +284,6 @@ def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>;
def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>;
def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>;

//===----------------------------------------------------------------------===//
// Tosa Type Definitions.
//===----------------------------------------------------------------------===//

// The base class for Tosa dialect types.
class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<Tosa_Dialect, name, traits> {
let mnemonic = typeMnemonic;
}

//===----------------------------------------------------------------------===//
// ShapeType
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,12 @@ Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc,
return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
}

unsigned mlir::tosa::getBitWidth(Type type) {
if (dyn_cast<tosa::mxint8Type>(type))
return 8;
return type.getIntOrFloatBitWidth();
}

//===----------------------------------------------------------------------===//
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ TosaProfileCompliance::TosaProfileCompliance() {
const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
const TypeInfo mxint8T = {mlir::tosa::mxint8Type::getTypeID(), 8};

// The profile-based compliance content below is auto-generated by a script
// in https://git.mlplatform.org/tosa/specification.git
Expand Down Expand Up @@ -625,6 +626,8 @@ TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) {
return {"fp4e2m1"};
} else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) {
return {"fp8e8m0"};
} else if (typeInfo.typeID == tosa::mxint8Type::getTypeID()) {
return {"mxint8"};
}
llvm_unreachable("unknown type");
}
7 changes: 4 additions & 3 deletions mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op,
<< " shape dimension cannot be dynamic";
}

int64_t element_bits = type.getElementTypeBitWidth();
int64_t element_bits = tosa::getBitWidth(getElementTypeOrSelf(type));
int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
int64_t size = element_bytes * type.getNumElements();

Expand Down Expand Up @@ -1217,9 +1217,10 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
return true;
}
}
} else if (mlir::isa<tosa::shapeType>(type)) {
} else if (mlir::isa<tosa::shapeType>(type))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: this type check for tosa::shapeType is inconsistent with the rest of the function, which uses isa<...>(type)

return true;
else if (isa<tosa::mxint8Type>(type))
return true;
}
return false;
}

Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Dialect/Tosa/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,13 @@ func.func @test_matmul_t_block_scaled_broadcast(%arg0: tensor<?x8x32xf8E4M3FN>,
return %0 : tensor<4x8x16xf32>
}

// -----
// CHECK-LABEL: test_matmul_t_block_scaled_mxint8
func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32x!tosa.mxint8>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x32x!tosa.mxint8>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32x!tosa.mxint8>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
return %0 : tensor<4x8x16xf32>
}

// -----
// CHECK-LABEL: test_cast_from_block_scaled_static
func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
Expand Down Expand Up @@ -1296,3 +1303,17 @@ func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
}

// -----
// CHECK-LABEL: test_cast_to_block_scaled_mxint8
func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) {
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
}

// -----
// CHECK-LABEL: test_const_mxint8
func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> {
%0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8>
return %0 : tensor<2x!tosa.mxint8>
}
52 changes: 50 additions & 2 deletions mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ func.func @test_argmax_int64(%arg0: tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64
// -----

// CHECK-LABEL: test_const_i64
func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
func.func @test_const_i64() -> tensor<4xi64> {
%0 = "tosa.const"() {values = dense<[3, 0, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
return %0 : tensor<4xi64>
}

// -----

// CHECK-LABEL: test_const_fp6e3m2
func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> {
func.func @test_const_fp6e3m2() -> tensor<4xf6E3M2FN> {
%0 = "tosa.const"() {values = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf6E3M2FN>} : () -> tensor<4xf6E3M2FN>
return %0 : tensor<4xf6E3M2FN>
}
Expand Down Expand Up @@ -82,3 +82,51 @@ func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>)
return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>
}

// -----

// CHECK-LABEL: test_cast_to_block_scaled_mxint8
func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) {
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
}

// -----

// CHECK-LABEL: test_const_fp6e3m2
func.func @test_const_fp6e3m2() -> tensor<4xf6E3M2FN> {
%0 = "tosa.const"() {values = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf6E3M2FN>} : () -> tensor<4xf6E3M2FN>
return %0 : tensor<4xf6E3M2FN>
}

// -----

// CHECK-LABEL: test_const_mxint8
func.func @test_const_mxint8() -> tensor<2x!tosa.mxint8> {
%0 = "tosa.const"() {values = dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8>
return %0 : tensor<2x!tosa.mxint8>
}

// -----

// CHECK-LABEL: test_cast_f4e2m1
func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> {
%0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16>
return %0 : tensor<13x21x3xbf16>
}

// -----

// CHECK-LABEL: test_matmul_t_block_scaled_mxint8
func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32x!tosa.mxint8>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32x!tosa.mxint8>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32x!tosa.mxint8>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
return %0 : tensor<4x8x16xf32>
}

// -----

// CHECK-LABEL: test_cast_to_block_scaled_mxint8
func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) {
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
}