diff --git a/lib/jxl/ans_test.cc b/lib/jxl/ans_test.cc index 83a2e732f8d..94a99ab05f9 100644 --- a/lib/jxl/ans_test.cc +++ b/lib/jxl/ans_test.cc @@ -3,9 +3,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -#include -#include +#include +#include +#include #include #include "lib/jxl/ans_params.h" @@ -15,6 +16,7 @@ #include "lib/jxl/dec_bit_reader.h" #include "lib/jxl/enc_ans.h" #include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/test_utils.h" #include "lib/jxl/testing.h" namespace jxl { @@ -22,10 +24,11 @@ namespace { void RoundtripTestcase(int n_histograms, int alphabet_size, const std::vector& input_values) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); constexpr uint16_t kMagic1 = 0x9e33; constexpr uint16_t kMagic2 = 0x8b04; - BitWriter writer; + BitWriter writer{memory_manager}; // Space for magic bytes. BitWriter::Allotment allotment_magic1(&writer, 16); writer.Write(16, kMagic1); @@ -36,8 +39,9 @@ void RoundtripTestcase(int n_histograms, int alphabet_size, std::vector> input_values_vec; input_values_vec.push_back(input_values); - BuildAndEncodeHistograms(HistogramParams(), n_histograms, input_values_vec, - &codes, &context_map, &writer, 0, nullptr); + BuildAndEncodeHistograms(memory_manager, HistogramParams(), n_histograms, + input_values_vec, &codes, &context_map, &writer, 0, + nullptr); WriteTokens(input_values_vec[0], codes, context_map, 0, &writer, 0, nullptr); // Magic bytes + padding @@ -54,10 +58,11 @@ void RoundtripTestcase(int n_histograms, int alphabet_size, std::vector dec_context_map; ANSCode decoded_codes; - ASSERT_TRUE( - DecodeHistograms(&br, n_histograms, &decoded_codes, &dec_context_map)); + ASSERT_TRUE(DecodeHistograms(memory_manager, &br, n_histograms, + &decoded_codes, &dec_context_map)); ASSERT_EQ(dec_context_map, context_map); - ANSSymbolReader reader(&decoded_codes, &br); + JXL_ASSIGN_OR_DIE(ANSSymbolReader reader, + ANSSymbolReader::Create(&decoded_codes, &br)); for (const Token& symbol : input_values) { uint32_t read_symbol = @@ -156,6 +161,7 @@ TEST(ANSTest, RandomUnbalancedStreamRoundtripBig) { } TEST(ANSTest, UintConfigRoundtrip) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); for (size_t log_alpha_size = 5; log_alpha_size <= 8; log_alpha_size++) { std::vector uint_config; std::vector uint_config_dec; @@ -168,7 +174,7 @@ TEST(ANSTest, UintConfigRoundtrip) { } uint_config.emplace_back(log_alpha_size, 0, 0); uint_config_dec.resize(uint_config.size()); - BitWriter writer; + BitWriter writer{memory_manager}; BitWriter::Allotment allotment(&writer, 10 * uint_config.size()); EncodeUintConfigs(uint_config, &writer, log_alpha_size); allotment.ReclaimAndCharge(&writer, 0, nullptr); @@ -185,6 +191,7 @@ TEST(ANSTest, UintConfigRoundtrip) { } void TestCheckpointing(bool ans, bool lz77) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); std::vector> input_values(1); for (size_t i = 0; i < 1024; i++) { input_values[0].emplace_back(0, i % 4); @@ -206,11 +213,11 @@ void TestCheckpointing(bool ans, bool lz77) { : HistogramParams::LZ77Method::kNone; params.force_huffman = !ans; - BitWriter writer; + BitWriter writer{memory_manager}; { auto input_values_copy = input_values; - BuildAndEncodeHistograms(params, 1, input_values_copy, &codes, &context_map, - &writer, 0, nullptr); + BuildAndEncodeHistograms(memory_manager, params, 1, input_values_copy, + &codes, &context_map, &writer, 0, nullptr); WriteTokens(input_values_copy[0], codes, context_map, 0, &writer, 0, nullptr); writer.ZeroPadToByte(); @@ -225,9 +232,11 @@ void TestCheckpointing(bool ans, bool lz77) { std::vector dec_context_map; ANSCode decoded_codes; - ASSERT_TRUE(DecodeHistograms(&br, 1, &decoded_codes, &dec_context_map)); + ASSERT_TRUE(DecodeHistograms(memory_manager, &br, 1, &decoded_codes, + &dec_context_map)); ASSERT_EQ(dec_context_map, context_map); - ANSSymbolReader reader(&decoded_codes, &br); + JXL_ASSIGN_OR_DIE(ANSSymbolReader reader, + ANSSymbolReader::Create(&decoded_codes, &br)); ANSSymbolReader::Checkpoint checkpoint; size_t br_pos = 0; diff --git a/lib/jxl/bit_reader_test.cc b/lib/jxl/bit_reader_test.cc index b2d5773d157..74466941617 100644 --- a/lib/jxl/bit_reader_test.cc +++ b/lib/jxl/bit_reader_test.cc @@ -3,10 +3,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -#include -#include - #include +#include +#include #include #include @@ -52,12 +51,13 @@ struct Symbol { // Reading from output gives the same values. TEST(BitReaderTest, TestRoundTrip) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); test::ThreadPoolForTests pool(8); EXPECT_TRUE(RunOnPool( pool.get(), 0, 1000, ThreadPool::NoInit, - [](const uint32_t task, size_t /* thread */) { + [&memory_manager](const uint32_t task, size_t /* thread */) { constexpr size_t kMaxBits = 8000; - BitWriter writer; + BitWriter writer{memory_manager}; BitWriter::Allotment allotment(&writer, kMaxBits); std::vector symbols; @@ -86,14 +86,15 @@ TEST(BitReaderTest, TestRoundTrip) { // SkipBits is the same as reading that many bits. TEST(BitReaderTest, TestSkip) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); test::ThreadPoolForTests pool(8); EXPECT_TRUE(RunOnPool( pool.get(), 0, 96, ThreadPool::NoInit, - [](const uint32_t task, size_t /* thread */) { + [&memory_manager](const uint32_t task, size_t /* thread */) { constexpr size_t kSize = 100; for (size_t skip = 0; skip < 128; ++skip) { - BitWriter writer; + BitWriter writer{memory_manager}; BitWriter::Allotment allotment(&writer, kSize * kBitsPerByte); // Start with "task" 1-bits. for (size_t i = 0; i < task; ++i) { @@ -142,11 +143,12 @@ TEST(BitReaderTest, TestSkip) { // Verifies byte order and different groupings of bits. TEST(BitReaderTest, TestOrder) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); constexpr size_t kMaxBits = 16; // u(1) - bits written into LSBs of first byte { - BitWriter writer; + BitWriter writer{memory_manager}; BitWriter::Allotment allotment(&writer, kMaxBits); for (size_t i = 0; i < 5; ++i) { writer.Write(1, 1); @@ -168,7 +170,7 @@ TEST(BitReaderTest, TestOrder) { // u(8) - get bytes in the same order { - BitWriter writer; + BitWriter writer{memory_manager}; BitWriter::Allotment allotment(&writer, kMaxBits); writer.Write(8, 0xF8); writer.Write(8, 0x3F); @@ -183,7 +185,7 @@ TEST(BitReaderTest, TestOrder) { // u(16) - little-endian bytes { - BitWriter writer; + BitWriter writer{memory_manager}; BitWriter::Allotment allotment(&writer, kMaxBits); writer.Write(16, 0xF83F); @@ -197,7 +199,7 @@ TEST(BitReaderTest, TestOrder) { // Non-byte-aligned, mixed sizes { - BitWriter writer; + BitWriter writer{memory_manager}; BitWriter::Allotment allotment(&writer, kMaxBits); writer.Write(1, 1); writer.Write(3, 6); diff --git a/lib/jxl/coeff_order.cc b/lib/jxl/coeff_order.cc index 296a7cb2f0d..d98caec3192 100644 --- a/lib/jxl/coeff_order.cc +++ b/lib/jxl/coeff_order.cc @@ -5,20 +5,18 @@ #include "lib/jxl/coeff_order.h" -#include +#include #include +#include #include -#include "lib/jxl/ans_params.h" -#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" #include "lib/jxl/coeff_order_fwd.h" #include "lib/jxl/dec_ans.h" #include "lib/jxl/dec_bit_reader.h" -#include "lib/jxl/entropy_coder.h" #include "lib/jxl/lehmer_code.h" #include "lib/jxl/modular/encoding/encoding.h" -#include "lib/jxl/modular/modular_image.h" namespace jxl { @@ -57,13 +55,14 @@ Status ReadPermutation(size_t skip, size_t size, coeff_order_t* order, } // namespace -Status DecodePermutation(size_t skip, size_t size, coeff_order_t* order, - BitReader* br) { +Status DecodePermutation(JxlMemoryManager* memory_manager, size_t skip, + size_t size, coeff_order_t* order, BitReader* br) { std::vector context_map; ANSCode code; - JXL_RETURN_IF_ERROR( - DecodeHistograms(br, kPermutationContexts, &code, &context_map)); - ANSSymbolReader reader(&code, br); + JXL_RETURN_IF_ERROR(DecodeHistograms(memory_manager, br, kPermutationContexts, + &code, &context_map)); + JXL_ASSIGN_OR_RETURN(ANSSymbolReader reader, + ANSSymbolReader::Create(&code, br)); JXL_RETURN_IF_ERROR( ReadPermutation(skip, size, order, br, &reader, context_map)); if (!reader.CheckANSFinalState()) { @@ -92,18 +91,19 @@ Status DecodeCoeffOrder(AcStrategy acs, coeff_order_t* order, BitReader* br, } // namespace -Status DecodeCoeffOrders(uint16_t used_orders, uint32_t used_acs, - coeff_order_t* order, BitReader* br) { +Status DecodeCoeffOrders(JxlMemoryManager* memory_manager, uint16_t used_orders, + uint32_t used_acs, coeff_order_t* order, + BitReader* br) { uint16_t computed = 0; std::vector context_map; ANSCode code; - std::unique_ptr reader; + ANSSymbolReader reader; std::vector natural_order; // Bitstream does not have histograms if no coefficient order is used. if (used_orders != 0) { - JXL_RETURN_IF_ERROR( - DecodeHistograms(br, kPermutationContexts, &code, &context_map)); - reader = make_unique(&code, br); + JXL_RETURN_IF_ERROR(DecodeHistograms( + memory_manager, br, kPermutationContexts, &code, &context_map)); + JXL_ASSIGN_OR_RETURN(reader, ANSSymbolReader::Create(&code, br)); } uint32_t acs_mask = 0; for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { @@ -136,12 +136,12 @@ Status DecodeCoeffOrders(uint16_t used_orders, uint32_t used_acs, } else { for (size_t c = 0; c < 3; c++) { coeff_order_t* dest = used ? &order[CoeffOrderOffset(ord, c)] : nullptr; - JXL_RETURN_IF_ERROR(DecodeCoeffOrder(acs, dest, br, reader.get(), + JXL_RETURN_IF_ERROR(DecodeCoeffOrder(acs, dest, br, &reader, natural_order, context_map)); } } } - if (used_orders && !reader->CheckANSFinalState()) { + if (used_orders && !reader.CheckANSFinalState()) { return JXL_FAILURE("Invalid ANS stream"); } return true; diff --git a/lib/jxl/coeff_order.h b/lib/jxl/coeff_order.h index 395e2966420..1c1cec3ff57 100644 --- a/lib/jxl/coeff_order.h +++ b/lib/jxl/coeff_order.h @@ -6,6 +6,8 @@ #ifndef LIB_JXL_COEFF_ORDER_H_ #define LIB_JXL_COEFF_ORDER_H_ +#include + #include #include #include @@ -53,12 +55,13 @@ constexpr JXL_MAYBE_UNUSED uint32_t kPermutationContexts = 8; uint32_t CoeffOrderContext(uint32_t val); -Status DecodeCoeffOrders(uint16_t used_orders, uint32_t used_acs, - coeff_order_t* order, BitReader* br); - -Status DecodePermutation(size_t skip, size_t size, coeff_order_t* order, +Status DecodeCoeffOrders(JxlMemoryManager* memory_manager, uint16_t used_orders, + uint32_t used_acs, coeff_order_t* order, BitReader* br); +Status DecodePermutation(JxlMemoryManager* memory_manager, size_t skip, + size_t size, coeff_order_t* order, BitReader* br); + } // namespace jxl #endif // LIB_JXL_COEFF_ORDER_H_ diff --git a/lib/jxl/coeff_order_test.cc b/lib/jxl/coeff_order_test.cc index a88dcfa2746..ed0131f92b8 100644 --- a/lib/jxl/coeff_order_test.cc +++ b/lib/jxl/coeff_order_test.cc @@ -5,6 +5,8 @@ #include "lib/jxl/coeff_order.h" +#include + #include #include // iota #include @@ -16,6 +18,7 @@ #include "lib/jxl/coeff_order_fwd.h" #include "lib/jxl/dec_bit_reader.h" #include "lib/jxl/enc_coeff_order.h" +#include "lib/jxl/test_utils.h" #include "lib/jxl/testing.h" namespace jxl { @@ -23,14 +26,15 @@ namespace { void RoundtripPermutation(coeff_order_t* perm, coeff_order_t* out, size_t len, size_t* size) { - BitWriter writer; + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); + BitWriter writer{memory_manager}; EncodePermutation(perm, 0, len, &writer, 0, nullptr); writer.ZeroPadToByte(); Status status = true; { BitReader reader(writer.GetSpan()); BitReaderScopedCloser closer(&reader, &status); - ASSERT_TRUE(DecodePermutation(0, len, out, &reader)); + ASSERT_TRUE(DecodePermutation(memory_manager, 0, len, out, &reader)); } ASSERT_TRUE(status); *size = writer.GetSpan().size(); diff --git a/lib/jxl/dec_ans.cc b/lib/jxl/dec_ans.cc index 8b7b54ce910..088d0a26bfd 100644 --- a/lib/jxl/dec_ans.cc +++ b/lib/jxl/dec_ans.cc @@ -5,8 +5,9 @@ #include "lib/jxl/dec_ans.h" -#include +#include +#include #include #include "lib/jxl/ans_common.h" @@ -182,9 +183,11 @@ Status ReadHistogram(int precision_bits, std::vector* counts, } // namespace -Status DecodeANSCodes(const size_t num_histograms, +Status DecodeANSCodes(JxlMemoryManager* memory_manager, + const size_t num_histograms, const size_t max_alphabet_size, BitReader* in, ANSCode* result) { + result->memory_manager = memory_manager; result->degenerate_symbols.resize(num_histograms, -1); if (result->use_prefix_code) { JXL_ASSERT(max_alphabet_size <= 1 << PREFIX_MAX_BITS); @@ -220,9 +223,9 @@ Status DecodeANSCodes(const size_t num_histograms, } } else { JXL_ASSERT(max_alphabet_size <= ANS_MAX_ALPHABET_SIZE); - result->alias_tables = - AllocateArray(num_histograms * (1 << result->log_alpha_size) * - sizeof(AliasTable::Entry)); + size_t alloc_size = num_histograms * (1 << result->log_alpha_size) * + sizeof(AliasTable::Entry); + result->alias_tables = AllocateArray(alloc_size); AliasTable::Entry* alias_tables = reinterpret_cast(result->alias_tables.get()); for (size_t c = 0; c < num_histograms; ++c) { @@ -325,7 +328,8 @@ void ANSCode::UpdateMaxNumBits(size_t ctx, size_t symbol) { max_num_bits = std::max(max_num_bits, total_bits); } -Status DecodeHistograms(BitReader* br, size_t num_contexts, ANSCode* code, +Status DecodeHistograms(JxlMemoryManager* memory_manager, BitReader* br, + size_t num_contexts, ANSCode* code, std::vector* context_map, bool disallow_lz77) { JXL_RETURN_IF_ERROR(Bundle::Read(br, &code->lz77)); if (code->lz77.enabled) { @@ -339,7 +343,8 @@ Status DecodeHistograms(BitReader* br, size_t num_contexts, ANSCode* code, size_t num_histograms = 1; context_map->resize(num_contexts); if (num_contexts > 1) { - JXL_RETURN_IF_ERROR(DecodeContextMap(context_map, &num_histograms, br)); + JXL_RETURN_IF_ERROR( + DecodeContextMap(memory_manager, context_map, &num_histograms, br)); } JXL_DEBUG_V( 4, "Decoded context map of size %" PRIuS " and %" PRIuS " histograms", @@ -355,9 +360,44 @@ Status DecodeHistograms(BitReader* br, size_t num_contexts, ANSCode* code, JXL_RETURN_IF_ERROR( DecodeUintConfigs(code->log_alpha_size, &code->uint_config, br)); const size_t max_alphabet_size = 1 << code->log_alpha_size; - JXL_RETURN_IF_ERROR( - DecodeANSCodes(num_histograms, max_alphabet_size, br, code)); + JXL_RETURN_IF_ERROR(DecodeANSCodes(memory_manager, num_histograms, + max_alphabet_size, br, code)); return true; } +StatusOr ANSSymbolReader::Create(const ANSCode* code, + BitReader* JXL_RESTRICT br, + size_t distance_multiplier) { + return ANSSymbolReader(code, br, distance_multiplier); +} + +ANSSymbolReader::ANSSymbolReader(const ANSCode* code, + BitReader* JXL_RESTRICT br, + size_t distance_multiplier) + : alias_tables_( + reinterpret_cast(code->alias_tables.get())), + huffman_data_(code->huffman_data.data()), + use_prefix_code_(code->use_prefix_code), + configs(code->uint_config.data()) { + if (!use_prefix_code_) { + state_ = static_cast(br->ReadFixedBits<32>()); + log_alpha_size_ = code->log_alpha_size; + log_entry_size_ = ANS_LOG_TAB_SIZE - code->log_alpha_size; + entry_size_minus_1_ = (1 << log_entry_size_) - 1; + } else { + state_ = (ANS_SIGNATURE << 16u); + } + if (!code->lz77.enabled) return; + lz77_window_storage_ = AllocateArray(kWindowSize * sizeof(uint32_t)); + lz77_window_ = reinterpret_cast(lz77_window_storage_.get()); + lz77_ctx_ = code->lz77.nonserialized_distance_context; + lz77_length_uint_ = code->lz77.length_uint_config; + lz77_threshold_ = code->lz77.min_symbol; + lz77_min_length_ = code->lz77.min_length; + num_special_distances_ = distance_multiplier == 0 ? 0 : kNumSpecialDistances; + for (size_t i = 0; i < num_special_distances_; i++) { + special_distances_[i] = SpecialDistance(i, distance_multiplier); + } +} + } // namespace jxl diff --git a/lib/jxl/dec_ans.h b/lib/jxl/dec_ans.h index dc56df12165..abdab7364d3 100644 --- a/lib/jxl/dec_ans.h +++ b/lib/jxl/dec_ans.h @@ -9,6 +9,7 @@ // Library to decode the ANS population counts from the bit-stream and build a // decoding table from them. +#include #include #include @@ -152,6 +153,7 @@ struct ANSCode { // Maximum number of bits necessary to represent the result of a // ReadHybridUint call done with this ANSCode. size_t max_num_bits = 0; + JxlMemoryManager* memory_manager; void UpdateMaxNumBits(size_t ctx, size_t symbol); }; @@ -159,36 +161,9 @@ class ANSSymbolReader { public: // Invalid symbol reader, to be overwritten. ANSSymbolReader() = default; - ANSSymbolReader(const ANSCode* code, BitReader* JXL_RESTRICT br, - size_t distance_multiplier = 0) - : alias_tables_( - reinterpret_cast(code->alias_tables.get())), - huffman_data_(code->huffman_data.data()), - use_prefix_code_(code->use_prefix_code), - configs(code->uint_config.data()) { - if (!use_prefix_code_) { - state_ = static_cast(br->ReadFixedBits<32>()); - log_alpha_size_ = code->log_alpha_size; - log_entry_size_ = ANS_LOG_TAB_SIZE - code->log_alpha_size; - entry_size_minus_1_ = (1 << log_entry_size_) - 1; - } else { - state_ = (ANS_SIGNATURE << 16u); - } - if (!code->lz77.enabled) return; - // a std::vector incurs unacceptable decoding speed loss because of - // initialization. - lz77_window_storage_ = AllocateArray(kWindowSize * sizeof(uint32_t)); - lz77_window_ = reinterpret_cast(lz77_window_storage_.get()); - lz77_ctx_ = code->lz77.nonserialized_distance_context; - lz77_length_uint_ = code->lz77.length_uint_config; - lz77_threshold_ = code->lz77.min_symbol; - lz77_min_length_ = code->lz77.min_length; - num_special_distances_ = - distance_multiplier == 0 ? 0 : kNumSpecialDistances; - for (size_t i = 0; i < num_special_distances_; i++) { - special_distances_[i] = SpecialDistance(i, distance_multiplier); - } - } + static StatusOr Create(const ANSCode* code, + BitReader* JXL_RESTRICT br, + size_t distance_multiplier = 0); JXL_INLINE size_t ReadSymbolANSWithoutRefill(const size_t histo_idx, BitReader* JXL_RESTRICT br) { @@ -471,6 +446,9 @@ class ANSSymbolReader { } private: + ANSSymbolReader(const ANSCode* code, BitReader* JXL_RESTRICT br, + size_t distance_multiplier); + const AliasTable::Entry* JXL_RESTRICT alias_tables_; // not owned const HuffmanDecodingData* huffman_data_; bool use_prefix_code_; @@ -482,6 +460,8 @@ class ANSSymbolReader { // LZ77 structures and constants. static constexpr size_t kWindowMask = kWindowSize - 1; + // a std::vector incurs unacceptable decoding speed loss because of + // initialization. CacheAlignedUniquePtr lz77_window_storage_; uint32_t* lz77_window_ = nullptr; uint32_t num_decoded_ = 0; @@ -495,7 +475,8 @@ class ANSSymbolReader { uint32_t num_special_distances_{}; }; -Status DecodeHistograms(BitReader* br, size_t num_contexts, ANSCode* code, +Status DecodeHistograms(JxlMemoryManager* memory_manager, BitReader* br, + size_t num_contexts, ANSCode* code, std::vector* context_map, bool disallow_lz77 = false); diff --git a/lib/jxl/dec_cache.cc b/lib/jxl/dec_cache.cc index c701b37fb42..2e2bef91317 100644 --- a/lib/jxl/dec_cache.cc +++ b/lib/jxl/dec_cache.cc @@ -252,9 +252,9 @@ Status PassesDecoderState::PreparePipeline(const FrameHeader& frame_header, (void)linear; if (main_output.callback.IsPresent() || main_output.buffer) { - builder.AddStage(GetWriteToOutputStage(main_output, width, height, - has_alpha, unpremul_alpha, alpha_c, - undo_orientation, extra_output)); + builder.AddStage(GetWriteToOutputStage( + main_output, width, height, has_alpha, unpremul_alpha, alpha_c, + undo_orientation, extra_output, memory_manager)); } else { builder.AddStage( GetWriteToImageBundleStage(decoded, output_encoding_info)); diff --git a/lib/jxl/dec_context_map.cc b/lib/jxl/dec_context_map.cc index baff87fa493..c97eb8ec65c 100644 --- a/lib/jxl/dec_context_map.cc +++ b/lib/jxl/dec_context_map.cc @@ -5,14 +5,14 @@ #include "lib/jxl/dec_context_map.h" +#include + #include #include #include -#include "lib/jxl/ans_params.h" #include "lib/jxl/base/status.h" #include "lib/jxl/dec_ans.h" -#include "lib/jxl/entropy_coder.h" #include "lib/jxl/inverse_mtf-inl.h" namespace jxl { @@ -40,7 +40,8 @@ Status VerifyContextMap(const std::vector& context_map, } // namespace -Status DecodeContextMap(std::vector* context_map, size_t* num_htrees, +Status DecodeContextMap(JxlMemoryManager* memory_manager, + std::vector* context_map, size_t* num_htrees, BitReader* input) { bool is_simple = static_cast(input->ReadFixedBits<1>()); if (is_simple) { @@ -61,9 +62,10 @@ Status DecodeContextMap(std::vector* context_map, size_t* num_htrees, // in malicious bitstreams by making every context map require its own // context map. JXL_RETURN_IF_ERROR( - DecodeHistograms(input, 1, &code, &sink_ctx_map, + DecodeHistograms(memory_manager, input, 1, &code, &sink_ctx_map, /*disallow_lz77=*/context_map->size() <= 2)); - ANSSymbolReader reader(&code, input); + JXL_ASSIGN_OR_RETURN(ANSSymbolReader reader, + ANSSymbolReader::Create(&code, input)); size_t i = 0; uint32_t maxsym = 0; while (i < context_map->size()) { diff --git a/lib/jxl/dec_context_map.h b/lib/jxl/dec_context_map.h index 95b8a0ca92f..8edad38a3a6 100644 --- a/lib/jxl/dec_context_map.h +++ b/lib/jxl/dec_context_map.h @@ -6,9 +6,10 @@ #ifndef LIB_JXL_DEC_CONTEXT_MAP_H_ #define LIB_JXL_DEC_CONTEXT_MAP_H_ -#include -#include +#include +#include +#include #include #include "lib/jxl/dec_bit_reader.h" @@ -22,7 +23,8 @@ constexpr size_t kMaxClusters = 256; // context_map->size() must be the number of possible context ids. // Sets *num_htrees to the number of different histogram ids in // *context_map. -Status DecodeContextMap(std::vector* context_map, size_t* num_htrees, +Status DecodeContextMap(JxlMemoryManager* memory_manager, + std::vector* context_map, size_t* num_htrees, BitReader* input); } // namespace jxl diff --git a/lib/jxl/dec_frame.cc b/lib/jxl/dec_frame.cc index 2ded8290926..c3511d3b8cf 100644 --- a/lib/jxl/dec_frame.cc +++ b/lib/jxl/dec_frame.cc @@ -64,8 +64,8 @@ Status DecodeGlobalDCInfo(BitReader* reader, bool is_jpeg, PassesDecoderState* state, ThreadPool* pool) { JXL_RETURN_IF_ERROR(state->shared_storage.quantizer.Decode(reader)); - JXL_RETURN_IF_ERROR( - DecodeBlockCtxMap(reader, &state->shared_storage.block_ctx_map)); + JXL_RETURN_IF_ERROR(DecodeBlockCtxMap(state->memory_manager(), reader, + &state->shared_storage.block_ctx_map)); JXL_RETURN_IF_ERROR(state->shared_storage.cmap.DecodeDC(reader)); @@ -136,6 +136,7 @@ Status FrameDecoder::InitFrame(BitReader* JXL_RESTRICT br, ImageBundle* decoded, bool is_preview) { decoded_ = decoded; JXL_ASSERT(is_finalized_); + JxlMemoryManager* memory_manager = decoded_->memory_manager(); // Reset the dequantization matrices to their default values. dec_state_->shared_storage.matrices = DequantMatrices(); @@ -171,7 +172,8 @@ Status FrameDecoder::InitFrame(BitReader* JXL_RESTRICT br, ImageBundle* decoded, NumTocEntries(num_groups, frame_dim_.num_dc_groups, num_passes); std::vector sizes; std::vector permutation; - JXL_RETURN_IF_ERROR(ReadToc(toc_entries, br, &sizes, &permutation)); + JXL_RETURN_IF_ERROR( + ReadToc(memory_manager, toc_entries, br, &sizes, &permutation)); bool have_permutation = !permutation.empty(); toc_.resize(toc_entries); section_sizes_sum_ = 0; @@ -265,10 +267,11 @@ Status FrameDecoder::InitFrameOutput() { Status FrameDecoder::ProcessDCGlobal(BitReader* br) { PassesSharedState& shared = dec_state_->shared_storage; + JxlMemoryManager* memory_manager = shared.memory_manager; if (frame_header_.flags & FrameHeader::kPatches) { bool uses_extra_channels = false; JXL_RETURN_IF_ERROR(shared.image_features.patches.Decode( - br, frame_dim_.xsize_padded, frame_dim_.ysize_padded, + memory_manager, br, frame_dim_.xsize_padded, frame_dim_.ysize_padded, &uses_extra_channels)); if (uses_extra_channels && frame_header_.upsampling != 1) { for (size_t ecups : frame_header_.extra_channel_upsampling) { @@ -285,7 +288,7 @@ Status FrameDecoder::ProcessDCGlobal(BitReader* br) { shared.image_features.splines.Clear(); if (frame_header_.flags & FrameHeader::kSplines) { JXL_RETURN_IF_ERROR(shared.image_features.splines.Decode( - br, frame_dim_.xsize * frame_dim_.ysize)); + memory_manager, br, frame_dim_.xsize * frame_dim_.ysize)); } if (frame_header_.flags & FrameHeader::kNoise) { JXL_RETURN_IF_ERROR(DecodeNoise(br, &shared.image_features.noise_params)); @@ -392,15 +395,16 @@ Status FrameDecoder::ProcessACGlobal(BitReader* br) { for (size_t i = 0; i < frame_header_.passes.num_passes; i++) { uint16_t used_orders = U32Coder::Read(kOrderEnc, br); JXL_RETURN_IF_ERROR(DecodeCoeffOrders( - used_orders, dec_state_->used_acs, + memory_manager, used_orders, dec_state_->used_acs, &dec_state_->shared_storage .coeff_orders[i * dec_state_->shared_storage.coeff_order_size], br)); size_t num_contexts = dec_state_->shared->num_histograms * dec_state_->shared_storage.block_ctx_map.NumACContexts(); - JXL_RETURN_IF_ERROR(DecodeHistograms( - br, num_contexts, &dec_state_->code[i], &dec_state_->context_map[i])); + JXL_RETURN_IF_ERROR(DecodeHistograms(memory_manager, br, num_contexts, + &dec_state_->code[i], + &dec_state_->context_map[i])); // Add extra values to enable the cheat in hot loop of DecodeACVarBlock. dec_state_->context_map[i].resize( num_contexts + kZeroDensityContextLimit - kZeroDensityContextCount); diff --git a/lib/jxl/dec_group.cc b/lib/jxl/dec_group.cc index c091e6b1dd3..e189b566491 100644 --- a/lib/jxl/dec_group.cc +++ b/lib/jxl/dec_group.cc @@ -606,8 +606,10 @@ struct GetBlockFromBitstream : public GetBlock { } ctx_offset[pass] = cur_histogram * block_ctx_map->NumACContexts(); - decoders[pass] = - ANSSymbolReader(&dec_state->code[pass + first_pass], readers[pass]); + JXL_ASSIGN_OR_RETURN( + decoders[pass], + ANSSymbolReader::Create(&dec_state->code[pass + first_pass], + readers[pass])); } nzeros_stride = group_dec_cache->num_nzeroes[0].PixelsPerRow(); for (size_t i = 0; i < num_passes; i++) { diff --git a/lib/jxl/dec_modular.cc b/lib/jxl/dec_modular.cc index 38283c08174..4c67e5d0ce4 100644 --- a/lib/jxl/dec_modular.cc +++ b/lib/jxl/dec_modular.cc @@ -197,9 +197,10 @@ Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader, std::min(static_cast(1 << 22), 1024 + frame_dim.xsize * frame_dim.ysize * (nb_chans + nb_extra) / 16); - JXL_RETURN_IF_ERROR(DecodeTree(reader, &tree, tree_size_limit)); JXL_RETURN_IF_ERROR( - DecodeHistograms(reader, (tree.size() + 1) / 2, &code, &context_map)); + DecodeTree(memory_manager, reader, &tree, tree_size_limit)); + JXL_RETURN_IF_ERROR(DecodeHistograms( + memory_manager, reader, (tree.size() + 1) / 2, &code, &context_map)); } } if (!do_color) nb_chans = 0; diff --git a/lib/jxl/dec_patch_dictionary.cc b/lib/jxl/dec_patch_dictionary.cc index c728e99d0a1..605818de43e 100644 --- a/lib/jxl/dec_patch_dictionary.cc +++ b/lib/jxl/dec_patch_dictionary.cc @@ -5,11 +5,12 @@ #include "lib/jxl/dec_patch_dictionary.h" -#include -#include +#include #include #include +#include +#include #include #include @@ -25,14 +26,16 @@ namespace jxl { -Status PatchDictionary::Decode(BitReader* br, size_t xsize, size_t ysize, +Status PatchDictionary::Decode(JxlMemoryManager* memory_manager, BitReader* br, + size_t xsize, size_t ysize, bool* uses_extra_channels) { positions_.clear(); std::vector context_map; ANSCode code; - JXL_RETURN_IF_ERROR( - DecodeHistograms(br, kNumPatchDictionaryContexts, &code, &context_map)); - ANSSymbolReader decoder(&code, br); + JXL_RETURN_IF_ERROR(DecodeHistograms( + memory_manager, br, kNumPatchDictionaryContexts, &code, &context_map)); + JXL_ASSIGN_OR_RETURN(ANSSymbolReader decoder, + ANSSymbolReader::Create(&code, br)); auto read_num = [&](size_t context) { size_t r = decoder.ReadHybridUint(context, br, context_map); diff --git a/lib/jxl/dec_patch_dictionary.h b/lib/jxl/dec_patch_dictionary.h index ad0f5d6a04b..0efdce56d10 100644 --- a/lib/jxl/dec_patch_dictionary.h +++ b/lib/jxl/dec_patch_dictionary.h @@ -8,6 +8,7 @@ // Chooses reference patches, and avoids encoding them once per occurrence. +#include #include #include @@ -99,8 +100,8 @@ class PatchDictionary { bool HasAny() const { return !positions_.empty(); } - Status Decode(BitReader* br, size_t xsize, size_t ysize, - bool* uses_extra_channels); + Status Decode(JxlMemoryManager* memory_manager, BitReader* br, size_t xsize, + size_t ysize, bool* uses_extra_channels); void Clear() { positions_.clear(); diff --git a/lib/jxl/decode.cc b/lib/jxl/decode.cc index 82603d1abc4..a7e0237fd80 100644 --- a/lib/jxl/decode.cc +++ b/lib/jxl/decode.cc @@ -26,9 +26,7 @@ #if JPEGXL_ENABLE_BOXES || JPEGXL_ENABLE_TRANSCODE_JPEG #include "lib/jxl/box_content_decoder.h" #endif -#include "lib/jxl/dec_external_image.h" #include "lib/jxl/dec_frame.h" -#include "lib/jxl/dec_modular.h" #if JPEGXL_ENABLE_TRANSCODE_JPEG #include "lib/jxl/decode_to_jpeg.h" #endif @@ -39,7 +37,6 @@ #include "lib/jxl/icc_codec.h" #include "lib/jxl/image_bundle.h" #include "lib/jxl/memory_manager_internal.h" -#include "lib/jxl/toc.h" namespace { @@ -362,7 +359,7 @@ struct JxlDecoderStruct { bool got_transform_data; // To skip everything before ICC. bool got_all_headers; // Codestream metadata headers. bool post_headers; // Already decoding pixels. - jxl::ICCReader icc_reader; + std::unique_ptr icc_reader; jxl::JxlDecoderFrameIndexBox frame_index_box; // This means either we actually got the preview image, or determined we // cannot get it or there is none. @@ -685,7 +682,7 @@ void JxlDecoderRewindDecodingState(JxlDecoder* dec) { dec->got_transform_data = false; dec->got_all_headers = false; dec->post_headers = false; - dec->icc_reader.Reset(); + if (dec->icc_reader) dec->icc_reader->Reset(); dec->got_preview_image = false; dec->preview_frame = false; dec->file_pos = 0; @@ -1048,7 +1045,7 @@ JxlDecoderStatus JxlDecoderReadAllHeaders(JxlDecoder* dec) { if (dec->metadata.m.color_encoding.WantICC()) { jxl::Status status = - dec->icc_reader.Init(reader.get(), dec->memory_limit_base); + dec->icc_reader->Init(reader.get(), dec->memory_limit_base); // Always check AllReadsWithinBounds, not all the C++ decoder implementation // handles reader out of bounds correctly yet (e.g. context map). Not // checking AllReadsWithinBounds can cause reader->Close() to trigger an @@ -1061,8 +1058,8 @@ JxlDecoderStatus JxlDecoderReadAllHeaders(JxlDecoder* dec) { // Other non-successful status is an error return JXL_DEC_ERROR; } - PaddedBytes decoded_icc; - status = dec->icc_reader.Process(reader.get(), &decoded_icc); + PaddedBytes decoded_icc{&dec->memory_manager}; + status = dec->icc_reader->Process(reader.get(), &decoded_icc); if (status.code() == StatusCode::kNotEnoughBytes) { return dec->RequestMoreInput(); } @@ -1186,6 +1183,10 @@ JxlDecoderStatus JxlDecoderProcessCodestream(JxlDecoder* dec) { return JXL_DEC_SUCCESS; } + if (!dec->icc_reader) { + dec->icc_reader.reset(new ICCReader(&dec->memory_manager)); + } + if (!dec->got_all_headers) { JxlDecoderStatus status = JxlDecoderReadAllHeaders(dec); if (status != JXL_DEC_SUCCESS) return status; diff --git a/lib/jxl/decode_test.cc b/lib/jxl/decode_test.cc index 7b26fe6750c..1758d8376b7 100644 --- a/lib/jxl/decode_test.cc +++ b/lib/jxl/decode_test.cc @@ -259,6 +259,7 @@ struct TestCodestreamParams { std::vector CreateTestJXLCodestream( Span pixels, size_t xsize, size_t ysize, size_t num_channels, const TestCodestreamParams& params) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); // Compress the pixels with JPEG XL. bool grayscale = (num_channels <= 2); bool have_alpha = ((num_channels & 1) == 0); @@ -317,8 +318,8 @@ std::vector CreateTestJXLCodestream( Bytes(jpeg_bytes).AppendTo(*params.jpeg_codestream); EXPECT_TRUE(jxl::jpeg::DecodeImageJPG( jxl::Bytes(jpeg_bytes.data(), jpeg_bytes.size()), &io)); - EXPECT_TRUE( - EncodeJPEGData(*io.Main().jpeg_data, &jpeg_data, params.cparams)); + EXPECT_TRUE(EncodeJPEGData(memory_manager, *io.Main().jpeg_data, + &jpeg_data, params.cparams)); io.metadata.m.xyb_encoded = false; } else { JXL_ABORT( @@ -722,7 +723,8 @@ std::vector GetTestHeader(size_t xsize, size_t ysize, bool have_container, bool metadata_default, bool insert_extra_box, const jxl::IccBytes& icc_profile) { - jxl::BitWriter writer; + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); + jxl::BitWriter writer{memory_manager}; jxl::BitWriter::Allotment allotment(&writer, 65536); // Large enough if (have_container) { @@ -3854,6 +3856,7 @@ struct StreamPositions { void AnalyzeCodestream(const std::vector& data, StreamPositions* streampos) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); // Unbox data to codestream and mark where it is broken up by boxes. std::vector codestream; std::vector> breakpoints; @@ -3939,8 +3942,9 @@ void AnalyzeCodestream(const std::vector& data, frame_header.passes.num_passes); std::vector section_offsets; std::vector section_sizes; - ASSERT_TRUE(ReadGroupOffsets(toc_entries, &br, §ion_offsets, - §ion_sizes, &groups_total_size)); + ASSERT_TRUE(ReadGroupOffsets(memory_manager, toc_entries, &br, + §ion_offsets, §ion_sizes, + &groups_total_size)); EXPECT_EQ(br.TotalBitsConsumed() % jxl::kBitsPerByte, 0); size_t sections_start = br.TotalBitsConsumed() / jxl::kBitsPerByte; p.toc_end = add_offset(sections_start); @@ -5015,7 +5019,7 @@ JXL_TRANSCODE_JPEG_TEST(DecodeTest, JPEGReconstructionTest) { jxl::CodecInOut orig_io{memory_manager}; ASSERT_TRUE(jxl::jpeg::DecodeImageJPG(jxl::Bytes(orig), &orig_io)); orig_io.metadata.m.xyb_encoded = false; - jxl::BitWriter writer; + jxl::BitWriter writer{memory_manager}; ASSERT_TRUE(WriteCodestreamHeaders(&orig_io.metadata, &writer, nullptr)); writer.ZeroPadToByte(); jxl::CompressParams cparams; @@ -5027,8 +5031,8 @@ JXL_TRANSCODE_JPEG_TEST(DecodeTest, JPEGReconstructionTest) { /*aux_out=*/nullptr)); std::vector jpeg_data; - ASSERT_TRUE( - EncodeJPEGData(*orig_io.Main().jpeg_data.get(), &jpeg_data, cparams)); + ASSERT_TRUE(EncodeJPEGData(memory_manager, *orig_io.Main().jpeg_data.get(), + &jpeg_data, cparams)); std::vector container; jxl::Bytes(jxl::kContainerHeader).AppendTo(container); jxl::AppendBoxHeader(jxl::MakeBoxType("jbrd"), jpeg_data.size(), false, diff --git a/lib/jxl/enc_ans.cc b/lib/jxl/enc_ans.cc index 5e59790b1eb..be2668ac3af 100644 --- a/lib/jxl/enc_ans.cc +++ b/lib/jxl/enc_ans.cc @@ -5,16 +5,14 @@ #include "lib/jxl/enc_ans.h" +#include #include -#include #include #include #include #include #include -#include -#include #include #include #include @@ -439,6 +437,7 @@ uint32_t ComputeBestMethod( // Returns an estimate of the cost of encoding this histogram and the // corresponding data. size_t BuildAndStoreANSEncodingData( + JxlMemoryManager* memory_manager, HistogramParams::ANSHistogramStrategy ans_histogram_strategy, const ANSHistBin* histogram, size_t alphabet_size, size_t log_alpha_size, bool use_prefix_code, ANSEncSymbolInfo* info, BitWriter* writer) { @@ -454,7 +453,7 @@ size_t BuildAndStoreANSEncodingData( std::vector depths(alphabet_size); std::vector bits(alphabet_size); if (writer == nullptr) { - BitWriter tmp_writer; + BitWriter tmp_writer{memory_manager}; BitWriter::Allotment allotment( &tmp_writer, 8 * alphabet_size + 8); // safe upper bound BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(), @@ -715,7 +714,7 @@ class HistogramBuilder { // NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge. size_t BuildAndStoreEntropyCodes( - const HistogramParams& params, + JxlMemoryManager* memory_manager, const HistogramParams& params, const std::vector>& tokens, EntropyEncodingData* codes, std::vector* context_map, BitWriter* writer, size_t layer, AuxOut* aux_out) const { @@ -810,14 +809,15 @@ class HistogramBuilder { codes->encoding_info.back().resize(alphabet_size); BitWriter* histo_writer = writer; if (params.streaming_mode) { - codes->encoded_histograms.emplace_back(); + codes->encoded_histograms.emplace_back(memory_manager); histo_writer = &codes->encoded_histograms.back(); } BitWriter::Allotment allotment(histo_writer, 256 + alphabet_size * 24); cost += BuildAndStoreANSEncodingData( - params.ans_histogram_strategy, clustered_histograms[c].data_.data(), - alphabet_size, log_alpha_size, codes->use_prefix_code, - codes->encoding_info.back().data(), histo_writer); + memory_manager, params.ans_histogram_strategy, + clustered_histograms[c].data_.data(), alphabet_size, log_alpha_size, + codes->use_prefix_code, codes->encoding_info.back().data(), + histo_writer); allotment.FinishedHistogram(histo_writer); allotment.ReclaimAndCharge(histo_writer, layer, aux_out); if (params.streaming_mode) { @@ -1535,13 +1535,11 @@ void EncodeHistograms(const std::vector& context_map, allotment.ReclaimAndCharge(writer, layer, aux_out); } -size_t BuildAndEncodeHistograms(const HistogramParams& params, - size_t num_contexts, - std::vector>& tokens, - EntropyEncodingData* codes, - std::vector* context_map, - BitWriter* writer, size_t layer, - AuxOut* aux_out) { +size_t BuildAndEncodeHistograms( + JxlMemoryManager* memory_manager, const HistogramParams& params, + size_t num_contexts, std::vector>& tokens, + EntropyEncodingData* codes, std::vector* context_map, + BitWriter* writer, size_t layer, AuxOut* aux_out) { size_t total_bits = 0; codes->lz77.nonserialized_distance_context = num_contexts; std::vector> tokens_lz77; @@ -1655,19 +1653,20 @@ size_t BuildAndEncodeHistograms(const HistogramParams& params, CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE); codes->encoding_info.emplace_back(); codes->encoding_info.back().resize(alphabet_size); - codes->encoded_histograms.emplace_back(); + codes->encoded_histograms.emplace_back(memory_manager); BitWriter* histo_writer = &codes->encoded_histograms.back(); BitWriter::Allotment allotment(histo_writer, 256 + alphabet_size * 24); BuildAndStoreANSEncodingData( - params.ans_histogram_strategy, counts.data(), alphabet_size, - log_alpha_size, codes->use_prefix_code, + memory_manager, params.ans_histogram_strategy, counts.data(), + alphabet_size, log_alpha_size, codes->use_prefix_code, codes->encoding_info.back().data(), histo_writer); allotment.ReclaimAndCharge(histo_writer, 0, nullptr); } // Encode histograms. - total_bits += builder.BuildAndStoreEntropyCodes( - params, tokens, codes, context_map, writer, layer, aux_out); + total_bits += + builder.BuildAndStoreEntropyCodes(memory_manager, params, tokens, codes, + context_map, writer, layer, aux_out); allotment.FinishedHistogram(writer); allotment.ReclaimAndCharge(writer, layer, aux_out); diff --git a/lib/jxl/enc_ans.h b/lib/jxl/enc_ans.h index ae4d955a563..50df1c4f103 100644 --- a/lib/jxl/enc_ans.h +++ b/lib/jxl/enc_ans.h @@ -9,6 +9,8 @@ // Library to encode the ANS population counts to the bit-stream and encode // symbols based on the respective distributions. +#include + #include #include #include @@ -106,13 +108,11 @@ void EncodeHistograms(const std::vector& context_map, // estimate of the total bits used for encoding the stream. If `writer` == // nullptr, the bit estimate will not take into account the context map (which // does not get written if `num_contexts` == 1). -size_t BuildAndEncodeHistograms(const HistogramParams& params, - size_t num_contexts, - std::vector>& tokens, - EntropyEncodingData* codes, - std::vector* context_map, - BitWriter* writer, size_t layer, - AuxOut* aux_out); +size_t BuildAndEncodeHistograms( + JxlMemoryManager* memory_manager, const HistogramParams& params, + size_t num_contexts, std::vector>& tokens, + EntropyEncodingData* codes, std::vector* context_map, + BitWriter* writer, size_t layer, AuxOut* aux_out); // Write the tokens to a string. void WriteTokens(const std::vector& tokens, diff --git a/lib/jxl/enc_bit_writer.cc b/lib/jxl/enc_bit_writer.cc index 6e6e6a353c3..8c537964290 100644 --- a/lib/jxl/enc_bit_writer.cc +++ b/lib/jxl/enc_bit_writer.cc @@ -6,11 +6,11 @@ #include "lib/jxl/enc_bit_writer.h" #include -#include // memcpy + +#include // memcpy #include "lib/jxl/base/byte_order.h" #include "lib/jxl/base/printf_macros.h" -#include "lib/jxl/dec_bit_reader.h" #include "lib/jxl/enc_aux_out.h" namespace jxl { diff --git a/lib/jxl/enc_bit_writer.h b/lib/jxl/enc_bit_writer.h index 6f4865077d2..019be801622 100644 --- a/lib/jxl/enc_bit_writer.h +++ b/lib/jxl/enc_bit_writer.h @@ -8,9 +8,10 @@ // BitWriter class: unbuffered writes using unaligned 64-bit stores. -#include -#include +#include +#include +#include #include #include @@ -32,7 +33,8 @@ struct BitWriter { // yet zero-initialized). static constexpr size_t kMaxBitsPerCall = 56; - BitWriter() : bits_written_(0) {} + explicit BitWriter(JxlMemoryManager* memory_manager) + : bits_written_(0), storage_(memory_manager) {} // Disallow copying - may lead to bugs. BitWriter(const BitWriter&) = delete; @@ -42,6 +44,8 @@ struct BitWriter { size_t BitsWritten() const { return bits_written_; } + JxlMemoryManager* memory_manager() const { return storage_.memory_manager(); } + Span GetSpan() const { // Callers must ensure byte alignment to avoid uninitialized bits. JXL_ASSERT(bits_written_ % kBitsPerByte == 0); diff --git a/lib/jxl/enc_cache.cc b/lib/jxl/enc_cache.cc index 562d4f96d4d..b339b6dcb58 100644 --- a/lib/jxl/enc_cache.cc +++ b/lib/jxl/enc_cache.cc @@ -89,7 +89,7 @@ Status InitializePassesEncoder(const FrameHeader& frame_header, if (enc_state->initialize_global_state) { float scale = shared.quantizer.ScaleGlobalScale(enc_state->cparams.quant_ac_rescale); - DequantMatricesScaleDC(&shared.matrices, scale); + DequantMatricesScaleDC(memory_manager, &shared.matrices, scale); shared.quantizer.RecomputeFromGlobalScale(); } @@ -161,7 +161,8 @@ Status InitializePassesEncoder(const FrameHeader& frame_header, } ib.SetExtraChannels(std::move(extra_channels)); } - auto special_frame = std::unique_ptr(new BitWriter()); + auto special_frame = + std::unique_ptr(new BitWriter(memory_manager)); FrameInfo dc_frame_info; dc_frame_info.frame_type = FrameType::kDCFrame; dc_frame_info.dc_level = frame_header.dc_level + 1; diff --git a/lib/jxl/enc_coeff_order.cc b/lib/jxl/enc_coeff_order.cc index abe404ce9e7..6e2a5600b75 100644 --- a/lib/jxl/enc_coeff_order.cc +++ b/lib/jxl/enc_coeff_order.cc @@ -3,6 +3,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +#include + #include #include #include @@ -234,12 +236,14 @@ void TokenizePermutation(const coeff_order_t* JXL_RESTRICT order, size_t skip, void EncodePermutation(const coeff_order_t* JXL_RESTRICT order, size_t skip, size_t size, BitWriter* writer, int layer, AuxOut* aux_out) { + JxlMemoryManager* memory_manager = writer->memory_manager(); std::vector> tokens(1); TokenizePermutation(order, skip, size, tokens.data()); std::vector context_map; EntropyEncodingData codes; - BuildAndEncodeHistograms(HistogramParams(), kPermutationContexts, tokens, - &codes, &context_map, writer, layer, aux_out); + BuildAndEncodeHistograms(memory_manager, HistogramParams(), + kPermutationContexts, tokens, &codes, &context_map, + writer, layer, aux_out); WriteTokens(tokens[0], codes, context_map, 0, writer, layer, aux_out); } @@ -260,6 +264,7 @@ void EncodeCoeffOrders(uint16_t used_orders, const coeff_order_t* JXL_RESTRICT order, BitWriter* writer, size_t layer, AuxOut* JXL_RESTRICT aux_out) { + JxlMemoryManager* memory_manager = writer->memory_manager(); auto mem = hwy::AllocateAligned(AcStrategy::kMaxCoeffArea); uint16_t computed = 0; std::vector> tokens(1); @@ -283,8 +288,9 @@ void EncodeCoeffOrders(uint16_t used_orders, if (used_orders != 0) { std::vector context_map; EntropyEncodingData codes; - BuildAndEncodeHistograms(HistogramParams(), kPermutationContexts, tokens, - &codes, &context_map, writer, layer, aux_out); + BuildAndEncodeHistograms(memory_manager, HistogramParams(), + kPermutationContexts, tokens, &codes, &context_map, + writer, layer, aux_out); WriteTokens(tokens[0], codes, context_map, 0, writer, layer, aux_out); } } diff --git a/lib/jxl/enc_context_map.cc b/lib/jxl/enc_context_map.cc index 36efc4e6494..739179105a6 100644 --- a/lib/jxl/enc_context_map.cc +++ b/lib/jxl/enc_context_map.cc @@ -7,6 +7,7 @@ #include "lib/jxl/enc_context_map.h" +#include #include #include @@ -70,6 +71,7 @@ void EncodeContextMap(const std::vector& context_map, return; } + JxlMemoryManager* memory_manager = writer->memory_manager(); std::vector transformed_symbols = MoveToFrontTransform(context_map); std::vector> tokens(1); std::vector> mtf_tokens(1); @@ -86,14 +88,16 @@ void EncodeContextMap(const std::vector& context_map, { EntropyEncodingData codes; std::vector sink_context_map; - ans_cost = BuildAndEncodeHistograms(params, 1, tokens, &codes, - &sink_context_map, nullptr, 0, nullptr); + ans_cost = + BuildAndEncodeHistograms(memory_manager, params, 1, tokens, &codes, + &sink_context_map, nullptr, 0, nullptr); } { EntropyEncodingData codes; std::vector sink_context_map; - mtf_cost = BuildAndEncodeHistograms(params, 1, mtf_tokens, &codes, - &sink_context_map, nullptr, 0, nullptr); + mtf_cost = + BuildAndEncodeHistograms(memory_manager, params, 1, mtf_tokens, &codes, + &sink_context_map, nullptr, 0, nullptr); } bool use_mtf = mtf_cost < ans_cost; // Rebuild token list. @@ -118,8 +122,8 @@ void EncodeContextMap(const std::vector& context_map, writer->Write(1, TO_JXL_BOOL(use_mtf)); // Use/don't use MTF. EntropyEncodingData codes; std::vector sink_context_map; - BuildAndEncodeHistograms(params, 1, tokens, &codes, &sink_context_map, - writer, layer, aux_out); + BuildAndEncodeHistograms(memory_manager, params, 1, tokens, &codes, + &sink_context_map, writer, layer, aux_out); WriteTokens(tokens[0], codes, sink_context_map, 0, writer); allotment.ReclaimAndCharge(writer, layer, aux_out); } diff --git a/lib/jxl/enc_context_map.h b/lib/jxl/enc_context_map.h index 041e71de7af..a18bc7646c9 100644 --- a/lib/jxl/enc_context_map.h +++ b/lib/jxl/enc_context_map.h @@ -6,9 +6,8 @@ #ifndef LIB_JXL_ENC_CONTEXT_MAP_H_ #define LIB_JXL_ENC_CONTEXT_MAP_H_ -#include -#include - +#include +#include #include #include "lib/jxl/ac_context.h" diff --git a/lib/jxl/enc_frame.cc b/lib/jxl/enc_frame.cc index a6f1613bf12..ab4c2100c9c 100644 --- a/lib/jxl/enc_frame.cc +++ b/lib/jxl/enc_frame.cc @@ -782,7 +782,7 @@ Status ComputeJPEGTranscodingData(const jpeg::JPEGData& jpeg_data, } } } - DequantMatricesSetCustomDC(&shared.matrices, dcquantization); + DequantMatricesSetCustomDC(memory_manager, &shared.matrices, dcquantization); float dcquantization_r[3] = {1.0f / dcquantization[0], 1.0f / dcquantization[1], 1.0f / dcquantization[2]}; @@ -1240,7 +1240,7 @@ Status EncodeGlobalACInfo(PassesEncoderState* enc_state, BitWriter* writer, hist_params.streaming_mode = enc_state->streaming_mode; hist_params.initialize_global_state = enc_state->initialize_global_state; BuildAndEncodeHistograms( - hist_params, + memory_manager, hist_params, num_histogram_groups * shared.block_ctx_map.NumACContexts(), enc_state->passes[i].ac_tokens, &enc_state->passes[i].codes, &enc_state->passes[i].context_map, writer, kLayerAC, aux_out); @@ -1252,8 +1252,10 @@ Status EncodeGlobalACInfo(PassesEncoderState* enc_state, BitWriter* writer, Status EncodeGroups(const FrameHeader& frame_header, PassesEncoderState* enc_state, ModularFrameEncoder* enc_modular, ThreadPool* pool, - std::vector* group_codes, AuxOut* aux_out) { + std::vector>* group_codes, + AuxOut* aux_out) { const PassesSharedState& shared = enc_state->shared; + JxlMemoryManager* memory_manager = shared.memory_manager; const FrameDimensions& frame_dim = shared.frame_dim; const size_t num_groups = frame_dim.num_groups; const size_t num_passes = enc_state->progressive_splitter.GetNumPasses(); @@ -1264,10 +1266,14 @@ Status EncodeGroups(const FrameHeader& frame_header, is_small_image ? 1 : AcGroupIndex(0, 0, num_groups, frame_dim.num_dc_groups) + num_groups * num_passes; - group_codes->resize(num_toc_entries); + JXL_ASSERT(group_codes->empty()); + group_codes->reserve(num_toc_entries); + for (size_t i = 0; i < num_toc_entries; ++i) { + group_codes->emplace_back(jxl::make_unique(memory_manager)); + } - const auto get_output = [&](const size_t index) { - return &(*group_codes)[is_small_image ? 0 : index]; + const auto get_output = [&](const size_t index) -> BitWriter* { + return (*group_codes)[is_small_image ? 0 : index].get(); }; auto ac_group_code = [&](size_t pass, size_t group) { return get_output(AcGroupIndex(pass, group, frame_dim.num_groups, @@ -1404,10 +1410,10 @@ Status EncodeGroups(const FrameHeader& frame_header, // Resizing aux_outs to 0 also Assimilates the array. static_cast(resize_aux_outs(0)); - for (BitWriter& bw : *group_codes) { - BitWriter::Allotment allotment(&bw, 8); - bw.ZeroPadToByte(); // end of group. - allotment.ReclaimAndCharge(&bw, kLayerAC, aux_out); + for (std::unique_ptr& bw : *group_codes) { + BitWriter::Allotment allotment(bw.get(), 8); + bw->ZeroPadToByte(); // end of group. + allotment.ReclaimAndCharge(bw.get(), kLayerAC, aux_out); } return true; } @@ -1418,8 +1424,8 @@ Status ComputeEncodingData( const jpeg::JPEGData* jpeg_data, size_t x0, size_t y0, size_t xsize, size_t ysize, const JxlCmsInterface& cms, ThreadPool* pool, FrameHeader& mutable_frame_header, ModularFrameEncoder& enc_modular, - PassesEncoderState& enc_state, std::vector* group_codes, - AuxOut* aux_out) { + PassesEncoderState& enc_state, + std::vector>* group_codes, AuxOut* aux_out) { JXL_ASSERT(x0 + xsize <= frame_data.xsize); JXL_ASSERT(y0 + ysize <= frame_data.ysize); JxlMemoryManager* memory_manager = enc_state.memory_manager(); @@ -1630,7 +1636,7 @@ Status ComputeEncodingData( Status PermuteGroups(const CompressParams& cparams, const FrameDimensions& frame_dim, size_t num_passes, std::vector* permutation, - std::vector* group_codes) { + std::vector>* group_codes) { const size_t num_groups = frame_dim.num_groups; if (!cparams.centerfirst || (num_passes == 1 && num_groups == 1)) { return true; @@ -1699,11 +1705,11 @@ Status PermuteGroups(const CompressParams& cparams, permutation->push_back(pass_start + v); } } - std::vector new_group_codes(group_codes->size()); + std::vector> new_group_codes(group_codes->size()); for (size_t i = 0; i < permutation->size(); i++) { new_group_codes[(*permutation)[i]] = std::move((*group_codes)[i]); } - *group_codes = std::move(new_group_codes); + group_codes->swap(new_group_codes); return true; } @@ -1824,8 +1830,9 @@ size_t TOCSize(const std::vector& group_sizes) { return (toc_bits + 7) / 8; } -PaddedBytes EncodeTOC(const std::vector& group_sizes, AuxOut* aux_out) { - BitWriter writer; +PaddedBytes EncodeTOC(JxlMemoryManager* memory_manager, + const std::vector& group_sizes, AuxOut* aux_out) { + BitWriter writer{memory_manager}; BitWriter::Allotment allotment(&writer, 32 * group_sizes.size()); for (size_t group_size : group_sizes) { JXL_CHECK(U32Coder::Write(kTocDist, group_size, &writer)); @@ -1865,17 +1872,17 @@ size_t ComputeDcGlobalPadding(const std::vector& group_sizes, return group_data_offset - actual_offset; } -Status OutputGroups(std::vector&& group_codes, +Status OutputGroups(std::vector>&& group_codes, std::vector* group_sizes, JxlEncoderOutputProcessorWrapper* output_processor) { JXL_ASSERT(group_codes.size() >= 4); { - PaddedBytes dc_group = std::move(group_codes[1]).TakeBytes(); + PaddedBytes dc_group = std::move(*group_codes[1]).TakeBytes(); group_sizes->push_back(dc_group.size()); JXL_RETURN_IF_ERROR(AppendData(*output_processor, dc_group)); } for (size_t i = 3; i < group_codes.size(); ++i) { - PaddedBytes ac_group = std::move(group_codes[i]).TakeBytes(); + PaddedBytes ac_group = std::move(*group_codes[i]).TakeBytes(); group_sizes->push_back(ac_group.size()); JXL_RETURN_IF_ERROR(AppendData(*output_processor, ac_group)); } @@ -1913,7 +1920,8 @@ Status OutputAcGlobal(PassesEncoderState& enc_state, JxlEncoderOutputProcessorWrapper* output_processor, AuxOut* aux_out) { JXL_ASSERT(frame_dim.num_groups > 1); - BitWriter writer; + JxlMemoryManager* memory_manager = enc_state.memory_manager(); + BitWriter writer{memory_manager}; { size_t num_histo_bits = CeilLog2Nonzero(frame_dim.num_groups); BitWriter::Allotment allotment(&writer, num_histo_bits + 1); @@ -1982,8 +1990,8 @@ Status EncodeFrameStreaming(JxlMemoryManager* memory_manager, size_t dc_group_xsize = DivCeil(frame_data.xsize, dc_group_size); size_t min_dc_global_size = 0; size_t group_data_offset = 0; - PaddedBytes frame_header_bytes; - PaddedBytes dc_global_bytes; + PaddedBytes frame_header_bytes{memory_manager}; + PaddedBytes dc_global_bytes{memory_manager}; std::vector group_sizes; size_t start_pos = output_processor->CurrentPosition(); for (size_t i = 0; i < dc_group_order.size(); ++i) { @@ -2005,14 +2013,14 @@ Status EncodeFrameStreaming(JxlMemoryManager* memory_manager, enc_state.initialize_global_state = (i == 0); enc_state.dc_group_index = dc_ix; enc_state.histogram_idx = std::vector(group_xsize * group_ysize, i); - std::vector group_codes; + std::vector> group_codes; JXL_RETURN_IF_ERROR(ComputeEncodingData( cparams, frame_info, metadata, frame_data, jpeg_data.get(), x0, y0, xsize, ysize, cms, pool, frame_header, enc_modular, enc_state, &group_codes, aux_out)); JXL_ASSERT(enc_state.special_frames.empty()); if (i == 0) { - BitWriter writer; + BitWriter writer{memory_manager}; JXL_RETURN_IF_ERROR(WriteFrameHeader(frame_header, &writer, aux_out)); BitWriter::Allotment allotment(&writer, 8); writer.Write(1, 1); // write permutation @@ -2021,7 +2029,7 @@ Status EncodeFrameStreaming(JxlMemoryManager* memory_manager, writer.ZeroPadToByte(); allotment.ReclaimAndCharge(&writer, kLayerHeader, aux_out); frame_header_bytes = std::move(writer).TakeBytes(); - dc_global_bytes = std::move(group_codes[0]).TakeBytes(); + dc_global_bytes = std::move(*group_codes[0]).TakeBytes(); ComputeGroupDataOffset(frame_header_bytes.size(), dc_global_bytes.size(), permutation.size(), min_dc_global_size, group_data_offset); @@ -2050,7 +2058,7 @@ Status EncodeFrameStreaming(JxlMemoryManager* memory_manager, ComputeDcGlobalPadding(group_sizes, frame_header_bytes.size(), group_data_offset, min_dc_global_size); group_sizes[0] += padding_size; - PaddedBytes toc_bytes = EncodeTOC(group_sizes, aux_out); + PaddedBytes toc_bytes = EncodeTOC(memory_manager, group_sizes, aux_out); std::vector padding_bytes(padding_size); JXL_RETURN_IF_ERROR(AppendData(*output_processor, frame_header_bytes)); JXL_RETURN_IF_ERROR(AppendData(*output_processor, toc_bytes)); @@ -2074,7 +2082,6 @@ Status EncodeFrameOneShot(JxlMemoryManager* memory_manager, AuxOut* aux_out) { PassesEncoderState enc_state{memory_manager}; SetProgressiveMode(cparams, &enc_state.progressive_splitter); - std::vector group_codes; FrameHeader frame_header(metadata); std::unique_ptr jpeg_data; if (frame_data.IsJPEG()) { @@ -2086,12 +2093,13 @@ Status EncodeFrameOneShot(JxlMemoryManager* memory_manager, &frame_header)); const size_t num_passes = enc_state.progressive_splitter.GetNumPasses(); ModularFrameEncoder enc_modular(memory_manager, frame_header, cparams, false); + std::vector> group_codes; JXL_RETURN_IF_ERROR(ComputeEncodingData( cparams, frame_info, metadata, frame_data, jpeg_data.get(), 0, 0, frame_data.xsize, frame_data.ysize, cms, pool, frame_header, enc_modular, enc_state, &group_codes, aux_out)); - BitWriter writer; + BitWriter writer{memory_manager}; writer.AppendByteAligned(enc_state.special_frames); JXL_RETURN_IF_ERROR(WriteFrameHeader(frame_header, &writer, aux_out)); diff --git a/lib/jxl/enc_heuristics.cc b/lib/jxl/enc_heuristics.cc index 1c45e016703..13daaad1cc6 100644 --- a/lib/jxl/enc_heuristics.cc +++ b/lib/jxl/enc_heuristics.cc @@ -196,7 +196,8 @@ void FindBestBlockEntropyModel(const CompressParams& cparams, const ImageI& rqf, namespace { -Status FindBestDequantMatrices(const CompressParams& cparams, +Status FindBestDequantMatrices(JxlMemoryManager* memory_manager, + const CompressParams& cparams, ModularFrameEncoder* modular_frame_encoder, DequantMatrices* dequant_matrices) { // TODO(veluca): quant matrices for no-gaborish. @@ -215,7 +216,7 @@ Status FindBestDequantMatrices(const CompressParams& cparams, JXL_RETURN_IF_ERROR(DequantMatricesSetCustom(dequant_matrices, encodings, modular_frame_encoder)); float dc_weights[3] = {1.0f / wp[0], 1.0f / wp[1], 1.0f / wp[2]}; - DequantMatricesSetCustomDC(dequant_matrices, dc_weights); + DequantMatricesSetCustomDC(memory_manager, dequant_matrices, dc_weights); } return true; } @@ -1088,8 +1089,8 @@ Status LossyFrameHeuristics(const FrameHeader& frame_header, } if (initialize_global_state) { - JXL_RETURN_IF_ERROR( - FindBestDequantMatrices(cparams, modular_frame_encoder, &matrices)); + JXL_RETURN_IF_ERROR(FindBestDequantMatrices( + memory_manager, cparams, modular_frame_encoder, &matrices)); } JXL_RETURN_IF_ERROR(cfl_heuristics.Init(memory_manager, rect)); diff --git a/lib/jxl/enc_icc_codec.cc b/lib/jxl/enc_icc_codec.cc index c3899600e51..1e40f32b307 100644 --- a/lib/jxl/enc_icc_codec.cc +++ b/lib/jxl/enc_icc_codec.cc @@ -5,14 +5,13 @@ #include "lib/jxl/enc_icc_codec.h" -#include +#include +#include #include #include -#include #include -#include "lib/jxl/base/byte_order.h" #include "lib/jxl/color_encoding_internal.h" #include "lib/jxl/enc_ans.h" #include "lib/jxl/enc_aux_out.h" @@ -33,9 +32,10 @@ namespace { // elements at the bottom of the rightmost column. The input is the input matrix // in scanline order, the output is the result matrix in scanline order, with // missing elements skipped over (this may occur at multiple positions). -void Unshuffle(uint8_t* data, size_t size, size_t width) { +void Unshuffle(JxlMemoryManager* memory_manager, uint8_t* data, size_t size, + size_t width) { size_t height = (size + width - 1) / width; // amount of rows of input - PaddedBytes result(size); + PaddedBytes result(memory_manager, size); // i = input index, j output index size_t s = 0; size_t j = 0; @@ -57,6 +57,7 @@ Status PredictAndShuffle(size_t stride, size_t width, int order, size_t num, const uint8_t* data, size_t size, size_t* pos, PaddedBytes* result) { JXL_RETURN_IF_ERROR(CheckOutOfBounds(*pos, num, size)); + JxlMemoryManager* memory_manager = result->memory_manager(); // Required by the specification, see decoder. stride * 4 must be < *pos. if (!*pos || ((*pos - 1u) >> 2u) < stride) { return JXL_FAILURE("Invalid stride"); @@ -69,7 +70,7 @@ Status PredictAndShuffle(size_t stride, size_t width, int order, size_t num, result->push_back(data[*pos + i] - predicted); } *pos += num; - if (width > 1) Unshuffle(result->data() + start, num, width); + if (width > 1) Unshuffle(memory_manager, result->data() + start, num, width); return true; } @@ -105,8 +106,9 @@ constexpr size_t kSizeLimit = std::numeric_limits::max() >> 2; // form that is easier to compress (more zeroes, ...) and will compress better // with brotli. Status PredictICC(const uint8_t* icc, size_t size, PaddedBytes* result) { - PaddedBytes commands; - PaddedBytes data; + JxlMemoryManager* memory_manager = result->memory_manager(); + PaddedBytes commands{memory_manager}; + PaddedBytes data{memory_manager}; static_assert(sizeof(size_t) >= 4, "size_t is too short"); // Fuzzer expects that PredictICC can accept any input, @@ -118,7 +120,7 @@ Status PredictICC(const uint8_t* icc, size_t size, PaddedBytes* result) { EncodeVarInt(size, result); // Header - PaddedBytes header; + PaddedBytes header{memory_manager}; header.append(ICCInitialHeaderPrediction()); EncodeUint32(0, size, &header); for (size_t i = 0; i < kICCHeaderSize && i < size; i++) { @@ -256,8 +258,8 @@ Status PredictICC(const uint8_t* icc, size_t size, PaddedBytes* result) { // but will not predict as well. while (pos <= size) { size_t last1 = pos; - PaddedBytes commands_add; - PaddedBytes data_add; + PaddedBytes commands_add{memory_manager}; + PaddedBytes data_add{memory_manager}; // This means the loop brought the position beyond the tag end. // If tagsize is nonsensical, any pos looks "ok-ish". @@ -285,7 +287,7 @@ Status PredictICC(const uint8_t* icc, size_t size, PaddedBytes* result) { data_add.push_back(icc[pos]); pos++; } - Unshuffle(data_add.data() + start, num, 2); + Unshuffle(memory_manager, data_add.data() + start, num, 2); } if (tag == kCurvTag && tag_sane() && pos + tagsize <= size && @@ -430,7 +432,8 @@ Status PredictICC(const uint8_t* icc, size_t size, PaddedBytes* result) { Status WriteICC(const IccBytes& icc, BitWriter* JXL_RESTRICT writer, size_t layer, AuxOut* JXL_RESTRICT aux_out) { if (icc.empty()) return JXL_FAILURE("ICC must be non-empty"); - PaddedBytes enc; + JxlMemoryManager* memory_manager = writer->memory_manager(); + PaddedBytes enc{memory_manager}; JXL_RETURN_IF_ERROR(PredictICC(icc.data(), icc.size(), &enc)); std::vector> tokens(1); BitWriter::Allotment allotment(writer, 128); @@ -448,8 +451,8 @@ Status WriteICC(const IccBytes& icc, BitWriter* JXL_RESTRICT writer, EntropyEncodingData code; std::vector context_map; params.force_huffman = true; - BuildAndEncodeHistograms(params, kNumICCContexts, tokens, &code, &context_map, - writer, layer, aux_out); + BuildAndEncodeHistograms(memory_manager, params, kNumICCContexts, tokens, + &code, &context_map, writer, layer, aux_out); WriteTokens(tokens[0], code, context_map, 0, writer, layer, aux_out); return true; } diff --git a/lib/jxl/enc_icc_codec.h b/lib/jxl/enc_icc_codec.h index a99e11b19cd..956a33ff4d6 100644 --- a/lib/jxl/enc_icc_codec.h +++ b/lib/jxl/enc_icc_codec.h @@ -8,9 +8,8 @@ // Compressed representation of ICC profiles. -#include -#include - +#include +#include #include #include "lib/jxl/base/compiler_specific.h" diff --git a/lib/jxl/enc_modular.cc b/lib/jxl/enc_modular.cc index 4fbc98ce21f..d935bcaa929 100644 --- a/lib/jxl/enc_modular.cc +++ b/lib/jxl/enc_modular.cc @@ -704,11 +704,12 @@ Status ModularFrameEncoder::ComputeEncodingData( cparams_.butteraugli_distance = 0; } if (cparams_.manual_xyb_factors.size() == 3) { - DequantMatricesSetCustomDC(&enc_state->shared.matrices, + DequantMatricesSetCustomDC(memory_manager, &enc_state->shared.matrices, cparams_.manual_xyb_factors.data()); // TODO(jon): update max_bitdepth in this case } else { - DequantMatricesSetCustomDC(&enc_state->shared.matrices, enc_factors); + DequantMatricesSetCustomDC(memory_manager, &enc_state->shared.matrices, + enc_factors); max_bitdepth = 12; } } @@ -1246,6 +1247,7 @@ Status ModularFrameEncoder::ComputeTokens(ThreadPool* pool) { Status ModularFrameEncoder::EncodeGlobalInfo(bool streaming_mode, BitWriter* writer, AuxOut* aux_out) { + JxlMemoryManager* memory_manager = writer->memory_manager(); BitWriter::Allotment allotment(writer, 1); // If we are using brotli, or not using modular mode. if (tree_tokens_.empty() || tree_tokens_[0].empty()) { @@ -1262,9 +1264,9 @@ Status ModularFrameEncoder::EncodeGlobalInfo(bool streaming_mode, { EntropyEncodingData tree_code; std::vector tree_context_map; - BuildAndEncodeHistograms(params, kNumTreeContexts, tree_tokens_, &tree_code, - &tree_context_map, writer, kLayerModularTree, - aux_out); + BuildAndEncodeHistograms(memory_manager, params, kNumTreeContexts, + tree_tokens_, &tree_code, &tree_context_map, + writer, kLayerModularTree, aux_out); WriteTokens(tree_tokens_[0], tree_code, tree_context_map, 0, writer, kLayerModularTree, aux_out); } @@ -1272,8 +1274,9 @@ Status ModularFrameEncoder::EncodeGlobalInfo(bool streaming_mode, params.add_missing_symbols = streaming_mode; params.image_widths = image_widths_; // Write histograms. - BuildAndEncodeHistograms(params, (tree_.size() + 1) / 2, tokens_, &code_, - &context_map_, writer, kLayerModularGlobal, aux_out); + BuildAndEncodeHistograms(memory_manager, params, (tree_.size() + 1) / 2, + tokens_, &code_, &context_map_, writer, + kLayerModularGlobal, aux_out); return true; } diff --git a/lib/jxl/enc_noise.cc b/lib/jxl/enc_noise.cc index 80b90eed2c7..446a770c683 100644 --- a/lib/jxl/enc_noise.cc +++ b/lib/jxl/enc_noise.cc @@ -5,19 +5,14 @@ #include "lib/jxl/enc_noise.h" -#include -#include - #include +#include +#include #include #include -#include "lib/jxl/base/compiler_specific.h" -#include "lib/jxl/chroma_from_luma.h" -#include "lib/jxl/convolve.h" #include "lib/jxl/enc_aux_out.h" #include "lib/jxl/enc_optimize.h" -#include "lib/jxl/image_ops.h" namespace jxl { namespace { diff --git a/lib/jxl/enc_patch_dictionary.cc b/lib/jxl/enc_patch_dictionary.cc index 2aa280b19d1..507e82caad0 100644 --- a/lib/jxl/enc_patch_dictionary.cc +++ b/lib/jxl/enc_patch_dictionary.cc @@ -47,6 +47,7 @@ void PatchDictionaryEncoder::Encode(const PatchDictionary& pdic, BitWriter* writer, size_t layer, AuxOut* aux_out) { JXL_ASSERT(pdic.HasAny()); + JxlMemoryManager* memory_manager = writer->memory_manager(); std::vector> tokens(1); size_t num_ec = pdic.shared_->metadata->m.num_extra_channels; @@ -107,9 +108,9 @@ void PatchDictionaryEncoder::Encode(const PatchDictionary& pdic, EntropyEncodingData codes; std::vector context_map; - BuildAndEncodeHistograms(HistogramParams(), kNumPatchDictionaryContexts, - tokens, &codes, &context_map, writer, layer, - aux_out); + BuildAndEncodeHistograms(memory_manager, HistogramParams(), + kNumPatchDictionaryContexts, tokens, &codes, + &context_map, writer, layer, aux_out); WriteTokens(tokens[0], codes, context_map, 0, writer, layer, aux_out); } @@ -804,7 +805,8 @@ Status RoundtripPatchFrame(Image3F* reference_frame, } ib.SetExtraChannels(std::move(extra_channels)); } - auto special_frame = std::unique_ptr(new BitWriter()); + auto special_frame = + std::unique_ptr(new BitWriter(memory_manager)); AuxOut patch_aux_out; JXL_CHECK(EncodeFrame( memory_manager, cparams, patch_frame_info, state->shared.metadata, ib, diff --git a/lib/jxl/enc_quant_weights.cc b/lib/jxl/enc_quant_weights.cc index 14b1f9f645c..eefdecc8988 100644 --- a/lib/jxl/enc_quant_weights.cc +++ b/lib/jxl/enc_quant_weights.cc @@ -161,10 +161,11 @@ Status DequantMatricesEncodeDC(const DequantMatrices& matrices, return true; } -void DequantMatricesSetCustomDC(DequantMatrices* matrices, const float* dc) { +void DequantMatricesSetCustomDC(JxlMemoryManager* memory_manager, + DequantMatrices* matrices, const float* dc) { matrices->SetDCQuant(dc); // Roundtrip encode/decode DC to ensure same values as decoder. - BitWriter writer; + BitWriter writer{memory_manager}; JXL_CHECK(DequantMatricesEncodeDC(*matrices, &writer, 0, nullptr)); writer.ZeroPadToByte(); BitReader br(writer.GetSpan()); @@ -173,19 +174,20 @@ void DequantMatricesSetCustomDC(DequantMatrices* matrices, const float* dc) { JXL_CHECK(br.Close()); } -void DequantMatricesScaleDC(DequantMatrices* matrices, const float scale) { +void DequantMatricesScaleDC(JxlMemoryManager* memory_manager, + DequantMatrices* matrices, const float scale) { float dc[3]; for (size_t c = 0; c < 3; ++c) { dc[c] = matrices->InvDCQuant(c) * (1.0f / scale); } - DequantMatricesSetCustomDC(matrices, dc); + DequantMatricesSetCustomDC(memory_manager, matrices, dc); } void DequantMatricesRoundtrip(JxlMemoryManager* memory_manager, DequantMatrices* matrices) { // Do not pass modular en/decoder, as they only change entropy and not // values. - BitWriter writer; + BitWriter writer{memory_manager}; JXL_CHECK( DequantMatricesEncode(memory_manager, *matrices, &writer, 0, nullptr)); writer.ZeroPadToByte(); diff --git a/lib/jxl/enc_quant_weights.h b/lib/jxl/enc_quant_weights.h index bdd2980d00a..e11bcf52cca 100644 --- a/lib/jxl/enc_quant_weights.h +++ b/lib/jxl/enc_quant_weights.h @@ -28,9 +28,11 @@ Status DequantMatricesEncodeDC(const DequantMatrices& matrices, AuxOut* aux_out); // For consistency with QuantEncoding, higher values correspond to more // precision. -void DequantMatricesSetCustomDC(DequantMatrices* matrices, const float* dc); +void DequantMatricesSetCustomDC(JxlMemoryManager* memory_manager, + DequantMatrices* matrices, const float* dc); -void DequantMatricesScaleDC(DequantMatrices* matrices, float scale); +void DequantMatricesScaleDC(JxlMemoryManager* memory_manager, + DequantMatrices* matrices, float scale); Status DequantMatricesSetCustom(DequantMatrices* matrices, const std::vector& encodings, diff --git a/lib/jxl/enc_splines.cc b/lib/jxl/enc_splines.cc index 186f19da934..ec35accd60c 100644 --- a/lib/jxl/enc_splines.cc +++ b/lib/jxl/enc_splines.cc @@ -3,14 +3,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -#include - -#include "lib/jxl/ans_params.h" #include "lib/jxl/base/status.h" -#include "lib/jxl/chroma_from_luma.h" -#include "lib/jxl/dct_scales.h" #include "lib/jxl/enc_ans.h" -#include "lib/jxl/entropy_coder.h" #include "lib/jxl/pack_signed.h" #include "lib/jxl/splines.h" @@ -84,8 +78,9 @@ void EncodeSplines(const Splines& splines, BitWriter* writer, EntropyEncodingData codes; std::vector context_map; - BuildAndEncodeHistograms(histogram_params, kNumSplineContexts, tokens, &codes, - &context_map, writer, layer, aux_out); + BuildAndEncodeHistograms(writer->memory_manager(), histogram_params, + kNumSplineContexts, tokens, &codes, &context_map, + writer, layer, aux_out); WriteTokens(tokens[0], codes, context_map, 0, writer, layer, aux_out); } diff --git a/lib/jxl/enc_toc.cc b/lib/jxl/enc_toc.cc index 4ecba8fdb10..19b6d7023f2 100644 --- a/lib/jxl/enc_toc.cc +++ b/lib/jxl/enc_toc.cc @@ -5,10 +5,9 @@ #include "lib/jxl/enc_toc.h" -#include +#include #include "lib/jxl/base/common.h" -#include "lib/jxl/coeff_order.h" #include "lib/jxl/enc_aux_out.h" #include "lib/jxl/enc_coeff_order.h" #include "lib/jxl/field_encodings.h" @@ -16,9 +15,10 @@ #include "lib/jxl/toc.h" namespace jxl { -Status WriteGroupOffsets(const std::vector& group_codes, - const std::vector& permutation, - BitWriter* JXL_RESTRICT writer, AuxOut* aux_out) { +Status WriteGroupOffsets( + const std::vector>& group_codes, + const std::vector& permutation, + BitWriter* JXL_RESTRICT writer, AuxOut* aux_out) { BitWriter::Allotment allotment(writer, MaxBits(group_codes.size())); if (!permutation.empty() && !group_codes.empty()) { // Don't write a permutation at all for an empty group_codes. @@ -33,8 +33,8 @@ Status WriteGroupOffsets(const std::vector& group_codes, writer->ZeroPadToByte(); // before TOC entries for (const auto& bw : group_codes) { - JXL_ASSERT(bw.BitsWritten() % kBitsPerByte == 0); - const size_t group_size = bw.BitsWritten() / kBitsPerByte; + JXL_ASSERT(bw->BitsWritten() % kBitsPerByte == 0); + const size_t group_size = bw->BitsWritten() / kBitsPerByte; JXL_RETURN_IF_ERROR(U32Coder::Write(kTocDist, group_size, writer)); } writer->ZeroPadToByte(); // before first group diff --git a/lib/jxl/enc_toc.h b/lib/jxl/enc_toc.h index aa222141bec..574439c6ced 100644 --- a/lib/jxl/enc_toc.h +++ b/lib/jxl/enc_toc.h @@ -6,9 +6,7 @@ #ifndef LIB_JXL_ENC_TOC_H_ #define LIB_JXL_ENC_TOC_H_ -#include -#include - +#include #include #include "lib/jxl/base/compiler_specific.h" @@ -22,9 +20,10 @@ struct AuxOut; // Writes the group offsets. If the permutation vector is empty, the identity // permutation will be used. -Status WriteGroupOffsets(const std::vector& group_codes, - const std::vector& permutation, - BitWriter* JXL_RESTRICT writer, AuxOut* aux_out); +Status WriteGroupOffsets( + const std::vector>& group_codes, + const std::vector& permutation, + BitWriter* JXL_RESTRICT writer, AuxOut* aux_out); } // namespace jxl diff --git a/lib/jxl/encode.cc b/lib/jxl/encode.cc index 857bb3a09d7..79f666e4302 100644 --- a/lib/jxl/encode.cc +++ b/lib/jxl/encode.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -122,14 +123,14 @@ JxlEncoderOutputProcessorWrapper::GetBuffer(size_t min_size, external_output_processor_->release_buffer( external_output_processor_->opaque, 0); } else { - internal_buffers_.emplace(position_, InternalBuffer()); + internal_buffers_.emplace(position_, InternalBuffer(memory_manager_)); has_buffer_ = true; return JxlOutputProcessorBuffer(user_buffer, size, 0, this); } } } else { if (min_size + additional_size < *avail_out_) { - internal_buffers_.emplace(position_, InternalBuffer()); + internal_buffers_.emplace(position_, InternalBuffer(memory_manager_)); has_buffer_ = true; return JxlOutputProcessorBuffer(*next_out_ + additional_size, *avail_out_ - additional_size, 0, this); @@ -137,7 +138,9 @@ JxlEncoderOutputProcessorWrapper::GetBuffer(size_t min_size, } // Otherwise, we need to allocate our own buffer. - auto it = internal_buffers_.emplace(position_, InternalBuffer()).first; + auto it = + internal_buffers_.emplace(position_, InternalBuffer(memory_manager_)) + .first; InternalBuffer& buffer = it->second; size_t alloc_size = requested_size; it++; @@ -462,6 +465,7 @@ void QueueBox(JxlEncoder* enc, // TODO(lode): share this code and the Brotli compression code in enc_jpeg_data JxlEncoderStatus BrotliCompress(int quality, const uint8_t* in, size_t in_size, jxl::PaddedBytes* out) { + JxlMemoryManager* memory_manager = out->memory_manager(); std::unique_ptr enc(BrotliEncoderCreateInstance(nullptr, nullptr, nullptr), BrotliEncoderDestroyInstance); @@ -471,7 +475,7 @@ JxlEncoderStatus BrotliCompress(int quality, const uint8_t* in, size_t in_size, BrotliEncoderSetParameter(enc.get(), BROTLI_PARAM_SIZE_HINT, in_size); constexpr size_t kBufferSize = 128 * 1024; - jxl::PaddedBytes temp_buffer(kBufferSize); + jxl::PaddedBytes temp_buffer(memory_manager, kBufferSize); size_t avail_in = in_size; const uint8_t* next_in = in; @@ -693,7 +697,7 @@ bool EncodeFrameIndexBox(const jxl::JxlEncoderFrameIndexBox& frame_index_box, } // namespace jxl::Status JxlEncoderStruct::ProcessOneEnqueuedInput() { - jxl::PaddedBytes header_bytes; + jxl::PaddedBytes header_bytes{&memory_manager}; jxl::JxlEncoderQueuedInput& input = input_queue[0]; @@ -729,7 +733,7 @@ jxl::Status JxlEncoderStruct::ProcessOneEnqueuedInput() { } jxl::AuxOut* aux_out = input.frame ? input.frame->option_values.aux_out : nullptr; - jxl::BitWriter writer; + jxl::BitWriter writer{&memory_manager}; if (!WriteCodestreamHeaders(&metadata, &writer, aux_out)) { return JXL_API_ERROR(this, JXL_ENC_ERR_GENERIC, "Failed to write codestream header"); @@ -1023,7 +1027,7 @@ jxl::Status JxlEncoderStruct::ProcessOneEnqueuedInput() { num_queued_boxes--; if (box->compress_box) { - jxl::PaddedBytes compressed(4); + jxl::PaddedBytes compressed(&memory_manager, 4); // Prepend the original box type in the brob box contents for (size_t i = 0; i < 4; i++) { compressed[i] = static_cast(box->type[i]); @@ -2127,7 +2131,8 @@ JxlEncoderStatus JxlEncoderAddJPEGFrame( } jxl::jpeg::JPEGData data_in = *io.Main().jpeg_data; std::vector jpeg_data; - if (!jxl::jpeg::EncodeJPEGData(data_in, &jpeg_data, + if (!jxl::jpeg::EncodeJPEGData(&frame_settings->enc->memory_manager, + data_in, &jpeg_data, frame_settings->values.cparams)) { return JXL_API_ERROR( frame_settings->enc, JXL_ENC_ERR_JBRD, @@ -2534,7 +2539,8 @@ JXL_EXPORT JxlEncoderStatus JxlEncoderSetOutputProcessor( return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, "Missing output processor functions"); } - enc->output_processor = JxlEncoderOutputProcessorWrapper(output_processor); + enc->output_processor = + JxlEncoderOutputProcessorWrapper(&enc->memory_manager, output_processor); return JxlErrorOrStatus::Success(); } diff --git a/lib/jxl/encode_internal.h b/lib/jxl/encode_internal.h index 8030c8314ab..d2154c2d9a8 100644 --- a/lib/jxl/encode_internal.h +++ b/lib/jxl/encode_internal.h @@ -431,9 +431,11 @@ class JxlEncoderOutputProcessorWrapper { friend class JxlOutputProcessorBuffer; public: - JxlEncoderOutputProcessorWrapper() = default; - explicit JxlEncoderOutputProcessorWrapper(JxlEncoderOutputProcessor processor) - : external_output_processor_( + JxlEncoderOutputProcessorWrapper() : memory_manager_(nullptr) {} + JxlEncoderOutputProcessorWrapper(JxlMemoryManager* memory_manager, + JxlEncoderOutputProcessor processor) + : memory_manager_(memory_manager), + external_output_processor_( jxl::make_unique(processor)) {} bool HasAvailOut() const { return avail_out_ != nullptr; } @@ -472,6 +474,8 @@ class JxlEncoderOutputProcessorWrapper { bool AppendBufferToExternalProcessor(void* data, size_t count); struct InternalBuffer { + explicit InternalBuffer(JxlMemoryManager* memory_manager) + : owned_data(memory_manager) {} // Bytes in the range `[output_position_ - start_of_the_buffer, // written_bytes)` need to be flushed out. size_t written_bytes = 0; @@ -496,6 +500,7 @@ class JxlEncoderOutputProcessorWrapper { bool stop_requested_ = false; bool has_buffer_ = false; + JxlMemoryManager* memory_manager_; std::unique_ptr external_output_processor_; }; diff --git a/lib/jxl/entropy_coder.cc b/lib/jxl/entropy_coder.cc index 5dc101b36fc..eb78aa6190a 100644 --- a/lib/jxl/entropy_coder.cc +++ b/lib/jxl/entropy_coder.cc @@ -5,31 +5,24 @@ #include "lib/jxl/entropy_coder.h" -#include -#include +#include -#include -#include +#include #include #include "lib/jxl/ac_context.h" -#include "lib/jxl/ac_strategy.h" -#include "lib/jxl/base/bits.h" -#include "lib/jxl/base/compiler_specific.h" #include "lib/jxl/base/status.h" #include "lib/jxl/coeff_order.h" #include "lib/jxl/coeff_order_fwd.h" -#include "lib/jxl/dec_ans.h" #include "lib/jxl/dec_bit_reader.h" #include "lib/jxl/dec_context_map.h" -#include "lib/jxl/epf.h" -#include "lib/jxl/image.h" -#include "lib/jxl/image_ops.h" +#include "lib/jxl/fields.h" #include "lib/jxl/pack_signed.h" namespace jxl { -Status DecodeBlockCtxMap(BitReader* br, BlockCtxMap* block_ctx_map) { +Status DecodeBlockCtxMap(JxlMemoryManager* memory_manager, BitReader* br, + BlockCtxMap* block_ctx_map) { auto& dct = block_ctx_map->dc_thresholds; auto& qft = block_ctx_map->qf_thresholds; auto& ctx_map = block_ctx_map->ctx_map; @@ -57,7 +50,8 @@ Status DecodeBlockCtxMap(BitReader* br, BlockCtxMap* block_ctx_map) { ctx_map.resize(3 * kNumOrders * block_ctx_map->num_dc_ctxs * (qft.size() + 1)); - JXL_RETURN_IF_ERROR(DecodeContextMap(&ctx_map, &block_ctx_map->num_ctxs, br)); + JXL_RETURN_IF_ERROR( + DecodeContextMap(memory_manager, &ctx_map, &block_ctx_map->num_ctxs, br)); if (block_ctx_map->num_ctxs > 16) { return JXL_FAILURE("Invalid block context map: too many distinct contexts"); } diff --git a/lib/jxl/entropy_coder.h b/lib/jxl/entropy_coder.h index e4afa7a6314..2ce9f53f878 100644 --- a/lib/jxl/entropy_coder.h +++ b/lib/jxl/entropy_coder.h @@ -6,8 +6,10 @@ #ifndef LIB_JXL_ENTROPY_CODER_H_ #define LIB_JXL_ENTROPY_CODER_H_ -#include -#include +#include + +#include +#include #include "lib/jxl/ac_context.h" #include "lib/jxl/base/compiler_specific.h" @@ -38,7 +40,8 @@ static constexpr U32Enc kDCThresholdDist(Bits(4), BitsOffset(8, 16), static constexpr U32Enc kQFThresholdDist(Bits(2), BitsOffset(3, 4), BitsOffset(5, 12), BitsOffset(8, 44)); -Status DecodeBlockCtxMap(BitReader* br, BlockCtxMap* block_ctx_map); +Status DecodeBlockCtxMap(JxlMemoryManager* memory_manager, BitReader* br, + BlockCtxMap* block_ctx_map); } // namespace jxl diff --git a/lib/jxl/fields_test.cc b/lib/jxl/fields_test.cc index 5af68d2d5fe..fd0e23465b7 100644 --- a/lib/jxl/fields_test.cc +++ b/lib/jxl/fields_test.cc @@ -5,6 +5,8 @@ #include "lib/jxl/fields.h" +#include + #include #include #include @@ -20,6 +22,7 @@ #include "lib/jxl/frame_header.h" #include "lib/jxl/headers.h" #include "lib/jxl/image_metadata.h" +#include "lib/jxl/test_utils.h" #include "lib/jxl/testing.h" namespace jxl { @@ -27,9 +30,10 @@ namespace { // Ensures `value` round-trips and in exactly `expected_bits_written`. void TestU32Coder(const uint32_t value, const size_t expected_bits_written) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); const U32Enc enc(Val(0), Bits(4), Val(0x7FFFFFFF), Bits(32)); - BitWriter writer; + BitWriter writer{memory_manager}; BitWriter::Allotment allotment( &writer, RoundUpBitsToByteMultiple(U32Coder::MaxEncodedBits(enc))); @@ -60,7 +64,8 @@ TEST(FieldsTest, U32CoderTest) { } void TestU64Coder(const uint64_t value, const size_t expected_bits_written) { - BitWriter writer; + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); + BitWriter writer{memory_manager}; BitWriter::Allotment allotment( &writer, RoundUpBitsToByteMultiple(U64Coder::MaxEncodedBits())); @@ -160,12 +165,13 @@ TEST(FieldsTest, U64CoderTest) { } Status TestF16Coder(const float value) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); size_t max_encoded_bits; // It is not a fatal error if it can't be encoded. if (!F16Coder::CanEncode(value, &max_encoded_bits)) return false; EXPECT_EQ(F16Coder::MaxEncodedBits(), max_encoded_bits); - BitWriter writer; + BitWriter writer{memory_manager}; BitWriter::Allotment allotment(&writer, RoundUpBitsToByteMultiple(max_encoded_bits)); @@ -199,6 +205,7 @@ TEST(FieldsTest, F16CoderTest) { // Ensures Read(Write()) returns the same fields. TEST(FieldsTest, TestRoundtripSize) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); for (int i = 0; i < 8; i++) { SizeHeader size; ASSERT_TRUE(size.Set(123 + 77 * i, 7 + i)); @@ -208,7 +215,7 @@ TEST(FieldsTest, TestRoundtripSize) { ASSERT_TRUE(Bundle::CanEncode(size, &extension_bits, &total_bits)); EXPECT_EQ(0u, extension_bits); - BitWriter writer; + BitWriter writer{memory_manager}; ASSERT_TRUE(WriteSizeHeader(size, &writer, 0, nullptr)); EXPECT_EQ(total_bits, writer.BitsWritten()); writer.ZeroPadToByte(); @@ -256,6 +263,7 @@ TEST(FieldsTest, TestPreview) { // Ensures Read(Write()) returns the same fields. TEST(FieldsTest, TestRoundtripFrame) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); CodecMetadata metadata; FrameHeader h(&metadata); h.extensions = 0x800; @@ -264,7 +272,7 @@ TEST(FieldsTest, TestRoundtripFrame) { size_t total_bits = 999; // Initialize as garbage. ASSERT_TRUE(Bundle::CanEncode(h, &extension_bits, &total_bits)); EXPECT_EQ(0u, extension_bits); - BitWriter writer; + BitWriter writer{memory_manager}; ASSERT_TRUE(WriteFrameHeader(h, &writer, nullptr)); EXPECT_EQ(total_bits, writer.BitsWritten()); writer.ZeroPadToByte(); @@ -348,6 +356,7 @@ struct NewBundle : public Fields { }; TEST(FieldsTest, TestNewDecoderOldData) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); OldBundle old_bundle; old_bundle.old_large = 123; old_bundle.old_f = 3.75f; @@ -355,7 +364,7 @@ TEST(FieldsTest, TestNewDecoderOldData) { // Write to bit stream const size_t kMaxOutBytes = 999; - BitWriter writer; + BitWriter writer{memory_manager}; // Make sure values are initialized by code under test. size_t extension_bits = 12345; size_t total_bits = 12345; @@ -392,6 +401,7 @@ TEST(FieldsTest, TestNewDecoderOldData) { } TEST(FieldsTest, TestOldDecoderNewData) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); NewBundle new_bundle; new_bundle.old_large = 123; new_bundle.extensions = 3; @@ -400,7 +410,7 @@ TEST(FieldsTest, TestOldDecoderNewData) { // Write to bit stream constexpr size_t kMaxOutBytes = 999; - BitWriter writer; + BitWriter writer{memory_manager}; // Make sure values are initialized by code under test. size_t extension_bits = 12345; size_t total_bits = 12345; diff --git a/lib/jxl/icc_codec.cc b/lib/jxl/icc_codec.cc index 8501c684ac3..51fb5a44e80 100644 --- a/lib/jxl/icc_codec.cc +++ b/lib/jxl/icc_codec.cc @@ -5,13 +5,11 @@ #include "lib/jxl/icc_codec.h" -#include +#include -#include -#include -#include +#include -#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/status.h" #include "lib/jxl/dec_ans.h" #include "lib/jxl/fields.h" #include "lib/jxl/icc_codec_common.h" @@ -29,9 +27,10 @@ namespace { // scanline order but with missing elements skipped (which may occur in multiple // locations), the output is the result matrix in scanline order (with // no need to skip missing elements as they are past the end of the data). -void Shuffle(uint8_t* data, size_t size, size_t width) { +void Shuffle(JxlMemoryManager* memory_manager, uint8_t* data, size_t size, + size_t width) { size_t height = (size + width - 1) / width; // amount of rows of output - PaddedBytes result(size); + PaddedBytes result(memory_manager, size); // i = output index, j input index size_t s = 0; size_t j = 0; @@ -92,6 +91,7 @@ Status CheckPreamble(const PaddedBytes& data, size_t enc_size, // Decodes the result of PredictICC back to a valid ICC profile. Status UnpredictICC(const uint8_t* enc, size_t size, PaddedBytes* result) { if (!result->empty()) return JXL_FAILURE("result must be empty initially"); + JxlMemoryManager* memory_manager = result->memory_manager(); size_t pos = 0; // TODO(lode): technically speaking we need to check that the entire varint // decoding never goes out of bounds, not just the first byte. This requires @@ -111,7 +111,7 @@ Status UnpredictICC(const uint8_t* enc, size_t size, PaddedBytes* result) { pos = commands_end; // pos in data stream // Header - PaddedBytes header; + PaddedBytes header{memory_manager}; header.append(ICCInitialHeaderPrediction()); EncodeUint32(0, osize, &header); for (size_t i = 0; i <= kICCHeaderSize; i++) { @@ -225,14 +225,14 @@ Status UnpredictICC(const uint8_t* enc, size_t size, PaddedBytes* result) { if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); uint64_t num = DecodeVarInt(enc, size, &cpos); JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size)); - PaddedBytes shuffled(num); + PaddedBytes shuffled(memory_manager, num); for (size_t i = 0; i < num; i++) { shuffled[i] = enc[pos + i]; } if (command == kCommandShuffle2) { - Shuffle(shuffled.data(), num, 2); + Shuffle(memory_manager, shuffled.data(), num, 2); } else if (command == kCommandShuffle4) { - Shuffle(shuffled.data(), num, 4); + Shuffle(memory_manager, shuffled.data(), num, 4); } for (size_t i = 0; i < num; i++) { result->push_back(shuffled[i]); @@ -269,11 +269,11 @@ Status UnpredictICC(const uint8_t* enc, size_t size, PaddedBytes* result) { uint64_t num = DecodeVarInt(enc, size, &cpos); // in bytes JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size)); - PaddedBytes shuffled(num); + PaddedBytes shuffled(memory_manager, num); for (size_t i = 0; i < num; i++) { shuffled[i] = enc[pos + i]; } - if (width > 1) Shuffle(shuffled.data(), num, width); + if (width > 1) Shuffle(memory_manager, shuffled.data(), num, width); size_t start = result->size(); for (size_t i = 0; i < num; i++) { @@ -308,6 +308,7 @@ Status UnpredictICC(const uint8_t* enc, size_t size, PaddedBytes* result) { Status ICCReader::Init(BitReader* reader, size_t output_limit) { JXL_RETURN_IF_ERROR(CheckEOI(reader)); + JxlMemoryManager* memory_manager = decompressed_.memory_manager(); used_bits_base_ = reader->TotalBitsConsumed(); if (bits_to_skip_ == 0) { enc_size_ = U64Coder::Read(reader); @@ -315,9 +316,9 @@ Status ICCReader::Init(BitReader* reader, size_t output_limit) { // Avoid too large memory allocation for invalid file. return JXL_FAILURE("Too large encoded profile"); } - JXL_RETURN_IF_ERROR( - DecodeHistograms(reader, kNumICCContexts, &code_, &context_map_)); - ans_reader_ = ANSSymbolReader(&code_, reader); + JXL_RETURN_IF_ERROR(DecodeHistograms( + memory_manager, reader, kNumICCContexts, &code_, &context_map_)); + JXL_ASSIGN_OR_RETURN(ans_reader_, ANSSymbolReader::Create(&code_, reader)); i_ = 0; decompressed_.resize(std::min(i_ + 0x400, enc_size_)); for (; i_ < std::min(2, enc_size_); i_++) { diff --git a/lib/jxl/icc_codec.h b/lib/jxl/icc_codec.h index e57018b4c35..d531c0a736e 100644 --- a/lib/jxl/icc_codec.h +++ b/lib/jxl/icc_codec.h @@ -8,6 +8,8 @@ // Compressed representation of ICC profiles. +#include + #include #include #include @@ -20,6 +22,9 @@ namespace jxl { struct ICCReader { + explicit ICCReader(JxlMemoryManager* memory_manager) + : decompressed_(memory_manager) {} + Status Init(BitReader* reader, size_t output_limit); Status Process(BitReader* reader, PaddedBytes* icc); void Reset() { diff --git a/lib/jxl/icc_codec_common.cc b/lib/jxl/icc_codec_common.cc index 1cb45426874..4bea0716ff7 100644 --- a/lib/jxl/icc_codec_common.cc +++ b/lib/jxl/icc_codec_common.cc @@ -5,14 +5,9 @@ #include "lib/jxl/icc_codec_common.h" -#include - -#include -#include -#include +#include #include "lib/jxl/base/byte_order.h" -#include "lib/jxl/fields.h" #include "lib/jxl/padded_bytes.h" namespace jxl { diff --git a/lib/jxl/icc_codec_test.cc b/lib/jxl/icc_codec_test.cc index 175b4768a08..b563bda2b39 100644 --- a/lib/jxl/icc_codec_test.cc +++ b/lib/jxl/icc_codec_test.cc @@ -3,7 +3,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -#include "lib/jxl/icc_codec.h" +#include #include #include @@ -19,7 +19,8 @@ namespace jxl { namespace { void TestProfile(const IccBytes& icc) { - BitWriter writer; + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); + BitWriter writer{memory_manager}; ASSERT_TRUE(WriteICC(icc, &writer, 0, nullptr)); writer.ZeroPadToByte(); std::vector dec; diff --git a/lib/jxl/image_bundle_test.cc b/lib/jxl/image_bundle_test.cc index 1a10598fe2d..cf94507ac59 100644 --- a/lib/jxl/image_bundle_test.cc +++ b/lib/jxl/image_bundle_test.cc @@ -3,18 +3,20 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -#include "lib/jxl/image_bundle.h" +#include #include "lib/jxl/enc_aux_out.h" #include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/test_utils.h" #include "lib/jxl/testing.h" namespace jxl { namespace { TEST(ImageBundleTest, ExtraChannelName) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); AuxOut aux_out; - BitWriter writer; + BitWriter writer{memory_manager}; BitWriter::Allotment allotment(&writer, 99); ImageMetadata metadata; diff --git a/lib/jxl/jpeg/enc_jpeg_data.cc b/lib/jxl/jpeg/enc_jpeg_data.cc index 92b9eea1ca8..01b6976a719 100644 --- a/lib/jxl/jpeg/enc_jpeg_data.cc +++ b/lib/jxl/jpeg/enc_jpeg_data.cc @@ -303,7 +303,8 @@ Status SetColorTransformFromJpegData(const JPEGData& jpg, return true; } -Status EncodeJPEGData(JPEGData& jpeg_data, std::vector* bytes, +Status EncodeJPEGData(JxlMemoryManager* memory_manager, JPEGData& jpeg_data, + std::vector* bytes, const CompressParams& cparams) { bytes->clear(); jpeg_data.app_marker_type.resize(jpeg_data.app_data.size(), @@ -327,7 +328,7 @@ Status EncodeJPEGData(JPEGData& jpeg_data, std::vector* bytes, total_data += jpeg_data.tail_data.size(); size_t brotli_capacity = BrotliEncoderMaxCompressedSize(total_data); - BitWriter writer; + BitWriter writer{memory_manager}; JXL_RETURN_IF_ERROR(Bundle::Write(jpeg_data, &writer, 0, nullptr)); writer.ZeroPadToByte(); { diff --git a/lib/jxl/jpeg/enc_jpeg_data.h b/lib/jxl/jpeg/enc_jpeg_data.h index f9a3a95e236..3313c1eea19 100644 --- a/lib/jxl/jpeg/enc_jpeg_data.h +++ b/lib/jxl/jpeg/enc_jpeg_data.h @@ -6,6 +6,8 @@ #ifndef LIB_JXL_JPEG_ENC_JPEG_DATA_H_ #define LIB_JXL_JPEG_ENC_JPEG_DATA_H_ +#include + #include #include @@ -20,7 +22,8 @@ namespace jxl { class CodecInOut; namespace jpeg { -Status EncodeJPEGData(JPEGData& jpeg_data, std::vector* bytes, +Status EncodeJPEGData(JxlMemoryManager* memory_manager, JPEGData& jpeg_data, + std::vector* bytes, const CompressParams& cparams); void SetColorEncodingFromJpegData(const jpeg::JPEGData& jpg, diff --git a/lib/jxl/modular/encoding/dec_ma.cc b/lib/jxl/modular/encoding/dec_ma.cc index b53b9a91032..08b3dcac77d 100644 --- a/lib/jxl/modular/encoding/dec_ma.cc +++ b/lib/jxl/modular/encoding/dec_ma.cc @@ -5,9 +5,12 @@ #include "lib/jxl/modular/encoding/dec_ma.h" +#include + #include #include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" #include "lib/jxl/dec_ans.h" #include "lib/jxl/modular/encoding/ma_common.h" #include "lib/jxl/modular/modular_image.h" @@ -89,16 +92,18 @@ Status DecodeTree(BitReader *br, ANSSymbolReader *reader, } } // namespace -Status DecodeTree(BitReader *br, Tree *tree, size_t tree_size_limit) { +Status DecodeTree(JxlMemoryManager *memory_manager, BitReader *br, Tree *tree, + size_t tree_size_limit) { std::vector tree_context_map; ANSCode tree_code; - JXL_RETURN_IF_ERROR( - DecodeHistograms(br, kNumTreeContexts, &tree_code, &tree_context_map)); + JXL_RETURN_IF_ERROR(DecodeHistograms(memory_manager, br, kNumTreeContexts, + &tree_code, &tree_context_map)); // TODO(eustas): investigate more infinite tree cases. if (tree_code.degenerate_symbols[tree_context_map[kPropertyContext]] > 0) { return JXL_FAILURE("Infinite tree"); } - ANSSymbolReader reader(&tree_code, br); + JXL_ASSIGN_OR_RETURN(ANSSymbolReader reader, + ANSSymbolReader::Create(&tree_code, br)); JXL_RETURN_IF_ERROR(DecodeTree(br, &reader, tree_context_map, tree, std::min(tree_size_limit, kMaxTreeSize))); if (!reader.CheckANSFinalState()) { diff --git a/lib/jxl/modular/encoding/dec_ma.h b/lib/jxl/modular/encoding/dec_ma.h index a910c4deb1f..1bce21b32b8 100644 --- a/lib/jxl/modular/encoding/dec_ma.h +++ b/lib/jxl/modular/encoding/dec_ma.h @@ -6,9 +6,10 @@ #ifndef LIB_JXL_MODULAR_ENCODING_DEC_MA_H_ #define LIB_JXL_MODULAR_ENCODING_DEC_MA_H_ -#include -#include +#include +#include +#include #include #include "lib/jxl/base/status.h" @@ -59,7 +60,8 @@ struct PropertyDecisionNode { using Tree = std::vector; -Status DecodeTree(BitReader *br, Tree *tree, size_t tree_size_limit); +Status DecodeTree(JxlMemoryManager *memory_manager, BitReader *br, Tree *tree, + size_t tree_size_limit); } // namespace jxl diff --git a/lib/jxl/modular/encoding/enc_encoding.cc b/lib/jxl/modular/encoding/enc_encoding.cc index ae8ab4fe07b..86b501849ef 100644 --- a/lib/jxl/modular/encoding/enc_encoding.cc +++ b/lib/jxl/modular/encoding/enc_encoding.cc @@ -535,6 +535,7 @@ Status ModularEncode(const Image &image, const ModularOptions &options, GroupHeader *header, std::vector *tokens, size_t *width) { if (image.error) return JXL_FAILURE("Invalid image"); + JxlMemoryManager *memory_manager = image.memory_manager(); size_t nb_channels = image.channel.size(); JXL_DEBUG_V( 2, "Encoding %" PRIuS "-channel, %i-bit, %" PRIuS "x%" PRIuS " image.", @@ -641,9 +642,9 @@ Status ModularEncode(const Image &image, const ModularOptions &options, } */ // Write tree - BuildAndEncodeHistograms(options.histogram_params, kNumTreeContexts, - tree_tokens, &code, &context_map, writer, - kLayerModularTree, aux_out); + BuildAndEncodeHistograms(memory_manager, options.histogram_params, + kNumTreeContexts, tree_tokens, &code, &context_map, + writer, kLayerModularTree, aux_out); WriteTokens(tree_tokens[0], code, context_map, 0, writer, kLayerModularTree, aux_out); } @@ -690,9 +691,9 @@ Status ModularEncode(const Image &image, const ModularOptions &options, std::vector context_map; HistogramParams histo_params = options.histogram_params; histo_params.image_widths.push_back(image_width); - BuildAndEncodeHistograms(histo_params, (tree->size() + 1) / 2, - tokens_storage, &code, &context_map, writer, layer, - aux_out); + BuildAndEncodeHistograms(memory_manager, histo_params, + (tree->size() + 1) / 2, tokens_storage, &code, + &context_map, writer, layer, aux_out); WriteTokens(tokens_storage[0], code, context_map, 0, writer, layer, aux_out); } else { diff --git a/lib/jxl/modular/encoding/encoding.cc b/lib/jxl/modular/encoding/encoding.cc index ce6001914d2..5bd56b11892 100644 --- a/lib/jxl/modular/encoding/encoding.cc +++ b/lib/jxl/modular/encoding/encoding.cc @@ -13,6 +13,7 @@ #include "lib/jxl/base/printf_macros.h" #include "lib/jxl/base/scope_guard.h" +#include "lib/jxl/base/status.h" #include "lib/jxl/dec_ans.h" #include "lib/jxl/dec_bit_reader.h" #include "lib/jxl/frame_dimensions.h" @@ -533,6 +534,7 @@ Status ModularDecode(BitReader *br, Image &image, GroupHeader &header, const std::vector *global_ctx_map, const bool allow_truncated_group) { if (image.channel.empty()) return true; + JxlMemoryManager *memory_manager = image.memory_manager(); // decode transforms Status status = Bundle::Read(br, &header); @@ -607,8 +609,10 @@ Status ModularDecode(BitReader *br, Image &image, GroupHeader &header, max_tree_size += pixels; } max_tree_size = std::min(static_cast(1 << 20), max_tree_size); - JXL_RETURN_IF_ERROR(DecodeTree(br, &tree_storage, max_tree_size)); - JXL_RETURN_IF_ERROR(DecodeHistograms(br, (tree_storage.size() + 1) / 2, + JXL_RETURN_IF_ERROR( + DecodeTree(memory_manager, br, &tree_storage, max_tree_size)); + JXL_RETURN_IF_ERROR(DecodeHistograms(memory_manager, br, + (tree_storage.size() + 1) / 2, &code_storage, &context_map_storage)); } else { if (!global_tree || !global_code || !global_ctx_map || @@ -621,7 +625,8 @@ Status ModularDecode(BitReader *br, Image &image, GroupHeader &header, } // Read channels - ANSSymbolReader reader(code, br, distance_multiplier); + JXL_ASSIGN_OR_RETURN(ANSSymbolReader reader, + ANSSymbolReader::Create(code, br, distance_multiplier)); auto tree_lut = jxl::make_unique>(); for (; next_channel < nb_channels; next_channel++) { Channel &channel = image.channel[next_channel]; diff --git a/lib/jxl/modular_test.cc b/lib/jxl/modular_test.cc index 27473ac57b6..cdb9c30ac90 100644 --- a/lib/jxl/modular_test.cc +++ b/lib/jxl/modular_test.cc @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -21,6 +22,7 @@ #include "lib/extras/enc/jxl.h" #include "lib/extras/metrics.h" #include "lib/extras/packed_image.h" +#include "lib/jxl/base/common.h" #include "lib/jxl/base/compiler_specific.h" #include "lib/jxl/base/data_parallel.h" #include "lib/jxl/base/random.h" @@ -234,7 +236,7 @@ TEST(ModularTest, RoundtripExtraProperties) { } } ZeroFillImage(&image.channel[1].plane); - BitWriter writer; + BitWriter writer{memory_manager}; ASSERT_TRUE(ModularGenericCompress(image, options, &writer)); writer.ZeroPadToByte(); JXL_ASSIGN_OR_DIE(Image decoded, @@ -458,24 +460,26 @@ void WriteHistograms(BitWriter* writer) { } TEST(ModularTest, PredictorIntegerOverflow) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); const size_t xsize = 1; const size_t ysize = 1; - BitWriter writer; + BitWriter writer{memory_manager}; WriteHeaders(&writer, xsize, ysize); - std::vector group_codes(1); + std::vector> group_codes; + group_codes.emplace_back(jxl::make_unique(memory_manager)); { - BitWriter* bw = group_codes.data(); - BitWriter::Allotment allotment(bw, 1 << 20); - WriteHistograms(bw); + std::unique_ptr& bw = group_codes[0]; + BitWriter::Allotment allotment(bw.get(), 1 << 20); + WriteHistograms(bw.get()); GroupHeader header; header.use_global_tree = true; - EXPECT_TRUE(Bundle::Write(header, bw, 0, nullptr)); + EXPECT_TRUE(Bundle::Write(header, bw.get(), 0, nullptr)); // After UnpackSigned this becomes (1 << 31) - 1, the largest pixel_type, // and after adding the offset we get -(1 << 31). bw->Write(8, 119); bw->Write(28, 0xfffffff); bw->ZeroPadToByte(); - allotment.ReclaimAndCharge(bw, 0, nullptr); + allotment.ReclaimAndCharge(bw.get(), 0, nullptr); } EXPECT_TRUE(WriteGroupOffsets(group_codes, {}, &writer, nullptr)); writer.AppendByteAligned(group_codes); @@ -493,16 +497,18 @@ TEST(ModularTest, PredictorIntegerOverflow) { } TEST(ModularTest, UnsqueezeIntegerOverflow) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); // Image width is 9 so we can test both the SIMD and non-vector code paths. const size_t xsize = 9; const size_t ysize = 2; - BitWriter writer; + BitWriter writer{memory_manager}; WriteHeaders(&writer, xsize, ysize); - std::vector group_codes(1); + std::vector> group_codes; + group_codes.emplace_back(jxl::make_unique(memory_manager)); { - BitWriter* bw = group_codes.data(); - BitWriter::Allotment allotment(bw, 1 << 20); - WriteHistograms(bw); + std::unique_ptr& bw = group_codes[0]; + BitWriter::Allotment allotment(bw.get(), 1 << 20); + WriteHistograms(bw.get()); GroupHeader header; header.use_global_tree = true; header.transforms.emplace_back(); @@ -513,7 +519,7 @@ TEST(ModularTest, UnsqueezeIntegerOverflow) { params.begin_c = 0; params.num_c = 1; header.transforms[0].squeezes.emplace_back(params); - EXPECT_TRUE(Bundle::Write(header, bw, 0, nullptr)); + EXPECT_TRUE(Bundle::Write(header, bw.get(), 0, nullptr)); for (size_t i = 0; i < xsize * ysize; ++i) { // After UnpackSigned and adding offset, this becomes (1 << 31) - 1, both // in the image and in the residual channels, and unsqueeze makes them @@ -523,7 +529,7 @@ TEST(ModularTest, UnsqueezeIntegerOverflow) { bw->Write(28, 0xffffffe); } bw->ZeroPadToByte(); - allotment.ReclaimAndCharge(bw, 0, nullptr); + allotment.ReclaimAndCharge(bw.get(), 0, nullptr); } EXPECT_TRUE(WriteGroupOffsets(group_codes, {}, &writer, nullptr)); writer.AppendByteAligned(group_codes); diff --git a/lib/jxl/padded_bytes.h b/lib/jxl/padded_bytes.h index 38167ed4080..321fb6cd9a7 100644 --- a/lib/jxl/padded_bytes.h +++ b/lib/jxl/padded_bytes.h @@ -8,11 +8,12 @@ // std::vector replacement with padding to reduce bounds checks in WriteBits -#include -#include -#include // memcpy +#include #include // max +#include +#include +#include // memcpy #include #include // swap @@ -29,25 +30,31 @@ namespace jxl { class PaddedBytes { public: // Required for output params. - PaddedBytes() : size_(0), capacity_(0) {} + explicit PaddedBytes(JxlMemoryManager* memory_manager) + : memory_manager_(memory_manager), size_(0), capacity_(0) {} - explicit PaddedBytes(size_t size) : size_(size), capacity_(0) { + PaddedBytes(JxlMemoryManager* memory_manager, size_t size) + : memory_manager_(memory_manager), size_(size), capacity_(0) { reserve(size); } - PaddedBytes(size_t size, uint8_t value) : size_(size), capacity_(0) { + PaddedBytes(JxlMemoryManager* memory_manager, size_t size, uint8_t value) + : memory_manager_(memory_manager), size_(size), capacity_(0) { reserve(size); if (size_ != 0) { memset(data(), value, size); } } - PaddedBytes(const PaddedBytes& other) : size_(other.size_), capacity_(0) { + PaddedBytes(const PaddedBytes& other) + : memory_manager_(other.memory_manager_), + size_(other.size_), + capacity_(0) { reserve(size_); if (data() != nullptr) memcpy(data(), other.data(), size_); } PaddedBytes& operator=(const PaddedBytes& other) { - // Self-assignment is safe. + if (this == &other) return *this; resize(other.size()); if (data() != nullptr) memmove(data(), other.data(), size_); return *this; @@ -55,12 +62,14 @@ class PaddedBytes { // default is not OK - need to set other.size_ to 0! PaddedBytes(PaddedBytes&& other) noexcept - : size_(other.size_), + : memory_manager_(other.memory_manager_), + size_(other.size_), capacity_(other.capacity_), data_(std::move(other.data_)) { other.size_ = other.capacity_ = 0; } PaddedBytes& operator=(PaddedBytes&& other) noexcept { + memory_manager_ = other.memory_manager_; size_ = other.size_; capacity_ = other.capacity_; data_ = std::move(other.data_); @@ -71,7 +80,10 @@ class PaddedBytes { return *this; } + JxlMemoryManager* memory_manager() const { return memory_manager_; } + void swap(PaddedBytes& other) noexcept { + std::swap(memory_manager_, other.memory_manager_); std::swap(size_, other.size_); std::swap(capacity_, other.capacity_); std::swap(data_, other.data_); @@ -198,6 +210,7 @@ class PaddedBytes { JXL_ASSERT(i <= size()); } + JxlMemoryManager* memory_manager_; size_t size_; size_t capacity_; CacheAlignedUniquePtr data_; diff --git a/lib/jxl/padded_bytes_test.cc b/lib/jxl/padded_bytes_test.cc index 83d1da9c254..0791964d9f1 100644 --- a/lib/jxl/padded_bytes_test.cc +++ b/lib/jxl/padded_bytes_test.cc @@ -5,16 +5,17 @@ #include "lib/jxl/padded_bytes.h" -#include // iota -#include +#include +#include "lib/jxl/test_utils.h" #include "lib/jxl/testing.h" namespace jxl { namespace { TEST(PaddedBytesTest, TestNonEmptyFirstByteZero) { - PaddedBytes pb(1); + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); + PaddedBytes pb(memory_manager, 1); EXPECT_EQ(0, pb[0]); // Even after resizing.. pb.resize(20); @@ -25,14 +26,16 @@ TEST(PaddedBytesTest, TestNonEmptyFirstByteZero) { } TEST(PaddedBytesTest, TestEmptyFirstByteZero) { - PaddedBytes pb(0); + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); + PaddedBytes pb(memory_manager, 0); // After resizing - new zero is written despite there being nothing to copy. pb.resize(20); EXPECT_EQ(0, pb[0]); } TEST(PaddedBytesTest, TestFillWithoutReserve) { - PaddedBytes pb; + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); + PaddedBytes pb{memory_manager}; for (size_t i = 0; i < 170u; ++i) { pb.push_back(i); } @@ -41,7 +44,8 @@ TEST(PaddedBytesTest, TestFillWithoutReserve) { } TEST(PaddedBytesTest, TestFillWithExactReserve) { - PaddedBytes pb; + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); + PaddedBytes pb{memory_manager}; pb.reserve(170); for (size_t i = 0; i < 170u; ++i) { pb.push_back(i); @@ -51,7 +55,8 @@ TEST(PaddedBytesTest, TestFillWithExactReserve) { } TEST(PaddedBytesTest, TestFillWithMoreReserve) { - PaddedBytes pb; + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); + PaddedBytes pb{memory_manager}; pb.reserve(171); for (size_t i = 0; i < 170u; ++i) { pb.push_back(i); diff --git a/lib/jxl/quant_weights_test.cc b/lib/jxl/quant_weights_test.cc index 086eacaf98c..28ebc14e566 100644 --- a/lib/jxl/quant_weights_test.cc +++ b/lib/jxl/quant_weights_test.cc @@ -4,6 +4,8 @@ // license that can be found in the LICENSE file. #include "lib/jxl/quant_weights.h" +#include + #include #include #include @@ -56,9 +58,10 @@ void CheckSimilar(float a, float b) { } TEST(QuantWeightsTest, DC) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); DequantMatrices mat; float dc_quant[3] = {1e+5, 1e+3, 1e+1}; - DequantMatricesSetCustomDC(&mat, dc_quant); + DequantMatricesSetCustomDC(memory_manager, &mat, dc_quant); for (size_t c = 0; c < 3; c++) { CheckSimilar(mat.InvDCQuant(c), dc_quant[c]); } @@ -182,6 +185,7 @@ class QuantWeightsTargetTest : public hwy::TestWithParamTarget {}; HWY_TARGET_INSTANTIATE_TEST_SUITE_P(QuantWeightsTargetTest); TEST_P(QuantWeightsTargetTest, DCTUniform) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); constexpr float kUniformQuant = 4; float weights[3][2] = {{1.0f / kUniformQuant, 0}, {1.0f / kUniformQuant, 0}, @@ -199,7 +203,7 @@ TEST_P(QuantWeightsTargetTest, DCTUniform) { const float dc_quant[3] = {1.0f / kUniformQuant, 1.0f / kUniformQuant, 1.0f / kUniformQuant}; - DequantMatricesSetCustomDC(&dequant_matrices, dc_quant); + DequantMatricesSetCustomDC(memory_manager, &dequant_matrices, dc_quant); HWY_ALIGN_MAX float scratch_space[16 * 16 * 5]; diff --git a/lib/jxl/quantizer_test.cc b/lib/jxl/quantizer_test.cc index 0a2e3c9fc6d..9057da7a29e 100644 --- a/lib/jxl/quantizer_test.cc +++ b/lib/jxl/quantizer_test.cc @@ -42,6 +42,7 @@ TEST(QuantizerTest, QuantizerParams) { } TEST(QuantizerTest, BitStreamRoundtripSameQuant) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); const int qxsize = 8; const int qysize = 8; DequantMatrices dequant; @@ -49,7 +50,7 @@ TEST(QuantizerTest, BitStreamRoundtripSameQuant) { JXL_ASSIGN_OR_DIE(ImageI raw_quant_field, ImageI::Create(jxl::test::MemoryManager(), qxsize, qysize)); quantizer1.SetQuant(0.17f, 0.17f, &raw_quant_field); - BitWriter writer; + BitWriter writer{memory_manager}; QuantizerParams params = quantizer1.GetParams(); EXPECT_TRUE(WriteQuantizerParams(params, &writer, 0, nullptr)); writer.ZeroPadToByte(); @@ -76,7 +77,7 @@ TEST(QuantizerTest, BitStreamRoundtripRandomQuant) { JXL_ASSIGN_OR_DIE(ImageF qf, ImageF::Create(memory_manager, qxsize, qysize)); RandomFillImage(&qf, 0.0f, 1.0f); quantizer1.SetQuantField(quant_dc, qf, &raw_quant_field); - BitWriter writer; + BitWriter writer{memory_manager}; QuantizerParams params = quantizer1.GetParams(); EXPECT_TRUE(WriteQuantizerParams(params, &writer, 0, nullptr)); writer.ZeroPadToByte(); diff --git a/lib/jxl/render_pipeline/stage_write.cc b/lib/jxl/render_pipeline/stage_write.cc index be2ed021150..3687dab70fa 100644 --- a/lib/jxl/render_pipeline/stage_write.cc +++ b/lib/jxl/render_pipeline/stage_write.cc @@ -7,6 +7,7 @@ #include +#include #include #include "lib/jxl/alpha.h" @@ -116,7 +117,8 @@ class WriteToOutputStage : public RenderPipelineStage { WriteToOutputStage(const ImageOutput& main_output, size_t width, size_t height, bool has_alpha, bool unpremul_alpha, size_t alpha_c, Orientation undo_orientation, - const std::vector& extra_output) + const std::vector& extra_output, + JxlMemoryManager* memory_manager) : RenderPipelineStage(RenderPipelineStage::Settings()), width_(width), height_(height), @@ -539,10 +541,10 @@ constexpr size_t WriteToOutputStage::kMaxPixelsPerCall; std::unique_ptr GetWriteToOutputStage( const ImageOutput& main_output, size_t width, size_t height, bool has_alpha, bool unpremul_alpha, size_t alpha_c, Orientation undo_orientation, - std::vector& extra_output) { + std::vector& extra_output, JxlMemoryManager* memory_manager) { return jxl::make_unique( main_output, width, height, has_alpha, unpremul_alpha, alpha_c, - undo_orientation, extra_output); + undo_orientation, extra_output, memory_manager); } // NOLINTNEXTLINE(google-readability-namespace-comments) @@ -681,10 +683,10 @@ std::unique_ptr GetWriteToImage3FStage( std::unique_ptr GetWriteToOutputStage( const ImageOutput& main_output, size_t width, size_t height, bool has_alpha, bool unpremul_alpha, size_t alpha_c, Orientation undo_orientation, - std::vector& extra_output) { + std::vector& extra_output, JxlMemoryManager* memory_manager) { return HWY_DYNAMIC_DISPATCH(GetWriteToOutputStage)( main_output, width, height, has_alpha, unpremul_alpha, alpha_c, - undo_orientation, extra_output); + undo_orientation, extra_output, memory_manager); } } // namespace jxl diff --git a/lib/jxl/render_pipeline/stage_write.h b/lib/jxl/render_pipeline/stage_write.h index 4b1db442cc0..9a7f2dbb104 100644 --- a/lib/jxl/render_pipeline/stage_write.h +++ b/lib/jxl/render_pipeline/stage_write.h @@ -32,7 +32,7 @@ std::unique_ptr GetWriteToImage3FStage( std::unique_ptr GetWriteToOutputStage( const ImageOutput& main_output, size_t width, size_t height, bool has_alpha, bool unpremul_alpha, size_t alpha_c, Orientation undo_orientation, - std::vector& extra_output); + std::vector& extra_output, JxlMemoryManager* memory_manager); } // namespace jxl diff --git a/lib/jxl/splines.cc b/lib/jxl/splines.cc index fd68c15493c..538859565b7 100644 --- a/lib/jxl/splines.cc +++ b/lib/jxl/splines.cc @@ -5,6 +5,8 @@ #include "lib/jxl/splines.h" +#include + #include #include // PRIu64 #include @@ -567,12 +569,14 @@ void Splines::Clear() { segment_y_start_.clear(); } -Status Splines::Decode(jxl::BitReader* br, const size_t num_pixels) { +Status Splines::Decode(JxlMemoryManager* memory_manager, jxl::BitReader* br, + const size_t num_pixels) { std::vector context_map; ANSCode code; - JXL_RETURN_IF_ERROR( - DecodeHistograms(br, kNumSplineContexts, &code, &context_map)); - ANSSymbolReader decoder(&code, br); + JXL_RETURN_IF_ERROR(DecodeHistograms(memory_manager, br, kNumSplineContexts, + &code, &context_map)); + JXL_ASSIGN_OR_RETURN(ANSSymbolReader decoder, + ANSSymbolReader::Create(&code, br)); size_t num_splines = decoder.ReadHybridUint(kNumSplinesContext, br, context_map); size_t max_control_points = std::min( diff --git a/lib/jxl/splines.h b/lib/jxl/splines.h index b292d6952b5..90dc2260e54 100644 --- a/lib/jxl/splines.h +++ b/lib/jxl/splines.h @@ -6,6 +6,8 @@ #ifndef LIB_JXL_SPLINES_H_ #define LIB_JXL_SPLINES_H_ +#include + #include #include #include @@ -108,7 +110,8 @@ class Splines { void Clear(); - Status Decode(BitReader* br, size_t num_pixels); + Status Decode(JxlMemoryManager* memory_manager, BitReader* br, + size_t num_pixels); void AddTo(Image3F* opsin, const Rect& opsin_rect, const Rect& image_rect) const; diff --git a/lib/jxl/splines_test.cc b/lib/jxl/splines_test.cc index 8e78b13acdf..843005f0df5 100644 --- a/lib/jxl/splines_test.cc +++ b/lib/jxl/splines_test.cc @@ -73,6 +73,7 @@ std::vector DequantizeSplines(const Splines& splines) { } // namespace TEST(SplinesTest, Serialization) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); Spline spline1{ /*control_points=*/{ {109, 54}, {218, 159}, {80, 3}, {110, 274}, {94, 185}, {17, 277}}, @@ -166,7 +167,7 @@ TEST(SplinesTest, Serialization) { } } - BitWriter writer; + BitWriter writer{memory_manager}; EncodeSplines(splines, &writer, kLayerSplines, HistogramParams(), nullptr); writer.ZeroPadToByte(); const size_t bits_written = writer.BitsWritten(); @@ -175,7 +176,8 @@ TEST(SplinesTest, Serialization) { BitReader reader(writer.GetSpan()); Splines decoded_splines; - ASSERT_TRUE(decoded_splines.Decode(&reader, /*num_pixels=*/1000)); + ASSERT_TRUE( + decoded_splines.Decode(memory_manager, &reader, /*num_pixels=*/1000)); ASSERT_TRUE(reader.JumpToByteBoundary()); EXPECT_EQ(reader.TotalBitsConsumed(), bits_written); ASSERT_TRUE(reader.Close()); @@ -211,6 +213,7 @@ TEST(SplinesTest, DISABLED_TooManySplinesTest) { #else TEST(SplinesTest, TooManySplinesTest) { #endif + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); // This is more than the limit for 1000 pixels. const size_t kNumSplines = 300; @@ -229,14 +232,15 @@ TEST(SplinesTest, TooManySplinesTest) { Splines splines(kQuantizationAdjustment, std::move(quantized_splines), std::move(starting_points)); - BitWriter writer; + BitWriter writer{memory_manager}; EncodeSplines(splines, &writer, kLayerSplines, HistogramParams(SpeedTier::kFalcon, 1), nullptr); writer.ZeroPadToByte(); // Re-read splines. BitReader reader(writer.GetSpan()); Splines decoded_splines; - EXPECT_FALSE(decoded_splines.Decode(&reader, /*num_pixels=*/1000)); + EXPECT_FALSE( + decoded_splines.Decode(memory_manager, &reader, /*num_pixels=*/1000)); EXPECT_TRUE(reader.Close()); } diff --git a/lib/jxl/test_utils.cc b/lib/jxl/test_utils.cc index 51aa653c3c9..66124eae8c0 100644 --- a/lib/jxl/test_utils.cc +++ b/lib/jxl/test_utils.cc @@ -721,9 +721,10 @@ bool SamePixels(const extras::PackedPixelFile& a, Status ReadICC(BitReader* JXL_RESTRICT reader, std::vector* JXL_RESTRICT icc, size_t output_limit) { + JxlMemoryManager* memort_manager = jxl::test::MemoryManager(); icc->clear(); - ICCReader icc_reader; - PaddedBytes icc_buffer; + ICCReader icc_reader{memort_manager}; + PaddedBytes icc_buffer{memort_manager}; JXL_RETURN_IF_ERROR(icc_reader.Init(reader, output_limit)); JXL_RETURN_IF_ERROR(icc_reader.Process(reader, &icc_buffer)); Bytes(icc_buffer).AppendTo(*icc); @@ -759,7 +760,7 @@ Status EncodePreview(const CompressParams& cparams, const ImageBundle& ib, const CodecMetadata* metadata, const JxlCmsInterface& cms, ThreadPool* pool, BitWriter* JXL_RESTRICT writer) { JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); - BitWriter preview_writer; + BitWriter preview_writer{memory_manager}; // TODO(janwas): also support generating preview by downsampling if (ib.HasColor()) { AuxOut aux_out; @@ -790,7 +791,7 @@ Status EncodeFile(const CompressParams& params, const CodecInOut* io, compressed->clear(); const JxlCmsInterface& cms = *JxlGetDefaultCms(); io->CheckMetadata(); - BitWriter writer; + BitWriter writer{memory_manager}; CompressParams cparams = params; if (io->Main().color_transform != ColorTransform::kNone) { diff --git a/lib/jxl/toc.cc b/lib/jxl/toc.cc index 72c8ac01cd7..feeab99cfb7 100644 --- a/lib/jxl/toc.cc +++ b/lib/jxl/toc.cc @@ -5,6 +5,7 @@ #include "lib/jxl/toc.h" +#include #include #include "lib/jxl/base/common.h" @@ -19,7 +20,8 @@ size_t MaxBits(const size_t num_sizes) { return 1 + kBitsPerByte + entry_bits + kBitsPerByte; } -Status ReadToc(size_t toc_entries, BitReader* JXL_RESTRICT reader, +Status ReadToc(JxlMemoryManager* memory_manager, size_t toc_entries, + BitReader* JXL_RESTRICT reader, std::vector* JXL_RESTRICT sizes, std::vector* JXL_RESTRICT permutation) { if (toc_entries > 65536) { @@ -51,8 +53,8 @@ Status ReadToc(size_t toc_entries, BitReader* JXL_RESTRICT reader, if (reader->ReadFixedBits<1>() == 1) { JXL_RETURN_IF_ERROR(check_bit_budget(toc_entries)); permutation->resize(toc_entries); - JXL_RETURN_IF_ERROR(DecodePermutation(/*skip=*/0, toc_entries, - permutation->data(), reader)); + JXL_RETURN_IF_ERROR(DecodePermutation( + memory_manager, /*skip=*/0, toc_entries, permutation->data(), reader)); } JXL_RETURN_IF_ERROR(reader->JumpToByteBoundary()); JXL_RETURN_IF_ERROR(check_bit_budget(toc_entries)); @@ -64,12 +66,14 @@ Status ReadToc(size_t toc_entries, BitReader* JXL_RESTRICT reader, return true; } -Status ReadGroupOffsets(size_t toc_entries, BitReader* JXL_RESTRICT reader, +Status ReadGroupOffsets(JxlMemoryManager* memory_manager, size_t toc_entries, + BitReader* JXL_RESTRICT reader, std::vector* JXL_RESTRICT offsets, std::vector* JXL_RESTRICT sizes, uint64_t* total_size) { std::vector permutation; - JXL_RETURN_IF_ERROR(ReadToc(toc_entries, reader, sizes, &permutation)); + JXL_RETURN_IF_ERROR( + ReadToc(memory_manager, toc_entries, reader, sizes, &permutation)); offsets->clear(); offsets->resize(toc_entries); diff --git a/lib/jxl/toc.h b/lib/jxl/toc.h index f5b9c65763d..bbc0549e67a 100644 --- a/lib/jxl/toc.h +++ b/lib/jxl/toc.h @@ -6,6 +6,7 @@ #ifndef LIB_JXL_TOC_H_ #define LIB_JXL_TOC_H_ +#include #include #include @@ -39,11 +40,13 @@ static JXL_INLINE size_t NumTocEntries(size_t num_groups, size_t num_dc_groups, num_groups * num_passes; } -Status ReadToc(size_t toc_entries, BitReader* JXL_RESTRICT reader, +Status ReadToc(JxlMemoryManager* memory_manager, size_t toc_entries, + BitReader* JXL_RESTRICT reader, std::vector* JXL_RESTRICT sizes, std::vector* JXL_RESTRICT permutation); -Status ReadGroupOffsets(size_t toc_entries, BitReader* JXL_RESTRICT reader, +Status ReadGroupOffsets(JxlMemoryManager* memory_manager, size_t toc_entries, + BitReader* JXL_RESTRICT reader, std::vector* JXL_RESTRICT offsets, std::vector* JXL_RESTRICT sizes, uint64_t* total_size); diff --git a/lib/jxl/toc_test.cc b/lib/jxl/toc_test.cc index 8c95f8bc269..eaefc97bfd8 100644 --- a/lib/jxl/toc_test.cc +++ b/lib/jxl/toc_test.cc @@ -5,20 +5,25 @@ #include "lib/jxl/toc.h" +#include + +#include #include #include "lib/jxl/base/common.h" #include "lib/jxl/base/random.h" -#include "lib/jxl/base/span.h" #include "lib/jxl/coeff_order_fwd.h" #include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_bit_writer.h" #include "lib/jxl/enc_toc.h" +#include "lib/jxl/test_utils.h" #include "lib/jxl/testing.h" namespace jxl { namespace { void Roundtrip(size_t num_entries, bool permute, Rng* rng) { + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); // Generate a random permutation. std::vector permutation; std::vector inv_permutation(num_entries); @@ -37,24 +42,28 @@ void Roundtrip(size_t num_entries, bool permute, Rng* rng) { } // Generate num_entries groups of random (byte-aligned) length - std::vector group_codes(num_entries); - for (BitWriter& writer : group_codes) { + std::vector> group_codes; + group_codes.reserve(num_entries); + for (size_t i = 0; i < num_entries; ++i) { + group_codes.emplace_back(jxl::make_unique(memory_manager)); + } + for (std::unique_ptr& writer : group_codes) { const size_t max_bits = (*rng)() & 0xFFF; - BitWriter::Allotment allotment(&writer, max_bits + kBitsPerByte); + BitWriter::Allotment allotment(writer.get(), max_bits + kBitsPerByte); size_t i = 0; for (; i + BitWriter::kMaxBitsPerCall < max_bits; i += BitWriter::kMaxBitsPerCall) { - writer.Write(BitWriter::kMaxBitsPerCall, 0); + writer->Write(BitWriter::kMaxBitsPerCall, 0); } for (; i < max_bits; i += 1) { - writer.Write(/*n_bits=*/1, 0); + writer->Write(/*n_bits=*/1, 0); } - writer.ZeroPadToByte(); + writer->ZeroPadToByte(); AuxOut aux_out; - allotment.ReclaimAndCharge(&writer, 0, &aux_out); + allotment.ReclaimAndCharge(writer.get(), 0, &aux_out); } - BitWriter writer; + BitWriter writer{memory_manager}; AuxOut aux_out; ASSERT_TRUE(WriteGroupOffsets(group_codes, permutation, &writer, &aux_out)); @@ -62,8 +71,8 @@ void Roundtrip(size_t num_entries, bool permute, Rng* rng) { std::vector group_offsets; std::vector group_sizes; uint64_t total_size; - ASSERT_TRUE(ReadGroupOffsets(num_entries, &reader, &group_offsets, - &group_sizes, &total_size)); + ASSERT_TRUE(ReadGroupOffsets(memory_manager, num_entries, &reader, + &group_offsets, &group_sizes, &total_size)); ASSERT_EQ(num_entries, group_offsets.size()); ASSERT_EQ(num_entries, group_sizes.size()); EXPECT_TRUE(reader.Close()); @@ -72,8 +81,8 @@ void Roundtrip(size_t num_entries, bool permute, Rng* rng) { for (size_t i = 0; i < num_entries; ++i) { EXPECT_EQ(prefix_sum, group_offsets[inv_permutation[i]]); - EXPECT_EQ(0u, group_codes[i].BitsWritten() % kBitsPerByte); - prefix_sum += group_codes[i].BitsWritten() / kBitsPerByte; + EXPECT_EQ(0u, group_codes[i]->BitsWritten() % kBitsPerByte); + prefix_sum += group_codes[i]->BitsWritten() / kBitsPerByte; if (i + 1 < num_entries) { EXPECT_EQ( diff --git a/tools/djxl_fuzzer_corpus.cc b/tools/djxl_fuzzer_corpus.cc index 7e6516d2a5b..bae87009ccc 100644 --- a/tools/djxl_fuzzer_corpus.cc +++ b/tools/djxl_fuzzer_corpus.cc @@ -201,7 +201,7 @@ bool GenerateFile(const char* output_dir, const ImageSpec& spec, std::mt19937 mt(spec.seed); // Compress the image. - jxl::PaddedBytes compressed; + jxl::PaddedBytes compressed{memory_manager}; std::uniform_int_distribution<> dis(1, 6); PixelGenerator gen = [&]() -> uint8_t { return dis(mt); }; @@ -264,8 +264,8 @@ bool GenerateFile(const char* output_dir, const ImageSpec& spec, JXL_RETURN_IF_ERROR(jxl::jpeg::DecodeImageJPG( jxl::Bytes(jpeg_bytes.data(), jpeg_bytes.size()), &io)); std::vector jpeg_data; - JXL_RETURN_IF_ERROR( - EncodeJPEGData(*io.Main().jpeg_data, &jpeg_data, params)); + JXL_RETURN_IF_ERROR(EncodeJPEGData(memory_manager, *io.Main().jpeg_data, + &jpeg_data, params)); std::vector header; header.insert(header.end(), jxl::kContainerHeader.begin(), jxl::kContainerHeader.end()); diff --git a/tools/icc_codec_fuzzer.cc b/tools/icc_codec_fuzzer.cc index 331db17275a..a682d4050fe 100644 --- a/tools/icc_codec_fuzzer.cc +++ b/tools/icc_codec_fuzzer.cc @@ -3,6 +3,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +#include + #include #include #include @@ -44,6 +46,7 @@ int DoTestOneInput(const uint8_t* data, size_t size) { data++; size--; #endif + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); #ifdef JXL_ICC_FUZZER_SLOW_TEST // Including JPEG XL LZ77 and ANS compression. These are already fuzzed @@ -57,9 +60,9 @@ int DoTestOneInput(const uint8_t* data, size_t size) { (void)br.Close(); } else { // Writing parses the original ICC profile. - PaddedBytes icc; + PaddedBytes icc{memory_manager}; icc.assign(data, data + size); - BitWriter writer; + BitWriter writer{memory_manager}; // Writing should support any random bytestream so must succeed, make // fuzzer fail if not. JXL_ASSERT(jxl::WriteICC(icc, &writer, 0, nullptr)); @@ -67,15 +70,15 @@ int DoTestOneInput(const uint8_t* data, size_t size) { #else // JXL_ICC_FUZZER_SLOW_TEST if (read) { // Reading (unpredicting) parses the compressed format. - PaddedBytes result; + PaddedBytes result{memory_manager}; (void)jxl::UnpredictICC(data, size, &result); } else { // Writing (predicting) parses the original ICC profile. - PaddedBytes result; + PaddedBytes result{memory_manager}; // Writing should support any random bytestream so must succeed, make // fuzzer fail if not. JXL_ASSERT(jxl::PredictICC(data, size, &result)); - PaddedBytes reconstructed; + PaddedBytes reconstructed{memory_manager}; JXL_ASSERT(jxl::UnpredictICC(result.data(), result.size(), &reconstructed)); JXL_ASSERT(reconstructed.size() == size); JXL_ASSERT(memcmp(data, reconstructed.data(), size) == 0); diff --git a/tools/jxl_from_tree.cc b/tools/jxl_from_tree.cc index 1ca7aa018cc..899fb5bf4e6 100644 --- a/tools/jxl_from_tree.cc +++ b/tools/jxl_from_tree.cc @@ -492,10 +492,10 @@ int JxlFromTree(const char* in, const char* out, const char* tree_out) { cparams.already_downsampled = true; cparams.custom_fixed_tree = tree; cparams.custom_splines = SplinesFromSplineData(spline_data); - PaddedBytes compressed; + PaddedBytes compressed{memory_manager}; io.CheckMetadata(); - BitWriter writer; + BitWriter writer{memory_manager}; std::unique_ptr metadata = jxl::make_unique(); *metadata = io.metadata; diff --git a/tools/rans_fuzzer.cc b/tools/rans_fuzzer.cc index b8339303856..526fc1426fa 100644 --- a/tools/rans_fuzzer.cc +++ b/tools/rans_fuzzer.cc @@ -3,6 +3,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +#include + #include #include @@ -29,14 +31,16 @@ int DoTestOneInput(const uint8_t* data, size_t size) { size -= 2; std::vector context_map; + JxlMemoryManager* memory_manager = jxl::test::MemoryManager(); Status ret = true; { BitReader br(Bytes(data, size)); BitReaderScopedCloser br_closer(&br, &ret); ANSCode code; - JXL_RETURN_IF_ERROR( - DecodeHistograms(&br, numContexts, &code, &context_map)); - ANSSymbolReader ansreader(&code, &br); + JXL_RETURN_IF_ERROR(DecodeHistograms(memory_manager, &br, numContexts, + &code, &context_map)); + JXL_ASSIGN_OR_DIE(ANSSymbolReader ansreader, + ANSSymbolReader::Create(&code, &br)); // Limit the maximum amount of reads to avoid (valid) infinite loops. const size_t maxreads = size * 8;