Skip to content

Commit

Permalink
Do 8-bit ordered dithering when decoding to 8-bit. (#3090)
Browse files Browse the repository at this point in the history
This should help mitigating banding.
  • Loading branch information
veluca93 authored Jan 3, 2024
1 parent 34b43a0 commit 815858b
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 25 deletions.
95 changes: 70 additions & 25 deletions lib/jxl/render_pipeline/stage_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#include "lib/jxl/render_pipeline/stage_write.h"

#include <type_traits>

#include "lib/jxl/alpha.h"
#include "lib/jxl/base/common.h"
#include "lib/jxl/dec_cache.h"
Expand All @@ -21,6 +23,7 @@ namespace jxl {
namespace HWY_NAMESPACE {

// These templates are not found via ADL.
using hwy::HWY_NAMESPACE::Add;
using hwy::HWY_NAMESPACE::Clamp;
using hwy::HWY_NAMESPACE::Div;
using hwy::HWY_NAMESPACE::Max;
Expand All @@ -30,6 +33,53 @@ using hwy::HWY_NAMESPACE::Or;
using hwy::HWY_NAMESPACE::Rebind;
using hwy::HWY_NAMESPACE::ShiftLeftSame;
using hwy::HWY_NAMESPACE::ShiftRightSame;
using hwy::HWY_NAMESPACE::VFromD;

// 8x8 ordered dithering pattern from
// https://en.wikipedia.org/wiki/Ordered_dithering
// scaled to have an average of 0 and be fully contained in (-0.5, 0.5).
const float kDither[64] = {
-0.4921875, 0.0078125, -0.3671875, 0.1328125, //
-0.4609375, 0.0390625, -0.3359375, 0.1640625, //
0.2578125, -0.2421875, 0.3828125, -0.1171875, //
0.2890625, -0.2109375, 0.4140625, -0.0859375, //
-0.3046875, 0.1953125, -0.4296875, 0.0703125, //
-0.2734375, 0.2265625, -0.3984375, 0.1015625, //
0.4453125, -0.0546875, 0.3203125, -0.1796875, //
0.4765625, -0.0234375, 0.3515625, -0.1484375, //
-0.4453125, 0.0546875, -0.3203125, 0.1796875, //
-0.4765625, 0.0234375, -0.3515625, 0.1484375, //
0.3046875, -0.1953125, 0.4296875, -0.0703125, //
0.2734375, -0.2265625, 0.3984375, -0.1015625, //
-0.2578125, 0.2421875, -0.3828125, 0.1171875, //
-0.2890625, 0.2109375, -0.4140625, 0.0859375, //
0.4921875, -0.0078125, 0.3671875, -0.1328125, //
0.4609375, -0.0390625, 0.3359375, -0.1640625, //
};

using DF = HWY_FULL(float);

// Converts `v` to an appropriate value for the given unsigned type.
// If the unsigned type is an 8-bit type, performs ordered dithering.
template <typename T>
VFromD<Rebind<T, DF>> MakeUnsigned(VFromD<DF> v, size_t x0, size_t y0,
VFromD<DF> mul) {
static_assert(std::is_unsigned<T>::value, "T must be an unsigned type");
using DU = Rebind<T, DF>;
v = Mul(v, mul);
// TODO(veluca): if constexpr with C++17
if (sizeof(T) == 1) {
size_t pos = (y0 % 8) * 8 + (x0 % 8);
#if HWY_TARGET != HWY_SCALAR
auto dither = LoadDup128(DF(), kDither + pos);
#else
auto dither = LoadU(DF(), kDither + pos);
#endif
v = Add(v, dither);
}
v = Clamp(Zero(DF()), v, mul);
return DemoteTo(DU(), NearestInt(v));
}

class WriteToOutputStage : public RenderPipelineStage {
public:
Expand Down Expand Up @@ -229,14 +279,14 @@ class WriteToOutputStage : public RenderPipelineStage {
if (out.data_type_ == JXL_TYPE_UINT8) {
uint8_t* JXL_RESTRICT temp =
reinterpret_cast<uint8_t*>(temp_out_[thread_id].get());
StoreUnsignedRow(out, input, len, temp);
StoreUnsignedRow(out, input, len, temp, xstart, ypos);
WriteToOutput(out, thread_id, ypos, xstart, len, temp);
} else if (out.data_type_ == JXL_TYPE_UINT16 ||
out.data_type_ == JXL_TYPE_FLOAT16) {
uint16_t* JXL_RESTRICT temp =
reinterpret_cast<uint16_t*>(temp_out_[thread_id].get());
if (out.data_type_ == JXL_TYPE_UINT16) {
StoreUnsignedRow(out, input, len, temp);
StoreUnsignedRow(out, input, len, temp, xstart, ypos);
} else {
StoreFloat16Row(out, input, len, temp);
}
Expand Down Expand Up @@ -289,10 +339,8 @@ class WriteToOutputStage : public RenderPipelineStage {

template <typename T>
void StoreUnsignedRow(const Output& out, const float* input[4], size_t len,
T* output) const {
T* output, size_t xstart, size_t ypos) const {
const HWY_FULL(float) d;
auto zero = Zero(d);
auto one = Set(d, 1.0f);
auto mul = Set(d, (1u << (out.bits_per_sample_)) - 1);
const Rebind<T, decltype(d)> du;
const size_t padding = RoundUpTo(len, Lanes(d)) - len;
Expand All @@ -301,35 +349,32 @@ class WriteToOutputStage : public RenderPipelineStage {
}
if (out.num_channels_ == 1) {
for (size_t i = 0; i < len; i += Lanes(d)) {
auto v0 = Mul(Clamp(zero, LoadU(d, &input[0][i]), one), mul);
StoreU(DemoteTo(du, NearestInt(v0)), du, &output[i]);
StoreU(MakeUnsigned<T>(LoadU(d, &input[0][i]), xstart + i, ypos, mul),
du, &output[i]);
}
} else if (out.num_channels_ == 2) {
for (size_t i = 0; i < len; i += Lanes(d)) {
auto v0 = Mul(Clamp(zero, LoadU(d, &input[0][i]), one), mul);
auto v1 = Mul(Clamp(zero, LoadU(d, &input[1][i]), one), mul);
StoreInterleaved2(DemoteTo(du, NearestInt(v0)),
DemoteTo(du, NearestInt(v1)), du, &output[2 * i]);
StoreInterleaved2(
MakeUnsigned<T>(LoadU(d, &input[0][i]), xstart + i, ypos, mul),
MakeUnsigned<T>(LoadU(d, &input[1][i]), xstart + i, ypos, mul), du,
&output[2 * i]);
}
} else if (out.num_channels_ == 3) {
for (size_t i = 0; i < len; i += Lanes(d)) {
auto v0 = Mul(Clamp(zero, LoadU(d, &input[0][i]), one), mul);
auto v1 = Mul(Clamp(zero, LoadU(d, &input[1][i]), one), mul);
auto v2 = Mul(Clamp(zero, LoadU(d, &input[2][i]), one), mul);
StoreInterleaved3(DemoteTo(du, NearestInt(v0)),
DemoteTo(du, NearestInt(v1)),
DemoteTo(du, NearestInt(v2)), du, &output[3 * i]);
StoreInterleaved3(
MakeUnsigned<T>(LoadU(d, &input[0][i]), xstart + i, ypos, mul),
MakeUnsigned<T>(LoadU(d, &input[1][i]), xstart + i, ypos, mul),
MakeUnsigned<T>(LoadU(d, &input[2][i]), xstart + i, ypos, mul), du,
&output[3 * i]);
}
} else if (out.num_channels_ == 4) {
for (size_t i = 0; i < len; i += Lanes(d)) {
auto v0 = Mul(Clamp(zero, LoadU(d, &input[0][i]), one), mul);
auto v1 = Mul(Clamp(zero, LoadU(d, &input[1][i]), one), mul);
auto v2 = Mul(Clamp(zero, LoadU(d, &input[2][i]), one), mul);
auto v3 = Mul(Clamp(zero, LoadU(d, &input[3][i]), one), mul);
StoreInterleaved4(DemoteTo(du, NearestInt(v0)),
DemoteTo(du, NearestInt(v1)),
DemoteTo(du, NearestInt(v2)),
DemoteTo(du, NearestInt(v3)), du, &output[4 * i]);
StoreInterleaved4(
MakeUnsigned<T>(LoadU(d, &input[0][i]), xstart + i, ypos, mul),
MakeUnsigned<T>(LoadU(d, &input[1][i]), xstart + i, ypos, mul),
MakeUnsigned<T>(LoadU(d, &input[2][i]), xstart + i, ypos, mul),
MakeUnsigned<T>(LoadU(d, &input[3][i]), xstart + i, ypos, mul), du,
&output[4 * i]);
}
}
msan::PoisonMemory(output + out.num_channels_ * len,
Expand Down
5 changes: 5 additions & 0 deletions lib/jxl/test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <jxl/cms.h>
#include <jxl/cms_interface.h>
#include <jxl/types.h>

#include <cstddef>
#include <fstream>
Expand Down Expand Up @@ -487,6 +488,10 @@ size_t ComparePixels(const uint8_t* a, const uint8_t* b, size_t xsize,
// TODO(lode): Set the required precision back to 11 bits when possible.
precision = 0.5 * threshold_multiplier / ((1ull << (bits - 1)) - 1ull);
}
if (format_b.data_type == JXL_TYPE_UINT8) {
// Increase the threshold by the maximum difference introduced by dithering.
precision += 63.0 / 128.0;
}
size_t numdiff = 0;
for (size_t y = 0; y < ysize; y++) {
for (size_t x = 0; x < xsize; x++) {
Expand Down

0 comments on commit 815858b

Please sign in to comment.