Skip to content

Commit

Permalink
No repeated primes in the limb moduli for CKKS (#793)
Browse files Browse the repository at this point in the history
* No repeated primes in the limb moduli for CKKS

* Moved error handling from the end of ParamsGenCKKSRNS() up to its middle

* Addressed PR review comments

* Fixed some bugs and added unit tests

* Addressed PR review comments (2)

* Addressed PR review comments (3)

---------

Co-authored-by: Dmitriy Suponitskiy <dsuponitskiy@dualitytech.com>
  • Loading branch information
dsuponitskiy and dsuponitskiy-duality committed Jun 17, 2024
1 parent 4543012 commit 6b801f8
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 53 deletions.
125 changes: 72 additions & 53 deletions src/pke/lib/scheme/ckksrns/ckksrns-parametergeneration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete
usint cyclOrder, usint numPrimes, usint scalingModSize,
usint firstModSize, uint32_t numPartQ,
COMPRESSION_LEVEL mPIntBootCiphertextCompressionLevel) const {
const auto cryptoParamsCKKSRNS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cryptoParams);
// the "const" modifier for cryptoParamsCKKSRNS and encodingParams below doesn't mean that the objects those 2 pointers
// point to are const (not changeable). it means that the pointers themselves are const only.
const auto cryptoParamsCKKSRNS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cryptoParams);
const EncodingParams encodingParams = cryptoParamsCKKSRNS->GetEncodingParams();

KeySwitchTechnique ksTech = cryptoParamsCKKSRNS->GetKeySwitchTechnique();
ScalingTechnique scalTech = cryptoParamsCKKSRNS->GetScalingTechnique();
Expand All @@ -65,11 +68,8 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete
OPENFHE_THROW(s.str());
}

usint extraModSize = 0;
if (scalTech == FLEXIBLEAUTOEXT) {
// TODO: Allow the user to specify this?
extraModSize = DCRT_MODULUS::DEFAULT_EXTRA_MOD_SIZE;
}
// TODO: Allow the user to specify this?
uint32_t extraModSize = (scalTech == FLEXIBLEAUTOEXT) ? DCRT_MODULUS::DEFAULT_EXTRA_MOD_SIZE : 0;

//// HE Standards compliance logic/check
SecurityLevel stdLevel = cryptoParamsCKKSRNS->GetStdLevel();
Expand All @@ -86,7 +86,7 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete

