-
Notifications
You must be signed in to change notification settings - Fork 9
/
poseidon.go
156 lines (134 loc) · 3.76 KB
/
poseidon.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
package hash
import (
"github.com/consensys/gnark-crypto/ecc/bn254/fr"
)
// PoseidonT2 is a hasher with T = 2
var PoseidonT2 PoseidonHasher
// PoseidonT4 is a hasher with T = 4
var PoseidonT4 PoseidonHasher
// PoseidonT8 is a hasher with T = 8
var PoseidonT8 PoseidonHasher
// The parameters are set according to
// https://eprint.iacr.org/2020/179.pdf
func initPoseidon() {
PoseidonT2 = newPoseidonHasher(2, 8, 82)
PoseidonT4 = newPoseidonHasher(4, 8, 83)
PoseidonT8 = newPoseidonHasher(8, 8, 84)
}
// PoseidonHasher contains all the parameters to specify a poseidon hash function
type PoseidonHasher struct {
t int // size of Cauchy matrix
cauchy [][]fr.Element
nRoundsF int
nRoundsP int
}
// NewPoseidonHasher generates the parameters to run poseidon
func newPoseidonHasher(t, nRoundsF, nRoundsP int) PoseidonHasher {
return PoseidonHasher{
t: t,
nRoundsF: nRoundsF,
nRoundsP: nRoundsP,
cauchy: GenerateMDSMatrix(t),
}
}
// Hash hashes a full message
func (p *PoseidonHasher) Hash(msg []fr.Element) fr.Element {
state := make([]fr.Element, p.t)
for i := 0; i < len(msg); i += p.t {
block := make([]fr.Element, p.t)
if i+p.t >= len(msg) {
// Only zero-pad the input
for j, w := range msg[i:] {
block[j] = w
}
} else {
// Take a full chunk
for j, w := range msg[i : i+p.t] {
block[j] = w
}
}
p.Update(state, block)
}
return state[0]
}
// Update uses the poseidon permutation in a Miyaguchi-Preenel
// construction to create the hash function.
// https://en.wikipedia.org/wiki/One-way_compression_function#Miyaguchi.E2.80.93Preneel
func (p *PoseidonHasher) Update(state, block []fr.Element) {
// Deep-copies the state
oldState := append([]fr.Element{}, state...)
// Runs the cipher
for i := 0; i < p.nRoundsF; i++ {
AddArkAndKeysInplace(state, block, Arks[i])
FullRoundInPlace(state)
state = MatrixMultiplication(p.cauchy, state)
}
for i := p.nRoundsF; i < p.nRoundsF+p.nRoundsP; i++ {
AddArkAndKeysInplace(state, block, Arks[i])
PartialRoundInplace(state)
state = MatrixMultiplication(p.cauchy, state)
}
for i := p.nRoundsF + p.nRoundsP; i < 2*p.nRoundsF+p.nRoundsP; i++ {
AddArkAndKeysInplace(state, block, Arks[i])
FullRoundInPlace(state)
state = MatrixMultiplication(p.cauchy, state)
}
// Recombine with the old state
for i := range state {
state[i].Add(&state[i], &oldState[i])
state[i].Add(&state[i], &block[i])
}
}
// GenerateMDSMatrix returns the MDS matrix for a given size
func GenerateMDSMatrix(t int) [][]fr.Element {
result := make([][]fr.Element, t)
for i := range result {
result[i] = make([]fr.Element, t)
for j := range result[i] {
result[i][j].Set(&xArr[i])
result[i][j].Add(&result[i][j], &yArr[j])
result[i][j].Inverse(&result[i][j])
}
}
return result
}
// MatrixMultiplication by a vector
// The dimensions are mat[k][n] * vec[n] = res[k]
func MatrixMultiplication(mat [][]fr.Element, vec []fr.Element) []fr.Element {
res := make([]fr.Element, len(mat))
var tmp fr.Element
for i, col := range mat {
for j, el := range col {
tmp.Set(&vec[j])
tmp.Mul(&tmp, &el)
res[i].Add(&res[i], &tmp)
}
}
return res
}
// SBoxInplace computes x^7 in-place
func SBoxInplace(x *fr.Element) {
tmp := *x
x.Square(x)
x.Mul(x, &tmp)
x.Square(x)
x.Mul(x, &tmp)
}
// FullRoundInPlace applies the SBox on all entries
// of the state
func FullRoundInPlace(state []fr.Element) {
for i := range state {
SBoxInplace(&state[i])
}
}
// AddArkAndKeysInplace adds the
func AddArkAndKeysInplace(state []fr.Element, keys []fr.Element, ark fr.Element) {
for i := range state {
state[i].Add(&state[i], &keys[i])
state[i].Add(&state[i], &ark)
}
}
// PartialRoundInplace applies the SBox on the first entry
func PartialRoundInplace(state []fr.Element) {
SBoxInplace(&state[0])
}