Skip to content

Commit

Permalink
756: Fixed EvalSumColsKeyGen() bug (#768)
Browse files Browse the repository at this point in the history
* Some cleanup for EvalSumCols() and EvalSumRows().

* Fixed a bug in EvalSumCols()

* Fixed EvalSumRowsKeyGen()/EvalSumColsKeyGen() incorrect behavior

* Fixed EvalSumRowsKeyGen()/EvalSumColsKeyGen() incorrect behavior(2)

* Added 2 functions to serialize/deserialize a selected set of eval keys

* Added unittests for EvalSumRows and EvalSumCols

* Removed a call to EvalSumKeyGen() before calls to EvalSumRowsKeyGen/EvalSumColsKeyGen

* Fixes to private functions generating indices for automorphism keys

* Made GetPartialEvalAutomorphismKeyMapPtr static

* Changed the unit tests

---------

Co-authored-by: Dmitriy Suponitskiy <dsuponitskiy@dualitytech.com>
  • Loading branch information
dsuponitskiy and dsuponitskiy-duality committed Jun 14, 2024
1 parent 4760cb6 commit 4543012
Show file tree
Hide file tree
Showing 11 changed files with 416 additions and 266 deletions.
83 changes: 76 additions & 7 deletions src/pke/include/cryptocontext.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,19 @@ class CryptoContextImpl : public Serializable {
value2);
}

/**
* Get automorphism keys for a specific secret key tag
*/
static std::shared_ptr<std::map<usint, EvalKey<Element>>> GetEvalAutomorphismKeyMapPtr(const std::string& keyID);
/**
* @brief Get automorphism keys for a specific secret key tag and an array of specific indices
* @param keyID - secret key tag
* @param indexList - array of specific indices to retrieve key for
* @return shared_ptr to std::map where the map key/data pair is index/automorphism key
*/
static std::shared_ptr<std::map<usint, EvalKey<Element>>> GetPartialEvalAutomorphismKeyMapPtr(
const std::string& keyID, const std::vector<uint32_t>& indexList);

// cached evalmult keys, by secret key UID
static std::map<std::string, std::vector<EvalKey<Element>>> s_evalMultKeyMap;
// cached evalautomorphism keys, by secret key UID
Expand Down Expand Up @@ -811,6 +824,65 @@ class CryptoContextImpl : public Serializable {
return true;
}

/**
* @brief Serialize automorphism keys for an array of specific indices within a specific secret key tag
* @param ser - stream to serialize to
* @param sertype - type of serialization
* @param keyID - secret key tag
* @param indexList - array of specific indices to serialize key for
* @return true on success
*/
template <typename ST>
static bool SerializeEvalAutomorphismKey(std::ostream& ser, const ST& sertype, const std::string& keyID,
const std::vector<uint32_t>& indexList) {
std::map<std::string, std::shared_ptr<std::map<usint, EvalKey<Element>>>> keyMap = {
{keyID, CryptoContextImpl<Element>::GetPartialEvalAutomorphismKeyMapPtr(keyID, indexList)}};

Serial::Serialize(keyMap, ser, sertype);
return true;
}

/**
* @brief Deserialize automorphism keys for an array of specific indices within a specific secret key tag
* @param ser - stream to serialize from
* @param sertype - type of serialization
* @param keyID - secret key tag
* @param indexList - array of specific indices to serialize key for
* @return true on success
*/
template <typename ST>
static bool DeserializeEvalAutomorphismKey(std::ostream& ser, const ST& sertype, const std::string& keyID,
const std::vector<uint32_t>& indexList) {
if (!indexList.size())
OPENFHE_THROW("indexList may not be empty");
if (keyID.empty())
OPENFHE_THROW("keyID may not be empty");

std::map<std::string, std::shared_ptr<std::map<usint, EvalKey<Element>>>> allDeserKeys;
Serial::Deserialize(allDeserKeys, ser, sertype);

const auto& keyMapIt = allDeserKeys.find(keyID);
if (keyMapIt == allDeserKeys.end()) {
OPENFHE_THROW("Deserialized automorphism keys are not generated for ID [" + keyID + "].");
}

// create a new map with evalkeys for the specified indices
std::map<usint, EvalKey<Element>> newMap;
for (const uint32_t indx : indexList) {
const auto& key = keyMapIt->find(indx);
if (key == keyMapIt->end()) {
OPENFHE_THROW("No automorphism key generated for index [" + std::to_string(indx) + "] within keyID [" +
keyID + "].");
}
newMap[indx] = key->second;
}

CryptoContextImpl<Element>::InsertEvalAutomorphismKey(
std::make_shared<std::map<uint32_t, EvalKey<Element>>>(newMap), keyID);

return true;
}

