diff --git a/pkg/io/binaryWriter.go b/pkg/io/binaryWriter.go index 5ad499988f..5ec3e42fa9 100644 --- a/pkg/io/binaryWriter.go +++ b/pkg/io/binaryWriter.go @@ -11,6 +11,7 @@ import ( // from a struct with many fields. type BinWriter struct { w io.Writer + uv []byte u64 []byte u32 []byte u16 []byte @@ -20,11 +21,12 @@ type BinWriter struct { // NewBinWriterFromIO makes a BinWriter from io.Writer. func NewBinWriterFromIO(iow io.Writer) *BinWriter { - u64 := make([]byte, 8) + uv := make([]byte, 9) + u64 := uv[:8] u32 := u64[:4] u16 := u64[:2] u8 := u64[:1] - return &BinWriter{w: iow, u64: u64, u32: u32, u16: u16, u8: u8} + return &BinWriter{w: iow, uv: uv, u64: u64, u32: u32, u16: u16, u8: u8} } // WriteU64LE writes an uint64 value into the underlying io.Writer in @@ -106,23 +108,31 @@ func (w *BinWriter) WriteVarUint(val uint64) { return } + n := PutVarUint(w.uv, val) + w.WriteBytes(w.uv[:n]) +} + +// PutVarUint puts val in varint form to the pre-allocated buffer. +func PutVarUint(data []byte, val uint64) int { + _ = data[8] if val < 0xfd { - w.WriteB(byte(val)) - return + data[0] = byte(val) + return 1 } if val < 0xFFFF { - w.WriteB(byte(0xfd)) - w.WriteU16LE(uint16(val)) - return + data[0] = byte(0xfd) + binary.LittleEndian.PutUint16(data[1:], uint16(val)) + return 3 } if val < 0xFFFFFFFF { - w.WriteB(byte(0xfe)) - w.WriteU32LE(uint32(val)) - return + data[0] = byte(0xfe) + binary.LittleEndian.PutUint32(data[1:], uint32(val)) + return 5 } - w.WriteB(byte(0xff)) - w.WriteU64LE(val) + data[0] = byte(0xff) + binary.LittleEndian.PutUint64(data[1:], val) + return 9 } // WriteBytes writes a variable byte into the underlying io.Writer without prefix. diff --git a/pkg/vm/stackitem/json.go b/pkg/vm/stackitem/json.go index b930344ba5..2f6088f1ec 100644 --- a/pkg/vm/stackitem/json.go +++ b/pkg/vm/stackitem/json.go @@ -9,8 +9,6 @@ import ( gio "io" "math" "math/big" - - "github.com/nspcc-dev/neo-go/pkg/io" ) // decoder is a wrapper around json.Decoder helping to mimic C# json decoder behaviour. @@ -43,87 +41,112 @@ var ErrTooDeep = errors.New("too deep") // Array, Struct -> array // Map -> map with keys as UTF-8 bytes func ToJSON(item Item) ([]byte, error) { - buf := io.NewBufBinWriter() - toJSON(buf, item) - if buf.Err != nil { - return nil, buf.Err - } - return buf.Bytes(), nil + seen := make(map[Item]sliceNoPointer) + return toJSON(nil, seen, item) +} + +// sliceNoPointer represents sub-slice of a known slice. +// It doesn't contain pointer and uses less memory than `[]byte`. +type sliceNoPointer struct { + start, end int } -func toJSON(buf *io.BufBinWriter, item Item) { - w := buf.BinWriter - if w.Err != nil { - return - } else if buf.Len() > MaxSize { - w.Err = errTooBigSize +func toJSON(data []byte, seen map[Item]sliceNoPointer, item Item) ([]byte, error) { + if len(data) > MaxSize { + return nil, errTooBigSize + } + + if old, ok := seen[item]; ok { + if len(data)+old.end-old.start > MaxSize { + return nil, errTooBigSize + } + return append(data, data[old.start:old.end]...), nil } + + start := len(data) + var err error + switch it := item.(type) { case *Array, *Struct: - w.WriteB('[') - items := it.Value().([]Item) + var items []Item + if a, ok := it.(*Array); ok { + items = a.value + } else { + items = it.(*Struct).value + } + + data = append(data, '[') for i, v := range items { - toJSON(buf, v) + data, err = toJSON(data, seen, v) + if err != nil { + return nil, err + } if i < len(items)-1 { - w.WriteB(',') + data = append(data, ',') } } - w.WriteB(']') + data = append(data, ']') + seen[item] = sliceNoPointer{start, len(data)} case *Map: - w.WriteB('{') + data = append(data, '{') for i := range it.value { // map key can always be converted to []byte // but are not always a valid UTF-8. - writeJSONString(buf.BinWriter, it.value[i].Key) - w.WriteBytes([]byte(`:`)) - toJSON(buf, it.value[i].Value) + raw, err := itemToJSONString(it.value[i].Key) + if err != nil { + return nil, err + } + data = append(data, raw...) + data = append(data, ':') + data, err = toJSON(data, seen, it.value[i].Value) + if err != nil { + return nil, err + } if i < len(it.value)-1 { - w.WriteB(',') + data = append(data, ',') } } - w.WriteB('}') + data = append(data, '}') + seen[item] = sliceNoPointer{start, len(data)} case *BigInteger: if it.value.CmpAbs(big.NewInt(MaxAllowedInteger)) == 1 { - w.Err = fmt.Errorf("%w (MaxAllowedInteger)", ErrInvalidValue) - return + return nil, fmt.Errorf("%w (MaxAllowedInteger)", ErrInvalidValue) } - w.WriteBytes([]byte(it.value.String())) + data = append(data, it.value.String()...) case *ByteArray, *Buffer: - writeJSONString(w, it) + raw, err := itemToJSONString(it) + if err != nil { + return nil, err + } + data = append(data, raw...) case *Bool: if it.value { - w.WriteBytes([]byte("true")) + data = append(data, "true"...) } else { - w.WriteBytes([]byte("false")) + data = append(data, "false"...) } case Null: - w.WriteBytes([]byte("null")) + data = append(data, "null"...) default: - w.Err = fmt.Errorf("%w: %s", ErrUnserializable, it.String()) - return + return nil, fmt.Errorf("%w: %s", ErrUnserializable, it.String()) } - if w.Err == nil && buf.Len() > MaxSize { - w.Err = errTooBigSize + if len(data) > MaxSize { + return nil, errTooBigSize } + return data, nil } -// writeJSONString converts it to string and writes it to w as JSON value +// itemToJSONString converts it to string // surrounded in quotes with control characters escaped. -func writeJSONString(w *io.BinWriter, it Item) { - if w.Err != nil { - return - } +func itemToJSONString(it Item) ([]byte, error) { s, err := ToString(it) if err != nil { - w.Err = err - return + return nil, err } data, _ := json.Marshal(s) // error never occurs because `ToString` checks for validity // ref https://github.com/neo-project/neo-modules/issues/375 and https://github.com/dotnet/runtime/issues/35281 - data = bytes.Replace(data, []byte{'+'}, []byte("\\u002B"), -1) - - w.WriteBytes(data) + return bytes.Replace(data, []byte{'+'}, []byte("\\u002B"), -1), nil } // FromJSON decodes Item from JSON. diff --git a/pkg/vm/stackitem/json_test.go b/pkg/vm/stackitem/json_test.go index 9536d403e7..a6e661115e 100644 --- a/pkg/vm/stackitem/json_test.go +++ b/pkg/vm/stackitem/json_test.go @@ -1,6 +1,7 @@ package stackitem import ( + "errors" "math/big" "testing" @@ -105,6 +106,64 @@ func TestFromToJSON(t *testing.T) { }) } +func testToJSON(t *testing.T, expectedErr error, item Item) { + data, err := ToJSON(item) + if expectedErr != nil { + require.True(t, errors.Is(err, expectedErr), err) + return + } + require.NoError(t, err) + + actual, err := FromJSON(data) + require.NoError(t, err) + require.Equal(t, item, actual) +} + +func TestToJSONCornerCases(t *testing.T) { + // base64 encoding increases size by a factor of ~256/64 = 4 + const maxSize = MaxSize / 4 + + bigByteArray := NewByteArray(make([]byte, maxSize/2)) + smallByteArray := NewByteArray(make([]byte, maxSize/4)) + t.Run("Array", func(t *testing.T) { + arr := NewArray([]Item{bigByteArray}) + testToJSON(t, ErrTooBig, NewArray([]Item{arr, arr})) + + arr.value[0] = smallByteArray + testToJSON(t, nil, NewArray([]Item{arr, arr})) + }) + t.Run("big ByteArray", func(t *testing.T) { + testToJSON(t, ErrTooBig, NewByteArray(make([]byte, maxSize+4))) + }) + t.Run("invalid Map key", func(t *testing.T) { + m := NewMap() + m.Add(Make([]byte{0xe9}), Make(true)) + testToJSON(t, ErrInvalidValue, m) + }) +} + +// getBigArray returns array takes up a lot of storage when serialized. +func getBigArray(depth int) *Array { + arr := NewArray([]Item{}) + for i := 0; i < depth; i++ { + arr = NewArray([]Item{arr, arr}) + } + return arr +} + +func BenchmarkToJSON(b *testing.B) { + arr := getBigArray(15) + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, err := ToJSON(arr) + if err != nil { + b.FailNow() + } + } +} + // This test is taken from the C# code // https://github.com/neo-project/neo/blob/master/tests/neo.UnitTests/VM/UT_Helper.cs#L30 func TestToJSONWithTypeCompat(t *testing.T) { diff --git a/pkg/vm/stackitem/serialization.go b/pkg/vm/stackitem/serialization.go index cbc54b4b00..4fb5c2b56c 100644 --- a/pkg/vm/stackitem/serialization.go +++ b/pkg/vm/stackitem/serialization.go @@ -19,38 +19,35 @@ var ErrUnserializable = errors.New("unserializable") // serContext is an internal serialization context. type serContext struct { - *io.BinWriter - buf *io.BufBinWriter + uv [9]byte + data []byte allowInvalid bool - seen map[Item]bool + seen map[Item]sliceNoPointer } // Serialize encodes given Item into the byte slice. func Serialize(item Item) ([]byte, error) { - w := io.NewBufBinWriter() sc := serContext{ - BinWriter: w.BinWriter, - buf: w, allowInvalid: false, - seen: make(map[Item]bool), + seen: make(map[Item]sliceNoPointer), } - sc.serialize(item) - if w.Err != nil { - return nil, w.Err + err := sc.serialize(item) + if err != nil { + return nil, err } - return w.Bytes(), nil + return sc.data, nil } // EncodeBinary encodes given Item into the given BinWriter. It's // similar to io.Serializable's EncodeBinary, but works with Item // interface. func EncodeBinary(item Item, w *io.BinWriter) { - sc := serContext{ - BinWriter: w, - allowInvalid: false, - seen: make(map[Item]bool), + data, err := Serialize(item) + if err != nil { + w.Err = err + return } - sc.serialize(item) + w.WriteBytes(data) } // EncodeBinaryProtected encodes given Item into the given BinWriter. It's @@ -59,88 +56,112 @@ func EncodeBinary(item Item, w *io.BinWriter) { // (like recursive array) is encountered it just writes special InvalidT // type of element to w. func EncodeBinaryProtected(item Item, w *io.BinWriter) { - bw := io.NewBufBinWriter() sc := serContext{ - BinWriter: bw.BinWriter, - buf: bw, allowInvalid: true, - seen: make(map[Item]bool), + seen: make(map[Item]sliceNoPointer), } - sc.serialize(item) - if bw.Err != nil { + err := sc.serialize(item) + if err != nil { w.WriteBytes([]byte{byte(InvalidT)}) return } - w.WriteBytes(bw.Bytes()) + w.WriteBytes(sc.data) } -func (w *serContext) serialize(item Item) { - if w.Err != nil { - return - } - if w.seen[item] { - w.Err = ErrRecursive - return +func (w *serContext) serialize(item Item) error { + if v, ok := w.seen[item]; ok { + if v.start == v.end { + return ErrRecursive + } + if len(w.data)+v.end-v.start > MaxSize { + return ErrTooBig + } + w.data = append(w.data, w.data[v.start:v.end]...) + return nil } + start := len(w.data) switch t := item.(type) { case *ByteArray: - w.WriteBytes([]byte{byte(ByteArrayT)}) - w.WriteVarBytes(t.Value().([]byte)) + w.data = append(w.data, byte(ByteArrayT)) + data := t.Value().([]byte) + w.appendVarUint(uint64(len(data))) + w.data = append(w.data, data...) case *Buffer: - w.WriteBytes([]byte{byte(BufferT)}) - w.WriteVarBytes(t.Value().([]byte)) + w.data = append(w.data, byte(BufferT)) + data := t.Value().([]byte) + w.appendVarUint(uint64(len(data))) + w.data = append(w.data, data...) case *Bool: - w.WriteBytes([]byte{byte(BooleanT)}) - w.WriteBool(t.Value().(bool)) + w.data = append(w.data, byte(BooleanT)) + if t.Value().(bool) { + w.data = append(w.data, 1) + } else { + w.data = append(w.data, 0) + } case *BigInteger: - w.WriteBytes([]byte{byte(IntegerT)}) - w.WriteVarBytes(bigint.ToBytes(t.Value().(*big.Int))) + w.data = append(w.data, byte(IntegerT)) + data := bigint.ToBytes(t.Value().(*big.Int)) + w.appendVarUint(uint64(len(data))) + w.data = append(w.data, data...) case *Interop: if w.allowInvalid { - w.WriteBytes([]byte{byte(InteropT)}) + w.data = append(w.data, byte(InteropT)) } else { - w.Err = fmt.Errorf("%w: Interop", ErrUnserializable) + return fmt.Errorf("%w: Interop", ErrUnserializable) } case *Array, *Struct: - w.seen[item] = true + w.seen[item] = sliceNoPointer{} _, isArray := t.(*Array) if isArray { - w.WriteBytes([]byte{byte(ArrayT)}) + w.data = append(w.data, byte(ArrayT)) } else { - w.WriteBytes([]byte{byte(StructT)}) + w.data = append(w.data, byte(StructT)) } arr := t.Value().([]Item) - w.WriteVarUint(uint64(len(arr))) + w.appendVarUint(uint64(len(arr))) for i := range arr { - w.serialize(arr[i]) + if err := w.serialize(arr[i]); err != nil { + return err + } } - delete(w.seen, item) + w.seen[item] = sliceNoPointer{start, len(w.data)} case *Map: - w.seen[item] = true - - w.WriteBytes([]byte{byte(MapT)}) - w.WriteVarUint(uint64(len(t.Value().([]MapElement)))) - for i := range t.Value().([]MapElement) { - w.serialize(t.Value().([]MapElement)[i].Key) - w.serialize(t.Value().([]MapElement)[i].Value) + w.seen[item] = sliceNoPointer{} + + elems := t.Value().([]MapElement) + w.data = append(w.data, byte(MapT)) + w.appendVarUint(uint64(len(elems))) + for i := range elems { + if err := w.serialize(elems[i].Key); err != nil { + return err + } + if err := w.serialize(elems[i].Value); err != nil { + return err + } } - delete(w.seen, item) + w.seen[item] = sliceNoPointer{start, len(w.data)} case Null: - w.WriteB(byte(AnyT)) + w.data = append(w.data, byte(AnyT)) case nil: if w.allowInvalid { - w.WriteBytes([]byte{byte(InvalidT)}) + w.data = append(w.data, byte(InvalidT)) } else { - w.Err = fmt.Errorf("%w: nil", ErrUnserializable) + return fmt.Errorf("%w: nil", ErrUnserializable) } } - if w.Err == nil && w.buf != nil && w.buf.Len() > MaxSize { - w.Err = errTooBigSize + if len(w.data) > MaxSize { + return errTooBigSize } + return nil +} + +func (w *serContext) appendVarUint(val uint64) { + n := io.PutVarUint(w.uv[:], val) + w.data = append(w.data, w.uv[:n]...) } // Deserialize decodes Item from the given byte slice. diff --git a/pkg/vm/stackitem/serialization_test.go b/pkg/vm/stackitem/serialization_test.go index 02b07cb242..b5adff6515 100644 --- a/pkg/vm/stackitem/serialization_test.go +++ b/pkg/vm/stackitem/serialization_test.go @@ -4,6 +4,7 @@ import ( "errors" "testing" + "github.com/nspcc-dev/neo-go/pkg/io" "github.com/stretchr/testify/require" ) @@ -21,3 +22,125 @@ func TestSerializationMaxErr(t *testing.T) { _, err = Serialize(aitem) require.True(t, errors.Is(err, ErrTooBig), err) } + +func testSerialize(t *testing.T, expectedErr error, item Item) { + data, err := Serialize(item) + if expectedErr != nil { + require.True(t, errors.Is(err, expectedErr), err) + return + } + require.NoError(t, err) + + actual, err := Deserialize(data) + require.NoError(t, err) + require.Equal(t, item, actual) +} + +func TestSerialize(t *testing.T) { + bigByteArray := NewByteArray(make([]byte, MaxSize/2)) + smallByteArray := NewByteArray(make([]byte, MaxSize/4)) + testArray := func(t *testing.T, newItem func([]Item) Item) { + arr := newItem([]Item{bigByteArray}) + testSerialize(t, nil, arr) + testSerialize(t, ErrTooBig, newItem([]Item{bigByteArray, bigByteArray})) + testSerialize(t, ErrTooBig, newItem([]Item{arr, arr})) + + arr.Value().([]Item)[0] = smallByteArray + testSerialize(t, nil, newItem([]Item{arr, arr})) + + arr.Value().([]Item)[0] = arr + testSerialize(t, ErrRecursive, arr) + } + t.Run("array", func(t *testing.T) { + testArray(t, func(items []Item) Item { return NewArray(items) }) + }) + t.Run("struct", func(t *testing.T) { + testArray(t, func(items []Item) Item { return NewStruct(items) }) + }) + t.Run("buffer", func(t *testing.T) { + testSerialize(t, nil, NewBuffer(make([]byte, MaxSize/2))) + testSerialize(t, errTooBigSize, NewBuffer(make([]byte, MaxSize))) + }) + t.Run("invalid", func(t *testing.T) { + testSerialize(t, ErrUnserializable, NewInterop(42)) + testSerialize(t, ErrUnserializable, nil) + + t.Run("protected interop", func(t *testing.T) { + w := io.NewBufBinWriter() + EncodeBinaryProtected(NewInterop(42), w.BinWriter) + require.NoError(t, w.Err) + + data := w.Bytes() + r := io.NewBinReaderFromBuf(data) + DecodeBinary(r) + require.Error(t, r.Err) + + r = io.NewBinReaderFromBuf(data) + item := DecodeBinaryProtected(r) + require.NoError(t, r.Err) + require.IsType(t, (*Interop)(nil), item) + }) + t.Run("protected nil", func(t *testing.T) { + w := io.NewBufBinWriter() + EncodeBinaryProtected(nil, w.BinWriter) + require.NoError(t, w.Err) + + data := w.Bytes() + r := io.NewBinReaderFromBuf(data) + DecodeBinary(r) + require.Error(t, r.Err) + + r = io.NewBinReaderFromBuf(data) + item := DecodeBinaryProtected(r) + require.NoError(t, r.Err) + require.Nil(t, item) + }) + }) + t.Run("bool", func(t *testing.T) { + testSerialize(t, nil, NewBool(true)) + testSerialize(t, nil, NewBool(false)) + }) + t.Run("null", func(t *testing.T) { + testSerialize(t, nil, Null{}) + }) + t.Run("integer", func(t *testing.T) { + testSerialize(t, nil, Make(0xF)) // 1-byte + testSerialize(t, nil, Make(0xFAB)) // 2-byte + testSerialize(t, nil, Make(0xFABCD)) // 4-byte + testSerialize(t, nil, Make(0xFABCDEFEDC)) // 8-byte + }) + t.Run("map", func(t *testing.T) { + one := Make(1) + m := NewMap() + m.Add(one, m) + testSerialize(t, ErrRecursive, m) + + m.Add(one, bigByteArray) + testSerialize(t, nil, m) + + m.Add(Make(2), bigByteArray) + testSerialize(t, ErrTooBig, m) + + // Cover code path when result becomes too big after key encode. + m = NewMap() + m.Add(Make(0), NewByteArray(make([]byte, MaxSize-MaxKeySize))) + m.Add(NewByteArray(make([]byte, MaxKeySize)), Make(1)) + testSerialize(t, ErrTooBig, m) + }) +} + +func BenchmarkEncodeBinary(b *testing.B) { + arr := getBigArray(15) + + w := io.NewBufBinWriter() + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + w.Reset() + EncodeBinary(arr, w.BinWriter) + if w.Err != nil { + b.FailNow() + } + } +}