diff --git a/compact-u16.go b/compact-u16.go index 2ab6474..f9a5695 100644 --- a/compact-u16.go +++ b/compact-u16.go @@ -15,41 +15,51 @@ package bin import ( + "fmt" "io" + "math" ) // EncodeCompactU16Length encodes a "Compact-u16" length into the provided slice pointer. // See https://docs.solana.com/developing/programming-model/transactions#compact-u16-format // See https://github.com/solana-labs/solana/blob/2ef2b6daa05a7cff057e9d3ef95134cee3e4045d/web3.js/src/util/shortvec-encoding.ts -func EncodeCompactU16Length(bytes *[]byte, ln int) { +func EncodeCompactU16Length(buf *[]byte, ln int) error { + if ln < 0 || ln > math.MaxUint16 { + return fmt.Errorf("length %d out of range", ln) + } rem_len := ln for { - elem := rem_len & 0x7f + elem := uint8(rem_len & 0x7f) rem_len >>= 7 if rem_len == 0 { - *bytes = append(*bytes, byte(elem)) + *buf = append(*buf, elem) break } else { elem |= 0x80 - *bytes = append(*bytes, byte(elem)) + *buf = append(*buf, elem) } } + return nil } -// DecodeCompactU16Length decodes a "Compact-u16" length from the provided byte slice. -func DecodeCompactU16Length(bytes []byte) int { - v, _, _ := DecodeCompactU16(bytes) - return v -} +const _MAX_COMPACTU16_ENCODING_LENGTH = 3 func DecodeCompactU16(bytes []byte) (int, int, error) { ln := 0 size := 0 - for { + for nth_byte := 0; nth_byte < _MAX_COMPACTU16_ENCODING_LENGTH; nth_byte++ { if len(bytes) == 0 { return 0, 0, io.ErrUnexpectedEOF } elem := int(bytes[0]) + if elem == 0 && nth_byte != 0 { + return 0, 0, fmt.Errorf("alias") + } + if nth_byte >= _MAX_COMPACTU16_ENCODING_LENGTH { + return 0, 0, fmt.Errorf("too long: %d", nth_byte+1) + } else if nth_byte == _MAX_COMPACTU16_ENCODING_LENGTH-1 && (elem&0x80) != 0 { + return 0, 0, fmt.Errorf("byte three continues") + } bytes = bytes[1:] ln |= (elem & 0x7f) << (size * 7) size += 1 @@ -57,6 +67,14 @@ func DecodeCompactU16(bytes []byte) (int, int, error) { break } } + // check for non-valid sizes + if size == 0 || size > _MAX_COMPACTU16_ENCODING_LENGTH { + return 0, 0, fmt.Errorf("invalid size: %d", size) + } + // check for non-valid lengths + if ln < 0 || ln > math.MaxUint16 { + return 0, 0, fmt.Errorf("invalid length: %d", ln) + } return ln, size, nil } @@ -64,17 +82,33 @@ func DecodeCompactU16(bytes []byte) (int, int, error) { func DecodeCompactU16LengthFromByteReader(reader io.ByteReader) (int, error) { ln := 0 size := 0 - for { + for nth_byte := 0; nth_byte < _MAX_COMPACTU16_ENCODING_LENGTH; nth_byte++ { elemByte, err := reader.ReadByte() if err != nil { return 0, err } elem := int(elemByte) + if elem == 0 && nth_byte != 0 { + return 0, fmt.Errorf("alias") + } + if nth_byte >= _MAX_COMPACTU16_ENCODING_LENGTH { + return 0, fmt.Errorf("too long: %d", nth_byte+1) + } else if nth_byte == _MAX_COMPACTU16_ENCODING_LENGTH-1 && (elem&0x80) != 0 { + return 0, fmt.Errorf("byte three continues") + } ln |= (elem & 0x7f) << (size * 7) size += 1 if (elem & 0x80) == 0 { break } } + // check for non-valid sizes + if size == 0 || size > _MAX_COMPACTU16_ENCODING_LENGTH { + return 0, fmt.Errorf("invalid size: %d", size) + } + // check for non-valid lengths + if ln < 0 || ln > math.MaxUint16 { + return 0, fmt.Errorf("invalid length: %d", ln) + } return ln, nil } diff --git a/compact-u16_test.go b/compact-u16_test.go index 7ffe565..077b88d 100644 --- a/compact-u16_test.go +++ b/compact-u16_test.go @@ -23,13 +23,17 @@ import ( ) func TestCompactU16(t *testing.T) { - candidates := []int{3, 0x7f, 0x7f + 1, 0x3fff, 0x3fff + 1} + candidates := []int{0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 100, 1000, 10000, math.MaxUint16 - 1, math.MaxUint16} for _, val := range candidates { + if val < 0 || val > math.MaxUint16 { + panic("value too large") + } buf := make([]byte, 0) - EncodeCompactU16Length(&buf, val) + require.NoError(t, EncodeCompactU16Length(&buf, val)) buf = append(buf, []byte("hello world")...) - decoded := DecodeCompactU16Length(buf) + decoded, _, err := DecodeCompactU16(buf) + require.NoError(t, err) require.Equal(t, val, decoded) } @@ -40,19 +44,34 @@ func TestCompactU16(t *testing.T) { buf = append(buf, []byte("hello world")...) { decoded, err := DecodeCompactU16LengthFromByteReader(bytes.NewReader(buf)) - if err != nil { - panic(err) - } + require.NoError(t, err) require.Equal(t, val, decoded) } { decoded, _, err := DecodeCompactU16(buf) - if err != nil { - panic(err) - } + require.NoError(t, err) require.Equal(t, val, decoded) } } + { + // now test all from 0 to 0xffff + for i := 0; i < math.MaxUint16; i++ { + buf := make([]byte, 0) + EncodeCompactU16Length(&buf, i) + + buf = append(buf, []byte("hello world")...) + { + decoded, err := DecodeCompactU16LengthFromByteReader(bytes.NewReader(buf)) + require.NoError(t, err) + require.Equal(t, i, decoded) + } + { + decoded, _, err := DecodeCompactU16(buf) + require.NoError(t, err) + require.Equal(t, i, decoded) + } + } + } } func BenchmarkCompactU16(b *testing.B) { @@ -102,3 +121,101 @@ func BenchmarkCompactU16Reader(b *testing.B) { reader.SetPosition(0) } } + +func encode_len(len uint16) []byte { + buf := make([]byte, 0) + err := EncodeCompactU16Length(&buf, int(len)) + if err != nil { + panic(err) + } + return buf +} + +func assert_len_encoding(t *testing.T, len uint16, buf []byte) { + require.Equal(t, encode_len(len), buf, "unexpected usize encoding") + decoded, _, err := DecodeCompactU16(buf) + require.NoError(t, err) + require.Equal(t, int(len), decoded) + { + // now try with a reader + reader := bytes.NewReader(buf) + out, _ := DecodeCompactU16LengthFromByteReader(reader) + require.Equal(t, int(len), out) + } +} + +func TestShortVecEncodeLen(t *testing.T) { + assert_len_encoding(t, 0x0, []byte{0x0}) + assert_len_encoding(t, 0x7f, []byte{0x7f}) + assert_len_encoding(t, 0x80, []byte{0x80, 0x01}) + assert_len_encoding(t, 0xff, []byte{0xff, 0x01}) + assert_len_encoding(t, 0x100, []byte{0x80, 0x02}) + assert_len_encoding(t, 0x7fff, []byte{0xff, 0xff, 0x01}) + assert_len_encoding(t, 0xffff, []byte{0xff, 0xff, 0x03}) +} + +func assert_good_deserialized_value(t *testing.T, value uint16, buf []byte) { + decoded, _, err := DecodeCompactU16(buf) + require.NoError(t, err) + require.Equal(t, int(value), decoded) + { + // now try with a reader + reader := bytes.NewReader(buf) + out, _ := DecodeCompactU16LengthFromByteReader(reader) + require.Equal(t, int(value), out) + } +} + +func assert_bad_deserialized_value(t *testing.T, buf []byte) { + _, _, err := DecodeCompactU16(buf) + require.Error(t, err, "expected an error for bytes: %v", buf) + { + // now try with a reader + reader := bytes.NewReader(buf) + _, err := DecodeCompactU16LengthFromByteReader(reader) + require.Error(t, err, "expected an error for bytes: %v", buf) + } +} + +func TestDeserialize(t *testing.T) { + assert_good_deserialized_value(t, 0x0000, []byte{0x00}) + assert_good_deserialized_value(t, 0x007f, []byte{0x7f}) + assert_good_deserialized_value(t, 0x0080, []byte{0x80, 0x01}) + assert_good_deserialized_value(t, 0x00ff, []byte{0xff, 0x01}) + assert_good_deserialized_value(t, 0x0100, []byte{0x80, 0x02}) + assert_good_deserialized_value(t, 0x07ff, []byte{0xff, 0x0f}) + assert_good_deserialized_value(t, 0x3fff, []byte{0xff, 0x7f}) + assert_good_deserialized_value(t, 0x4000, []byte{0x80, 0x80, 0x01}) + assert_good_deserialized_value(t, 0xffff, []byte{0xff, 0xff, 0x03}) + + // aliases + // 0x0000 + assert_bad_deserialized_value(t, []byte{0x80, 0x00}) + assert_bad_deserialized_value(t, []byte{0x80, 0x80, 0x00}) + // 0x007f + assert_bad_deserialized_value(t, []byte{0xff, 0x00}) + assert_bad_deserialized_value(t, []byte{0xff, 0x80, 0x00}) + // 0x0080 + assert_bad_deserialized_value(t, []byte{0x80, 0x81, 0x00}) + // 0x00ff + assert_bad_deserialized_value(t, []byte{0xff, 0x81, 0x00}) + // 0x0100 + assert_bad_deserialized_value(t, []byte{0x80, 0x82, 0x00}) + // 0x07ff + assert_bad_deserialized_value(t, []byte{0xff, 0x8f, 0x00}) + // 0x3fff + assert_bad_deserialized_value(t, []byte{0xff, 0xff, 0x00}) + + // too short + assert_bad_deserialized_value(t, []byte{}) + assert_bad_deserialized_value(t, []byte{0x80}) + + // too long + assert_bad_deserialized_value(t, []byte{0x80, 0x80, 0x80, 0x00}) + + // too large + // 0x0001_0000 + assert_bad_deserialized_value(t, []byte{0x80, 0x80, 0x04}) + // 0x0001_8000 + assert_bad_deserialized_value(t, []byte{0x80, 0x80, 0x06}) +}