Skip to content

Commit

Permalink
Refactoring: move codec logic into struct values (composite pattern w…
Browse files Browse the repository at this point in the history
…ip).
  • Loading branch information
andreibancioiu committed May 15, 2024
1 parent 3e3e0ff commit b41b39b
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 152 deletions.
74 changes: 24 additions & 50 deletions abi/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@ import (
)

type codec struct {
codecForBool *codecForBool
codecForSmallInt *codecForSmallInt
codecForBigInt *codecForBigInt
codecForAddress *codecForAddress
codecForString *codecForString
codecForBytes *codecForBytes
codecForStruct *codeForStruct
codecForEnum *codecForEnum
codecForOption *codecForOption
Expand All @@ -34,14 +29,7 @@ func newCodec(args argsNewCodec) (*codec, error) {
}

codec := &codec{
codecForBool: &codecForBool{},
codecForSmallInt: &codecForSmallInt{},
codecForBigInt: &codecForBigInt{},
codecForAddress: &codecForAddress{
pubKeyLength: args.pubKeyLength,
},
codecForString: &codecForString{},
codecForBytes: &codecForBytes{},
}

codec.codecForStruct = &codeForStruct{
Expand Down Expand Up @@ -77,7 +65,7 @@ func (c *codec) EncodeNested(value any) ([]byte, error) {
func (c *codec) doEncodeNested(writer io.Writer, value any) error {
switch value := value.(type) {
case BoolValue:
return c.codecForBool.encodeNested(writer, value)
return value.encodeNested(writer)
case U8Value:
return c.codecForSmallInt.encodeNested(writer, value.Value, 1)
case U16Value:
Expand All @@ -95,15 +83,15 @@ func (c *codec) doEncodeNested(writer io.Writer, value any) error {
case I64Value:
return c.codecForSmallInt.encodeNested(writer, value.Value, 8)
case BigUIntValue:
return c.codecForBigInt.encodeNestedUnsigned(writer, value.Value)
return value.encodeNested(writer)
case BigIntValue:
return c.codecForBigInt.encodeNestedSigned(writer, value.Value)
return value.encodeNested(writer)
case AddressValue:
return c.codecForAddress.encodeNested(writer, value)
return value.encodeNested(writer)
case StringValue:
return c.codecForString.encodeNested(writer, value)
return value.encodeNested(writer)
case BytesValue:
return c.codecForBytes.encodeNested(writer, value)
return value.encodeNested(writer)
case StructValue:
return c.codecForStruct.encodeNested(writer, value)
case EnumValue:
Expand Down Expand Up @@ -131,7 +119,7 @@ func (c *codec) EncodeTopLevel(value any) ([]byte, error) {
func (c *codec) doEncodeTopLevel(writer io.Writer, value any) error {
switch value := value.(type) {
case BoolValue:
return c.codecForBool.encodeTopLevel(writer, value)
return value.encodeTopLevel(writer)
case U8Value:
return c.codecForSmallInt.encodeTopLevelUnsigned(writer, uint64(value.Value))
case U16Value:
Expand All @@ -149,15 +137,15 @@ func (c *codec) doEncodeTopLevel(writer io.Writer, value any) error {
case I64Value:
return c.codecForSmallInt.encodeTopLevelSigned(writer, value.Value)
case BigUIntValue:
return c.codecForBigInt.encodeTopLevelUnsigned(writer, value.Value)
return value.encodeTopLevel(writer)
case BigIntValue:
return c.codecForBigInt.encodeTopLevelSigned(writer, value.Value)
return value.encodeTopLevel(writer)
case AddressValue:
return c.codecForAddress.encodeTopLevel(writer, value)
return value.encodeTopLevel(writer)
case StringValue:
return c.codecForString.encodeTopLevel(writer, value)
return value.encodeTopLevel(writer)
case BytesValue:
return c.codecForBytes.encodeTopLevel(writer, value)
return value.encodeTopLevel(writer)
case StructValue:
return c.codecForStruct.encodeTopLevel(writer, value)
case EnumValue:
Expand Down Expand Up @@ -185,7 +173,7 @@ func (c *codec) DecodeNested(data []byte, value any) error {
func (c *codec) doDecodeNested(reader io.Reader, value any) error {
switch value := value.(type) {
case *BoolValue:
return c.codecForBool.decodeNested(reader, value)
return value.decodeNested(reader)
case *U8Value:
return c.codecForSmallInt.decodeNested(reader, &value.Value, 1)
case *U16Value:
Expand All @@ -203,27 +191,15 @@ func (c *codec) doDecodeNested(reader io.Reader, value any) error {
case *I64Value:
return c.codecForSmallInt.decodeNested(reader, &value.Value, 8)
case *BigUIntValue:
n, err := c.codecForBigInt.decodeNestedUnsigned(reader)
if err != nil {
return err
}

value.Value = n
return nil
return value.decodeNested(reader)
case *BigIntValue:
n, err := c.codecForBigInt.decodeNestedSigned(reader)
if err != nil {
return err
}

value.Value = n
return nil
return value.decodeNested(reader)
case *AddressValue:
return c.codecForAddress.decodeNested(reader, value)
return value.decodeNested(reader)
case *StringValue:
return c.codecForString.decodeNested(reader, value)
return value.decodeNested(reader)
case *BytesValue:
return c.codecForBytes.decodeNested(reader, value)
return value.decodeNested(reader)
case *StructValue:
return c.codecForStruct.decodeNested(reader, value)
case *EnumValue:
Expand All @@ -250,7 +226,7 @@ func (c *codec) DecodeTopLevel(data []byte, value any) error {
func (c *codec) doDecodeTopLevel(data []byte, value any) error {
switch value := value.(type) {
case *BoolValue:
return c.codecForBool.decodeTopLevel(data, value)
return value.decodeTopLevel(data)
case *U8Value:
n, err := c.codecForSmallInt.decodeTopLevelUnsigned(data, math.MaxUint8)
if err != nil {
Expand Down Expand Up @@ -309,17 +285,15 @@ func (c *codec) doDecodeTopLevel(data []byte, value any) error {

value.Value = int64(n)
case *BigUIntValue:
n := c.codecForBigInt.decodeTopLevelUnsigned(data)
value.Value = n
value.decodeTopLevel(data)
case *BigIntValue:
n := c.codecForBigInt.decodeTopLevelSigned(data)
value.Value = n
value.decodeTopLevel(data)
case *AddressValue:
return c.codecForAddress.decodeTopLevel(data, value)
return value.decodeTopLevel(data)
case *StringValue:
return c.codecForString.decodeTopLevel(data, value)
return value.decodeTopLevel(data)
case *BytesValue:
return c.codecForBytes.decodeTopLevel(data, value)
return value.decodeTopLevel(data)
case *StructValue:
return c.codecForStruct.decodeTopLevel(data, value)
case *EnumValue:
Expand Down
25 changes: 13 additions & 12 deletions abi/codecForAddress.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ import (
"io"
)

type codecForAddress struct {
pubKeyLength int
// AddressValue is a wrapper for an address
type AddressValue struct {
Value []byte
}

func (c *codecForAddress) encodeNested(writer io.Writer, value AddressValue) error {
err := c.checkPubKeyLength(value.Value)
func (value *AddressValue) encodeNested(writer io.Writer) error {
err := value.checkPubKeyLength(value.Value)
if err != nil {
return err
}
Expand All @@ -19,12 +20,12 @@ func (c *codecForAddress) encodeNested(writer io.Writer, value AddressValue) err
return err
}

func (c *codecForAddress) encodeTopLevel(writer io.Writer, value AddressValue) error {
return c.encodeNested(writer, value)
func (value *AddressValue) encodeTopLevel(writer io.Writer) error {
return value.encodeNested(writer)
}

func (c *codecForAddress) decodeNested(reader io.Reader, value *AddressValue) error {
data, err := readBytesExactly(reader, c.pubKeyLength)
func (value *AddressValue) decodeNested(reader io.Reader) error {
data, err := readBytesExactly(reader, pubKeyLength)
if err != nil {
return err
}
Expand All @@ -33,8 +34,8 @@ func (c *codecForAddress) decodeNested(reader io.Reader, value *AddressValue) er
return nil
}

func (c *codecForAddress) decodeTopLevel(data []byte, value *AddressValue) error {
err := c.checkPubKeyLength(data)
func (value *AddressValue) decodeTopLevel(data []byte) error {
err := value.checkPubKeyLength(data)
if err != nil {
return err
}
Expand All @@ -43,8 +44,8 @@ func (c *codecForAddress) decodeTopLevel(data []byte, value *AddressValue) error
return nil
}

func (c *codecForAddress) checkPubKeyLength(pubkey []byte) error {
if len(pubkey) != c.pubKeyLength {
func (value *AddressValue) checkPubKeyLength(pubkey []byte) error {
if len(pubkey) != pubKeyLength {
return fmt.Errorf("public key (address) has invalid length: %d", len(pubkey))
}

Expand Down
89 changes: 49 additions & 40 deletions abi/codecForBigInt.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ import (
twos "github.com/multiversx/mx-components-big-int/twos-complement"
)

type codecForBigInt struct {
// BigUIntValue is a wrapper for a big integer (unsigned)
type BigUIntValue struct {
Value *big.Int
}

func (c *codecForBigInt) encodeNestedUnsigned(writer io.Writer, value *big.Int) error {
data := value.Bytes()
func (value *BigUIntValue) encodeNested(writer io.Writer) error {
data := value.Value.Bytes()
dataLength := len(data)

// Write the length of the payload
Expand All @@ -29,81 +31,88 @@ func (c *codecForBigInt) encodeNestedUnsigned(writer io.Writer, value *big.Int)
return nil
}

func (c *codecForBigInt) encodeNestedSigned(writer io.Writer, value *big.Int) error {
data := twos.ToBytes(value)
dataLength := len(data)
func (value *BigUIntValue) encodeTopLevel(writer io.Writer) error {
data := value.Value.Bytes()
_, err := writer.Write(data)
if err != nil {
return err
}

// Write the length of the payload
err := encodeLength(writer, uint32(dataLength))
return nil
}

func (value *BigUIntValue) decodeNested(reader io.Reader) error {
// Read the length of the payload
length, err := decodeLength(reader)
if err != nil {
return err
}

// Write the payload
_, err = writer.Write(data)
// Read the payload
data, err := readBytesExactly(reader, int(length))
if err != nil {
return err
}

value.Value = big.NewInt(0).SetBytes(data)
return nil
}

func (c *codecForBigInt) encodeTopLevelUnsigned(writer io.Writer, value *big.Int) error {
data := value.Bytes()
_, err := writer.Write(data)
func (value *BigUIntValue) decodeTopLevel(data []byte) {
value.Value = big.NewInt(0).SetBytes(data)
}

// BigIntValue is a wrapper for a big integer (signed)
type BigIntValue struct {
Value *big.Int
}

func (value *BigIntValue) encodeNested(writer io.Writer) error {
data := twos.ToBytes(value.Value)
dataLength := len(data)

// Write the length of the payload
err := encodeLength(writer, uint32(dataLength))
if err != nil {
return err
}

return nil
}

func (c *codecForBigInt) encodeTopLevelSigned(writer io.Writer, value *big.Int) error {
data := twos.ToBytes(value)
_, err := writer.Write(data)
// Write the payload
_, err = writer.Write(data)
if err != nil {
return err
}

return nil
}

func (c *codecForBigInt) decodeNestedUnsigned(reader io.Reader) (*big.Int, error) {
// Read the length of the payload
length, err := decodeLength(reader)
if err != nil {
return nil, err
}

// Read the payload
data, err := readBytesExactly(reader, int(length))
func (value *BigIntValue) encodeTopLevel(writer io.Writer) error {
data := twos.ToBytes(value.Value)
_, err := writer.Write(data)
if err != nil {
return nil, err
return err
}

return big.NewInt(0).SetBytes(data), nil
return nil
}

func (c *codecForBigInt) decodeNestedSigned(reader io.Reader) (*big.Int, error) {
func (value *BigIntValue) decodeNested(reader io.Reader) error {
// Read the length of the payload
length, err := decodeLength(reader)
if err != nil {
return nil, err
return err
}

// Read the payload
data, err := readBytesExactly(reader, int(length))
if err != nil {
return nil, err
return err
}

return twos.FromBytes(data), nil
}

func (c *codecForBigInt) decodeTopLevelUnsigned(data []byte) *big.Int {
return big.NewInt(0).SetBytes(data)
value.Value = twos.FromBytes(data)
return nil
}

func (c *codecForBigInt) decodeTopLevelSigned(data []byte) *big.Int {
return twos.FromBytes(data)
func (value *BigIntValue) decodeTopLevel(data []byte) {
value.Value = twos.FromBytes(data)
}
Loading

0 comments on commit b41b39b

Please sign in to comment.