Skip to content

Commit

Permalink
Create sequence producer tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Coder-256 authored and luben committed Nov 30, 2023
1 parent e4ad211 commit ae1ad52
Show file tree
Hide file tree
Showing 4 changed files with 263 additions and 3 deletions.
4 changes: 4 additions & 0 deletions src/main/java/com/github/luben/zstd/Zstd.java
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,9 @@ public static long decompressDirectByteBufferFastDict(ByteBuffer dst, int dstOff
public static native int loadDictCompress(long stream, byte[] dict, int dict_size);
public static native int loadFastDictCompress(long stream, ZstdDictCompress dict);
public static native void registerSequenceProducer(long stream, long seqProdState, long seqProdFunction);
public static native void generateSequences(long stream, long outSeqs, long outSeqsSize, long src, long srcSize);
static native long getBuiltinSequenceProducer(); // Used in tests
static native long getStubSequenceProducer(); // Used in tests
public static native int setCompressionChecksums(long stream, boolean useChecksums);
public static native int setCompressionMagicless(long stream, boolean useMagicless);
public static native int setCompressionLevel(long stream, int level);
Expand All @@ -578,6 +581,7 @@ public static long decompressDirectByteBufferFastDict(ByteBuffer dst, int dstOff
public static native int setDecompressionLongMax(long stream, int windowLogMax);
public static native int setDecompressionMagicless(long stream, boolean useMagicless);
public static native int setRefMultipleDDicts(long stream, boolean useMultiple);
public static native int setValidateSequences(long stream, boolean validateSequences);

/* Utility methods */
/**
Expand Down
18 changes: 18 additions & 0 deletions src/main/java/com/github/luben/zstd/ZstdCompressCtx.java
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,24 @@ public ZstdCompressCtx setSequenceProducerFallback(boolean fallbackFlag){
}
private static native void setSequenceProducerFallback0(long ptr, boolean fallbackFlag);

public ZstdCompressCtx setValidateSequences(boolean validateSequences) {
ensureOpen();
acquireSharedLock();
try {
long result = Zstd.setValidateSequences(nativePtr, validateSequences);
if (Zstd.isError(result)) {
throw new ZstdException(result);
}
} finally {
releaseSharedLock();
}
return this;
}

// Used in tests
long getNativePtr() {
return nativePtr;
}

/**
* Load compression dictionary to be used for subsequently compressed frames.
Expand Down
61 changes: 61 additions & 0 deletions src/main/native/jni_zstd.c
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,57 @@ JNIEXPORT jint JNICALL Java_com_github_luben_zstd_Zstd_loadFastDictCompress
return ZSTD_CCtx_refCDict((ZSTD_CCtx *)(intptr_t) stream, cdict);
}

size_t builtinSequenceProducer(
void* sequenceProducerState,
ZSTD_Sequence* outSeqs, size_t outSeqsCapacity,
const void* src, size_t srcSize,
const void* dict, size_t dictSize,
int compressionLevel,
size_t windowSize
) {
ZSTD_CCtx *zc = (ZSTD_CCtx *)sequenceProducerState;
int windowLog = 0;
while (windowSize > 1) {
windowLog++;
windowSize >>= 1;
}
ZSTD_CCtx_setParameter(zc, ZSTD_c_compressionLevel, compressionLevel);
ZSTD_CCtx_setParameter(zc, ZSTD_c_windowLog, windowSize);
size_t numSeqs = ZSTD_generateSequences((ZSTD_CCtx *)sequenceProducerState, outSeqs, outSeqsCapacity, src, srcSize);
return ZSTD_isError(numSeqs) ? ZSTD_SEQUENCE_PRODUCER_ERROR : numSeqs;
}

size_t stubSequenceProducer(
void* sequenceProducerState,
ZSTD_Sequence* outSeqs, size_t outSeqsCapacity,
const void* src, size_t srcSize,
const void* dict, size_t dictSize,
int compressionLevel,
size_t windowSize
) {
return ZSTD_SEQUENCE_PRODUCER_ERROR;
}

/*
* Class: com_github_luben_zstd_Zstd
* Method: getBuiltinSequenceProducer
* Signature: ()J
*/
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_Zstd_getBuiltinSequenceProducer
(JNIEnv *env, jclass obj) {
return (jlong)(intptr_t)&builtinSequenceProducer;
}

/*
* Class: com_github_luben_zstd_Zstd
* Method: getBuiltinSequenceProducer
* Signature: ()J
*/
JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_Zstd_getStubSequenceProducer
(JNIEnv *env, jclass obj) {
return (jlong)(intptr_t)&stubSequenceProducer;
}

/*
* Class: com_github_luben_zstd_Zstd
* Method: registerSequenceProducer
Expand Down Expand Up @@ -489,6 +540,16 @@ JNIEXPORT jint JNICALL Java_com_github_luben_zstd_Zstd_setRefMultipleDDicts
return ZSTD_DCtx_setParameter((ZSTD_DCtx *)(intptr_t) stream, ZSTD_d_refMultipleDDicts, value);
}

/*
* Class: com_github_luben_zstd_Zstd
* Method: setValidateSequences
* Signature: (JZ)I
*/
JNIEXPORT jint JNICALL Java_com_github_luben_zstd_Zstd_setValidateSequences
(JNIEnv *env, jclass obj, jlong stream, jboolean validateSequences) {
return ZSTD_CCtx_setParameter((ZSTD_CCtx *)(intptr_t) stream, ZSTD_c_validateSequences, validateSequences);
}

/*
* Class: com_github_luben_zstd_Zstd
* Methods: header constants access
Expand Down
183 changes: 180 additions & 3 deletions src/test/scala/Zstd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import java.nio.channels.FileChannel
import java.nio.channels.FileChannel.MapMode
import java.nio.charset.Charset
import java.nio.file.StandardOpenOption
import scala.io._
import scala.annotation.unused
import scala.collection.mutable.WrappedArray
import scala.io._
import scala.util.Using

class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
Expand Down Expand Up @@ -1105,7 +1106,7 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
}
}

"streaming compressiong and decompression" should "roundtrip" in {
"streaming compression and decompression" should "roundtrip" in {
Using.Manager { use =>
val cctx = use(new ZstdCompressCtx())
val dctx = use(new ZstdDecompressCtx())
Expand Down Expand Up @@ -1149,7 +1150,7 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
decompressedBuffer.flip()

val comparison = inputBuffer.compareTo(decompressedBuffer)
comparison == 0 && Zstd.decompressedSize(compressedBuffer) == size && Zstd.getFrameContentSize(compressedBuffer) == size
assert(comparison == 0 && Zstd.decompressedSize(compressedBuffer) == size && Zstd.getFrameContentSize(compressedBuffer) == size)
}
}
}.get
Expand Down Expand Up @@ -1211,4 +1212,180 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
}
}
}.get

it should "be able to use a sequence producer" in {
Using.Manager { use =>
val cctx = use(new ZstdCompressCtx())
val cctx2 = use(new ZstdCompressCtx())
val dctx = use(new ZstdDecompressCtx())

forAll { input: Array[Byte] =>
{
val size = input.length
val inputBuffer = ByteBuffer.allocateDirect(size)
inputBuffer.put(input)
inputBuffer.flip()
cctx.reset()
cctx.setLevel(9)
val seqProd = new SequenceProducer {
def getFunctionPointer(): Long = {
Zstd.getBuiltinSequenceProducer()
}

def createState(): Long = {
cctx2.getNativePtr()
}

def freeState(@unused state: Long) = {}
}
cctx.registerSequenceProducer(seqProd)
cctx.setValidateSequences(true)
cctx.setSequenceProducerFallback(false)
cctx.setPledgedSrcSize(size)
val compressedBuffer = ByteBuffer.allocateDirect(Zstd.compressBound(size).toInt)
while (inputBuffer.hasRemaining) {
compressedBuffer.limit(compressedBuffer.position() + 1)
cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.CONTINUE)
}

var frameProgression = cctx.getFrameProgression()
assert(frameProgression.getIngested() == size)
assert(frameProgression.getFlushed() == compressedBuffer.position())

compressedBuffer.limit(compressedBuffer.capacity())
val done = cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.END)
assert(done)

frameProgression = cctx.getFrameProgression()
assert(frameProgression.getConsumed() == size)

compressedBuffer.flip()
val decompressedBuffer = ByteBuffer.allocateDirect(size)
dctx.reset()
while (compressedBuffer.hasRemaining) {
if (decompressedBuffer.limit() < decompressedBuffer.position()) {
decompressedBuffer.limit(compressedBuffer.position() + 1)
}
dctx.decompressDirectByteBufferStream(decompressedBuffer, compressedBuffer)
}

inputBuffer.rewind()
compressedBuffer.rewind()
decompressedBuffer.flip()

val comparison = inputBuffer.compareTo(decompressedBuffer)
assert(comparison == 0 && Zstd.decompressedSize(compressedBuffer) == size && Zstd.getFrameContentSize(compressedBuffer) == size)
}
}
}.get
}

it should "fail with a stub sequence producer" in {
Using.Manager { use =>
val cctx = use(new ZstdCompressCtx())

forAll(minSize(32)) { input: Array[Byte] =>
{
val size = input.length
val inputBuffer = ByteBuffer.allocateDirect(size)
inputBuffer.put(input)
inputBuffer.flip()
cctx.reset()
cctx.setLevel(9)

val seqProd = new SequenceProducer {
def getFunctionPointer(): Long = {
Zstd.getStubSequenceProducer()
}

def createState(): Long = { 0 }
def freeState(@unused state: Long) = { 0 }
}

cctx.registerSequenceProducer(seqProd)
cctx.setValidateSequences(true)
cctx.setSequenceProducerFallback(false)
cctx.setPledgedSrcSize(size)

val compressedBuffer = ByteBuffer.allocateDirect(Zstd.compressBound(size).toInt)
try {
while (inputBuffer.hasRemaining) {
compressedBuffer.limit(compressedBuffer.position() + 1)
cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.CONTINUE)
}
cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.END)
fail("compression succeeded, but should have failed")
} catch {
case _: ZstdException => // compression should throw a ZstdException
}
}
}
}.get
}

it should "succeed with a stub sequence producer and software fallback" in {
Using.Manager { use =>
val cctx = use(new ZstdCompressCtx())
val dctx = use(new ZstdDecompressCtx())

forAll { input: Array[Byte] =>
{
val size = input.length
val inputBuffer = ByteBuffer.allocateDirect(size)
inputBuffer.put(input)
inputBuffer.flip()
cctx.reset()
cctx.setLevel(9)

val seqProd = new SequenceProducer {
def getFunctionPointer(): Long = {
Zstd.getStubSequenceProducer()
}

def createState(): Long = { 0 }
def freeState(@unused state: Long) = { 0 }
}

cctx.registerSequenceProducer(seqProd)
cctx.setValidateSequences(true)
cctx.setSequenceProducerFallback(true) // !!
cctx.setPledgedSrcSize(size)

val compressedBuffer = ByteBuffer.allocateDirect(Zstd.compressBound(size).toInt)
while (inputBuffer.hasRemaining) {
compressedBuffer.limit(compressedBuffer.position() + 1)
cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.CONTINUE)
}

var frameProgression = cctx.getFrameProgression()
assert(frameProgression.getIngested() == size)
assert(frameProgression.getFlushed() == compressedBuffer.position())

compressedBuffer.limit(compressedBuffer.capacity())
val done = cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.END)
assert(done)

frameProgression = cctx.getFrameProgression()
assert(frameProgression.getConsumed() == size)

compressedBuffer.flip()
val decompressedBuffer = ByteBuffer.allocateDirect(size)
dctx.reset()
while (compressedBuffer.hasRemaining) {
if (decompressedBuffer.limit() < decompressedBuffer.position()) {
decompressedBuffer.limit(compressedBuffer.position() + 1)
}
dctx.decompressDirectByteBufferStream(decompressedBuffer, compressedBuffer)
}

inputBuffer.rewind()
compressedBuffer.rewind()
decompressedBuffer.flip()

val comparison = inputBuffer.compareTo(decompressedBuffer)
assert(comparison == 0 && Zstd.decompressedSize(compressedBuffer) == size && Zstd.getFrameContentSize(compressedBuffer) == size)
}
}
}.get
}
}

0 comments on commit ae1ad52

Please sign in to comment.