Skip to content

Commit

Permalink
Merge pull request #16 from multiversx/composite-enums
Browse files Browse the repository at this point in the history
Refactor enums. Add "fields provider" for heterogeneous enums
  • Loading branch information
andreibancioiu committed May 30, 2024
2 parents f8168d1 + 1e8992a commit 8426314
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 104 deletions.
38 changes: 26 additions & 12 deletions abi/enumValue.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,28 @@ package abi

import (
"bytes"
"errors"
"fmt"
"io"
)

type codecForEnum struct {
generalCodec generalCodec
// EnumValue is an enum (discriminant and fields)
type EnumValue struct {
Discriminant uint8
Fields []Field
FieldsProvider func(uint8) []Field
}

func (c *codecForEnum) encodeNested(writer io.Writer, value EnumValue) error {
err := c.generalCodec.doEncodeNested(writer, U8Value{Value: value.Discriminant})
// EncodeNested encodes the value in the nested form
func (value *EnumValue) EncodeNested(writer io.Writer) error {
discriminant := U8Value{Value: value.Discriminant}
err := discriminant.EncodeNested(writer)
if err != nil {
return err
}

for _, field := range value.Fields {
err := c.generalCodec.doEncodeNested(writer, field.Value)
err := field.Value.EncodeNested(writer)
if err != nil {
return fmt.Errorf("cannot encode field '%s' of enum, because of: %w", field.Name, err)
}
Expand All @@ -26,26 +32,33 @@ func (c *codecForEnum) encodeNested(writer io.Writer, value EnumValue) error {
return nil
}

func (c *codecForEnum) encodeTopLevel(writer io.Writer, value EnumValue) error {
// EncodeTopLevel encodes the value in the top-level form
func (value *EnumValue) EncodeTopLevel(writer io.Writer) error {
if value.Discriminant == 0 && len(value.Fields) == 0 {
// Write nothing
return nil
}

return c.encodeNested(writer, value)
return value.EncodeNested(writer)
}

func (c *codecForEnum) decodeNested(reader io.Reader, value *EnumValue) error {
// DecodeNested decodes the value from the nested form
func (value *EnumValue) DecodeNested(reader io.Reader) error {
if value.FieldsProvider == nil {
return errors.New("cannot decode enum: fields provider is nil")
}

discriminant := &U8Value{}
err := c.generalCodec.doDecodeNested(reader, discriminant)
err := discriminant.DecodeNested(reader)
if err != nil {
return err
}

value.Discriminant = discriminant.Value
value.Fields = value.FieldsProvider(value.Discriminant)

for _, field := range value.Fields {
err := c.generalCodec.doDecodeNested(reader, field.Value)
err := field.Value.DecodeNested(reader)
if err != nil {
return fmt.Errorf("cannot decode field '%s' of enum, because of: %w", field.Name, err)
}
Expand All @@ -54,12 +67,13 @@ func (c *codecForEnum) decodeNested(reader io.Reader, value *EnumValue) error {
return nil
}

func (c *codecForEnum) decodeTopLevel(data []byte, value *EnumValue) error {
// DecodeTopLevel decodes the value from the top-level form
func (value *EnumValue) DecodeTopLevel(data []byte) error {
if len(data) == 0 {
value.Discriminant = 0
return nil
}

reader := bytes.NewReader(data)
return c.decodeNested(reader, value)
return value.DecodeNested(reader)
}
175 changes: 107 additions & 68 deletions abi/enumValue_test.go
Original file line number Diff line number Diff line change
@@ -1,38 +1,39 @@
package abi

import (
"encoding/hex"
"testing"

"github.com/stretchr/testify/require"
)

func TestCodecForEnum(t *testing.T) {
codec, _ := newCodec(argsNewCodec{
pubKeyLength: 32,
})
func TestEnumValue(t *testing.T) {
codec := &codec{}

t.Run("should encode nested", func(t *testing.T) {
testEncodeNested(t, codec,
EnumValue{
&EnumValue{
Discriminant: 0,
},
"00",
)

testEncodeNested(t, codec,
EnumValue{
&EnumValue{
Discriminant: 42,
},
"2a",
)

testEncodeNested(t, codec,
EnumValue{
&EnumValue{
Discriminant: 42,
Fields: []Field{
{
Value: U8Value{Value: 0x01},
Value: &U8Value{Value: 0x01},
},
{
Value: U16Value{Value: 0x4142},
Value: &U16Value{Value: 0x4142},
},
},
},
Expand All @@ -42,118 +43,156 @@ func TestCodecForEnum(t *testing.T) {

t.Run("should encode top-level", func(t *testing.T) {
testEncodeTopLevel(t, codec,
EnumValue{
&EnumValue{
Discriminant: 0,
},
"",
)

testEncodeTopLevel(t, codec,
EnumValue{
&EnumValue{
Discriminant: 42,
},
"2a",
)

testEncodeTopLevel(t, codec,
EnumValue{
&EnumValue{
Discriminant: 42,
Fields: []Field{
{
Value: U8Value{Value: 0x01},
Value: &U8Value{Value: 0x01},
},
{
Value: U16Value{Value: 0x4142},
Value: &U16Value{Value: 0x4142},
},
},
},
"2a014142",
)
})

t.Run("should decode nested", func(t *testing.T) {
testDecodeNested(t, codec,
"00",
&EnumValue{},
&EnumValue{
Discriminant: 0x00,
t.Run("should decode nested (simple)", func(t *testing.T) {
data, _ := hex.DecodeString("2a")

destination := &EnumValue{
FieldsProvider: func(discriminant uint8) []Field {
return nil
},
)
}

testDecodeNested(t, codec,
"2a",
&EnumValue{},
&EnumValue{
Discriminant: 42,
err := codec.DecodeNested(data, destination)
require.NoError(t, err)
require.Equal(t, uint8(42), destination.Discriminant)
require.Empty(t, destination.Fields)
})

t.Run("should decode nested (simple, zero)", func(t *testing.T) {
data, _ := hex.DecodeString("00")

destination := &EnumValue{
FieldsProvider: func(discriminant uint8) []Field {
return nil
},
)
}

testDecodeNested(t, codec,
"01014142",
&EnumValue{
Fields: []Field{
err := codec.DecodeNested(data, destination)
require.NoError(t, err)
require.Equal(t, uint8(0), destination.Discriminant)
require.Empty(t, destination.Fields)
})

t.Run("should decode nested (heterogeneous)", func(t *testing.T) {
data, _ := hex.DecodeString("01014142")

destination := &EnumValue{
FieldsProvider: func(discriminant uint8) []Field {
return []Field{
{
Value: &U8Value{},
},
{
Value: &U16Value{},
},
},
}
},
&EnumValue{
Discriminant: 0x01,
Fields: []Field{
{
Value: &U8Value{Value: 0x01},
},
{
Value: &U16Value{Value: 0x4142},
},
}

err := codec.DecodeNested(data, destination)
require.NoError(t, err)
require.Equal(t, uint8(1), destination.Discriminant)
require.Equal(t,
[]Field{
{
Value: &U8Value{Value: 0x01},
},
{
Value: &U16Value{Value: 0x4142},
},
},
destination.Fields,
)
})

t.Run("should decode top-level", func(t *testing.T) {
testDecodeTopLevel(t, codec,
"",
&EnumValue{},
&EnumValue{
Discriminant: 0x00,
t.Run("should decode top-level (simple)", func(t *testing.T) {
data, _ := hex.DecodeString("2a")

destination := &EnumValue{
FieldsProvider: func(discriminant uint8) []Field {
return nil
},
)
}

testDecodeTopLevel(t, codec,
"2a",
&EnumValue{},
&EnumValue{
Discriminant: 42,
err := codec.DecodeTopLevel(data, destination)
require.NoError(t, err)
require.Equal(t, uint8(42), destination.Discriminant)
require.Empty(t, destination.Fields)
})

t.Run("should decode top-level (simple, zero)", func(t *testing.T) {
data, _ := hex.DecodeString("")

destination := &EnumValue{
FieldsProvider: func(discriminant uint8) []Field {
return nil
},
)
}

testDecodeTopLevel(t, codec,
"01014142",
&EnumValue{
Fields: []Field{
err := codec.DecodeTopLevel(data, destination)
require.NoError(t, err)
require.Equal(t, uint8(0), destination.Discriminant)
require.Empty(t, destination.Fields)
})

t.Run("should decode top-level (heterogeneous)", func(t *testing.T) {
data, _ := hex.DecodeString("01014142")

destination := &EnumValue{
FieldsProvider: func(discriminant uint8) []Field {
return []Field{
{
Value: &U8Value{},
},
{
Value: &U16Value{},
},
},
}
},
&EnumValue{
Discriminant: 0x01,
Fields: []Field{
{
Value: &U8Value{Value: 0x01},
},
{
Value: &U16Value{Value: 0x4142},
},
}

err := codec.DecodeNested(data, destination)
require.NoError(t, err)
require.Equal(t, uint8(1), destination.Discriminant)
require.Equal(t,
[]Field{
{
Value: &U8Value{Value: 0x01},
},
{
Value: &U16Value{Value: 0x4142},
},
},
destination.Fields,
)
})
}
Loading

0 comments on commit 8426314

Please sign in to comment.