Skip to content

Commit

Permalink
Merge pull request #24 from bytemare/add-scalar-setint
Browse files Browse the repository at this point in the history
Add Scalar.SetInt()
  • Loading branch information
bytemare committed Dec 29, 2022
2 parents 39c00b2 + 2cfc772 commit 55455f7
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 1 deletion.
3 changes: 3 additions & 0 deletions internal/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ var (

// ErrIdentity indicates that the identity point (or point at infinity) has been encountered.
ErrIdentity = errors.New("infinity/identity point")

// ErrBigIntConversion reports an error in converting to a *big.int.
ErrBigIntConversion = errors.New("conversion error")
)

// RandomBytes returns random bytes of length len (wrapper for crypto/rand).
Expand Down
8 changes: 8 additions & 0 deletions internal/nist/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,14 @@ func (s *Scalar) Set(scalar internal.Scalar) internal.Scalar {
return s
}

// SetInt sets s to i modulo the field order, and returns an error if one occurs.
func (s *Scalar) SetInt(i *big.Int) error {
s.s.Set(i)
s.field.mod(&s.s)

return nil
}

// Copy returns a copy of the Scalar.
func (s *Scalar) Copy() internal.Scalar {
cpy := newScalar(s.field)
Expand Down
20 changes: 20 additions & 0 deletions internal/ristretto/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package ristretto

import (
"fmt"
"math/big"

"github.com/gtank/ristretto255"

Expand Down Expand Up @@ -191,6 +192,25 @@ func (s *Scalar) Set(scalar internal.Scalar) internal.Scalar {
return s
}

// SetInt sets s to i modulo the field order, and returns an error if one occurs.
func (s *Scalar) SetInt(i *big.Int) error {
a := new(big.Int).Set(i)

order, ok := new(big.Int).SetString(orderPrime, 10)
if !ok {
return internal.ErrBigIntConversion
}

bytes := make([]byte, 32)
bytes = a.Mod(a, order).FillBytes(bytes)

for j, k := 0, len(bytes)-1; j < k; j, k = j+1, k-1 {
bytes[j], bytes[k] = bytes[k], bytes[j]
}

return s.Decode(bytes)
}

// Copy returns a copy of the receiver.
func (s *Scalar) Copy() internal.Scalar {
return &Scalar{*ristretto255.NewScalar().Add(ristretto255.NewScalar(), &s.scalar)}
Expand Down
8 changes: 7 additions & 1 deletion internal/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
// Package internal defines simple and abstract APIs to group Elements and Scalars.
package internal

import "encoding"
import (
"encoding"
"math/big"
)

// Scalar interface abstracts common operations on scalars in a prime-order Group.
type Scalar interface {
Expand Down Expand Up @@ -50,6 +53,9 @@ type Scalar interface {
// Set sets the receiver to the value of the argument scalar, and returns the receiver.
Set(Scalar) Scalar

// SetInt sets s to i modulo the field order, and returns an error if one occurs.
SetInt(i *big.Int) error

// Copy returns a copy of the receiver.
Copy() Scalar

Expand Down
10 changes: 10 additions & 0 deletions scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package crypto

import (
"fmt"
"math/big"

"github.com/bytemare/crypto/internal"
)
Expand Down Expand Up @@ -122,6 +123,15 @@ func (s *Scalar) Set(scalar *Scalar) *Scalar {
return s
}

// SetInt sets s to i modulo the field order, and returns an error if one occurs.
func (s *Scalar) SetInt(i *big.Int) error {
if err := s.Scalar.SetInt(i); err != nil {
return fmt.Errorf("scalar: %w", err)
}

return nil
}

// Copy returns a copy of the receiver.
func (s *Scalar) Copy() *Scalar {
return &Scalar{s.Scalar.Copy()}
Expand Down
38 changes: 38 additions & 0 deletions tests/scalar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package group_test

import (
"encoding/hex"
"math/big"
"testing"

"github.com/bytemare/crypto"
Expand Down Expand Up @@ -104,6 +105,43 @@ func TestScalarSet(t *testing.T) {
})
}

func TestScalarSetInt(t *testing.T) {
testAll(t, func(t2 *testing.T, group *testGroup) {
i := big.NewInt(0)

s := group.id.NewScalar()
if err := s.SetInt(i); err != nil {
t.Fatal(err)
}

if !s.IsZero() {
t.Fatal("expected 0")
}

i = big.NewInt(1)
if err := s.SetInt(i); err != nil {
t.Fatal(err)
}

if s.Equal(group.id.NewScalar().One()) != 1 {
t.Fatal("expected 1")
}

order, ok := new(big.Int).SetString(group.id.Order(), 10)
if !ok {
t.Fatal("conversion error")
}

if err := s.SetInt(order); err != nil {
t.Fatal(err)
}

if !s.IsZero() {
t.Fatal("expected 0")
}
})
}

func TestScalar_EncodedLength(t *testing.T) {
testAll(t, func(t2 *testing.T, group *testGroup) {
encodedScalar := group.id.NewScalar().Random().Encode()
Expand Down

0 comments on commit 55455f7

Please sign in to comment.