From 75d1d72059cf2731ddfd5e44f8646cd8cb6ebe66 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Sun, 12 May 2024 22:49:59 -0500 Subject: [PATCH] Generalize Operand Quantization in FuseQuantizeOps (#3327) This change enables more customization with operand quantization, and generalizes the patterns QuantizeOperands and QuantizeTransposeOperands to QuantizeOperandsPastCommutingOps. This allows for passing quantization through operations which are functionally unaffected by quantization, such as view-like ops. The purpose of this change is to address a myriad of quantization issues seen in quantized onnx models that have some reshape-like operations sandwiched in between a dequant and something like a matmul (whose other operand is immediately quantizable). --- .../Torch/Transforms/FuseQuantizedOps.cpp | 181 ++++++++++-------- test/Dialect/Torch/fuse-quantized-ops.mlir | 84 ++++++-- 2 files changed, 168 insertions(+), 97 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index 0c352d31ca8..7870ff63cb4 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -13,6 +13,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include using namespace mlir; using namespace mlir::torch; @@ -27,98 +28,112 @@ template struct QuantInfo { template <> struct QuantInfo { static constexpr unsigned operandsToQuantize[1] = {0}; }; -template -class QuantizeOperands : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SrcOp op, - PatternRewriter &rewriter) const override { - llvm::SmallVector operands(op->getOperands()); - - bool dequanted = false; - auto f = [&dequanted](Value operand) { - if (auto dequant = operand.getDefiningOp()) { - operand = dequant.getOperand(); - dequanted = true; - } - if (auto dequant = operand.getDefiningOp()) { - operand = dequant.getOperand(); - dequanted = true; - } - return operand; - }; - - for (unsigned i : QuantInfo::operandsToQuantize) { - operands[i] = f(operands[i]); - } - if (!dequanted) { - return rewriter.notifyMatchFailure(op, "no dequantizations found"); - } - - rewriter.replaceOpWithNewOp(op, op.getType(), operands); - return success(); - } -}; +// A QCommutingOp is an Op satisfying: +// 1. Has at most one tensor operand at index 0 +// 2. Has a single output, which is a tensor +// 3. Satisfies the commutation relation: +// [MPTQT -> Dequant -> Op(float)] = [Op(int) -> MPTQT -> Dequant] +// where MPTQT = "Aten_MakePerTensorQuantizedTensorOp" +// and Dequant = "AtenDequantizeSelfOp" or "AtenDequantizeTensorOp" +bool isQCommutingOp(mlir::Operation *op) { + // if adding a new commuting op here, be sure to add a + // RemoveUnused pattern for that op to clean up afterwards + return llvm::isa(op); +} -template -class QuantizeTransposedOperands : public OpRewritePattern { +// The following conversion takes patterns of the form [op0 -> MPTQT -> dequant +// -> Op1 -> Op2 -> ... Opk -> SrcOp] to [op0 -> Int(Op1) -> Int(Op2) -> ... -> +// Int(Opk) -> MPTQT -> SrcOp] for any sequence of q commuting ops +// {Op1,Op2,...,Opk} with k <= depth. +// With depth = 0, this conversion will simply fuse any immediately quantizable +// operands: [MPTQT -> Dequant -> SrcOp (float operands)] to [MPTQT -> SrcOp(int +// operands)] +template +class QuantizeOperandsPastCommutingOps : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const override { + mlir::Location loc = op.getLoc(); llvm::SmallVector operands(op->getOperands()); - unsigned numOperands = operands.size(); bool dequanted = false; - for (unsigned i = 0; i < numOperands; i++) { - if (auto trans = operands[i].getDefiningOp()) { - auto transOperands = trans.getOperands(); - Value dequantOperand; - if (auto dequant = - transOperands[0].getDefiningOp()) { - dequantOperand = dequant.getOperand(); - if (auto quant = - dequantOperand - .getDefiningOp()) { - auto quantOperands = quant.getOperands(); - auto qType = quantOperands[0] - .getType() - .cast() - .getOptionalDtype(); - auto torchQType = - cast(quant.getType()).getOptionalDtype(); - auto transQTy = - rewriter.getType(trans.getResult() - .getType() - .cast() - .getOptionalSizes(), - qType); - auto newQuantTy = - rewriter.getType(trans.getResult() - .getType() - .cast() - .getOptionalSizes(), - torchQType); - Value newTrans = rewriter.create( - op.getLoc(), transQTy, quantOperands[0], transOperands[1], - transOperands[2]); - Value newQuant = - rewriter.create( - op.getLoc(), newQuantTy, newTrans, quantOperands[1], - quantOperands[2]); - operands[i] = newQuant; - dequanted = true; - } + + for (unsigned i : QuantInfo::operandsToQuantize) { + Value operand = operands[i]; + std::stack commutingOpStack; + Value dequantOpd, MPTQTOpd; + for (unsigned k = 0; k < depth + 1; k++) { + auto currOp = operand.getDefiningOp(); + // Case 0 : currOp is a nullptr (e.g., operand is a block argument) + if (!currOp) + break; + // Case 1 : currOp is a q commuting op (continue loop) + if (isQCommutingOp(currOp)) { + commutingOpStack.push(currOp); + // set operand to currOp for next k-iteration + operand = currOp->getOperand(0); + continue; + } + // Case 2 : currOp is a dequant op (end loop) + if (llvm::isa(currOp)) { + dequantOpd = currOp->getOperand(0); + auto MPTQTOp = + dequantOpd.getDefiningOp(); + MPTQTOpd = MPTQTOp.getOperand(0); } + // either a dequant was found or chain broken, so break loop + break; + } + + // move to next operand if this trace was unsuccessful + if (!MPTQTOpd) + continue; + + // a successful trace occured, so set dequant to true + dequanted = true; + + // rewrite stack + Value oldOpd = MPTQTOpd; + Type intDType = + cast(MPTQTOpd.getType()).getOptionalDtype(); + while (!commutingOpStack.empty()) { + // get front of the commuting op stack and replace its first operand + // with oldOpd + auto currOp = commutingOpStack.top(); + commutingOpStack.pop(); + llvm::SmallVector currOperands(currOp->getOperands()); + currOperands[0] = oldOpd; + // get new result type + auto oldType = cast(currOp->getResultTypes()[0]); + auto intType = + rewriter.getType(oldType.getSizes(), intDType); + // rewrite currOp to have new operands and result type + // store this as oldOpd for next loop + oldOpd = rewriter + .create(loc, (currOp->getName()).getIdentifier(), + currOperands, intType, currOp->getAttrs()) + ->getResult(0); } + + // stack is empty, so oldOpd is now the corrected verion of the + // SrcOp's original operand + // convert operand -> SrcOp to oldOpd -> newMPTQTOp -> SrcOp + auto MPTQTOperands = dequantOpd.getDefiningOp()->getOperands(); + auto qTorchType = + cast(dequantOpd.getType()).getOptionalDtype(); + auto newMPTQTType = rewriter.getType( + cast(operands[i].getType()).getSizes(), qTorchType); + operands[i] = rewriter.create( + loc, newMPTQTType, oldOpd, MPTQTOperands[1], MPTQTOperands[2]); } + if (!dequanted) { - return rewriter.notifyMatchFailure( - op, "no dequantized transpose inputs found."); + return rewriter.notifyMatchFailure(op, "No dequantizations found."); } + rewriter.replaceOpWithNewOp(op, op.getType(), operands); return success(); } @@ -356,11 +371,13 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { RemoveUnused, RemoveUnused, RemoveUnused, - RemoveUnused, QuantizeOperands, - QuantizeOperands, QuantizeOperands, - QuantizeTransposedOperands, - QuantizeAccumulator, QuantizeOperands, - QuantizeTransposedOperands, QuantizeAccumulator, + RemoveUnused, RemoveUnused, + RemoveUnused, + QuantizeOperandsPastCommutingOps, + QuantizeOperandsPastCommutingOps, + QuantizeOperandsPastCommutingOps, + QuantizeOperandsPastCommutingOps, + QuantizeAccumulator, QuantizeAccumulator, QuantizeResultLikeOperand, QuantizeBias>( context); diff --git a/test/Dialect/Torch/fuse-quantized-ops.mlir b/test/Dialect/Torch/fuse-quantized-ops.mlir index f98cb842f5d..594295d4e86 100644 --- a/test/Dialect/Torch/fuse-quantized-ops.mlir +++ b/test/Dialect/Torch/fuse-quantized-ops.mlir @@ -28,6 +28,60 @@ func.func @mm(%arg0: !torch.vtensor<[4, 4],si8>, %arg1: !torch.vtensor<[4, 4],si // ----- +// CHECK-LABEL: @matmul_commuting +func.func @matmul_commuting(%arg0: !torch.vtensor<[2,128,32,32],si8>) -> !torch.vtensor<[1,1024,1024],f32> { + %float5.000000e-01 = torch.constant.float 5.000000e-01 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int-128 = torch.constant.int -128 + %int2 = torch.constant.int 2 + %int128 = torch.constant.int 128 + %int1024 = torch.constant.int 1024 + %int12 = torch.constant.int 12 + %0 = torch.aten._make_per_tensor_quantized_tensor %arg0, %float5.000000e-01, %int-128 : !torch.vtensor<[2,128,32,32],si8>, !torch.float, !torch.int -> !torch.vtensor<[2,128,32,32],!torch.qint8> + %1 = torch.aten.dequantize.self %0 : !torch.vtensor<[2,128,32,32],!torch.qint8> -> !torch.vtensor<[2,128,32,32],f32> + %2 = torch.aten.slice.Tensor %1, %int0, %int0, %int1, %int1 : !torch.vtensor<[2,128,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],f32> + %3 = torch.aten.slice.Tensor %1, %int0, %int1, %int2, %int1 : !torch.vtensor<[2,128,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],f32> + %4 = torch.prim.ListConstruct %int1, %int128, %int1024 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5 = torch.aten.reshape %2, %4 : !torch.vtensor<[1,128,32,32],f32>, !torch.list -> !torch.vtensor<[1,128,1024],f32> + %6 = torch.aten.reshape %3, %4 : !torch.vtensor<[1,128,32,32],f32>, !torch.list -> !torch.vtensor<[1,128,1024],f32> + %7 = torch.aten.transpose.int %5, %int1, %int2 : !torch.vtensor<[1,128,1024],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],f32> + %8 = torch.aten.quantize_per_tensor %7, %float5.000000e-01, %int0, %int12 : !torch.vtensor<[1,1024,128],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8> + %9 = torch.aten.int_repr %8 : !torch.vtensor<[1,1024,128],!torch.qint8> -> !torch.vtensor<[1,1024,128],si8> + %10 = torch.aten._make_per_tensor_quantized_tensor %9, %float5.000000e-01, %int0 : !torch.vtensor<[1,1024,128],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8> + %11 = torch.aten.dequantize.self %10 : !torch.vtensor<[1,1024,128],!torch.qint8> -> !torch.vtensor<[1,1024,128],f32> + %12 = torch.aten.matmul %11, %6 : !torch.vtensor<[1,1024,128],f32>, !torch.vtensor<[1,128,1024],f32> -> !torch.vtensor<[1,1024,1024],f32> + + // CHECK-DAG: %[[QUARTER:.+]] = torch.constant.float 2.500000e-01 + // CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[IN128:.+]] = torch.constant.int -128 + // CHECK-DAG: %[[I2:.+]] = torch.constant.int 2 + // CHECK-DAG: %[[I128:.+]] = torch.constant.int 128 + // CHECK-DAG: %[[I1024:.+]] = torch.constant.int 1024 + // CHECK-DAG: %[[I12:.+]] = torch.constant.int 12 + // CHECK-DAG: %[[MPTQT0:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[IN128]] : !torch.vtensor<[2,128,32,32],si8>, !torch.float, !torch.int -> !torch.vtensor<[2,128,32,32],!torch.qint8> + // CHECK-DAG: %[[DQ0:.+]] = torch.aten.dequantize.self %[[MPTQT0]] : !torch.vtensor<[2,128,32,32],!torch.qint8> -> !torch.vtensor<[2,128,32,32],f32> + // CHECK-DAG: %[[SLICE0:.+]] = torch.aten.slice.Tensor %[[DQ0]], %[[I0]], %[[I0]], %[[I1]], %[[I1]] : !torch.vtensor<[2,128,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],f32> + // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[I1]], %[[I128]], %[[I1024]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[RESHAPE0:.+]] = torch.aten.reshape %[[SLICE0]], %[[LIST]] : !torch.vtensor<[1,128,32,32],f32>, !torch.list -> !torch.vtensor<[1,128,1024],f32> + // CHECK-DAG: %[[TR0:.+]] = torch.aten.transpose.int %[[RESHAPE0]], %[[I1]], %[[I2]] : !torch.vtensor<[1,128,1024],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],f32> + // CHECK-DAG: %[[Q0:.+]] = torch.aten.quantize_per_tensor %[[TR0]], %[[HALF]], %[[I0]], %[[I12]] : !torch.vtensor<[1,1024,128],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8> + // CHECK-DAG: %[[IR0:.+]] = torch.aten.int_repr %[[Q0]] : !torch.vtensor<[1,1024,128],!torch.qint8> -> !torch.vtensor<[1,1024,128],si8> + // CHECK-DAG: %[[MPTQT1:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[IR0]], %[[HALF]], %[[I0]] : !torch.vtensor<[1,1024,128],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8> + // CHECK-DAG: %[[SLICE1:.+]] = torch.aten.slice.Tensor %arg0, %[[I0]], %[[I1]], %[[I2]], %[[I1]] : !torch.vtensor<[2,128,32,32],si8>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],si8> + // CHECK-DAG: %[[RESHAPE1:.+]] = torch.aten.reshape %[[SLICE1]], %[[LIST]] : !torch.vtensor<[1,128,32,32],si8>, !torch.list -> !torch.vtensor<[1,128,1024],si8> + // CHECK-DAG: %[[MPTQT2:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[RESHAPE1]], %[[HALF]], %[[IN128]] : !torch.vtensor<[1,128,1024],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1024],!torch.qint8> + // CHECK-DAG: %[[MATMUL:.+]] = torch.aten.matmul %[[MPTQT1]], %[[MPTQT2]] : !torch.vtensor<[1,1024,128],!torch.qint8>, !torch.vtensor<[1,128,1024],!torch.qint8> -> !torch.vtensor<[1,1024,1024],!torch.qint32> + // CHECK-DAG: %[[IR1:.+]] = torch.aten.int_repr %[[MATMUL]] : !torch.vtensor<[1,1024,1024],!torch.qint32> -> !torch.vtensor<[1,1024,1024],si32> + // CHECK-DAG: %[[MPTQT3:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[IR1]], %[[QUARTER]], %[[I0]] : !torch.vtensor<[1,1024,1024],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,1024,1024],!torch.qint32> + // CHECK-DAG: %[[DQ1:.+]] = torch.aten.dequantize.tensor %[[MPTQT3]] : !torch.vtensor<[1,1024,1024],!torch.qint32> -> !torch.vtensor<[1,1024,1024],f32> + return %12 : !torch.vtensor<[1,1024,1024],f32> +} + +// ----- + // CHECK-LABEL: @convolution_bias func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> { %scale = torch.constant.float 0.5 @@ -43,21 +97,21 @@ func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch. %15 = torch.prim.ListConstruct %zero, %zero : (!torch.int, !torch.int) -> !torch.list %16 = torch.aten.convolution %7, %13, %arg2, %14, %15, %14, %false, %15, %one : !torch.vtensor<[1,3,8,8],f32>, !torch.vtensor<[3,3,2,2],f32>, !torch.vtensor<[3],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],f32> - // CHECK: %[[DTYPE:.+]] = torch.constant.int 14 - // CHECK: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01 - // CHECK: %[[HALF:.+]] = torch.constant.float 5.000000e-01 - // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: %[[ZERO:.+]] = torch.constant.int 0 - // CHECK: %[[ONE:.+]] = torch.constant.int 1 - // CHECK: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> - // CHECK: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> - // CHECK: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32> - // CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QBIAS]] : !torch.vtensor<[3],!torch.qint32> -> !torch.vtensor<[3],si32> - // CHECK: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[INT]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],si32> - // CHECK: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32> - // CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32> + // CHECK-DAG: %[[DTYPE:.+]] = torch.constant.int 14 + // CHECK-DAG: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01 + // CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + // CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> + // CHECK-DAG: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32> + // CHECK-DAG: %[[INT:.+]] = torch.aten.int_repr %[[QBIAS]] : !torch.vtensor<[3],!torch.qint32> -> !torch.vtensor<[3],si32> + // CHECK-DAG: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[INT]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],si32> + // CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32> + // CHECK-DAG: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32> return %16 : !torch.vtensor<[1,3,7,7],f32> }