Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions cli/tests/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,34 @@ custom_unittest(
],
)

custom_unittest(
name = "serial_segmentation_default_test",
command = [
"$(location :integration_test_bin)",
"$(location ..:zli)",
"SerialSegmentationTest.test_serial_default_chunk_size",
],
type = "simple",
deps = [
"..:zli",
":integration_test_bin",
],
)

custom_unittest(
name = "serial_segmentation_chunk_size_test",
command = [
"$(location :integration_test_bin)",
"$(location ..:zli)",
"SerialSegmentationTest.test_serial_with_chunk_size",
],
type = "simple",
deps = [
"..:zli",
":integration_test_bin",
],
)

custom_unittest(
name = "strict_mode_permissive_test",
command = [
Expand Down
57 changes: 57 additions & 0 deletions cli/tests/cli_integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,63 @@ def test_numeric_profiles_with_chunk_size(self) -> None:
)


class SerialSegmentationTest(unittest.TestCase):
"""
Test case for the serial profile's auto-segmentation via the CLI.

Generates raw byte data, compresses with the `serial` profile, decompresses,
and verifies round-trip correctness with and without --chunk-size.
"""

def setUp(self) -> None:
import shutil
import tempfile

self.tmpdir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(self.tmpdir, True))

# Generate ~2MB of data so --chunk-size 1M triggers multi-chunk
# segmentation (2 chunks).
target_bytes = 2 * 1000 * 1000 # 2 MB
data = bytes((i % 256) for i in range(target_bytes))
self.input_path: str = os.path.join(self.tmpdir, "serial.bin")
with open(self.input_path, "wb") as f:
f.write(data)

def _round_trip(self, extra_args: str | None = None) -> None:
compressed_path = self.input_path + ".zl"
decompressed_path = self.input_path + ".rt"

compressor_info = CompressorInfo(
compressor_str="serial",
compressor_type=CompressorType.PROFILE,
)
execute_compress(
file_to_compress_path=self.input_path,
compressor_info=compressor_info,
compressed_file_path=compressed_path,
extra_args=extra_args,
)
execute_decompress(
compressed_file_path=compressed_path,
decompressed_file_path=decompressed_path,
)
from file_utils import file_contents_match

self.assertTrue(
file_contents_match(self.input_path, decompressed_path),
f"Round-trip failed for serial profile on {self.input_path}",
)

def test_serial_default_chunk_size(self) -> None:
"""Default chunk size (16 MiB) on 2MB data → single-chunk segmentation."""
self._round_trip()

def test_serial_with_chunk_size(self) -> None:
"""--chunk-size 1M on 2MB data forces multi-chunk segmentation."""
self._round_trip(extra_args="--chunk-size 1M")


