Skip to content

Commit

Permalink
Let OpenSslEngine.wrap(...) / OpenSslEngine.unwrap(...) behave like s…
Browse files Browse the repository at this point in the history
…tated in the javadocs.

Motivation:

OpenSslEngine.wrap(...) and OpenSslEngie.unwrap(...) may consume bytes even if an BUFFER_OVERFLOW / BUFFER_UNDERFLOW is detected. This is not correct as it should only consume bytes if it can process them without storing data between unwrap(...) / wrap(...) calls. Beside this it also should only process one record at a time.

Modifications:

- Correctly detect BUFFER_OVERFLOW / BUFFER_UNDERFLOW and only consume bytes if non of them is detected.
- Only process one record per call.

Result:

OpenSslEngine behaves like stated in the javadocs of SSLEngine.
  • Loading branch information
normanmaurer committed Nov 11, 2016
1 parent e47da7b commit fc3c9c9
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 29 deletions.
Expand Up @@ -67,6 +67,7 @@
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
import static javax.net.ssl.SSLEngineResult.Status.BUFFER_UNDERFLOW;
import static javax.net.ssl.SSLEngineResult.Status.CLOSED;
import static javax.net.ssl.SSLEngineResult.Status.OK;

Expand Down Expand Up @@ -416,9 +417,8 @@ private int writePlaintextData(final ByteBuffer src) {
/**
* Write encrypted data to the OpenSSL network BIO.
*/
private int writeEncryptedData(final ByteBuffer src) {
private int writeEncryptedData(final ByteBuffer src, int len) {
final int pos = src.position();
final int len = src.remaining();
final int netWrote;
if (src.isDirect()) {
final long addr = Buffer.address(src) + pos;
Expand All @@ -430,8 +430,12 @@ private int writeEncryptedData(final ByteBuffer src) {
final ByteBuf buf = alloc.directBuffer(len);
try {
final long addr = memoryAddress(buf);

buf.setBytes(0, src);
int newLimit = pos + len;
if (newLimit != src.remaining()) {
buf.setBytes(0, (ByteBuffer) src.duplicate().position(pos).limit(newLimit));
} else {
buf.setBytes(0, src);
}

netWrote = SSL.writeToBIO(networkBIO, addr, len);
if (netWrote >= 0) {
Expand Down Expand Up @@ -601,6 +605,11 @@ public final SSLEngineResult wrap(
}
}

if (dst.remaining() < MAX_ENCRYPTED_PACKET_LENGTH) {
// Can not hold the maximum packet so we need to tell the caller to use a bigger destination
// buffer.
return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0);
}
// There was no pending data in the network BIO -- encrypt any application data
int bytesProduced = 0;
int bytesConsumed = 0;
Expand Down Expand Up @@ -775,9 +784,29 @@ public final SSLEngineResult unwrap(
}
}

// Write encrypted data to network BIO
if (len < SslUtils.SSL_RECORD_HEADER_LENGTH) {
return new SSLEngineResult(BUFFER_UNDERFLOW, getHandshakeStatus(), 0, 0);
}

int packetLength = SslUtils.getEncryptedPacketLength(srcs, srcsOffset);
if (packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH > capacity) {
// No enough space in the destination buffer so signal the caller
// that the buffer needs to be increased.
return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0);
}

if (len < packetLength) {
// We either have no enough data to read the packet length at all or not enough for reading
// the whole packet.
return new SSLEngineResult(BUFFER_UNDERFLOW, getHandshakeStatus(), 0, 0);
}

int bytesConsumed = 0;
if (srcsOffset < srcsEndOffset) {

// Write encrypted data to network BIO
int packetLengthRemaining = packetLength;

do {
ByteBuffer src = srcs[srcsOffset];
int remaining = src.remaining();
Expand All @@ -787,9 +816,15 @@ public final SSLEngineResult unwrap(
srcsOffset++;
continue;
}
int written = writeEncryptedData(src);
// Write more encrypted data into the BIO. Ensure we only read one packet at a time as
// stated in the SSLEngine javadocs.
int written = writeEncryptedData(src, Math.min(packetLengthRemaining, src.remaining()));
if (written > 0) {
bytesConsumed += written;
packetLengthRemaining -= written;
if (packetLengthRemaining == 0) {
// A whole packet has been consumed.
break;
}

if (written == remaining) {
srcsOffset++;
Expand All @@ -808,6 +843,7 @@ public final SSLEngineResult unwrap(
break;
}
} while (srcsOffset < srcsEndOffset);
bytesConsumed = packetLength - packetLengthRemaining;
}

// Number of produced bytes
Expand Down
23 changes: 4 additions & 19 deletions handler/src/main/java/io/netty/handler/ssl/SslHandler.java
Expand Up @@ -197,14 +197,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
* {@code true} if and only if {@link SSLEngine} expects a direct buffer.
*/
private final boolean wantsDirectBuffer;
/**
* {@code true} if and only if {@link SSLEngine#wrap(ByteBuffer, ByteBuffer)} requires the output buffer
* to be always as large as {@link #maxPacketBufferSize} even if the input buffer contains small amount of data.
* <p>
* If this flag is {@code false}, we allocate a smaller output buffer.
* </p>
*/
private final boolean wantsLargeOutboundNetworkBuffer;

// END Platform-dependent flags

Expand Down Expand Up @@ -283,7 +275,6 @@ public SslHandler(SSLEngine engine, boolean startTls, Executor delegatedTaskExec

boolean opensslEngine = engine instanceof OpenSslEngine;
wantsDirectBuffer = opensslEngine;
wantsLargeOutboundNetworkBuffer = !opensslEngine;

/**
* When using JDK {@link SSLEngine}, we use {@link #MERGE_CUMULATOR} because it works only with
Expand Down Expand Up @@ -516,7 +507,7 @@ private void wrap(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLExcepti

ByteBuf buf = (ByteBuf) msg;
if (out == null) {
out = allocateOutNetBuf(ctx, buf.readableBytes());
out = allocateOutNetBuf(ctx);
}

SSLEngineResult result = wrap(alloc, engine, buf, out);
Expand Down Expand Up @@ -599,7 +590,7 @@ private void wrapNonAppData(ChannelHandlerContext ctx, boolean inUnwrap) throws
// See https://github.com/netty/netty/issues/5860
while (!ctx.isRemoved()) {
if (out == null) {
out = allocateOutNetBuf(ctx, 0);
out = allocateOutNetBuf(ctx);
}
SSLEngineResult result = wrap(alloc, engine, Unpooled.EMPTY_BUFFER, out);

Expand Down Expand Up @@ -1477,14 +1468,8 @@ private ByteBuf allocate(ChannelHandlerContext ctx, int capacity) {
* Allocates an outbound network buffer for {@link SSLEngine#wrap(ByteBuffer, ByteBuffer)} which can encrypt
* the specified amount of pending bytes.
*/
private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes) {
if (wantsLargeOutboundNetworkBuffer) {
return allocate(ctx, maxPacketBufferSize);
} else {
return allocate(ctx, Math.min(
pendingBytes + OpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH,
maxPacketBufferSize));
}
private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx) {
return allocate(ctx, maxPacketBufferSize);
}

private final class LazyChannelPromise extends DefaultPromise<Channel> {
Expand Down
88 changes: 88 additions & 0 deletions handler/src/main/java/io/netty/handler/ssl/SslUtils.java
Expand Up @@ -21,6 +21,8 @@
import io.netty.handler.codec.base64.Base64;
import io.netty.handler.codec.base64.Base64Dialect;

import java.nio.ByteBuffer;

/**
* Constants for SSL packets.
*/
Expand Down Expand Up @@ -120,6 +122,92 @@ static int getEncryptedPacketLength(ByteBuf buffer, int offset) {
return packetLength;
}

private static short unsignedByte(byte b) {
return (short) (b & 0xFF);
}

private static int unsignedShort(short s) {
return s & 0xFFFF;
}

static int getEncryptedPacketLength(ByteBuffer[] buffers, int offset) {
ByteBuffer buffer = buffers[offset];

// Check if everything we need is in one ByteBuffer. If so we can make use of the fast-path.
if (buffer.remaining() >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
return getEncryptedPacketLength(buffer);
}

// We need to copy 5 bytes into a temporary buffer so we can parse out the packet length easily.
ByteBuffer tmp = ByteBuffer.allocate(5);

do {
buffer = buffers[offset++].duplicate();
if (buffer.remaining() > tmp.remaining()) {
buffer.limit(buffer.position() + tmp.remaining());
}
tmp.put(buffer);
} while (tmp.hasRemaining());

// Done, flip the buffer so we can read from it.
tmp.flip();
return getEncryptedPacketLength(tmp);
}

private static int getEncryptedPacketLength(ByteBuffer buffer) {
int packetLength = 0;
int pos = buffer.position();
// SSLv3 or TLS - Check ContentType
boolean tls;
switch (unsignedByte(buffer.get(pos))) {
case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
case SslUtils.SSL_CONTENT_TYPE_ALERT:
case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
case SslUtils.SSL_CONTENT_TYPE_APPLICATION_DATA:
tls = true;
break;
default:
// SSLv2 or bad data
tls = false;
}

if (tls) {
// SSLv3 or TLS - Check ProtocolVersion
int majorVersion = unsignedByte(buffer.get(pos + 1));
if (majorVersion == 3) {
// SSLv3 or TLS
packetLength = unsignedShort(buffer.getShort(pos + 3)) + SslUtils.SSL_RECORD_HEADER_LENGTH;
if (packetLength <= SslUtils.SSL_RECORD_HEADER_LENGTH) {
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
tls = false;
}
} else {
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
tls = false;
}
}

if (!tls) {
// SSLv2 or bad data - Check the version
int headerLength = (unsignedByte(buffer.get(pos)) & 0x80) != 0 ? 2 : 3;
int majorVersion = unsignedByte(buffer.get(pos + headerLength + 1));
if (majorVersion == 2 || majorVersion == 3) {
// SSLv2
if (headerLength == 2) {
packetLength = (buffer.getShort(pos) & 0x7FFF) + 2;
} else {
packetLength = (buffer.getShort(pos) & 0x3FFF) + 3;
}
if (packetLength <= headerLength) {
return -1;
}
} else {
return -1;
}
}
return packetLength;
}

static void notifyHandshakeFailure(ChannelHandlerContext ctx, Throwable cause) {
// We have may haven written some parts of data before an exception was thrown so ensure we always flush.
// See https://github.com/netty/netty/issues/3900#issuecomment-172481830
Expand Down
85 changes: 82 additions & 3 deletions handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java
Expand Up @@ -678,9 +678,8 @@ protected void testEnablingAnAlreadyDisabledSslProtocol(String[] protocols1, Str
}

protected static void handshake(SSLEngine clientEngine, SSLEngine serverEngine) throws SSLException {
int netBufferSize = 17 * 1024;
ByteBuffer cTOs = ByteBuffer.allocateDirect(netBufferSize);
ByteBuffer sTOc = ByteBuffer.allocateDirect(netBufferSize);
ByteBuffer cTOs = ByteBuffer.allocateDirect(clientEngine.getSession().getPacketBufferSize());
ByteBuffer sTOc = ByteBuffer.allocateDirect(serverEngine.getSession().getPacketBufferSize());

ByteBuffer serverAppReadBuffer = ByteBuffer.allocateDirect(
serverEngine.getSession().getApplicationBufferSize());
Expand Down Expand Up @@ -915,4 +914,84 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc

promise.syncUninterruptibly();
}

@Test
public void testUnwrapBehavior() throws Exception {
SelfSignedCertificate cert = new SelfSignedCertificate();

clientSslCtx = SslContextBuilder
.forClient()
.trustManager(cert.cert())
.sslProvider(sslClientProvider())
.build();
SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT);

serverSslCtx = SslContextBuilder
.forServer(cert.certificate(), cert.privateKey())
.sslProvider(sslServerProvider())
.build();
SSLEngine server = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT);

byte[] bytes = "Hello World".getBytes(CharsetUtil.US_ASCII);

try {
ByteBuffer plainClientOut = ByteBuffer.allocate(client.getSession().getApplicationBufferSize());
ByteBuffer encryptedClientToServer = ByteBuffer.allocate(server.getSession().getPacketBufferSize() * 2);
ByteBuffer plainServerIn = ByteBuffer.allocate(server.getSession().getApplicationBufferSize());

handshake(client, server);

// create two TLS frames

// first frame
plainClientOut.put(bytes, 0, 5);
plainClientOut.flip();

SSLEngineResult result = client.wrap(plainClientOut, encryptedClientToServer);
assertEquals(SSLEngineResult.Status.OK, result.getStatus());
assertEquals(5, result.bytesConsumed());
assertTrue(result.bytesProduced() > 0);

assertFalse(plainClientOut.hasRemaining());

// second frame
plainClientOut.clear();
plainClientOut.put(bytes, 5, 6);
plainClientOut.flip();

result = client.wrap(plainClientOut, encryptedClientToServer);
assertEquals(SSLEngineResult.Status.OK, result.getStatus());
assertEquals(6, result.bytesConsumed());
assertTrue(result.bytesProduced() > 0);

// send over to server
encryptedClientToServer.flip();

// try with too small output buffer first (to check BUFFER_OVERFLOW case)
int remaining = encryptedClientToServer.remaining();
ByteBuffer small = ByteBuffer.allocate(3);
result = server.unwrap(encryptedClientToServer, small);
assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus());
assertEquals(remaining, encryptedClientToServer.remaining());

// now with big enough buffer
result = server.unwrap(encryptedClientToServer, plainServerIn);
assertEquals(SSLEngineResult.Status.OK, result.getStatus());

assertEquals(5, result.bytesProduced());
assertTrue(encryptedClientToServer.hasRemaining());

result = server.unwrap(encryptedClientToServer, plainServerIn);
assertEquals(SSLEngineResult.Status.OK, result.getStatus());
assertEquals(6, result.bytesProduced());
assertFalse(encryptedClientToServer.hasRemaining());

plainServerIn.flip();

assertEquals(ByteBuffer.wrap(bytes), plainServerIn);
} finally {
cleanupClientSslEngine(client);
cleanupServerSslEngine(server);
}
}
}

0 comments on commit fc3c9c9

Please sign in to comment.