Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No repeated primes in the limb moduli for CKKS #793

Merged
merged 6 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
127 changes: 74 additions & 53 deletions src/pke/lib/scheme/ckksrns/ckksrns-parametergeneration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete
usint cyclOrder, usint numPrimes, usint scalingModSize,
usint firstModSize, uint32_t numPartQ,
COMPRESSION_LEVEL mPIntBootCiphertextCompressionLevel) const {
const auto cryptoParamsCKKSRNS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cryptoParams);
// the "const" modifier for cryptoParamsCKKSRNS and encodingParams below doesn't mean that the objects those 2 pointers
// point to are const (not changeable). it means that the pointers themselves are const only.
const auto cryptoParamsCKKSRNS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cryptoParams);
const EncodingParams encodingParams = cryptoParamsCKKSRNS->GetEncodingParams();

KeySwitchTechnique ksTech = cryptoParamsCKKSRNS->GetKeySwitchTechnique();
ScalingTechnique scalTech = cryptoParamsCKKSRNS->GetScalingTechnique();
Expand All @@ -65,11 +68,8 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete
OPENFHE_THROW(s.str());
}

usint extraModSize = 0;
if (scalTech == FLEXIBLEAUTOEXT) {
// TODO: Allow the user to specify this?
extraModSize = DCRT_MODULUS::DEFAULT_EXTRA_MOD_SIZE;
}
// TODO: Allow the user to specify this?
uint32_t extraModSize = (scalTech == FLEXIBLEAUTOEXT) ? DCRT_MODULUS::DEFAULT_EXTRA_MOD_SIZE : 0;

//// HE Standards compliance logic/check
SecurityLevel stdLevel = cryptoParamsCKKSRNS->GetStdLevel();
Expand All @@ -86,7 +86,7 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete

