Skip to content

Commit

Permalink
Handle one-sided constants in insert-rotate
Browse files Browse the repository at this point in the history
This PR ports the gx_kernel evaluation artifact from HECO, and supports it by doing two things:

- Adding a canonicalize pass before insert-rotate but after full-loop-unroll, to ensure that ND tensors of constants are materialized to constants of tensors.
- Adding two patterns to splat constants into tensors during rotations insertion.

This leaves unanswered an important question: how should we detect and handle plaintext types that are not constants (say, function inputs or function inputs that are modified by some IR-internal ops). I will file a followup issue on that topic.

Note canonicalize slows the box_blur_64x64 test to a crawl (> 15m), so I converted it to "enormous" size so that it is skipped in CI.

Part of #571

PiperOrigin-RevId: 621321174
  • Loading branch information
j2kun authored and Copybara-Service committed Apr 2, 2024
1 parent 54043a3 commit 4d84f04
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 6 deletions.
48 changes: 43 additions & 5 deletions include/Dialect/TensorExt/Transforms/InsertRotate.td
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ include "include/DRR/Utils.td"
include "include/Dialect/TensorExt/IR/TensorExtOps.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/Constraints.td"
include "mlir/IR/PatternBase.td"

// Get the target_slot attribute from an op, if it exists, or else
Expand All @@ -14,17 +16,18 @@ def GetTargetSlotAttr : NativeCodeCall<
" ? llvm::cast<mlir::IntegerAttr>($0.getOwner()->getAttr(\"target_slot\"))"
" : $_builder.getIndexAttr(0)">;

def CreateSplatOp : NativeCodeCall<
"$_builder.create<tensor::SplatOp>($2.getLoc(), $0, $1.getType())">;

// The patterns in this file are intended to align with the automatic-SIMD
// batching heuristics from the HECO project. See section 4.4 of
// https://arxiv.org/abs/2202.01649 and the hir2hir passes in
// https://github.com/MarbleHE/HECO/blob/main/src/Passes/hir2hir/

// Match an arith op that extracts scalar values from two tensors, and replace
// it with rotations to align slots and apply the same op in SIMD. Other
// patterns in this file will find better alignment of adjacent rotations, and
// canonicalization patterns will remove duplicated rotations.
// Match an arith op that extracts scalar values from one or more tensors, and
// replace it with rotations to align slots and apply the same op in SIMD.
foreach ArithOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in {
def InsertRotations_#ArithOp : Pattern<
def InsertRotations_TwoTensorArgs_#ArithOp : Pattern<
(ArithOp:$arithOp
(Tensor_ExtractOp $t1, (variadic $i1)),
(Tensor_ExtractOp $t2, (variadic $i2)),
Expand All @@ -41,6 +44,41 @@ foreach ArithOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in {
(Arith_ConstantOp (GetTargetSlotAttr $arithOp)))),
]
>;

// In this and the next pattern, the non-tensor arg must be elevated to a tensor
// by repeating it across the right dimension.
// TODO(#586): support more than just constant operands
def InsertRotations_SplatRHSToTensor_#ArithOp : Pattern<
(ArithOp:$arithOp
(Tensor_ExtractOp $t1, (variadic $i1)),
(Arith_ConstantOp:$nonExtractedArg $value),
$overflow),
[
(TensorExt_RotateOp:$r1 $t1,
(Arith_SubIOp $i1, (Arith_ConstantOp (GetTargetSlotAttr $arithOp)), DefOverflow)),
(ArithOp:$opResult $r1, (CreateSplatOp $nonExtractedArg, $t1, $arithOp), $overflow),
(Tensor_ExtractOp
$opResult,
(MakeSingleResultVariadic
(Arith_ConstantOp (GetTargetSlotAttr $arithOp)))),
]
>;

def InsertRotations_SplatLHSToTensor_#ArithOp : Pattern<
(ArithOp:$arithOp
(Arith_ConstantOp:$nonExtractedArg $value),
(Tensor_ExtractOp $t1, (variadic $i1)),
$overflow),
[
(TensorExt_RotateOp:$r1 $t1,
(Arith_SubIOp $i1, (Arith_ConstantOp (GetTargetSlotAttr $arithOp)), DefOverflow)),
(ArithOp:$opResult (CreateSplatOp $nonExtractedArg, $t1, $arithOp), $r1, $overflow),
(Tensor_ExtractOp
$opResult,
(MakeSingleResultVariadic
(Arith_ConstantOp (GetTargetSlotAttr $arithOp)))),
]
>;
}


Expand Down
8 changes: 7 additions & 1 deletion tests/heir_simd_vectorizer/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@ glob_lit_tests(
data = ["@heir//tests:test_utilities"],
driver = "@heir//tests:run_lit.sh",
size_override = {
"box_blur_64x64.mlir": "large",
"box_blur_64x64.mlir": "enormous",
},
tags_override = {
"box_blur_64x64.mlir": [
"nofastbuild",
"manual",
],
},
test_file_exts = ["mlir"],
)
48 changes: 48 additions & 0 deletions tests/heir_simd_vectorizer/gx_kernel_8x8.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// RUN: heir-opt --secretize=entry-function=gx_kernel --wrap-generic --canonicalize --cse \
// RUN: --heir-simd-vectorizer %s | FileCheck %s

