Skip to content

Commit

Permalink
Add Packstream message size sanity check () ()
Browse files Browse the repository at this point in the history
Add check to prevent OOMs occurring when bolt messages are received which state they contain more bytes than they do.
  • Loading branch information
gjmwoods committed Feb 25, 2021
1 parent d40f29c commit 31ba1c4
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 83 deletions.
Expand Up @@ -90,4 +90,10 @@ private void assertNotStarted()
throw new IllegalStateException( "Already started" );
}
}

@Override
public int readableBytes()
{
return buf.readableBytes();
}
}
Expand Up @@ -33,7 +33,6 @@
import org.neo4j.kernel.api.exceptions.Status;
import org.neo4j.values.AnyValue;
import org.neo4j.values.AnyValueWriter;
import org.neo4j.values.VirtualValue;
import org.neo4j.values.storable.CoordinateReferenceSystem;
import org.neo4j.values.storable.TextArray;
import org.neo4j.values.storable.TextValue;
Expand All @@ -46,7 +45,6 @@
import org.neo4j.values.virtual.RelationshipValue;
import org.neo4j.values.virtual.VirtualValues;

import static org.neo4j.bolt.packstream.PackStream.UNKNOWN_SIZE;
import static org.neo4j.values.storable.Values.byteArray;
import static org.neo4j.values.virtual.VirtualValues.EMPTY_MAP;

Expand Down Expand Up @@ -509,34 +507,15 @@ ListValue unpackList() throws IOException
{
return VirtualValues.EMPTY_LIST;
}
else if ( size == UNKNOWN_SIZE )
{
ListValueBuilder builder = ListValueBuilder.newListBuilder();
boolean more = true;
while ( more )
{
PackType keyType = peekNextType();
if ( keyType == PackType.END_OF_STREAM )
{
unpack();
more = false;
}
else
{
builder.add( unpack() );
}
}
return builder.build();
}
else

sizeSanityCheck( size, in );

ListValueBuilder builder = ListValueBuilder.newListBuilder( size );
for ( int i = 0; i < size; i++ )
{
ListValueBuilder builder = ListValueBuilder.newListBuilder( size );
for ( int i = 0; i < size; i++ )
{
builder.add( unpack() );
}
return builder.build();
builder.add( unpack() );
}
return builder.build();
}

protected AnyValue unpackStruct( char signature, long size ) throws IOException
Expand All @@ -560,60 +539,29 @@ public MapValue unpackMap() throws IOException
{
return EMPTY_MAP;
}
MapValueBuilder map;
if ( size == UNKNOWN_SIZE )

sizeSanityCheck( size, in );

MapValueBuilder map = new MapValueBuilder( size );
for ( int i = 0; i < size; i++ )
{
map = new MapValueBuilder();
boolean more = true;
while ( more )
PackType keyType = peekNextType();
String key;
switch ( keyType )
{
PackType keyType = peekNextType();
String key;
AnyValue val;
switch ( keyType )
{
case END_OF_STREAM:
unpack();
more = false;
break;
case STRING:
key = unpackString();
val = unpack();
if ( map.add( key, val ) != null )
{
throw new BoltIOException( Status.Request.Invalid, "Duplicate map key `" + key + "`." );
}
break;
case NULL:
throw new BoltIOException( Status.Request.Invalid, "Value `null` is not supported as key in maps, must be a non-nullable string." );
default:
throw new BoltIOException( Status.Request.InvalidFormat, "Bad key type: " + keyType );
}
case NULL:
throw new BoltIOException( Status.Request.Invalid, "Value `null` is not supported as key in maps, must be a non-nullable string." );
case STRING:
key = unpackString();
break;
default:
throw new BoltIOException( Status.Request.InvalidFormat, "Bad key type: " + keyType );
}
}
else
{
map = new MapValueBuilder( size );
for ( int i = 0; i < size; i++ )
{
PackType keyType = peekNextType();
String key;
switch ( keyType )
{
case NULL:
throw new BoltIOException( Status.Request.Invalid, "Value `null` is not supported as key in maps, must be a non-nullable string." );
case STRING:
key = unpackString();
break;
default:
throw new BoltIOException( Status.Request.InvalidFormat, "Bad key type: " + keyType );
}

AnyValue val = unpack();
if ( map.add( key, val ) != null )
{
throw new BoltIOException( Status.Request.Invalid, "Duplicate map key `" + key + "`." );
}
AnyValue val = unpack();
if ( map.add( key, val ) != null )
{
throw new BoltIOException( Status.Request.Invalid, "Duplicate map key `" + key + "`." );
}
}
return map.build();
Expand Down
Expand Up @@ -47,4 +47,7 @@ public interface PackInput

/** Get the next byte without forwarding the internal pointer */
byte peekByte() throws IOException;

/** Remaining Readable bytes */
int readableBytes();
}
Expand Up @@ -23,8 +23,10 @@
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;

