diff --git a/community/bolt/src/main/java/org/neo4j/bolt/v1/messaging/Neo4jPackV1.java b/community/bolt/src/main/java/org/neo4j/bolt/v1/messaging/Neo4jPackV1.java index 0424465e12fb..7c9c3e31e0d5 100644 --- a/community/bolt/src/main/java/org/neo4j/bolt/v1/messaging/Neo4jPackV1.java +++ b/community/bolt/src/main/java/org/neo4j/bolt/v1/messaging/Neo4jPackV1.java @@ -553,7 +553,7 @@ public MapValue unpackMap() throws IOException MapValueBuilder map; if ( size == UNKNOWN_SIZE ) { - map = new MapValueBuilder( ); + map = new MapValueBuilder(); boolean more = true; while ( more ) { diff --git a/community/bolt/src/main/java/org/neo4j/bolt/v1/packstream/PackStream.java b/community/bolt/src/main/java/org/neo4j/bolt/v1/packstream/PackStream.java index 05abcc3e2056..9874fd7ba77e 100644 --- a/community/bolt/src/main/java/org/neo4j/bolt/v1/packstream/PackStream.java +++ b/community/bolt/src/main/java/org/neo4j/bolt/v1/packstream/PackStream.java @@ -153,6 +153,8 @@ public class PackStream private static final long PLUS_2_TO_THE_31 = 2147483648L; private static final long PLUS_2_TO_THE_15 = 32768L; + private static final long PLUS_2_TO_THE_16 = 65536L; + private static final long PLUS_2_TO_THE_8 = 256L; private static final long PLUS_2_TO_THE_7 = 128L; private static final long MINUS_2_TO_THE_4 = -16L; private static final long MINUS_2_TO_THE_7 = -128L; @@ -351,58 +353,17 @@ public void packUTF8( byte[] bytes, int offset, int length ) throws IOException protected void packBytesHeader( int size ) throws IOException { - if ( size <= Byte.MAX_VALUE ) - { - out.writeShort( (short) (BYTES_8 << 8 | size) ); - } - else if ( size <= Short.MAX_VALUE ) - { - out.writeByte( BYTES_16 ).writeShort( (short) size ); - } - else - { - out.writeByte( BYTES_32 ).writeInt( size ); - } + packHeader( size, BYTES_8, BYTES_16, BYTES_32 ); } - private void packStringHeader( int size ) throws IOException + void packStringHeader( int size ) throws IOException { - if ( size < 0x10 ) - { - out.writeByte( (byte) (TINY_STRING | size) ); - } - else if ( size <= Byte.MAX_VALUE ) - { - out.writeShort( (short) (STRING_8 << 8 | size) ); - } - else if ( size <= Short.MAX_VALUE ) - { - out.writeByte( STRING_16 ).writeShort( (short) size ); - } - else - { - out.writeByte( STRING_32 ).writeInt( size ); - } + packHeader( size, TINY_STRING, STRING_8, STRING_16, STRING_32 ); } public void packListHeader( int size ) throws IOException { - if ( size < 0x10 ) - { - out.writeByte( (byte) (TINY_LIST | size) ); - } - else if ( size <= Byte.MAX_VALUE ) - { - out.writeShort( (short) (LIST_8 << 8 | size) ); - } - else if ( size <= Short.MAX_VALUE ) - { - out.writeByte( LIST_16 ).writeShort( (short) size ); - } - else - { - out.writeByte( LIST_32 ).writeInt( size ); - } + packHeader( size, TINY_LIST, LIST_8, LIST_16, LIST_32 ); } public void packListStreamHeader() throws IOException @@ -412,22 +373,7 @@ public void packListStreamHeader() throws IOException public void packMapHeader( int size ) throws IOException { - if ( size < 0x10 ) - { - out.writeByte( (byte) (TINY_MAP | size) ); - } - else if ( size <= Byte.MAX_VALUE ) - { - out.writeShort( (short) (MAP_8 << 8 | size) ); - } - else if ( size <= Short.MAX_VALUE ) - { - out.writeByte( MAP_16 ).writeShort( (short) size ); - } - else - { - out.writeByte( MAP_32 ).writeInt( size ); - } + packHeader( size, TINY_MAP, MAP_8, MAP_16, MAP_32 ); } public void packMapStreamHeader() throws IOException @@ -441,11 +387,11 @@ public void packStructHeader( int size, byte signature ) throws IOException { out.writeShort( (short) ((byte) (TINY_STRUCT | size) << 8 | (signature & 0xFF)) ); } - else if ( size <= Byte.MAX_VALUE ) + else if ( size < PLUS_2_TO_THE_8 ) { out.writeByte( STRUCT_8 ).writeByte( (byte) size ).writeByte( signature ); } - else if ( size <= Short.MAX_VALUE ) + else if ( size < PLUS_2_TO_THE_16 ) { out.writeByte( STRUCT_16 ).writeShort( (short) size ).writeByte( signature ); } @@ -460,6 +406,33 @@ public void packEndOfStream() throws IOException out.writeByte( END_OF_STREAM ); } + private void packHeader( int size, byte marker8, byte marker16, byte marker32 ) throws IOException + { + if ( size < PLUS_2_TO_THE_8 ) + { + out.writeShort( (short) (marker8 << 8 | size) ); + } + else if ( size < PLUS_2_TO_THE_16 ) + { + out.writeByte( marker16 ).writeShort( (short) size ); + } + else + { + out.writeByte( marker32 ).writeInt( size ); + } + } + + private void packHeader( int size, byte marker4, byte marker8, byte marker16, byte marker32 ) throws IOException + { + if ( size < 0x10 ) + { + out.writeByte( (byte) (marker4 | size) ); + } + else + { + packHeader( size, marker8, marker16, marker32 ); + } + } } public static class Unpacker @@ -519,7 +492,7 @@ public long unpackListHeader() throws IOException case LIST_16: return unpackUINT16(); case LIST_32: - return unpackUINT32(); + return unpackUINT32( PackType.LIST ); case LIST_STREAM: return UNKNOWN_SIZE; default: @@ -544,7 +517,7 @@ public long unpackMapHeader() throws IOException case MAP_16: return unpackUINT16(); case MAP_32: - return unpackUINT32(); + return unpackUINT32( PackType.MAP ); case MAP_STREAM: return UNKNOWN_SIZE; default: @@ -631,15 +604,7 @@ public int unpackBytesHeader() throws IOException break; case BYTES_32: { - long longSize = unpackUINT32(); - if ( longSize <= Integer.MAX_VALUE ) - { - size = (int) longSize; - } - else - { - throw new Overflow( "BYTES_32 too long for Java" ); - } + size = unpackUINT32( PackType.BYTES ); break; } default: @@ -672,15 +637,7 @@ public int unpackStringHeader() throws IOException break; case STRING_32: { - long longSize = unpackUINT32(); - if ( longSize <= Integer.MAX_VALUE ) - { - size = (int) longSize; - } - else - { - throw new Overflow( "STRING_32 too long for Java" ); - } + size = unpackUINT32( PackType.STRING ); break; } default: @@ -732,6 +689,19 @@ private long unpackUINT32() throws IOException return in.readInt() & 0xFFFFFFFFL; } + private int unpackUINT32( PackType type ) throws IOException + { + long longSize = unpackUINT32(); + if ( longSize <= Integer.MAX_VALUE ) + { + return (int) longSize; + } + else + { + throw new Overflow( String.format( "%s_32 too long for Java", type ) ); + } + } + public void unpackEndOfStream() throws IOException { final byte markerByte = in.readByte(); diff --git a/community/bolt/src/test/java/org/neo4j/bolt/v1/packstream/PackStreamTest.java b/community/bolt/src/test/java/org/neo4j/bolt/v1/packstream/PackStreamTest.java index 5a2de3942298..ae01bfeeb2c5 100644 --- a/community/bolt/src/test/java/org/neo4j/bolt/v1/packstream/PackStreamTest.java +++ b/community/bolt/src/test/java/org/neo4j/bolt/v1/packstream/PackStreamTest.java @@ -24,6 +24,7 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.channels.Channels; import java.nio.channels.ReadableByteChannel; import java.nio.channels.WritableByteChannel; @@ -62,7 +63,6 @@ public static Map asMap( Object... keysAndValues ) private static class Machine { - private final ByteArrayOutputStream output; private final WritableByteChannel writable; private final PackStream.Packer packer; @@ -95,7 +95,61 @@ public PackStream.Packer packer() { return packer; } + } + + private static class MachineClient + { + private final PackStream.Unpacker unpacker; + private final ResetableReadableByteChannel readable; + MachineClient( int capacity ) + { + readable = new ResetableReadableByteChannel(); + BufferedChannelInput input = new BufferedChannelInput( capacity ).reset( readable ); + unpacker = new PackStream.Unpacker( input ); + } + + public void reset( byte[] input ) + { + readable.reset( input ); + } + + public PackStream.Unpacker unpacker() + { + return this.unpacker; + } + } + + private static class ResetableReadableByteChannel implements ReadableByteChannel + { + private byte[] bytes; + private int pos; + + public void reset( byte[] input ) + { + bytes = input; + pos = 0; + } + + @Override + public int read( ByteBuffer dst ) throws IOException + { + dst.put( bytes ); + int read = bytes.length; + pos += read; + return read; + } + + @Override + public boolean isOpen() + { + return pos < bytes.length; + } + + @Override + public void close() throws IOException + { + } } private PackStream.Unpacker newUnpacker( byte[] bytes ) @@ -765,6 +819,124 @@ public void testCanPeekOnNextType() throws Throwable assertPeekType( PackType.MAP, asMap( "l", 3 ) ); } + @Test + public void shouldPackBytesHeaderWithMinimalBuffer() throws Throwable + { + Machine machine = new Machine(); + PackStream.Packer packer = machine.packer(); + + MachineClient client = new MachineClient( 8 ); + PackStream.Unpacker unpacker = client.unpacker(); + + for ( int size = 0; size <= Math.pow( 2, 16 ); size++ ) + { + machine.reset(); + packer.packBytesHeader( size ); + packer.flush(); + + // Then + int bufferSize = computeOutputBufferSize( size, false ); + byte[] output = machine.output(); + assertThat( output.length, equalTo( bufferSize ) ); + + client.reset( output ); + int value = unpacker.unpackBytesHeader(); + assertThat( value, equalTo( size ) ); + } + } + + @Test + public void shouldPackStringHeaderWithMinimalBuffer() throws Throwable + { + shouldPackHeaderWithMinimalBuffer( PackType.STRING ); + } + + @Test + public void shouldPackMapHeaderWithMinimalBuffer() throws Throwable + { + shouldPackHeaderWithMinimalBuffer( PackType.MAP ); + } + + @Test + public void shouldPackListHeaderWithMinimalBuffer() throws Throwable + { + shouldPackHeaderWithMinimalBuffer( PackType.LIST ); + } + + private void shouldPackHeaderWithMinimalBuffer( PackType type ) throws Throwable + { + Machine machine = new Machine(); + PackStream.Packer packer = machine.packer(); + + MachineClient client = new MachineClient( 8 ); + PackStream.Unpacker unpacker = client.unpacker(); + + for ( int size = 0; size <= Math.pow( 2, 16 ); size++ ) + { + machine.reset(); + switch ( type ) + { + case MAP: + packer.packMapHeader( size ); + break; + case LIST: + packer.packListHeader( size ); + break; + case STRING: + packer.packStringHeader( size ); + break; + default: + throw new IllegalArgumentException( "Unsupported type: " + type + "." ); + } + packer.flush(); + + int bufferSize = computeOutputBufferSize( size, true ); + byte[] output = machine.output(); + assertThat( output.length, equalTo( bufferSize ) ); + + client.reset( output ); + int value = 0; + switch ( type ) + { + case MAP: + value = (int) unpacker.unpackMapHeader(); + break; + case LIST: + value = (int) unpacker.unpackListHeader(); + break; + case STRING: + value = unpacker.unpackStringHeader(); + break; + default: + throw new IllegalArgumentException( "Unsupported type: " + type + "." ); + } + + assertThat( value, equalTo( size ) ); + } + } + + private int computeOutputBufferSize( int size, boolean withMarker8 ) + { + int bufferSize; + if ( withMarker8 && size < 16 ) + { + bufferSize = 1; + } + else if ( size < 256 ) + { + bufferSize = 2; + } + else if ( size < 65536 ) + { + bufferSize = 1 + 2; + } + else + { + bufferSize = 1 + 4; + } + return bufferSize; + } + private void assertPeekType( PackType type, Object value ) throws IOException { // Given