// GAUSSIAN security constraint
DistributionType distType = (cryptoParamsCKKSRNS->GetSecretKeyDist() == GAUSSIAN) ? HEStd_error : HEStd_ternary;
auto nRLWE = [&](usint q) -> uint32_t {
auto nRLWE = [&](uint32_t q) -> uint32_t {
return StdLatticeParm::FindRingDim(distType, stdLevel, q);
};

Expand All @@ -113,9 +113,15 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete
else if (n == 0) {
OPENFHE_THROW("Please specify the ring dimension or desired security level.");
}

if (encodingParams->GetBatchSize() > n / 2)
OPENFHE_THROW("The batch size cannot be larger than ring dimension / 2.");

if (encodingParams->GetBatchSize() & (encodingParams->GetBatchSize() - 1))
OPENFHE_THROW("The batch size can only be set to zero (for full packing) or a power of two.");
//// End HE Standards compliance logic/check

usint dcrtBits = scalingModSize;
uint32_t dcrtBits = scalingModSize;

uint32_t vecSize = (extraModSize == 0) ? numPrimes : numPrimes + 1;
std::vector<NativeInteger> moduliQ(vecSize);
Expand All @@ -125,107 +131,122 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete
moduliQ[numPrimes - 1] = q;
rootsQ[numPrimes - 1] = RootOfUnity(cyclOrder, moduliQ[numPrimes - 1]);

NativeInteger qNext = q;
NativeInteger qPrev = q;
NativeInteger maxPrime{q};
NativeInteger minPrime{q};
if (numPrimes > 1) {
if (scalTech != FLEXIBLEAUTO && scalTech != FLEXIBLEAUTOEXT) {
uint32_t cnt = 0;
for (usint i = numPrimes - 2; i >= 1; i--) {
NativeInteger qPrev = q;
NativeInteger qNext = q;
for (size_t i = numPrimes - 2, cnt = 0; i >= 1; --i, ++cnt) {
if ((cnt % 2) == 0) {
qPrev = PreviousPrime(qPrev, cyclOrder);
q = qPrev;
qPrev = PreviousPrime(qPrev, cyclOrder);
moduliQ[i] = qPrev;
}
else {
qNext = NextPrime(qNext, cyclOrder);
q = qNext;
qNext = NextPrime(qNext, cyclOrder);
moduliQ[i] = qNext;
}

moduliQ[i] = q;
rootsQ[i] = RootOfUnity(cyclOrder, moduliQ[i]);
cnt++;
if (moduliQ[i] > maxPrime)
maxPrime = moduliQ[i];
else if (moduliQ[i] < minPrime)
minPrime = moduliQ[i];

rootsQ[i] = RootOfUnity(cyclOrder, moduliQ[i]);
}
}
else { // FLEXIBLEAUTO
/* Scaling factors in FLEXIBLEAUTO are a bit fragile,
* in the sense that once one scaling factor gets far enough from the
* original scaling factor, subsequent level scaling factors quickly
* diverge to either 0 or infinity. To mitigate this problem to a certain
* extend, we have a special prime selection process in place. The goal is
* to maintain the scaling factor of all levels as close to the original
* scale factor of level 0 as possible.
*/

double sf = moduliQ[numPrimes - 1].ConvertToDouble();
uint32_t cnt = 0;
for (usint i = numPrimes - 2; i >= 1; i--) {
sf = static_cast<double>(pow(sf, 2) / moduliQ[i + 1].ConvertToDouble());
* in the sense that once one scaling factor gets far enough from the
* original scaling factor, subsequent level scaling factors quickly
* diverge to either 0 or infinity. To mitigate this problem to a certain
* extend, we have a special prime selection process in place. The goal is
* to maintain the scaling factor of all levels as close to the original
* scale factor of level 0 as possible.
*/
double sf = moduliQ[numPrimes - 1].ConvertToDouble();
for (size_t i = numPrimes - 2, cnt = 0; i >= 1; --i, ++cnt) {
sf = static_cast<double>(pow(sf, 2) / moduliQ[i + 1].ConvertToDouble());
NativeInteger sfInt = std::llround(sf);
NativeInteger sfRem = sfInt.Mod(cyclOrder);
bool hasSameMod = true;
if ((cnt % 2) == 0) {
NativeInteger sfInt = std::llround(sf);
NativeInteger sfRem = sfInt.Mod(cyclOrder);
NativeInteger qPrev = sfInt - NativeInteger(cyclOrder) - sfRem + NativeInteger(1);

bool hasSameMod = true;
while (hasSameMod) {
hasSameMod = false;
qPrev = PreviousPrime(qPrev, cyclOrder);
for (uint32_t j = i + 1; j < numPrimes; j++) {
for (size_t j = i + 1; j < numPrimes; j++) {
if (qPrev == moduliQ[j]) {
hasSameMod = true;
break;
}
}
}
moduliQ[i] = qPrev;
}
else {
NativeInteger sfInt = std::llround(sf);
NativeInteger sfRem = sfInt.Mod(cyclOrder);
NativeInteger qNext = sfInt + NativeInteger(cyclOrder) - sfRem + NativeInteger(1);
bool hasSameMod = true;
while (hasSameMod) {
hasSameMod = false;
qNext = NextPrime(qNext, cyclOrder);
for (uint32_t j = i + 1; j < numPrimes; j++) {
for (size_t j = i + 1; j < numPrimes; j++) {
if (qNext == moduliQ[j]) {
hasSameMod = true;
break;
}
}
}
moduliQ[i] = qNext;
}
if (moduliQ[i] > maxPrime)
maxPrime = moduliQ[i];
else if (moduliQ[i] < minPrime)
minPrime = moduliQ[i];

rootsQ[i] = RootOfUnity(cyclOrder, moduliQ[i]);
cnt++;
}
}
}

if (firstModSize == dcrtBits) { // this requires dcrtBits < 60
moduliQ[0] = PreviousPrime<NativeInteger>(qPrev, cyclOrder);
moduliQ[0] = NextPrime<NativeInteger>(maxPrime, cyclOrder);
}
else {
moduliQ[0] = LastPrime<NativeInteger>(firstModSize, cyclOrder);

// find if the value of moduliQ[0] is already in the vector starting with moduliQ[1] and
// if there is, then get another prime for moduliQ[0]
const auto pos = std::find(moduliQ.begin() + 1, moduliQ.end(), moduliQ[0]);
if (pos != moduliQ.end()) {
moduliQ[0] = NextPrime<NativeInteger>(maxPrime, cyclOrder);
}
}
if (moduliQ[0] > maxPrime)
maxPrime = moduliQ[0];

rootsQ[0] = RootOfUnity(cyclOrder, moduliQ[0]);

if (scalTech == FLEXIBLEAUTOEXT) {
// moduliQ[numPrimes] must still be 0, so it has to be populated now

// no need for extra checking as extraModSize is automatically chosen by the library
moduliQ[numPrimes] = FirstPrime<NativeInteger>(extraModSize - 1, cyclOrder);
rootsQ[numPrimes] = RootOfUnity(cyclOrder, moduliQ[numPrimes]);
auto tempMod = FirstPrime<NativeInteger>(extraModSize - 1, cyclOrder);
if (tempMod > maxPrime)
yspolyakov marked this conversation as resolved.
Show resolved Hide resolved
maxPrime = tempMod;
// check if tempMod has a duplicate in the vector (exclude moduliQ[numPrimes] from this operation):
const auto endPos = moduliQ.end() - 1;
auto pos = std::find(moduliQ.begin(), endPos, tempMod);
// if there is a duplicate, then we call NextPrime()
moduliQ[numPrimes] = (pos != endPos) ? NextPrime<NativeInteger>(maxPrime, cyclOrder) : tempMod;

rootsQ[numPrimes] = RootOfUnity(cyclOrder, moduliQ[numPrimes]);
}

auto paramsDCRT = std::make_shared<ILDCRTParams<BigInteger>>(cyclOrder, moduliQ, rootsQ);

cryptoParamsCKKSRNS->SetElementParams(paramsDCRT);

const EncodingParams encodingParams = cryptoParamsCKKSRNS->GetEncodingParams();
if (encodingParams->GetBatchSize() > n / 2)
OPENFHE_THROW("The batch size cannot be larger than ring dimension / 2.");

if (encodingParams->GetBatchSize() & (encodingParams->GetBatchSize() - 1))
OPENFHE_THROW("The batch size can only be set to zero (for full packing) or a power of two.");

// if no batch size was specified, we set batchSize = n/2 by default (for full
// packing)
// if no batch size was specified, we set batchSize = n/2 by default (for full packing)
if (encodingParams->GetBatchSize() == 0) {
uint32_t batchSize = n / 2;
EncodingParams encodingParamsNew(
Expand Down
34 changes: 34 additions & 0 deletions src/pke/unittest/utckksrns/UnitTestCKKSrns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ enum TEST_CASE_TYPE {
ADD_PACKED_PRECISION,
MULT_PACKED_PRECISION,
EVALSQUARE,
SMALL_SCALING_MOD_SIZE,
};

static std::ostream& operator<<(std::ostream& os, const TEST_CASE_TYPE& type) {
Expand Down Expand Up @@ -114,6 +115,9 @@ static std::ostream& operator<<(std::ostream& os, const TEST_CASE_TYPE& type) {
case EVALSQUARE:
typeName = "EVALSQUARE";
break;
case SMALL_SCALING_MOD_SIZE:
typeName = "SMALL_SCALING_MOD_SIZE";
break;
default:
typeName = "UNKNOWN";
break;
Expand Down Expand Up @@ -517,6 +521,10 @@ static std::vector<TEST_CASE_UTCKKSRNS> testCases = {
{ EVALSQUARE, "07", {CKKSRNS_SCHEME, RING_DIM, 7, DFLT, DSIZE, BATCH, DFLT, DFLT, DFLT, HEStd_NotSet, BV, FLEXIBLEAUTOEXT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, },
{ EVALSQUARE, "08", {CKKSRNS_SCHEME, RING_DIM, 7, DFLT, DSIZE, BATCH, DFLT, DFLT, DFLT, HEStd_NotSet, HYBRID, FLEXIBLEAUTOEXT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, },
#endif
// ==========================================
// TestType, Descr, Scheme, RDim, MultDepth, SModSize, DSize, BatchSz, SecKeyDist, MaxRelinSkDeg, FModSize, SecLvl, KSTech, ScalTech, LDigits, PtMod, StdDev, EvalAddCt, KSCt, MultTech, EncTech, PREMode
{ SMALL_SCALING_MOD_SIZE, "01", {CKKSRNS_SCHEME, 32768, 19, 22, DFLT, DFLT, DFLT, DFLT, 23, DFLT, DFLT, FIXEDMANUAL, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, },
{ SMALL_SCALING_MOD_SIZE, "02", {CKKSRNS_SCHEME, 32768, 16, 50, DFLT, DFLT, DFLT, DFLT, 50, HEStd_NotSet, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, },
#endif
// ==========================================
};
Expand Down Expand Up @@ -2143,6 +2151,28 @@ class UTCKKSRNS : public ::testing::TestWithParam<TEST_CASE_UTCKKSRNS> {
std::string name("EMSCRIPTEN_UNKNOWN");
#else
std::string name(demangle(__cxxabiv1::__cxa_current_exception_type()->name()));
#endif
std::cerr << "Unknown exception of type \"" << name << "\" thrown from " << __func__ << "()" << std::endl;
// make it fail
EXPECT_TRUE(0 == 1) << failmsg;
}
}

void UnitTest_Small_ScalingModSize(const TEST_CASE_UTCKKSRNS& testData,
const std::string& failmsg = std::string()) {
try {
CryptoContext<Element> cc(UnitTestGenerateContext(testData.params));
}
catch (std::exception& e) {
std::cerr << "Exception thrown from " << __func__ << "(): " << e.what() << std::endl;
// make it fail
EXPECT_TRUE(0 == 1) << failmsg;
}
catch (...) {
#if defined EMSCRIPTEN
std::string name("EMSCRIPTEN_UNKNOWN");
#else
std::string name(demangle(__cxxabiv1::__cxa_current_exception_type()->name()));
#endif
std::cerr << "Unknown exception of type \"" << name << "\" thrown from " << __func__ << "()" << std::endl;
// make it fail
Expand Down Expand Up @@ -2226,6 +2256,10 @@ TEST_P(UTCKKSRNS, CKKSRNS) {
break;
case EVALSQUARE:
UnitTest_EvalSquare(test, test.buildTestName());
break;
case SMALL_SCALING_MOD_SIZE:
UnitTest_Small_ScalingModSize(test, test.buildTestName());
break;
default:
break;
}
Expand Down