// GAUSSIAN security constraint
DistributionType distType = (cryptoParamsCKKSRNS->GetSecretKeyDist() == GAUSSIAN) ? HEStd_error : HEStd_ternary;
auto nRLWE = [&](usint q) -> uint32_t {
auto nRLWE = [&](uint32_t q) -> uint32_t {
return StdLatticeParm::FindRingDim(distType, stdLevel, q);
};

Expand All @@ -113,9 +113,15 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete
else if (n == 0) {
OPENFHE_THROW("Please specify the ring dimension or desired security level.");
}

if (encodingParams->GetBatchSize() > n / 2)
OPENFHE_THROW("The batch size cannot be larger than ring dimension / 2.");

if (encodingParams->GetBatchSize() & (encodingParams->GetBatchSize() - 1))
OPENFHE_THROW("The batch size can only be set to zero (for full packing) or a power of two.");
//// End HE Standards compliance logic/check

usint dcrtBits = scalingModSize;
uint32_t dcrtBits = scalingModSize;

uint32_t vecSize = (extraModSize == 0) ? numPrimes : numPrimes + 1;
std::vector<NativeInteger> moduliQ(vecSize);
Expand All @@ -125,107 +131,120 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete
moduliQ[numPrimes - 1] = q;
rootsQ[numPrimes - 1] = RootOfUnity(cyclOrder, moduliQ[numPrimes - 1]);

NativeInteger qNext = q;
NativeInteger qPrev = q;
NativeInteger maxPrime{q};
NativeInteger minPrime{q};
if (numPrimes > 1) {
if (scalTech != FLEXIBLEAUTO && scalTech != FLEXIBLEAUTOEXT) {
uint32_t cnt = 0;
for (usint i = numPrimes - 2; i >= 1; i--) {
NativeInteger qPrev = q;
NativeInteger qNext = q;
for (size_t i = numPrimes - 2, cnt = 0; i >= 1; --i, ++cnt) {
if ((cnt % 2) == 0) {
qPrev = PreviousPrime(qPrev, cyclOrder);
q = qPrev;
qPrev = PreviousPrime(qPrev, cyclOrder);
moduliQ[i] = qPrev;
}
else {
qNext = NextPrime(qNext, cyclOrder);
q = qNext;
qNext = NextPrime(qNext, cyclOrder);
moduliQ[i] = qNext;
}

moduliQ[i] = q;
rootsQ[i] = RootOfUnity(cyclOrder, moduliQ[i]);
cnt++;
if (moduliQ[i] > maxPrime)
maxPrime = moduliQ[i];
else if (moduliQ[i] < minPrime)
minPrime = moduliQ[i];

rootsQ[i] = RootOfUnity(cyclOrder, moduliQ[i]);
}
}
else { // FLEXIBLEAUTO
/* Scaling factors in FLEXIBLEAUTO are a bit fragile,
* in the sense that once one scaling factor gets far enough from the
* original scaling factor, subsequent level scaling factors quickly
* diverge to either 0 or infinity. To mitigate this problem to a certain
* extend, we have a special prime selection process in place. The goal is
* to maintain the scaling factor of all levels as close to the original
* scale factor of level 0 as possible.
*/

double sf = moduliQ[numPrimes - 1].ConvertToDouble();
uint32_t cnt = 0;
for (usint i = numPrimes - 2; i >= 1; i--) {
sf = static_cast<double>(pow(sf, 2) / moduliQ[i + 1].ConvertToDouble());
* in the sense that once one scaling factor gets far enough from the
* original scaling factor, subsequent level scaling factors quickly
* diverge to either 0 or infinity. To mitigate this problem to a certain
* extend, we have a special prime selection process in place. The goal is
* to maintain the scaling factor of all levels as close to the original
* scale factor of level 0 as possible.
*/
double sf = moduliQ[numPrimes - 1].ConvertToDouble();
for (size_t i = numPrimes - 2, cnt = 0; i >= 1; --i, ++cnt) {
sf = static_cast<double>(pow(sf, 2) / moduliQ[i + 1].ConvertToDouble());
NativeInteger sfInt = std::llround(sf);
NativeInteger sfRem = sfInt.Mod(cyclOrder);
bool hasSameMod = true;
if ((cnt % 2) == 0) {
NativeInteger sfInt = std::llround(sf);
NativeInteger sfRem = sfInt.Mod(cyclOrder);
NativeInteger qPrev = sfInt - NativeInteger(cyclOrder) - sfRem + NativeInteger(1);

bool hasSameMod = true;
while (hasSameMod) {
hasSameMod = false;
qPrev = PreviousPrime(qPrev, cyclOrder);
for (uint32_t j = i + 1; j < numPrimes; j++) {
for (size_t j = i + 1; j < numPrimes; j++) {
if (qPrev == moduliQ[j]) {
hasSameMod = true;
break;
}
}
}
moduliQ[i] = qPrev;
}
else {
NativeInteger sfInt = std::llround(sf);
NativeInteger sfRem = sfInt.Mod(cyclOrder);
NativeInteger qNext = sfInt + NativeInteger(cyclOrder) - sfRem + NativeInteger(1);
bool hasSameMod = true;
while (hasSameMod) {
hasSameMod = false;
qNext = NextPrime(qNext, cyclOrder);
for (uint32_t j = i + 1; j < numPrimes; j++) {
for (size_t j = i + 1; j < numPrimes; j++) {
if (qNext == moduliQ[j]) {
hasSameMod = true;
break;
}
}
}
moduliQ[i] = qNext;
}
if (moduliQ[i] > maxPrime)
maxPrime = moduliQ[i];
else if (moduliQ[i] < minPrime)
minPrime = moduliQ[i];

rootsQ[i] = RootOfUnity(cyclOrder, moduliQ[i]);
cnt++;
}
}
}

if (firstModSize == dcrtBits) { // this requires dcrtBits < 60
moduliQ[0] = PreviousPrime<NativeInteger>(qPrev, cyclOrder);
moduliQ[0] = NextPrime<NativeInteger>(maxPrime, cyclOrder);
}
else {
moduliQ[0] = LastPrime<NativeInteger>(firstModSize, cyclOrder);

// find if the value of moduliQ[0] is already in the vector starting with moduliQ[1] and
// if there is, then get another prime for moduliQ[0]
const auto pos = std::find(moduliQ.begin() + 1, moduliQ.end(), moduliQ[0]);
if (pos != moduliQ.end()) {
moduliQ[0] = NextPrime<NativeInteger>(maxPrime, cyclOrder);
}
}
if (moduliQ[0] > maxPrime)
maxPrime = moduliQ[0];

rootsQ[0] = RootOfUnity(cyclOrder, moduliQ[0]);

if (scalTech == FLEXIBLEAUTOEXT) {
// moduliQ[numPrimes] must still be 0, so it has to be populated now

// no need for extra checking as extraModSize is automatically chosen by the library
moduliQ[numPrimes] = FirstPrime<NativeInteger>(extraModSize - 1, cyclOrder);
rootsQ[numPrimes] = RootOfUnity(cyclOrder, moduliQ[numPrimes]);
auto tempMod = FirstPrime<NativeInteger>(extraModSize - 1, cyclOrder);
// check if tempMod has a duplicate in the vector (exclude moduliQ[numPrimes] from this operation):
const auto endPos = moduliQ.end() - 1;
auto pos = std::find(moduliQ.begin(), endPos, tempMod);
// if there is a duplicate, then we call NextPrime()
moduliQ[numPrimes] = (pos != endPos) ? NextPrime<NativeInteger>(maxPrime, cyclOrder) : tempMod;

rootsQ[numPrimes] = RootOfUnity(cyclOrder, moduliQ[numPrimes]);
}

auto paramsDCRT = std::make_shared<ILDCRTParams<BigInteger>>(cyclOrder, moduliQ, rootsQ);

cryptoParamsCKKSRNS->SetElementParams(paramsDCRT);

const EncodingParams encodingParams = cryptoParamsCKKSRNS->GetEncodingParams();
if (encodingParams->GetBatchSize() > n / 2)
OPENFHE_THROW("The batch size cannot be larger than ring dimension / 2.");

if (encodingParams->GetBatchSize() & (encodingParams->GetBatchSize() - 1))
OPENFHE_THROW("The batch size can only be set to zero (for full packing) or a power of two.");

// if no batch size was specified, we set batchSize = n/2 by default (for full
// packing)
// if no batch size was specified, we set batchSize = n/2 by default (for full packing)
if (encodingParams->GetBatchSize() == 0) {
uint32_t batchSize = n / 2;
EncodingParams encodingParamsNew(
Expand Down
34 changes: 34 additions & 0 deletions src/pke/unittest/utckksrns/UnitTestCKKSrns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ enum TEST_CASE_TYPE {
ADD_PACKED_PRECISION,
MULT_PACKED_PRECISION,
EVALSQUARE,
SMALL_SCALING_MOD_SIZE,
};

static std::ostream& operator<<(std::ostream& os, const TEST_CASE_TYPE& type) {
Expand Down Expand Up @@ -114,6 +115,9 @@ static std::ostream& operator<<(std::ostream& os, const TEST_CASE_TYPE& type) {
case EVALSQUARE:
typeName = "EVALSQUARE";
break;
case SMALL_SCALING_MOD_SIZE:
typeName = "SMALL_SCALING_MOD_SIZE";
break;
default:
typeName = "UNKNOWN";
break;
Expand Down Expand Up @@ -517,6 +521,10 @@ static std::vector<TEST_CASE_UTCKKSRNS> testCases = {
{ EVALSQUARE, "07", {CKKSRNS_SCHEME, RING_DIM, 7, DFLT, DSIZE, BATCH, DFLT, DFLT, DFLT, HEStd_NotSet, BV, FLEXIBLEAUTOEXT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, },
{ EVALSQUARE, "08", {CKKSRNS_SCHEME, RING_DIM, 7, DFLT, DSIZE, BATCH, DFLT, DFLT, DFLT, HEStd_NotSet, HYBRID, FLEXIBLEAUTOEXT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, },
#endif
// ==========================================
// TestType, Descr, Scheme, RDim, MultDepth, SModSize, DSize, BatchSz, SecKeyDist, MaxRelinSkDeg, FModSize, SecLvl, KSTech, ScalTech, LDigits, PtMod, StdDev, EvalAddCt, KSCt, MultTech, EncTech, PREMode
{ SMALL_SCALING_MOD_SIZE, "01", {CKKSRNS_SCHEME, 32768, 19, 22, DFLT, DFLT, DFLT, DFLT, 23, DFLT, DFLT, FIXEDMANUAL, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, },
{ SMALL_SCALING_MOD_SIZE, "02", {CKKSRNS_SCHEME, 32768, 16, 50, DFLT, DFLT, DFLT, DFLT, 50, HEStd_NotSet, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, },
#endif
// ==========================================
};
Expand Down Expand Up @@ -2143,6 +2151,28 @@ class UTCKKSRNS : public ::testing::TestWithParam<TEST_CASE_UTCKKSRNS> {
std::string name("EMSCRIPTEN_UNKNOWN");
#else
std::string name(demangle(__cxxabiv1::__cxa_current_exception_type()->name()));
#endif
std::cerr << "Unknown exception of type \"" << name << "\" thrown from " << __func__ << "()" << std::endl;
// make it fail
EXPECT_TRUE(0 == 1) << failmsg;
}
}

void UnitTest_Small_ScalingModSize(const TEST_CASE_UTCKKSRNS& testData,
const std::string& failmsg = std::string()) {
try {
CryptoContext<Element> cc(UnitTestGenerateContext(testData.params));
}
catch (std::exception& e) {
std::cerr << "Exception thrown from " << __func__ << "(): " << e.what() << std::endl;
// make it fail
EXPECT_TRUE(0 == 1) << failmsg;
}
catch (...) {
#if defined EMSCRIPTEN
std::string name("EMSCRIPTEN_UNKNOWN");
#else
std::string name(demangle(__cxxabiv1::__cxa_current_exception_type()->name()));
#endif
std::cerr << "Unknown exception of type \"" << name << "\" thrown from " << __func__ << "()" << std::endl;
// make it fail
Expand Down Expand Up @@ -2226,6 +2256,10 @@ TEST_P(UTCKKSRNS, CKKSRNS) {
break;
case EVALSQUARE:
UnitTest_EvalSquare(test, test.buildTestName());
break;
case SMALL_SCALING_MOD_SIZE:
UnitTest_Small_ScalingModSize(test, test.buildTestName());
break;
default:
break;
}
Expand Down

0 comments on commit 6b801f8

Please sign in to comment.