Skip to content

Commit

Permalink
Add Client protocol frame size limit
Browse files Browse the repository at this point in the history
  • Loading branch information
kwart committed Oct 1, 2019
1 parent 4c93d14 commit 9b1c2ce
Show file tree
Hide file tree
Showing 12 changed files with 353 additions and 28 deletions.
Expand Up @@ -536,7 +536,7 @@ public Throwable createException(String message, Throwable cause) {
register(ClientProtocolErrorCodes.MAX_MESSAGE_SIZE_EXCEEDED, MaxMessageSizeExceeded.class, new ExceptionFactory() {
@Override
public Throwable createException(String message, Throwable cause) {
return new MaxMessageSizeExceeded();
return new MaxMessageSizeExceeded(message);
}
});
register(ClientProtocolErrorCodes.WAN_REPLICATION_QUEUE_FULL, WanReplicationQueueFullException.class, new ExceptionFactory() {
Expand Down
Expand Up @@ -71,7 +71,7 @@ public void initChannel(Channel channel) {
public void accept(ClientMessage message) {
connection.handleClientMessage(message);
}
});
}, null);
channel.inboundPipeline().addLast(decoder);

channel.outboundPipeline().addLast(new ClientMessageEncoder());
Expand Down
Expand Up @@ -16,24 +16,33 @@

package com.hazelcast.client.impl.protocol;

import com.hazelcast.client.impl.protocol.exception.MaxMessageSizeExceeded;
import com.hazelcast.internal.nio.Bits;

import static java.lang.String.format;

import java.nio.ByteBuffer;
import java.util.LinkedList;

import static com.hazelcast.client.impl.protocol.ClientMessage.IS_FINAL_FLAG;
import static com.hazelcast.client.impl.protocol.ClientMessage.SIZE_OF_FRAME_LENGTH_AND_FLAGS;

