-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrite rotation analysis to support dot product port from HECO #575
Conversation
5e067e5
to
7b85b3a
Compare
614b9c6
to
00142d4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank
// packed ciphertext is the constant term, i.e., the first element of the | ||
// tensor. So a tensor by itself is always considered a reduction by that | ||
// first element. | ||
reduction.addRotation(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was a little confused about this one and went back and forth trying to understand it. If there are no operations on it at this point, why would it be assumed to be accessed at 0? for e.g. what if the only op is a tensor.rotate %1 and then a tensor.extract?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way I understand it (which may not be entirely correct) is that scalars in BGV/BFV/CKKS are interpreted as lying in the constant term of a ciphertext, and is zero-padded to fill up the rest of the tensor. Similarly, if we weren't aligning rotations, then the "default" thing to do to get a scalar addition of two values in a tensor is rotate them both to the zero-th slot and SIMD-add there. I think this is more of a convention than a deduction, but with that mental model, if the scalar you want is already in slot 0 then doing nothing is the same as accessing that index.
what if the only op is a tensor.rotate %1 and then a tensor.extract?
I think this would be incorrect if you think of this analysis as "what indices are accessed by ANY op ending at a particular SSA value," but we're restricting it to (semantic) fold ops by specific binary operations.
More pragmatically, rotations by zero are deleted by the canonicalizer, and because they're missing from the IR, the analysis can't recover them and improperly marks a full reduction as incomplete. If I removed this and had the check assume zero is always included (not sure what the alternative is), that would be incorrect if the original untouched tensor is not part of the reduction (sum all but the 0th entry).
This rewrite still has some efficiency issues, but it removes the hacky issues from the previous lattice-based dataflow analysis, and supports the unsupported examples from this PR. Ports - dot_product - linear_polynomial - quadratic_polynomial
Ports:
Part of #571
The main thing this PR does it rewrite the rotation analysis used in
rotate-and-reduce
. It originally did not support any reduction that started from two separate tensors (e.g., thedot_product
example ported in this PR) because the lattice framework considered that overdetermined. After thinking about this, I realized that what I really wanted to do was "reset" the lattice after I found an overdetermined state, re-initializing it at the result of the operation that made it overdetermined. E.g.,The lattice framework doesn't support this because "resetting" is not a monotonic operation which is required of a semijoin lattice. I suspect that subclassing a higher level of the dataflow framework may have worked, but I felt it was just simpler to write a "generic" analysis with no constraints.
The new analysis does a single walk over the IR, and cumulatively builds up a
PartialReduction
struct corresponding to the set of indices visited, along with extra information about the operation being processed, the source tensor being reduced, and the "root" tensor which is the SSA value that may ultimately be replaced by a log-number of rotations in the pass, and is used as the extension point for the accumulation.The logic should be much clearer now, and the pass itself no longer requires any extra hacks to ensure consistency. If this looks good to everyone else, I think it would make sense to mirror this flow for the
tryReplaceExtractions
part ofrotate-and-reduce
.