diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/FlushOperation.java b/libs/nio/src/main/java/org/elasticsearch/nio/FlushOperation.java index de0318a941af6..7a1696483db06 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/FlushOperation.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/FlushOperation.java @@ -25,8 +25,6 @@ public class FlushOperation { - private static final ByteBuffer[] EMPTY_ARRAY = new ByteBuffer[0]; - private final BiConsumer listener; private final ByteBuffer[] buffers; private final int[] offsets; @@ -63,38 +61,19 @@ public void incrementIndex(int delta) { } public ByteBuffer[] getBuffersToWrite() { - return getBuffersToWrite(length); - } - - public ByteBuffer[] getBuffersToWrite(int maxBytes) { final int index = Arrays.binarySearch(offsets, internalIndex); - final int offsetIndex = index < 0 ? (-(index + 1)) - 1 : index; - final int finalIndex = Arrays.binarySearch(offsets, Math.min(internalIndex + maxBytes, length)); - final int finalOffsetIndex = finalIndex < 0 ? (-(finalIndex + 1)) - 1 : finalIndex; + int offsetIndex = index < 0 ? (-(index + 1)) - 1 : index; - int nBuffers = (finalOffsetIndex - offsetIndex) + 1; + ByteBuffer[] postIndexBuffers = new ByteBuffer[buffers.length - offsetIndex]; - int firstBufferPosition = internalIndex - offsets[offsetIndex]; ByteBuffer firstBuffer = buffers[offsetIndex].duplicate(); - firstBuffer.position(firstBufferPosition); - if (nBuffers == 1 && firstBuffer.remaining() == 0) { - return EMPTY_ARRAY; - } - - ByteBuffer[] postIndexBuffers = new ByteBuffer[nBuffers]; + firstBuffer.position(internalIndex - offsets[offsetIndex]); postIndexBuffers[0] = firstBuffer; - int finalOffset = offsetIndex + nBuffers; - int nBytes = firstBuffer.remaining(); int j = 1; - for (int i = (offsetIndex + 1); i < finalOffset; ++i) { - ByteBuffer buffer = buffers[i].duplicate(); - nBytes += buffer.remaining(); - postIndexBuffers[j++] = buffer; + for (int i = (offsetIndex + 1); i < buffers.length; ++i) { + postIndexBuffers[j++] = buffers[i].duplicate(); } - int excessBytes = Math.max(0, nBytes - maxBytes); - ByteBuffer lastBuffer = postIndexBuffers[postIndexBuffers.length - 1]; - lastBuffer.limit(lastBuffer.limit() - excessBytes); return postIndexBuffers; } } diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/FlushReadyWrite.java b/libs/nio/src/main/java/org/elasticsearch/nio/FlushReadyWrite.java index 4855e0cbade9c..61c997603ff97 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/FlushReadyWrite.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/FlushReadyWrite.java @@ -27,7 +27,7 @@ public class FlushReadyWrite extends FlushOperation implements WriteOperation { private final SocketChannelContext channelContext; private final ByteBuffer[] buffers; - public FlushReadyWrite(SocketChannelContext channelContext, ByteBuffer[] buffers, BiConsumer listener) { + FlushReadyWrite(SocketChannelContext channelContext, ByteBuffer[] buffers, BiConsumer listener) { super(buffers, listener); this.channelContext = channelContext; this.buffers = buffers; diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/InboundChannelBuffer.java b/libs/nio/src/main/java/org/elasticsearch/nio/InboundChannelBuffer.java index 2dfd53d27e109..f7e6fbb768728 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/InboundChannelBuffer.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/InboundChannelBuffer.java @@ -19,6 +19,7 @@ package org.elasticsearch.nio; +import org.elasticsearch.common.util.concurrent.AbstractRefCounted; import org.elasticsearch.nio.utils.ExceptionsHelper; import java.nio.ByteBuffer; @@ -139,11 +140,11 @@ public ByteBuffer[] sliceBuffersTo(long to) { ByteBuffer[] buffers = new ByteBuffer[pageCount]; Iterator pageIterator = pages.iterator(); - ByteBuffer firstBuffer = pageIterator.next().byteBuffer().duplicate(); + ByteBuffer firstBuffer = pageIterator.next().byteBuffer.duplicate(); firstBuffer.position(firstBuffer.position() + offset); buffers[0] = firstBuffer; for (int i = 1; i < buffers.length; i++) { - buffers[i] = pageIterator.next().byteBuffer().duplicate(); + buffers[i] = pageIterator.next().byteBuffer.duplicate(); } if (finalLimit != 0) { buffers[buffers.length - 1].limit(finalLimit); @@ -179,14 +180,14 @@ public Page[] sliceAndRetainPagesTo(long to) { Page[] pages = new Page[pageCount]; Iterator pageIterator = this.pages.iterator(); Page firstPage = pageIterator.next().duplicate(); - ByteBuffer firstBuffer = firstPage.byteBuffer(); + ByteBuffer firstBuffer = firstPage.byteBuffer; firstBuffer.position(firstBuffer.position() + offset); pages[0] = firstPage; for (int i = 1; i < pages.length; i++) { pages[i] = pageIterator.next().duplicate(); } if (finalLimit != 0) { - pages[pages.length - 1].byteBuffer().limit(finalLimit); + pages[pages.length - 1].byteBuffer.limit(finalLimit); } return pages; @@ -216,9 +217,9 @@ public ByteBuffer[] sliceBuffersFrom(long from) { ByteBuffer[] buffers = new ByteBuffer[pages.size() - pageIndex]; Iterator pageIterator = pages.descendingIterator(); for (int i = buffers.length - 1; i > 0; --i) { - buffers[i] = pageIterator.next().byteBuffer().duplicate(); + buffers[i] = pageIterator.next().byteBuffer.duplicate(); } - ByteBuffer firstPostIndexBuffer = pageIterator.next().byteBuffer().duplicate(); + ByteBuffer firstPostIndexBuffer = pageIterator.next().byteBuffer.duplicate(); firstPostIndexBuffer.position(firstPostIndexBuffer.position() + indexInPage); buffers[0] = firstPostIndexBuffer; @@ -267,4 +268,53 @@ private int pageIndex(long index) { private int indexInPage(long index) { return (int) (index & PAGE_MASK); } + + public static class Page implements AutoCloseable { + + private final ByteBuffer byteBuffer; + // This is reference counted as some implementations want to retain the byte pages by calling + // sliceAndRetainPagesTo. With reference counting we can increment the reference count, return the + // pages, and safely close them when this channel buffer is done with them. The reference count + // would be 1 at that point, meaning that the pages will remain until the implementation closes + // theirs. + private final RefCountedCloseable refCountedCloseable; + + public Page(ByteBuffer byteBuffer, Runnable closeable) { + this(byteBuffer, new RefCountedCloseable(closeable)); + } + + private Page(ByteBuffer byteBuffer, RefCountedCloseable refCountedCloseable) { + this.byteBuffer = byteBuffer; + this.refCountedCloseable = refCountedCloseable; + } + + private Page duplicate() { + refCountedCloseable.incRef(); + return new Page(byteBuffer.duplicate(), refCountedCloseable); + } + + public ByteBuffer getByteBuffer() { + return byteBuffer; + } + + @Override + public void close() { + refCountedCloseable.decRef(); + } + + private static class RefCountedCloseable extends AbstractRefCounted { + + private final Runnable closeable; + + private RefCountedCloseable(Runnable closeable) { + super("byte array page"); + this.closeable = closeable; + } + + @Override + protected void closeInternal() { + closeable.run(); + } + } + } } diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/Page.java b/libs/nio/src/main/java/org/elasticsearch/nio/Page.java deleted file mode 100644 index b60c1c0127919..0000000000000 --- a/libs/nio/src/main/java/org/elasticsearch/nio/Page.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.nio; - -import org.elasticsearch.common.util.concurrent.AbstractRefCounted; - -import java.io.Closeable; -import java.nio.ByteBuffer; - -public class Page implements Closeable { - - private final ByteBuffer byteBuffer; - // This is reference counted as some implementations want to retain the byte pages by calling - // duplicate. With reference counting we can increment the reference count, return a new page, - // and safely close the pages independently. The closeable will not be called until each page is - // released. - private final RefCountedCloseable refCountedCloseable; - - public Page(ByteBuffer byteBuffer) { - this(byteBuffer, () -> {}); - } - - public Page(ByteBuffer byteBuffer, Runnable closeable) { - this(byteBuffer, new RefCountedCloseable(closeable)); - } - - private Page(ByteBuffer byteBuffer, RefCountedCloseable refCountedCloseable) { - this.byteBuffer = byteBuffer; - this.refCountedCloseable = refCountedCloseable; - } - - /** - * Duplicates this page and increments the reference count. The new page must be closed independently - * of the original page. - * - * @return the new page - */ - public Page duplicate() { - refCountedCloseable.incRef(); - return new Page(byteBuffer.duplicate(), refCountedCloseable); - } - - /** - * Returns the {@link ByteBuffer} for this page. Modifications to the limits, positions, etc of the - * buffer will also mutate this page. Call {@link ByteBuffer#duplicate()} to avoid mutating the page. - * - * @return the byte buffer - */ - public ByteBuffer byteBuffer() { - return byteBuffer; - } - - @Override - public void close() { - refCountedCloseable.decRef(); - } - - private static class RefCountedCloseable extends AbstractRefCounted { - - private final Runnable closeable; - - private RefCountedCloseable(Runnable closeable) { - super("byte array page"); - this.closeable = closeable; - } - - @Override - protected void closeInternal() { - closeable.run(); - } - } -} diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java b/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java index 1444422f7a7f6..816f4adc8cbb1 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java @@ -325,7 +325,7 @@ protected int flushToChannel(FlushOperation flushOperation) throws IOException { ioBuffer.clear(); ioBuffer.limit(Math.min(WRITE_LIMIT, ioBuffer.limit())); int j = 0; - ByteBuffer[] buffers = flushOperation.getBuffersToWrite(WRITE_LIMIT); + ByteBuffer[] buffers = flushOperation.getBuffersToWrite(); while (j < buffers.length && ioBuffer.remaining() > 0) { ByteBuffer buffer = buffers[j++]; copyBytes(buffer, ioBuffer); diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java index c98e7dc8dfb29..0591abdd69a97 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java @@ -31,7 +31,6 @@ import java.util.function.Consumer; import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -169,7 +168,7 @@ public void testQueuedWriteIsFlushedInFlushCall() throws Exception { assertTrue(context.readyForFlush()); - when(flushOperation.getBuffersToWrite(anyInt())).thenReturn(buffers); + when(flushOperation.getBuffersToWrite()).thenReturn(buffers); when(flushOperation.isFullyFlushed()).thenReturn(false, true); when(flushOperation.getListener()).thenReturn(listener); context.flushChannel(); @@ -188,7 +187,7 @@ public void testPartialFlush() throws IOException { assertTrue(context.readyForFlush()); when(flushOperation.isFullyFlushed()).thenReturn(false); - when(flushOperation.getBuffersToWrite(anyInt())).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)}); + when(flushOperation.getBuffersToWrite()).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)}); context.flushChannel(); verify(listener, times(0)).accept(null, null); @@ -202,8 +201,8 @@ public void testMultipleWritesPartialFlushes() throws IOException { BiConsumer listener2 = mock(BiConsumer.class); FlushReadyWrite flushOperation1 = mock(FlushReadyWrite.class); FlushReadyWrite flushOperation2 = mock(FlushReadyWrite.class); - when(flushOperation1.getBuffersToWrite(anyInt())).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)}); - when(flushOperation2.getBuffersToWrite(anyInt())).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)}); + when(flushOperation1.getBuffersToWrite()).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)}); + when(flushOperation2.getBuffersToWrite()).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)}); when(flushOperation1.getListener()).thenReturn(listener); when(flushOperation2.getListener()).thenReturn(listener2); @@ -238,7 +237,7 @@ public void testWhenIOExceptionThrownListenerIsCalled() throws IOException { assertTrue(context.readyForFlush()); IOException exception = new IOException(); - when(flushOperation.getBuffersToWrite(anyInt())).thenReturn(buffers); + when(flushOperation.getBuffersToWrite()).thenReturn(buffers); when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception); when(flushOperation.getListener()).thenReturn(listener); expectThrows(IOException.class, () -> context.flushChannel()); @@ -253,7 +252,7 @@ public void testWriteIOExceptionMeansChannelReadyToClose() throws IOException { context.queueWriteOperation(flushOperation); IOException exception = new IOException(); - when(flushOperation.getBuffersToWrite(anyInt())).thenReturn(buffers); + when(flushOperation.getBuffersToWrite()).thenReturn(buffers); when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception); assertFalse(context.selectorShouldClose()); diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/FlushOperationTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/FlushOperationTests.java index 73dba34cc30f7..4f2a320ad583d 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/FlushOperationTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/FlushOperationTests.java @@ -65,45 +65,29 @@ public void testMultipleFlushesWithCompositeBuffer() throws IOException { ByteBuffer[] byteBuffers = writeOp.getBuffersToWrite(); assertEquals(3, byteBuffers.length); assertEquals(5, byteBuffers[0].remaining()); - ByteBuffer[] byteBuffersWithLimit = writeOp.getBuffersToWrite(10); - assertEquals(2, byteBuffersWithLimit.length); - assertEquals(5, byteBuffersWithLimit[0].remaining()); - assertEquals(5, byteBuffersWithLimit[1].remaining()); writeOp.incrementIndex(5); assertFalse(writeOp.isFullyFlushed()); byteBuffers = writeOp.getBuffersToWrite(); assertEquals(2, byteBuffers.length); assertEquals(15, byteBuffers[0].remaining()); - assertEquals(3, byteBuffers[1].remaining()); - byteBuffersWithLimit = writeOp.getBuffersToWrite(10); - assertEquals(1, byteBuffersWithLimit.length); - assertEquals(10, byteBuffersWithLimit[0].remaining()); writeOp.incrementIndex(2); assertFalse(writeOp.isFullyFlushed()); byteBuffers = writeOp.getBuffersToWrite(); assertEquals(2, byteBuffers.length); assertEquals(13, byteBuffers[0].remaining()); - assertEquals(3, byteBuffers[1].remaining()); - byteBuffersWithLimit = writeOp.getBuffersToWrite(10); - assertEquals(1, byteBuffersWithLimit.length); - assertEquals(10, byteBuffersWithLimit[0].remaining()); writeOp.incrementIndex(15); assertFalse(writeOp.isFullyFlushed()); byteBuffers = writeOp.getBuffersToWrite(); assertEquals(1, byteBuffers.length); assertEquals(1, byteBuffers[0].remaining()); - byteBuffersWithLimit = writeOp.getBuffersToWrite(10); - assertEquals(1, byteBuffersWithLimit.length); - assertEquals(1, byteBuffersWithLimit[0].remaining()); writeOp.incrementIndex(1); assertTrue(writeOp.isFullyFlushed()); byteBuffers = writeOp.getBuffersToWrite(); - assertEquals(0, byteBuffers.length); - byteBuffersWithLimit = writeOp.getBuffersToWrite(10); - assertEquals(0, byteBuffersWithLimit.length); + assertEquals(1, byteBuffers.length); + assertEquals(0, byteBuffers[0].remaining()); } } diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/InboundChannelBufferTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/InboundChannelBufferTests.java index f558043095372..8917bec39f17e 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/InboundChannelBufferTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/InboundChannelBufferTests.java @@ -30,8 +30,8 @@ public class InboundChannelBufferTests extends ESTestCase { private static final int PAGE_SIZE = PageCacheRecycler.PAGE_SIZE_IN_BYTES; - private final Supplier defaultPageSupplier = () -> - new Page(ByteBuffer.allocate(PageCacheRecycler.BYTE_PAGE_SIZE), () -> { + private final Supplier defaultPageSupplier = () -> + new InboundChannelBuffer.Page(ByteBuffer.allocate(PageCacheRecycler.BYTE_PAGE_SIZE), () -> { }); public void testNewBufferNoPages() { @@ -126,10 +126,10 @@ public void testIncrementIndexWithOffset() { public void testReleaseClosesPages() { ConcurrentLinkedQueue queue = new ConcurrentLinkedQueue<>(); - Supplier supplier = () -> { + Supplier supplier = () -> { AtomicBoolean atomicBoolean = new AtomicBoolean(); queue.add(atomicBoolean); - return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true)); + return new InboundChannelBuffer.Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true)); }; InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier); channelBuffer.ensureCapacity(PAGE_SIZE * 4); @@ -153,10 +153,10 @@ public void testReleaseClosesPages() { public void testClose() { ConcurrentLinkedQueue queue = new ConcurrentLinkedQueue<>(); - Supplier supplier = () -> { + Supplier supplier = () -> { AtomicBoolean atomicBoolean = new AtomicBoolean(); queue.add(atomicBoolean); - return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true)); + return new InboundChannelBuffer.Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true)); }; InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier); channelBuffer.ensureCapacity(PAGE_SIZE * 4); @@ -178,10 +178,10 @@ public void testClose() { public void testCloseRetainedPages() { ConcurrentLinkedQueue queue = new ConcurrentLinkedQueue<>(); - Supplier supplier = () -> { + Supplier supplier = () -> { AtomicBoolean atomicBoolean = new AtomicBoolean(); queue.add(atomicBoolean); - return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true)); + return new InboundChannelBuffer.Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true)); }; InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier); channelBuffer.ensureCapacity(PAGE_SIZE * 4); @@ -192,7 +192,7 @@ public void testCloseRetainedPages() { assertFalse(closedRef.get()); } - Page[] pages = channelBuffer.sliceAndRetainPagesTo(PAGE_SIZE * 2); + InboundChannelBuffer.Page[] pages = channelBuffer.sliceAndRetainPagesTo(PAGE_SIZE * 2); pages[1].close(); diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java index baf7abac79d1b..345c5197c76b8 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java @@ -285,7 +285,7 @@ public void testCloseClosesChannelBuffer() throws IOException { when(channel.getRawChannel()).thenReturn(realChannel); when(channel.isOpen()).thenReturn(true); Runnable closer = mock(Runnable.class); - Supplier pageSupplier = () -> new Page(ByteBuffer.allocate(1 << 14), closer); + Supplier pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer); InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); buffer.ensureCapacity(1); TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, buffer); diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NettyAdaptor.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NettyAdaptor.java index 96db559e60333..c221fdf1378d7 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NettyAdaptor.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NettyAdaptor.java @@ -29,7 +29,7 @@ import io.netty.channel.embedded.EmbeddedChannel; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.nio.FlushOperation; -import org.elasticsearch.nio.Page; +import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.WriteOperation; import java.nio.ByteBuffer; @@ -97,7 +97,7 @@ public int read(ByteBuffer[] buffers) { return byteBuf.readerIndex() - initialReaderIndex; } - public int read(Page[] pages) { + public int read(InboundChannelBuffer.Page[] pages) { ByteBuf byteBuf = PagedByteBuf.byteBufFromPages(pages); int readableBytes = byteBuf.readableBytes(); nettyChannel.writeInbound(byteBuf); diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java index 57936ff70c628..a5f274c7ccd34 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java @@ -43,7 +43,6 @@ import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.rest.RestUtils; @@ -206,9 +205,9 @@ private HttpChannelFactory() { @Override public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { NioHttpChannel httpChannel = new NioHttpChannel(channel); - java.util.function.Supplier pageSupplier = () -> { + java.util.function.Supplier pageSupplier = () -> { Recycler.V bytes = pageCacheRecycler.bytePage(false); - return new Page(ByteBuffer.wrap(bytes.v()), bytes::close); + return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); }; HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(httpChannel,NioHttpServerTransport.this, handlingSettings, corsConfig); diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/PagedByteBuf.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/PagedByteBuf.java index 359926d43f9a7..40f3aeecfbc94 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/PagedByteBuf.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/PagedByteBuf.java @@ -24,7 +24,7 @@ import io.netty.buffer.Unpooled; import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.buffer.UnpooledHeapByteBuf; -import org.elasticsearch.nio.Page; +import org.elasticsearch.nio.InboundChannelBuffer; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -39,7 +39,7 @@ private PagedByteBuf(byte[] array, Runnable releasable) { this.releasable = releasable; } - static ByteBuf byteBufFromPages(Page[] pages) { + static ByteBuf byteBufFromPages(InboundChannelBuffer.Page[] pages) { int componentCount = pages.length; if (componentCount == 0) { return Unpooled.EMPTY_BUFFER; @@ -48,15 +48,15 @@ static ByteBuf byteBufFromPages(Page[] pages) { } else { int maxComponents = Math.max(16, componentCount); final List components = new ArrayList<>(componentCount); - for (Page page : pages) { + for (InboundChannelBuffer.Page page : pages) { components.add(byteBufFromPage(page)); } return new CompositeByteBuf(UnpooledByteBufAllocator.DEFAULT, false, maxComponents, components); } } - private static ByteBuf byteBufFromPage(Page page) { - ByteBuffer buffer = page.byteBuffer(); + private static ByteBuf byteBufFromPage(InboundChannelBuffer.Page page) { + ByteBuffer buffer = page.getByteBuffer(); assert buffer.isDirect() == false && buffer.hasArray() : "Must be a heap buffer with an array"; int offset = buffer.arrayOffset() + buffer.position(); PagedByteBuf newByteBuf = new PagedByteBuf(buffer.array(), page::close); diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java index 17dc6c41baac7..2f30143f18968 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java @@ -36,7 +36,6 @@ import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TcpTransport; @@ -158,9 +157,9 @@ private TcpChannelFactoryImpl(ProfileSettings profileSettings, boolean isClient) @Override public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) { NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel); - Supplier pageSupplier = () -> { + Supplier pageSupplier = () -> { Recycler.V bytes = pageCacheRecycler.bytePage(false); - return new Page(ByteBuffer.wrap(bytes.v()), bytes::close); + return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); }; TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, NioTransport.this); Consumer exceptionHandler = (e) -> onException(nioChannel, e); diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/PagedByteBufTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/PagedByteBufTests.java index df4bf3274b3bc..15bd18ecf6959 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/PagedByteBufTests.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/PagedByteBufTests.java @@ -20,7 +20,7 @@ package org.elasticsearch.http.nio; import io.netty.buffer.ByteBuf; -import org.elasticsearch.nio.Page; +import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.test.ESTestCase; import java.nio.ByteBuffer; @@ -32,12 +32,12 @@ public class PagedByteBufTests extends ESTestCase { public void testReleasingPage() { AtomicInteger integer = new AtomicInteger(0); int pageCount = randomInt(10) + 1; - ArrayList pages = new ArrayList<>(); + ArrayList pages = new ArrayList<>(); for (int i = 0; i < pageCount; ++i) { - pages.add(new Page(ByteBuffer.allocate(10), integer::incrementAndGet)); + pages.add(new InboundChannelBuffer.Page(ByteBuffer.allocate(10), integer::incrementAndGet)); } - ByteBuf byteBuf = PagedByteBuf.byteBufFromPages(pages.toArray(new Page[0])); + ByteBuf byteBuf = PagedByteBuf.byteBufFromPages(pages.toArray(new InboundChannelBuffer.Page[0])); assertEquals(0, integer.get()); byteBuf.retain(); @@ -62,9 +62,9 @@ public void testBytesAreUsed() { bytes2[i - 10] = (byte) i; } - Page[] pages = new Page[2]; - pages[0] = new Page(ByteBuffer.wrap(bytes1), () -> {}); - pages[1] = new Page(ByteBuffer.wrap(bytes2), () -> {}); + InboundChannelBuffer.Page[] pages = new InboundChannelBuffer.Page[2]; + pages[0] = new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes1), () -> {}); + pages[1] = new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes2), () -> {}); ByteBuf byteBuf = PagedByteBuf.byteBufFromPages(pages); assertEquals(20, byteBuf.readableBytes()); @@ -73,13 +73,13 @@ public void testBytesAreUsed() { assertEquals((byte) i, byteBuf.getByte(i)); } - Page[] pages2 = new Page[2]; + InboundChannelBuffer.Page[] pages2 = new InboundChannelBuffer.Page[2]; ByteBuffer firstBuffer = ByteBuffer.wrap(bytes1); firstBuffer.position(2); ByteBuffer secondBuffer = ByteBuffer.wrap(bytes2); secondBuffer.limit(8); - pages2[0] = new Page(firstBuffer, () -> {}); - pages2[1] = new Page(secondBuffer, () -> {}); + pages2[0] = new InboundChannelBuffer.Page(firstBuffer, () -> {}); + pages2[1] = new InboundChannelBuffer.Page(secondBuffer, () -> {}); ByteBuf byteBuf2 = PagedByteBuf.byteBufFromPages(pages2); assertEquals(16, byteBuf2.readableBytes()); diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java index abb92979f8d11..537bfd3aefd21 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java @@ -41,7 +41,6 @@ import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioServerSocketChannel; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.ConnectionProfile; @@ -192,9 +191,9 @@ private MockTcpChannelFactory(boolean isClient, ProfileSettings profileSettings, @Override public MockSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { MockSocketChannel nioChannel = new MockSocketChannel(isClient == false, profileName, channel); - Supplier pageSupplier = () -> { + Supplier pageSupplier = () -> { Recycler.V bytes = pageCacheRecycler.bytePage(false); - return new Page(ByteBuffer.wrap(bytes.v()), bytes::close); + return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); }; MockTcpReadWriteHandler readWriteHandler = new MockTcpReadWriteHandler(nioChannel, MockNioTransport.this); BytesChannelContext context = new BytesChannelContext(nioChannel, selector, (e) -> exceptionCaught(nioChannel, e), diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java index 2c00dd7092950..b5d5db2166c1f 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java @@ -10,7 +10,6 @@ import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ReadWriteHandler; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.NioSelector; @@ -18,8 +17,6 @@ import javax.net.ssl.SSLEngine; import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.ClosedChannelException; import java.util.concurrent.TimeUnit; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -37,8 +34,6 @@ public final class SSLChannelContext extends SocketChannelContext { private static final Runnable DEFAULT_TIMEOUT_CANCELLER = () -> {}; private final SSLDriver sslDriver; - private final SSLOutboundBuffer outboundBuffer; - private FlushOperation encryptedFlush; private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER; SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, SSLDriver sslDriver, @@ -51,8 +46,6 @@ public final class SSLChannelContext extends SocketChannelContext { Predicate allowChannelPredicate) { super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate); this.sslDriver = sslDriver; - // TODO: When the bytes are actually recycled, we need to test that they are released on context close - this.outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); } @Override @@ -79,32 +72,34 @@ public void flushChannel() throws IOException { return; } // If there is currently data in the outbound write buffer, flush the buffer. - if (pendingChannelFlush()) { + if (sslDriver.hasFlushPending()) { // If the data is not completely flushed, exit. We cannot produce new write data until the // existing data has been fully flushed. - flushEncryptedOperation(); - if (pendingChannelFlush()) { + flushToChannel(sslDriver.getNetworkWriteBuffer()); + if (sslDriver.hasFlushPending()) { return; } } // If the driver is ready for application writes, we can attempt to proceed with any queued writes. if (sslDriver.readyForApplicationWrites()) { - FlushOperation unencryptedFlush; - while (pendingChannelFlush() == false && (unencryptedFlush = getPendingFlush()) != null) { - if (unencryptedFlush.isFullyFlushed()) { + FlushOperation currentFlush; + while (sslDriver.hasFlushPending() == false && (currentFlush = getPendingFlush()) != null) { + // If the current operation has been fully consumed (encrypted) we now know that it has been + // sent (as we only get to this point if the write buffer has been fully flushed). + if (currentFlush.isFullyFlushed()) { currentFlushOperationComplete(); } else { try { // Attempt to encrypt application write data. The encrypted data ends up in the // outbound write buffer. - sslDriver.write(unencryptedFlush, outboundBuffer); - if (outboundBuffer.hasEncryptedBytesToFlush() == false) { + int bytesEncrypted = sslDriver.applicationWrite(currentFlush.getBuffersToWrite()); + if (bytesEncrypted == 0) { break; } - encryptedFlush = outboundBuffer.buildNetworkFlushOperation(); + currentFlush.incrementIndex(bytesEncrypted); // Flush the write buffer to the channel - flushEncryptedOperation(); + flushToChannel(sslDriver.getNetworkWriteBuffer()); } catch (IOException e) { currentFlushOperationFailed(e); throw e; @@ -114,38 +109,23 @@ public void flushChannel() throws IOException { } else { // We are not ready for application writes, check if the driver has non-application writes. We // only want to continue producing new writes if the outbound write buffer is fully flushed. - while (pendingChannelFlush() == false && sslDriver.needsNonApplicationWrite()) { - sslDriver.nonApplicationWrite(outboundBuffer); + while (sslDriver.hasFlushPending() == false && sslDriver.needsNonApplicationWrite()) { + sslDriver.nonApplicationWrite(); // If non-application writes were produced, flush the outbound write buffer. - if (outboundBuffer.hasEncryptedBytesToFlush()) { - encryptedFlush = outboundBuffer.buildNetworkFlushOperation(); - flushEncryptedOperation(); + if (sslDriver.hasFlushPending()) { + flushToChannel(sslDriver.getNetworkWriteBuffer()); } } } } - private void flushEncryptedOperation() throws IOException { - try { - flushToChannel(encryptedFlush); - if (encryptedFlush.isFullyFlushed()) { - getSelector().executeListener(encryptedFlush.getListener(), null); - encryptedFlush = null; - } - } catch (IOException e) { - getSelector().executeFailedListener(encryptedFlush.getListener(), e); - encryptedFlush = null; - throw e; - } - } - @Override public boolean readyForFlush() { getSelector().assertOnSelectorThread(); if (sslDriver.readyForApplicationWrites()) { - return pendingChannelFlush() || super.readyForFlush(); + return sslDriver.hasFlushPending() || super.readyForFlush(); } else { - return pendingChannelFlush() || sslDriver.needsNonApplicationWrite(); + return sslDriver.hasFlushPending() || sslDriver.needsNonApplicationWrite(); } } @@ -169,7 +149,7 @@ public int read() throws IOException { @Override public boolean selectorShouldClose() { - return closeNow() || (sslDriver.isClosed() && pendingChannelFlush() == false); + return closeNow() || sslDriver.isClosed(); } @Override @@ -190,10 +170,7 @@ public void closeFromSelector() throws IOException { getSelector().assertOnSelectorThread(); if (channel.isOpen()) { closeTimeoutCanceller.run(); - if (encryptedFlush != null) { - getSelector().executeFailedListener(encryptedFlush.getListener(), new ClosedChannelException()); - } - IOUtils.close(super::closeFromSelector, outboundBuffer::close, sslDriver::close); + IOUtils.close(super::closeFromSelector, sslDriver::close); } } @@ -207,14 +184,9 @@ private void channelCloseTimeout() { getSelector().queueChannelClose(channel); } - private boolean pendingChannelFlush() { - return encryptedFlush != null; - } - private static class CloseNotifyOperation implements WriteOperation { - private static final BiConsumer LISTENER = (v, t) -> { - }; + private static final BiConsumer LISTENER = (v, t) -> {}; private static final Object WRITE_OBJECT = new Object(); private final SocketChannelContext channelContext; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java index 4dbf1d1f03fdf..93978bcc6a359 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.security.transport.nio; -import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.utils.ExceptionsHelper; @@ -30,17 +29,19 @@ * the buffer passed as an argument. Otherwise, it will be consumed internally and advance the SSL/TLS close * or handshake process. * - * Producing writes for a channel is more complicated. The method {@link #needsNonApplicationWrite()} can be - * called to determine if this driver needs to produce more data to advance the handshake or close process. - * If that method returns true, {@link #nonApplicationWrite(SSLOutboundBuffer)} should be called (and the - * data produced then flushed to the channel) until no further non-application writes are needed. + * Producing writes for a channel is more complicated. If there is existing data in the outbound write buffer + * as indicated by {@link #hasFlushPending()}, that data must be written to the channel before more outbound + * data can be produced. If no flushes are pending, {@link #needsNonApplicationWrite()} can be called to + * determine if this driver needs to produce more data to advance the handshake or close process. If that + * method returns true, {@link #nonApplicationWrite()} should be called (and the data produced then flushed + * to the channel) until no further non-application writes are needed. * * If no non-application writes are needed, {@link #readyForApplicationWrites()} can be called to determine * if the driver is ready to consume application data. (Note: It is possible that * {@link #readyForApplicationWrites()} and {@link #needsNonApplicationWrite()} can both return false if the * driver is waiting on non-application data from the peer.) If the driver indicates it is ready for - * application writes, {@link #write(FlushOperation, SSLOutboundBuffer)} can be called. This method will - * encrypt flush operation application data and place it in the outbound buffer for flushing to a channel. + * application writes, {@link #applicationWrite(ByteBuffer[])} can be called. This method will encrypt + * application data and place it in the write buffer for flushing to a channel. * * If you are ready to close the channel {@link #initiateClose()} should be called. After that is called, the * driver will start producing non-application writes related to notifying the peer connection that this @@ -49,23 +50,23 @@ */ public class SSLDriver implements AutoCloseable { - private static final ByteBuffer[] EMPTY_BUFFERS = {ByteBuffer.allocate(0)}; - private static final FlushOperation EMPTY_FLUSH_OPERATION = new FlushOperation(EMPTY_BUFFERS, (r, t) -> {}); + private static final ByteBuffer[] EMPTY_BUFFER_ARRAY = new ByteBuffer[0]; private final SSLEngine engine; private final boolean isClientMode; // This should only be accessed by the network thread associated with this channel, so nothing needs to // be volatile. private Mode currentMode = new HandshakeMode(); + private ByteBuffer networkWriteBuffer; private ByteBuffer networkReadBuffer; - private int packetSize; public SSLDriver(SSLEngine engine, boolean isClientMode) { this.engine = engine; this.isClientMode = isClientMode; SSLSession session = engine.getSession(); - packetSize = session.getPacketBufferSize(); - this.networkReadBuffer = ByteBuffer.allocate(packetSize); + this.networkReadBuffer = ByteBuffer.allocate(session.getPacketBufferSize()); + this.networkWriteBuffer = ByteBuffer.allocate(session.getPacketBufferSize()); + this.networkWriteBuffer.position(this.networkWriteBuffer.limit()); } public void init() throws SSLException { @@ -99,10 +100,18 @@ public SSLEngine getSSLEngine() { return engine; } + public boolean hasFlushPending() { + return networkWriteBuffer.hasRemaining(); + } + public boolean isHandshaking() { return currentMode.isHandshake(); } + public ByteBuffer getNetworkWriteBuffer() { + return networkWriteBuffer; + } + public ByteBuffer getNetworkReadBuffer() { return networkReadBuffer; } @@ -125,14 +134,15 @@ public boolean needsNonApplicationWrite() { return currentMode.needsNonApplicationWrite(); } - public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException { - return currentMode.write(applicationBytes, outboundBuffer); + public int applicationWrite(ByteBuffer[] buffers) throws SSLException { + assert readyForApplicationWrites() : "Should not be called if driver is not ready for application writes"; + return currentMode.write(buffers); } - public void nonApplicationWrite(SSLOutboundBuffer outboundBuffer) throws SSLException { + public void nonApplicationWrite() throws SSLException { assert currentMode.isApplication() == false : "Should not be called if driver is in application mode"; if (currentMode.isApplication() == false) { - currentMode.write(EMPTY_FLUSH_OPERATION, outboundBuffer); + currentMode.write(EMPTY_BUFFER_ARRAY); } else { throw new AssertionError("Attempted to non-application write from invalid mode: " + currentMode.modeName()); } @@ -195,36 +205,45 @@ private SSLEngineResult unwrap(InboundChannelBuffer buffer) throws SSLException } } - private SSLEngineResult wrap(SSLOutboundBuffer outboundBuffer) throws SSLException { - return wrap(outboundBuffer, EMPTY_FLUSH_OPERATION); - } + private SSLEngineResult wrap(ByteBuffer[] buffers) throws SSLException { + assert hasFlushPending() == false : "Should never called with pending writes"; - private SSLEngineResult wrap(SSLOutboundBuffer outboundBuffer, FlushOperation applicationBytes) throws SSLException { - ByteBuffer[] buffers = applicationBytes.getBuffersToWrite(engine.getSession().getApplicationBufferSize()); + networkWriteBuffer.clear(); while (true) { SSLEngineResult result; - ByteBuffer networkBuffer = outboundBuffer.nextWriteBuffer(packetSize); try { - result = engine.wrap(buffers, networkBuffer); + if (buffers.length == 1) { + result = engine.wrap(buffers[0], networkWriteBuffer); + } else { + result = engine.wrap(buffers, networkWriteBuffer); + } } catch (SSLException e) { - outboundBuffer.incrementEncryptedBytes(0); + networkWriteBuffer.position(networkWriteBuffer.limit()); throw e; } - outboundBuffer.incrementEncryptedBytes(result.bytesProduced()); - applicationBytes.incrementIndex(result.bytesConsumed()); switch (result.getStatus()) { case OK: + networkWriteBuffer.flip(); return result; case BUFFER_UNDERFLOW: throw new IllegalStateException("Should not receive BUFFER_UNDERFLOW on WRAP"); case BUFFER_OVERFLOW: - packetSize = engine.getSession().getPacketBufferSize(); - // There is not enough space in the network buffer for an entire SSL packet. We will - // allocate a buffer with the correct packet size the next time through the loop. + // There is not enough space in the network buffer for an entire SSL packet. Expand the + // buffer if it's smaller than the current session packet size. Otherwise return and wait + // for existing data to be flushed. + int currentCapacity = networkWriteBuffer.capacity(); + ensureNetworkWriteBufferSize(); + if (currentCapacity == networkWriteBuffer.capacity()) { + return result; + } break; case CLOSED: - assert result.bytesProduced() > 0 : "WRAP during close processing should produce close message."; + if (result.bytesProduced() > 0) { + networkWriteBuffer.flip(); + } else { + assert false : "WRAP during close processing should produce close message."; + } return result; default: throw new IllegalStateException("Unexpected WRAP result: " + result.getStatus()); @@ -246,12 +265,23 @@ private void ensureApplicationBufferSize(InboundChannelBuffer applicationBuffer) } } + private void ensureNetworkWriteBufferSize() { + networkWriteBuffer = ensureNetBufferSize(networkWriteBuffer); + } + private void ensureNetworkReadBufferSize() { - packetSize = engine.getSession().getPacketBufferSize(); - if (networkReadBuffer.capacity() < packetSize) { - ByteBuffer newBuffer = ByteBuffer.allocate(packetSize); - networkReadBuffer.flip(); - newBuffer.put(networkReadBuffer); + networkReadBuffer = ensureNetBufferSize(networkReadBuffer); + } + + private ByteBuffer ensureNetBufferSize(ByteBuffer current) { + int networkPacketSize = engine.getSession().getPacketBufferSize(); + if (current.capacity() < networkPacketSize) { + ByteBuffer newBuffer = ByteBuffer.allocate(networkPacketSize); + current.flip(); + newBuffer.put(current); + return newBuffer; + } else { + return current; } } @@ -276,7 +306,7 @@ private interface Mode { void read(InboundChannelBuffer buffer) throws SSLException; - int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException; + int write(ByteBuffer[] buffers) throws SSLException; boolean needsNonApplicationWrite(); @@ -299,7 +329,7 @@ private void startHandshake() throws SSLException { if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_WRAP) { try { - handshake(null); + handshake(); } catch (SSLException e) { closingInternal(); throw e; @@ -307,7 +337,7 @@ private void startHandshake() throws SSLException { } } - private void handshake(SSLOutboundBuffer outboundBuffer) throws SSLException { + private void handshake() throws SSLException { boolean continueHandshaking = true; while (continueHandshaking) { switch (handshakeStatus) { @@ -316,13 +346,11 @@ private void handshake(SSLOutboundBuffer outboundBuffer) throws SSLException { continueHandshaking = false; break; case NEED_WRAP: - if (outboundBuffer != null) { - handshakeStatus = wrap(outboundBuffer).getHandshakeStatus(); - // If we need NEED_TASK we should run the tasks immediately - if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_TASK) { - continueHandshaking = false; - } - } else { + if (hasFlushPending() == false) { + handshakeStatus = wrap(EMPTY_BUFFER_ARRAY).getHandshakeStatus(); + } + // If we need NEED_TASK we should run the tasks immediately + if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_TASK) { continueHandshaking = false; } break; @@ -351,7 +379,7 @@ public void read(InboundChannelBuffer buffer) throws SSLException { try { SSLEngineResult result = unwrap(buffer); handshakeStatus = result.getHandshakeStatus(); - handshake(null); + handshake(); // If we are done handshaking we should exit the handshake read continueUnwrap = result.bytesConsumed() > 0 && currentMode.isHandshake(); } catch (SSLException e) { @@ -362,9 +390,9 @@ public void read(InboundChannelBuffer buffer) throws SSLException { } @Override - public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException { + public int write(ByteBuffer[] buffers) throws SSLException { try { - handshake(outboundBuffer); + handshake(); } catch (SSLException e) { closingInternal(); throw e; @@ -417,7 +445,8 @@ private void maybeFinishHandshake() { String message = "Expected to be in handshaking/closed mode. Instead in application mode."; throw new AssertionError(message); } - } else { + } else if (hasFlushPending() == false) { + // We only acknowledge that we are done handshaking if there are no bytes that need to be written if (currentMode.isHandshake()) { currentMode = new ApplicationMode(); } else { @@ -444,17 +473,10 @@ public void read(InboundChannelBuffer buffer) throws SSLException { } @Override - public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException { - boolean continueWrap = true; - int totalBytesProduced = 0; - while (continueWrap && applicationBytes.isFullyFlushed() == false) { - SSLEngineResult result = wrap(outboundBuffer, applicationBytes); - int bytesProduced = result.bytesProduced(); - totalBytesProduced += bytesProduced; - boolean renegotiationRequested = maybeRenegotiation(result.getHandshakeStatus()); - continueWrap = bytesProduced > 0 && renegotiationRequested == false; - } - return totalBytesProduced; + public int write(ByteBuffer[] buffers) throws SSLException { + SSLEngineResult result = wrap(buffers); + maybeRenegotiation(result.getHandshakeStatus()); + return result.bytesConsumed(); } private boolean maybeRenegotiation(SSLEngineResult.HandshakeStatus newStatus) throws SSLException { @@ -538,19 +560,18 @@ public void read(InboundChannelBuffer buffer) throws SSLException { } @Override - public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException { - int bytesProduced = 0; - if (engine.isOutboundDone() == false) { - bytesProduced += wrap(outboundBuffer).bytesProduced(); - if (engine.isOutboundDone()) { - needToSendClose = false; - // Close inbound if it is still open and we have decided not to wait for response. - if (needToReceiveClose == false && engine.isInboundDone() == false) { - closeInboundAndSwallowPeerDidNotCloseException(); - } + public int write(ByteBuffer[] buffers) throws SSLException { + if (hasFlushPending() == false && engine.isOutboundDone()) { + needToSendClose = false; + // Close inbound if it is still open and we have decided not to wait for response. + if (needToReceiveClose == false && engine.isInboundDone() == false) { + closeInboundAndSwallowPeerDidNotCloseException(); } + } else { + wrap(EMPTY_BUFFER_ARRAY); + assert hasFlushPending() : "Should have produced close message"; } - return bytesProduced; + return 0; } @Override diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLOutboundBuffer.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLOutboundBuffer.java deleted file mode 100644 index 2cd28f7d7dc32..0000000000000 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLOutboundBuffer.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.security.transport.nio; - -import org.elasticsearch.core.internal.io.IOUtils; -import org.elasticsearch.nio.FlushOperation; -import org.elasticsearch.nio.Page; - -import java.nio.ByteBuffer; -import java.util.ArrayDeque; -import java.util.function.IntFunction; - -public class SSLOutboundBuffer implements AutoCloseable { - - private final ArrayDeque pages; - private final IntFunction pageSupplier; - - private Page currentPage; - - SSLOutboundBuffer(IntFunction pageSupplier) { - this.pages = new ArrayDeque<>(); - this.pageSupplier = pageSupplier; - } - - void incrementEncryptedBytes(int encryptedBytesProduced) { - if (encryptedBytesProduced != 0) { - currentPage.byteBuffer().limit(encryptedBytesProduced); - pages.addLast(currentPage); - } - currentPage = null; - } - - ByteBuffer nextWriteBuffer(int networkBufferSize) { - if (currentPage != null) { - // If there is an existing page, close it as it wasn't large enough to accommodate the SSLEngine. - currentPage.close(); - } - - Page newPage = pageSupplier.apply(networkBufferSize); - currentPage = newPage; - return newPage.byteBuffer().duplicate(); - } - - FlushOperation buildNetworkFlushOperation() { - int pageCount = pages.size(); - ByteBuffer[] byteBuffers = new ByteBuffer[pageCount]; - Page[] pagesToClose = new Page[pageCount]; - for (int i = 0; i < pageCount; ++i) { - Page page = pages.removeFirst(); - pagesToClose[i] = page; - byteBuffers[i] = page.byteBuffer(); - } - - return new FlushOperation(byteBuffers, (r, e) -> IOUtils.closeWhileHandlingException(pagesToClose)); - } - - boolean hasEncryptedBytesToFlush() { - return pages.isEmpty() == false; - } - - @Override - public void close() { - IOUtils.closeWhileHandlingException(pages); - } -} diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java index 8ecba16fa460d..9e0da2518835d 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java @@ -22,7 +22,6 @@ import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.threadpool.ThreadPool; @@ -93,9 +92,9 @@ private SecurityHttpChannelFactory() { @Override public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { NioHttpChannel httpChannel = new NioHttpChannel(channel); - Supplier pageSupplier = () -> { + Supplier pageSupplier = () -> { Recycler.V bytes = pageCacheRecycler.bytePage(false); - return new Page(ByteBuffer.wrap(bytes.v()), bytes::close); + return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); }; HttpReadWriteHandler httpHandler = new HttpReadWriteHandler(httpChannel,SecurityNioHttpServerTransport.this, handlingSettings, corsConfig); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java index d3f92a2575f6d..78c93ffb73cfd 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java @@ -21,7 +21,6 @@ import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.threadpool.ThreadPool; @@ -156,9 +155,9 @@ private SecurityTcpChannelFactory(RawChannelFactory rawChannelFactory, String pr @Override public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel); - Supplier pageSupplier = () -> { + Supplier pageSupplier = () -> { Recycler.V bytes = pageCacheRecycler.bytePage(false); - return new Page(ByteBuffer.wrap(bytes.v()), bytes::close); + return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close); }; TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, SecurityNioTransport.this); InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java index 893af2140b9b0..0870124022850 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java @@ -8,7 +8,6 @@ import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.nio.BytesWriteHandler; -import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.FlushReadyWrite; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSelector; @@ -29,7 +28,6 @@ import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; -import static org.mockito.Matchers.eq; import static org.mockito.Matchers.same; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -51,6 +49,7 @@ public class SSLChannelContextTests extends ESTestCase { private Consumer exceptionHandler; private SSLDriver sslDriver; private ByteBuffer readBuffer = ByteBuffer.allocate(1 << 14); + private ByteBuffer writeBuffer = ByteBuffer.allocate(1 << 14); private int messageLength; @Before @@ -74,6 +73,7 @@ public void init() { when(selector.isOnCurrentThread()).thenReturn(true); when(selector.getTaskScheduler()).thenReturn(nioTimer); when(sslDriver.getNetworkReadBuffer()).thenReturn(readBuffer); + when(sslDriver.getNetworkWriteBuffer()).thenReturn(writeBuffer); ByteBuffer buffer = ByteBuffer.allocate(1 << 14); when(selector.getIoBuffer()).thenAnswer(invocationOnMock -> { buffer.clear(); @@ -85,7 +85,7 @@ public void testSuccessfulRead() throws IOException { byte[] bytes = createMessage(messageLength); when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length); - doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); + doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, 0); @@ -100,7 +100,7 @@ public void testMultipleReadsConsumed() throws IOException { byte[] bytes = createMessage(messageLength * 2); when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length); - doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); + doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, messageLength, 0); @@ -115,7 +115,7 @@ public void testPartialRead() throws IOException { byte[] bytes = createMessage(messageLength); when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length); - doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); + doAnswer(getAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer); when(readConsumer.apply(channelBuffer)).thenReturn(0); @@ -173,6 +173,7 @@ public void testSSLDriverClosedOnClose() throws IOException { public void testQueuedWritesAreIgnoredWhenNotReadyForAppWrites() { when(sslDriver.readyForApplicationWrites()).thenReturn(false); + when(sslDriver.hasFlushPending()).thenReturn(false); when(sslDriver.needsNonApplicationWrite()).thenReturn(false); context.queueWriteOperation(mock(FlushReadyWrite.class)); @@ -180,25 +181,25 @@ public void testQueuedWritesAreIgnoredWhenNotReadyForAppWrites() { assertFalse(context.readyForFlush()); } - public void testPendingEncryptedFlushMeansWriteInterested() throws Exception { - when(sslDriver.readyForApplicationWrites()).thenReturn(false); - when(sslDriver.needsNonApplicationWrite()).thenReturn(true, false); - doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); + public void testPendingFlushMeansWriteInterested() { + when(sslDriver.readyForApplicationWrites()).thenReturn(randomBoolean()); + when(sslDriver.hasFlushPending()).thenReturn(true); + when(sslDriver.needsNonApplicationWrite()).thenReturn(false); - // Call will put bytes in buffer to flush - context.flushChannel(); assertTrue(context.readyForFlush()); } public void testNeedsNonAppWritesMeansWriteInterested() { when(sslDriver.readyForApplicationWrites()).thenReturn(false); + when(sslDriver.hasFlushPending()).thenReturn(false); when(sslDriver.needsNonApplicationWrite()).thenReturn(true); assertTrue(context.readyForFlush()); } - public void testNoNonAppWriteInterestInAppMode() { + public void testNotWritesInterestInAppMode() { when(sslDriver.readyForApplicationWrites()).thenReturn(true); + when(sslDriver.hasFlushPending()).thenReturn(false); assertFalse(context.readyForFlush()); @@ -206,68 +207,66 @@ public void testNoNonAppWriteInterestInAppMode() { } public void testFirstFlushMustFinishForWriteToContinue() throws Exception { + when(sslDriver.hasFlushPending()).thenReturn(true, true); when(sslDriver.readyForApplicationWrites()).thenReturn(false); - when(sslDriver.needsNonApplicationWrite()).thenReturn(true); - doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); - // First call will put bytes in buffer to flush - context.flushChannel(); - assertTrue(context.readyForFlush()); - // Second call will will not continue generating non-app bytes because they still need to be flushed context.flushChannel(); - assertTrue(context.readyForFlush()); - verify(sslDriver, times(1)).nonApplicationWrite(any(SSLOutboundBuffer.class)); + verify(sslDriver, times(0)).nonApplicationWrite(); } public void testNonAppWrites() throws Exception { + when(sslDriver.hasFlushPending()).thenReturn(false, false, true, false, true); when(sslDriver.needsNonApplicationWrite()).thenReturn(true, true, false); when(sslDriver.readyForApplicationWrites()).thenReturn(false); - doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); - when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(1); context.flushChannel(); - verify(sslDriver, times(2)).nonApplicationWrite(any(SSLOutboundBuffer.class)); + verify(sslDriver, times(2)).nonApplicationWrite(); verify(rawChannel, times(2)).write(same(selector.getIoBuffer())); } public void testNonAppWritesStopIfBufferNotFullyFlushed() throws Exception { - when(sslDriver.needsNonApplicationWrite()).thenReturn(true); + when(sslDriver.hasFlushPending()).thenReturn(false, false, true, true); + when(sslDriver.needsNonApplicationWrite()).thenReturn(true, true, true, true); when(sslDriver.readyForApplicationWrites()).thenReturn(false); - doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); - when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(0); context.flushChannel(); - verify(sslDriver, times(1)).nonApplicationWrite(any(SSLOutboundBuffer.class)); + verify(sslDriver, times(1)).nonApplicationWrite(); verify(rawChannel, times(1)).write(same(selector.getIoBuffer())); } public void testQueuedWriteIsFlushedInFlushCall() throws Exception { ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; - FlushReadyWrite flushOperation = new FlushReadyWrite(context, buffers, listener); + FlushReadyWrite flushOperation = mock(FlushReadyWrite.class); context.queueWriteOperation(flushOperation); + when(flushOperation.getBuffersToWrite()).thenReturn(buffers); + when(flushOperation.getListener()).thenReturn(listener); + when(sslDriver.hasFlushPending()).thenReturn(false, false, false, false); when(sslDriver.readyForApplicationWrites()).thenReturn(true); - doAnswer(getWriteAnswer(10, true)).when(sslDriver).write(eq(flushOperation), any(SSLOutboundBuffer.class)); - - when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(10); + when(sslDriver.applicationWrite(buffers)).thenReturn(10); + when(flushOperation.isFullyFlushed()).thenReturn(false,true); context.flushChannel(); + verify(flushOperation).incrementIndex(10); verify(rawChannel, times(1)).write(same(selector.getIoBuffer())); verify(selector).executeListener(listener, null); assertFalse(context.readyForFlush()); } public void testPartialFlush() throws IOException { - ByteBuffer[] buffers = {ByteBuffer.allocate(5)}; - FlushReadyWrite flushOperation = new FlushReadyWrite(context, buffers, listener); + ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; + FlushReadyWrite flushOperation = mock(FlushReadyWrite.class); context.queueWriteOperation(flushOperation); + when(flushOperation.getBuffersToWrite()).thenReturn(buffers); + when(flushOperation.getListener()).thenReturn(listener); + when(sslDriver.hasFlushPending()).thenReturn(false, false, true); when(sslDriver.readyForApplicationWrites()).thenReturn(true); - doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(eq(flushOperation), any(SSLOutboundBuffer.class)); - when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(4); + when(sslDriver.applicationWrite(buffers)).thenReturn(5); + when(flushOperation.isFullyFlushed()).thenReturn(false, false); context.flushChannel(); verify(rawChannel, times(1)).write(same(selector.getIoBuffer())); @@ -280,16 +279,24 @@ public void testMultipleWritesPartialFlushes() throws IOException { BiConsumer listener2 = mock(BiConsumer.class); ByteBuffer[] buffers1 = {ByteBuffer.allocate(10)}; ByteBuffer[] buffers2 = {ByteBuffer.allocate(5)}; - FlushReadyWrite flushOperation1 = new FlushReadyWrite(context, buffers1, listener); - FlushReadyWrite flushOperation2 = new FlushReadyWrite(context, buffers2, listener2); + FlushReadyWrite flushOperation1 = mock(FlushReadyWrite.class); + FlushReadyWrite flushOperation2 = mock(FlushReadyWrite.class); + when(flushOperation1.getBuffersToWrite()).thenReturn(buffers1); + when(flushOperation2.getBuffersToWrite()).thenReturn(buffers2); + when(flushOperation1.getListener()).thenReturn(listener); + when(flushOperation2.getListener()).thenReturn(listener2); context.queueWriteOperation(flushOperation1); context.queueWriteOperation(flushOperation2); + when(sslDriver.hasFlushPending()).thenReturn(false, false, false, false, false, true); when(sslDriver.readyForApplicationWrites()).thenReturn(true); - doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(any(FlushOperation.class), any(SSLOutboundBuffer.class)); - when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(5, 5, 2); + when(sslDriver.applicationWrite(buffers1)).thenReturn(5, 5); + when(sslDriver.applicationWrite(buffers2)).thenReturn(3); + when(flushOperation1.isFullyFlushed()).thenReturn(false, false, true); + when(flushOperation2.isFullyFlushed()).thenReturn(false); context.flushChannel(); + verify(flushOperation1, times(2)).incrementIndex(5); verify(rawChannel, times(3)).write(same(selector.getIoBuffer())); verify(selector).executeListener(listener, null); verify(selector, times(0)).executeListener(listener2, null); @@ -297,27 +304,29 @@ public void testMultipleWritesPartialFlushes() throws IOException { } public void testWhenIOExceptionThrownListenerIsCalled() throws IOException { - ByteBuffer[] buffers = {ByteBuffer.allocate(5)}; - FlushReadyWrite flushOperation = new FlushReadyWrite(context, buffers, listener); + ByteBuffer[] buffers = {ByteBuffer.allocate(10)}; + FlushReadyWrite flushOperation = mock(FlushReadyWrite.class); context.queueWriteOperation(flushOperation); IOException exception = new IOException(); + when(flushOperation.getBuffersToWrite()).thenReturn(buffers); + when(flushOperation.getListener()).thenReturn(listener); + when(sslDriver.hasFlushPending()).thenReturn(false, false); when(sslDriver.readyForApplicationWrites()).thenReturn(true); - doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(eq(flushOperation), any(SSLOutboundBuffer.class)); + when(sslDriver.applicationWrite(buffers)).thenReturn(5); when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception); + when(flushOperation.isFullyFlushed()).thenReturn(false); expectThrows(IOException.class, () -> context.flushChannel()); + verify(flushOperation).incrementIndex(5); verify(selector).executeFailedListener(listener, exception); assertFalse(context.readyForFlush()); } public void testWriteIOExceptionMeansChannelReadyToClose() throws Exception { - when(sslDriver.readyForApplicationWrites()).thenReturn(false); + when(sslDriver.hasFlushPending()).thenReturn(true); when(sslDriver.needsNonApplicationWrite()).thenReturn(true); - doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); - - context.flushChannel(); - + when(sslDriver.readyForApplicationWrites()).thenReturn(false); when(rawChannel.write(any(ByteBuffer.class))).thenThrow(new IOException()); assertFalse(context.selectorShouldClose()); @@ -404,27 +413,7 @@ public void testRegisterInitiatesDriver() throws IOException { } } - private Answer getWriteAnswer(int bytesToEncrypt, boolean isApp) { - return invocationOnMock -> { - SSLOutboundBuffer outboundBuffer; - if (isApp) { - outboundBuffer = (SSLOutboundBuffer) invocationOnMock.getArguments()[1]; - } else { - outboundBuffer = (SSLOutboundBuffer) invocationOnMock.getArguments()[0]; - } - ByteBuffer byteBuffer = outboundBuffer.nextWriteBuffer(bytesToEncrypt + 1); - for (int i = 0; i < bytesToEncrypt; ++i) { - byteBuffer.put((byte) i); - } - outboundBuffer.incrementEncryptedBytes(bytesToEncrypt); - if (isApp) { - ((FlushOperation) invocationOnMock.getArguments()[0]).incrementIndex(bytesToEncrypt); - } - return bytesToEncrypt; - }; - } - - private Answer getReadAnswerForBytes(byte[] bytes) { + private Answer getAnswerForBytes(byte[] bytes) { return invocationOnMock -> { InboundChannelBuffer buffer = (InboundChannelBuffer) invocationOnMock.getArguments()[0]; buffer.ensureCapacity(buffer.getIndex() + bytes.length); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java index 4b86d3223b061..b1d39ddc6ac9f 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java @@ -6,9 +6,7 @@ package org.elasticsearch.xpack.security.transport.nio; import org.elasticsearch.bootstrap.JavaVersion; -import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; -import org.elasticsearch.nio.Page; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ssl.CertParsingUtils; import org.elasticsearch.xpack.core.ssl.PemUtils; @@ -30,7 +28,8 @@ public class SSLDriverTests extends ESTestCase { - private final Supplier pageSupplier = () -> new Page(ByteBuffer.allocate(1 << 14), () -> {}); + private final Supplier pageSupplier = + () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), () -> {}); private InboundChannelBuffer serverBuffer = new InboundChannelBuffer(pageSupplier); private InboundChannelBuffer clientBuffer = new InboundChannelBuffer(pageSupplier); private InboundChannelBuffer genericBuffer = new InboundChannelBuffer(pageSupplier); @@ -142,6 +141,10 @@ public void testHandshakeFailureBecauseProtocolMismatch() throws Exception { boolean expectedMessage = oldExpected.equals(sslException.getMessage()) || jdk11Expected.equals(sslException.getMessage()); assertTrue("Unexpected exception message: " + sslException.getMessage(), expectedMessage); + // In JDK11 we need an non-application write + if (serverDriver.needsNonApplicationWrite()) { + serverDriver.nonApplicationWrite(); + } // Prior to JDK11 we still need to send a close alert if (serverDriver.isClosed() == false) { failedCloseAlert(serverDriver, clientDriver, Arrays.asList("Received fatal alert: protocol_version", @@ -163,7 +166,10 @@ public void testHandshakeFailureBecauseNoCiphers() throws Exception { SSLDriver serverDriver = getDriver(serverEngine, false); expectThrows(SSLException.class, () -> handshake(clientDriver, serverDriver)); - + // In JDK11 we need an non-application write + if (serverDriver.needsNonApplicationWrite()) { + serverDriver.nonApplicationWrite(); + } // Prior to JDK11 we still need to send a close alert if (serverDriver.isClosed() == false) { List messages = Arrays.asList("Received fatal alert: handshake_failure", @@ -186,6 +192,8 @@ public void testCloseDuringHandshakeJDK11() throws Exception { sendHandshakeMessages(clientDriver, serverDriver); sendHandshakeMessages(serverDriver, clientDriver); + sendData(clientDriver, serverDriver); + assertTrue(clientDriver.isHandshaking()); assertTrue(serverDriver.isHandshaking()); @@ -219,6 +227,8 @@ public void testCloseDuringHandshakePreJDK11() throws Exception { sendHandshakeMessages(clientDriver, serverDriver); sendHandshakeMessages(serverDriver, clientDriver); + sendData(clientDriver, serverDriver); + assertTrue(clientDriver.isHandshaking()); assertTrue(serverDriver.isHandshaking()); @@ -296,12 +306,12 @@ private void normalClose(SSLDriver sendDriver, SSLDriver receiveDriver) throws I } private void sendNonApplicationWrites(SSLDriver sendDriver, SSLDriver receiveDriver) throws SSLException { - SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); - while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) { - if (outboundBuffer.hasEncryptedBytesToFlush()) { - sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver); - } else { - sendDriver.nonApplicationWrite(outboundBuffer); + while (sendDriver.needsNonApplicationWrite() || sendDriver.hasFlushPending()) { + if (sendDriver.hasFlushPending() == false) { + sendDriver.nonApplicationWrite(); + } + if (sendDriver.hasFlushPending()) { + sendData(sendDriver, receiveDriver, true); } } } @@ -316,7 +326,7 @@ private void handshake(SSLDriver clientDriver, SSLDriver serverDriver, boolean i serverDriver.init(); } - assertTrue(clientDriver.needsNonApplicationWrite()); + assertTrue(clientDriver.needsNonApplicationWrite() || clientDriver.hasFlushPending()); assertFalse(serverDriver.needsNonApplicationWrite()); sendHandshakeMessages(clientDriver, serverDriver); @@ -340,51 +350,58 @@ private void handshake(SSLDriver clientDriver, SSLDriver serverDriver, boolean i } private void sendHandshakeMessages(SSLDriver sendDriver, SSLDriver receiveDriver) throws IOException { - assertTrue(sendDriver.needsNonApplicationWrite()); - - SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); + assertTrue(sendDriver.needsNonApplicationWrite() || sendDriver.hasFlushPending()); - while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) { - if (outboundBuffer.hasEncryptedBytesToFlush()) { - sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver); + while (sendDriver.needsNonApplicationWrite() || sendDriver.hasFlushPending()) { + if (sendDriver.hasFlushPending() == false) { + sendDriver.nonApplicationWrite(); + } + if (sendDriver.isHandshaking()) { + assertTrue(sendDriver.hasFlushPending()); + sendData(sendDriver, receiveDriver); + assertFalse(sendDriver.hasFlushPending()); receiveDriver.read(genericBuffer); - } else { - sendDriver.nonApplicationWrite(outboundBuffer); } } if (receiveDriver.isHandshaking()) { - assertTrue(receiveDriver.needsNonApplicationWrite()); + assertTrue(receiveDriver.needsNonApplicationWrite() || receiveDriver.hasFlushPending()); } } private void sendAppData(SSLDriver sendDriver, SSLDriver receiveDriver, ByteBuffer[] message) throws IOException { + assertFalse(sendDriver.needsNonApplicationWrite()); int bytesToEncrypt = Arrays.stream(message).mapToInt(Buffer::remaining).sum(); - SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); - FlushOperation flushOperation = new FlushOperation(message, (r, l) -> {}); int bytesEncrypted = 0; while (bytesToEncrypt > bytesEncrypted) { - bytesEncrypted += sendDriver.write(flushOperation, outboundBuffer); - sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver); + bytesEncrypted += sendDriver.applicationWrite(message); + sendData(sendDriver, receiveDriver); } } - private void sendData(FlushOperation flushOperation, SSLDriver receiveDriver) { + private void sendData(SSLDriver sendDriver, SSLDriver receiveDriver) { + sendData(sendDriver, receiveDriver, randomBoolean()); + } + + private void sendData(SSLDriver sendDriver, SSLDriver receiveDriver, boolean partial) { + ByteBuffer writeBuffer = sendDriver.getNetworkWriteBuffer(); ByteBuffer readBuffer = receiveDriver.getNetworkReadBuffer(); - ByteBuffer[] writeBuffers = flushOperation.getBuffersToWrite(); - int bytesToEncrypt = Arrays.stream(writeBuffers).mapToInt(Buffer::remaining).sum(); - assert bytesToEncrypt < readBuffer.capacity() : "Flush operation must be less that read buffer"; - assert writeBuffers.length > 0 : "No write buffers"; + if (partial) { + int initialLimit = writeBuffer.limit(); + int bytesToWrite = writeBuffer.remaining() / (randomInt(2) + 2); + writeBuffer.limit(writeBuffer.position() + bytesToWrite); + readBuffer.put(writeBuffer); + writeBuffer.limit(initialLimit); + assertTrue(sendDriver.hasFlushPending()); + readBuffer.put(writeBuffer); + assertFalse(sendDriver.hasFlushPending()); - for (ByteBuffer writeBuffer : writeBuffers) { - int written = writeBuffer.remaining(); + } else { readBuffer.put(writeBuffer); - flushOperation.incrementIndex(written); + assertFalse(sendDriver.hasFlushPending()); } - - assertTrue(flushOperation.isFullyFlushed()); } private SSLDriver getDriver(SSLEngine engine, boolean isClient) {