From aac5478cdf44076141ad8ab1a7f1491f42728ac6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrei=20B=C4=83ncioiu?= Date: Thu, 30 May 2024 13:34:22 +0300 Subject: [PATCH 1/2] Refactor enums. Add "fields provider" for heterogeneous enums. --- abi/enumValue.go | 38 ++++++--- abi/enumValue_test.go | 175 ++++++++++++++++++++++++++---------------- 2 files changed, 133 insertions(+), 80 deletions(-) diff --git a/abi/enumValue.go b/abi/enumValue.go index 5bfd29d..f4ab90f 100644 --- a/abi/enumValue.go +++ b/abi/enumValue.go @@ -2,22 +2,28 @@ package abi import ( "bytes" + "errors" "fmt" "io" ) -type codecForEnum struct { - generalCodec generalCodec +// EnumValue is an enum (discriminant and fields) +type EnumValue struct { + Discriminant uint8 + Fields []Field + FieldsProvider func(uint8) []Field } -func (c *codecForEnum) encodeNested(writer io.Writer, value EnumValue) error { - err := c.generalCodec.doEncodeNested(writer, U8Value{Value: value.Discriminant}) +// EncodeNested encodes the value in the nested form +func (value *EnumValue) EncodeNested(writer io.Writer) error { + discriminant := U8Value{Value: value.Discriminant} + err := discriminant.EncodeNested(writer) if err != nil { return err } for _, field := range value.Fields { - err := c.generalCodec.doEncodeNested(writer, field.Value) + err := field.Value.EncodeNested(writer) if err != nil { return fmt.Errorf("cannot encode field '%s' of enum, because of: %w", field.Name, err) } @@ -26,26 +32,33 @@ func (c *codecForEnum) encodeNested(writer io.Writer, value EnumValue) error { return nil } -func (c *codecForEnum) encodeTopLevel(writer io.Writer, value EnumValue) error { +// EncodeTopLevel encodes the value in the top-level form +func (value *EnumValue) EncodeTopLevel(writer io.Writer) error { if value.Discriminant == 0 && len(value.Fields) == 0 { // Write nothing return nil } - return c.encodeNested(writer, value) + return value.EncodeNested(writer) } -func (c *codecForEnum) decodeNested(reader io.Reader, value *EnumValue) error { +// DecodeNested decodes the value from the nested form +func (value *EnumValue) DecodeNested(reader io.Reader) error { + if value.FieldsProvider == nil { + return errors.New("cannot decode enum: fields provider is nil") + } + discriminant := &U8Value{} - err := c.generalCodec.doDecodeNested(reader, discriminant) + err := discriminant.DecodeNested(reader) if err != nil { return err } value.Discriminant = discriminant.Value + value.Fields = value.FieldsProvider(value.Discriminant) for _, field := range value.Fields { - err := c.generalCodec.doDecodeNested(reader, field.Value) + err := field.Value.DecodeNested(reader) if err != nil { return fmt.Errorf("cannot decode field '%s' of enum, because of: %w", field.Name, err) } @@ -54,12 +67,13 @@ func (c *codecForEnum) decodeNested(reader io.Reader, value *EnumValue) error { return nil } -func (c *codecForEnum) decodeTopLevel(data []byte, value *EnumValue) error { +// DecodeTopLevel decodes the value from the top-level form +func (value *EnumValue) DecodeTopLevel(data []byte) error { if len(data) == 0 { value.Discriminant = 0 return nil } reader := bytes.NewReader(data) - return c.decodeNested(reader, value) + return value.DecodeNested(reader) } diff --git a/abi/enumValue_test.go b/abi/enumValue_test.go index f4a3dea..b298ab4 100644 --- a/abi/enumValue_test.go +++ b/abi/enumValue_test.go @@ -1,38 +1,39 @@ package abi import ( + "encoding/hex" "testing" + + "github.com/stretchr/testify/require" ) -func TestCodecForEnum(t *testing.T) { - codec, _ := newCodec(argsNewCodec{ - pubKeyLength: 32, - }) +func TestEnumValue(t *testing.T) { + codec := &codec{} t.Run("should encode nested", func(t *testing.T) { testEncodeNested(t, codec, - EnumValue{ + &EnumValue{ Discriminant: 0, }, "00", ) testEncodeNested(t, codec, - EnumValue{ + &EnumValue{ Discriminant: 42, }, "2a", ) testEncodeNested(t, codec, - EnumValue{ + &EnumValue{ Discriminant: 42, Fields: []Field{ { - Value: U8Value{Value: 0x01}, + Value: &U8Value{Value: 0x01}, }, { - Value: U16Value{Value: 0x4142}, + Value: &U16Value{Value: 0x4142}, }, }, }, @@ -42,28 +43,28 @@ func TestCodecForEnum(t *testing.T) { t.Run("should encode top-level", func(t *testing.T) { testEncodeTopLevel(t, codec, - EnumValue{ + &EnumValue{ Discriminant: 0, }, "", ) testEncodeTopLevel(t, codec, - EnumValue{ + &EnumValue{ Discriminant: 42, }, "2a", ) testEncodeTopLevel(t, codec, - EnumValue{ + &EnumValue{ Discriminant: 42, Fields: []Field{ { - Value: U8Value{Value: 0x01}, + Value: &U8Value{Value: 0x01}, }, { - Value: U16Value{Value: 0x4142}, + Value: &U16Value{Value: 0x4142}, }, }, }, @@ -71,89 +72,127 @@ func TestCodecForEnum(t *testing.T) { ) }) - t.Run("should decode nested", func(t *testing.T) { - testDecodeNested(t, codec, - "00", - &EnumValue{}, - &EnumValue{ - Discriminant: 0x00, + t.Run("should decode nested (simple)", func(t *testing.T) { + data, _ := hex.DecodeString("2a") + + destination := &EnumValue{ + FieldsProvider: func(discriminant uint8) []Field { + return nil }, - ) + } - testDecodeNested(t, codec, - "2a", - &EnumValue{}, - &EnumValue{ - Discriminant: 42, + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, uint8(42), destination.Discriminant) + require.Empty(t, destination.Fields) + }) + + t.Run("should decode nested (simple, zero)", func(t *testing.T) { + data, _ := hex.DecodeString("00") + + destination := &EnumValue{ + FieldsProvider: func(discriminant uint8) []Field { + return nil }, - ) + } - testDecodeNested(t, codec, - "01014142", - &EnumValue{ - Fields: []Field{ + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, uint8(0), destination.Discriminant) + require.Empty(t, destination.Fields) + }) + + t.Run("should decode nested (heterogeneous)", func(t *testing.T) { + data, _ := hex.DecodeString("01014142") + + destination := &EnumValue{ + FieldsProvider: func(discriminant uint8) []Field { + return []Field{ { Value: &U8Value{}, }, { Value: &U16Value{}, }, - }, + } }, - &EnumValue{ - Discriminant: 0x01, - Fields: []Field{ - { - Value: &U8Value{Value: 0x01}, - }, - { - Value: &U16Value{Value: 0x4142}, - }, + } + + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, uint8(1), destination.Discriminant) + require.Equal(t, + []Field{ + { + Value: &U8Value{Value: 0x01}, + }, + { + Value: &U16Value{Value: 0x4142}, }, }, + destination.Fields, ) }) - t.Run("should decode top-level", func(t *testing.T) { - testDecodeTopLevel(t, codec, - "", - &EnumValue{}, - &EnumValue{ - Discriminant: 0x00, + t.Run("should decode top-level (simple)", func(t *testing.T) { + data, _ := hex.DecodeString("2a") + + destination := &EnumValue{ + FieldsProvider: func(discriminant uint8) []Field { + return nil }, - ) + } - testDecodeTopLevel(t, codec, - "2a", - &EnumValue{}, - &EnumValue{ - Discriminant: 42, + err := codec.DecodeTopLevel(data, destination) + require.NoError(t, err) + require.Equal(t, uint8(42), destination.Discriminant) + require.Empty(t, destination.Fields) + }) + + t.Run("should decode top-level (simple, zero)", func(t *testing.T) { + data, _ := hex.DecodeString("") + + destination := &EnumValue{ + FieldsProvider: func(discriminant uint8) []Field { + return nil }, - ) + } - testDecodeTopLevel(t, codec, - "01014142", - &EnumValue{ - Fields: []Field{ + err := codec.DecodeTopLevel(data, destination) + require.NoError(t, err) + require.Equal(t, uint8(0), destination.Discriminant) + require.Empty(t, destination.Fields) + }) + + t.Run("should decode top-level (heterogeneous)", func(t *testing.T) { + data, _ := hex.DecodeString("01014142") + + destination := &EnumValue{ + FieldsProvider: func(discriminant uint8) []Field { + return []Field{ { Value: &U8Value{}, }, { Value: &U16Value{}, }, - }, + } }, - &EnumValue{ - Discriminant: 0x01, - Fields: []Field{ - { - Value: &U8Value{Value: 0x01}, - }, - { - Value: &U16Value{Value: 0x4142}, - }, + } + + err := codec.DecodeNested(data, destination) + require.NoError(t, err) + require.Equal(t, uint8(1), destination.Discriminant) + require.Equal(t, + []Field{ + { + Value: &U8Value{Value: 0x01}, + }, + { + Value: &U16Value{Value: 0x4142}, }, }, + destination.Fields, ) }) } From 1e8992a2043471e0d38f181963833158f530fa8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrei=20B=C4=83ncioiu?= Date: Thu, 30 May 2024 13:37:29 +0300 Subject: [PATCH 2/2] Adjust test. --- abi/serializer_test.go | 54 +++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/abi/serializer_test.go b/abi/serializer_test.go index 184a68d..f526ab1 100644 --- a/abi/serializer_test.go +++ b/abi/serializer_test.go @@ -551,30 +551,36 @@ func TestSerializer_InRealWorldScenarios(t *testing.T) { }, } - action := EnumValue{ - Fields: []Field{ - { - Name: "to", - Value: &actionTo, - }, - { - Name: "egld_amount", - Value: &actionEgldAmount, - }, - { - Name: "opt_gas_limit", - Value: &OptionValue{ - Value: &actionGasLimit, - }, - }, - { - Name: "endpoint_name", - Value: &actionEndpointName, - }, - { - Name: "arguments", - Value: &actionArguments, - }, + action := &EnumValue{ + FieldsProvider: func(discriminant uint8) []Field { + if discriminant == 5 { + return []Field{ + { + Name: "to", + Value: actionTo, + }, + { + Name: "egld_amount", + Value: actionEgldAmount, + }, + { + Name: "opt_gas_limit", + Value: &OptionValue{ + Value: actionGasLimit, + }, + }, + { + Name: "endpoint_name", + Value: actionEndpointName, + }, + { + Name: "arguments", + Value: actionArguments, + }, + } + } + + return nil }, }