Skip to content

Commit

Permalink
Integrate llvm-project at dabdec1001dc368373dd581cf72f37a440873ce3 (#…
Browse files Browse the repository at this point in the history
…3300)

Co-authored-by: Jacques Pienaar <jpienaar@google.com>
  • Loading branch information
bjacob and jpienaar committed May 8, 2024
1 parent 0abc586 commit bce800a
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 32 deletions.
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ cmake_dependent_option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF TORCH_MLI

option(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER "Enables the ONNX C importer" OFF)

# TODO(#3299): migrate to from member x.cast<T>() to mlir::cast<T>(x).
if(MSVC)
add_compile_options(/wd4996)
else()
add_compile_options(-Wno-deprecated-declarations)
endif()

macro(torch_mlir_enable_werror)
if(TORCH_MLIR_ENABLE_WERROR_FLAG)
if(NOT MSVC)
Expand Down
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 5037 files
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ TosaOpT createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op,
// This specialization is for Div op. Unlike other binary ops, it doesn't
// support floating type.
template <>
tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
Operation *op, TensorType outType,
Value lhs, Value rhs);
tosa::IntDivOp
createBinaryOpAndCast<IntDivOp>(PatternRewriter &rewriter, Operation *op,
TensorType outType, Value lhs, Value rhs);

std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
Operation *op,
Expand Down
9 changes: 4 additions & 5 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,8 +511,8 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
} else {
// The output type can be different than the input types (e.g. dividing an
// int tensor results in a floating point tensor).
result = tosa::createBinaryOpAndCast<tosa::DivOp>(rewriter, op, outType,
lhs, rhsTensor)
result = tosa::createBinaryOpAndCast<tosa::IntDivOp>(
rewriter, op, outType, lhs, rhsTensor)
.getResult();
}

Expand Down Expand Up @@ -4380,16 +4380,15 @@ LogicalResult ConvertAtenOp<AtenRemainderScalarOp>::matchAndRewrite(
self = rewriter.create<tosa::CastOp>(op.getLoc(), outType, self);

auto divTensor = self;
// tosa::DivOp only supports int
if (isa<mlir::FloatType>(outElemTy)) {
auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>(
op.getLoc(), otherTensor.getType(), otherTensor);
divTensor = rewriter.create<tosa::MulOp>(
op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0);
divTensor = rewriter.create<tosa::FloorOp>(op.getLoc(), outType, divTensor);
} else {
divTensor =
rewriter.create<tosa::DivOp>(op.getLoc(), outType, self, otherTensor);
divTensor = rewriter.create<tosa::IntDivOp>(op.getLoc(), outType, self,
otherTensor);
}

auto mulTensor =
Expand Down
14 changes: 7 additions & 7 deletions lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,20 +114,20 @@ tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
}

