|
| 1 | +package linear_transformation |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + "sort" |
| 6 | + |
| 7 | + "github.com/tuneinsight/lattigo/v5/core/rlwe" |
| 8 | + "github.com/tuneinsight/lattigo/v5/ring" |
| 9 | + "github.com/tuneinsight/lattigo/v5/ring/ringqp" |
| 10 | + "github.com/tuneinsight/lattigo/v5/schemes" |
| 11 | + "github.com/tuneinsight/lattigo/v5/utils" |
| 12 | +) |
| 13 | + |
| 14 | +// LinearTransformationParameters is a struct storing the parameterization of a |
| 15 | +// linear transformation. |
| 16 | +// |
| 17 | +// A homomorphic linear transformations on a ciphertext acts as evaluating: |
| 18 | +// |
| 19 | +// Ciphertext([1 x n] vector) <- Ciphertext([1 x n] vector) x Plaintext([n x n] matrix) |
| 20 | +// |
| 21 | +// where n is the number of plaintext slots. |
| 22 | +// |
| 23 | +// The diagonal representation of a linear transformations is defined by first expressing |
| 24 | +// the linear transformation through its nxn matrix and then traversing the matrix diagonally. |
| 25 | +// |
| 26 | +// For example, the following nxn for n=4 matrix: |
| 27 | +// |
| 28 | +// 0 1 2 3 (diagonal index) |
| 29 | +// | 1 2 3 0 | |
| 30 | +// | 0 1 2 3 | |
| 31 | +// | 3 0 1 2 | |
| 32 | +// | 2 3 0 1 | |
| 33 | +// |
| 34 | +// its diagonal traversal representation is comprised of 3 non-zero diagonals at indexes [0, 1, 2]: |
| 35 | +// 0: [1, 1, 1, 1] |
| 36 | +// 1: [2, 2, 2, 2] |
| 37 | +// 2: [3, 3, 3, 3] |
| 38 | +// 3: [0, 0, 0, 0] -> this diagonal is omitted as it is composed only of zero values. |
| 39 | +// |
| 40 | +// Note that negative indexes can be used and will be interpreted modulo the matrix dimension. |
| 41 | +// |
| 42 | +// The diagonal representation is well suited for two reasons: |
| 43 | +// 1. It is the effective format used during the homomorphic evaluation. |
| 44 | +// 2. It enables on average a more compact and efficient representation of linear transformations |
| 45 | +// than their matrix representation by being able to only store the non-zero diagonals. |
| 46 | +// |
| 47 | +// Finally, some metrics about the time and storage complexity of homomorphic linear transformations: |
| 48 | +// - Storage: #diagonals polynomials mod Q * P |
| 49 | +// - Evaluation: #diagonals multiplications and 2sqrt(#diagonals) ciphertexts rotations. |
| 50 | +type LinearTransformationParameters struct { |
| 51 | + // DiagonalsIndexList is the list of the non-zero diagonals of the square matrix. |
| 52 | + // A non zero diagonals is a diagonal with a least one non-zero element. |
| 53 | + DiagonalsIndexList []int |
| 54 | + |
| 55 | + // LevelQ is the level at which to encode the linear transformation. |
| 56 | + LevelQ int |
| 57 | + |
| 58 | + // LevelP is the level of the auxliary prime used during the automorphisms |
| 59 | + // User must ensure that this value is the same as the one used to generate |
| 60 | + // the evaluation keys used to perform the automorphisms. |
| 61 | + LevelP int |
| 62 | + |
| 63 | + // Scale is the plaintext scale at which to encode the linear transformation. |
| 64 | + Scale rlwe.Scale |
| 65 | + |
| 66 | + // LogDimensions is the log2 dimensions of the matrix that can be SIMD packed |
| 67 | + // in a single plaintext polynomial. |
| 68 | + // This method is equivalent to params.PlaintextDimensions(). |
| 69 | + // Note that the linear transformation is evaluated independently on each rows of |
| 70 | + // the SIMD packed matrix. |
| 71 | + LogDimensions ring.Dimensions |
| 72 | + |
| 73 | + // LogBabyStepGianStepRatio is the log2 of the ratio n1/n2 for n = n1 * n2 and |
| 74 | + // n is the dimension of the linear transformation. The number of Galois keys required |
| 75 | + // is minimized when this value is 0 but the overall complexity of the homomorphic evaluation |
| 76 | + // can be reduced by increasing the ratio (at the expanse of increasing the number of keys required). |
| 77 | + // If the value returned is negative, then the baby-step giant-step algorithm is not used |
| 78 | + // and the evaluation complexity (as well as the number of keys) becomes O(n) instead of O(sqrt(n)). |
| 79 | + LogBabyStepGianStepRatio int |
| 80 | +} |
| 81 | + |
| 82 | +type Diagonals[T any] map[int][]T |
| 83 | + |
| 84 | +// DiagonalsIndexList returns the list of the non-zero diagonals of the square matrix. |
| 85 | +// A non zero diagonals is a diagonal with a least one non-zero element. |
| 86 | +func (m Diagonals[T]) DiagonalsIndexList() (indexes []int) { |
| 87 | + indexes = make([]int, 0, len(m)) |
| 88 | + for k := range m { |
| 89 | + indexes = append(indexes, k) |
| 90 | + } |
| 91 | + return indexes |
| 92 | +} |
| 93 | + |
| 94 | +// At returns the i-th non-zero diagonal. |
| 95 | +// Method accepts negative values with the equivalency -i = n - i. |
| 96 | +func (m Diagonals[T]) At(i, slots int) ([]T, error) { |
| 97 | + |
| 98 | + v, ok := m[i] |
| 99 | + |
| 100 | + if !ok { |
| 101 | + |
| 102 | + var j int |
| 103 | + if i > 0 { |
| 104 | + j = i - slots |
| 105 | + } else if j < 0 { |
| 106 | + j = i + slots |
| 107 | + } else { |
| 108 | + return nil, fmt.Errorf("cannot At[0]: diagonal does not exist") |
| 109 | + } |
| 110 | + |
| 111 | + v, ok := m[j] |
| 112 | + |
| 113 | + if !ok { |
| 114 | + return nil, fmt.Errorf("cannot At[%d or %d]: diagonal does not exist", i, j) |
| 115 | + } |
| 116 | + |
| 117 | + return v, nil |
| 118 | + } |
| 119 | + |
| 120 | + return v, nil |
| 121 | +} |
| 122 | + |
| 123 | +// LinearTransformation is a type for linear transformations on ciphertexts. |
| 124 | +// It stores a plaintext matrix in diagonal form and can be evaluated on a |
| 125 | +// ciphertext using a LinearTransformationEvaluator. |
| 126 | +type LinearTransformation struct { |
| 127 | + *rlwe.MetaData |
| 128 | + LogBabyStepGianStepRatio int |
| 129 | + N1 int |
| 130 | + LevelQ int |
| 131 | + LevelP int |
| 132 | + Vec map[int]ringqp.Poly |
| 133 | +} |
| 134 | + |
| 135 | +// GaloisElements returns the list of Galois elements needed for the evaluation of the linear transformation. |
| 136 | +func (lt LinearTransformation) GaloisElements(params rlwe.ParameterProvider) (galEls []uint64) { |
| 137 | + return GaloisElementsForLinearTransformation(params, utils.GetKeys(lt.Vec), 1<<lt.LogDimensions.Cols, lt.LogBabyStepGianStepRatio) |
| 138 | +} |
| 139 | + |
| 140 | +// BSGSIndex returns the BSGSIndex of the target linear transformation. |
| 141 | +func (lt LinearTransformation) BSGSIndex() (index map[int][]int, n1, n2 []int) { |
| 142 | + return BSGSIndex(utils.GetKeys(lt.Vec), 1<<lt.LogDimensions.Cols, lt.N1) |
| 143 | +} |
| 144 | + |
| 145 | +// NewLinearTransformation allocates a new LinearTransformation with zero values according to the parameters specified by the LinearTransformationParameters. |
| 146 | +func NewLinearTransformation(params rlwe.ParameterProvider, ltparams LinearTransformationParameters) LinearTransformation { |
| 147 | + |
| 148 | + p := params.GetRLWEParameters() |
| 149 | + |
| 150 | + vec := make(map[int]ringqp.Poly) |
| 151 | + cols := 1 << ltparams.LogDimensions.Cols |
| 152 | + logBabyStepGianStepRatio := ltparams.LogBabyStepGianStepRatio |
| 153 | + levelQ := ltparams.LevelQ |
| 154 | + levelP := ltparams.LevelP |
| 155 | + ringQP := p.RingQP().AtLevel(levelQ, levelP) |
| 156 | + |
| 157 | + diagslislt := ltparams.DiagonalsIndexList |
| 158 | + |
| 159 | + var N1 int |
| 160 | + if logBabyStepGianStepRatio < 0 { |
| 161 | + N1 = 0 |
| 162 | + for _, i := range diagslislt { |
| 163 | + idx := i |
| 164 | + if idx < 0 { |
| 165 | + idx += cols |
| 166 | + } |
| 167 | + vec[idx] = ringQP.NewPoly() |
| 168 | + } |
| 169 | + } else { |
| 170 | + N1 = FindBestBSGSRatio(diagslislt, cols, logBabyStepGianStepRatio) |
| 171 | + index, _, _ := BSGSIndex(diagslislt, cols, N1) |
| 172 | + for j := range index { |
| 173 | + for _, i := range index[j] { |
| 174 | + vec[j+i] = ringQP.NewPoly() |
| 175 | + } |
| 176 | + } |
| 177 | + } |
| 178 | + |
| 179 | + metadata := &rlwe.MetaData{ |
| 180 | + PlaintextMetaData: rlwe.PlaintextMetaData{ |
| 181 | + LogDimensions: ltparams.LogDimensions, |
| 182 | + Scale: ltparams.Scale, |
| 183 | + IsBatched: true, |
| 184 | + }, |
| 185 | + CiphertextMetaData: rlwe.CiphertextMetaData{ |
| 186 | + IsNTT: true, |
| 187 | + IsMontgomery: true, |
| 188 | + }, |
| 189 | + } |
| 190 | + |
| 191 | + return LinearTransformation{ |
| 192 | + MetaData: metadata, |
| 193 | + LogBabyStepGianStepRatio: logBabyStepGianStepRatio, |
| 194 | + N1: N1, |
| 195 | + LevelQ: levelQ, |
| 196 | + LevelP: levelP, |
| 197 | + Vec: vec, |
| 198 | + } |
| 199 | +} |
| 200 | + |
| 201 | +// EncodeLinearTransformation encodes on a pre-allocated LinearTransformation a set of non-zero diagonaes of a matrix representing a linear transformation. |
| 202 | +func EncodeLinearTransformation[T any](encoder schemes.Encoder, diagonals Diagonals[T], allocated LinearTransformation) (err error) { |
| 203 | + |
| 204 | + rows := 1 << allocated.LogDimensions.Rows |
| 205 | + cols := 1 << allocated.LogDimensions.Cols |
| 206 | + N1 := allocated.N1 |
| 207 | + |
| 208 | + diags := diagonals.DiagonalsIndexList() |
| 209 | + |
| 210 | + buf := make([]T, rows*cols) |
| 211 | + |
| 212 | + metaData := allocated.MetaData |
| 213 | + |
| 214 | + metaData.Scale = allocated.Scale |
| 215 | + |
| 216 | + var v []T |
| 217 | + |
| 218 | + if N1 == 0 { |
| 219 | + for _, i := range diags { |
| 220 | + |
| 221 | + idx := i |
| 222 | + if idx < 0 { |
| 223 | + idx += cols |
| 224 | + } |
| 225 | + |
| 226 | + if vec, ok := allocated.Vec[idx]; !ok { |
| 227 | + return fmt.Errorf("cannot EncodeLinearTransformation: error encoding on LinearTransformation: plaintext diagonal [%d] does not exist", idx) |
| 228 | + } else { |
| 229 | + |
| 230 | + if v, err = diagonals.At(i, cols); err != nil { |
| 231 | + return fmt.Errorf("cannot EncodeLinearTransformation: %w", err) |
| 232 | + } |
| 233 | + |
| 234 | + if err = rotateAndEncodeDiagonal(v, encoder, 0, metaData, buf, vec); err != nil { |
| 235 | + return |
| 236 | + } |
| 237 | + } |
| 238 | + } |
| 239 | + } else { |
| 240 | + |
| 241 | + index, _, _ := allocated.BSGSIndex() |
| 242 | + |
| 243 | + for j := range index { |
| 244 | + |
| 245 | + rot := -j & (cols - 1) |
| 246 | + |
| 247 | + for _, i := range index[j] { |
| 248 | + |
| 249 | + if vec, ok := allocated.Vec[i+j]; !ok { |
| 250 | + return fmt.Errorf("cannot Encode: error encoding on LinearTransformation BSGS: input does not match the same non-zero diagonals") |
| 251 | + } else { |
| 252 | + |
| 253 | + if v, err = diagonals.At(i+j, cols); err != nil { |
| 254 | + return fmt.Errorf("cannot EncodeLinearTransformation: %w", err) |
| 255 | + } |
| 256 | + |
| 257 | + if err = rotateAndEncodeDiagonal(v, encoder, rot, metaData, buf, vec); err != nil { |
| 258 | + return |
| 259 | + } |
| 260 | + } |
| 261 | + } |
| 262 | + } |
| 263 | + } |
| 264 | + |
| 265 | + return |
| 266 | +} |
| 267 | + |
| 268 | +func rotateAndEncodeDiagonal[T any](v []T, encoder schemes.Encoder, rot int, metaData *rlwe.MetaData, buf []T, poly ringqp.Poly) (err error) { |
| 269 | + |
| 270 | + rows := 1 << metaData.LogDimensions.Rows |
| 271 | + cols := 1 << metaData.LogDimensions.Cols |
| 272 | + |
| 273 | + rot &= (cols - 1) |
| 274 | + |
| 275 | + var values []T |
| 276 | + if rot != 0 { |
| 277 | + |
| 278 | + values = buf |
| 279 | + |
| 280 | + for i := 0; i < rows; i++ { |
| 281 | + utils.RotateSliceAllocFree(v[i*cols:(i+1)*cols], rot, values[i*cols:(i+1)*cols]) |
| 282 | + } |
| 283 | + |
| 284 | + } else { |
| 285 | + values = v |
| 286 | + } |
| 287 | + |
| 288 | + return encoder.Embed(values, metaData, poly) |
| 289 | +} |
| 290 | + |
| 291 | +// GaloisElementsForLinearTransformation returns the list of Galois elements needed for the evaluation of a linear transformation |
| 292 | +// given the index of its non-zero diagonals, the number of slots in the plaintext and the LogBabyStepGianStepRatio (see LinearTransformationParameters). |
| 293 | +func GaloisElementsForLinearTransformation(params rlwe.ParameterProvider, diags []int, slots, logBabyStepGianStepRatio int) (galEls []uint64) { |
| 294 | + |
| 295 | + p := params.GetRLWEParameters() |
| 296 | + |
| 297 | + if logBabyStepGianStepRatio < 0 { |
| 298 | + |
| 299 | + _, _, rotN2 := BSGSIndex(diags, slots, slots) |
| 300 | + |
| 301 | + galEls = make([]uint64, len(rotN2)) |
| 302 | + for i := range rotN2 { |
| 303 | + galEls[i] = p.GaloisElement(rotN2[i]) |
| 304 | + } |
| 305 | + |
| 306 | + return |
| 307 | + } |
| 308 | + |
| 309 | + N1 := FindBestBSGSRatio(diags, slots, logBabyStepGianStepRatio) |
| 310 | + |
| 311 | + _, rotN1, rotN2 := BSGSIndex(diags, slots, N1) |
| 312 | + |
| 313 | + return p.GaloisElements(utils.GetDistincts(append(rotN1, rotN2...))) |
| 314 | +} |
| 315 | + |
| 316 | +// FindBestBSGSRatio finds the best N1*N2 = N for the baby-step giant-step algorithm for matrix multiplication. |
| 317 | +func FindBestBSGSRatio(nonZeroDiags []int, maxN int, logMaxRatio int) (minN int) { |
| 318 | + |
| 319 | + maxRatio := float64(int(1 << logMaxRatio)) |
| 320 | + |
| 321 | + for N1 := 1; N1 < maxN; N1 <<= 1 { |
| 322 | + |
| 323 | + _, rotN1, rotN2 := BSGSIndex(nonZeroDiags, maxN, N1) |
| 324 | + |
| 325 | + nbN1, nbN2 := len(rotN1)-1, len(rotN2)-1 |
| 326 | + |
| 327 | + if float64(nbN2)/float64(nbN1) == maxRatio { |
| 328 | + return N1 |
| 329 | + } |
| 330 | + |
| 331 | + if float64(nbN2)/float64(nbN1) > maxRatio { |
| 332 | + return N1 / 2 |
| 333 | + } |
| 334 | + } |
| 335 | + |
| 336 | + return 1 |
| 337 | +} |
| 338 | + |
| 339 | +// BSGSIndex returns the index map and needed rotation for the BSGS matrix-vector multiplication algorithm. |
| 340 | +func BSGSIndex(nonZeroDiags []int, slots, N1 int) (index map[int][]int, rotN1, rotN2 []int) { |
| 341 | + index = make(map[int][]int) |
| 342 | + rotN1Map := make(map[int]bool) |
| 343 | + rotN2Map := make(map[int]bool) |
| 344 | + |
| 345 | + for _, rot := range nonZeroDiags { |
| 346 | + rot &= (slots - 1) |
| 347 | + idxN1 := ((rot / N1) * N1) & (slots - 1) |
| 348 | + idxN2 := rot & (N1 - 1) |
| 349 | + if index[idxN1] == nil { |
| 350 | + index[idxN1] = []int{idxN2} |
| 351 | + } else { |
| 352 | + index[idxN1] = append(index[idxN1], idxN2) |
| 353 | + } |
| 354 | + rotN1Map[idxN1] = true |
| 355 | + rotN2Map[idxN2] = true |
| 356 | + } |
| 357 | + |
| 358 | + for k := range index { |
| 359 | + sort.Ints(index[k]) |
| 360 | + } |
| 361 | + |
| 362 | + return index, utils.GetSortedKeys(rotN1Map), utils.GetSortedKeys(rotN2Map) |
| 363 | +} |
0 commit comments