Skip to content

Commit

Permalink
Add support for multiple dictionary references in streaming decompres…
Browse files Browse the repository at this point in the history
…sion.
  • Loading branch information
Alex1OPS authored and luben committed May 16, 2023
1 parent 100c434 commit d0997db
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 9 deletions.
1 change: 1 addition & 0 deletions src/main/java/com/github/luben/zstd/Zstd.java
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ public static long decompressDirectByteBufferFastDict(ByteBuffer dst, int dstOff
public static native int setCompressionWorkers(long stream, int workers);
public static native int setDecompressionLongMax(long stream, int windowLogMax);
public static native int setDecompressionMagicless(long stream, boolean useMagicless);
public static native int setRefMultipleDDicts(long stream, boolean useMultiple);

/* Utility methods */

Expand Down
11 changes: 11 additions & 0 deletions src/main/java/com/github/luben/zstd/ZstdInputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ public ZstdInputStream setLongMax(int windowLogMax) throws IOException {
return this;
}

/**
* Enable or disable support for multiple dictionary references
*
* @param useMultiple Enables references table for DDict, so the DDict used for decompression will be
* determined per the dictId in the frame, default: false
*/
public ZstdInputStream setRefMultipleDDicts(boolean useMultiple) throws IOException {
inner.setRefMultipleDDicts(useMultiple);
return this;
}

public int read(byte[] dst, int offset, int len) throws IOException {
return inner.read(dst, offset, len);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ public synchronized ZstdInputStreamNoFinalizer setLongMax(int windowLogMax) thro
return this;
}

public synchronized ZstdInputStreamNoFinalizer setRefMultipleDDicts(boolean useMultiple) throws IOException {
int size = Zstd.setRefMultipleDDicts(stream, useMultiple);
if (Zstd.isError(size)) {
throw new ZstdIOException(size);
}
return this;
}

public synchronized int read(byte[] dst, int offset, int len) throws IOException {
// guard agains buffer overflows
if (offset < 0 || len > dst.length - offset) {
Expand Down
11 changes: 11 additions & 0 deletions src/main/native/jni_zstd.c
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,17 @@ JNIEXPORT jint JNICALL Java_com_github_luben_zstd_Zstd_setCompressionWorkers
return ZSTD_CCtx_setParameter((ZSTD_CCtx *)(intptr_t) stream, ZSTD_c_nbWorkers, workers);
}

/*
* Class: com_github_luben_zstd_Zstd
* Method: setRefMultipleDDicts
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_com_github_luben_zstd_Zstd_setRefMultipleDDicts
(JNIEnv *env, jclass obj, jlong stream, jboolean enabled) {
ZSTD_refMultipleDDicts_e value = enabled ? ZSTD_rmd_refMultipleDDicts : ZSTD_rmd_refSingleDDict;
return ZSTD_DCtx_setParameter((ZSTD_DCtx *)(intptr_t) stream, ZSTD_d_refMultipleDDicts, value);
}

/*
* Class: com_github_luben_zstd_Zstd
* Methods: header constants access
Expand Down
93 changes: 84 additions & 9 deletions src/test/scala/ZstdDict.scala
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
package com.github.luben.zstd

import org.scalatest.flatspec.AnyFlatSpec

import java.io._
import java.nio._
import java.nio.channels.FileChannel
import java.nio.channels.FileChannel.MapMode
import java.nio.file.StandardOpenOption

import scala.io._
import scala.collection.mutable.WrappedArray
import scala.util.Using

class ZstdDictSpec extends AnyFlatSpec {

def source = Source.fromFile("src/test/resources/xml")(Codec.ISO8859).map{_.toByte}

def train(legacy: Boolean): Array[Byte] = {
val src = source.sliding(1024, 1024).take(1024).map(_.toArray)
val trainer = new ZstdDictTrainer(1024 * 1024, 32 * 1024)
def train(legacy: Boolean, sampleSize: Int): Array[Byte] = {
val src = source.sliding(1024, 1024).take(sampleSize).map(_.toArray)
val trainer = new ZstdDictTrainer(1024 * sampleSize, 32 * sampleSize)
for (sample <- src) {
trainer.addSample(sample)
}
Expand Down Expand Up @@ -52,7 +49,8 @@ class ZstdDictSpec extends AnyFlatSpec {
val levels = List(1)
for {
legacy <- legacyS
dict = train(legacy)
dict = train(legacy, 1024)
dict2 = train(legacy, 512)
dictInDirectByteBuffer = wrapInDirectByteBuffer(dict)
level <- levels
} {
Expand Down Expand Up @@ -282,6 +280,83 @@ class ZstdDictSpec extends AnyFlatSpec {
assert(input.toSeq == output.toSeq)
}

it should s"round-trip streaming compression/decompression with multiple fast dicts with legacy $legacy " in {
// given: compress using first one dictionary, then another
val cdict = new ZstdDictCompress(dict, 0, dict.length, 1)
val cdict2 = new ZstdDictCompress(dict2, 0, dict2.length, 1)

val compressedWithDict1 = compressWithDict(cdict)
val compressedWithDict2 = compressWithDict(cdict2)

// when: decompress with the both dictionaries configured and multiple dict references enabled
val ddict = new ZstdDictDecompress(dict)
val ddict2 = new ZstdDictDecompress(dict2)

val dicts = ddict::ddict2::Nil
val uncompressed1 = uncompressWithMultipleDicts(compressedWithDict1, dicts)
val uncompressed2 = uncompressWithMultipleDicts(compressedWithDict2, dicts)

// then: both compressed inputs decompressed successfully
assert(uncompressed1.toSeq == input.toSeq)
assert(Zstd.getDictIdFromFrame(compressedWithDict1) == Zstd.getDictIdFromDict(dict))

assert(uncompressed2.toSeq == input.toSeq)
assert(Zstd.getDictIdFromFrame(compressedWithDict2) == Zstd.getDictIdFromDict(dict2))
}

it should s"round-trip streaming compression/decompression with multiple fast dicts with legacy $legacy and disabled multiple dict references" in {
// given: compress using first one dictionary, then another
val cdict = new ZstdDictCompress(dict, 0, dict.length, 1)
val cdict2 = new ZstdDictCompress(dict2, 0, dict2.length, 1)

val compressedWithDict1 = compressWithDict(cdict)
val compressedWithDict2 = compressWithDict(cdict2)

// when: decompress with the both dictionaries configured and multiple dict references disabled
// -> should be used only the second one
val ddict = new ZstdDictDecompress(dict)
val ddict2 = new ZstdDictDecompress(dict2)

val dicts = ddict :: ddict2 :: Nil
val uncompressed2 = uncompressWithMultipleDicts(compressedWithDict2, dicts, multipleDdicts = false)

// then: decompression of compressed with the first dict should fail with dictionary mismatch,
// the second one should be decompressed successfully
val caughtException = intercept[ZstdIOException] {
uncompressWithMultipleDicts(compressedWithDict1, dicts, multipleDdicts = false)
}
assert(caughtException.getMessage == "Dictionary mismatch")

assert(uncompressed2.toSeq == input.toSeq)
assert(Zstd.getDictIdFromFrame(compressedWithDict2) == Zstd.getDictIdFromDict(dict2))
}

def compressWithDict(cdict: ZstdDictCompress): Array[Byte] = {
val os = new ByteArrayOutputStream(Zstd.compressBound(input.length.toLong).toInt)
Using(new ZstdOutputStream(os, 1)) { zos =>
zos.setDict(cdict)
zos.write(input)
}
os.toByteArray
}

def uncompressWithMultipleDicts(
compressed: Array[Byte],
dicts: List[ZstdDictDecompress],
multipleDdicts: Boolean = true
): Array[Byte] = {
Using.resources(
new ZstdInputStream(new ByteArrayInputStream(compressed))
.setRefMultipleDDicts(multipleDdicts),
new ByteArrayOutputStream()
) { (zis, os) =>
dicts.foreach(zis.setDict)

zis.transferTo(os)
os.toByteArray
}
}

it should s"round-trip streaming ByteBuffer compression/decompression with byte[] dict with legacy $legacy" in {
val size = input.length
val os = ByteBuffer.allocateDirect(Zstd.compressBound(size.toLong).toInt)
Expand Down

0 comments on commit d0997db

Please sign in to comment.