import org.neo4j.bolt.messaging.BoltIOException;
import org.neo4j.bolt.messaging.StructType;
import org.neo4j.bolt.packstream.utf8.UTF8Encoder;
import org.neo4j.kernel.api.exceptions.Status;

/**
* PackStream is a messaging serialisation format heavily inspired by MessagePack.
Expand Down Expand Up @@ -756,6 +758,15 @@ public static void ensureCorrectStructSize( StructType structType, int expected,
structType.description(), expected, structType.description(), actual ) );
}
}

public static void sizeSanityCheck( int reportedSize, PackInput packInput ) throws BoltIOException
{
// We check if the stated size of the read item makes sense with respect to the remaining available bytes.
if ( reportedSize > packInput.readableBytes() )
{
throw new BoltIOException( Status.Request.Invalid, "Collection size exceeds message capacity." );
}
}
}

public static class PackStreamException extends IOException
Expand Down
Expand Up @@ -81,4 +81,10 @@ public byte peekByte() throws IOException
data.reset();
return value;
}

@Override
public int readableBytes()
{
return bytes.available();
}
}
Expand Up @@ -140,6 +140,12 @@ public byte peekByte() throws IOException
return buffer.get( buffer.position() );
}

@Override
public int readableBytes()
{
return buffer.remaining();
}

private void ensure( int numBytes ) throws IOException
{
if ( !attempt( numBytes ) )
Expand Down
Expand Up @@ -213,6 +213,55 @@ void shouldThrowOnUnpackingMapWithUnsupportedKeyType() throws IOException
}
}

@Test
void shouldThrowOnUnpackingMapWithOversizeDeclaredSize() throws IOException
{
// Given
PackedOutputArray output = new PackedOutputArray();
Neo4jPack.Packer packer = neo4jPack.newPacker( output );
packer.packMapHeader( Integer.MAX_VALUE );
packer.pack( "key" );
packer.pack( intValue( 1 ) );

// When
try
{
PackedInputArray input = new PackedInputArray( output.bytes() );
Neo4jPack.Unpacker unpacker = neo4jPack.newUnpacker( input );
unpacker.unpack();

fail( "exception expected" );
}
catch ( BoltIOException ex )
{
assertEquals( Neo4jError.from( Status.Request.Invalid, "Collection size exceeds message capacity." ), Neo4jError.from( ex ) );
}
}

@Test
void shouldThrowOnUnpackingListWithOversizeDeclaredSize() throws IOException
{
// Given
PackedOutputArray output = new PackedOutputArray();
Neo4jPack.Packer packer = neo4jPack.newPacker( output );
packer.packListHeader( Integer.MAX_VALUE );
packer.pack( intValue( 1 ) );

// When
try
{
PackedInputArray input = new PackedInputArray( output.bytes() );
Neo4jPack.Unpacker unpacker = neo4jPack.newUnpacker( input );
unpacker.unpack();

fail( "exception expected" );
}
catch ( BoltIOException ex )
{
assertEquals( Neo4jError.from( Status.Request.Invalid, "Collection size exceeds message capacity." ), Neo4jError.from( ex ) );
}
}

@Test
void shouldNotBeAbleToUnpackNode()
{
Expand Down
Expand Up @@ -19,6 +19,7 @@
*/
package org.neo4j.bolt.testing;

import io.netty.buffer.Unpooled;
import org.assertj.core.api.Condition;

import java.io.IOException;
Expand All @@ -36,8 +37,8 @@
import org.neo4j.bolt.messaging.RecordingByteChannel;
import org.neo4j.bolt.messaging.RequestMessage;
import org.neo4j.bolt.messaging.ResponseMessage;
import org.neo4j.bolt.packstream.BufferedChannelInput;
import org.neo4j.bolt.packstream.BufferedChannelOutput;
import org.neo4j.bolt.packstream.ByteBufInput;
import org.neo4j.bolt.packstream.Neo4jPack;
import org.neo4j.bolt.v3.messaging.BoltResponseMessageWriterV3;
import org.neo4j.bolt.v3.messaging.response.FailureMessage;
Expand Down Expand Up @@ -275,8 +276,8 @@ public static ResponseMessage responseMessage( Neo4jPack neo4jPack, byte[] bytes

private static BoltResponseMessageReader responseReader( Neo4jPack neo4jPack, byte[] bytes )
{
BufferedChannelInput input = new BufferedChannelInput( 128 );
input.reset( new ArrayByteChannel( bytes ) );
return new BoltResponseMessageReader( neo4jPack.newUnpacker( input ) );
ByteBufInput byteBufInput = new ByteBufInput();
byteBufInput.start( Unpooled.wrappedBuffer( bytes ) );
return new BoltResponseMessageReader( neo4jPack.newUnpacker( byteBufInput ) );
}
}

0 comments on commit 31ba1c4

Please sign in to comment.