Skip to content

Commit

Permalink
Generalize Operand Quantization in FuseQuantizeOps (#3327)
Browse files Browse the repository at this point in the history
This change enables more customization with operand quantization, and
generalizes the patterns QuantizeOperands and QuantizeTransposeOperands
to QuantizeOperandsPastCommutingOps.

This allows for passing quantization through operations which are
functionally unaffected by quantization, such as view-like ops. The
purpose of this change is to address a myriad of quantization issues
seen in quantized onnx models that have some reshape-like operations
sandwiched in between a dequant and something like a matmul (whose other
operand is immediately quantizable).
  • Loading branch information
zjgarvey committed May 13, 2024
1 parent 0b7cbf5 commit 75d1d72
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 97 deletions.
181 changes: 99 additions & 82 deletions lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include <stack>

using namespace mlir;
using namespace mlir::torch;
Expand All @@ -27,98 +28,112 @@ template <typename SrcOp> struct QuantInfo {
template <> struct QuantInfo<AtenReluOp> {
static constexpr unsigned operandsToQuantize[1] = {0};
};
template <typename SrcOp>
class QuantizeOperands : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;

LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
llvm::SmallVector<Value> operands(op->getOperands());

bool dequanted = false;
auto f = [&dequanted](Value operand) {
if (auto dequant = operand.getDefiningOp<AtenDequantizeTensorOp>()) {
operand = dequant.getOperand();
dequanted = true;
}
if (auto dequant = operand.getDefiningOp<AtenDequantizeSelfOp>()) {
operand = dequant.getOperand();
dequanted = true;
}
return operand;
};

for (unsigned i : QuantInfo<SrcOp>::operandsToQuantize) {
operands[i] = f(operands[i]);
}

if (!dequanted) {
return rewriter.notifyMatchFailure(op, "no dequantizations found");
}

rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
return success();
}
};
// A QCommutingOp is an Op satisfying:
// 1. Has at most one tensor operand at index 0
// 2. Has a single output, which is a tensor
// 3. Satisfies the commutation relation:
// [MPTQT -> Dequant -> Op(float)] = [Op(int) -> MPTQT -> Dequant]
// where MPTQT = "Aten_MakePerTensorQuantizedTensorOp"
// and Dequant = "AtenDequantizeSelfOp" or "AtenDequantizeTensorOp"
bool isQCommutingOp(mlir::Operation *op) {
// if adding a new commuting op here, be sure to add a
// RemoveUnused pattern for that op to clean up afterwards
return llvm::isa<AtenTransposeIntOp, AtenReshapeOp, AtenSliceTensorOp>(op);
}

