Skip to content

Commit

Permalink
[mlir][tosa] Replace StructAttrs with AttrDefs
Browse files Browse the repository at this point in the history
Depends on D127352

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D127370
  • Loading branch information
Mogball committed Jun 9, 2022
1 parent d7ef488 commit f1182bd
Show file tree
Hide file tree
Showing 17 changed files with 117 additions and 119 deletions.
7 changes: 3 additions & 4 deletions mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
Expand Up @@ -3,7 +3,6 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
add_mlir_interface(TosaInterfaces)

set(LLVM_TARGET_DEFINITIONS TosaOps.td)
mlir_tablegen(TosaStructs.h.inc -gen-struct-attr-decls)
mlir_tablegen(TosaStructs.cpp.inc -gen-struct-attr-defs)
add_public_tablegen_target(MLIRTosaStructsIncGen)

mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRTosaAttributesIncGen)
51 changes: 30 additions & 21 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Expand Up @@ -10,13 +10,16 @@
//
//===----------------------------------------------------------------------===//


#ifndef TOSA_OP_BASE
#define TOSA_OP_BASE

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

//===----------------------------------------------------------------------===//
// The TOSA Dialect.
//===----------------------------------------------------------------------===//

def Tosa_Dialect : Dialect {
let name = "tosa";

Expand All @@ -41,6 +44,16 @@ def Tosa_Dialect : Dialect {

let cppNamespace = "mlir::tosa";
let hasConstantMaterializer = 1;
let useDefaultAttributePrinterParser = 1;
}

//===----------------------------------------------------------------------===//
// TOSA Attributes.
//===----------------------------------------------------------------------===//

class Tosa_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
: AttrDef<Tosa_Dialect, attrName, traits> {
let mnemonic = attrMnemonic;
}

//===----------------------------------------------------------------------===//
Expand All @@ -51,7 +64,7 @@ def Tosa_Dialect : Dialect {
// feed numerical precision parameters to the functional implementation of TOSA
// operators.
// The functional behavior is defined in the TOSA specification maintained at
// https://developer.mlplatform.org/w/tosa/ . TOSA leverages MLIR's built in
// https://developer.mlplatform.org/w/tosa/. TOSA leverages MLIR's built in
// quantization support: https://mlir.llvm.org/docs/Quantization/, and supports
// uniform quantization. Depending on datatype, asymmetric and symmetric
// quantization are supported. The types themselves are described in
Expand All @@ -60,12 +73,11 @@ def Tosa_Dialect : Dialect {
// This quantization attribute expresses numerical behavior of operators where
// the operator has a numerical relationship between a single input and output.
// For example: tosa.negate.
def Tosa_UnaryOpQuantizationAttr : StructAttr<"UnaryOpQuantizationAttr",
Tosa_Dialect, [
StructFieldAttr<"input_zp", I32Attr>,
StructFieldAttr<"output_zp", I32Attr>
]> {
def Tosa_UnaryOpQuantizationAttr
: Tosa_Attr<"UnaryOpQuantization", "unary_quant"> {
let summary = "Attribute for UnaryOp quantization information.";
let parameters = (ins "int64_t":$input_zp, "int64_t":$output_zp);
let assemblyFormat = "`<` struct(params) `>`";
}

// There is no explicit BinaryOpQuantizationAttr for 2-input/1-output ops. In
Expand All @@ -79,31 +91,28 @@ def Tosa_UnaryOpQuantizationAttr : StructAttr<"UnaryOpQuantizationAttr",
// the inputs.
// The scaling of their accumulator output is done using an explicit
// tosa.rescale operator that scales the accumulator result to output scale.
def Tosa_ConvOpQuantizationAttr : StructAttr<"ConvOpQuantizationAttr",
Tosa_Dialect, [
StructFieldAttr<"input_zp", I32Attr>,
StructFieldAttr<"weight_zp", I32Attr>
]> {
def Tosa_ConvOpQuantizationAttr
: Tosa_Attr<"ConvOpQuantization", "conv_quant"> {
let summary = "Attribute for Conv type op quantization information.";
let parameters = (ins "int64_t":$input_zp, "int64_t":$weight_zp);
let assemblyFormat = "`<` struct(params) `>`";
}

def Tosa_MatMulOpQuantizationAttr : StructAttr<"MatMulOpQuantizationAttr",
Tosa_Dialect, [
StructFieldAttr<"a_zp", I32Attr>,
StructFieldAttr<"b_zp", I32Attr>
]> {
def Tosa_MatMulOpQuantizationAttr
: Tosa_Attr< "MatMulOpQuantization", "matmul_quant"> {
let summary = "Attribute for MatMulOp quantization information.";
let parameters = (ins "int64_t":$a_zp, "int64_t":$b_zp);
let assemblyFormat = "`<` struct(params) `>`";
}

// This attribute holds input zero point correction applied to the padding
// zeros to ensure numerical accuracy in the subsequent TOSA operations.
// Its functional application is described in the tosa.pad() operator
// description in the specification.
def Tosa_PadOpQuantizationAttr : StructAttr<"PadOpQuantizationAttr",
Tosa_Dialect, [
StructFieldAttr<"input_zp", I32Attr>
]> {
def Tosa_PadOpQuantizationAttr : Tosa_Attr<"PadOpQuantization", "pad_quant"> {
let summary = "Attribute for PadOp quantization information.";
let parameters = (ins "int64_t":$input_zp);
let assemblyFormat = "`<` struct(params) `>`";
}

//===----------------------------------------------------------------------===//
Expand Down
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
Expand Up @@ -21,8 +21,8 @@
//===----------------------------------------------------------------------===//
// TOSA dialect and structs includes.
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.h.inc"
#include "mlir/Dialect/Tosa/IR/TosaStructs.h.inc"

namespace mlir {
class PatternRewriter;
Expand All @@ -45,6 +45,9 @@ void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx,
} // namespace tosa
} // namespace mlir

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc"

#define GET_OP_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"

Expand Down
16 changes: 7 additions & 9 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Expand Up @@ -147,10 +147,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
cast<tosa::NegateOp>(op).quantization_info()) {
auto quantizationInfo = cast<tosa::NegateOp>(op).quantization_info();
int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
int64_t inZp =
quantizationInfo.getValue().input_zp().getValue().getSExtValue();
int64_t outZp =
quantizationInfo.getValue().output_zp().getValue().getSExtValue();
int64_t inZp = quantizationInfo.getValue().getInput_zp();
int64_t outZp = quantizationInfo.getValue().getOutput_zp();

// Compute the maximum value that can occur in the intermediate buffer.
int64_t zpAdd = inZp + outZp;
Expand Down Expand Up @@ -1844,13 +1842,13 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
loc, padOp.pad_const(), ValueRange({}));
} else {
Attribute constantAttr;
if (elementTy.isa<FloatType>())
if (elementTy.isa<FloatType>()) {
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
} else if (elementTy.isa<IntegerType>() && !padOp.quantization_info()) {
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
auto value = padOp.quantization_info().getValue().input_zp().getValue();
constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
} else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
int64_t value = padOp.quantization_info().getValue().getInput_zp();
constantAttr = rewriter.getIntegerAttr(elementTy, value);
}
if (constantAttr)
padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
Expand Down
34 changes: 14 additions & 20 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Expand Up @@ -202,7 +202,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
if (isQuantized) {
auto quantizationInfo =
op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
auto iZp = quantizationInfo.input_zp().getValue().getSExtValue();
int64_t iZp = quantizationInfo.getInput_zp();

int64_t intMin =
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
Expand Down Expand Up @@ -274,10 +274,8 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
if (isQuantized) {
auto quantizationInfo =
op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
auto iZp = rewriter.getI32IntegerAttr(
quantizationInfo.input_zp().getValue().getSExtValue());
auto kZp = rewriter.getI32IntegerAttr(
quantizationInfo.weight_zp().getValue().getSExtValue());
auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp());
auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp());

auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
Expand Down Expand Up @@ -368,10 +366,8 @@ class DepthwiseConvConverter
if (isQuantized) {
auto quantizationInfo =
op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
iZp = rewriter.getI32IntegerAttr(
quantizationInfo.input_zp().getValue().getSExtValue());
kZp = rewriter.getI32IntegerAttr(
quantizationInfo.weight_zp().getValue().getSExtValue());
iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp());
kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp());
}

auto weightShape = weightTy.getShape();
Expand All @@ -382,7 +378,7 @@ class DepthwiseConvConverter
if (isQuantized) {
auto quantizationInfo =
op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
auto iZp = quantizationInfo.input_zp().getValue().getSExtValue();
int64_t iZp = quantizationInfo.getInput_zp();

int64_t intMin =
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
Expand Down Expand Up @@ -546,11 +542,9 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {

auto quantizationInfo = op.quantization_info().getValue();
auto aZp = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(
quantizationInfo.a_zp().getValue().getSExtValue()));
loc, rewriter.getI32IntegerAttr(quantizationInfo.getA_zp()));
auto bZp = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(
quantizationInfo.b_zp().getValue().getSExtValue()));
loc, rewriter.getI32IntegerAttr(quantizationInfo.getB_zp()));
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
op, TypeRange{op.getType()},
ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor);
Expand Down Expand Up @@ -658,11 +652,9 @@ class FullyConnectedConverter

