Skip to content

Commit

Permalink
[SPARK-29971][CORE] Fix buffer leaks in `TransportFrameDecoder/Transp…
Browse files Browse the repository at this point in the history
…ortCipher`

- Correctly release `ByteBuf` in `TransportCipher` in all cases
- Move closing / releasing logic to `handlerRemoved(...)` so we are guaranteed that is always called.

We need to carefully manage the ownership / lifecycle of `ByteBuf` instances so we don't leak any of these. We did not correctly do this in all cases:
 - when end up in invalid cipher state.
 - when partial data was received and the channel is closed before the full frame is decoded

Fixes netty/netty#9784.

No.

Pass the newly added UTs.

Closes apache#26609 from normanmaurer/leaks_2_4.

Authored-by: Norman Maurer <norman_maurer@apple.com>
  • Loading branch information
normanmaurer committed Nov 25, 2019
1 parent 6880ccd commit e8e6e60
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 35 deletions.
Expand Up @@ -90,7 +90,8 @@ CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException
return new CryptoOutputStream(cipher, conf, ch, key, new IvParameterSpec(outIv));
}

private CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException {
@VisibleForTesting
CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException {
return new CryptoInputStream(cipher, conf, ch, key, new IvParameterSpec(inIv));
}

Expand Down Expand Up @@ -166,34 +167,45 @@ private static class DecryptionHandler extends ChannelInboundHandlerAdapter {

@Override
public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception {
if (!isCipherValid) {
throw new IOException("Cipher is in invalid state.");
}
byteChannel.feedData((ByteBuf) data);

byte[] decryptedData = new byte[byteChannel.readableBytes()];
int offset = 0;
while (offset < decryptedData.length) {
// SPARK-25535: workaround for CRYPTO-141.
try {
offset += cis.read(decryptedData, offset, decryptedData.length - offset);
} catch (InternalError ie) {
isCipherValid = false;
throw ie;
ByteBuf buffer = (ByteBuf) data;

try {
if (!isCipherValid) {
throw new IOException("Cipher is in invalid state.");
}
byte[] decryptedData = new byte[buffer.readableBytes()];
byteChannel.feedData(buffer);

int offset = 0;
while (offset < decryptedData.length) {
// SPARK-25535: workaround for CRYPTO-141.
try {
offset += cis.read(decryptedData, offset, decryptedData.length - offset);
} catch (InternalError ie) {
isCipherValid = false;
throw ie;
}
}
}

ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length));
ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length));
} finally {
buffer.release();
}
}

@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
// We do the closing of the stream / channel in handlerRemoved(...) as
// this method will be called in all cases:
//
// - when the Channel becomes inactive
// - when the handler is removed from the ChannelPipeline
try {
if (isCipherValid) {
cis.close();
}
} finally {
super.channelInactive(ctx);
super.handlerRemoved(ctx);
}
}
}
Expand Down
Expand Up @@ -19,44 +19,44 @@

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.ReadableByteChannel;

import io.netty.buffer.ByteBuf;

public class ByteArrayReadableChannel implements ReadableByteChannel {
private ByteBuf data;
private boolean closed;

public int readableBytes() {
return data.readableBytes();
}

public void feedData(ByteBuf buf) {
public void feedData(ByteBuf buf) throws ClosedChannelException {
if (closed) {
throw new ClosedChannelException();
}
data = buf;
}

@Override
public int read(ByteBuffer dst) throws IOException {
if (closed) {
throw new ClosedChannelException();
}
int totalRead = 0;
while (data.readableBytes() > 0 && dst.remaining() > 0) {
int bytesToRead = Math.min(data.readableBytes(), dst.remaining());
dst.put(data.readSlice(bytesToRead).nioBuffer());
totalRead += bytesToRead;
}

if (data.readableBytes() == 0) {
data.release();
}

return totalRead;
}

@Override
public void close() throws IOException {
public void close() {
closed = true;
}

@Override
public boolean isOpen() {
return true;
return !closed;
}

}
Expand Up @@ -172,13 +172,9 @@ private ByteBuf nextBufferForFrame(int bytesToRead) {

@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
for (ByteBuf b : buffers) {
b.release();
}
if (interceptor != null) {
interceptor.channelInactive();
}
frameLenBuf.release();
super.channelInactive(ctx);
}

Expand All @@ -190,6 +186,20 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E
super.exceptionCaught(ctx, cause);
}

@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
// Release all buffers that are still in our ownership.
// Doing this in handlerRemoved(...) guarantees that this will happen in all cases:
// - When the Channel becomes inactive
// - When the decoder is removed from the ChannelPipeline
for (ByteBuf b : buffers) {
b.release();
}
buffers.clear();
frameLenBuf.release();
super.handlerRemoved(ctx);
}

public void setInterceptor(Interceptor interceptor) {
Preconditions.checkState(this.interceptor == null, "Already have an interceptor.");
this.interceptor = interceptor;
Expand Down
@@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.crypto;

import javax.crypto.spec.SecretKeySpec;
import java.io.IOException;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import org.apache.commons.crypto.stream.CryptoInputStream;
import org.apache.commons.crypto.stream.CryptoOutputStream;
import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;
import org.hamcrest.CoreMatchers;
import org.junit.Test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class TransportCipherSuite {

@Test
public void testBufferNotLeaksOnInternalError() throws IOException {
String algorithm = "TestAlgorithm";
TransportConf conf = new TransportConf("Test", MapConfigProvider.EMPTY);
TransportCipher cipher = new TransportCipher(conf.cryptoConf(), conf.cipherTransformation(),
new SecretKeySpec(new byte[256], algorithm), new byte[0], new byte[0]) {

@Override
CryptoOutputStream createOutputStream(WritableByteChannel ch) {
return null;
}

@Override
CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException {
CryptoInputStream mockInputStream = mock(CryptoInputStream.class);
when(mockInputStream.read(any(byte[].class), anyInt(), anyInt()))
.thenThrow(new InternalError());
return mockInputStream;
}
};

EmbeddedChannel channel = new EmbeddedChannel();
cipher.addToChannel(channel);

ByteBuf buffer = Unpooled.wrappedBuffer(new byte[] { 1, 2 });
ByteBuf buffer2 = Unpooled.wrappedBuffer(new byte[] { 1, 2 });

try {
channel.writeInbound(buffer);
fail("Should have raised InternalError");
} catch (InternalError expected) {
// expected
assertEquals(0, buffer.refCnt());
}

try {
channel.writeInbound(buffer2);
fail("Should have raised an exception");
} catch (Throwable expected) {
assertThat(expected, CoreMatchers.instanceOf(IOException.class));
assertEquals(0, buffer2.refCnt());
}

// Simulate closing the connection
assertFalse(channel.finish());
}
}

0 comments on commit e8e6e60

Please sign in to comment.