forked from coinbase/kryptology
-
Notifications
You must be signed in to change notification settings - Fork 0
/
poseidon_config.go
112 lines (99 loc) · 2 KB
/
poseidon_config.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
//
// Copyright Coinbase, Inc. All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//
package mina
import (
"github.com/berry-block/kryptology/pkg/core/curves/native/pasta/fp"
)
// SBox is the type of exponentiation to perform
type SBox int
const (
Cube = iota // x^3
Quint // x^5
Sept // x^7
Inverse // x^-1
)
// Exp mutates f by computing x^3, x^5, x^7 or x^-1 as described in
// https://eprint.iacr.org/2019/458.pdf page 8
func (sbox SBox) Exp(f *fp.Fp) {
switch sbox {
case Cube:
t := new(fp.Fp).Square(f)
f.Mul(t, f)
case Quint:
t := new(fp.Fp).Square(f)
t.Square(t)
f.Mul(t, f)
case Sept:
f2 := new(fp.Fp).Square(f)
f4 := new(fp.Fp).Square(f2)
t := new(fp.Fp).Mul(f2, f4)
f.Mul(t, f)
case Inverse:
f.Invert(f)
default:
}
}
// Permutation is the permute function to use
type Permutation int
const (
ThreeW = iota
FiveW
Three
)
// Permute executes the poseidon hash function
func (p Permutation) Permute(ctx *Context) {
switch p {
case ThreeW:
for r := 0; r < ctx.fullRounds; r++ {
ark(ctx, r)
sbox(ctx)
mds(ctx)
}
ark(ctx, ctx.fullRounds)
case Three:
fallthrough
case FiveW:
// Full rounds only
for r := 0; r < ctx.fullRounds; r++ {
sbox(ctx)
mds(ctx)
ark(ctx, r)
}
default:
}
}
func ark(ctx *Context, round int) {
for i := 0; i < ctx.spongeWidth; i++ {
ctx.state[i].Add(ctx.state[i], ctx.roundKeys[round][i])
}
}
func sbox(ctx *Context) {
for i := 0; i < ctx.spongeWidth; i++ {
ctx.sBox.Exp(ctx.state[i])
}
}
func mds(ctx *Context) {
state2 := make([]*fp.Fp, len(ctx.state))
for i := range ctx.state {
state2[i] = new(fp.Fp).SetZero()
}
for row := 0; row < ctx.spongeWidth; row++ {
for col := 0; col < ctx.spongeWidth; col++ {
t := new(fp.Fp).Mul(ctx.state[col], ctx.mdsMatrix[row][col])
state2[row].Add(state2[row], t)
}
}
for i, f := range state2 {
ctx.state[i].Set(f)
}
}
// NetworkType is which Mina network id to use
type NetworkType int
const (
TestNet = iota
MainNet
NullNet
)