diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index 0c352d31ca8..7870ff63cb4 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -13,6 +13,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include using namespace mlir; using namespace mlir::torch; @@ -27,98 +28,112 @@ template struct QuantInfo { template <> struct QuantInfo { static constexpr unsigned operandsToQuantize[1] = {0}; }; -template -class QuantizeOperands : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SrcOp op, - PatternRewriter &rewriter) const override { - llvm::SmallVector operands(op->getOperands()); - - bool dequanted = false; - auto f = [&dequanted](Value operand) { - if (auto dequant = operand.getDefiningOp()) { - operand = dequant.getOperand(); - dequanted = true; - } - if (auto dequant = operand.getDefiningOp()) { - operand = dequant.getOperand(); - dequanted = true; - } - return operand; - }; - - for (unsigned i : QuantInfo::operandsToQuantize) { - operands[i] = f(operands[i]); - } - if (!dequanted) { - return rewriter.notifyMatchFailure(op, "no dequantizations found"); - } - - rewriter.replaceOpWithNewOp(op, op.getType(), operands); - return success(); - } -}; +// A QCommutingOp is an Op satisfying: +// 1. Has at most one tensor operand at index 0 +// 2. Has a single output, which is a tensor +// 3. Satisfies the commutation relation: +// [MPTQT -> Dequant -> Op(float)] = [Op(int) -> MPTQT -> Dequant] +// where MPTQT = "Aten_MakePerTensorQuantizedTensorOp" +// and Dequant = "AtenDequantizeSelfOp" or "AtenDequantizeTensorOp" +bool isQCommutingOp(mlir::Operation *op) { + // if adding a new commuting op here, be sure to add a + // RemoveUnused pattern for that op to clean up afterwards + return llvm::isa(op); +} -template -class QuantizeTransposedOperands : public OpRewritePattern { +// The following conversion takes patterns of the form [op0 -> MPTQT -> dequant +// -> Op1 -> Op2 -> ... Opk -> SrcOp] to [op0 -> Int(Op1) -> Int(Op2) -> ... -> +// Int(Opk) -> MPTQT -> SrcOp] for any sequence of q commuting ops +// {Op1,Op2,...,Opk} with k <= depth. +// With depth = 0, this conversion will simply fuse any immediately quantizable +// operands: [MPTQT -> Dequant -> SrcOp (float operands)] to [MPTQT -> SrcOp(int +// operands)] +template +class QuantizeOperandsPastCommutingOps : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const override { + mlir::Location loc = op.getLoc(); llvm::SmallVector operands(op->getOperands()); - unsigned numOperands = operands.size(); bool dequanted = false; - for (unsigned i = 0; i < numOperands; i++) { - if (auto trans = operands[i].getDefiningOp()) { - auto transOperands = trans.getOperands(); - Value dequantOperand; - if (auto dequant = - transOperands[0].getDefiningOp()) { - dequantOperand = dequant.getOperand(); - if (auto quant = - dequantOperand - .getDefiningOp()) { - auto quantOperands = quant.getOperands(); - auto qType = quantOperands[0] - .getType() - .cast() - .getOptionalDtype(); - auto torchQType = - cast(quant.getType()).getOptionalDtype(); - auto transQTy = - rewriter.getType(trans.getResult() - .getType() - .cast() - .getOptionalSizes(), - qType); - auto newQuantTy = - rewriter.getType(trans.getResult() - .getType() - .cast() - .getOptionalSizes(), - torchQType); - Value newTrans = rewriter.create( - op.getLoc(), transQTy, quantOperands[0], transOperands[1], - transOperands[2]); - Value newQuant = - rewriter.create( - op.getLoc(), newQuantTy, newTrans, quantOperands[1], - quantOperands[2]); - operands[i] = newQuant; - dequanted = true; - } + + for (unsigned i : QuantInfo::operandsToQuantize) { + Value operand = operands[i]; + std::stack commutingOpStack; + Value dequantOpd, MPTQTOpd; + for (unsigned k = 0; k < depth + 1; k++) { + auto currOp = operand.getDefiningOp(); + // Case 0 : currOp is a nullptr (e.g., operand is a block argument) + if (!currOp) + break; + // Case 1 : currOp is a q commuting op (continue loop) + if (isQCommutingOp(currOp)) { + commutingOpStack.push(currOp); + // set operand to currOp for next k-iteration + operand = currOp->getOperand(0); + continue; + } + // Case 2 : currOp is a dequant op (end loop) + if (llvm::isa(currOp)) { + dequantOpd = currOp->getOperand(0); + auto MPTQTOp = + dequantOpd.getDefiningOp(); + MPTQTOpd = MPTQTOp.getOperand(0); } + // either a dequant was found or chain broken, so break loop + break; + } + + // move to next operand if this trace was unsuccessful + if (!MPTQTOpd) + continue; + + // a successful trace occured, so set dequant to true + dequanted = true; + + // rewrite stack + Value oldOpd = MPTQTOpd; + Type intDType = + cast(MPTQTOpd.getType()).getOptionalDtype(); + while (!commutingOpStack.empty()) { + // get front of the commuting op stack and replace its first operand + // with oldOpd + auto currOp = commutingOpStack.top(); + commutingOpStack.pop(); + llvm::SmallVector currOperands(currOp->getOperands()); + currOperands[0] = oldOpd; + // get new result type + auto oldType = cast(currOp->getResultTypes()[0]); + auto intType = + rewriter.getType(oldType.getSizes(), intDType); + // rewrite currOp to have new operands and result type + // store this as oldOpd for next loop + oldOpd = rewriter + .create(loc, (currOp->getName()).getIdentifier(), + currOperands, intType, currOp->getAttrs()) + ->getResult(0); } + + // stack is empty, so oldOpd is now the corrected verion of the + // SrcOp's original operand + // convert operand -> SrcOp to oldOpd -> newMPTQTOp -> SrcOp + auto MPTQTOperands = dequantOpd.getDefiningOp()->getOperands(); + auto qTorchType = + cast(dequantOpd.getType()).getOptionalDtype(); + auto newMPTQTType = rewriter.getType( + cast(operands[i].getType()).getSizes(), qTorchType); + operands[i] = rewriter.create( + loc, newMPTQTType, oldOpd, MPTQTOperands[1], MPTQTOperands[2]); } + if (!dequanted) { - return rewriter.notifyMatchFailure( - op, "no dequantized transpose inputs found."); + return rewriter.notifyMatchFailure(op, "No dequantizations found."); } + rewriter.replaceOpWithNewOp(op, op.getType(), operands); return success(); } @@ -356,11 +371,13 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { RemoveUnused, RemoveUnused, RemoveUnused, - RemoveUnused, QuantizeOperands, - QuantizeOperands, QuantizeOperands, - QuantizeTransposedOperands, - QuantizeAccumulator, QuantizeOperands, - QuantizeTransposedOperands, QuantizeAccumulator, + RemoveUnused, RemoveUnused, + RemoveUnused, + QuantizeOperandsPastCommutingOps, + QuantizeOperandsPastCommutingOps, + QuantizeOperandsPastCommutingOps, + QuantizeOperandsPastCommutingOps, + QuantizeAccumulator, QuantizeAccumulator, QuantizeResultLikeOperand, QuantizeBias>( context); diff --git a/test/Dialect/Torch/fuse-quantized-ops.mlir b/test/Dialect/Torch/fuse-quantized-ops.mlir index f98cb842f5d..594295d4e86 100644 --- a/test/Dialect/Torch/fuse-quantized-ops.mlir +++ b/test/Dialect/Torch/fuse-quantized-ops.mlir @@ -28,6 +28,60 @@ func.func @mm(%arg0: !torch.vtensor<[4, 4],si8>, %arg1: !torch.vtensor<[4, 4],si // ----- +// CHECK-LABEL: @matmul_commuting +func.func @matmul_commuting(%arg0: !torch.vtensor<[2,128,32,32],si8>) -> !torch.vtensor<[1,1024,1024],f32> { + %float5.000000e-01 = torch.constant.float 5.000000e-01 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int-128 = torch.constant.int -128 + %int2 = torch.constant.int 2 + %int128 = torch.constant.int 128 + %int1024 = torch.constant.int 1024 + %int12 = torch.constant.int 12 + %0 = torch.aten._make_per_tensor_quantized_tensor %arg0, %float5.000000e-01, %int-128 : !torch.vtensor<[2,128,32,32],si8>, !torch.float, !torch.int -> !torch.vtensor<[2,128,32,32],!torch.qint8> + %1 = torch.aten.dequantize.self %0 : !torch.vtensor<[2,128,32,32],!torch.qint8> -> !torch.vtensor<[2,128,32,32],f32> + %2 = torch.aten.slice.Tensor %1, %int0, %int0, %int1, %int1 : !torch.vtensor<[2,128,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],f32> + %3 = torch.aten.slice.Tensor %1, %int0, %int1, %int2, %int1 : !torch.vtensor<[2,128,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],f32> + %4 = torch.prim.ListConstruct %int1, %int128, %int1024 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5 = torch.aten.reshape %2, %4 : !torch.vtensor<[1,128,32,32],f32>, !torch.list -> !torch.vtensor<[1,128,1024],f32> + %6 = torch.aten.reshape %3, %4 : !torch.vtensor<[1,128,32,32],f32>, !torch.list -> !torch.vtensor<[1,128,1024],f32> + %7 = torch.aten.transpose.int %5, %int1, %int2 : !torch.vtensor<[1,128,1024],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],f32> + %8 = torch.aten.quantize_per_tensor %7, %float5.000000e-01, %int0, %int12 : !torch.vtensor<[1,1024,128],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8> + %9 = torch.aten.int_repr %8 : !torch.vtensor<[1,1024,128],!torch.qint8> -> !torch.vtensor<[1,1024,128],si8> + %10 = torch.aten._make_per_tensor_quantized_tensor %9, %float5.000000e-01, %int0 : !torch.vtensor<[1,1024,128],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8> + %11 = torch.aten.dequantize.self %10 : !torch.vtensor<[1,1024,128],!torch.qint8> -> !torch.vtensor<[1,1024,128],f32> + %12 = torch.aten.matmul %11, %6 : !torch.vtensor<[1,1024,128],f32>, !torch.vtensor<[1,128,1024],f32> -> !torch.vtensor<[1,1024,1024],f32> + + // CHECK-DAG: %[[QUARTER:.+]] = torch.constant.float 2.500000e-01 + // CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[IN128:.+]] = torch.constant.int -128 + // CHECK-DAG: %[[I2:.+]] = torch.constant.int 2 + // CHECK-DAG: %[[I128:.+]] = torch.constant.int 128 + // CHECK-DAG: %[[I1024:.+]] = torch.constant.int 1024 + // CHECK-DAG: %[[I12:.+]] = torch.constant.int 12 + // CHECK-DAG: %[[MPTQT0:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[IN128]] : !torch.vtensor<[2,128,32,32],si8>, !torch.float, !torch.int -> !torch.vtensor<[2,128,32,32],!torch.qint8> + // CHECK-DAG: %[[DQ0:.+]] = torch.aten.dequantize.self %[[MPTQT0]] : !torch.vtensor<[2,128,32,32],!torch.qint8> -> !torch.vtensor<[2,128,32,32],f32> + // CHECK-DAG: %[[SLICE0:.+]] = torch.aten.slice.Tensor %[[DQ0]], %[[I0]], %[[I0]], %[[I1]], %[[I1]] : !torch.vtensor<[2,128,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],f32> + // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[I1]], %[[I128]], %[[I1024]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[RESHAPE0:.+]] = torch.aten.reshape %[[SLICE0]], %[[LIST]] : !torch.vtensor<[1,128,32,32],f32>, !torch.list -> !torch.vtensor<[1,128,1024],f32> + // CHECK-DAG: %[[TR0:.+]] = torch.aten.transpose.int %[[RESHAPE0]], %[[I1]], %[[I2]] : !torch.vtensor<[1,128,1024],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],f32> + // CHECK-DAG: %[[Q0:.+]] = torch.aten.quantize_per_tensor %[[TR0]], %[[HALF]], %[[I0]], %[[I12]] : !torch.vtensor<[1,1024,128],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8> + // CHECK-DAG: %[[IR0:.+]] = torch.aten.int_repr %[[Q0]] : !torch.vtensor<[1,1024,128],!torch.qint8> -> !torch.vtensor<[1,1024,128],si8> + // CHECK-DAG: %[[MPTQT1:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[IR0]], %[[HALF]], %[[I0]] : !torch.vtensor<[1,1024,128],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8> + // CHECK-DAG: %[[SLICE1:.+]] = torch.aten.slice.Tensor %arg0, %[[I0]], %[[I1]], %[[I2]], %[[I1]] : !torch.vtensor<[2,128,32,32],si8>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],si8> + // CHECK-DAG: %[[RESHAPE1:.+]] = torch.aten.reshape %[[SLICE1]], %[[LIST]] : !torch.vtensor<[1,128,32,32],si8>, !torch.list -> !torch.vtensor<[1,128,1024],si8> + // CHECK-DAG: %[[MPTQT2:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[RESHAPE1]], %[[HALF]], %[[IN128]] : !torch.vtensor<[1,128,1024],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1024],!torch.qint8> + // CHECK-DAG: %[[MATMUL:.+]] = torch.aten.matmul %[[MPTQT1]], %[[MPTQT2]] : !torch.vtensor<[1,1024,128],!torch.qint8>, !torch.vtensor<[1,128,1024],!torch.qint8> -> !torch.vtensor<[1,1024,1024],!torch.qint32> + // CHECK-DAG: %[[IR1:.+]] = torch.aten.int_repr %[[MATMUL]] : !torch.vtensor<[1,1024,1024],!torch.qint32> -> !torch.vtensor<[1,1024,1024],si32> + // CHECK-DAG: %[[MPTQT3:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[IR1]], %[[QUARTER]], %[[I0]] : !torch.vtensor<[1,1024,1024],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,1024,1024],!torch.qint32> + // CHECK-DAG: %[[DQ1:.+]] = torch.aten.dequantize.tensor %[[MPTQT3]] : !torch.vtensor<[1,1024,1024],!torch.qint32> -> !torch.vtensor<[1,1024,1024],f32> + return %12 : !torch.vtensor<[1,1024,1024],f32> +} + +// ----- + // CHECK-LABEL: @convolution_bias func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> { %scale = torch.constant.float 0.5 @@ -43,21 +97,21 @@ func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch. %15 = torch.prim.ListConstruct %zero, %zero : (!torch.int, !torch.int) -> !torch.list %16 = torch.aten.convolution %7, %13, %arg2, %14, %15, %14, %false, %15, %one : !torch.vtensor<[1,3,8,8],f32>, !torch.vtensor<[3,3,2,2],f32>, !torch.vtensor<[3],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],f32> - // CHECK: %[[DTYPE:.+]] = torch.constant.int 14 - // CHECK: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01 - // CHECK: %[[HALF:.+]] = torch.constant.float 5.000000e-01 - // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: %[[ZERO:.+]] = torch.constant.int 0 - // CHECK: %[[ONE:.+]] = torch.constant.int 1 - // CHECK: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> - // CHECK: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> - // CHECK: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32> - // CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QBIAS]] : !torch.vtensor<[3],!torch.qint32> -> !torch.vtensor<[3],si32> - // CHECK: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[INT]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],si32> - // CHECK: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32> - // CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32> + // CHECK-DAG: %[[DTYPE:.+]] = torch.constant.int 14 + // CHECK-DAG: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01 + // CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + // CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> + // CHECK-DAG: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32> + // CHECK-DAG: %[[INT:.+]] = torch.aten.int_repr %[[QBIAS]] : !torch.vtensor<[3],!torch.qint32> -> !torch.vtensor<[3],si32> + // CHECK-DAG: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[INT]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],si32> + // CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32> + // CHECK-DAG: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32> return %16 : !torch.vtensor<[1,3,7,7],f32> }