auto quantizationInfo = op.quantization_info().getValue();
auto inputZp = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(
quantizationInfo.input_zp().getValue().getSExtValue()));
loc, rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp()));
auto outputZp = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(
quantizationInfo.weight_zp().getValue().getSExtValue()));
loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp()));
Value matmul =
rewriter
.create<linalg::QuantizedMatmulOp>(
Expand Down Expand Up @@ -900,7 +892,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
if (op.quantization_info()) {
auto quantizationInfo = op.quantization_info().getValue();
auto inputZp = rewriter.create<arith::ConstantOp>(
loc, quantizationInfo.input_zp());
loc,
b.getIntegerAttr(accETy, quantizationInfo.getInput_zp()));
Value offset =
rewriter.create<arith::MulIOp>(loc, accETy, countI, inputZp);
poolVal =
Expand Down Expand Up @@ -936,7 +929,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
if (op.quantization_info()) {
auto quantizationInfo = op.quantization_info().getValue();
auto outputZp = rewriter.create<arith::ConstantOp>(
loc, quantizationInfo.output_zp());
loc, b.getIntegerAttr(scaled.getType(),
quantizationInfo.getOutput_zp()));
scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
.getResult();
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Tosa/CMakeLists.txt
Expand Up @@ -7,8 +7,8 @@ add_mlir_dialect_library(MLIRTosa
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa

DEPENDS
MLIRTosaAttributesIncGen
MLIRTosaOpsIncGen
MLIRTosaStructsIncGen
MLIRTosaInterfacesIncGen

LINK_LIBS PUBLIC
Expand Down
25 changes: 19 additions & 6 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Expand Up @@ -18,12 +18,14 @@
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::tosa;
Expand All @@ -33,8 +35,8 @@ using namespace mlir::tosa;
//===----------------------------------------------------------------------===//
// Tosa dialect structs and interface includes.
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
#include "mlir/Dialect/Tosa/IR/TosaStructs.cpp.inc"

namespace {
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -78,6 +80,10 @@ void TosaDialect::initialize() {
#define GET_OP_LIST
#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
>();
addInterfaces<TosaInlinerInterface>();
}

Expand Down Expand Up @@ -336,13 +342,13 @@ struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
Type elementTy = inputTy.getElementType();

Attribute constantAttr;
if (elementTy.isa<FloatType>())
if (elementTy.isa<FloatType>()) {
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
else if (elementTy.isa<IntegerType>() && !op.quantization_info())
} else if (elementTy.isa<IntegerType>() && !op.quantization_info()) {
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
else if (elementTy.isa<IntegerType>() && op.quantization_info()) {
auto value = op.quantization_info().getValue().input_zp().getValue();
constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
} else if (elementTy.isa<IntegerType>() && op.quantization_info()) {
auto value = op.quantization_info().getValue().getInput_zp();
constantAttr = rewriter.getIntegerAttr(elementTy, value);
}

if (!constantAttr) {
Expand Down Expand Up @@ -1925,6 +1931,13 @@ LogicalResult WhileOp::inferReturnTypeComponents(
return success();
}

//===----------------------------------------------------------------------===//
// TOSA Attribute Definitions.
//===----------------------------------------------------------------------===//

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"

//===----------------------------------------------------------------------===//
// TOSA Operator Definitions.
//===----------------------------------------------------------------------===//
Expand Down
Expand Up @@ -214,8 +214,7 @@ class TransposeConvStridedConverter
weight = createOpAndInfer<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
weightPaddingVal, nullptr,
PadOpQuantizationAttr::get(quantInfo.weight_zp(),
rewriter.getContext()));
rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeight_zp()));

} else {
weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
Expand Down Expand Up @@ -279,8 +278,7 @@ class TransposeConvStridedConverter
input = createOpAndInfer<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(inputETy), input,
inputPaddingVal, nullptr,
PadOpQuantizationAttr::get(quantInfo.input_zp(),
rewriter.getContext()));
rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInput_zp()));
} else {
input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
UnrankedTensorType::get(inputETy),
Expand Down

0 comments on commit f1182bd

Please sign in to comment.