Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement encoding.Binary{Marshaler, Unmarshaler} #41

Merged
merged 1 commit into from Aug 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 24 additions & 0 deletions int.go
@@ -1,6 +1,7 @@
package safenum

import (
"errors"
"math/big"
"math/bits"
)
Expand Down Expand Up @@ -33,6 +34,29 @@ func (z *Int) SetBytes(data []byte) *Int {
return z
}

// MarshalBinary implements encoding.BinaryMarshaler.
// The retrned byte slice is always of length 1 + len(i.Abs().Bytes()),
// where the first byte encodes the sign.
func (i *Int) MarshalBinary() ([]byte, error) {
length := 1 + (i.abs.announced+7)/8
out := make([]byte, length)
out[0] = byte(i.sign)
i.abs.FillBytes(out[1:])
return out, nil
}

// UnmarshalBinary implements encoding.BinaryUnmarshaler.
// Returns an error when the length of data is 0,
// since we always expect the first byte to encode the sign.
func (i *Int) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return errors.New("data must contain a sign byte")
}
i.abs.SetBytes(data[1:])
i.sign = Choice(data[0] & 1)
return nil
}

// SetUint64 sets the value of z to x.
//
// This number will be positive.
Expand Down
42 changes: 42 additions & 0 deletions int_test.go
@@ -1,6 +1,7 @@
package safenum

import (
"bytes"
"math/rand"
"reflect"
"testing"
Expand Down Expand Up @@ -156,3 +157,44 @@ func TestIntAddExamples(t *testing.T) {
t.Errorf("%+v != %+v", expected, actual)
}
}

func testIntMarshalBinaryRoundTrip(x *Int) bool {
out, err := x.MarshalBinary()
if err != nil {
return false
}
y := new(Int)
err = y.UnmarshalBinary(out)
if err != nil {
return false
}
return x.Eq(y) == 1
}

func TestIntMarshalBinaryRoundTrip(t *testing.T) {
err := quick.Check(testIntMarshalBinaryRoundTrip, &quick.Config{})
if err != nil {
t.Error(err)
}
}

func testInvalidInt(expected []byte) bool {
x := new(Int)
err := x.UnmarshalBinary(expected)
// empty slice is invalid, so we expect an error
if len(expected) == 0 {
return err != nil
}
expectedBytes := expected[1:]
expectedSign := Choice(expected[0]) & 1
actualBytes := x.Abs().Bytes()
actualSign := x.sign
return (expectedSign == actualSign) && bytes.Equal(expectedBytes, actualBytes)
}

func TestInvalidInt(t *testing.T) {
err := quick.Check(testInvalidInt, &quick.Config{})
if err != nil {
t.Error(err)
}
}
25 changes: 25 additions & 0 deletions num.go
Expand Up @@ -377,6 +377,19 @@ func (z *Nat) Bytes() []byte {
return z.FillBytes(out)
}

// MarshalBinary implements encoding.BinaryMarshaler.
// Returns the same value as Bytes().
func (i *Nat) MarshalBinary() ([]byte, error) {
return i.Bytes(), nil
}

// UnmarshalBinary implements encoding.BinaryUnmarshaler.
// Wraps SetBytes
func (i *Nat) UnmarshalBinary(data []byte) error {
i.SetBytes(data)
return nil
}

// convert a 4 bit value into an ASCII value in constant time
func nibbletoASCII(nibble byte) byte {
w := Word(nibble)
Expand Down Expand Up @@ -665,6 +678,18 @@ func (m *Modulus) Bytes() []byte {
return m.nat.Bytes()
}

// MarshalBinary implements encoding.BinaryMarshaler.
func (i *Modulus) MarshalBinary() ([]byte, error) {
return i.nat.Bytes(), nil
}

// UnmarshalBinary implements encoding.BinaryUnmarshaler.
func (i *Modulus) UnmarshalBinary(data []byte) error {
i.nat.SetBytes(data)
i.precomputeValues()
return nil
}

// Big returns the value of this Modulus as a big.Int
func (m *Modulus) Big() *big.Int {
return m.nat.Big()
Expand Down
41 changes: 41 additions & 0 deletions num_test.go
Expand Up @@ -79,6 +79,47 @@ func TestSetBytesRoundTrip(t *testing.T) {
}
}

func testNatMarshalBinaryRoundTrip(x Nat) bool {
out, err := x.MarshalBinary()
if err != nil {
return false
}
y := new(Nat)
err = y.UnmarshalBinary(out)
if err != nil {
return false
}
return x.Eq(y) == 1
}

func TestNatMarshalBinaryRoundTrip(t *testing.T) {
err := quick.Check(testNatMarshalBinaryRoundTrip, &quick.Config{})
if err != nil {
t.Error(err)
}
}

func testModulusMarshalBinaryRoundTrip(x Modulus) bool {
out, err := x.MarshalBinary()
if err != nil {
return false
}
y := new(Modulus)
err = y.UnmarshalBinary(out)
if err != nil {
return false
}
_, eq, _ := x.Cmp(y)
return eq == 1
}

func TestModulusMarshalBinaryRoundTrip(t *testing.T) {
err := quick.Check(testModulusMarshalBinaryRoundTrip, &quick.Config{})
if err != nil {
t.Error(err)
}
}

func testAddZeroIdentity(n Nat) bool {
if !n.checkInvariants() {
return false
Expand Down