/
curve.go
120 lines (98 loc) · 2.21 KB
/
curve.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
package elliptic
import (
el "crypto/elliptic"
"crypto/rand"
"github.com/jtejido/spake2plus/internal/suite"
"math/big"
)
type curve struct {
el.Curve
n, m []byte
p *el.CurveParams
}
func (c curve) String() string {
return c.p.Name
}
func (c curve) ScalarLen() int { return (c.p.N.BitLen() + 7) / 8 }
func (c curve) Scalar() suite.Scalar {
return newScalar(0, c.p.N)
}
func (c curve) coordLen() int {
return (c.p.BitSize + 7) / 8
}
func (c curve) ElementLen() int {
return 1 + 2*c.coordLen()
}
func (c curve) Element() suite.Element {
p := new(curvePoint)
p.c = &c
return p
}
// TO-DO: 5. Per-User M and N in SPAKE2 RFC
func (c curve) M() suite.Element {
point := c.Element().(*curvePoint)
var ch byte
for _, b := range c.n[1:] {
ch |= b
}
if ch != 0 {
point.x, point.y = point.unmarshalCompressed(c.n, 1+c.coordLen())
if point.x == nil || !point.valid() {
panic("invalid elliptic curve point")
}
} else {
point.x = big.NewInt(0)
point.y = big.NewInt(0)
}
return point
}
// TO-DO: 5. Per-User M and N in SPAKE2 RFC
func (c curve) N() suite.Element {
point := c.Element().(*curvePoint)
var ch byte
for _, b := range c.n[1:] {
ch |= b
}
if ch != 0 {
point.x, point.y = point.unmarshalCompressed(c.n, 1+c.coordLen())
if point.x == nil || !point.valid() {
panic("invalid elliptic curve point")
}
} else {
point.x = big.NewInt(0)
point.y = big.NewInt(0)
}
return point
}
var mask = []byte{0xff, 0x1, 0x3, 0x7, 0xf, 0x1f, 0x3f, 0x7f}
func (c curve) RandomElement() (suite.Element, error) {
sc, err := c.RandomScalar()
return c.Element().ScalarMult(sc, nil), err
}
func (c curve) RandomScalar() (suite.Scalar, error) {
buf := make([]byte, c.ScalarLen())
bitLen := c.p.N.BitLen()
for {
_, err := rand.Read(buf)
if err != nil {
return nil, err
}
buf[0] &= mask[bitLen%8]
buf[1] ^= 0x42
if new(big.Int).SetBytes(buf).Cmp(c.Order()) < 0 {
break
}
}
sc := c.Scalar()
err := sc.FromBytes(buf)
return sc, err
}
func (c curve) Order() *big.Int {
return c.p.N
}
func (c curve) CofactorScalar() suite.Scalar {
return newScalar(1, c.p.N)
}
func (c curve) ClearCofactor(elem suite.Element) suite.Element {
return c.Element().ScalarMult(c.CofactorScalar(), elem)
}