diff --git a/tests/simd/simple_sum.mlir b/tests/simd/simple_sum.mlir index c61c146a6..a7c35c927 100644 --- a/tests/simd/simple_sum.mlir +++ b/tests/simd/simple_sum.mlir @@ -1,8 +1,12 @@ // RUN: heir-opt --secretize=entry-function=simple_sum --wrap-generic --canonicalize --cse \ -// RUN: --full-loop-unroll --cse --canonicalize --insert-rotate \ -// RUN: %s | FileCheck %s +// RUN: --full-loop-unroll --insert-rotate --cse --canonicalize \ +// RUN: --rotate-and-reduce --canonicalize \ +// RUN: %s | FileCheck %s // CHECK-LABEL: @simple_sum +// CHECK: secret.generic +// CHECK-COUNT-5: tensor_ext.rotate +// CHECK-NOT: tensor_ext.rotate func.func @simple_sum(%arg0: tensor<32xi16> {secret.secret}) -> i16 { %c0 = arith.constant 0 : index %c0_si16 = arith.constant 0 : i16 diff --git a/tests/tensor_ext/rotate_and_reduce.mlir b/tests/tensor_ext/rotate_and_reduce.mlir index 126f70577..0694e6f43 100644 --- a/tests/tensor_ext/rotate_and_reduce.mlir +++ b/tests/tensor_ext/rotate_and_reduce.mlir @@ -305,3 +305,67 @@ func.func @not_supported_non_tensor_operands(%arg0: tensor<8xi32>) -> i32 { %15 = arith.addi %14, %2 : i32 return %15 : i32 } + +// CHECK-LABEL: @sum_of_linear_rotates +// CHECK-COUNT-5: tensor_ext.rotate +// CHECK-NOT: tensor_ext.rotate +func.func @sum_of_linear_rotates(%arg0: !secret.secret>) -> !secret.secret { + %c30 = arith.constant 30 : index + %c29 = arith.constant 29 : index + %c31 = arith.constant 31 : index + %c1 = arith.constant 1 : index + %0 = secret.generic ins(%arg0 : !secret.secret>) { + ^bb0(%arg1: tensor<32xi16>): + %1 = tensor_ext.rotate %arg1, %c1 : tensor<32xi16>, index + %2 = arith.addi %1, %arg1 : tensor<32xi16> + %3 = tensor_ext.rotate %arg1, %c31 : tensor<32xi16>, index + %4 = tensor_ext.rotate %2, %c29 : tensor<32xi16>, index + %5 = arith.addi %3, %4 : tensor<32xi16> + %6 = arith.addi %5, %arg1 : tensor<32xi16> + %7 = tensor_ext.rotate %6, %c30 : tensor<32xi16>, index + %8 = arith.addi %3, %7 : tensor<32xi16> + %9 = arith.addi %8, %arg1 : tensor<32xi16> + %10 = tensor_ext.rotate %9, %c30 : tensor<32xi16>, index + %11 = arith.addi %3, %10 : tensor<32xi16> + %12 = arith.addi %11, %arg1 : tensor<32xi16> + %13 = tensor_ext.rotate %12, %c30 : tensor<32xi16>, index + %14 = arith.addi %3, %13 : tensor<32xi16> + %15 = arith.addi %14, %arg1 : tensor<32xi16> + %16 = tensor_ext.rotate %15, %c30 : tensor<32xi16>, index + %17 = arith.addi %3, %16 : tensor<32xi16> + %18 = arith.addi %17, %arg1 : tensor<32xi16> + %19 = tensor_ext.rotate %18, %c30 : tensor<32xi16>, index + %20 = arith.addi %3, %19 : tensor<32xi16> + %21 = arith.addi %20, %arg1 : tensor<32xi16> + %22 = tensor_ext.rotate %21, %c30 : tensor<32xi16>, index + %23 = arith.addi %3, %22 : tensor<32xi16> + %24 = arith.addi %23, %arg1 : tensor<32xi16> + %25 = tensor_ext.rotate %24, %c30 : tensor<32xi16>, index + %26 = arith.addi %3, %25 : tensor<32xi16> + %27 = arith.addi %26, %arg1 : tensor<32xi16> + %28 = tensor_ext.rotate %27, %c30 : tensor<32xi16>, index + %29 = arith.addi %3, %28 : tensor<32xi16> + %30 = arith.addi %29, %arg1 : tensor<32xi16> + %31 = tensor_ext.rotate %30, %c30 : tensor<32xi16>, index + %32 = arith.addi %3, %31 : tensor<32xi16> + %33 = arith.addi %32, %arg1 : tensor<32xi16> + %34 = tensor_ext.rotate %33, %c30 : tensor<32xi16>, index + %35 = arith.addi %3, %34 : tensor<32xi16> + %36 = arith.addi %35, %arg1 : tensor<32xi16> + %37 = tensor_ext.rotate %36, %c30 : tensor<32xi16>, index + %38 = arith.addi %3, %37 : tensor<32xi16> + %39 = arith.addi %38, %arg1 : tensor<32xi16> + %40 = tensor_ext.rotate %39, %c30 : tensor<32xi16>, index + %41 = arith.addi %3, %40 : tensor<32xi16> + %42 = arith.addi %41, %arg1 : tensor<32xi16> + %43 = tensor_ext.rotate %42, %c30 : tensor<32xi16>, index + %44 = arith.addi %3, %43 : tensor<32xi16> + %45 = arith.addi %44, %arg1 : tensor<32xi16> + %46 = tensor_ext.rotate %45, %c30 : tensor<32xi16>, index + %47 = arith.addi %3, %46 : tensor<32xi16> + %48 = arith.addi %47, %arg1 : tensor<32xi16> + %extracted = tensor.extract %48[%c31] : tensor<32xi16> + secret.yield %extracted : i16 + } -> !secret.secret + return %0 : !secret.secret +}