Skip to content

Add fold-plaintext-masks for simple cleanup after implement-shift-network#2343

Merged
copybara-service[bot] merged 1 commit intogoogle:mainfrom
j2kun:shift-network-cleanup
Oct 24, 2025
Merged

Add fold-plaintext-masks for simple cleanup after implement-shift-network#2343
copybara-service[bot] merged 1 commit intogoogle:mainfrom
j2kun:shift-network-cleanup

Conversation

@j2kun
Copy link
Collaborator

@j2kun j2kun commented Oct 23, 2025

Fixes #2337

The sole idea here is to take repeated applications of plaintext one-hot masks and combine them via bitwise-and'ing the entries of the masks.

The example test still produces some inefficiencies:

$ python ~/fhe/heir/scripts/lit_to_bazel.py --run=True tests/Regressionissue_2337.mlir
module {
  func.func @trivial_insert(%arg0: !secret.secret<tensor<2x32xi32>>) -> !secret.secret<tensor<2x32xi32>> {
    %cst = arith.constant dense<[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<32xi32>
    %cst_0 = arith.constant dense<0> : tensor<32xi32>
    %cst_1 = arith.constant dense<[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]> : tensor<32xi32>
    %c16 = arith.constant 16 : index
    %0 = tensor.empty() : tensor<2x32xi32>
    %1 = secret.generic(%arg0: !secret.secret<tensor<2x32xi32>>) {
    ^body(%input0: tensor<2x32xi32>):
      %extracted_slice = tensor.extract_slice %input0[1, 0] [1, 32] [1, 1] : tensor<2x32xi32> to tensor<32xi32>
      %2 = arith.muli %extracted_slice, %cst : tensor<32xi32>
      %3 = tensor_ext.rotate %2, %c16 : tensor<32xi32>, index
      %4 = arith.muli %3, %cst_1 : tensor<32xi32>
      %5 = arith.addi %4, %2 : tensor<32xi32>
      %6 = arith.muli %extracted_slice, %cst_1 : tensor<32xi32>
      %7 = tensor_ext.rotate %6, %c16 : tensor<32xi32>, index
      %8 = arith.addi %6, %7 : tensor<32xi32>
      %9 = arith.addi %5, %8 : tensor<32xi32>
      %inserted_slice = tensor.insert_slice %9 into %0[0, 0] [1, 32] [1, 1] : tensor<32xi32> into tensor<2x32xi32>
      %inserted_slice_2 = tensor.insert_slice %cst_0 into %inserted_slice[1, 0] [1, 32] [1, 1] : tensor<32xi32> into tensor<2x32xi32>
      secret.yield %inserted_slice_2 : tensor<2x32xi32>
    } -> !secret.secret<tensor<2x32xi32>>
    return %1 : !secret.secret<tensor<2x32xi32>>
  }
}

These three lines show the problem

%2 = arith.muli %extracted_slice, %cst : tensor<32xi32>
%3 = tensor_ext.rotate %2, %c16 : tensor<32xi32>, index
%4 = arith.muli %3, %cst_1 : tensor<32xi32>

It extracts the first half, rotates it to the second half, then masks the second half (a no-op) before adding it to a copy of the first half. That second mask can be eliminated because the other half of the entries can be inferred to be zero. Something for a future slot analysis to handle...

@j2kun j2kun requested a review from asraa October 23, 2025 06:39
Copy link
Collaborator

@asraa asraa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!

@j2kun j2kun force-pushed the shift-network-cleanup branch from feff1c0 to a5a3cf3 Compare October 23, 2025 16:33
@j2kun j2kun added the pull_ready Indicates whether a PR is ready to pull. The copybara worker will import for internal testing label Oct 23, 2025
@j2kun j2kun force-pushed the shift-network-cleanup branch from a5a3cf3 to 776db9f Compare October 23, 2025 17:34
@copybara-service copybara-service bot merged commit 38a67f9 into google:main Oct 24, 2025
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pull_ready Indicates whether a PR is ready to pull. The copybara worker will import for internal testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

implement-shift-network extraneous code output

2 participants