// CHECK-LABEL: @gx_kernel
// CHECK: secret.generic
// CHECK-COUNT-6: tensor_ext.rotate
// CHECK-NOT: tensor_ext.rotate
func.func @gx_kernel(%arg0: tensor<64xi16>) -> tensor<64xi16> {
%c64 = arith.constant 64 : index
%c8 = arith.constant 8 : index
%c1_index = arith.constant 1 : index
%c0_si16 = arith.constant 0 : i16
%c0 = arith.constant 0 : i16
%c1 = arith.constant 1 : i16
%c2 = arith.constant 2 : i16
%cm1= arith.constant -1 : i16
%cm2 = arith.constant -2 : i16
%weight_matrix = tensor.from_elements %c1, %cm1, %c2, %cm2, %c1, %cm1, %c0, %c0, %c0 : tensor<3x3xi16>
%0 = affine.for %x = 0 to 8 iter_args(%arg0_x = %arg0) -> (tensor<64xi16>) {
%1 = affine.for %y = 0 to 8 iter_args(%arg0_y = %arg0_x) -> (tensor<64xi16>) {
%2 = affine.for %j = -1 to 2 iter_args(%value_j = %c0_si16) -> (i16) {
%6 = affine.for %i = -1 to 2 iter_args(%value_i = %value_j) -> (i16) {
%7 = arith.addi %x, %i : index
%8 = arith.muli %7, %c8 : index
%9 = arith.addi %y, %j : index
%10 = arith.addi %8, %9 : index
%11 = arith.remui %10, %c64 : index
%12 = tensor.extract %arg0[%11] : tensor<64xi16>
// Get the weight from the weight matrix!
%ip = arith.addi %i,%c1_index : index
%jp = arith.addi %j,%c1_index : index
%w = tensor.extract %weight_matrix[%ip,%jp] : tensor<3x3xi16>
%mul = arith.muli %12, %w : i16
%13 = arith.addi %value_i, %mul : i16
affine.yield %13 : i16
}
affine.yield %6 : i16
}
%3 = arith.muli %c8, %x : index
%4 = arith.addi %3, %y : index
%5 = arith.remui %4, %c64 : index
%6 = tensor.insert %2 into %arg0_y[%5] : tensor<64xi16>
affine.yield %6 : tensor<64xi16>
}
affine.yield %1 : tensor<64xi16>
}
return %0 : tensor<64xi16>
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,29 @@ func.func @test_insert_rotation_for_add(%arg1: tensor<16xi32>) -> i32 {
// CHECK-NEXT: %[[v3:.*]] = arith.addi %[[v2]], %[[arg0]] : tensor<16xi32>
// CHECK-NEXT: %[[extracted:.*]] = tensor.extract %[[v3]][%[[c4]]] : tensor<16xi32>
// CHECK-NEXT: return

// This test should be applying the new patterns InsertRotations_SplatRHSToTensor,
// but the splats themselves should be elided.
// CHECK-LABEL: @test_splat
// CHECK-NOT: tensor_ext.splat
func.func @test_splat(%arg1: tensor<16xi32>) -> tensor<16xi32> {
%c0 = arith.constant 1 : index
%c1 = arith.constant 1 : index
%c11 = arith.constant 11 : index
%c12 = arith.constant 12 : index
%c15 = arith.constant 15 : index
%cst = arith.constant dense<[[2, 3], [4, 5]]> : tensor<2x2xi32>

%extracted = tensor.extract %arg1[%c11] : tensor<16xi32>
%extracted_0 = tensor.extract %cst[%c1, %c1] : tensor<2x2xi32>
%1 = arith.muli %extracted, %extracted_0 : i32

%extracted_1 = tensor.extract %arg1[%c12] : tensor<16xi32>
%extracted_2 = tensor.extract %cst[%c0, %c1] : tensor<2x2xi32>
%2 = arith.muli %extracted_1, %extracted_2 : i32

%3 = arith.addi %1, %2 : i32

%inserted = tensor.insert %3 into %arg1[%c1] : tensor<16xi32>
return %inserted : tensor<16xi32>
}
13 changes: 13 additions & 0 deletions tools/heir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,21 @@ void polynomialToLLVMPipelineBuilder(OpPassManager &manager) {
void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager) {
// For now we unroll loops to enable insert-rotate, but we would like to be
// smarter about this and do an affine loop analysis.
// TODO(#589): avoid unrolling loops
manager.addPass(createFullLoopUnroll());

// This canonicalize is required in this position for a relatively nuanced
// reason. insert-rotate doesn't have general match support. In particular,
// if a tensor extract from a secret is combined with a tensor extract from
// a constant 2D tensor (e.g., the weight matrix of a convolution), then
// insert-rotate won't be able to tell the difference and understand that
// the extracted value from the 2D tensor should be splatted. This
// canonicalize pass converts a constant weight matrix into the underlying
// arith.constant values, which are supported as a splattable non-tensor
// input in insert-rotate.
// TODO(#586): find a more robust solution
manager.addPass(createCanonicalizerPass());

// Insert rotations aligned to slot targets. Future work should provide
// alternative methods to optimally align rotations, and allow the user to
// configure this via pipeline options.
Expand Down

0 comments on commit 4d84f04

Please sign in to comment.