Skip to content

Commit

Permalink
Fixed some bugs and added unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dsuponitskiy-duality committed Jun 13, 2024
1 parent 151ef85 commit 9101e87
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 6 deletions.
21 changes: 15 additions & 6 deletions src/pke/lib/scheme/ckksrns/ckksrns-parametergeneration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,22 +219,31 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete
const auto pos = std::find(moduliQ.begin() + 1, moduliQ.end(), moduliQ[0]);
if (pos != moduliQ.end()) {
moduliQ[0] = NextPrime<NativeInteger>(maxPrime, cyclOrder);
maxPrime = moduliQ[0];
}
}
if (moduliQ[0] > maxPrime)
maxPrime = moduliQ[0];

rootsQ[0] = RootOfUnity(cyclOrder, moduliQ[0]);

if (scalTech == FLEXIBLEAUTOEXT) {
// find if the value of moduliQ[0] is already in the vector starting with moduliQ[1] and
// if there is, then get another prime for moduliQ[0]
const auto pos = std::find(moduliQ.begin(), moduliQ.end(), moduliQ[numPrimes]);
if (pos == moduliQ.end()) {
// find if the value of moduliQ[numPrimes] is already in the vector.
// in this case we must redefine the position of "end()" as "end()-1"
// currently "moduliQ[numPrimes] == 0" as it is not set up anywhere in this code yet
const auto endPos = moduliQ.end() - 1;
auto pos = std::find(moduliQ.begin(), endPos, moduliQ[numPrimes]);
if (pos == endPos) {
// no need for extra checking as extraModSize is automatically chosen by the library
moduliQ[numPrimes] = FirstPrime<NativeInteger>(extraModSize - 1, cyclOrder);
}
else {
moduliQ[numPrimes] = NextPrime<NativeInteger>(maxPrime, cyclOrder);
// maxPrime = moduliQ[numPrimes];
maxPrime = moduliQ[numPrimes];
}

pos = std::find(moduliQ.begin(), endPos, moduliQ[numPrimes]);
if (pos != endPos) {
moduliQ[numPrimes] = NextPrime<NativeInteger>(maxPrime, cyclOrder);
}

rootsQ[numPrimes] = RootOfUnity(cyclOrder, moduliQ[numPrimes]);
Expand Down
34 changes: 34 additions & 0 deletions src/pke/unittest/utckksrns/UnitTestCKKSrns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ enum TEST_CASE_TYPE {
ADD_PACKED_PRECISION,
MULT_PACKED_PRECISION,
EVALSQUARE,
SMALL_SCALING_MOD_SIZE,
};

static std::ostream& operator<<(std::ostream& os, const TEST_CASE_TYPE& type) {
Expand Down Expand Up @@ -114,6 +115,9 @@ static std::ostream& operator<<(std::ostream& os, const TEST_CASE_TYPE& type) {
case EVALSQUARE:
typeName = "EVALSQUARE";
break;
case SMALL_SCALING_MOD_SIZE:
typeName = "SMALL_SCALING_MOD_SIZE";
break;
default:
typeName = "UNKNOWN";
break;
Expand Down Expand Up @@ -517,6 +521,10 @@ static std::vector<TEST_CASE_UTCKKSRNS> testCases = {
{ EVALSQUARE, "07", {CKKSRNS_SCHEME, RING_DIM, 7, DFLT, DSIZE, BATCH, DFLT, DFLT, DFLT, HEStd_NotSet, BV, FLEXIBLEAUTOEXT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, },
{ EVALSQUARE, "08", {CKKSRNS_SCHEME, RING_DIM, 7, DFLT, DSIZE, BATCH, DFLT, DFLT, DFLT, HEStd_NotSet, HYBRID, FLEXIBLEAUTOEXT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, },
#endif
// ==========================================
// TestType, Descr, Scheme, RDim, MultDepth, SModSize, DSize, BatchSz, SecKeyDist, MaxRelinSkDeg, FModSize, SecLvl, KSTech, ScalTech, LDigits, PtMod, StdDev, EvalAddCt, KSCt, MultTech, EncTech, PREMode
{ SMALL_SCALING_MOD_SIZE, "01", {CKKSRNS_SCHEME, 32768, 19, 22, DFLT, DFLT, DFLT, DFLT, 23, DFLT, DFLT, FIXEDMANUAL, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, },
{ SMALL_SCALING_MOD_SIZE, "02", {CKKSRNS_SCHEME, 32768, 16, 50, DFLT, DFLT, DFLT, DFLT, 50, HEStd_NotSet, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, },
#endif
// ==========================================
};
Expand Down Expand Up @@ -2143,6 +2151,28 @@ class UTCKKSRNS : public ::testing::TestWithParam<TEST_CASE_UTCKKSRNS> {
std::string name("EMSCRIPTEN_UNKNOWN");
#else
std::string name(demangle(__cxxabiv1::__cxa_current_exception_type()->name()));
#endif
std::cerr << "Unknown exception of type \"" << name << "\" thrown from " << __func__ << "()" << std::endl;
// make it fail
EXPECT_TRUE(0 == 1) << failmsg;
}
}

void UnitTest_Small_ScalingModSize(const TEST_CASE_UTCKKSRNS& testData,
const std::string& failmsg = std::string()) {
try {
CryptoContext<Element> cc(UnitTestGenerateContext(testData.params));
}
catch (std::exception& e) {
std::cerr << "Exception thrown from " << __func__ << "(): " << e.what() << std::endl;
// make it fail
EXPECT_TRUE(0 == 1) << failmsg;
}
catch (...) {
#if defined EMSCRIPTEN
std::string name("EMSCRIPTEN_UNKNOWN");
#else
std::string name(demangle(__cxxabiv1::__cxa_current_exception_type()->name()));
#endif
std::cerr << "Unknown exception of type \"" << name << "\" thrown from " << __func__ << "()" << std::endl;
// make it fail
Expand Down Expand Up @@ -2226,6 +2256,10 @@ TEST_P(UTCKKSRNS, CKKSRNS) {
break;
case EVALSQUARE:
UnitTest_EvalSquare(test, test.buildTestName());
break;
case SMALL_SCALING_MOD_SIZE:
UnitTest_Small_ScalingModSize(test, test.buildTestName());
break;
default:
break;
}
Expand Down

0 comments on commit 9101e87

Please sign in to comment.