/**
* DeserializeEvalAutomorphismKey deserialize all keys in the serialization
* deserialized keys silently replace any existing matching keys
Expand Down Expand Up @@ -984,12 +1056,9 @@ class CryptoContextImpl : public Serializable {
/**
* Get automorphism keys for a specific secret key tag
*/
static std::shared_ptr<std::map<usint, EvalKey<Element>>> GetEvalAutomorphismKeyMapPtr(const std::string& keyID);

static std::map<usint, EvalKey<Element>>& GetEvalAutomorphismKeyMap(const std::string& keyID) {
return *(CryptoContextImpl<Element>::GetEvalAutomorphismKeyMapPtr(keyID));
}

/**
* Get a map of summation keys (each is composed of several automorphism keys) for all secret keys
*/
Expand Down Expand Up @@ -2702,25 +2771,25 @@ class CryptoContextImpl : public Serializable {
* encoding
*
* @param ciphertext the input ciphertext.
* @param rowSize size of rows in the matrix
* @param numRows number of rows in the matrix
* @param &evalSumKeyMap - reference to the map of evaluation keys generated by
* @param subringDim the current cyclotomic order/subring dimension. If set to
* 0, we use the full cyclotomic order.
* @return resulting ciphertext
*/
Ciphertext<Element> EvalSumRows(ConstCiphertext<Element> ciphertext, usint rowSize,
Ciphertext<Element> EvalSumRows(ConstCiphertext<Element> ciphertext, usint numRows,
const std::map<usint, EvalKey<Element>>& evalSumKeyMap, usint subringDim = 0) const;

/**
* Sums all elements over column-vectors in a matrix - works only with packed
* encoding
*
* @param ciphertext the input ciphertext.
* @param rowSize size of rows in the matrix
* @param numCols number of columns in the matrix
* @param &evalSumKeyMap - reference to the map of evaluation keys generated by
* @return resulting ciphertext
*/
Ciphertext<Element> EvalSumCols(ConstCiphertext<Element> ciphertext, usint rowSize,
Ciphertext<Element> EvalSumCols(ConstCiphertext<Element> ciphertext, usint numCols,
const std::map<usint, EvalKey<Element>>& evalSumKeyMap) const;

//------------------------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions src/pke/include/encoding/plaintextfactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class PlaintextFactory {
// Check if plaintext has got enough slots for data (value)
usint ringDim = vp->GetRingDimension();
size_t valueSize = value.size();
if (SCHEME::CKKSRNS_SCHEME == schemeID && valueSize > ringDim / 2) {
if (isCKKS(schemeID) && valueSize > ringDim / 2) {
OPENFHE_THROW("The size [" + std::to_string(valueSize) +
"] of the vector with values should not be greater than ringDim/2 [" +
std::to_string(ringDim / 2) + "] if the scheme is CKKS");
Expand Down Expand Up @@ -111,7 +111,7 @@ class PlaintextFactory {
// Check if plaintext has got enough slots for data (value)
usint ringDim = vp->GetRingDimension();
size_t valueSize = value.size();
if (SCHEME::CKKSRNS_SCHEME == schemeID && valueSize > ringDim / 2) {
if (isCKKS(schemeID) && valueSize > ringDim / 2) {
OPENFHE_THROW("The size [" + std::to_string(valueSize) +
"] of the vector with values should not be greater than ringDim/2 [" +
std::to_string(ringDim / 2) + "] if the scheme is CKKS");
Expand Down
2 changes: 1 addition & 1 deletion src/pke/include/scheme/gen-cryptocontext-params.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ class Params {
}

// setters
// They all(must be virtual, so any of them can be disabled in the derived class
// They all must be virtual, so any of them can be disabled in the derived class
virtual void SetPlaintextModulus(PlaintextModulus ptModulus0) {
ptModulus = ptModulus0;
}
Expand Down
117 changes: 53 additions & 64 deletions src/pke/include/schemebase/base-advancedshe.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include <vector>
#include <string>
#include <map>
#include <unordered_set>

/**
* @namespace lbcrypto
Expand Down Expand Up @@ -224,8 +225,8 @@ class AdvancedSHEBase {
* @return returns the evaluation keys
*/
virtual std::shared_ptr<std::map<usint, EvalKey<Element>>> EvalSumRowsKeyGen(const PrivateKey<Element> privateKey,
const PublicKey<Element> publicKey,
usint rowSize, usint subringDim) const;
usint rowSize, usint subringDim,
std::vector<usint>& indices) const;

/**
* Virtual function to generate the automorphism keys for EvalSumCols; works
Expand All @@ -235,84 +236,70 @@ class AdvancedSHEBase {
* @param publicKey public key.
* @return returns the evaluation keys
*/
virtual std::shared_ptr<std::map<usint, EvalKey<Element>>> EvalSumColsKeyGen(
const PrivateKey<Element> privateKey, const PublicKey<Element> publicKey) const;
virtual std::shared_ptr<std::map<usint, EvalKey<Element>>> EvalSumColsKeyGen(const PrivateKey<Element> privateKey,
std::vector<usint>& indices) const;

/**
* Sums all elements in log (batch size) time - works only with packed
* encoding
*
* @param ciphertext the input ciphertext.
* @param batchSize size of the batch to be summed up
* @param &evalKeys - reference to the map of evaluation keys generated by
* EvalAutomorphismKeyGen.
* @return resulting ciphertext
*/
* @brief Sums all elements in log (batch size) time - works only with packedvencoding
* @param ciphertext the input ciphertext.
* @param batchSize size of the batch to be summed up
* @param evalKeys - reference to the map of evaluation keys generated by EvalAutomorphismKeyGen.
* @return resulting ciphertext
*/
virtual Ciphertext<Element> EvalSum(ConstCiphertext<Element> ciphertext, usint batchSize,
const std::map<usint, EvalKey<Element>>& evalSumKeyMap) const;

/**
* Sums all elements over row-vectors in a matrix - works only with packed
* encoding
*
* @param ciphertext the input ciphertext.
* @param rowSize size of rows in the matrix
* @param &evalKeys - reference to the map of evaluation keys generated by
* @param subringDim the current cyclotomic order/subring dimension. If set to
* 0, we use the full cyclotomic order. EvalAutomorphismKeyGen.
* @return resulting ciphertext
*/
virtual Ciphertext<Element> EvalSumRows(ConstCiphertext<Element> ciphertext, usint rowSize,
const std::map<usint, EvalKey<Element>>& evalSumRowsKeyMap,
usint subringDim) const;
* @brief Sums all elements over row-vectors in a matrix - works only with packed encoding.
* @param ciphertext the input ciphertext.
* @param numRows number of rows in the matrix
* @param evalSumKeys - reference to the map of evaluation keys generated by EvalAutomorphismKeyGen.
* @param subringDim the current cyclotomic order/subring dimension. If set to 0, we use the full cyclotomic order.
* @return resulting ciphertext
*/
virtual Ciphertext<Element> EvalSumRows(ConstCiphertext<Element> ciphertext, uint32_t numRows,
const std::map<uint32_t, EvalKey<Element>>& evalSumKeys,
uint32_t subringDim) const;

/**
* Sums all elements over column-vectors in a matrix - works only with
* packed encoding
*
* @param ciphertext the input ciphertext.
* @param rowSize size of rows in the matrixs
* @param &evalKeys - reference to the map of evaluation keys generated by
* EvalAutomorphismKeyGen.
* @return resulting ciphertext
*/
virtual Ciphertext<Element> EvalSumCols(ConstCiphertext<Element> ciphertext, usint batchSize,
const std::map<usint, EvalKey<Element>>& evalSumColsKeyMap,
const std::map<usint, EvalKey<Element>>& rightEvalKeys) const;
* @brief Sums all elements over column-vectors in a matrix - works only with packed encoding. The code is
* implemented according to the specifications in https://eprint.iacr.org/2018/662.pdf
* @param ciphertext the input ciphertext.
* @param numCols number of columns in the matrixs
* @param evalSumKeys - reference to the map of evaluation keys generated by EvalAutomorphismKeyGen.
* @param rightEvalKeys - reference to the map of
* @return resulting ciphertext
*/
virtual Ciphertext<Element> EvalSumCols(ConstCiphertext<Element> ciphertext, uint32_t numCols,
const std::map<uint32_t, EvalKey<Element>>& evalSumKeys,
const std::map<uint32_t, EvalKey<Element>>& rightEvalKeys) const;

//------------------------------------------------------------------------------
// Advanced SHE EVAL INNER PRODUCT
//------------------------------------------------------------------------------

/**
* Evaluates inner product in batched encoding
*
* @param ciphertext1 first vector.
* @param ciphertext2 second vector.
* @param batchSize size of the batch to be summed up
* @param &evalSumKeys - reference to the map of evaluation keys generated
* by EvalAutomorphismKeyGen.
* @param &evalMultKey - reference to the evaluation key generated by
* EvalMultKeyGen.
* @return resulting ciphertext
*/
* @brief Evaluates inner product in batched encoding
* @param ciphertext1 first vector.
* @param ciphertext2 second vector.
* @param batchSize size of the batch to be summed up
* @param evalSumKeys - reference to the map of evaluation keys generated by EvalAutomorphismKeyGen.
* @param evalMultKey - reference to the evaluation key generated by EvalMultKeyGen.
* @return resulting ciphertext
*/
virtual Ciphertext<Element> EvalInnerProduct(ConstCiphertext<Element> ciphertext1,
ConstCiphertext<Element> ciphertext2, usint batchSize,
const std::map<usint, EvalKey<Element>>& evalKeyMap,
const EvalKey<Element> evalMultKey) const;

/**
* Evaluates inner product in batched encoding
*
* @param ciphertext1 first vector.
* @param plaintext plaintext.
* @param batchSize size of the batch to be summed up
* @param &evalSumKeys - reference to the map of evaluation keys generated
* by EvalAutomorphismKeyGen.
* @param &evalMultKey - reference to the evaluation key generated by
* EvalMultKeyGen.
* @return resulting ciphertext
*/
* @brief Evaluates inner product in batched encoding
* @param ciphertext first vector.
* @param plaintext plaintext.
* @param batchSize size of the batch to be summed up
* @param evalSumKeys - reference to the map of evaluation keys generated by EvalAutomorphismKeyGen.
* @return resulting ciphertext
*/
virtual Ciphertext<Element> EvalInnerProduct(ConstCiphertext<Element> ciphertext, ConstPlaintext plaintext,
usint batchSize,
const std::map<usint, EvalKey<Element>>& evalKeyMap) const;
Expand Down Expand Up @@ -348,13 +335,15 @@ class AdvancedSHEBase {
//------------------------------------------------------------------------------

protected:
std::vector<usint> GenerateIndices_2n(usint batchSize, usint m) const;
std::unordered_set<uint32_t> GenerateIndices_2n(usint batchSize, usint m) const;

std::unordered_set<uint32_t> GenerateIndices2nComplex(usint batchSize, usint m) const;

std::vector<usint> GenerateIndices2nComplex(usint batchSize, usint m) const;
std::unordered_set<uint32_t> GenerateIndices2nComplexRows(usint rowSize, usint m) const;

std::vector<usint> GenerateIndices2nComplexRows(usint rowSize, usint m) const;
std::unordered_set<uint32_t> GenerateIndices2nComplexCols(usint batchSize, usint m) const;

std::vector<usint> GenerateIndices2nComplexCols(usint batchSize, usint m) const;
std::unordered_set<uint32_t> GenerateIndexListForEvalSum(const PrivateKey<Element>& privateKey) const;

Ciphertext<Element> EvalSum_2n(ConstCiphertext<Element> ciphertext, usint batchSize, usint m,
const std::map<usint, EvalKey<Element>>& evalKeyMap) const;
Expand Down
6 changes: 3 additions & 3 deletions src/pke/include/schemebase/base-scheme.h
Original file line number Diff line number Diff line change
Expand Up @@ -1188,11 +1188,11 @@ class SchemeBase {
const PublicKey<Element> publicKey) const;

virtual std::shared_ptr<std::map<usint, EvalKey<Element>>> EvalSumRowsKeyGen(const PrivateKey<Element> privateKey,
const PublicKey<Element> publicKey,
usint rowSize, usint subringDim) const;
usint rowSize, usint subringDim,
std::vector<usint>& indices) const;

virtual std::shared_ptr<std::map<usint, EvalKey<Element>>> EvalSumColsKeyGen(
const PrivateKey<Element> privateKey, const PublicKey<Element> publicKey) const;
const PrivateKey<Element> privateKey, std::vector<usint>& indices) const;

virtual Ciphertext<Element> EvalSum(ConstCiphertext<Element> ciphertext, usint batchSize,
const std::map<usint, EvalKey<Element>>& evalKeyMap) const {
Expand Down
Loading

0 comments on commit 4543012

Please sign in to comment.