diff --git a/src/pke/lib/scheme/ckksrns/ckksrns-parametergeneration.cpp b/src/pke/lib/scheme/ckksrns/ckksrns-parametergeneration.cpp index 7933fe01..4c86c732 100644 --- a/src/pke/lib/scheme/ckksrns/ckksrns-parametergeneration.cpp +++ b/src/pke/lib/scheme/ckksrns/ckksrns-parametergeneration.cpp @@ -219,22 +219,31 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr(maxPrime, cyclOrder); - maxPrime = moduliQ[0]; } } + if (moduliQ[0] > maxPrime) + maxPrime = moduliQ[0]; + rootsQ[0] = RootOfUnity(cyclOrder, moduliQ[0]); if (scalTech == FLEXIBLEAUTOEXT) { - // find if the value of moduliQ[0] is already in the vector starting with moduliQ[1] and - // if there is, then get another prime for moduliQ[0] - const auto pos = std::find(moduliQ.begin(), moduliQ.end(), moduliQ[numPrimes]); - if (pos == moduliQ.end()) { + // find if the value of moduliQ[numPrimes] is already in the vector. + // in this case we must redefine the position of "end()" as "end()-1" + // currently "moduliQ[numPrimes] == 0" as it is not set up anywhere in this code yet + const auto endPos = moduliQ.end() - 1; + auto pos = std::find(moduliQ.begin(), endPos, moduliQ[numPrimes]); + if (pos == endPos) { // no need for extra checking as extraModSize is automatically chosen by the library moduliQ[numPrimes] = FirstPrime(extraModSize - 1, cyclOrder); } else { moduliQ[numPrimes] = NextPrime(maxPrime, cyclOrder); - // maxPrime = moduliQ[numPrimes]; + maxPrime = moduliQ[numPrimes]; + } + + pos = std::find(moduliQ.begin(), endPos, moduliQ[numPrimes]); + if (pos != endPos) { + moduliQ[numPrimes] = NextPrime(maxPrime, cyclOrder); } rootsQ[numPrimes] = RootOfUnity(cyclOrder, moduliQ[numPrimes]); diff --git a/src/pke/unittest/utckksrns/UnitTestCKKSrns.cpp b/src/pke/unittest/utckksrns/UnitTestCKKSrns.cpp index 156f3af0..a8c7fa94 100644 --- a/src/pke/unittest/utckksrns/UnitTestCKKSrns.cpp +++ b/src/pke/unittest/utckksrns/UnitTestCKKSrns.cpp @@ -64,6 +64,7 @@ enum TEST_CASE_TYPE { ADD_PACKED_PRECISION, MULT_PACKED_PRECISION, EVALSQUARE, + SMALL_SCALING_MOD_SIZE, }; static std::ostream& operator<<(std::ostream& os, const TEST_CASE_TYPE& type) { @@ -114,6 +115,9 @@ static std::ostream& operator<<(std::ostream& os, const TEST_CASE_TYPE& type) { case EVALSQUARE: typeName = "EVALSQUARE"; break; + case SMALL_SCALING_MOD_SIZE: + typeName = "SMALL_SCALING_MOD_SIZE"; + break; default: typeName = "UNKNOWN"; break; @@ -517,6 +521,10 @@ static std::vector testCases = { { EVALSQUARE, "07", {CKKSRNS_SCHEME, RING_DIM, 7, DFLT, DSIZE, BATCH, DFLT, DFLT, DFLT, HEStd_NotSet, BV, FLEXIBLEAUTOEXT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, }, { EVALSQUARE, "08", {CKKSRNS_SCHEME, RING_DIM, 7, DFLT, DSIZE, BATCH, DFLT, DFLT, DFLT, HEStd_NotSet, HYBRID, FLEXIBLEAUTOEXT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, }, #endif + // ========================================== + // TestType, Descr, Scheme, RDim, MultDepth, SModSize, DSize, BatchSz, SecKeyDist, MaxRelinSkDeg, FModSize, SecLvl, KSTech, ScalTech, LDigits, PtMod, StdDev, EvalAddCt, KSCt, MultTech, EncTech, PREMode + { SMALL_SCALING_MOD_SIZE, "01", {CKKSRNS_SCHEME, 32768, 19, 22, DFLT, DFLT, DFLT, DFLT, 23, DFLT, DFLT, FIXEDMANUAL, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, }, + { SMALL_SCALING_MOD_SIZE, "02", {CKKSRNS_SCHEME, 32768, 16, 50, DFLT, DFLT, DFLT, DFLT, 50, HEStd_NotSet, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, }, #endif // ========================================== }; @@ -2143,6 +2151,28 @@ class UTCKKSRNS : public ::testing::TestWithParam { std::string name("EMSCRIPTEN_UNKNOWN"); #else std::string name(demangle(__cxxabiv1::__cxa_current_exception_type()->name())); +#endif + std::cerr << "Unknown exception of type \"" << name << "\" thrown from " << __func__ << "()" << std::endl; + // make it fail + EXPECT_TRUE(0 == 1) << failmsg; + } + } + + void UnitTest_Small_ScalingModSize(const TEST_CASE_UTCKKSRNS& testData, + const std::string& failmsg = std::string()) { + try { + CryptoContext cc(UnitTestGenerateContext(testData.params)); + } + catch (std::exception& e) { + std::cerr << "Exception thrown from " << __func__ << "(): " << e.what() << std::endl; + // make it fail + EXPECT_TRUE(0 == 1) << failmsg; + } + catch (...) { +#if defined EMSCRIPTEN + std::string name("EMSCRIPTEN_UNKNOWN"); +#else + std::string name(demangle(__cxxabiv1::__cxa_current_exception_type()->name())); #endif std::cerr << "Unknown exception of type \"" << name << "\" thrown from " << __func__ << "()" << std::endl; // make it fail @@ -2226,6 +2256,10 @@ TEST_P(UTCKKSRNS, CKKSRNS) { break; case EVALSQUARE: UnitTest_EvalSquare(test, test.buildTestName()); + break; + case SMALL_SCALING_MOD_SIZE: + UnitTest_Small_ScalingModSize(test, test.buildTestName()); + break; default: break; }