class StrictModeTest(_CompressDecompressBaseTest):
"""
Test case for strict mode behavior.
Expand Down
12 changes: 9 additions & 3 deletions cli/utils/compress_profiles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,16 @@ compressProfiles()
mp[kSerialName] = std::make_shared<CompressProfile>(
kSerialName,
"Serial data (aka raw bytes)",
[](ZL_Compressor* compressor, void*, const ProfileArgs&) {
return ZL_Compressor_buildACEGraphWithDefault(
[](ZL_Compressor* compressor, void*, const ProfileArgs& args) {
ZL_GraphID inner = ZL_Compressor_buildACEGraphWithDefault(
compressor, ZL_GRAPH_LZ);
});
size_t chunkSize = args.chunkSize().value_or(
ZL_DEFAULT_SEGMENTER_CHUNK_BYTE_SIZE);
return ZL_Compressor_buildSerialSegmenter(
compressor, chunkSize, inner);
},
nullptr,
true);

std::string kPytorchName = "pytorch";
mp[kPytorchName] = std::make_shared<CompressProfile>(
Expand Down
1 change: 1 addition & 0 deletions cpp/include/openzl/cpp/Codecs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "openzl/cpp/codecs/RangePack.hpp" // IWYU pragma: export
#include "openzl/cpp/codecs/SDDL.hpp" // IWYU pragma: export
#include "openzl/cpp/codecs/SDDL2.hpp" // IWYU pragma: export
#include "openzl/cpp/codecs/SegmentSerial.hpp" // IWYU pragma: export
#include "openzl/cpp/codecs/Sentinel.hpp" // IWYU pragma: export
#include "openzl/cpp/codecs/Split.hpp" // IWYU pragma: export
#include "openzl/cpp/codecs/SplitByStruct.hpp" // IWYU pragma: export
Expand Down
74 changes: 74 additions & 0 deletions cpp/include/openzl/cpp/codecs/SegmentSerial.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.

#pragma once

#include "openzl/codecs/zl_segmenters.h"
#include "openzl/cpp/Compressor.hpp"
#include "openzl/cpp/codecs/Graph.hpp"
#include "openzl/cpp/codecs/Metadata.hpp"

namespace openzl {
namespace graphs {

class SegmentSerial : public Graph {
public:
static constexpr GraphID graph = ZL_SEGMENT_SERIAL;

static constexpr GraphMetadata<1> metadata = {
.inputs = { InputMetadata{ .typeMask = TypeMask::Serial } },
.description =
"Auto-segmenting graph for serial inputs. Chunks the input into "
"independently compressed segments and forwards each chunk to a "
"successor graph (defaults to ZL_GRAPH_COMPRESS_GENERIC). "
"Defaults to 16 MiB chunking when no chunk-size hint is provided; "
"treats a 0 hint as the default chunk size.",
};

explicit SegmentSerial(GraphID successor, size_t chunkByteSize = 0)
: successor_(successor), chunkByteSize_(chunkByteSize)
{
}

~SegmentSerial() override = default;

SegmentSerial(const SegmentSerial&) = default;
SegmentSerial& operator=(const SegmentSerial&) = default;
SegmentSerial(SegmentSerial&&) = default;
SegmentSerial& operator=(SegmentSerial&&) = default;

GraphID baseGraph() const override
{
return graph;
}

/// Delegates to ZL_Compressor_buildSerialSegmenter2 so the C builder
/// is the single source of truth for chunkByteSize validation, default
/// substitution, and the resulting parameter_invalid error code.
GraphID parameterize(Compressor& compressor) const override
{
return compressor.unwrap(ZL_Compressor_buildSerialSegmenter2(
compressor.get(), chunkByteSize_, successor_));
}

/// Used by setMultiInputDestination in FunctionGraph contexts. Mirrors
/// the C builder's wire-level parameterization but without validation —
/// callers in that path must pass a chunkByteSize that fits in int.
poly::optional<GraphParameters> parameters() const override
{
LocalParams lp;
if (chunkByteSize_ != 0) {
lp.addIntParam(
ZL_SEGMENT_SERIAL_CHUNK_BYTE_SIZE_PARAM,
static_cast<int>(chunkByteSize_));
}
return GraphParameters{ .customGraphs = { { successor_ } },
.localParams = std::move(lp) };
}

private:
GraphID successor_;
size_t chunkByteSize_;
};

} // namespace graphs
} // namespace openzl
33 changes: 33 additions & 0 deletions cpp/tests/TestCodecs.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.

#include <climits>

#include <gtest/gtest.h>

#include "cpp/tests/TestUtils.hpp"
Expand Down Expand Up @@ -45,4 +47,35 @@ TEST_F(TestCodecs, lz4_hc)
compressor_.selectStartingGraph(graph);
auto compressed = testRoundTrip(compressor_, Input::refSerial(data));
}

TEST_F(TestCodecs, segmentSerial_defaultChunkSize)
{
/* chunkByteSize = 0 sentinel: use the segmenter's built-in default. */
std::string data(4096, 'a');
auto graph = graphs::SegmentSerial(ZL_GRAPH_COMPRESS_GENERIC)
.parameterize(compressor_);
compressor_.selectStartingGraph(graph);
auto compressed = testRoundTrip(compressor_, Input::refSerial(data));
}

TEST_F(TestCodecs, segmentSerial_explicitChunkSize)
{
/* Explicit chunk size at the minimum threshold must round-trip. */
std::string data(ZL_MIN_CHUNK_SIZE * 2, 'a');
auto graph =
graphs::SegmentSerial(ZL_GRAPH_COMPRESS_GENERIC, ZL_MIN_CHUNK_SIZE)
.parameterize(compressor_);
compressor_.selectStartingGraph(graph);
auto compressed = testRoundTrip(compressor_, Input::refSerial(data));
}

TEST_F(TestCodecs, segmentSerial_chunkSizeOverflowThrows)
{
/* The C builder rejects chunk sizes that would not fit in int; the C++
* wrapper surfaces that rejection as a typed Exception at parameterize
* time (construction itself does not validate). */
graphs::SegmentSerial wrapper(
ZL_GRAPH_COMPRESS_GENERIC, static_cast<size_t>(INT_MAX) + 1);
EXPECT_THROW(wrapper.parameterize(compressor_), Exception);
}
} // namespace openzl
55 changes: 53 additions & 2 deletions include/openzl/codecs/zl_segmenters.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
extern "C" {
#endif

/** Pass as successorGraph to use the built-in default pipeline
* (interpret serial as numeric + compress). */
/** Pass as successorGraph to let the segmenter substitute its built-in
* default successor. Each segmenter family resolves this to its own
* default — see the corresponding builder's documentation
* (e.g. ZL_Compressor_buildNumFromSerialSegmenter resolves to
* ZL_GRAPH_INTERPRET_NUMxx_COMPRESS; ZL_Compressor_buildSerialSegmenter
* resolves to ZL_GRAPH_COMPRESS_GENERIC). */
#define ZL_SEGMENTER_DEFAULT_SUCCESSOR ZL_GRAPH_ILLEGAL

// Numeric segmenter (numeric input)
Expand All @@ -41,6 +45,22 @@ extern "C" {
#define ZL_SEGMENT_NUM64_FROM_SERIAL \
ZL_MAKE_GRAPH_ID(ZL_StandardGraphID_segment_num64_from_serial)

// Serial-input segmenter (no element-width interpretation)
// Input : 1 stream of serial data
// Result : chunks the serial input using default chunk size,
// and forwards each chunk to the default successor
// ZL_GRAPH_COMPRESS_GENERIC.
// Both chunk size and successor can be overridden via
// ZL_Compressor_buildSerialSegmenter().
#define ZL_SEGMENT_SERIAL ZL_MAKE_GRAPH_ID(ZL_StandardGraphID_segment_serial)

/**
* Local int parameter ID for the serial segmenter's chunk byte size.
* When omitted, ZL_SEGMENT_SERIAL falls back to
* ZL_DEFAULT_SEGMENTER_CHUNK_BYTE_SIZE.
*/
#define ZL_SEGMENT_SERIAL_CHUNK_BYTE_SIZE_PARAM 7700

/**
* @brief Register a serial-numeric segmenter.
*
Expand Down Expand Up @@ -78,6 +98,37 @@ ZL_Compressor_buildNumFromSerialSegmenter2(
size_t chunkByteSize,
ZL_GraphID successorGraph);

/**
* @brief Register a serial segmenter.
*
* Creates a parameterized segmenter that accepts serial input, chunks it
* by @p chunkByteSize, and forwards each chunk to @p successorGraph.
*
* @param compressor The compressor to register with
* @param chunkByteSize Maximum chunk size in bytes. Pass 0 to use the
* built-in default (ZL_DEFAULT_SEGMENTER_CHUNK_BYTE_SIZE). Otherwise
* must be in [ZL_MIN_CHUNK_SIZE, INT_MAX]; smaller positive values
* are rejected because ZL_compressBound() assumes chunks of at least
* ZL_MIN_CHUNK_SIZE bytes.
* @param successorGraph The graph to process each chunk, or
* ZL_SEGMENTER_DEFAULT_SUCCESSOR to use ZL_GRAPH_COMPRESS_GENERIC
* @return The registered segmenter graph ID, or ZL_GRAPH_ILLEGAL on error
*/
ZL_GraphID ZL_Compressor_buildSerialSegmenter(
ZL_Compressor* compressor,
size_t chunkByteSize,
ZL_GraphID successorGraph);

/**
* Same as ZL_Compressor_buildSerialSegmenter(), but returns a
* ZL_RESULT_OF(ZL_GraphID) for richer error reporting.
*/
ZL_RESULT_OF(ZL_GraphID)
ZL_Compressor_buildSerialSegmenter2(
ZL_Compressor* compressor,
size_t chunkByteSize,
ZL_GraphID successorGraph);

#if defined(__cplusplus)
}
#endif
Expand Down
2 changes: 2 additions & 0 deletions include/openzl/zl_graphs.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ typedef enum {

ZL_StandardGraphID_lz,

ZL_StandardGraphID_segment_serial,

ZL_StandardGraphID_public_end // last id, used to detect end of public
// range
} ZL_StandardGraphID;
Expand Down
2 changes: 2 additions & 0 deletions src/openzl/compress/graph_registry.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "openzl/compress/implicit_conversion.h" // ICONV_isCompatible for type checking
#include "openzl/compress/private_nodes.h" // ZL_PrivateStandardGraphID_end, private node ID definitions
#include "openzl/compress/segmenters/segmenter_numeric.h" // SEGM_numeric_desc
#include "openzl/compress/segmenters/segmenter_serial.h" // SEGM_serial_desc
#include "openzl/compress/selector.h" // SelectorCtx, ZL_SelectorFn, SelCtx_* functions
#include "openzl/compress/selectors/ml/ml_selector_graph.h" // ZL_MLSel_dynGraph
#include "openzl/compress/selectors/selector_compress.h" // SI_selector_compress, SI_selector_compress_* functions
Expand Down Expand Up @@ -149,6 +150,7 @@ const InternalGraphDesc GR_standardGraphs[ZL_PrivateStandardGraphID_end] = {
REGISTER_SEGMENTER(ZL_StandardGraphID_segment_num16_from_serial, SEGM_NUM_FROM_SERIAL_DESC(2, 16, ZL_PrivateStandardGraphID_interpret_num16_compress)),
REGISTER_SEGMENTER(ZL_StandardGraphID_segment_num32_from_serial, SEGM_NUM_FROM_SERIAL_DESC(4, 32, ZL_PrivateStandardGraphID_interpret_num32_compress)),
REGISTER_SEGMENTER(ZL_StandardGraphID_segment_num64_from_serial, SEGM_NUM_FROM_SERIAL_DESC(8, 64, ZL_PrivateStandardGraphID_interpret_num64_compress)),
REGISTER_SEGMENTER(ZL_StandardGraphID_segment_serial, SEGM_SERIAL_DESC),
REGISTER_SELECTOR(ZL_StandardGraphID_select_numeric, "!zl.select_numeric", SI_selector_numeric, ZL_Type_numeric),
REGISTER_MIGRAPH(ZL_StandardGraphID_clustering, MIGRAPH_CLUSTERING),
REGISTER_DYNAMIC_GRAPH(ZL_StandardGraphID_simple_data_description_language, "!zl.sddl", ZL_Type_serial, ZL_SDDL_dynGraph),
Expand Down
Loading
Loading