Skip to content

Commit

Permalink
Merge pull request #19 from bytemare/fix-nist-scalar-copy
Browse files Browse the repository at this point in the history
Fix NIST scalar and element Copy() and Set() methods
  • Loading branch information
bytemare committed Aug 22, 2022
2 parents 19786a1 + 3ba956e commit 0f414c4
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 19 deletions.
4 changes: 3 additions & 1 deletion internal/nist/element.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ func (e *Element[P]) Set(element internal.Element) internal.Element {
panic(internal.ErrCastElement)
}

return e.set(ec)
e.p.Set(ec.p)

return e
}

// Copy returns a copy of the receiver.
Expand Down
7 changes: 4 additions & 3 deletions internal/nist/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,23 +140,24 @@ func (s *Scalar) set(scalar *Scalar) *Scalar {
return s
}

// Set sets the receiver to the argument scalar, and returns the receiver.
// Set sets the receiver to the value of the argument scalar, and returns the receiver.
func (s *Scalar) Set(scalar internal.Scalar) internal.Scalar {
if scalar == nil {
return s.set(nil)
}

ec := s.assert(scalar)
s.s.Set(&ec.s)

return s.set(ec)
return s
}

// Copy returns a copy of the Scalar.
func (s *Scalar) Copy() internal.Scalar {
cpy := &Scalar{field: s.field}
cpy.s.Set(&s.s)

return s
return cpy
}

// Encode returns the compressed byte encoding of the scalar.
Expand Down
5 changes: 3 additions & 2 deletions internal/ristretto/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,16 @@ func (s *Scalar) set(scalar *Scalar) *Scalar {
return s
}

// Set sets the receiver to the argument, and returns the receiver.
// Set sets the receiver to the value of the argument scalar, and returns the receiver.
func (s *Scalar) Set(scalar internal.Scalar) internal.Scalar {
if scalar == nil {
return s.set(nil)
}

ec := assert(scalar)
s.scalar = ec.scalar

return s.set(ec)
return s
}

// Copy returns a copy of the receiver.
Expand Down
2 changes: 1 addition & 1 deletion internal/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type Scalar interface {
// IsZero returns whether the scalar is 0.
IsZero() bool

// Set sets the receiver to the argument, and returns the receiver.
// Set sets the receiver to the value of the argument scalar, and returns the receiver.
Set(Scalar) Scalar

// Copy returns a copy of the receiver.
Expand Down
2 changes: 1 addition & 1 deletion scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (s *Scalar) IsZero() bool {
return s.Scalar.IsZero()
}

// Set sets the receiver to the argument, and returns the receiver.
// Set sets the receiver to the value of the argument scalar, and returns the receiver.
func (s *Scalar) Set(scalar *Scalar) *Scalar {
s.Scalar.Set(scalar.Scalar)
return s
Expand Down
45 changes: 34 additions & 11 deletions tests/element_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,43 @@ const (
wrongGroup = "wrong group"
)

func TestElement_SetCopy(t *testing.T) {
testAll(t, func(t2 *testing.T, group *testGroup) {
g := group.id
base := g.Base()
func testElementCopySet(t *testing.T, element, other *crypto.Element) {
// Verify they don't point to the same thing
if &element == &other {
t.Fatalf("Pointer to the same scalar")
}

set := g.NewElement().Set(base)
if set.Equal(base) != 1 {
t.Fatal(expectedEquality)
}
// Verify whether they are equivalent
if element.Equal(other) != 1 {
t.Fatalf("Expected equality")
}

// Verify than operations on one don't affect the other
element.Add(element)
if element.Equal(other) == 1 {
t.Fatalf("Unexpected equality")
}

other.Double().Double()
if element.Equal(other) == 1 {
t.Fatalf("Unexpected equality")
}
}

func TestElementCopy(t *testing.T) {
testAll(t, func(t2 *testing.T, group *testGroup) {
base := group.id.Base()
cpy := base.Copy()
if cpy.Equal(base) != 1 {
t.Fatal(expectedEquality)
}
testElementCopySet(t, base, cpy)
})
}

func TestElementSet(t *testing.T) {
testAll(t, func(t2 *testing.T, group *testGroup) {
base := group.id.Base()
other := group.id.NewElement()
other.Set(base)
testElementCopySet(t, base, other)
})
}

Expand Down
40 changes: 40 additions & 0 deletions tests/scalar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,46 @@ func TestScalar_WrongInput(t *testing.T) {
})
}

func testScalarCopySet(t *testing.T, scalar, other *crypto.Scalar) {
// Verify they don't point to the same thing
if &scalar == &other {
t.Fatalf("Pointer to the same scalar")
}

// Verify whether they are equivalent
if scalar.Equal(other) != 1 {
t.Fatalf("Expected equality")
}

// Verify than operations on one don't affect the other
scalar.Add(scalar)
if scalar.Equal(other) == 1 {
t.Fatalf("Unexpected equality")
}

other.Invert()
if scalar.Equal(other) == 1 {
t.Fatalf("Unexpected equality")
}
}

func TestScalarCopy(t *testing.T) {
testAll(t, func(t2 *testing.T, group *testGroup) {
random := group.id.NewScalar().Random()
cpy := random.Copy()
testScalarCopySet(t, random, cpy)
})
}

func TestScalarSet(t *testing.T) {
testAll(t, func(t2 *testing.T, group *testGroup) {
random := group.id.NewScalar().Random()
other := group.id.NewScalar()
other.Set(random)
testScalarCopySet(t, random, other)
})
}

func TestScalar_Arithmetic(t *testing.T) {
testAll(t, func(t2 *testing.T, group *testGroup) {
scalarTestZero(t, group.id)
Expand Down
1 change: 1 addition & 0 deletions tests/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ func TestEncoding(t *testing.T) {
scalar := group.id.NewScalar().Random()
testEncoding(t, scalar, group.id.NewScalar())

scalar = group.id.NewScalar().Random()
element := group.id.Base().Multiply(scalar)
testEncoding(t, element, group.id.NewElement())
})
Expand Down

0 comments on commit 0f414c4

Please sign in to comment.