diff --git a/lib/jxl/compressed_image_test.cc b/lib/jxl/compressed_image_test.cc index e324ff4b0ad..85357038568 100644 --- a/lib/jxl/compressed_image_test.cc +++ b/lib/jxl/compressed_image_test.cc @@ -77,6 +77,8 @@ void RunRGBRoundTrip(float distance, bool fast) { PassesEncoderState enc_state; JXL_CHECK(InitializePassesSharedState(frame_header, &enc_state.shared)); + JXL_CHECK(enc_state.shared.matrices.EnsureComputed(~0u)); + enc_state.shared.quantizer.SetQuant(4.0f, 4.0f, &enc_state.shared.raw_quant_field); enc_state.shared.ac_strategy.FillDCT8(); diff --git a/lib/jxl/dec_frame.cc b/lib/jxl/dec_frame.cc index a435e3eb15b..fbc3c49b1e3 100644 --- a/lib/jxl/dec_frame.cc +++ b/lib/jxl/dec_frame.cc @@ -545,6 +545,8 @@ Status FrameDecoder::ProcessACGlobal(BitReader* br) { if (frame_header_.encoding == FrameEncoding::kVarDCT) { JXL_RETURN_IF_ERROR(dec_state_->shared_storage.matrices.Decode( br, &modular_frame_decoder_)); + JXL_RETURN_IF_ERROR(dec_state_->shared_storage.matrices.EnsureComputed( + dec_state_->used_acs)); size_t num_histo_bits = CeilLog2Nonzero(dec_state_->shared->frame_dim.num_groups); diff --git a/lib/jxl/dec_group.cc b/lib/jxl/dec_group.cc index 88d3fb9f02a..8ae23c1a02c 100644 --- a/lib/jxl/dec_group.cc +++ b/lib/jxl/dec_group.cc @@ -97,15 +97,14 @@ void Transpose8x8InPlace(int32_t* JXL_RESTRICT block) { template void DequantLane(Vec scaled_dequant_x, Vec scaled_dequant_y, Vec scaled_dequant_b, - const float* JXL_RESTRICT dequant_matrices, size_t dq_ofs, - size_t size, size_t k, Vec x_cc_mul, Vec b_cc_mul, + const float* JXL_RESTRICT dequant_matrices, size_t size, + size_t k, Vec x_cc_mul, Vec b_cc_mul, const float* JXL_RESTRICT biases, ACPtr qblock[3], float* JXL_RESTRICT block) { - const auto x_mul = Load(d, dequant_matrices + dq_ofs + k) * scaled_dequant_x; - const auto y_mul = - Load(d, dequant_matrices + dq_ofs + size + k) * scaled_dequant_y; + const auto x_mul = Load(d, dequant_matrices + k) * scaled_dequant_x; + const auto y_mul = Load(d, dequant_matrices + size + k) * scaled_dequant_y; const auto b_mul = - Load(d, dequant_matrices + dq_ofs + 2 * size + k) * scaled_dequant_b; + Load(d, dequant_matrices + 2 * size + k) * scaled_dequant_b; Vec quantized_x_int; Vec quantized_y_int; @@ -139,9 +138,8 @@ template void DequantBlock(const AcStrategy& acs, float inv_global_scale, int quant, float x_dm_multiplier, float b_dm_multiplier, Vec x_cc_mul, Vec b_cc_mul, size_t kind, size_t size, - const Quantizer& quantizer, - const float* JXL_RESTRICT dequant_matrices, - size_t covered_blocks, const size_t* sbx, + const Quantizer& quantizer, size_t covered_blocks, + const size_t* sbx, const float* JXL_RESTRICT* JXL_RESTRICT dc_row, size_t dc_stride, const float* JXL_RESTRICT biases, ACPtr qblock[3], float* JXL_RESTRICT block) { @@ -153,12 +151,12 @@ void DequantBlock(const AcStrategy& acs, float inv_global_scale, int quant, const auto scaled_dequant_y = Set(d, scaled_dequant_s); const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier); - const size_t dq_ofs = quantizer.DequantMatrixOffset(kind, 0); + const float* dequant_matrices = quantizer.DequantMatrix(kind, 0); for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) { DequantLane(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b, - dequant_matrices, dq_ofs, size, k, x_cc_mul, b_cc_mul, - biases, qblock, block); + dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases, + qblock, block); } for (size_t c = 0; c < 3; c++) { LowestFrequenciesFromDC(acs.Strategy(), dc_row[c] + sbx[c], dc_stride, @@ -186,8 +184,6 @@ Status DecodeGroupImpl(GetBlock* JXL_RESTRICT get_block, const size_t dc_stride = dec_state->shared->dc->PixelsPerRow(); const float inv_global_scale = dec_state->shared->quantizer.InvGlobalScale(); - const float* JXL_RESTRICT dequant_matrices = - dec_state->shared->quantizer.DequantMatrix(0, 0); const YCbCrChromaSubsampling& cs = dec_state->shared->frame_header.chroma_subsampling; @@ -428,7 +424,7 @@ Status DecodeGroupImpl(GetBlock* JXL_RESTRICT get_block, dequant_block( acs, inv_global_scale, row_quant[bx], dec_state->x_dm_multiplier, dec_state->b_dm_multiplier, x_cc_mul, b_cc_mul, acs.RawStrategy(), - size, dec_state->shared->quantizer, dequant_matrices, + size, dec_state->shared->quantizer, acs.covered_blocks_y() * acs.covered_blocks_x(), sbx, dc_rows, dc_stride, dec_state->output_encoding_info.opsin_params.quant_biases, qblock, diff --git a/lib/jxl/enc_ac_strategy.cc b/lib/jxl/enc_ac_strategy.cc index c0ed68fde1c..3ef97924284 100644 --- a/lib/jxl/enc_ac_strategy.cc +++ b/lib/jxl/enc_ac_strategy.cc @@ -1008,6 +1008,17 @@ void AcStrategyHeuristics::Init(const Image3F& src, const CompressParams& cparams = enc_state->cparams; const float butteraugli_target = cparams.butteraugli_distance; + if (cparams.speed_tier >= SpeedTier::kCheetah) { + JXL_CHECK(enc_state->shared.matrices.EnsureComputed(1)); // DCT8 only + } else { + uint32_t acs_mask = 0; + // All transforms up to 64x64. + for (size_t i = 0; i < AcStrategy::DCT128X128; i++) { + acs_mask |= (1 << i); + } + JXL_CHECK(enc_state->shared.matrices.EnsureComputed(acs_mask)); + } + // Image row pointers and strides. config.quant_field_row = enc_state->initial_quant_field.Row(0); config.quant_field_stride = enc_state->initial_quant_field.PixelsPerRow(); diff --git a/lib/jxl/enc_heuristics.cc b/lib/jxl/enc_heuristics.cc index cc4af55aaaa..9de389cd5a1 100644 --- a/lib/jxl/enc_heuristics.cc +++ b/lib/jxl/enc_heuristics.cc @@ -866,6 +866,9 @@ Status DefaultEncoderHeuristics::LossyFrameHeuristics( GaborishInverse(opsin, 0.9908511000000001f, pool); } + FindBestDequantMatrices(cparams, *opsin, modular_frame_encoder, + &enc_state->shared.matrices); + cfl_heuristics.Init(*opsin); acs_heuristics.Init(*opsin, enc_state); @@ -934,9 +937,6 @@ Status DefaultEncoderHeuristics::LossyFrameHeuristics( &enc_state->shared.cmap); } - FindBestDequantMatrices(cparams, *opsin, modular_frame_encoder, - &enc_state->shared.matrices); - // Refine quantization levels. FindBestQuantizer(original_pixels, *opsin, enc_state, cms, pool, aux_out); diff --git a/lib/jxl/quant_weights.cc b/lib/jxl/quant_weights.cc index 399a559ab93..2acd639b058 100644 --- a/lib/jxl/quant_weights.cc +++ b/lib/jxl/quant_weights.cc @@ -21,14 +21,21 @@ #include "lib/jxl/fields.h" #include "lib/jxl/image.h" +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/quant_weights.cc" +#include +#include + +#include "lib/jxl/fast_math-inl.h" + +HWY_BEFORE_NAMESPACE(); namespace jxl { +namespace HWY_NAMESPACE { // kQuantWeights[N * N * c + N * y + x] is the relative weight of the (x, y) // coefficient in component c. Higher weights correspond to finer quantization // intervals and more bits spent in encoding. -namespace { - static constexpr const float kAlmostZero = 1e-8f; void GetQuantWeightsDCT2(const QuantEncoding::DCT2Weights& dct2weights, @@ -75,33 +82,47 @@ void GetQuantWeightsIdentity(const QuantEncoding::IdWeights& idweights, } } -float Mult(float v) { - if (v > 0) return 1 + v; - return 1 / (1 - v); -} - float Interpolate(float pos, float max, const float* array, size_t len) { float scaled_pos = pos * (len - 1) / max; size_t idx = scaled_pos; - JXL_ASSERT(idx + 1 < len); + JXL_DASSERT(idx + 1 < len); float a = array[idx]; float b = array[idx + 1]; - return a * pow(b / a, scaled_pos - idx); + return a * FastPowf(b / a, scaled_pos - idx); +} + +float Mult(float v) { + if (v > 0.0f) return 1.0f + v; + return 1.0f / (1.0f - v); +} + +using DF4 = HWY_CAPPED(float, 4); + +hwy::HWY_NAMESPACE::Vec InterpolateVec( + hwy::HWY_NAMESPACE::Vec scaled_pos, const float* array) { + HWY_CAPPED(int32_t, 4) di; + + auto idx = ConvertTo(di, scaled_pos); + + auto frac = scaled_pos - ConvertTo(DF4(), idx); + + // TODO(veluca): in theory, this could be done with 8 TableLookupBytes, but + // it's probably slower. + auto a = GatherIndex(DF4(), array, idx); + auto b = GatherIndex(DF4(), array + 1, idx); + + return a * FastPowf(DF4(), b / a, frac); } // Computes quant weights for a COLS*ROWS-sized transform, using num_bands // eccentricity bands and num_ebands eccentricity bands. If print_mode is 1, // prints the resulting matrix; if print_mode is 2, prints the matrix in a // format suitable for a 3d plot with gnuplot. -template Status GetQuantWeights( size_t ROWS, size_t COLS, const DctQuantWeightParams::DistanceBandsArray& distance_bands, size_t num_bands, float* out) { for (size_t c = 0; c < 3; c++) { - if (print_mode) { - fprintf(stderr, "Channel %" PRIuS "\n", c); - } float bands[DctQuantWeightParams::kMaxDistanceBands] = { distance_bands[c][0]}; if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid distance bands"); @@ -109,159 +130,38 @@ Status GetQuantWeights( bands[i] = bands[i - 1] * Mult(distance_bands[c][i]); if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid distance bands"); } - for (size_t y = 0; y < ROWS; y++) { - for (size_t x = 0; x < COLS; x++) { - float dx = 1.0f * x / (COLS - 1); - float dy = 1.0f * y / (ROWS - 1); - float distance = std::sqrt(dx * dx + dy * dy); - float weight = - num_bands == 1 - ? bands[0] - : Interpolate(distance, std::sqrt(2) + 1e-6f, bands, num_bands); - - if (print_mode == 1) { - fprintf(stderr, "%15.12f, ", weight); - } - if (print_mode == 2) { - fprintf(stderr, "%" PRIuS " %" PRIuS " %15.12f\n", x, y, weight); - } - out[c * COLS * ROWS + y * COLS + x] = weight; + float scale = (num_bands - 1) / (kSqrt2 + 1e-6f); + float rcpcol = scale / (COLS - 1); + float rcprow = scale / (ROWS - 1); + JXL_ASSERT(COLS >= Lanes(DF4())); + HWY_ALIGN float l0123[4] = {0, 1, 2, 3}; + for (uint32_t y = 0; y < ROWS; y++) { + float dy = y * rcprow; + float dy2 = dy * dy; + for (uint32_t x = 0; x < COLS; x += Lanes(DF4())) { + auto dx = (Set(DF4(), x) + Load(DF4(), l0123)) * Set(DF4(), rcpcol); + auto scaled_distance = Sqrt(MulAdd(dx, dx, Set(DF4(), dy2))); + auto weight = num_bands == 1 ? Set(DF4(), bands[0]) + : InterpolateVec(scaled_distance, bands); + StoreU(weight, DF4(), out + c * COLS * ROWS + y * COLS + x); } - if (print_mode) fprintf(stderr, "\n"); - if (print_mode == 1) fprintf(stderr, "\n"); - } - if (print_mode) fprintf(stderr, "\n"); - } - return true; -} - -Status DecodeDctParams(BitReader* br, DctQuantWeightParams* params) { - params->num_distance_bands = - br->ReadFixedBits() + 1; - for (size_t c = 0; c < 3; c++) { - for (size_t i = 0; i < params->num_distance_bands; i++) { - JXL_RETURN_IF_ERROR(F16Coder::Read(br, ¶ms->distance_bands[c][i])); } - if (params->distance_bands[c][0] < kAlmostZero) { - return JXL_FAILURE("Distance band seed is too small"); - } - params->distance_bands[c][0] *= 64.0f; } return true; } -Status Decode(BitReader* br, QuantEncoding* encoding, size_t required_size_x, - size_t required_size_y, size_t idx, - ModularFrameDecoder* modular_frame_decoder) { - size_t required_size = required_size_x * required_size_y; - required_size_x *= kBlockDim; - required_size_y *= kBlockDim; - int mode = br->ReadFixedBits(); - switch (mode) { - case QuantEncoding::kQuantModeLibrary: { - encoding->predefined = br->ReadFixedBits(); - if (encoding->predefined >= kNumPredefinedTables) { - return JXL_FAILURE("Invalid predefined table"); - } - break; - } - case QuantEncoding::kQuantModeID: { - if (required_size != 1) return JXL_FAILURE("Invalid mode"); - for (size_t c = 0; c < 3; c++) { - for (size_t i = 0; i < 3; i++) { - JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->idweights[c][i])); - if (std::abs(encoding->idweights[c][i]) < kAlmostZero) { - return JXL_FAILURE("ID Quantizer is too small"); - } - encoding->idweights[c][i] *= 64; - } - } - break; - } - case QuantEncoding::kQuantModeDCT2: { - if (required_size != 1) return JXL_FAILURE("Invalid mode"); - for (size_t c = 0; c < 3; c++) { - for (size_t i = 0; i < 6; i++) { - JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->dct2weights[c][i])); - if (std::abs(encoding->dct2weights[c][i]) < kAlmostZero) { - return JXL_FAILURE("Quantizer is too small"); - } - encoding->dct2weights[c][i] *= 64; - } - } - break; - } - case QuantEncoding::kQuantModeDCT4X8: { - if (required_size != 1) return JXL_FAILURE("Invalid mode"); - for (size_t c = 0; c < 3; c++) { - JXL_RETURN_IF_ERROR( - F16Coder::Read(br, &encoding->dct4x8multipliers[c])); - if (std::abs(encoding->dct4x8multipliers[c]) < kAlmostZero) { - return JXL_FAILURE("DCT4X8 multiplier is too small"); - } - } - JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); - break; - } - case QuantEncoding::kQuantModeDCT4: { - if (required_size != 1) return JXL_FAILURE("Invalid mode"); - for (size_t c = 0; c < 3; c++) { - for (size_t i = 0; i < 2; i++) { - JXL_RETURN_IF_ERROR( - F16Coder::Read(br, &encoding->dct4multipliers[c][i])); - if (std::abs(encoding->dct4multipliers[c][i]) < kAlmostZero) { - return JXL_FAILURE("DCT4 multiplier is too small"); - } - } - } - JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); - break; - } - case QuantEncoding::kQuantModeAFV: { - if (required_size != 1) return JXL_FAILURE("Invalid mode"); - for (size_t c = 0; c < 3; c++) { - for (size_t i = 0; i < 9; i++) { - JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->afv_weights[c][i])); - } - for (size_t i = 0; i < 6; i++) { - encoding->afv_weights[c][i] *= 64; - } - JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); - JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params_afv_4x4)); - } - break; - } - case QuantEncoding::kQuantModeDCT: { - JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); - break; - } - case QuantEncoding::kQuantModeRAW: { - // Set mode early, to avoid mem-leak. - encoding->mode = QuantEncoding::kQuantModeRAW; - JXL_RETURN_IF_ERROR(ModularFrameDecoder::DecodeQuantTable( - required_size_x, required_size_y, br, encoding, idx, - modular_frame_decoder)); - break; - } - default: - return JXL_FAILURE("Invalid quantization table encoding"); - } - encoding->mode = QuantEncoding::Mode(mode); - return true; -} - // TODO(veluca): SIMD-fy. With 256x256, this is actually slow. Status ComputeQuantTable(const QuantEncoding& encoding, float* JXL_RESTRICT table, float* JXL_RESTRICT inv_table, size_t table_num, DequantMatrices::QuantTable kind, size_t* pos) { - std::vector weights(3 * kMaxQuantTableSize); - constexpr size_t N = kBlockDim; size_t wrows = 8 * DequantMatrices::required_size_x[kind], wcols = 8 * DequantMatrices::required_size_y[kind]; size_t num = wrows * wcols; + std::vector weights(3 * num); + switch (encoding.mode) { case QuantEncoding::kQuantModeLibrary: { // Library and copy quant encoding should get replaced by the actual @@ -363,7 +263,7 @@ Status ComputeQuantTable(const QuantEncoding& encoding, encoding.dct_params_afv_4x4.num_distance_bands, weights4x4))); constexpr float lo = 0.8517778890324296; - constexpr float hi = 12.97166202570235 - lo + 1e-6; + constexpr float hi = 12.97166202570235f - lo + 1e-6f; for (size_t c = 0; c < 3; c++) { float bands[4]; bands[0] = encoding.afv_weights[c][5]; @@ -415,18 +315,19 @@ Status ComputeQuantTable(const QuantEncoding& encoding, } } size_t prev_pos = *pos; - for (size_t c = 0; c < 3; c++) { - for (size_t i = 0; i < num; i++) { - float inv_val = weights[c * num + i]; - if (inv_val > 1.0f / kAlmostZero || inv_val < kAlmostZero) { - return JXL_FAILURE("Invalid quantization table"); - } - float val = 1.0f / inv_val; - table[*pos] = val; - inv_table[*pos] = inv_val; - (*pos)++; + HWY_CAPPED(float, 64) d; + for (size_t i = 0; i < num * 3; i += Lanes(d)) { + auto inv_val = LoadU(d, weights.data() + i); + if (JXL_UNLIKELY(!AllFalse(inv_val >= Set(d, 1.0f / kAlmostZero)) | + !AllFalse(inv_val < Set(d, kAlmostZero)))) { + return JXL_FAILURE("Invalid quantization table"); } + auto val = Set(d, 1.0f) / inv_val; + StoreU(val, d, table + *pos + i); + StoreU(inv_val, d, inv_table + *pos + i); } + (*pos) += 3 * num; + // Ensure that the lowest frequencies have a 0 inverse table. // This does not affect en/decoding, but allows AC strategy selection to be // slightly simpler. @@ -444,6 +345,135 @@ Status ComputeQuantTable(const QuantEncoding& encoding, return true; } +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace jxl { +namespace { + +HWY_EXPORT(ComputeQuantTable); + +static constexpr const float kAlmostZero = 1e-8f; + +Status DecodeDctParams(BitReader* br, DctQuantWeightParams* params) { + params->num_distance_bands = + br->ReadFixedBits() + 1; + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < params->num_distance_bands; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, ¶ms->distance_bands[c][i])); + } + if (params->distance_bands[c][0] < kAlmostZero) { + return JXL_FAILURE("Distance band seed is too small"); + } + params->distance_bands[c][0] *= 64.0f; + } + return true; +} + +Status Decode(BitReader* br, QuantEncoding* encoding, size_t required_size_x, + size_t required_size_y, size_t idx, + ModularFrameDecoder* modular_frame_decoder) { + size_t required_size = required_size_x * required_size_y; + required_size_x *= kBlockDim; + required_size_y *= kBlockDim; + int mode = br->ReadFixedBits(); + switch (mode) { + case QuantEncoding::kQuantModeLibrary: { + encoding->predefined = br->ReadFixedBits(); + if (encoding->predefined >= kNumPredefinedTables) { + return JXL_FAILURE("Invalid predefined table"); + } + break; + } + case QuantEncoding::kQuantModeID: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 3; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->idweights[c][i])); + if (std::abs(encoding->idweights[c][i]) < kAlmostZero) { + return JXL_FAILURE("ID Quantizer is too small"); + } + encoding->idweights[c][i] *= 64; + } + } + break; + } + case QuantEncoding::kQuantModeDCT2: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 6; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->dct2weights[c][i])); + if (std::abs(encoding->dct2weights[c][i]) < kAlmostZero) { + return JXL_FAILURE("Quantizer is too small"); + } + encoding->dct2weights[c][i] *= 64; + } + } + break; + } + case QuantEncoding::kQuantModeDCT4X8: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + JXL_RETURN_IF_ERROR( + F16Coder::Read(br, &encoding->dct4x8multipliers[c])); + if (std::abs(encoding->dct4x8multipliers[c]) < kAlmostZero) { + return JXL_FAILURE("DCT4X8 multiplier is too small"); + } + } + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); + break; + } + case QuantEncoding::kQuantModeDCT4: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 2; i++) { + JXL_RETURN_IF_ERROR( + F16Coder::Read(br, &encoding->dct4multipliers[c][i])); + if (std::abs(encoding->dct4multipliers[c][i]) < kAlmostZero) { + return JXL_FAILURE("DCT4 multiplier is too small"); + } + } + } + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); + break; + } + case QuantEncoding::kQuantModeAFV: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 9; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->afv_weights[c][i])); + } + for (size_t i = 0; i < 6; i++) { + encoding->afv_weights[c][i] *= 64; + } + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params_afv_4x4)); + } + break; + } + case QuantEncoding::kQuantModeDCT: { + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); + break; + } + case QuantEncoding::kQuantModeRAW: { + // Set mode early, to avoid mem-leak. + encoding->mode = QuantEncoding::kQuantModeRAW; + JXL_RETURN_IF_ERROR(ModularFrameDecoder::DecodeQuantTable( + required_size_x, required_size_y, br, encoding, idx, + modular_frame_decoder)); + break; + } + default: + return JXL_FAILURE("Invalid quantization table encoding"); + } + encoding->mode = QuantEncoding::Mode(mode); + return true; +} + } // namespace // These definitions are needed before C++17. @@ -463,7 +493,8 @@ Status DequantMatrices::Decode(BitReader* br, jxl::Decode(br, &encodings_[i], required_size_x[i % kNum], required_size_y[i % kNum], i, modular_frame_decoder)); } - return DequantMatrices::Compute(); + computed_mask_ = 0; + return true; } Status DequantMatrices::DecodeDC(BitReader* br) { @@ -1126,61 +1157,78 @@ const QuantEncoding* DequantMatrices::Library() { return reinterpret_cast(kDequantLibrary.data()); } -Status DequantMatrices::Compute() { +DequantMatrices::DequantMatrices() { + encodings_.resize(size_t(QuantTable::kNum), QuantEncoding::Library(0)); size_t pos = 0; - - struct DefaultMatrices { - DefaultMatrices() { - const QuantEncoding* library = Library(); - size_t pos = 0; - for (size_t i = 0; i < kNum; i++) { - JXL_CHECK(ComputeQuantTable(library[i], table, inv_table, i, - QuantTable(i), &pos)); - } - JXL_CHECK(pos == kTotalTableSize); + size_t offsets[kNum * 3]; + for (size_t i = 0; i < size_t(QuantTable::kNum); i++) { + size_t num = required_size_[i] * kDCTBlockSize; + for (size_t c = 0; c < 3; c++) { + offsets[3 * i + c] = pos + c * num; } - HWY_ALIGN_MAX float table[kTotalTableSize]; - HWY_ALIGN_MAX float inv_table[kTotalTableSize]; - }; - - static const DefaultMatrices& default_matrices = - *hwy::MakeUniqueAligned().release(); - - JXL_ASSERT(encodings_.size() == kNum); - - bool has_nondefault_matrix = false; - for (const auto& enc : encodings_) { - if (enc.mode != QuantEncoding::kQuantModeLibrary) { - has_nondefault_matrix = true; + pos += 3 * num; + } + for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { + for (size_t c = 0; c < 3; c++) { + table_offsets_[i * 3 + c] = offsets[kQuantTable[i] * 3 + c]; } } - if (has_nondefault_matrix) { +} + +Status DequantMatrices::EnsureComputed(uint32_t acs_mask) { + const QuantEncoding* library = Library(); + + if (!table_storage_) { table_storage_ = hwy::AllocateAligned(2 * kTotalTableSize); table_ = table_storage_.get(); inv_table_ = table_storage_.get() + kTotalTableSize; - for (size_t table = 0; table < kNum; table++) { - size_t prev_pos = pos; - if (encodings_[table].mode == QuantEncoding::kQuantModeLibrary) { - size_t num = required_size_[table] * kDCTBlockSize; - memcpy(table_storage_.get() + prev_pos, - default_matrices.table + prev_pos, num * sizeof(float) * 3); - memcpy(table_storage_.get() + kTotalTableSize + prev_pos, - default_matrices.inv_table + prev_pos, num * sizeof(float) * 3); - pos += num * 3; - } else { - JXL_RETURN_IF_ERROR( - ComputeQuantTable(encodings_[table], table_storage_.get(), - table_storage_.get() + kTotalTableSize, table, - QuantTable(table), &pos)); - } + } + + size_t offsets[kNum * 3 + 1]; + size_t pos = 0; + for (size_t i = 0; i < kNum; i++) { + size_t num = required_size_[i] * kDCTBlockSize; + for (size_t c = 0; c < 3; c++) { + offsets[3 * i + c] = pos + c * num; + } + pos += 3 * num; + } + offsets[kNum * 3] = pos; + JXL_ASSERT(pos == kTotalTableSize); + + uint32_t kind_mask = 0; + for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { + if (acs_mask & (1u << i)) { + kind_mask |= 1u << kQuantTable[i]; + } + } + uint32_t computed_kind_mask = 0; + for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { + if (computed_mask_ & (1u << i)) { + computed_kind_mask |= 1u << kQuantTable[i]; + } + } + for (size_t table = 0; table < kNum; table++) { + if ((1 << table) & computed_kind_mask) continue; + if ((1 << table) & ~kind_mask) continue; + size_t pos = offsets[table * 3]; + if (encodings_[table].mode == QuantEncoding::kQuantModeLibrary) { + JXL_CHECK(HWY_DYNAMIC_DISPATCH(ComputeQuantTable)( + library[table], table_storage_.get(), + table_storage_.get() + kTotalTableSize, table, QuantTable(table), + &pos)); + } else { + JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(ComputeQuantTable)( + encodings_[table], table_storage_.get(), + table_storage_.get() + kTotalTableSize, table, QuantTable(table), + &pos)); } - JXL_ASSERT(pos == kTotalTableSize); - } else { - table_ = default_matrices.table; - inv_table_ = default_matrices.inv_table; + JXL_ASSERT(pos == offsets[table * 3 + 3]); } + computed_mask_ |= acs_mask; return true; } } // namespace jxl +#endif diff --git a/lib/jxl/quant_weights.h b/lib/jxl/quant_weights.h index 816362f81c8..b235dc754ed 100644 --- a/lib/jxl/quant_weights.h +++ b/lib/jxl/quant_weights.h @@ -363,26 +363,7 @@ class DequantMatrices { sizeof(kQuantTable) / sizeof *kQuantTable, "Update this array when adding or removing AC strategies."); - DequantMatrices() { - encodings_.resize(size_t(QuantTable::kNum), QuantEncoding::Library(0)); - size_t pos = 0; - size_t offsets[kNum * 3]; - for (size_t i = 0; i < size_t(QuantTable::kNum); i++) { - encodings_[i] = QuantEncoding::Library(0); - size_t num = required_size_[i] * kDCTBlockSize; - for (size_t c = 0; c < 3; c++) { - offsets[3 * i + c] = pos + c * num; - } - pos += 3 * num; - } - for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { - for (size_t c = 0; c < 3; c++) { - table_offsets_[i * 3 + c] = offsets[kQuantTable[i] * 3 + c]; - } - } - // Default quantization tables need to be valid. - JXL_CHECK(Compute()); - } + DequantMatrices(); static const QuantEncoding* Library(); @@ -393,20 +374,17 @@ class DequantMatrices { // .cc file. static const DequantLibraryInternal LibraryInit(); - JXL_INLINE size_t MatrixOffset(size_t quant_kind, size_t c) const { - JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies); - return table_offsets_[quant_kind * 3 + c]; - } - // Returns aligned memory. JXL_INLINE const float* Matrix(size_t quant_kind, size_t c) const { JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies); - return &table_[MatrixOffset(quant_kind, c)]; + JXL_DASSERT((1 << quant_kind) & computed_mask_); + return &table_[table_offsets_[quant_kind * 3 + c]]; } JXL_INLINE const float* InvMatrix(size_t quant_kind, size_t c) const { JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies); - return &inv_table_[MatrixOffset(quant_kind, c)]; + JXL_DASSERT((1 << quant_kind) & computed_mask_); + return &inv_table_[table_offsets_[quant_kind * 3 + c]]; } // DC quants are used in modular mode for XYB multipliers. @@ -418,6 +396,7 @@ class DequantMatrices { // For encoder. void SetEncodings(const std::vector& encodings) { encodings_ = encodings; + computed_mask_ = 0; } // For encoder. @@ -444,9 +423,9 @@ class DequantMatrices { static_assert(kNum == sizeof(required_size_y) / sizeof(*required_size_y), "Update this array when adding or removing quant tables."); - private: - Status Compute(); + Status EnsureComputed(uint32_t kind_mask); + private: static constexpr size_t required_size_[] = { 1, 1, 1, 1, 4, 16, 2, 4, 8, 1, 1, 64, 32, 256, 128, 1024, 512}; static_assert(kNum == sizeof(required_size_) / sizeof(*required_size_), @@ -454,6 +433,7 @@ class DequantMatrices { static constexpr size_t kTotalTableSize = ArraySum(required_size_) * kDCTBlockSize * 3; + uint32_t computed_mask_ = 0; // kTotalTableSize entries followed by kTotalTableSize for inv_table hwy::AlignedFreeUniquePtr table_storage_; const float* table_; diff --git a/lib/jxl/quant_weights_test.cc b/lib/jxl/quant_weights_test.cc index a700d178c45..f0497948a7f 100644 --- a/lib/jxl/quant_weights_test.cc +++ b/lib/jxl/quant_weights_test.cc @@ -173,6 +173,7 @@ TEST_P(QuantWeightsTargetTest, DCTUniform) { FrameHeader frame_header(&metadata); ModularFrameEncoder encoder(frame_header, CompressParams{}); DequantMatricesSetCustom(&dequant_matrices, encodings, &encoder); + JXL_CHECK(dequant_matrices.EnsureComputed(~0u)); const float dc_quant[3] = {1.0f / kUniformQuant, 1.0f / kUniformQuant, 1.0f / kUniformQuant}; diff --git a/lib/jxl/quantizer.h b/lib/jxl/quantizer.h index 8d9a2347901..1ff593e5c18 100644 --- a/lib/jxl/quantizer.h +++ b/lib/jxl/quantizer.h @@ -123,10 +123,6 @@ class Quantizer { return dequant_->InvMatrix(quant_kind, c); } - JXL_INLINE size_t DequantMatrixOffset(size_t quant_kind, size_t c) const { - return dequant_->MatrixOffset(quant_kind, c); - } - // Calculates DC quantization step. JXL_INLINE float GetDcStep(size_t c) const { return inv_quant_dc_ * dequant_->DCQuant(c);