Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

poly: add lowering of poly.add #134

Merged
merged 7 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
74 changes: 68 additions & 6 deletions lib/Conversion/PolyToStandard/PolyToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
#include "include/Dialect/Poly/IR/PolyOps.h"
#include "include/Dialect/Poly/IR/PolyTypes.h"
#include "lib/Conversion/Utils.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
Expand All @@ -28,6 +29,9 @@ class PolyToStandardTypeConverter : public TypeConverter {
IntegerType elementTy =
IntegerType::get(ctx, attr.coefficientModulus().getBitWidth(),
IntegerType::SignednessSemantics::Signless);
// We must remove the ring attribute on the tensor, since the
// unrealized_conversion_casts cannot carry the poly.ring attribute
// through.
return RankedTensorType::get({idealDegree}, elementTy);
});

Expand Down Expand Up @@ -109,11 +113,66 @@ struct ConvertAdd : public OpConversionPattern<AddOp> {

using OpConversionPattern::OpConversionPattern;

// Convert add lowers a poly.add operation to arith operations. A poly.add
// operation is defined within the polynomial ring. Coefficients are added
// element-wise as elements of the ring, so they are performed modulo the
// coefficient modulus.
//
// To perform modular addition, assume that `cmod` is the coefficient modulus
// of the ring, and that `N` is the bitwidth used to store the ring elements.
// This may be much larger than `log_2(cmod)`.
//
// Let `x` and `y` be the inputs to modular addition, then:
// c1, n1 = addui_extended(x, y)
// If the coefficient modulus divides `2^N`, then return
// c0 = c1 % cmod
// Otherwise, compute the adjusted result:
// c0 = ((c1 % cmod) + (n1 * 2^N % cmod)) % cmod
//
// Note that `(c1 % cmod) + (n1 * 2^N % cmod)` will not overflow mod `2^N`.
// If it did, then it would require that `cmod > (2^N) / 2`.
// This would imply that `2^N % cmod = 2^N - cmod`.
// If the sum overflowed, then we would have
// ((c1 % cmod) + (2^N % cmod)) > 2^N
// ((c1 % cmod) + (2^N - cmod)) > 2^N
// ((c1 % cmod) > cmod
// Which is a contradiction.
LogicalResult matchAndRewrite(
AddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(https://github.com/google/heir/issues/104): implement
return failure();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto type = adaptor.getLhs().getType();

APInt mod =
cast<PolyType>(op.getResult().getType()).getRing().coefficientModulus();
auto cmod = b.create<arith::ConstantOp>(
DenseElementsAttr::get(cast<ShapedType>(type), {mod}));

auto addExtendedOp =
b.create<arith::AddUIExtendedOp>(adaptor.getLhs(), adaptor.getRhs());
auto c1ModOp = b.create<arith::RemUIOp>(addExtendedOp->getResult(0), cmod);
// If mod divides 2^N, c1modOp is our result.
if (mod.isPowerOf2()) {
rewriter.replaceOp(op, c1ModOp.getResult());
return success();
}
// Otherwise, add (n1 * 2^N % cmod)
APInt quotient, remainder;
APInt bigMod = APInt(mod.getBitWidth() + 1, 2) << (mod.getBitWidth() - 1);
APInt::udivrem(bigMod, mod.zext(bigMod.getBitWidth()), quotient, remainder);
remainder = remainder.trunc(mod.getBitWidth());

auto bitwidth = b.create<arith::ConstantOp>(
DenseElementsAttr::get(cast<ShapedType>(type), {remainder}));
auto adjustOp = b.create<arith::AddIOp>(c1ModOp, bitwidth);

auto selectOp = b.create<arith::SelectOp>(addExtendedOp.getResult(1),
c1ModOp, adjustOp);
// Mod the final result.
rewriter.replaceOp(op, b.create<arith::RemUIOp>(selectOp, cmod));

return success();
}
};

Expand All @@ -140,16 +199,19 @@ struct PolyToStandard : impl::PolyToStandardBase<PolyToStandard> {
ConversionTarget target(*context);
PolyToStandardTypeConverter typeConverter(context);

target.addLegalDialect<arith::ArithDialect>();

// target.addIllegalDialect<PolyDialect>();
target.addIllegalOp<FromTensorOp, ToTensorOp>();
target.addIllegalOp<FromTensorOp, ToTensorOp, AddOp>();
// target.addIllegalOp<AddOp>();
// target.addIllegalOp<MulOp>();

RewritePatternSet patterns(context);
patterns.add<ConvertFromTensor, ConvertToTensor>(typeConverter, context);

patterns.add<ConvertFromTensor, ConvertToTensor, ConvertAdd>(typeConverter,
context);
addStructuralConversionPatterns(typeConverter, patterns, target);

// TODO(https://github.com/google/heir/issues/143): Handle tensor of polys.
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Poly/IR/PolyOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ LogicalResult FromTensorOp::verify() {
}

APInt coefficientModulus = ring.coefficientModulus();
unsigned cmodBitWidth = coefficientModulus.logBase2();
unsigned cmodBitWidth = coefficientModulus.ceilLogBase2();
j2kun marked this conversation as resolved.
Show resolved Hide resolved
unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();

if (inputBitWidth > cmodBitWidth) {
Expand Down
54 changes: 54 additions & 0 deletions tests/poly/lower_poly.mlir
asraa marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#cycl_2048 = #poly.polynomial<1 + x**1024>
#ring = #poly.ring<cmod=4294967296, ideal=#cycl_2048>
#ring_prime = #poly.ring<cmod=4294967291, ideal=#cycl_2048>
module {
// CHECK-label: test_lower_from_tensor
func.func @test_lower_from_tensor() {
Expand Down Expand Up @@ -59,4 +60,57 @@ module {
func.call @f0(%arg) : (!poly.poly<#ring>) -> !poly.poly<#ring>
return
}

func.func @test_lower_add_power_of_two_cmod() -> !poly.poly<#ring> {
// 2 + 2x + 2x^2 + ... + 2x^{1023}
// CHECK: [[X:%.+]] = arith.constant dense<2> : [[T:tensor<1024xi32>]]
%coeffs1 = arith.constant dense<2> : tensor<1024xi32>
// CHECK: [[Y:%.+]] = arith.constant dense<3> : [[T]]
%coeffs2 = arith.constant dense<3> : tensor<1024xi32>
// CHECK-NOT: poly.from_tensor
// CHECK: [[XEXT:%.+]] = arith.extui [[X]] : [[T]] to [[TPOLY:tensor<1024xi64>]]
// CHECK: [[YEXT:%.+]] = arith.extui [[Y]] : [[T]] to [[TPOLY:tensor<1024xi64>]]
%poly0 = poly.from_tensor %coeffs1 : tensor<1024xi32> -> !poly.poly<#ring>
%poly1 = poly.from_tensor %coeffs2 : tensor<1024xi32> -> !poly.poly<#ring>
// CHECK: [[MOD:%.+]] = arith.constant dense<4294967296> : [[TPOLY]]
// CHECK-NEXT: [[ADD:%.+]], [[OVERFLOW:%.+]] = arith.addui_extended [[XEXT]], [[YEXT]] : [[TPOLY]], tensor<1024xi1>
// CHECK-NEXT: [[REM:%.+]] = arith.remui [[ADD]], [[MOD]] : [[TPOLY]]
%poly2 = poly.add(%poly0, %poly1) {ring = #ring} : !poly.poly<#ring>
// CHECK: return [[REM]] : [[TPOLY]]
return %poly2 : !poly.poly<#ring>
}

func.func @test_lower_add_prime_cmod() -> !poly.poly<#ring_prime> {
// CHECK: [[X:%.+]] = arith.constant dense<2> : [[TCOEFF:tensor<1024xi31>]]
%coeffs1 = arith.constant dense<2> : tensor<1024xi31>
// CHECK: [[Y:%.+]] = arith.constant dense<3> : [[TCOEFF]]
%coeffs2 = arith.constant dense<3> : tensor<1024xi31>
// CHECK-NOT: poly.from_tensor
// CHECK: [[XEXT:%.+]] = arith.extui [[X]] : [[TCOEFF]] to [[T:tensor<1024xi64>]]
// CHECK: [[YEXT:%.+]] = arith.extui [[Y]] : [[TCOEFF]] to [[T:tensor<1024xi64>]]
%poly0 = poly.from_tensor %coeffs1 : tensor<1024xi31> -> !poly.poly<#ring_prime>
%poly1 = poly.from_tensor %coeffs2 : tensor<1024xi31> -> !poly.poly<#ring_prime>
// CHECK: [[MOD:%.+]] = arith.constant dense<4294967291> : [[T]]
// CHECK-NEXT: [[ADD:%.+]], [[OVERFLOW:%.+]] = arith.addui_extended [[XEXT]], [[YEXT]] : [[T]], tensor<1024xi1>
// CHECK-NEXT: [[REM:%.+]] = arith.remui [[ADD]], [[MOD]] : [[T]]
// CHECK-NEXT: [[NMOD:%.+]] = arith.constant dense<25> : [[T]]
// CHECK-NEXT: [[REMPLUS2N:%.+]] = arith.addi [[REM]], [[NMOD]] : [[T]]
// CHECK-NEXT: [[RES:%.+]] = arith.select [[OVERFLOW]], [[REM]], [[REMPLUS2N]] : tensor<1024xi1>, [[T]]
// CHECK-NEXT: [[RESMOD:%.+]] = arith.remui [[RES]], [[MOD]] : [[T]]
%poly2 = poly.add(%poly0, %poly1) {ring = #ring_prime} : !poly.poly<#ring_prime>
// CHECK: return [[RESMOD]] : [[T]]
return %poly2 : !poly.poly<#ring_prime>
}

func.func @test_i32_coeff_with_i32_mod() -> () {
// CHECK: [[X:%.+]] = arith.constant dense<2> : [[TCOEFF:tensor<1024xi32>]]
%coeffs1 = arith.constant dense<2> : tensor<1024xi32>
// CHECK: [[Y:%.+]] = arith.constant dense<3> : [[TCOEFF]]
%coeffs2 = arith.constant dense<3> : tensor<1024xi32>
// CHECK-NOT: poly.from_tensor
%poly0 = poly.from_tensor %coeffs1 : tensor<1024xi32> -> !poly.poly<#ring_prime>
%poly1 = poly.from_tensor %coeffs2 : tensor<1024xi32> -> !poly.poly<#ring_prime>
// CHECK: return
return
}
}