template <typename SrcOp>
class QuantizeTransposedOperands : public OpRewritePattern<SrcOp> {
// The following conversion takes patterns of the form [op0 -> MPTQT -> dequant
// -> Op1 -> Op2 -> ... Opk -> SrcOp] to [op0 -> Int(Op1) -> Int(Op2) -> ... ->
// Int(Opk) -> MPTQT -> SrcOp] for any sequence of q commuting ops
// {Op1,Op2,...,Opk} with k <= depth.
// With depth = 0, this conversion will simply fuse any immediately quantizable
// operands: [MPTQT -> Dequant -> SrcOp (float operands)] to [MPTQT -> SrcOp(int
// operands)]
template <typename SrcOp, unsigned depth>
class QuantizeOperandsPastCommutingOps : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;

LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {

mlir::Location loc = op.getLoc();
llvm::SmallVector<Value> operands(op->getOperands());
unsigned numOperands = operands.size();
bool dequanted = false;
for (unsigned i = 0; i < numOperands; i++) {
if (auto trans = operands[i].getDefiningOp<AtenTransposeIntOp>()) {
auto transOperands = trans.getOperands();
Value dequantOperand;
if (auto dequant =
transOperands[0].getDefiningOp<AtenDequantizeSelfOp>()) {
dequantOperand = dequant.getOperand();
if (auto quant =
dequantOperand
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
auto quantOperands = quant.getOperands();
auto qType = quantOperands[0]
.getType()
.cast<ValueTensorType>()
.getOptionalDtype();
auto torchQType =
cast<ValueTensorType>(quant.getType()).getOptionalDtype();
auto transQTy =
rewriter.getType<ValueTensorType>(trans.getResult()
.getType()
.cast<ValueTensorType>()
.getOptionalSizes(),
qType);
auto newQuantTy =
rewriter.getType<ValueTensorType>(trans.getResult()
.getType()
.cast<ValueTensorType>()
.getOptionalSizes(),
torchQType);
Value newTrans = rewriter.create<AtenTransposeIntOp>(
op.getLoc(), transQTy, quantOperands[0], transOperands[1],
transOperands[2]);
Value newQuant =
rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
op.getLoc(), newQuantTy, newTrans, quantOperands[1],
quantOperands[2]);
operands[i] = newQuant;
dequanted = true;
}

for (unsigned i : QuantInfo<SrcOp>::operandsToQuantize) {
Value operand = operands[i];
std::stack<mlir::Operation *> commutingOpStack;
Value dequantOpd, MPTQTOpd;
for (unsigned k = 0; k < depth + 1; k++) {
auto currOp = operand.getDefiningOp();
// Case 0 : currOp is a nullptr (e.g., operand is a block argument)
if (!currOp)
break;
// Case 1 : currOp is a q commuting op (continue loop)
if (isQCommutingOp(currOp)) {
commutingOpStack.push(currOp);
// set operand to currOp for next k-iteration
operand = currOp->getOperand(0);
continue;
}
// Case 2 : currOp is a dequant op (end loop)
if (llvm::isa<AtenDequantizeSelfOp, AtenDequantizeTensorOp>(currOp)) {
dequantOpd = currOp->getOperand(0);
auto MPTQTOp =
dequantOpd.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
MPTQTOpd = MPTQTOp.getOperand(0);
}
// either a dequant was found or chain broken, so break loop
break;
}

// move to next operand if this trace was unsuccessful
if (!MPTQTOpd)
continue;

// a successful trace occured, so set dequant to true
dequanted = true;

// rewrite stack
Value oldOpd = MPTQTOpd;
Type intDType =
cast<ValueTensorType>(MPTQTOpd.getType()).getOptionalDtype();
while (!commutingOpStack.empty()) {
// get front of the commuting op stack and replace its first operand
// with oldOpd
auto currOp = commutingOpStack.top();
commutingOpStack.pop();
llvm::SmallVector<Value> currOperands(currOp->getOperands());
currOperands[0] = oldOpd;
// get new result type
auto oldType = cast<ValueTensorType>(currOp->getResultTypes()[0]);
auto intType =
rewriter.getType<ValueTensorType>(oldType.getSizes(), intDType);
// rewrite currOp to have new operands and result type
// store this as oldOpd for next loop
oldOpd = rewriter
.create(loc, (currOp->getName()).getIdentifier(),
currOperands, intType, currOp->getAttrs())
->getResult(0);
}

// stack is empty, so oldOpd is now the corrected verion of the
// SrcOp's original operand
// convert operand -> SrcOp to oldOpd -> newMPTQTOp -> SrcOp
auto MPTQTOperands = dequantOpd.getDefiningOp()->getOperands();
auto qTorchType =
cast<ValueTensorType>(dequantOpd.getType()).getOptionalDtype();
auto newMPTQTType = rewriter.getType<ValueTensorType>(
cast<ValueTensorType>(operands[i].getType()).getSizes(), qTorchType);
operands[i] = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
loc, newMPTQTType, oldOpd, MPTQTOperands[1], MPTQTOperands[2]);
}

if (!dequanted) {
return rewriter.notifyMatchFailure(
op, "no dequantized transpose inputs found.");
return rewriter.notifyMatchFailure(op, "No dequantizations found.");
}

rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
return success();
}
Expand Down Expand Up @@ -356,11 +371,13 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase<FuseQuantizedOpsPass> {
RemoveUnused<AtenDequantizeTensorOp>,
RemoveUnused<AtenQuantizePerTensorOp>,
RemoveUnused<Aten_MakePerTensorQuantizedTensorOp>,
RemoveUnused<AtenTransposeIntOp>, QuantizeOperands<AtenConvolutionOp>,
QuantizeOperands<AtenMatmulOp>, QuantizeOperands<AtenReluOp>,
QuantizeTransposedOperands<AtenMatmulOp>,
QuantizeAccumulator<AtenMatmulOp>, QuantizeOperands<AtenMmOp>,
QuantizeTransposedOperands<AtenMmOp>, QuantizeAccumulator<AtenMmOp>,
RemoveUnused<AtenTransposeIntOp>, RemoveUnused<AtenSliceTensorOp>,
RemoveUnused<AtenReshapeOp>,
QuantizeOperandsPastCommutingOps<AtenConvolutionOp, 0>,
QuantizeOperandsPastCommutingOps<AtenReluOp, 0>,
QuantizeOperandsPastCommutingOps<AtenMatmulOp, 2>,
QuantizeOperandsPastCommutingOps<AtenMmOp, 1>,
QuantizeAccumulator<AtenMmOp>, QuantizeAccumulator<AtenMatmulOp>,
QuantizeResultLikeOperand<AtenReluOp>, QuantizeBias<AtenConvolutionOp>>(
context);

Expand Down

0 comments on commit 75d1d72

Please sign in to comment.