public class ClientMessageReader {
public final class ClientMessageReader {

private static final int INT_MASK = 0xffff;
private int readIndex;
private int readOffset = -1;
private int sumUntrustedMessageLength;
private final int maxMessageLength;
private LinkedList<ClientMessage.Frame> frames = new LinkedList<>();

public boolean readFrom(ByteBuffer src) {
public ClientMessageReader(int maxMessageLenth) {
this.maxMessageLength = maxMessageLenth > 0 ? maxMessageLenth : Integer.MAX_VALUE;
}

public boolean readFrom(ByteBuffer src, boolean trusted) {
for (; ; ) {
if (readFrame(src)) {
if (readFrame(src, trusted)) {
if (ClientMessage.isFlagSet(frames.get(readIndex).flags, IS_FINAL_FLAG)) {
return true;
}
Expand All @@ -50,13 +59,7 @@ public LinkedList<ClientMessage.Frame> getFrames() {
return frames;
}

public void reset() {
readIndex = 0;
readOffset = -1;
frames = new LinkedList<>();
}

private boolean readFrame(ByteBuffer src) {
private boolean readFrame(ByteBuffer src, boolean trusted) {
// init internal buffer
int remaining = src.remaining();
if (remaining < SIZE_OF_FRAME_LENGTH_AND_FLAGS) {
Expand All @@ -65,6 +68,22 @@ private boolean readFrame(ByteBuffer src) {
}
if (readOffset == -1) {
int frameLength = Bits.readIntL(src, src.position());
if (frameLength < SIZE_OF_FRAME_LENGTH_AND_FLAGS) {
throw new IllegalArgumentException(format(
"The client message frame reported illegal length (%d bytes)."
+ " Minimal length is the size of frame header (%d bytes).",
frameLength, SIZE_OF_FRAME_LENGTH_AND_FLAGS));
}
if (!trusted) {
// check the message size overflow and message size limit
if (Integer.MAX_VALUE - frameLength < sumUntrustedMessageLength
|| sumUntrustedMessageLength + frameLength > maxMessageLength) {
throw new MaxMessageSizeExceeded(
format("The client message size (%d + %d) exceededs the maximum allowed length (%d)",
sumUntrustedMessageLength, frameLength, maxMessageLength));
}
sumUntrustedMessageLength += frameLength;
}
src.position(src.position() + Bits.INT_SIZE_IN_BYTES);
int flags = Bits.readShortL(src, src.position()) & INT_MASK;
src.position(src.position() + Bits.SHORT_SIZE_IN_BYTES);
Expand Down
Expand Up @@ -24,6 +24,9 @@
public class MaxMessageSizeExceeded
extends HazelcastException {
public MaxMessageSizeExceeded() {
super("The size of the message exceeds the maximum value of " + Integer.MAX_VALUE + " bytes.");
}

public MaxMessageSizeExceeded(String message) {
super(message);
}
}
Expand Up @@ -16,16 +16,22 @@

package com.hazelcast.client.impl.protocol.util;

import com.hazelcast.client.impl.ClientEndpoint;
import com.hazelcast.client.impl.ClientEndpointManager;
import com.hazelcast.client.impl.ClientEngine;
import com.hazelcast.client.impl.protocol.ClientMessage;
import com.hazelcast.client.impl.protocol.ClientMessageReader;
import com.hazelcast.internal.networking.HandlerStatus;
import com.hazelcast.internal.networking.nio.InboundHandlerWithCounters;
import com.hazelcast.internal.nio.Bits;
import com.hazelcast.internal.nio.Connection;
import com.hazelcast.internal.util.collection.Long2ObjectHashMap;
import com.hazelcast.spi.properties.GroupProperty;
import com.hazelcast.spi.properties.HazelcastProperties;

import java.nio.ByteBuffer;
import java.util.LinkedList;
import java.util.Properties;
import java.util.function.Consumer;

import static com.hazelcast.client.impl.protocol.ClientMessage.BEGIN_FRAGMENT_FLAG;
Expand All @@ -44,10 +50,20 @@ public class ClientMessageDecoder extends InboundHandlerWithCounters<ByteBuffer,

private final Connection connection;
private final Long2ObjectHashMap<ClientMessageReader> builderBySessionIdMap = new Long2ObjectHashMap<>();
private ClientMessageReader activeReader = new ClientMessageReader();
private ClientMessageReader activeReader;

public ClientMessageDecoder(Connection connection, Consumer<ClientMessage> dst) {
private boolean clientIsTrusted;
private final int maxMessageLength;
private final ClientEndpointManager clientEndpointManager;

public ClientMessageDecoder(Connection connection, Consumer<ClientMessage> dst, HazelcastProperties properties) {
dst(dst);
if (properties == null) {
properties = new HazelcastProperties((Properties) null);
}
clientEndpointManager = dst instanceof ClientEngine ? ((ClientEngine) dst).getEndpointManager() : null;
maxMessageLength = properties.getInteger(GroupProperty.CLIENT_PROTOCOL_UNVERIFIED_MESSAGE_BYTES);
activeReader = new ClientMessageReader(maxMessageLength);
this.connection = connection;
}

Expand All @@ -61,7 +77,8 @@ public HandlerStatus onRead() {
src.flip();
try {
while (src.hasRemaining()) {
boolean complete = activeReader.readFrom(src);
boolean trusted = isEndpointTrusted();
boolean complete = activeReader.readFrom(src, trusted);
if (!complete) {
break;
}
Expand All @@ -70,6 +87,9 @@ public HandlerStatus onRead() {
int flags = firstFrame.flags;
if (ClientMessage.isFlagSet(flags, UNFRAGMENTED_MESSAGE)) {
handleMessage(activeReader);
} else if (!trusted) {
throw new IllegalStateException(
"Fragmented client messages are not allowed before the client is authenticated.");
} else {
//remove the fragmentationFrame
activeReader.getFrames().removeFirst();
Expand All @@ -87,7 +107,7 @@ public HandlerStatus onRead() {
}
}

activeReader = new ClientMessageReader();
activeReader = new ClientMessageReader(maxMessageLength);
}

return CLEAN;
Expand All @@ -96,6 +116,15 @@ public HandlerStatus onRead() {
}
}

private boolean isEndpointTrusted() {
if (clientEndpointManager == null || clientIsTrusted) {
return true;
}
ClientEndpoint endpoint = clientEndpointManager.getEndpoint(connection);
clientIsTrusted = endpoint != null && endpoint.isAuthenticated();
return clientIsTrusted;
}

private void handleMessage(ClientMessageReader clientMessageReader) {
LinkedList<ClientMessage.Frame> frames = clientMessageReader.getFrames();
ClientMessage clientMessage = ClientMessage.createForDecode(frames);
Expand Down
Expand Up @@ -35,7 +35,7 @@ public class ClientChannelInitializer
public void initChannel(Channel channel) {
TcpIpConnection connection = (TcpIpConnection) channel.attributeMap().get(TcpIpConnection.class);
SingleProtocolDecoder protocolDecoder = new SingleProtocolDecoder(CLIENT,
new ClientMessageDecoder(connection, ioService.getClientEngine()));
new ClientMessageDecoder(connection, ioService.getClientEngine(), ioService.properties()));

channel.outboundPipeline().addLast(new ClientMessageEncoder());
channel.inboundPipeline().addLast(protocolDecoder);
Expand Down
Expand Up @@ -152,7 +152,7 @@ private void initChannelForClient() {
.setOption(DIRECT_BUF, false);

TcpIpConnection connection = (TcpIpConnection) channel.attributeMap().get(TcpIpConnection.class);
channel.inboundPipeline().replace(this, new ClientMessageDecoder(connection, ioService.getClientEngine()));
channel.inboundPipeline().replace(this, new ClientMessageDecoder(connection, ioService.getClientEngine(), props));
}

private void initChannelForText(String protocol, boolean restApi) {
Expand Down
Expand Up @@ -1085,6 +1085,12 @@ private int getWhenNoSSLDetected() {
public static final HazelcastProperty MOBY_NAMING_ENABLED
= new HazelcastProperty("hazelcast.member.naming.moby.enabled", true);

/**
* Client protocol message size limit (in bytes) for unverified connections (i.e. maximal length of authentication message).
*/
public static final HazelcastProperty CLIENT_PROTOCOL_UNVERIFIED_MESSAGE_BYTES =
new HazelcastProperty("hazelcast.client.protocol.max.message.bytes", 1024);

public static final HazelcastProperty AUDIT_LOG_ENABLED = new HazelcastProperty("hazelcast.auditlog.enabled", false);

private GroupProperty() {
Expand Down
Expand Up @@ -247,7 +247,7 @@ public static Iterable<Object[]> parameters() {
},
new Object[]{new NoDataMemberInClusterException(randomString())},
new Object[]{new ReplicatedMapCantBeCreatedOnLiteMemberException(randomString())},
new Object[]{new MaxMessageSizeExceeded()},
new Object[]{new MaxMessageSizeExceeded(randomString())},
new Object[]{new WanReplicationQueueFullException(randomString())},
new Object[]{new AssertionError(randomString())},
new Object[]{new OutOfMemoryError(randomString())},
Expand Down
Expand Up @@ -74,7 +74,7 @@ public void test() {
assertEquals(CLEAN, result);

AtomicReference<ClientMessage> resultingMessage = new AtomicReference<>();
ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set);
ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set, null);
decoder.setNormalPacketsRead(SwCounter.newSwCounter());

buffer.position(buffer.limit());
Expand Down Expand Up @@ -107,7 +107,7 @@ public void testPut() {
assertEquals(CLEAN, result);

AtomicReference<ClientMessage> resultingMessage = new AtomicReference<>();
ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set);
ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set, null);
decoder.setNormalPacketsRead(SwCounter.newSwCounter());

buffer.position(buffer.limit());
Expand Down Expand Up @@ -151,7 +151,7 @@ public void testAuthenticationRequest() {
assertEquals(CLEAN, result);

AtomicReference<ClientMessage> resultingMessage = new AtomicReference<>();
ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set);
ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set, null);
decoder.setNormalPacketsRead(SwCounter.newSwCounter());

buffer.position(buffer.limit());
Expand Down Expand Up @@ -208,7 +208,7 @@ public void testAuthenticationResponse() throws UnknownHostException {
assertEquals(CLEAN, result);

AtomicReference<ClientMessage> resultingMessage = new AtomicReference<>();
ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set);
ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set, null);
decoder.setNormalPacketsRead(SwCounter.newSwCounter());

buffer.position(buffer.limit());
Expand Down Expand Up @@ -275,7 +275,7 @@ public void testEvent() throws UnknownHostException {
assertEquals(CLEAN, result);

AtomicReference<ClientMessage> resultingMessage = new AtomicReference<>();
ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set);
ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set, null);
decoder.setNormalPacketsRead(SwCounter.newSwCounter());

buffer.position(buffer.limit());
Expand Down

0 comments on commit 9b1c2ce

Please sign in to comment.