Skip to content

Commit

Permalink
Use CloseableByteBuffer in compress/decompress signatures (#106724)
Browse files Browse the repository at this point in the history
CloseableByteBuffer is backed by native memory segments, but the
interfaces for compress and decompress methods of zstd take ByteBuffer.
Although both Jna and the Jdk can deal with turning the native
ByteBuffer back into an address to pass to the native method, the jdk
may have a more significant cost to that action.

This commit changes the signature of compress and decompress to take in
CloseableByteBuffer so that each implementation can do its own
unwrapping to get the appropriate native address.

relates #103374
  • Loading branch information
rjernst committed Mar 25, 2024
1 parent 78115fb commit 96230f7
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import java.nio.ByteBuffer;

class JnaCloseableByteBuffer implements CloseableByteBuffer {
private final Memory memory;
final Memory memory;
private final ByteBuffer bufferView;

JnaCloseableByteBuffer(int len) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,23 @@

import com.sun.jna.Library;
import com.sun.jna.Native;
import com.sun.jna.Pointer;

import org.elasticsearch.nativeaccess.CloseableByteBuffer;
import org.elasticsearch.nativeaccess.lib.ZstdLibrary;

import java.nio.ByteBuffer;

class JnaZstdLibrary implements ZstdLibrary {

private interface NativeFunctions extends Library {
long ZSTD_compressBound(int scrLen);

long ZSTD_compress(ByteBuffer dst, int dstLen, ByteBuffer src, int srcLen, int compressionLevel);
long ZSTD_compress(Pointer dst, int dstLen, Pointer src, int srcLen, int compressionLevel);

boolean ZSTD_isError(long code);

String ZSTD_getErrorName(long code);

long ZSTD_decompress(ByteBuffer dst, int dstLen, ByteBuffer src, int srcLen);
long ZSTD_decompress(Pointer dst, int dstLen, Pointer src, int srcLen);
}

private final NativeFunctions functions;
Expand All @@ -41,8 +41,18 @@ public long compressBound(int scrLen) {
}

@Override
public long compress(ByteBuffer dst, ByteBuffer src, int compressionLevel) {
return functions.ZSTD_compress(dst, dst.remaining(), src, src.remaining(), compressionLevel);
public long compress(CloseableByteBuffer dst, CloseableByteBuffer src, int compressionLevel) {
assert dst instanceof JnaCloseableByteBuffer;
assert src instanceof JnaCloseableByteBuffer;
var nativeDst = (JnaCloseableByteBuffer) dst;
var nativeSrc = (JnaCloseableByteBuffer) src;
return functions.ZSTD_compress(
nativeDst.memory.share(dst.buffer().position()),
dst.buffer().remaining(),
nativeSrc.memory.share(src.buffer().position()),
src.buffer().remaining(),
compressionLevel
);
}

@Override
Expand All @@ -56,7 +66,16 @@ public String getErrorName(long code) {
}

@Override
public long decompress(ByteBuffer dst, ByteBuffer src) {
return functions.ZSTD_decompress(dst, dst.remaining(), src, src.remaining());
public long decompress(CloseableByteBuffer dst, CloseableByteBuffer src) {
assert dst instanceof JnaCloseableByteBuffer;
assert src instanceof JnaCloseableByteBuffer;
var nativeDst = (JnaCloseableByteBuffer) dst;
var nativeSrc = (JnaCloseableByteBuffer) src;
return functions.ZSTD_decompress(
nativeDst.memory.share(dst.buffer().position()),
dst.buffer().remaining(),
nativeSrc.memory.share(src.buffer().position()),
src.buffer().remaining()
);
}
}
12 changes: 2 additions & 10 deletions libs/native/src/main/java/org/elasticsearch/nativeaccess/Zstd.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,9 @@ public final class Zstd {
* Compress the content of {@code src} into {@code dst} at compression level {@code level}, and return the number of compressed bytes.
* {@link ByteBuffer#position()} and {@link ByteBuffer#limit()} of both {@link ByteBuffer}s are left unmodified.
*/
public int compress(ByteBuffer dst, ByteBuffer src, int level) {
public int compress(CloseableByteBuffer dst, CloseableByteBuffer src, int level) {
Objects.requireNonNull(dst, "Null destination buffer");
Objects.requireNonNull(src, "Null source buffer");
assert dst.isDirect();
assert dst.isReadOnly() == false;
assert src.isDirect();
assert src.isReadOnly() == false;
long ret = zstdLib.compress(dst, src, level);
if (zstdLib.isError(ret)) {
throw new IllegalArgumentException(zstdLib.getErrorName(ret));
Expand All @@ -45,13 +41,9 @@ public int compress(ByteBuffer dst, ByteBuffer src, int level) {
* Compress the content of {@code src} into {@code dst}, and return the number of decompressed bytes. {@link ByteBuffer#position()} and
* {@link ByteBuffer#limit()} of both {@link ByteBuffer}s are left unmodified.
*/
public int decompress(ByteBuffer dst, ByteBuffer src) {
public int decompress(CloseableByteBuffer dst, CloseableByteBuffer src) {
Objects.requireNonNull(dst, "Null destination buffer");
Objects.requireNonNull(src, "Null source buffer");
assert dst.isDirect();
assert dst.isReadOnly() == false;
assert src.isDirect();
assert src.isReadOnly() == false;
long ret = zstdLib.decompress(dst, src);
if (zstdLib.isError(ret)) {
throw new IllegalArgumentException(zstdLib.getErrorName(ret));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@

package org.elasticsearch.nativeaccess.lib;

import java.nio.ByteBuffer;
import org.elasticsearch.nativeaccess.CloseableByteBuffer;

public non-sealed interface ZstdLibrary extends NativeLibrary {

long compressBound(int scrLen);

long compress(ByteBuffer dst, ByteBuffer src, int compressionLevel);
long compress(CloseableByteBuffer dst, CloseableByteBuffer src, int compressionLevel);

boolean isError(long code);

String getErrorName(long code);

long decompress(ByteBuffer dst, ByteBuffer src);
long decompress(CloseableByteBuffer dst, CloseableByteBuffer src);
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@
import org.elasticsearch.nativeaccess.CloseableByteBuffer;

import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.ByteBuffer;

class JdkCloseableByteBuffer implements CloseableByteBuffer {
private final Arena arena;
final MemorySegment segment;
private final ByteBuffer bufferView;

JdkCloseableByteBuffer(int len) {
this.arena = Arena.ofConfined();
this.bufferView = this.arena.allocate(len).asByteBuffer();
this.segment = arena.allocate(len);
this.bufferView = segment.asByteBuffer();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

package org.elasticsearch.nativeaccess.jdk;

import org.elasticsearch.nativeaccess.CloseableByteBuffer;
import org.elasticsearch.nativeaccess.lib.ZstdLibrary;

import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.MemorySegment;
import java.lang.invoke.MethodHandle;
import java.nio.ByteBuffer;

import static java.lang.foreign.ValueLayout.ADDRESS;
import static java.lang.foreign.ValueLayout.JAVA_BOOLEAN;
Expand Down Expand Up @@ -49,11 +49,17 @@ public long compressBound(int srcLen) {
}

@Override
public long compress(ByteBuffer dst, ByteBuffer src, int compressionLevel) {
var nativeDst = MemorySegment.ofBuffer(dst);
var nativeSrc = MemorySegment.ofBuffer(src);
public long compress(CloseableByteBuffer dst, CloseableByteBuffer src, int compressionLevel) {
assert dst instanceof JdkCloseableByteBuffer;
assert src instanceof JdkCloseableByteBuffer;
var nativeDst = (JdkCloseableByteBuffer) dst;
var nativeSrc = (JdkCloseableByteBuffer) src;
var dstSize = dst.buffer().remaining();
var srcSize = src.buffer().remaining();
var segmentDst = nativeDst.segment.asSlice(dst.buffer().position(), dstSize);
var segmentSrc = nativeSrc.segment.asSlice(src.buffer().position(), srcSize);
try {
return (long) compress$mh.invokeExact(nativeDst, dst.remaining(), nativeSrc, src.remaining(), compressionLevel);
return (long) compress$mh.invokeExact(segmentDst, dstSize, segmentSrc, srcSize, compressionLevel);
} catch (Throwable t) {
throw new AssertionError(t);
}
Expand All @@ -79,11 +85,17 @@ public String getErrorName(long code) {
}

@Override
public long decompress(ByteBuffer dst, ByteBuffer src) {
var nativeDst = MemorySegment.ofBuffer(dst);
var nativeSrc = MemorySegment.ofBuffer(src);
public long decompress(CloseableByteBuffer dst, CloseableByteBuffer src) {
assert dst instanceof JdkCloseableByteBuffer;
assert src instanceof JdkCloseableByteBuffer;
var nativeDst = (JdkCloseableByteBuffer) dst;
var nativeSrc = (JdkCloseableByteBuffer) src;
var dstSize = dst.buffer().remaining();
var srcSize = src.buffer().remaining();
var segmentDst = nativeDst.segment.asSlice(dst.buffer().position(), dstSize);
var segmentSrc = nativeSrc.segment.asSlice(src.buffer().position(), srcSize);
try {
return (long) decompress$mh.invokeExact(nativeDst, dst.remaining(), nativeSrc, src.remaining());
return (long) decompress$mh.invokeExact(segmentDst, dstSize, segmentSrc, srcSize);
} catch (Throwable t) {
throw new AssertionError(t);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ public void testCompressValidation() {
var srcBuf = src.buffer();
var dstBuf = dst.buffer();

var npe1 = expectThrows(NullPointerException.class, () -> zstd.compress(null, srcBuf, 0));
var npe1 = expectThrows(NullPointerException.class, () -> zstd.compress(null, src, 0));
assertThat(npe1.getMessage(), equalTo("Null destination buffer"));
var npe2 = expectThrows(NullPointerException.class, () -> zstd.compress(dstBuf, null, 0));
var npe2 = expectThrows(NullPointerException.class, () -> zstd.compress(dst, null, 0));
assertThat(npe2.getMessage(), equalTo("Null source buffer"));

// dst capacity too low
for (int i = 0; i < srcBuf.remaining(); ++i) {
srcBuf.put(i, randomByte());
}
var e = expectThrows(IllegalArgumentException.class, () -> zstd.compress(dstBuf, srcBuf, 0));
var e = expectThrows(IllegalArgumentException.class, () -> zstd.compress(dst, src, 0));
assertThat(e.getMessage(), equalTo("Destination buffer is too small"));
}
}
Expand All @@ -64,21 +64,21 @@ public void testDecompressValidation() {
var originalBuf = original.buffer();
var compressedBuf = compressed.buffer();

var npe1 = expectThrows(NullPointerException.class, () -> zstd.decompress(null, originalBuf));
var npe1 = expectThrows(NullPointerException.class, () -> zstd.decompress(null, original));
assertThat(npe1.getMessage(), equalTo("Null destination buffer"));
var npe2 = expectThrows(NullPointerException.class, () -> zstd.decompress(compressedBuf, null));
var npe2 = expectThrows(NullPointerException.class, () -> zstd.decompress(compressed, null));
assertThat(npe2.getMessage(), equalTo("Null source buffer"));

// Invalid compressed format
for (int i = 0; i < originalBuf.remaining(); ++i) {
originalBuf.put(i, (byte) i);
}
var e = expectThrows(IllegalArgumentException.class, () -> zstd.decompress(compressedBuf, originalBuf));
var e = expectThrows(IllegalArgumentException.class, () -> zstd.decompress(compressed, original));
assertThat(e.getMessage(), equalTo("Unknown frame descriptor"));

int compressedLength = zstd.compress(compressedBuf, originalBuf, 0);
int compressedLength = zstd.compress(compressed, original, 0);
compressedBuf.limit(compressedLength);
e = expectThrows(IllegalArgumentException.class, () -> zstd.decompress(restored.buffer(), compressedBuf));
e = expectThrows(IllegalArgumentException.class, () -> zstd.decompress(restored, compressed));
assertThat(e.getMessage(), equalTo("Destination buffer is too small"));

}
Expand Down Expand Up @@ -109,9 +109,9 @@ private void doTestRoundtrip(byte[] data) {
var restored = nativeAccess.newBuffer(data.length)
) {
original.buffer().put(0, data);
int compressedLength = zstd.compress(compressed.buffer(), original.buffer(), randomIntBetween(-3, 9));
int compressedLength = zstd.compress(compressed, original, randomIntBetween(-3, 9));
compressed.buffer().limit(compressedLength);
int decompressedLength = zstd.decompress(restored.buffer(), compressed.buffer());
int decompressedLength = zstd.decompress(restored, compressed);
assertThat(restored.buffer(), equalTo(original.buffer()));
assertThat(decompressedLength, equalTo(data.length));
}
Expand All @@ -127,15 +127,15 @@ private void doTestRoundtrip(byte[] data) {
original.buffer().put(decompressedOffset, data);
original.buffer().position(decompressedOffset);
compressed.buffer().position(compressedOffset);
int compressedLength = zstd.compress(compressed.buffer(), original.buffer(), randomIntBetween(-3, 9));
int compressedLength = zstd.compress(compressed, original, randomIntBetween(-3, 9));
compressed.buffer().limit(compressedOffset + compressedLength);
restored.buffer().position(decompressedOffset);
int decompressedLength = zstd.decompress(restored.buffer(), compressed.buffer());
int decompressedLength = zstd.decompress(restored, compressed);
assertThat(decompressedLength, equalTo(data.length));
assertThat(
restored.buffer().slice(decompressedOffset, data.length),
equalTo(original.buffer().slice(decompressedOffset, data.length))
);
assertThat(decompressedLength, equalTo(data.length));
}
}
}

0 comments on commit 96230f7

Please sign in to comment.