Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,34 @@ extensionComplianceMap = {
{{bf16T, fp6e3m2T}, SpecificationVersion::V_1_1_DRAFT},
{{bf16T, fp6e2m3T}, SpecificationVersion::V_1_1_DRAFT}},
allOf}}},
{"tosa.cast_from_block_scaled",
{{{Extension::bf16, Extension::mxfp},
{{{fp4e2m1T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
{{fp6e2m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
{{fp6e3m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
{{fp8e4m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
{{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
allOf},
{{Extension::mxfp},
{{{fp4e2m1T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
{{fp6e2m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
{{fp6e3m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
{{fp8e4m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
{{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
{"tosa.cast_to_block_scaled",
{{{Extension::mxfp},
{{{bf16T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{fp32T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{fp32T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{fp32T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{fp32T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::bf16, Extension::mxfp},
{{{bf16T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{bf16T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{bf16T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
{{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}},
allOf}}},
{"tosa.rescale",
{{{Extension::int16},
{{{i48T, i48T, i8T, i8T}, SpecificationVersion::V_1_0},
Expand Down
63 changes: 63 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2470,6 +2470,69 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
// Operator: cast_from_block_scaled
//===----------------------------------------------------------------------===//
def Tosa_CastFromBlockScaledOp: Tosa_InferShapedTypeOp<"cast_from_block_scaled"> {
let summary = "Apply scales from a scale tensor to the values in a value tensor";

let description = [{
Apply the scales from a scale tensor to the values in a value tensor, casting
the result to the output type. The block dimension must be the last dimension
of the tensor.
}];

let arguments = (ins
Tosa_MXFPDataTensorAtLeast1D:$input_data,
Tosa_MXFPScaleTensorAtLeast1D:$input_scale,
Tosa_BlockSizeAttr:$block_size
);

let results = (outs
Tosa_TensorAtLeast1D: $output_data
);

list<Availability> availability = [
Profile<[Tosa_PRO_FP]>,
Extension<[Tosa_EXT_BF16, Tosa_EXT_MXFP]>,
];

let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// Operator: cast_to_block_scaled
//===----------------------------------------------------------------------===//
def Tosa_CastToBlockScaledOp : Tosa_InferShapedTypeOp<"cast_to_block_scaled"> {
let summary = "Calculate scale tensor values per block, output to separate scale and data tensors.";

let description = [{
Calculate a scale value per block of input values and use that to calculate
scaled data values from an input tensor. The output tensors are cast to the
specified scale and value types. The block dimension will be the last dimension
of the tensor.
}];

let arguments = (ins
Tosa_TensorAtLeast1D:$input_data,
Tosa_BlockSizeAttr:$block_size
);

let results = (outs
Tosa_MXFPDataTensorAtLeast1D:$output_data,
Tosa_MXFPScaleTensorAtLeast1D:$output_scale
);

list<Availability> availability = [
Profile<[Tosa_PRO_FP]>,
Extension<[Tosa_EXT_BF16, Tosa_EXT_MXFP]>
];

let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// Operator: rescale
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ class ProfileInfoDepot {

LogicalResult populatationDispatch(Operation *op);

LogicalResult populateProfileInfo(ValueRange operands, Value output);
// Add input operands and output results to the profile type info list
LogicalResult populateProfileInfo(ValueRange operands, ValueRange results);

// Base
template <typename T>
Expand Down
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,16 @@ def Tosa_MXFPScaleTensor3D : AnyTypeOf<[
TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]>
]>;
def Tosa_MXFPDataTensorAtLeast1D : AnyTypeOf<[
TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
TosaRankedTensorOf<[Tosa_MXFPNumber], [AtLeastRankOne]>],
"tosa-conformant tensor of at least rank 1", "::mlir::TensorType"
>;
def Tosa_MXFPScaleTensorAtLeast1D : AnyTypeOf<[
TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
TosaRankedTensorOf<[Tosa_MXFPScaleNumber], [AtLeastRankOne]>],
"tosa-conformant tensor of at least rank 1", "::mlir::TensorType"
>;

//===----------------------------------------------------------------------===//
// Generic scalar, vector, or tensor of a particular type.
Expand Down
159 changes: 158 additions & 1 deletion mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
result.operands)))
return failure();

result.addTypes(fnTy.getResult(0));
result.addTypes(fnTy.getResults());
result.addAttributes(attrs);

return success();
Expand Down Expand Up @@ -532,6 +532,24 @@ void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
printWithEnumHandling(parser, *this);
}

ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseWithEnumHandling<tosa::BlockSize>(parser, result);
}

void CastFromBlockScaledOp::print(OpAsmPrinter &parser) {
printWithEnumHandling(parser, *this);
}

ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseWithEnumHandling<tosa::BlockSize>(parser, result);
}

void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
printWithEnumHandling(parser, *this);
}

//===----------------------------------------------------------------------===//
// Tosa utilities.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3944,6 +3962,145 @@ LogicalResult RescaleOp::inferReturnTypeComponents(
return success();
}

LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
CastFromBlockScaledOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
const ShapeAdaptor inputShape(adaptor.getInputData().getType());
inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
return success();
}

LogicalResult CastFromBlockScaledOp::verify() {
const Type inputDataType = getInputData().getType();
const Type outputDataType = getResult().getType();
if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
return emitOpError() << "require compatible shapes for input_data ("
<< inputDataType << ") and "
<< "output_data (" << outputDataType << ")";

const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);

if (inputDataShape.hasRank()) {
const unsigned int blockSize =
BlockSizeAttr::getBlockSizeValue(getBlockSize());
const int64_t inputDataLastDim =
inputDataShape.getDimSize(inputDataShape.getRank() - 1);
if (inputDataLastDim % blockSize != 0)
return emitOpError() << "expect last dimension of input_data ("
<< inputDataLastDim
<< ") to be divisible by block_size (" << blockSize
<< ")";

const Type inputScaleType = getInputScale().getType();
const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);

if (inputScaleShape.hasRank()) {
SmallVector<int64_t> inputDataDims, inputScaleDims;
inputDataShape.getDims(inputDataDims);
inputScaleShape.getDims(inputScaleDims);

if (inputDataDims.size() != inputScaleDims.size() ||
failed(verifyCompatibleShape(
ArrayRef<int64_t>(inputDataDims).drop_back(1),
ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
return emitOpError() << "require compatible shapes for input_data ("
<< inputDataType << ") and "
<< "input_scale (" << inputScaleType
<< ") except for the last dimension";

const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
inputScaleDims.back()};
if (ShapedType::isStatic(inputDataLastDim) &&
failed(verifyCompatibleDims(dimsToCheck)))
return emitOpError()
<< "expect last dimension of input_scale ("
<< inputScaleDims.back()
<< ") to be equal to last dimension of input_data / block_size ("
<< inputDataDims.back() / blockSize << ")";
}
}

return success();
}

LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
CastToBlockScaledOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
const ShapeAdaptor inputShape(adaptor.getInputData().getType());
inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
if (!inputShape.hasRank())
return success();

// Calculate output_scale shape if ranked input provided
SmallVector<int64_t> outputScaleShape;
inputShape.getDims(outputScaleShape);
const int64_t lastDimLoc = inputShape.getRank() - 1;
const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
if (ShapedType::isStatic(lastDimSize)) {
const unsigned int blockSize =
BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
}
inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
return success();
}

LogicalResult CastToBlockScaledOp::verify() {
const Type inputDataType = getInputData().getType();
const Type outputDataType = getResult(0).getType();
if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
return emitOpError() << "require compatible shapes for input_data ("
<< inputDataType << ") and "
<< "output_data (" << outputDataType << ")";

const unsigned int blockSize =
BlockSizeAttr::getBlockSizeValue(getBlockSize());
const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
if (inputDataShape.hasRank()) {
const int64_t inputDataLastDim =
inputDataShape.getDimSize(inputDataShape.getRank() - 1);
if (ShapedType::isStatic(inputDataLastDim) &&
inputDataLastDim % blockSize != 0)
return emitOpError() << "expect last dimension of input_data ("
<< inputDataLastDim
<< ") to be divisible by block_size (" << blockSize
<< ")";
}

const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
const Type outputScaleType = getResult(1).getType();
const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
SmallVector<int64_t> outputDataDims, outputScaleDims;
outputDataShape.getDims(outputDataDims);
outputScaleShape.getDims(outputScaleDims);

if (outputDataDims.size() != outputScaleDims.size() ||
failed(verifyCompatibleShape(
ArrayRef<int64_t>(outputDataDims).drop_back(1),
ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
return emitOpError() << "require compatible shapes for output_data ("
<< outputDataType << ") and "
<< "output_scale (" << outputScaleType
<< ") except for the last dimension";

const int64_t outputDataLastDim = outputDataDims.back();
const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
outputScaleDims.back()};
if (ShapedType::isStatic(outputDataLastDim) &&
failed(verifyCompatibleDims(dimsToCheck)))
return emitOpError()
<< "expect last dimension of output_scale ("
<< outputScaleDims.back()
<< ") to be equal to last dimension of output_data / block_size ("
<< outputDataDims.back() / blockSize << ")";
}

return success();
}

LogicalResult IfOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
IfOp::Adaptor adaptor,
Expand Down
32 changes: 9 additions & 23 deletions mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ TosaProfileCompliance::getProfileComplianceMap() {

// Base populating function
LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
Value output) {
for (auto operand : operands)
ValueRange results) {
for (const auto &operand : operands)
addValue(operand);
addValue(output);
for (const auto &result : results)
addValue(result);
return success();
}

Expand Down Expand Up @@ -176,23 +177,6 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
addValue(op.getInputReal());
addValue(op.getInputImag());
addValue(op.getOutputReal());
addValue(op.getOutputImag());
return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
addValue(op.getInputReal());
addValue(op.getOutputReal());
addValue(op.getOutputImag());
return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
addValue(op.getOnTrue());
Expand Down Expand Up @@ -246,7 +230,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// This helper function populates the info for all operands.
#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
return populateProfileInfo(op->getOperands(), op->getResult(0)); \
return populateProfileInfo(op->getOperands(), op->getResults()); \
}

// Skip irrelevant operands when they are independent and not tied to any
Expand All @@ -257,8 +241,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_CUSTOM(Conv3D)
POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
POPULATE_PROFILE_INFO_CUSTOM(Mul)
POPULATE_PROFILE_INFO_CUSTOM(FFT2d)
POPULATE_PROFILE_INFO_CUSTOM(RFFT2d)
POPULATE_PROFILE_INFO_CUSTOM(Concat)
POPULATE_PROFILE_INFO_CUSTOM(Pad)
POPULATE_PROFILE_INFO_CUSTOM(Reshape)
Expand All @@ -277,7 +259,11 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// For the most of tosa operators, all operands are profile/extension related
// and hence are all considered in this profile-based compilance check.
POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled)
POPULATE_PROFILE_INFO_COMMON(FFT2d)
POPULATE_PROFILE_INFO_COMMON(RFFT2d)
POPULATE_PROFILE_INFO_COMMON(Cast)
POPULATE_PROFILE_INFO_COMMON(CastFromBlockScaled)
POPULATE_PROFILE_INFO_COMMON(CastToBlockScaled)
POPULATE_PROFILE_INFO_COMMON(Const)
POPULATE_PROFILE_INFO_COMMON(ArgMax)
POPULATE_PROFILE_INFO_COMMON(Sub)
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,8 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_RANKS_AND_SIZES(Transpose);
// Type Conversion
CHECK_RANKS_AND_SIZES(Cast);
CHECK_RANKS_AND_SIZES(CastFromBlockScaled);
CHECK_RANKS_AND_SIZES(CastToBlockScaled);
CHECK_RANKS_AND_SIZES(Rescale);
// Control Flow Operators
CHECK_RANKS_AND_SIZES(If);
Expand Down
Loading