template <>
tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
Operation *op, TensorType outType,
Value lhs, Value rhs) {
tosa::IntDivOp
createBinaryOpAndCast<IntDivOp>(PatternRewriter &rewriter, Operation *op,
TensorType outType, Value lhs, Value rhs) {
auto lhsElemTy = cast<TensorType>(lhs.getType()).getElementType();
auto rhsElemTy = cast<TensorType>(rhs.getType()).getElementType();
if (isa<mlir::FloatType>(lhsElemTy) || isa<mlir::FloatType>(rhsElemTy)) {
(void)rewriter.notifyMatchFailure(op,
"tosa.div only supports integer type");
(void)rewriter.notifyMatchFailure(
op, "tosa.int_div only supports integer type");
}

lhs = promoteType(rewriter, lhs, outType);
rhs = promoteType(rewriter, rhs, outType);
return tosa::CreateOpAndInfer<tosa::DivOp>(rewriter, op->getLoc(), outType,
lhs, rhs);
return tosa::CreateOpAndInfer<tosa::IntDivOp>(rewriter, op->getLoc(), outType,
lhs, rhs);
}

std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/TorchToLinalg/flatten.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func.func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3
// CHECK-LABEL: func.func @torch.aten.flatten.using_ints$rank0(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[COLLAPSED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] output_shape [1] : tensor<f32> into tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>

Expand Down
10 changes: 5 additions & 5 deletions test/Conversion/TorchToLinalg/unsqueeze.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
// CHECK-LABEL: func.func @torch.aten.unsqueeze$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] output_shape [1] : tensor<f32> into tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
Expand All @@ -18,7 +18,7 @@ func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.v
// CHECK-LABEL: func.func @torch.aten.unsqueeze$basic_negative(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] output_shape [1] : tensor<f32> into tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
func.func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
Expand All @@ -30,7 +30,7 @@ func.func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) ->
// CHECK-LABEL: func.func @torch.aten.unsqueeze$higher_rank_front(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3]] : tensor<2x3x4xf32> into tensor<1x2x3x4xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3]] output_shape [1, 2, 3, 4] : tensor<2x3x4xf32> into tensor<1x2x3x4xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,3,4],f32>
func.func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
Expand All @@ -42,7 +42,7 @@ func.func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],
// CHECK-LABEL: func.func @torch.aten.unsqueeze$higher_rank_back(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x4x1xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] output_shape [2, 3, 4, 1] : tensor<2x3x4xf32> into tensor<2x3x4x1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x4x1xf32> -> !torch.vtensor<[2,3,4,1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4,1],f32>
func.func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> {
Expand All @@ -54,7 +54,7 @@ func.func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f
// CHECK-LABEL: func.func @torch.aten.unsqueeze$higher_rank_middle(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x1x4xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] output_shape [2, 3, 1, 4] : tensor<2x3x4xf32> into tensor<2x3x1x4xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x1x4xf32> -> !torch.vtensor<[2,3,1,4],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,1,4],f32>
func.func @torch.aten.unsqueeze$higher_rank_middle(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> {
Expand Down
18 changes: 11 additions & 7 deletions test/Conversion/TorchToLinalg/view.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[3,2],f32> -> tensor<3x2xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1]] : tensor<3x2xf32> into tensor<6xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1]] : tensor<6xf32> into tensor<2x3xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1]] output_shape [2, 3] : tensor<6xf32> into tensor<2x3xf32>
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,3],f32>

Expand Down Expand Up @@ -64,7 +64,7 @@ func.func @torch.aten.view$dynamictest2(%arg0: !torch.vtensor<[?,6,?],f32>) -> !
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[1,?,128],f32> -> tensor<1x?x128xf32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor<1x?x128xf32> to tensor<1x16x128xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CASTED]] {{\[\[}}0, 1], [2]] : tensor<1x16x128xf32> into tensor<16x128xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0], [1, 2]] : tensor<16x128xf32> into tensor<16x1x128xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0], [1, 2]] output_shape [16, 1, 128] : tensor<16x128xf32> into tensor<16x1x128xf32>
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<16x1x128xf32> -> !torch.vtensor<[16,1,128],f32>
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[16,1,128],f32>

Expand All @@ -83,7 +83,7 @@ func.func @torch.aten.view$dynamicVal(%arg0: !torch.vtensor<[1,?,128],f32>) -> !
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[8,1,?,1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2]] : tensor<4x5x6xf32> into tensor<120xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2, 3]] : tensor<120xf32> into tensor<8x1x15x1xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2, 3]] output_shape [8, 1, 15, 1] : tensor<120xf32> into tensor<8x1x15x1xf32>
// CHECK: %[[CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<8x1x15x1xf32> to tensor<8x1x?x1xf32>
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<8x1x?x1xf32> -> !torch.vtensor<[8,1,?,1],f32>
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[8,1,?,1],f32>
Expand All @@ -103,7 +103,7 @@ func.func @torch.aten$dynamicValOutput(%arg0: !torch.vtensor<[4,5,6],f32>) -> !t
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[2,1,2,3,?],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2]] : tensor<4x5x6xf32> into tensor<4x30xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2], [3, 4]] : tensor<4x30xf32> into tensor<2x1x2x3x10xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2], [3, 4]] output_shape [2, 1, 2, 3, 10] : tensor<4x30xf32> into tensor<2x1x2x3x10xf32>
// CHECK: %[[CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<2x1x2x3x10xf32> to tensor<2x1x2x3x?xf32>
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<2x1x2x3x?xf32> -> !torch.vtensor<[2,1,2,3,?],f32>
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,1,2,3,?],f32>
Expand All @@ -125,7 +125,7 @@ func.func @torch.aten$dynamicValOutput2(%arg0: !torch.vtensor<[4,5,6],f32>) -> !
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,6],f32> -> tensor<2x6xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1]] : tensor<2x6xf32> into tensor<12xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] : tensor<12xf32> into tensor<3x2x2xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] output_shape [3, 2, 2] : tensor<12xf32> into tensor<3x2x2xf32>
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<3x2x2xf32> -> !torch.vtensor<[3,2,2],f32>
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[3,2,2],f32>

