Skip to content

Commit

Permalink
[ckks] : fixed bug in MultiplyByDiagMatrixBSGS that was modifying inp…
Browse files Browse the repository at this point in the history
…uts (#159)
  • Loading branch information
Pro7ech committed Nov 25, 2021
1 parent bd2ac2f commit c761a16
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 20 deletions.
8 changes: 4 additions & 4 deletions ckks/ckks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ func testLinearTransform(testContext *testParams, t *testing.T) {

eval := testContext.evaluator.WithKey(rlwe.EvaluationKey{Rlk: testContext.rlk, Rtks: rotKey})

res := eval.LinearTransformNew(ciphertext1, ptDiagMatrix)[0]
eval.LinearTransform(ciphertext1, ptDiagMatrix, []*Ciphertext{ciphertext1})

tmp := make([]complex128, params.Slots())
copy(tmp, values1)
Expand All @@ -1133,7 +1133,7 @@ func testLinearTransform(testContext *testParams, t *testing.T) {
values1[i] += tmp[(i+15)%params.Slots()]
}

verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values1, res, testContext.params.LogSlots(), 0, t)
verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t)
})

t.Run(GetTestName(testContext.params, "LinearTransform/Naive/"), func(t *testing.T) {
Expand All @@ -1160,7 +1160,7 @@ func testLinearTransform(testContext *testParams, t *testing.T) {

eval := testContext.evaluator.WithKey(rlwe.EvaluationKey{Rlk: testContext.rlk, Rtks: rotKey})

res := eval.LinearTransformNew(ciphertext1, ptDiagMatrix)[0]
eval.LinearTransform(ciphertext1, ptDiagMatrix, []*Ciphertext{ciphertext1})

tmp := make([]complex128, params.Slots())
copy(tmp, values1)
Expand All @@ -1169,7 +1169,7 @@ func testLinearTransform(testContext *testParams, t *testing.T) {
values1[i] += tmp[(i-1+params.Slots())%params.Slots()]
}

verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values1, res, testContext.params.LogSlots(), 0, t)
verifyTestVectors(testContext.params, testContext.encoder, testContext.decryptor, values1, ciphertext1, testContext.params.LogSlots(), 0, t)
})
}

Expand Down
22 changes: 6 additions & 16 deletions ckks/linear_transform.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,14 +395,9 @@ func (eval *evaluator) MultiplyByDiagMatrix(ctIn *Ciphertext, matrix PtDiagMatri
ksRes0QP := eval.Pool[3]
ksRes1QP := eval.Pool[4]

var ctInTmp0, ctInTmp1 *ring.Poly
if ctIn != ctOut {
ring.CopyValuesLvl(levelQ, ctIn.Value[0], eval.ctxpool.Value[0])
ring.CopyValuesLvl(levelQ, ctIn.Value[1], eval.ctxpool.Value[1])
ctInTmp0, ctInTmp1 = eval.ctxpool.Value[0], eval.ctxpool.Value[1]
} else {
ctInTmp0, ctInTmp1 = ctIn.Value[0], ctIn.Value[1]
}
ring.CopyValuesLvl(levelQ, ctIn.Value[0], eval.ctxpool.Value[0])
ring.CopyValuesLvl(levelQ, ctIn.Value[1], eval.ctxpool.Value[1])
ctInTmp0, ctInTmp1 := eval.ctxpool.Value[0], eval.ctxpool.Value[1]

ringQ.MulScalarBigintLvl(levelQ, ctInTmp0, ringP.ModulusBigint, ct0TimesP) // P*c0

Expand Down Expand Up @@ -499,14 +494,9 @@ func (eval *evaluator) MultiplyByDiagMatrixBSGS(ctIn *Ciphertext, matrix PtDiagM

index, rotations := bsgsIndex(matrix.Vec, 1<<matrix.LogSlots, matrix.N1)

var ctInTmp0, ctInTmp1 *ring.Poly
if ctIn == ctOut {
ring.CopyValuesLvl(levelQ, ctIn.Value[0], eval.ctxpool.Value[0])
ring.CopyValuesLvl(levelQ, ctIn.Value[1], eval.ctxpool.Value[1])
ctInTmp0, ctInTmp1 = eval.ctxpool.Value[0], eval.ctxpool.Value[1]
} else {
ctInTmp0, ctInTmp1 = ctIn.Value[0], ctIn.Value[1]
}
ring.CopyValuesLvl(levelQ, ctIn.Value[0], eval.ctxpool.Value[0])
ring.CopyValuesLvl(levelQ, ctIn.Value[1], eval.ctxpool.Value[1])
ctInTmp0, ctInTmp1 := eval.ctxpool.Value[0], eval.ctxpool.Value[1]

// Pre-rotates ciphertext for the baby-step giant-step algorithm, does not divide by P yet
ctInRotQP := eval.RotateHoistedNoModDownNew(levelQ, rotations, ctInTmp0, eval.PoolDecompQP)
Expand Down

0 comments on commit c761a16

Please sign in to comment.