Skip to content

Commit c2c6f71

Browse files
committed
repack linear transformation
1 parent f0a1a98 commit c2c6f71

File tree

10 files changed

+1304
-19
lines changed

10 files changed

+1304
-19
lines changed
Lines changed: 363 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,363 @@
1+
package linear_transformation
2+
3+
import (
4+
"fmt"
5+
"sort"
6+
7+
"github.com/tuneinsight/lattigo/v5/core/rlwe"
8+
"github.com/tuneinsight/lattigo/v5/ring"
9+
"github.com/tuneinsight/lattigo/v5/ring/ringqp"
10+
"github.com/tuneinsight/lattigo/v5/schemes"
11+
"github.com/tuneinsight/lattigo/v5/utils"
12+
)
13+
14+
// LinearTransformationParameters is a struct storing the parameterization of a
15+
// linear transformation.
16+
//
17+
// A homomorphic linear transformations on a ciphertext acts as evaluating:
18+
//
19+
// Ciphertext([1 x n] vector) <- Ciphertext([1 x n] vector) x Plaintext([n x n] matrix)
20+
//
21+
// where n is the number of plaintext slots.
22+
//
23+
// The diagonal representation of a linear transformations is defined by first expressing
24+
// the linear transformation through its nxn matrix and then traversing the matrix diagonally.
25+
//
26+
// For example, the following nxn for n=4 matrix:
27+
//
28+
// 0 1 2 3 (diagonal index)
29+
// | 1 2 3 0 |
30+
// | 0 1 2 3 |
31+
// | 3 0 1 2 |
32+
// | 2 3 0 1 |
33+
//
34+
// its diagonal traversal representation is comprised of 3 non-zero diagonals at indexes [0, 1, 2]:
35+
// 0: [1, 1, 1, 1]
36+
// 1: [2, 2, 2, 2]
37+
// 2: [3, 3, 3, 3]
38+
// 3: [0, 0, 0, 0] -> this diagonal is omitted as it is composed only of zero values.
39+
//
40+
// Note that negative indexes can be used and will be interpreted modulo the matrix dimension.
41+
//
42+
// The diagonal representation is well suited for two reasons:
43+
// 1. It is the effective format used during the homomorphic evaluation.
44+
// 2. It enables on average a more compact and efficient representation of linear transformations
45+
// than their matrix representation by being able to only store the non-zero diagonals.
46+
//
47+
// Finally, some metrics about the time and storage complexity of homomorphic linear transformations:
48+
// - Storage: #diagonals polynomials mod Q * P
49+
// - Evaluation: #diagonals multiplications and 2sqrt(#diagonals) ciphertexts rotations.
50+
type LinearTransformationParameters struct {
51+
// DiagonalsIndexList is the list of the non-zero diagonals of the square matrix.
52+
// A non zero diagonals is a diagonal with a least one non-zero element.
53+
DiagonalsIndexList []int
54+
55+
// LevelQ is the level at which to encode the linear transformation.
56+
LevelQ int
57+
58+
// LevelP is the level of the auxliary prime used during the automorphisms
59+
// User must ensure that this value is the same as the one used to generate
60+
// the evaluation keys used to perform the automorphisms.
61+
LevelP int
62+
63+
// Scale is the plaintext scale at which to encode the linear transformation.
64+
Scale rlwe.Scale
65+
66+
// LogDimensions is the log2 dimensions of the matrix that can be SIMD packed
67+
// in a single plaintext polynomial.
68+
// This method is equivalent to params.PlaintextDimensions().
69+
// Note that the linear transformation is evaluated independently on each rows of
70+
// the SIMD packed matrix.
71+
LogDimensions ring.Dimensions
72+
73+
// LogBabyStepGianStepRatio is the log2 of the ratio n1/n2 for n = n1 * n2 and
74+
// n is the dimension of the linear transformation. The number of Galois keys required
75+
// is minimized when this value is 0 but the overall complexity of the homomorphic evaluation
76+
// can be reduced by increasing the ratio (at the expanse of increasing the number of keys required).
77+
// If the value returned is negative, then the baby-step giant-step algorithm is not used
78+
// and the evaluation complexity (as well as the number of keys) becomes O(n) instead of O(sqrt(n)).
79+
LogBabyStepGianStepRatio int
80+
}
81+
82+
type Diagonals[T any] map[int][]T
83+
84+
// DiagonalsIndexList returns the list of the non-zero diagonals of the square matrix.
85+
// A non zero diagonals is a diagonal with a least one non-zero element.
86+
func (m Diagonals[T]) DiagonalsIndexList() (indexes []int) {
87+
indexes = make([]int, 0, len(m))
88+
for k := range m {
89+
indexes = append(indexes, k)
90+
}
91+
return indexes
92+
}
93+
94+
// At returns the i-th non-zero diagonal.
95+
// Method accepts negative values with the equivalency -i = n - i.
96+
func (m Diagonals[T]) At(i, slots int) ([]T, error) {
97+
98+
v, ok := m[i]
99+
100+
if !ok {
101+
102+
var j int
103+
if i > 0 {
104+
j = i - slots
105+
} else if j < 0 {
106+
j = i + slots
107+
} else {
108+
return nil, fmt.Errorf("cannot At[0]: diagonal does not exist")
109+
}
110+
111+
v, ok := m[j]
112+
113+
if !ok {
114+
return nil, fmt.Errorf("cannot At[%d or %d]: diagonal does not exist", i, j)
115+
}
116+
117+
return v, nil
118+
}
119+
120+
return v, nil
121+
}
122+
123+
// LinearTransformation is a type for linear transformations on ciphertexts.
124+
// It stores a plaintext matrix in diagonal form and can be evaluated on a
125+
// ciphertext using a LinearTransformationEvaluator.
126+
type LinearTransformation struct {
127+
*rlwe.MetaData
128+
LogBabyStepGianStepRatio int
129+
N1 int
130+
LevelQ int
131+
LevelP int
132+
Vec map[int]ringqp.Poly
133+
}
134+
135+
// GaloisElements returns the list of Galois elements needed for the evaluation of the linear transformation.
136+
func (lt LinearTransformation) GaloisElements(params rlwe.ParameterProvider) (galEls []uint64) {
137+
return GaloisElementsForLinearTransformation(params, utils.GetKeys(lt.Vec), 1<<lt.LogDimensions.Cols, lt.LogBabyStepGianStepRatio)
138+
}
139+
140+
// BSGSIndex returns the BSGSIndex of the target linear transformation.
141+
func (lt LinearTransformation) BSGSIndex() (index map[int][]int, n1, n2 []int) {
142+
return BSGSIndex(utils.GetKeys(lt.Vec), 1<<lt.LogDimensions.Cols, lt.N1)
143+
}
144+
145+
// NewLinearTransformation allocates a new LinearTransformation with zero values according to the parameters specified by the LinearTransformationParameters.
146+
func NewLinearTransformation(params rlwe.ParameterProvider, ltparams LinearTransformationParameters) LinearTransformation {
147+
148+
p := params.GetRLWEParameters()
149+
150+
vec := make(map[int]ringqp.Poly)
151+
cols := 1 << ltparams.LogDimensions.Cols
152+
logBabyStepGianStepRatio := ltparams.LogBabyStepGianStepRatio
153+
levelQ := ltparams.LevelQ
154+
levelP := ltparams.LevelP
155+
ringQP := p.RingQP().AtLevel(levelQ, levelP)
156+
157+
diagslislt := ltparams.DiagonalsIndexList
158+
159+
var N1 int
160+
if logBabyStepGianStepRatio < 0 {
161+
N1 = 0
162+
for _, i := range diagslislt {
163+
idx := i
164+
if idx < 0 {
165+
idx += cols
166+
}
167+
vec[idx] = ringQP.NewPoly()
168+
}
169+
} else {
170+
N1 = FindBestBSGSRatio(diagslislt, cols, logBabyStepGianStepRatio)
171+
index, _, _ := BSGSIndex(diagslislt, cols, N1)
172+
for j := range index {
173+
for _, i := range index[j] {
174+
vec[j+i] = ringQP.NewPoly()
175+
}
176+
}
177+
}
178+
179+
metadata := &rlwe.MetaData{
180+
PlaintextMetaData: rlwe.PlaintextMetaData{
181+
LogDimensions: ltparams.LogDimensions,
182+
Scale: ltparams.Scale,
183+
IsBatched: true,
184+
},
185+
CiphertextMetaData: rlwe.CiphertextMetaData{
186+
IsNTT: true,
187+
IsMontgomery: true,
188+
},
189+
}
190+
191+
return LinearTransformation{
192+
MetaData: metadata,
193+
LogBabyStepGianStepRatio: logBabyStepGianStepRatio,
194+
N1: N1,
195+
LevelQ: levelQ,
196+
LevelP: levelP,
197+
Vec: vec,
198+
}
199+
}
200+
201+
// EncodeLinearTransformation encodes on a pre-allocated LinearTransformation a set of non-zero diagonaes of a matrix representing a linear transformation.
202+
func EncodeLinearTransformation[T any](encoder schemes.Encoder, diagonals Diagonals[T], allocated LinearTransformation) (err error) {
203+
204+
rows := 1 << allocated.LogDimensions.Rows
205+
cols := 1 << allocated.LogDimensions.Cols
206+
N1 := allocated.N1
207+
208+
diags := diagonals.DiagonalsIndexList()
209+
210+
buf := make([]T, rows*cols)
211+
212+
metaData := allocated.MetaData
213+
214+
metaData.Scale = allocated.Scale
215+
216+
var v []T
217+
218+
if N1 == 0 {
219+
for _, i := range diags {
220+
221+
idx := i
222+
if idx < 0 {
223+
idx += cols
224+
}
225+
226+
if vec, ok := allocated.Vec[idx]; !ok {
227+
return fmt.Errorf("cannot EncodeLinearTransformation: error encoding on LinearTransformation: plaintext diagonal [%d] does not exist", idx)
228+
} else {
229+
230+
if v, err = diagonals.At(i, cols); err != nil {
231+
return fmt.Errorf("cannot EncodeLinearTransformation: %w", err)
232+
}
233+
234+
if err = rotateAndEncodeDiagonal(v, encoder, 0, metaData, buf, vec); err != nil {
235+
return
236+
}
237+
}
238+
}
239+
} else {
240+
241+
index, _, _ := allocated.BSGSIndex()
242+
243+
for j := range index {
244+
245+
rot := -j & (cols - 1)
246+
247+
for _, i := range index[j] {
248+
249+
if vec, ok := allocated.Vec[i+j]; !ok {
250+
return fmt.Errorf("cannot Encode: error encoding on LinearTransformation BSGS: input does not match the same non-zero diagonals")
251+
} else {
252+
253+
if v, err = diagonals.At(i+j, cols); err != nil {
254+
return fmt.Errorf("cannot EncodeLinearTransformation: %w", err)
255+
}
256+
257+
if err = rotateAndEncodeDiagonal(v, encoder, rot, metaData, buf, vec); err != nil {
258+
return
259+
}
260+
}
261+
}
262+
}
263+
}
264+
265+
return
266+
}
267+
268+
func rotateAndEncodeDiagonal[T any](v []T, encoder schemes.Encoder, rot int, metaData *rlwe.MetaData, buf []T, poly ringqp.Poly) (err error) {
269+
270+
rows := 1 << metaData.LogDimensions.Rows
271+
cols := 1 << metaData.LogDimensions.Cols
272+
273+
rot &= (cols - 1)
274+
275+
var values []T
276+
if rot != 0 {
277+
278+
values = buf
279+
280+
for i := 0; i < rows; i++ {
281+
utils.RotateSliceAllocFree(v[i*cols:(i+1)*cols], rot, values[i*cols:(i+1)*cols])
282+
}
283+
284+
} else {
285+
values = v
286+
}
287+
288+
return encoder.Embed(values, metaData, poly)
289+
}
290+
291+
// GaloisElementsForLinearTransformation returns the list of Galois elements needed for the evaluation of a linear transformation
292+
// given the index of its non-zero diagonals, the number of slots in the plaintext and the LogBabyStepGianStepRatio (see LinearTransformationParameters).
293+
func GaloisElementsForLinearTransformation(params rlwe.ParameterProvider, diags []int, slots, logBabyStepGianStepRatio int) (galEls []uint64) {
294+
295+
p := params.GetRLWEParameters()
296+
297+
if logBabyStepGianStepRatio < 0 {
298+
299+
_, _, rotN2 := BSGSIndex(diags, slots, slots)
300+
301+
galEls = make([]uint64, len(rotN2))
302+
for i := range rotN2 {
303+
galEls[i] = p.GaloisElement(rotN2[i])
304+
}
305+
306+
return
307+
}
308+
309+
N1 := FindBestBSGSRatio(diags, slots, logBabyStepGianStepRatio)
310+
311+
_, rotN1, rotN2 := BSGSIndex(diags, slots, N1)
312+
313+
return p.GaloisElements(utils.GetDistincts(append(rotN1, rotN2...)))
314+
}
315+
316+
// FindBestBSGSRatio finds the best N1*N2 = N for the baby-step giant-step algorithm for matrix multiplication.
317+
func FindBestBSGSRatio(nonZeroDiags []int, maxN int, logMaxRatio int) (minN int) {
318+
319+
maxRatio := float64(int(1 << logMaxRatio))
320+
321+
for N1 := 1; N1 < maxN; N1 <<= 1 {
322+
323+
_, rotN1, rotN2 := BSGSIndex(nonZeroDiags, maxN, N1)
324+
325+
nbN1, nbN2 := len(rotN1)-1, len(rotN2)-1
326+
327+
if float64(nbN2)/float64(nbN1) == maxRatio {
328+
return N1
329+
}
330+
331+
if float64(nbN2)/float64(nbN1) > maxRatio {
332+
return N1 / 2
333+
}
334+
}
335+
336+
return 1
337+
}
338+
339+
// BSGSIndex returns the index map and needed rotation for the BSGS matrix-vector multiplication algorithm.
340+
func BSGSIndex(nonZeroDiags []int, slots, N1 int) (index map[int][]int, rotN1, rotN2 []int) {
341+
index = make(map[int][]int)
342+
rotN1Map := make(map[int]bool)
343+
rotN2Map := make(map[int]bool)
344+
345+
for _, rot := range nonZeroDiags {
346+
rot &= (slots - 1)
347+
idxN1 := ((rot / N1) * N1) & (slots - 1)
348+
idxN2 := rot & (N1 - 1)
349+
if index[idxN1] == nil {
350+
index[idxN1] = []int{idxN2}
351+
} else {
352+
index[idxN1] = append(index[idxN1], idxN2)
353+
}
354+
rotN1Map[idxN1] = true
355+
rotN2Map[idxN2] = true
356+
}
357+
358+
for k := range index {
359+
sort.Ints(index[k])
360+
}
361+
362+
return index, utils.GetSortedKeys(rotN1Map), utils.GetSortedKeys(rotN2Map)
363+
}

0 commit comments

Comments
 (0)