Skip to content

Commit

Permalink
Fix decoding of heterogeneous enums.
Browse files Browse the repository at this point in the history
  • Loading branch information
andreibancioiu committed May 16, 2024
1 parent 821793e commit fc4d051
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 79 deletions.
11 changes: 9 additions & 2 deletions abi/codecForEnum.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@ package abi

import (
"bytes"
"errors"
"fmt"
"io"
)

// EnumValue is an enum (discriminant and fields)
type EnumValue struct {
Discriminant uint8
Fields []Field
Discriminant uint8
Fields []Field
FieldsProvider func(uint8) []Field
}

// EncodeNested encodes the value in the nested form
Expand Down Expand Up @@ -42,13 +44,18 @@ func (value *EnumValue) EncodeTopLevel(writer io.Writer) error {

// DecodeNested decodes the value from the nested form
func (value *EnumValue) DecodeNested(reader io.Reader) error {
if value.FieldsProvider == nil {
return errors.New("cannot decode enum: fields provider is nil")
}

discriminant := &U8Value{}
err := discriminant.DecodeNested(reader)
if err != nil {
return err
}

value.Discriminant = discriminant.Value
value.Fields = value.FieldsProvider(value.Discriminant)

for _, field := range value.Fields {
err := field.Value.DecodeNested(reader)
Expand Down
149 changes: 95 additions & 54 deletions abi/codecForEnum_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package abi

import (
"encoding/hex"
"testing"

"github.com/stretchr/testify/require"
)

func TestEnumValue(t *testing.T) {
Expand Down Expand Up @@ -69,89 +72,127 @@ func TestEnumValue(t *testing.T) {
)
})

t.Run("should decode nested", func(t *testing.T) {
testDecodeNested(t, codec,
"00",
&EnumValue{},
&EnumValue{
Discriminant: 0x00,
t.Run("should decode nested (simple)", func(t *testing.T) {
data, _ := hex.DecodeString("2a")

destination := &EnumValue{
FieldsProvider: func(discriminant uint8) []Field {
return nil
},
)
}

testDecodeNested(t, codec,
"2a",
&EnumValue{},
&EnumValue{
Discriminant: 42,
err := codec.DecodeNested(data, destination)
require.NoError(t, err)
require.Equal(t, uint8(42), destination.Discriminant)
require.Empty(t, destination.Fields)
})

t.Run("should decode nested (simple, zero)", func(t *testing.T) {
data, _ := hex.DecodeString("00")

destination := &EnumValue{
FieldsProvider: func(discriminant uint8) []Field {
return nil
},
)
}

testDecodeNested(t, codec,
"01014142",
&EnumValue{
Fields: []Field{
err := codec.DecodeNested(data, destination)
require.NoError(t, err)
require.Equal(t, uint8(0), destination.Discriminant)
require.Empty(t, destination.Fields)
})

t.Run("should decode nested (heterogeneous)", func(t *testing.T) {
data, _ := hex.DecodeString("01014142")

destination := &EnumValue{
FieldsProvider: func(discriminant uint8) []Field {
return []Field{
{
Value: &U8Value{},
},
{
Value: &U16Value{},
},
},
}
},
&EnumValue{
Discriminant: 0x01,
Fields: []Field{
{
Value: &U8Value{Value: 0x01},
},
{
Value: &U16Value{Value: 0x4142},
},
}

err := codec.DecodeNested(data, destination)
require.NoError(t, err)
require.Equal(t, uint8(1), destination.Discriminant)
require.Equal(t,
[]Field{
{
Value: &U8Value{Value: 0x01},
},
{
Value: &U16Value{Value: 0x4142},
},
},
destination.Fields,
)
})

t.Run("should decode top-level", func(t *testing.T) {
testDecodeTopLevel(t, codec,
"",
&EnumValue{},
&EnumValue{
Discriminant: 0x00,
t.Run("should decode top-level (simple)", func(t *testing.T) {
data, _ := hex.DecodeString("2a")

destination := &EnumValue{
FieldsProvider: func(discriminant uint8) []Field {
return nil
},
)
}

testDecodeTopLevel(t, codec,
"2a",
&EnumValue{},
&EnumValue{
Discriminant: 42,
err := codec.DecodeTopLevel(data, destination)
require.NoError(t, err)
require.Equal(t, uint8(42), destination.Discriminant)
require.Empty(t, destination.Fields)
})

t.Run("should decode top-level (simple, zero)", func(t *testing.T) {
data, _ := hex.DecodeString("")

destination := &EnumValue{
FieldsProvider: func(discriminant uint8) []Field {
return nil
},
)
}

testDecodeTopLevel(t, codec,
"01014142",
&EnumValue{
Fields: []Field{
err := codec.DecodeTopLevel(data, destination)
require.NoError(t, err)
require.Equal(t, uint8(0), destination.Discriminant)
require.Empty(t, destination.Fields)
})

t.Run("should decode top-level (heterogeneous)", func(t *testing.T) {
data, _ := hex.DecodeString("01014142")

destination := &EnumValue{
FieldsProvider: func(discriminant uint8) []Field {
return []Field{
{
Value: &U8Value{},
},
{
Value: &U16Value{},
},
},
}
},
&EnumValue{
Discriminant: 0x01,
Fields: []Field{
{
Value: &U8Value{Value: 0x01},
},
{
Value: &U16Value{Value: 0x4142},
},
}

err := codec.DecodeNested(data, destination)
require.NoError(t, err)
require.Equal(t, uint8(1), destination.Discriminant)
require.Equal(t,
[]Field{
{
Value: &U8Value{Value: 0x01},
},
{
Value: &U16Value{Value: 0x4142},
},
},
destination.Fields,
)
})
}
52 changes: 29 additions & 23 deletions abi/serializer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -552,29 +552,35 @@ func TestSerializer_InRealWorldScenarios(t *testing.T) {
}

action := &EnumValue{
Fields: []Field{
{
Name: "to",
Value: actionTo,
},
{
Name: "egld_amount",
Value: actionEgldAmount,
},
{
Name: "opt_gas_limit",
Value: &OptionValue{
Value: actionGasLimit,
},
},
{
Name: "endpoint_name",
Value: actionEndpointName,
},
{
Name: "arguments",
Value: actionArguments,
},
FieldsProvider: func(discriminant uint8) []Field {
if discriminant == 5 {
return []Field{
{
Name: "to",
Value: actionTo,
},
{
Name: "egld_amount",
Value: actionEgldAmount,
},
{
Name: "opt_gas_limit",
Value: &OptionValue{
Value: actionGasLimit,
},
},
{
Name: "endpoint_name",
Value: actionEndpointName,
},
{
Name: "arguments",
Value: actionArguments,
},
}
}

return nil
},
}

Expand Down

0 comments on commit fc4d051

Please sign in to comment.