Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ profileComplianceMap = {
{{Profile::pro_fp},
{{{fp16T, fp16T}, SpecificationVersion::V_1_0},
{{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.max_pool2d_adaptive",
{{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_1_DRAFT}}},
{{Profile::pro_fp},
{{{fp16T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
{{fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
{"tosa.transpose_conv2d",
{{{Profile::pro_int},
{{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
Expand Down Expand Up @@ -657,6 +662,14 @@ extensionComplianceMap = {
{{Extension::fp8e5m2},
{{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.max_pool2d_adaptive",
{{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e4m3},
{{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e5m2},
{{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::bf16},
{{{bf16T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}}}},
{"tosa.rfft2d",
{{{Extension::fft},
{{{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
Expand Down
38 changes: 37 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d", [NoMemoryEffect]> {
//===----------------------------------------------------------------------===//
// Operator: avg_pool2d_adaptive
//===----------------------------------------------------------------------===//
def Tosa_AvgPool2dAdaptiveOp : Tosa_InferShapedTypeOp<"avg_pool2d_adaptive"> {
def Tosa_AvgPool2dAdaptiveOp
: Tosa_InferShapedTypeOp<"avg_pool2d_adaptive", [NoMemoryEffect]> {
let summary = "Performs average pooling on the input with shape operands.";

let description = [{
Expand Down Expand Up @@ -524,6 +525,41 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d", [Pure]> {
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// Operator: max_pool2d_adaptive
//===----------------------------------------------------------------------===//
def Tosa_MaxPool2dAdaptiveOp
: Tosa_InferShapedTypeOp<"max_pool2d_adaptive", [Pure]> {
let summary = "Performs max pooling on the input.";

let description = [{
This performs a max pooling over the given input tensor. A sliding window of
size given by <kernel size> is passed over the input tensor, with the
maximum value being placed in the output tensor.
Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride,
pad arguments as inputs rather than attributes.
}];

let arguments =
(ins Tosa_Tensor4D:$input, Rank2TosaShape:$kernel, Rank2TosaShape:$stride,
Rank4TosaShape:$pad,

DefaultValuedAttr<
Tosa_NanPropagationModeAttr,
"::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode);

let results = (outs Tosa_Tensor4D:$output);

list<Availability> availability =
[Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2,
Tosa_EXT_BF16]>,
];

let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// Operator: rfft2d
//===----------------------------------------------------------------------===//
Expand Down
56 changes: 53 additions & 3 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,15 @@ void MaxPool2dOp::print(OpAsmPrinter &parser) {
printWithNanPropagationHandling(parser, *this);
}

ParseResult MaxPool2dAdaptiveOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
}

void MaxPool2dAdaptiveOp::print(OpAsmPrinter &parser) {
printWithNanPropagationHandling(parser, *this);
}

ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
}
Expand Down Expand Up @@ -1228,9 +1237,8 @@ struct AdaptivePoolingConstShapeValues {

template <typename T>
static constexpr bool IsSupportedAdaptivePoolConstShapeVerifyOp =
std::is_same_v<T, tosa::AvgPool2dAdaptiveOp>
// || std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>
;
std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;

template <typename T,
typename std::enable_if<IsSupportedAdaptivePoolConstShapeVerifyOp<T>,
Expand Down Expand Up @@ -4085,6 +4093,33 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
inferredReturnShapes);
}

LogicalResult MaxPool2dAdaptiveOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
MaxPool2dAdaptiveOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput().getType());

llvm::SmallVector<int64_t> kernelValues;
llvm::SmallVector<int64_t> strideValues;
llvm::SmallVector<int64_t> padValues;
if (tosa::getConstShapeValues(adaptor.getKernel().getDefiningOp(),
kernelValues) &&
tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
strideValues) &&
tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues)) {
return poolingInferReturnTypes(inputShape, kernelValues, strideValues,
padValues, inferredReturnShapes);
}

llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
if (inputShape.hasRank()) {
outputShape[0] = inputShape.getDimSize(0);
outputShape[3] = inputShape.getDimSize(3);
}
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}

LogicalResult MaxPool2dOp::verify() {
if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
/* outType = */ getOutput().getType())))
Expand All @@ -4096,6 +4131,21 @@ LogicalResult MaxPool2dOp::verify() {
return success();
}

LogicalResult MaxPool2dAdaptiveOp::verify() {
if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
/* outType = */ getOutput().getType())))
return failure();

AdaptivePoolingConstShapeValues values;
extractAdaptivePoolingConstShapeOperands(*this, values);

if (failed(verifyPoolingOpImpl(getOperation(), values.kernel, values.stride,
values.pad, getInput(), getOutput())))
return failure();

return success();
}

LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
DepthwiseConv2DOp::Adaptor adaptor,
Expand Down
9 changes: 9 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dAdaptiveOp op) {
return success();
}

template <>
LogicalResult
ProfileInfoDepot::populateProfileInfo(tosa::MaxPool2dAdaptiveOp op) {
addValue(op.getInput());
addValue(op.getOutput());
return success();
}

template <typename T>
LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
addValue(op.getInput());
Expand Down Expand Up @@ -288,6 +296,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_CUSTOM(Variable)
POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
POPULATE_PROFILE_INFO_CUSTOM(Dim)
POPULATE_PROFILE_INFO_CUSTOM(MaxPool2dAdaptive)

// For the most of tosa operators, all operands are profile/extension related
// and hence are all considered in this profile-based compilance check.
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {

template <typename T>
static constexpr bool IsSupportedAdaptivePoolOp =
std::is_same_v<T, tosa::AvgPool2dAdaptiveOp>
// || std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>
;
std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;

template <typename T, typename std::enable_if<IsSupportedAdaptivePoolOp<T>,
int>::type = 0>
Expand Down Expand Up @@ -817,6 +816,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_SIZES(MatMul);
CHECK_SIZES(MatmulTBlockScaled);
CHECK_SIZES(MaxPool2d);
CHECK_SIZES(MaxPool2dAdaptive);
CHECK_SIZES(RFFT2d);
// Scatter/Gather Operators
CHECK_SIZES(Gather);
Expand Down Expand Up @@ -918,6 +918,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
failed(levelCheckAdaptivePool<tosa::MaxPool2dAdaptiveOp>(op)) ||
failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op)) ||
failed(levelCheckConv2DBlockScaled(op))) {
Expand Down
116 changes: 116 additions & 0 deletions mlir/test/Dialect/Tosa/level_check.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,122 @@ func.func @test_maxpool2d_stride_x(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x3

// -----

func.func @test_maxpool2d_adaptive_kernel_y(%arg0: tensor<1x8194x32x8xf32>) -> tensor<1x2x32x8xf32> {
%kernel = tosa.const_shape {values = dense<[8193, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: kernel <= MAX_KERNEL (8192), got 8193}}
%0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
(tensor<1x8194x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x2x32x8xf32>
return %0 : tensor<1x2x32x8xf32>
}

// -----

func.func @test_maxpool2d_adaptive_kernel_x(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
%kernel = tosa.const_shape {values = dense<[1, 8193]> : tensor<2xindex>} : () -> !tosa.shape<2>
%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: kernel <= MAX_KERNEL (8192), got 8193}}
%0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
(tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
return %0 : tensor<1x32x2x8xf32>
}

// -----

func.func @test_maxpool2d_adaptive_stride_y(%arg0: tensor<1x8194x32x8xf32>) -> tensor<1x2x32x8xf32> {
%kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%stride = tosa.const_shape {values = dense<[8193, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: stride <= MAX_STRIDE (8192), got 8193}}
%0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
(tensor<1x8194x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x2x32x8xf32>
return %0 : tensor<1x2x32x8xf32>
}

// -----

func.func @test_maxpool2d_adaptive_stride_x(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
%kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%stride = tosa.const_shape {values = dense<[1, 8193]> : tensor<2xindex>} : () -> !tosa.shape<2>
%pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: stride <= MAX_STRIDE (8192), got 8193}}
%0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
(tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
return %0 : tensor<1x32x2x8xf32>
}

// -----

func.func @test_maxpool2d_adaptive_pad_first(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
// If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
// This is a workaround for the above so that we can level check the padding.
%a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>

%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%pad = tosa.const_shape {values = dense<[8193, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
%0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
(tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
return %0 : tensor<1x32x2x8xf32>
}

// -----

func.func @test_maxpool2d_adaptive_pad_second(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
// If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
// This is a workaround for the above so that we can level check the padding.
%a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>

%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%pad = tosa.const_shape {values = dense<[0, 8193, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
%0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
(tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
return %0 : tensor<1x32x2x8xf32>
}

// -----

func.func @test_maxpool2d_adaptive_pad_third(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
// If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
// This is a workaround for the above so that we can level check the padding.
%a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>

%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%pad = tosa.const_shape {values = dense<[0, 0, 8193, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
%0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
(tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
return %0 : tensor<1x32x2x8xf32>
}

// -----

func.func @test_maxpool2d_adaptive_pad_forth(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
// If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
// This is a workaround for the above so that we can level check the padding.
%a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>

%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%pad = tosa.const_shape {values = dense<[0, 0, 0, 8193]> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error@+1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
%0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
(tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
return %0 : tensor<1x32x2x8xf32>
}

// -----

func.func @test_rfft2d_input_h(%arg0: tensor<13x16384x16xf32>) -> (tensor<13x16384x9xf32>, tensor<13x16384x9xf32>) {
// expected-error@+1 {{'tosa.rfft2d' op failed level check: H <= MAX_KERNEL (8192), got 16384}}
%0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x16384x16xf32>) -> (tensor<13x16384x9xf32>, tensor<13x16384x9xf32>)
Expand Down
47 changes: 47 additions & 0 deletions mlir/test/Dialect/Tosa/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,53 @@ func.func @test_max_pool2d_f16(%arg0: tensor<1x32x32x8xf16>) -> tensor<1x32x32x8
return %0 : tensor<1x32x32x8xf16>
}

// CHECK-LABEL: max_pool2d_adaptive_f32
func.func @test_max_pool2d_adaptive_f32(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
%kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
%0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad : (tensor<1x32x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xf32>
return %0 : tensor<1x32x32x8xf32>
}

// -----
// CHECK-LABEL: max_pool2d_adaptive_bf16
func.func @test_max_pool2d_adaptive_bf16(%arg0: tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16> {
%kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
%0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad : (tensor<1x32x32x8xbf16>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xbf16>
return %0 : tensor<1x32x32x8xbf16>
}

// -----
// CHECK-LABEL: max_pool2d_adaptive_f16
func.func @test_max_pool2d_adaptive_f16(%arg0: tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16> {
%kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
%0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad : (tensor<1x32x32x8xf16>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xf16>
return %0 : tensor<1x32x32x8xf16>
}

// CHECK-LABEL: dynamic_max_pool2d_adaptive_f16
func.func @test_dynamic_max_pool2d_adaptive_f16(%arg0: tensor<1x?x?x8xf16>) -> tensor<1x?x?x8xf16> {
%kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
%0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad : (tensor<1x?x?x8xf16>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x?x?x8xf16>
return %0 : tensor<1x?x?x8xf16>
}

// CHECK-LABEL: unranked_max_pool2d_adaptive_f16
func.func @test_unranked_max_pool2d_adaptive_f16(%arg0: tensor<*xf16>) -> tensor<*xf16> {
%kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
%0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad : (tensor<*xf16>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<*xf16>
return %0 : tensor<*xf16>
}

// -----
// CHECK-LABEL: rfft2d
func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
Expand Down
Loading