Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite rotation analysis to support dot product port from HECO #575

Merged
merged 1 commit into from
Apr 1, 2024

Conversation

j2kun
Copy link
Collaborator

@j2kun j2kun commented Mar 28, 2024

Ports:

  • dot_product
  • linear_polynomial
  • quadratic_polynomial

Part of #571

The main thing this PR does it rewrite the rotation analysis used in rotate-and-reduce. It originally did not support any reduction that started from two separate tensors (e.g., the dot_product example ported in this PR) because the lattice framework considered that overdetermined. After thinking about this, I realized that what I really wanted to do was "reset" the lattice after I found an overdetermined state, re-initializing it at the result of the operation that made it overdetermined. E.g.,

%0 = arith.muli %arg1, %arg2  <-- overdetermined b/c two different source tensors
%1 = tensor_ext.rotate %0, %c1  <-- valid start of a reduction with base tensor %0
%2 = arith.addi %0, %1
...

The lattice framework doesn't support this because "resetting" is not a monotonic operation which is required of a semijoin lattice. I suspect that subclassing a higher level of the dataflow framework may have worked, but I felt it was just simpler to write a "generic" analysis with no constraints.

The new analysis does a single walk over the IR, and cumulatively builds up a PartialReduction struct corresponding to the set of indices visited, along with extra information about the operation being processed, the source tensor being reduced, and the "root" tensor which is the SSA value that may ultimately be replaced by a log-number of rotations in the pass, and is used as the extension point for the accumulation.

The logic should be much clearer now, and the pass itself no longer requires any extra hacks to ensure consistency. If this looks good to everyone else, I think it would make sense to mirror this flow for the tryReplaceExtractions part of rotate-and-reduce.

Copy link
Collaborator

@asraa asraa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank

// packed ciphertext is the constant term, i.e., the first element of the
// tensor. So a tensor by itself is always considered a reduction by that
// first element.
reduction.addRotation(0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was a little confused about this one and went back and forth trying to understand it. If there are no operations on it at this point, why would it be assumed to be accessed at 0? for e.g. what if the only op is a tensor.rotate %1 and then a tensor.extract?

Copy link
Collaborator Author

@j2kun j2kun Mar 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way I understand it (which may not be entirely correct) is that scalars in BGV/BFV/CKKS are interpreted as lying in the constant term of a ciphertext, and is zero-padded to fill up the rest of the tensor. Similarly, if we weren't aligning rotations, then the "default" thing to do to get a scalar addition of two values in a tensor is rotate them both to the zero-th slot and SIMD-add there. I think this is more of a convention than a deduction, but with that mental model, if the scalar you want is already in slot 0 then doing nothing is the same as accessing that index.

what if the only op is a tensor.rotate %1 and then a tensor.extract?

I think this would be incorrect if you think of this analysis as "what indices are accessed by ANY op ending at a particular SSA value," but we're restricting it to (semantic) fold ops by specific binary operations.

More pragmatically, rotations by zero are deleted by the canonicalizer, and because they're missing from the IR, the analysis can't recover them and improperly marks a full reduction as incomplete. If I removed this and had the check assume zero is always included (not sure what the alternative is), that would be incorrect if the original untouched tensor is not part of the reduction (sum all but the 0th entry).

@j2kun j2kun added the pull_ready Indicates whether a PR is ready to pull. The copybara worker will import for internal testing label Mar 30, 2024
This rewrite still has some efficiency issues, but it removes the hacky
issues from the previous lattice-based dataflow analysis, and supports
the unsupported examples from this PR.

Ports

 - dot_product
 - linear_polynomial
 - quadratic_polynomial
@copybara-service copybara-service bot merged commit 81e45d8 into google:main Apr 1, 2024
8 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull_ready Indicates whether a PR is ready to pull. The copybara worker will import for internal testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants