Skip to content

Commit

Permalink
Spread JxlMemoryManager (#3569)
Browse files Browse the repository at this point in the history
(for PaddedBytes and its containers, most notably BitWriter)
  • Loading branch information
eustas committed May 13, 2024
1 parent 13adc43 commit 7e9f6f2
Show file tree
Hide file tree
Showing 71 changed files with 634 additions and 448 deletions.
37 changes: 23 additions & 14 deletions lib/jxl/ans_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include <stddef.h>
#include <stdint.h>
#include <jxl/memory_manager.h>

#include <cstddef>
#include <cstdint>
#include <vector>

#include "lib/jxl/ans_params.h"
Expand All @@ -15,17 +16,19 @@
#include "lib/jxl/dec_bit_reader.h"
#include "lib/jxl/enc_ans.h"
#include "lib/jxl/enc_bit_writer.h"
#include "lib/jxl/test_utils.h"
#include "lib/jxl/testing.h"

namespace jxl {
namespace {

void RoundtripTestcase(int n_histograms, int alphabet_size,
const std::vector<Token>& input_values) {
JxlMemoryManager* memory_manager = jxl::test::MemoryManager();
constexpr uint16_t kMagic1 = 0x9e33;
constexpr uint16_t kMagic2 = 0x8b04;

BitWriter writer;
BitWriter writer{memory_manager};
// Space for magic bytes.
BitWriter::Allotment allotment_magic1(&writer, 16);
writer.Write(16, kMagic1);
Expand All @@ -36,8 +39,9 @@ void RoundtripTestcase(int n_histograms, int alphabet_size,
std::vector<std::vector<Token>> input_values_vec;
input_values_vec.push_back(input_values);

BuildAndEncodeHistograms(HistogramParams(), n_histograms, input_values_vec,
&codes, &context_map, &writer, 0, nullptr);
BuildAndEncodeHistograms(memory_manager, HistogramParams(), n_histograms,
input_values_vec, &codes, &context_map, &writer, 0,
nullptr);
WriteTokens(input_values_vec[0], codes, context_map, 0, &writer, 0, nullptr);

// Magic bytes + padding
Expand All @@ -54,10 +58,11 @@ void RoundtripTestcase(int n_histograms, int alphabet_size,

std::vector<uint8_t> dec_context_map;
ANSCode decoded_codes;
ASSERT_TRUE(
DecodeHistograms(&br, n_histograms, &decoded_codes, &dec_context_map));
ASSERT_TRUE(DecodeHistograms(memory_manager, &br, n_histograms,
&decoded_codes, &dec_context_map));
ASSERT_EQ(dec_context_map, context_map);
ANSSymbolReader reader(&decoded_codes, &br);
JXL_ASSIGN_OR_DIE(ANSSymbolReader reader,
ANSSymbolReader::Create(&decoded_codes, &br));

for (const Token& symbol : input_values) {
uint32_t read_symbol =
Expand Down Expand Up @@ -156,6 +161,7 @@ TEST(ANSTest, RandomUnbalancedStreamRoundtripBig) {
}

TEST(ANSTest, UintConfigRoundtrip) {
JxlMemoryManager* memory_manager = jxl::test::MemoryManager();
for (size_t log_alpha_size = 5; log_alpha_size <= 8; log_alpha_size++) {
std::vector<HybridUintConfig> uint_config;
std::vector<HybridUintConfig> uint_config_dec;
Expand All @@ -168,7 +174,7 @@ TEST(ANSTest, UintConfigRoundtrip) {
}
uint_config.emplace_back(log_alpha_size, 0, 0);
uint_config_dec.resize(uint_config.size());
BitWriter writer;
BitWriter writer{memory_manager};
BitWriter::Allotment allotment(&writer, 10 * uint_config.size());
EncodeUintConfigs(uint_config, &writer, log_alpha_size);
allotment.ReclaimAndCharge(&writer, 0, nullptr);
Expand All @@ -185,6 +191,7 @@ TEST(ANSTest, UintConfigRoundtrip) {
}

void TestCheckpointing(bool ans, bool lz77) {
JxlMemoryManager* memory_manager = jxl::test::MemoryManager();
std::vector<std::vector<Token>> input_values(1);
for (size_t i = 0; i < 1024; i++) {
input_values[0].emplace_back(0, i % 4);
Expand All @@ -206,11 +213,11 @@ void TestCheckpointing(bool ans, bool lz77) {
: HistogramParams::LZ77Method::kNone;
params.force_huffman = !ans;

BitWriter writer;
BitWriter writer{memory_manager};
{
auto input_values_copy = input_values;
BuildAndEncodeHistograms(params, 1, input_values_copy, &codes, &context_map,
&writer, 0, nullptr);
BuildAndEncodeHistograms(memory_manager, params, 1, input_values_copy,
&codes, &context_map, &writer, 0, nullptr);
WriteTokens(input_values_copy[0], codes, context_map, 0, &writer, 0,
nullptr);
writer.ZeroPadToByte();
Expand All @@ -225,9 +232,11 @@ void TestCheckpointing(bool ans, bool lz77) {

std::vector<uint8_t> dec_context_map;
ANSCode decoded_codes;
ASSERT_TRUE(DecodeHistograms(&br, 1, &decoded_codes, &dec_context_map));
ASSERT_TRUE(DecodeHistograms(memory_manager, &br, 1, &decoded_codes,
&dec_context_map));
ASSERT_EQ(dec_context_map, context_map);
ANSSymbolReader reader(&decoded_codes, &br);
JXL_ASSIGN_OR_DIE(ANSSymbolReader reader,
ANSSymbolReader::Create(&decoded_codes, &br));

ANSSymbolReader::Checkpoint checkpoint;
size_t br_pos = 0;
Expand Down
24 changes: 13 additions & 11 deletions lib/jxl/bit_reader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include <stddef.h>
#include <stdint.h>

#include <array>
#include <cstddef>
#include <cstdint>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -52,12 +51,13 @@ struct Symbol {

// Reading from output gives the same values.
TEST(BitReaderTest, TestRoundTrip) {
JxlMemoryManager* memory_manager = jxl::test::MemoryManager();
test::ThreadPoolForTests pool(8);
EXPECT_TRUE(RunOnPool(
pool.get(), 0, 1000, ThreadPool::NoInit,
[](const uint32_t task, size_t /* thread */) {
[&memory_manager](const uint32_t task, size_t /* thread */) {
constexpr size_t kMaxBits = 8000;
BitWriter writer;
BitWriter writer{memory_manager};
BitWriter::Allotment allotment(&writer, kMaxBits);

std::vector<Symbol> symbols;
Expand Down Expand Up @@ -86,14 +86,15 @@ TEST(BitReaderTest, TestRoundTrip) {

// SkipBits is the same as reading that many bits.
TEST(BitReaderTest, TestSkip) {
JxlMemoryManager* memory_manager = jxl::test::MemoryManager();
test::ThreadPoolForTests pool(8);
EXPECT_TRUE(RunOnPool(
pool.get(), 0, 96, ThreadPool::NoInit,
[](const uint32_t task, size_t /* thread */) {
[&memory_manager](const uint32_t task, size_t /* thread */) {
constexpr size_t kSize = 100;

for (size_t skip = 0; skip < 128; ++skip) {
BitWriter writer;
BitWriter writer{memory_manager};
BitWriter::Allotment allotment(&writer, kSize * kBitsPerByte);
// Start with "task" 1-bits.
for (size_t i = 0; i < task; ++i) {
Expand Down Expand Up @@ -142,11 +143,12 @@ TEST(BitReaderTest, TestSkip) {

// Verifies byte order and different groupings of bits.
TEST(BitReaderTest, TestOrder) {
JxlMemoryManager* memory_manager = jxl::test::MemoryManager();
constexpr size_t kMaxBits = 16;

// u(1) - bits written into LSBs of first byte
{
BitWriter writer;
BitWriter writer{memory_manager};
BitWriter::Allotment allotment(&writer, kMaxBits);
for (size_t i = 0; i < 5; ++i) {
writer.Write(1, 1);
Expand All @@ -168,7 +170,7 @@ TEST(BitReaderTest, TestOrder) {

// u(8) - get bytes in the same order
{
BitWriter writer;
BitWriter writer{memory_manager};
BitWriter::Allotment allotment(&writer, kMaxBits);
writer.Write(8, 0xF8);
writer.Write(8, 0x3F);
Expand All @@ -183,7 +185,7 @@ TEST(BitReaderTest, TestOrder) {

// u(16) - little-endian bytes
{
BitWriter writer;
BitWriter writer{memory_manager};
BitWriter::Allotment allotment(&writer, kMaxBits);
writer.Write(16, 0xF83F);

Expand All @@ -197,7 +199,7 @@ TEST(BitReaderTest, TestOrder) {

// Non-byte-aligned, mixed sizes
{
BitWriter writer;
BitWriter writer{memory_manager};
BitWriter::Allotment allotment(&writer, kMaxBits);
writer.Write(1, 1);
writer.Write(3, 6);
Expand Down
36 changes: 18 additions & 18 deletions lib/jxl/coeff_order.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,18 @@

#include "lib/jxl/coeff_order.h"

#include <stdint.h>
#include <jxl/memory_manager.h>

#include <algorithm>
#include <cstdint>
#include <vector>

#include "lib/jxl/ans_params.h"
#include "lib/jxl/base/span.h"
#include "lib/jxl/base/status.h"
#include "lib/jxl/coeff_order_fwd.h"
#include "lib/jxl/dec_ans.h"
#include "lib/jxl/dec_bit_reader.h"
#include "lib/jxl/entropy_coder.h"
#include "lib/jxl/lehmer_code.h"
#include "lib/jxl/modular/encoding/encoding.h"
#include "lib/jxl/modular/modular_image.h"

namespace jxl {

Expand Down Expand Up @@ -57,13 +55,14 @@ Status ReadPermutation(size_t skip, size_t size, coeff_order_t* order,

} // namespace

Status DecodePermutation(size_t skip, size_t size, coeff_order_t* order,
BitReader* br) {
Status DecodePermutation(JxlMemoryManager* memory_manager, size_t skip,
size_t size, coeff_order_t* order, BitReader* br) {
std::vector<uint8_t> context_map;
ANSCode code;
JXL_RETURN_IF_ERROR(
DecodeHistograms(br, kPermutationContexts, &code, &context_map));
ANSSymbolReader reader(&code, br);
JXL_RETURN_IF_ERROR(DecodeHistograms(memory_manager, br, kPermutationContexts,
&code, &context_map));
JXL_ASSIGN_OR_RETURN(ANSSymbolReader reader,
ANSSymbolReader::Create(&code, br));
JXL_RETURN_IF_ERROR(
ReadPermutation(skip, size, order, br, &reader, context_map));
if (!reader.CheckANSFinalState()) {
Expand Down Expand Up @@ -92,18 +91,19 @@ Status DecodeCoeffOrder(AcStrategy acs, coeff_order_t* order, BitReader* br,

} // namespace

Status DecodeCoeffOrders(uint16_t used_orders, uint32_t used_acs,
coeff_order_t* order, BitReader* br) {
Status DecodeCoeffOrders(JxlMemoryManager* memory_manager, uint16_t used_orders,
uint32_t used_acs, coeff_order_t* order,
BitReader* br) {
uint16_t computed = 0;
std::vector<uint8_t> context_map;
ANSCode code;
std::unique_ptr<ANSSymbolReader> reader;
ANSSymbolReader reader;
std::vector<coeff_order_t> natural_order;
// Bitstream does not have histograms if no coefficient order is used.
if (used_orders != 0) {
JXL_RETURN_IF_ERROR(
DecodeHistograms(br, kPermutationContexts, &code, &context_map));
reader = make_unique<ANSSymbolReader>(&code, br);
JXL_RETURN_IF_ERROR(DecodeHistograms(
memory_manager, br, kPermutationContexts, &code, &context_map));
JXL_ASSIGN_OR_RETURN(reader, ANSSymbolReader::Create(&code, br));
}
uint32_t acs_mask = 0;
for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) {
Expand Down Expand Up @@ -136,12 +136,12 @@ Status DecodeCoeffOrders(uint16_t used_orders, uint32_t used_acs,
} else {
for (size_t c = 0; c < 3; c++) {
coeff_order_t* dest = used ? &order[CoeffOrderOffset(ord, c)] : nullptr;
JXL_RETURN_IF_ERROR(DecodeCoeffOrder(acs, dest, br, reader.get(),
JXL_RETURN_IF_ERROR(DecodeCoeffOrder(acs, dest, br, &reader,
natural_order, context_map));
}
}
}
if (used_orders && !reader->CheckANSFinalState()) {
if (used_orders && !reader.CheckANSFinalState()) {
return JXL_FAILURE("Invalid ANS stream");
}
return true;
Expand Down
11 changes: 7 additions & 4 deletions lib/jxl/coeff_order.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#ifndef LIB_JXL_COEFF_ORDER_H_
#define LIB_JXL_COEFF_ORDER_H_

#include <jxl/memory_manager.h>

#include <array>
#include <cstddef>
#include <cstdint>
Expand Down Expand Up @@ -53,12 +55,13 @@ constexpr JXL_MAYBE_UNUSED uint32_t kPermutationContexts = 8;

uint32_t CoeffOrderContext(uint32_t val);

Status DecodeCoeffOrders(uint16_t used_orders, uint32_t used_acs,
coeff_order_t* order, BitReader* br);

Status DecodePermutation(size_t skip, size_t size, coeff_order_t* order,
Status DecodeCoeffOrders(JxlMemoryManager* memory_manager, uint16_t used_orders,
uint32_t used_acs, coeff_order_t* order,
BitReader* br);

Status DecodePermutation(JxlMemoryManager* memory_manager, size_t skip,
size_t size, coeff_order_t* order, BitReader* br);

} // namespace jxl

#endif // LIB_JXL_COEFF_ORDER_H_
8 changes: 6 additions & 2 deletions lib/jxl/coeff_order_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#include "lib/jxl/coeff_order.h"

#include <jxl/memory_manager.h>

#include <algorithm>
#include <numeric> // iota
#include <utility>
Expand All @@ -16,21 +18,23 @@
#include "lib/jxl/coeff_order_fwd.h"
#include "lib/jxl/dec_bit_reader.h"
#include "lib/jxl/enc_coeff_order.h"
#include "lib/jxl/test_utils.h"
#include "lib/jxl/testing.h"

namespace jxl {
namespace {

void RoundtripPermutation(coeff_order_t* perm, coeff_order_t* out, size_t len,
size_t* size) {
BitWriter writer;
JxlMemoryManager* memory_manager = jxl::test::MemoryManager();
BitWriter writer{memory_manager};
EncodePermutation(perm, 0, len, &writer, 0, nullptr);
writer.ZeroPadToByte();
Status status = true;
{
BitReader reader(writer.GetSpan());
BitReaderScopedCloser closer(&reader, &status);
ASSERT_TRUE(DecodePermutation(0, len, out, &reader));
ASSERT_TRUE(DecodePermutation(memory_manager, 0, len, out, &reader));
}
ASSERT_TRUE(status);
*size = writer.GetSpan().size();
Expand Down

0 comments on commit 7e9f6f2

Please sign in to comment.