Expand All @@ -144,7 +144,9 @@ func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) -
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[10,3,?,2,3],f32>) -> !torch.vtensor<[2,3,5,?,6],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[10,3,?,2,3],f32> -> tensor<10x3x?x2x3xf32>
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3, 4]] : tensor<10x3x?x2x3xf32> into tensor<30x?x6xf32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1, 2], [3], [4]] : tensor<30x?x6xf32> into tensor<2x3x5x?x6xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[COLLAPSE]], %[[C1]] : tensor<30x?x6xf32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1, 2], [3], [4]] output_shape [2, 3, 5, %[[DIM]], 6] : tensor<30x?x6xf32> into tensor<2x3x5x?x6xf32>
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor<2x3x5x?x6xf32> -> !torch.vtensor<[2,3,5,?,6],f32>
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,3,5,?,6],f32>

Expand Down Expand Up @@ -241,7 +243,9 @@ func.func @torch.aten.view$castingView (%arg0 : !torch.vtensor<[?,?,?], f32>) ->
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[10,?,2,3],f32>) -> !torch.vtensor<[2,5,?,6],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[10,?,2,3],f32> -> tensor<10x?x2x3xf32>
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<10x?x2x3xf32> into tensor<10x?x6xf32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1], [2], [3]] : tensor<10x?x6xf32> into tensor<2x5x?x6xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[COLLAPSE]], %[[C1]] : tensor<10x?x6xf32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1], [2], [3]] output_shape [2, 5, %[[DIM]], 6] : tensor<10x?x6xf32> into tensor<2x5x?x6xf32>
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor<2x5x?x6xf32> -> !torch.vtensor<[2,5,?,6],f32>
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,5,?,6],f32>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32>
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xi16>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xi32>
// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_0]] : (tensor<?x?xi16>) -> tensor<?x?xi32>
// CHECK: %[[VAL_3:.*]] = tosa.div %[[VAL_2]], %[[VAL_1]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
// CHECK: %[[VAL_3:.*]] = tosa.int_div %[[VAL_2]], %[[VAL_1]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],si32> {
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32>
return %0 : !torch.vtensor<[?, ?],si32>
Expand Down
4 changes: 2 additions & 2 deletions test/Dialect/TorchConversion/convert-custom-quant-op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ func.func @forward(%arg0: !torch.vtensor<[1,1,2],f16>) -> !torch.vtensor<[1,1,2]
// CHECK: %[[SCALES:.*]] = torch_c.to_builtin_tensor %[[TENSOR2]] : !torch.vtensor<[2,1,1],f16> -> tensor<2x1x1xf16>
// CHECK: %[[TENSOR3:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16>
// CHECK: %[[ZPS:.*]] = torch_c.to_builtin_tensor %[[TENSOR3]] : !torch.vtensor<[2,1,1],f16> -> tensor<2x1x1xf16>
// CHECK: %[[EXPANDED_LHS:.*]] = tensor.expand_shape %[[LHS]] {{\[\[}}0], [1], [2, 3]] : tensor<1x1x2xf16> into tensor<1x1x1x2xf16>
// CHECK: %[[EXPANDED_RHS:.*]] = tensor.expand_shape %[[QUANT_RHS]] {{\[\[}}0], [1, 2]] : tensor<2x2xi8> into tensor<2x1x2xi8>
// CHECK: %[[EXPANDED_LHS:.*]] = tensor.expand_shape %[[LHS]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 1, 1, 2] : tensor<1x1x2xf16> into tensor<1x1x1x2xf16>
// CHECK: %[[EXPANDED_RHS:.*]] = tensor.expand_shape %[[QUANT_RHS]] {{\[\[}}0], [1, 2]] output_shape [2, 1, 2] : tensor<2x2xi8> into tensor<2x1x2xi8>
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f16
// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<2x1x2xf16>
// CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<1x1x2xf16>
Expand Down

0 comments on commit bce800a

Please sign in to comment.