diff --git a/hazelcast/src/main/java/com/hazelcast/client/impl/clientside/ClientExceptionFactory.java b/hazelcast/src/main/java/com/hazelcast/client/impl/clientside/ClientExceptionFactory.java index 1b14ccc5f3bd..94ca6df09d81 100644 --- a/hazelcast/src/main/java/com/hazelcast/client/impl/clientside/ClientExceptionFactory.java +++ b/hazelcast/src/main/java/com/hazelcast/client/impl/clientside/ClientExceptionFactory.java @@ -536,7 +536,7 @@ public Throwable createException(String message, Throwable cause) { register(ClientProtocolErrorCodes.MAX_MESSAGE_SIZE_EXCEEDED, MaxMessageSizeExceeded.class, new ExceptionFactory() { @Override public Throwable createException(String message, Throwable cause) { - return new MaxMessageSizeExceeded(); + return new MaxMessageSizeExceeded(message); } }); register(ClientProtocolErrorCodes.WAN_REPLICATION_QUEUE_FULL, WanReplicationQueueFullException.class, new ExceptionFactory() { diff --git a/hazelcast/src/main/java/com/hazelcast/client/impl/connection/nio/ClientPlainChannelInitializer.java b/hazelcast/src/main/java/com/hazelcast/client/impl/connection/nio/ClientPlainChannelInitializer.java index 8ad1351ec390..b690a5686c47 100644 --- a/hazelcast/src/main/java/com/hazelcast/client/impl/connection/nio/ClientPlainChannelInitializer.java +++ b/hazelcast/src/main/java/com/hazelcast/client/impl/connection/nio/ClientPlainChannelInitializer.java @@ -71,7 +71,7 @@ public void initChannel(Channel channel) { public void accept(ClientMessage message) { connection.handleClientMessage(message); } - }); + }, null); channel.inboundPipeline().addLast(decoder); channel.outboundPipeline().addLast(new ClientMessageEncoder()); diff --git a/hazelcast/src/main/java/com/hazelcast/client/impl/protocol/ClientMessageReader.java b/hazelcast/src/main/java/com/hazelcast/client/impl/protocol/ClientMessageReader.java index 466eebee955a..ffe2b7698fc7 100644 --- a/hazelcast/src/main/java/com/hazelcast/client/impl/protocol/ClientMessageReader.java +++ b/hazelcast/src/main/java/com/hazelcast/client/impl/protocol/ClientMessageReader.java @@ -16,24 +16,33 @@ package com.hazelcast.client.impl.protocol; +import com.hazelcast.client.impl.protocol.exception.MaxMessageSizeExceeded; import com.hazelcast.internal.nio.Bits; +import static java.lang.String.format; + import java.nio.ByteBuffer; import java.util.LinkedList; import static com.hazelcast.client.impl.protocol.ClientMessage.IS_FINAL_FLAG; import static com.hazelcast.client.impl.protocol.ClientMessage.SIZE_OF_FRAME_LENGTH_AND_FLAGS; -public class ClientMessageReader { +public final class ClientMessageReader { private static final int INT_MASK = 0xffff; private int readIndex; private int readOffset = -1; + private int sumUntrustedMessageLength; + private final int maxMessageLength; private LinkedList frames = new LinkedList<>(); - public boolean readFrom(ByteBuffer src) { + public ClientMessageReader(int maxMessageLenth) { + this.maxMessageLength = maxMessageLenth > 0 ? maxMessageLenth : Integer.MAX_VALUE; + } + + public boolean readFrom(ByteBuffer src, boolean trusted) { for (; ; ) { - if (readFrame(src)) { + if (readFrame(src, trusted)) { if (ClientMessage.isFlagSet(frames.get(readIndex).flags, IS_FINAL_FLAG)) { return true; } @@ -50,13 +59,7 @@ public LinkedList getFrames() { return frames; } - public void reset() { - readIndex = 0; - readOffset = -1; - frames = new LinkedList<>(); - } - - private boolean readFrame(ByteBuffer src) { + private boolean readFrame(ByteBuffer src, boolean trusted) { // init internal buffer int remaining = src.remaining(); if (remaining < SIZE_OF_FRAME_LENGTH_AND_FLAGS) { @@ -65,6 +68,22 @@ private boolean readFrame(ByteBuffer src) { } if (readOffset == -1) { int frameLength = Bits.readIntL(src, src.position()); + if (frameLength < SIZE_OF_FRAME_LENGTH_AND_FLAGS) { + throw new IllegalArgumentException(format( + "The client message frame reported illegal length (%d bytes)." + + " Minimal length is the size of frame header (%d bytes).", + frameLength, SIZE_OF_FRAME_LENGTH_AND_FLAGS)); + } + if (!trusted) { + // check the message size overflow and message size limit + if (Integer.MAX_VALUE - frameLength < sumUntrustedMessageLength + || sumUntrustedMessageLength + frameLength > maxMessageLength) { + throw new MaxMessageSizeExceeded( + format("The client message size (%d + %d) exceededs the maximum allowed length (%d)", + sumUntrustedMessageLength, frameLength, maxMessageLength)); + } + sumUntrustedMessageLength += frameLength; + } src.position(src.position() + Bits.INT_SIZE_IN_BYTES); int flags = Bits.readShortL(src, src.position()) & INT_MASK; src.position(src.position() + Bits.SHORT_SIZE_IN_BYTES); diff --git a/hazelcast/src/main/java/com/hazelcast/client/impl/protocol/exception/MaxMessageSizeExceeded.java b/hazelcast/src/main/java/com/hazelcast/client/impl/protocol/exception/MaxMessageSizeExceeded.java index 8edfb1b6d297..141703228cb7 100644 --- a/hazelcast/src/main/java/com/hazelcast/client/impl/protocol/exception/MaxMessageSizeExceeded.java +++ b/hazelcast/src/main/java/com/hazelcast/client/impl/protocol/exception/MaxMessageSizeExceeded.java @@ -24,6 +24,9 @@ public class MaxMessageSizeExceeded extends HazelcastException { public MaxMessageSizeExceeded() { - super("The size of the message exceeds the maximum value of " + Integer.MAX_VALUE + " bytes."); + } + + public MaxMessageSizeExceeded(String message) { + super(message); } } diff --git a/hazelcast/src/main/java/com/hazelcast/client/impl/protocol/util/ClientMessageDecoder.java b/hazelcast/src/main/java/com/hazelcast/client/impl/protocol/util/ClientMessageDecoder.java index 7569057831a3..56f061764b57 100644 --- a/hazelcast/src/main/java/com/hazelcast/client/impl/protocol/util/ClientMessageDecoder.java +++ b/hazelcast/src/main/java/com/hazelcast/client/impl/protocol/util/ClientMessageDecoder.java @@ -16,6 +16,9 @@ package com.hazelcast.client.impl.protocol.util; +import com.hazelcast.client.impl.ClientEndpoint; +import com.hazelcast.client.impl.ClientEndpointManager; +import com.hazelcast.client.impl.ClientEngine; import com.hazelcast.client.impl.protocol.ClientMessage; import com.hazelcast.client.impl.protocol.ClientMessageReader; import com.hazelcast.internal.networking.HandlerStatus; @@ -23,9 +26,12 @@ import com.hazelcast.internal.nio.Bits; import com.hazelcast.internal.nio.Connection; import com.hazelcast.internal.util.collection.Long2ObjectHashMap; +import com.hazelcast.spi.properties.GroupProperty; +import com.hazelcast.spi.properties.HazelcastProperties; import java.nio.ByteBuffer; import java.util.LinkedList; +import java.util.Properties; import java.util.function.Consumer; import static com.hazelcast.client.impl.protocol.ClientMessage.BEGIN_FRAGMENT_FLAG; @@ -44,10 +50,20 @@ public class ClientMessageDecoder extends InboundHandlerWithCounters builderBySessionIdMap = new Long2ObjectHashMap<>(); - private ClientMessageReader activeReader = new ClientMessageReader(); + private ClientMessageReader activeReader; - public ClientMessageDecoder(Connection connection, Consumer dst) { + private boolean clientIsTrusted; + private final int maxMessageLength; + private final ClientEndpointManager clientEndpointManager; + + public ClientMessageDecoder(Connection connection, Consumer dst, HazelcastProperties properties) { dst(dst); + if (properties == null) { + properties = new HazelcastProperties((Properties) null); + } + clientEndpointManager = dst instanceof ClientEngine ? ((ClientEngine) dst).getEndpointManager() : null; + maxMessageLength = properties.getInteger(GroupProperty.CLIENT_PROTOCOL_UNVERIFIED_MESSAGE_BYTES); + activeReader = new ClientMessageReader(maxMessageLength); this.connection = connection; } @@ -61,7 +77,8 @@ public HandlerStatus onRead() { src.flip(); try { while (src.hasRemaining()) { - boolean complete = activeReader.readFrom(src); + boolean trusted = isEndpointTrusted(); + boolean complete = activeReader.readFrom(src, trusted); if (!complete) { break; } @@ -70,6 +87,9 @@ public HandlerStatus onRead() { int flags = firstFrame.flags; if (ClientMessage.isFlagSet(flags, UNFRAGMENTED_MESSAGE)) { handleMessage(activeReader); + } else if (!trusted) { + throw new IllegalStateException( + "Fragmented client messages are not allowed before the client is authenticated."); } else { //remove the fragmentationFrame activeReader.getFrames().removeFirst(); @@ -87,7 +107,7 @@ public HandlerStatus onRead() { } } - activeReader = new ClientMessageReader(); + activeReader = new ClientMessageReader(maxMessageLength); } return CLEAN; @@ -96,6 +116,15 @@ public HandlerStatus onRead() { } } + private boolean isEndpointTrusted() { + if (clientEndpointManager == null || clientIsTrusted) { + return true; + } + ClientEndpoint endpoint = clientEndpointManager.getEndpoint(connection); + clientIsTrusted = endpoint != null && endpoint.isAuthenticated(); + return clientIsTrusted; + } + private void handleMessage(ClientMessageReader clientMessageReader) { LinkedList frames = clientMessageReader.getFrames(); ClientMessage clientMessage = ClientMessage.createForDecode(frames); diff --git a/hazelcast/src/main/java/com/hazelcast/internal/nio/tcp/ClientChannelInitializer.java b/hazelcast/src/main/java/com/hazelcast/internal/nio/tcp/ClientChannelInitializer.java index 7286616a146f..66d77d0c5d1d 100644 --- a/hazelcast/src/main/java/com/hazelcast/internal/nio/tcp/ClientChannelInitializer.java +++ b/hazelcast/src/main/java/com/hazelcast/internal/nio/tcp/ClientChannelInitializer.java @@ -35,7 +35,7 @@ public class ClientChannelInitializer public void initChannel(Channel channel) { TcpIpConnection connection = (TcpIpConnection) channel.attributeMap().get(TcpIpConnection.class); SingleProtocolDecoder protocolDecoder = new SingleProtocolDecoder(CLIENT, - new ClientMessageDecoder(connection, ioService.getClientEngine())); + new ClientMessageDecoder(connection, ioService.getClientEngine(), ioService.properties())); channel.outboundPipeline().addLast(new ClientMessageEncoder()); channel.inboundPipeline().addLast(protocolDecoder); diff --git a/hazelcast/src/main/java/com/hazelcast/internal/nio/tcp/UnifiedProtocolDecoder.java b/hazelcast/src/main/java/com/hazelcast/internal/nio/tcp/UnifiedProtocolDecoder.java index b1b123f97497..c1d9536b815e 100644 --- a/hazelcast/src/main/java/com/hazelcast/internal/nio/tcp/UnifiedProtocolDecoder.java +++ b/hazelcast/src/main/java/com/hazelcast/internal/nio/tcp/UnifiedProtocolDecoder.java @@ -152,7 +152,7 @@ private void initChannelForClient() { .setOption(DIRECT_BUF, false); TcpIpConnection connection = (TcpIpConnection) channel.attributeMap().get(TcpIpConnection.class); - channel.inboundPipeline().replace(this, new ClientMessageDecoder(connection, ioService.getClientEngine())); + channel.inboundPipeline().replace(this, new ClientMessageDecoder(connection, ioService.getClientEngine(), props)); } private void initChannelForText(String protocol, boolean restApi) { diff --git a/hazelcast/src/main/java/com/hazelcast/spi/properties/GroupProperty.java b/hazelcast/src/main/java/com/hazelcast/spi/properties/GroupProperty.java index fafe15f24409..ca2bbd95253f 100644 --- a/hazelcast/src/main/java/com/hazelcast/spi/properties/GroupProperty.java +++ b/hazelcast/src/main/java/com/hazelcast/spi/properties/GroupProperty.java @@ -1085,6 +1085,12 @@ private int getWhenNoSSLDetected() { public static final HazelcastProperty MOBY_NAMING_ENABLED = new HazelcastProperty("hazelcast.member.naming.moby.enabled", true); + /** + * Client protocol message size limit (in bytes) for unverified connections (i.e. maximal length of authentication message). + */ + public static final HazelcastProperty CLIENT_PROTOCOL_UNVERIFIED_MESSAGE_BYTES = + new HazelcastProperty("hazelcast.client.protocol.max.message.bytes", 1024); + public static final HazelcastProperty AUDIT_LOG_ENABLED = new HazelcastProperty("hazelcast.auditlog.enabled", false); private GroupProperty() { diff --git a/hazelcast/src/test/java/com/hazelcast/client/impl/protocol/ClientExceptionFactoryTest.java b/hazelcast/src/test/java/com/hazelcast/client/impl/protocol/ClientExceptionFactoryTest.java index ade15b29a44a..e9a3e0c0fd15 100644 --- a/hazelcast/src/test/java/com/hazelcast/client/impl/protocol/ClientExceptionFactoryTest.java +++ b/hazelcast/src/test/java/com/hazelcast/client/impl/protocol/ClientExceptionFactoryTest.java @@ -247,7 +247,7 @@ public static Iterable parameters() { }, new Object[]{new NoDataMemberInClusterException(randomString())}, new Object[]{new ReplicatedMapCantBeCreatedOnLiteMemberException(randomString())}, - new Object[]{new MaxMessageSizeExceeded()}, + new Object[]{new MaxMessageSizeExceeded(randomString())}, new Object[]{new WanReplicationQueueFullException(randomString())}, new Object[]{new AssertionError(randomString())}, new Object[]{new OutOfMemoryError(randomString())}, diff --git a/hazelcast/src/test/java/com/hazelcast/client/impl/protocol/util/ClientMessageEncoderDecoderTest.java b/hazelcast/src/test/java/com/hazelcast/client/impl/protocol/util/ClientMessageEncoderDecoderTest.java index e500e9c7ac75..93ef11397635 100644 --- a/hazelcast/src/test/java/com/hazelcast/client/impl/protocol/util/ClientMessageEncoderDecoderTest.java +++ b/hazelcast/src/test/java/com/hazelcast/client/impl/protocol/util/ClientMessageEncoderDecoderTest.java @@ -74,7 +74,7 @@ public void test() { assertEquals(CLEAN, result); AtomicReference resultingMessage = new AtomicReference<>(); - ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set); + ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set, null); decoder.setNormalPacketsRead(SwCounter.newSwCounter()); buffer.position(buffer.limit()); @@ -107,7 +107,7 @@ public void testPut() { assertEquals(CLEAN, result); AtomicReference resultingMessage = new AtomicReference<>(); - ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set); + ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set, null); decoder.setNormalPacketsRead(SwCounter.newSwCounter()); buffer.position(buffer.limit()); @@ -151,7 +151,7 @@ public void testAuthenticationRequest() { assertEquals(CLEAN, result); AtomicReference resultingMessage = new AtomicReference<>(); - ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set); + ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set, null); decoder.setNormalPacketsRead(SwCounter.newSwCounter()); buffer.position(buffer.limit()); @@ -208,7 +208,7 @@ public void testAuthenticationResponse() throws UnknownHostException { assertEquals(CLEAN, result); AtomicReference resultingMessage = new AtomicReference<>(); - ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set); + ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set, null); decoder.setNormalPacketsRead(SwCounter.newSwCounter()); buffer.position(buffer.limit()); @@ -275,7 +275,7 @@ public void testEvent() throws UnknownHostException { assertEquals(CLEAN, result); AtomicReference resultingMessage = new AtomicReference<>(); - ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set); + ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set, null); decoder.setNormalPacketsRead(SwCounter.newSwCounter()); buffer.position(buffer.limit()); diff --git a/hazelcast/src/test/java/com/hazelcast/client/protocol/ClientMessageProtectionTest.java b/hazelcast/src/test/java/com/hazelcast/client/protocol/ClientMessageProtectionTest.java new file mode 100644 index 000000000000..96afc8aac0c1 --- /dev/null +++ b/hazelcast/src/test/java/com/hazelcast/client/protocol/ClientMessageProtectionTest.java @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2008-2019, Hazelcast, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hazelcast.client.protocol; + +import static com.hazelcast.client.impl.protocol.ClientMessage.IS_FINAL_FLAG; +import static com.hazelcast.client.impl.protocol.ClientMessage.SIZE_OF_FRAME_LENGTH_AND_FLAGS; +import static com.hazelcast.internal.nio.IOUtil.readFully; +import static com.hazelcast.internal.nio.Protocols.CLIENT_BINARY_NEW; +import static com.hazelcast.internal.util.StringUtil.UTF8_CHARSET; +import static com.hazelcast.test.HazelcastTestSupport.getNode; +import static com.hazelcast.test.HazelcastTestSupport.smallInstanceConfig; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import org.junit.After; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; + +import com.hazelcast.client.impl.protocol.AuthenticationStatus; +import com.hazelcast.client.impl.protocol.ClientMessage; +import com.hazelcast.client.impl.protocol.ClientMessage.Frame; +import com.hazelcast.client.impl.protocol.codec.ClientAuthenticationCodec; +import com.hazelcast.client.impl.protocol.util.ClientMessageSplitter; +import com.hazelcast.config.Config; +import com.hazelcast.core.HazelcastInstance; +import com.hazelcast.instance.EndpointQualifier; +import com.hazelcast.spi.properties.GroupProperty; +import com.hazelcast.test.HazelcastParallelClassRunner; +import com.hazelcast.test.TestAwareInstanceFactory; +import com.hazelcast.test.annotation.QuickTest; + +/** + * This class verifies that client protocol protection is able to filter large and fragmented messages for untrusted + * connections. + */ +@RunWith(HazelcastParallelClassRunner.class) +@Category({ QuickTest.class }) +public class ClientMessageProtectionTest { + + private final TestAwareInstanceFactory factory = new TestAwareInstanceFactory(); + + @Rule + public ExpectedException expected = ExpectedException.none(); + + @After + public void after() { + factory.terminateAll(); + } + + @Test + public void testLimitsRemovedAfterAValidAuthentication() throws IOException { + Config config = smallInstanceConfig(); + HazelcastInstance hz = factory.newHazelcastInstance(config); + ClientMessage clientMessage = createAuthenticationMessage(hz, createPassword(3)); + + InetSocketAddress address = getNode(hz).getLocalMember().getSocketAddress(EndpointQualifier.CLIENT); + try (Socket socket = new Socket(address.getAddress(), address.getPort())) { + socket.setSoTimeout(5000); + try (OutputStream os = socket.getOutputStream(); InputStream is = socket.getInputStream()) { + os.write(CLIENT_BINARY_NEW.getBytes(UTF8_CHARSET)); + writeClientMessage(os, clientMessage); + ClientMessage respMessage = readResponse(is); + assertEquals(ClientAuthenticationCodec.RESPONSE_MESSAGE_TYPE, respMessage.getMessageType()); + ClientAuthenticationCodec.ResponseParameters authnResponse = ClientAuthenticationCodec + .decodeResponse(respMessage); + assertEquals(AuthenticationStatus.AUTHENTICATED, AuthenticationStatus.getById(authnResponse.status)); + + // the connection is now trusted, lets try bigger and fragmented messages + ClientMessage authenticationMessage = createAuthenticationMessage(hz, createPassword(1024)); + writeClientMessage(os, authenticationMessage); + respMessage = readResponse(is); + assertEquals(ClientAuthenticationCodec.RESPONSE_MESSAGE_TYPE, respMessage.getMessageType()); + authnResponse = ClientAuthenticationCodec.decodeResponse(respMessage); + assertEquals(AuthenticationStatus.AUTHENTICATED, AuthenticationStatus.getById(authnResponse.status)); + + List subFrames = ClientMessageSplitter.getFragments(50, clientMessage); + assertTrue(subFrames.size() > 1); + for (ClientMessage frame : subFrames) { + writeClientMessage(os, frame); + } + respMessage = readResponse(is); + assertEquals(ClientAuthenticationCodec.RESPONSE_MESSAGE_TYPE, respMessage.getMessageType()); + authnResponse = ClientAuthenticationCodec.decodeResponse(respMessage); + assertEquals(AuthenticationStatus.AUTHENTICATED, AuthenticationStatus.getById(authnResponse.status)); + } + } + } + + @Test + public void testMessageFraming() throws IOException { + Config config = smallInstanceConfig(); + HazelcastInstance hz = factory.newHazelcastInstance(config); + ClientMessage clientMessage = createAuthenticationMessage(hz, createPassword(200)); + InetSocketAddress address = getNode(hz).getLocalMember().getSocketAddress(EndpointQualifier.CLIENT); + try (Socket socket = new Socket(address.getAddress(), address.getPort())) { + socket.setSoTimeout(5000); + try (OutputStream os = socket.getOutputStream(); InputStream is = socket.getInputStream()) { + os.write(CLIENT_BINARY_NEW.getBytes(UTF8_CHARSET)); + List subFrames = ClientMessageSplitter.getFragments(50, clientMessage); + assertTrue(subFrames.size() > 1); + writeClientMessage(os, subFrames.get(0)); + expected.expect(EOFException.class); + readResponse(is); + } + } + } + + @Test + public void testExceededMessageSize() throws IOException { + Config config = smallInstanceConfig(); + int limit = 800; + config.setProperty(GroupProperty.CLIENT_PROTOCOL_UNVERIFIED_MESSAGE_BYTES.getName(), Integer.toString(limit)); + HazelcastInstance hz = factory.newHazelcastInstance(config); + String password = createPassword(limit); + ClientMessage clientMessage = createAuthenticationMessage(hz, password); + InetSocketAddress address = getNode(hz).getLocalMember().getSocketAddress(EndpointQualifier.CLIENT); + try (Socket socket = new Socket(address.getAddress(), address.getPort())) { + socket.setSoTimeout(5000); + try (OutputStream os = socket.getOutputStream(); InputStream is = socket.getInputStream()) { + os.write(CLIENT_BINARY_NEW.getBytes(UTF8_CHARSET)); + writeClientMessage(os, clientMessage); + expected.expect(EOFException.class); + readResponse(is); + } + } + } + + @Test + public void testNegativeFrameLength() throws IOException { + Config config = smallInstanceConfig(); + HazelcastInstance hz = factory.newHazelcastInstance(config); + ClientMessage clientMessage = createAuthenticationMessage(hz, ""); + InetSocketAddress address = getNode(hz).getLocalMember().getSocketAddress(EndpointQualifier.CLIENT); + try (Socket socket = new Socket(address.getAddress(), address.getPort())) { + socket.setSoTimeout(5000); + try (OutputStream os = socket.getOutputStream(); InputStream is = socket.getInputStream()) { + os.write(CLIENT_BINARY_NEW.getBytes(UTF8_CHARSET)); + ByteBuffer buffer = ByteBuffer.allocateDirect(1024 * 1024); + buffer.order(ByteOrder.LITTLE_ENDIAN); + // it should be enough to write just the first frame + Frame frame = clientMessage.get(0); + buffer.putInt(Integer.MIN_VALUE); + buffer.putShort((short) (frame.flags)); + buffer.put(frame.content); + os.write(byteBufferToBytes(buffer)); + os.flush(); + expected.expect(EOFException.class); + readResponse(is); + } + } + } + + @Test + public void testAccumulatedMessageSizeOverflow() throws IOException { + Config config = smallInstanceConfig(); + HazelcastInstance hz = factory.newHazelcastInstance(config); + ClientMessage clientMessage = createAuthenticationMessage(hz, ""); + InetSocketAddress address = getNode(hz).getLocalMember().getSocketAddress(EndpointQualifier.CLIENT); + try (Socket socket = new Socket(address.getAddress(), address.getPort())) { + socket.setSoTimeout(5000); + try (OutputStream os = socket.getOutputStream(); InputStream is = socket.getInputStream()) { + os.write(CLIENT_BINARY_NEW.getBytes(UTF8_CHARSET)); + // it should be enough to write just the first frame + byte[] firstFrameBytes = frameAsBytes(clientMessage.get(0), false); + os.write(firstFrameBytes); + ByteBuffer buffer = ByteBuffer.allocateDirect(SIZE_OF_FRAME_LENGTH_AND_FLAGS); + buffer.order(ByteOrder.LITTLE_ENDIAN); + // try to cause the size accumulator overflow + buffer.putInt(Integer.MAX_VALUE - firstFrameBytes.length + 1); + Frame frame = clientMessage.get(1); + buffer.putShort((short) frame.flags); + os.write(byteBufferToBytes(buffer)); + os.flush(); + expected.expect(EOFException.class); + readResponse(is); + } + } + } + + private String createPassword(int pwdLength) { + return new String(new char[pwdLength]).replace('\0', 'a'); + } + + private ClientMessage createAuthenticationMessage(HazelcastInstance hz, String passwd) { + return ClientAuthenticationCodec.encodeRequest(hz.getConfig().getClusterName(), passwd, null, null, true, + "FOO", (byte) 1, "abc", "xxx", new ArrayList<>(), -1, null); + } + + private ClientMessage readResponse(InputStream is) throws IOException, EOFException { + LinkedList frames = new LinkedList<>(); + while (true) { + ByteBuffer frameSizeBuffer = ByteBuffer.allocate(SIZE_OF_FRAME_LENGTH_AND_FLAGS); + frameSizeBuffer.order(ByteOrder.LITTLE_ENDIAN); + readFully(is, frameSizeBuffer.array()); + int frameSize = frameSizeBuffer.getInt(); + int flags = frameSizeBuffer.getShort() & 0xffff; + byte[] content = new byte[frameSize - SIZE_OF_FRAME_LENGTH_AND_FLAGS]; + readFully(is, content); + frames.add(new ClientMessage.Frame(content, flags)); + if (ClientMessage.isFlagSet(flags, IS_FINAL_FLAG)) { + break; + } + } + ClientMessage respMessage = ClientMessage.createForDecode(frames); + return respMessage; + } + + private void writeClientMessage(OutputStream os, final ClientMessage clientMessage) throws IOException { + for (Iterator it = clientMessage.iterator(); it.hasNext();) { + ClientMessage.Frame frame = it.next(); + os.write(frameAsBytes(frame, it.hasNext())); + } + os.flush(); + } + + private byte[] frameAsBytes(ClientMessage.Frame frame, boolean isLastFrame) { + byte[] content = frame.content != null ? frame.content : new byte[0]; + int frameSize = content.length + SIZE_OF_FRAME_LENGTH_AND_FLAGS; + ByteBuffer buffer = ByteBuffer.allocateDirect(frameSize); + buffer.order(ByteOrder.LITTLE_ENDIAN); + buffer.putInt(frameSize); + if (isLastFrame) { + buffer.putShort((short) frame.flags); + } else { + buffer.putShort((short) (frame.flags | IS_FINAL_FLAG)); + } + buffer.put(content); + return byteBufferToBytes(buffer); + } + + private static byte[] byteBufferToBytes(ByteBuffer buffer) { + buffer.flip(); + byte[] requestBytes = new byte[buffer.limit()]; + buffer.get(requestBytes); + return requestBytes; + } + +} diff --git a/hazelcast/src/test/java/com/hazelcast/client/protocol/ClientMessageSplitAndBuildTest.java b/hazelcast/src/test/java/com/hazelcast/client/protocol/ClientMessageSplitAndBuildTest.java index 71418d250d01..3fb4b061477e 100644 --- a/hazelcast/src/test/java/com/hazelcast/client/protocol/ClientMessageSplitAndBuildTest.java +++ b/hazelcast/src/test/java/com/hazelcast/client/protocol/ClientMessageSplitAndBuildTest.java @@ -99,7 +99,7 @@ public void splitAndBuild() { Assert.assertEquals(CLEAN, result); AtomicReference resultingMessage = new AtomicReference<>(); - ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set); + ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set, null); decoder.setNormalPacketsRead(SwCounter.newSwCounter()); buffer.position(buffer.limit()); @@ -135,7 +135,7 @@ public void splitAndBuild_multipleMessages() { Assert.assertEquals(CLEAN, result); Queue inputQueue = new ConcurrentLinkedQueue<>(); - ClientMessageDecoder decoder = new ClientMessageDecoder(null, inputQueue::offer); + ClientMessageDecoder decoder = new ClientMessageDecoder(null, inputQueue::offer, null); decoder.setNormalPacketsRead(SwCounter.newSwCounter()); buffer.position(buffer.limit()); @@ -170,7 +170,7 @@ public void splitAndBuild_whenMessageIsAlreadySmallerThanFrameSize() { Assert.assertEquals(CLEAN, result); AtomicReference resultingMessage = new AtomicReference<>(); - ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set); + ClientMessageDecoder decoder = new ClientMessageDecoder(null, resultingMessage::set, null); decoder.setNormalPacketsRead(SwCounter.newSwCounter()); buffer.position(buffer.limit());