Skip to content

Commit

Permalink
Fix roundtrip failure for images with dots near the border. (libjxl#3367
Browse files Browse the repository at this point in the history
)
  • Loading branch information
szabadka authored Mar 4, 2024
1 parent 3499925 commit 5c6b57f
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 50 deletions.
89 changes: 45 additions & 44 deletions lib/jxl/enc_detect_dots.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ const size_t kMaxCCSize = 1000;

// Extracts a connected component from a Binary image where seed is part
// of the component
bool ExtractComponent(ImageF* img, std::vector<Pixel>* pixels,
bool ExtractComponent(const Rect& rect, ImageF* img, std::vector<Pixel>* pixels,
const Pixel& seed, double threshold) {
static const std::vector<Pixel> neighbors{{1, -1}, {1, 0}, {1, 1}, {0, -1},
{0, 1}, {-1, -1}, {-1, 1}, {1, 0}};
Expand All @@ -190,9 +190,9 @@ bool ExtractComponent(ImageF* img, std::vector<Pixel>* pixels,
if (pixels->size() > kMaxCCSize) return false;
for (const Pixel& delta : neighbors) {
Pixel child = current + delta;
if (child.x >= 0 && static_cast<size_t>(child.x) < img->xsize() &&
child.y >= 0 && static_cast<size_t>(child.y) < img->ysize()) {
float* value = &img->Row(child.y)[child.x];
if (child.x >= 0 && static_cast<size_t>(child.x) < rect.xsize() &&
child.y >= 0 && static_cast<size_t>(child.y) < rect.ysize()) {
float* value = &rect.Row(img, child.y)[child.x];
if (*value > threshold) {
*value = 0.0;
q.push_back(child);
Expand Down Expand Up @@ -223,7 +223,7 @@ struct ConnectedComponent {
float score;
Pixel mode;

void CompStats(const ImageF& energy, int extra) {
void CompStats(const ImageF& energy, const Rect& rect, int extra) {
maxEnergy = 0.0;
meanEnergy = 0.0;
varEnergy = 0.0;
Expand All @@ -236,12 +236,12 @@ struct ConnectedComponent {
for (int sy = -extra; sy < (static_cast<int>(bounds.ysize()) + extra);
sy++) {
int y = sy + static_cast<int>(bounds.y0());
if (y < 0 || static_cast<size_t>(y) >= energy.ysize()) continue;
const float* JXL_RESTRICT erow = energy.ConstRow(y);
if (y < 0 || static_cast<size_t>(y) >= rect.ysize()) continue;
const float* JXL_RESTRICT erow = rect.ConstRow(energy, y);
for (int sx = -extra; sx < (static_cast<int>(bounds.xsize()) + extra);
sx++) {
int x = sx + static_cast<int>(bounds.x0());
if (x < 0 || static_cast<size_t>(x) >= energy.xsize()) continue;
if (x < 0 || static_cast<size_t>(x) >= rect.xsize()) continue;
if (erow[x] > maxEnergy) {
maxEnergy = erow[x];
mode.x = x;
Expand Down Expand Up @@ -284,23 +284,24 @@ Rect BoundingRectangle(const std::vector<Pixel>& pixels) {
}

StatusOr<std::vector<ConnectedComponent>> FindCC(const ImageF& energy,
double t_low, double t_high,
const Rect& rect, double t_low,
double t_high,
uint32_t maxWindow,
double minScore) {
const int kExtraRect = 4;
JXL_ASSIGN_OR_RETURN(ImageF img,
ImageF::Create(energy.xsize(), energy.ysize()));
CopyImageTo(energy, &img);
std::vector<ConnectedComponent> ans;
for (size_t y = 0; y < img.ysize(); y++) {
float* JXL_RESTRICT row = img.Row(y);
for (size_t x = 0; x < img.xsize(); x++) {
for (size_t y = 0; y < rect.ysize(); y++) {
float* JXL_RESTRICT row = rect.Row(&img, y);
for (size_t x = 0; x < rect.xsize(); x++) {
if (row[x] > t_high) {
std::vector<Pixel> pixels;
row[x] = 0.0;
bool success = ExtractComponent(
&img, &pixels, Pixel{static_cast<int>(x), static_cast<int>(y)},
t_low);
rect, &img, &pixels,
Pixel{static_cast<int>(x), static_cast<int>(y)}, t_low);
if (!success) continue;
#if JXL_DEBUG_DOT_DETECT
for (size_t i = 0; i < pixels.size(); i++) {
Expand All @@ -311,7 +312,7 @@ StatusOr<std::vector<ConnectedComponent>> FindCC(const ImageF& energy,
Rect bounds = BoundingRectangle(pixels);
if (bounds.xsize() < maxWindow && bounds.ysize() < maxWindow) {
ConnectedComponent cc{bounds, std::move(pixels)};
cc.CompStats(energy, kExtraRect);
cc.CompStats(energy, rect, kExtraRect);
if (cc.score < minScore) continue;
JXL_DEBUG(JXL_DEBUG_DOT_DETECT,
"cc mode: (%d,%d), max: %f, bgMean: %f bgVar: "
Expand All @@ -330,7 +331,8 @@ StatusOr<std::vector<ConnectedComponent>> FindCC(const ImageF& energy,
// TODO(sggonzalez): Adapt this function for the different color spaces or
// remove it if the color space with the best performance does not need it
void ComputeDotLosses(GaussianEllipse* ellipse, const ConnectedComponent& cc,
const Image3F& img, const Image3F& background) {
const Rect& rect, const Image3F& img,
const Image3F& background) {
const int rectBounds = 2;
const double kIntensityR = 0.0; // 0.015;
const double kSigmaR = 0.0; // 0.01;
Expand All @@ -350,15 +352,15 @@ void ComputeDotLosses(GaussianEllipse* ellipse, const ConnectedComponent& cc,
for (int sy = -rectBounds;
sy < (static_cast<int>(cc.bounds.ysize()) + rectBounds); sy++) {
int y = sy + cc.bounds.y0();
if (y < 0 || static_cast<size_t>(y) >= img.ysize()) continue;
const float* JXL_RESTRICT row = img.ConstPlaneRow(c, y);
if (y < 0 || static_cast<size_t>(y) >= rect.ysize()) continue;
const float* JXL_RESTRICT row = rect.ConstPlaneRow(img, c, y);
// bgrow is only used if kOptimizeBackground is false.
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
const float* JXL_RESTRICT bgrow = background.ConstPlaneRow(c, y);
const float* JXL_RESTRICT bgrow = rect.ConstPlaneRow(background, c, y);
for (int sx = -rectBounds;
sx < (static_cast<int>(cc.bounds.xsize()) + rectBounds); sx++) {
int x = sx + cc.bounds.x0();
if (x < 0 || static_cast<size_t>(x) >= img.xsize()) continue;
if (x < 0 || static_cast<size_t>(x) >= rect.xsize()) continue;
double target = row[x];
double dotDelta = DotGaussianModel(
x - ellipse->x, y - ellipse->y, ct, st, ellipse->sigma_x,
Expand Down Expand Up @@ -393,9 +395,8 @@ void ComputeDotLosses(GaussianEllipse* ellipse, const ConnectedComponent& cc,
ellipse->ridge_loss = ellipse->l2_loss + ridgeTerm;
}

GaussianEllipse FitGaussianFast(const ConnectedComponent& cc,
const ImageF& energy, const Image3F& img,
const Image3F& background) {
GaussianEllipse FitGaussianFast(const ConnectedComponent& cc, const Rect& rect,
const Image3F& img, const Image3F& background) {
constexpr bool leastSqIntensity = true;
constexpr double kEpsilon = 1e-6;
GaussianEllipse ans;
Expand All @@ -413,18 +414,18 @@ GaussianEllipse FitGaussianFast(const ConnectedComponent& cc,
"%" PRIuS " %" PRIuS " %" PRIuS " %" PRIuS "\n", cc.bounds.x0(),
cc.bounds.y0(), cc.bounds.xsize(), cc.bounds.ysize());
for (int c = 0; c < 3; c++) {
color[c] = img.ConstPlaneRow(c, cc.mode.y)[cc.mode.x] -
background.ConstPlaneRow(c, cc.mode.y)[cc.mode.x];
color[c] = rect.ConstPlaneRow(img, c, cc.mode.y)[cc.mode.x] -
rect.ConstPlaneRow(background, c, cc.mode.y)[cc.mode.x];
}
double sign = (color[1] > 0) ? 1 : -1;
for (int sy = -kRectBounds; sy <= kRectBounds; sy++) {
int y = sy + cc.mode.y;
if (y < 0 || static_cast<size_t>(y) >= energy.ysize()) continue;
const float* JXL_RESTRICT row = img.ConstPlaneRow(1, y);
const float* JXL_RESTRICT bgrow = background.ConstPlaneRow(1, y);
if (y < 0 || static_cast<size_t>(y) >= rect.ysize()) continue;
const float* JXL_RESTRICT row = rect.ConstPlaneRow(img, 1, y);
const float* JXL_RESTRICT bgrow = rect.ConstPlaneRow(background, 1, y);
for (int sx = -kRectBounds; sx <= kRectBounds; sx++) {
int x = sx + cc.mode.x;
if (x < 0 || static_cast<size_t>(x) >= energy.xsize()) continue;
if (x < 0 || static_cast<size_t>(x) >= rect.xsize()) continue;
double w = std::max(kEpsilon, sign * (row[x] - bgrow[x]));
sum += w;

Expand All @@ -434,7 +435,7 @@ GaussianEllipse FitGaussianFast(const ConnectedComponent& cc,
m2[1] += w * x * y;
m2[2] += w * y * y;
for (int c = 0; c < 3; c++) {
bgColor[c] += background.ConstPlaneRow(c, y)[x];
bgColor[c] += rect.ConstPlaneRow(background, c, y)[x];
}
N++;
}
Expand Down Expand Up @@ -485,11 +486,11 @@ GaussianEllipse FitGaussianFast(const ConnectedComponent& cc,
int yc = static_cast<int>(cc.mode.y);
int xc = static_cast<int>(cc.mode.x);
for (int y = yc - kRectBounds; y <= yc + kRectBounds; y++) {
if (y < 0 || static_cast<size_t>(y) >= img.ysize()) continue;
const float* JXL_RESTRICT row = img.ConstPlaneRow(c, y);
const float* JXL_RESTRICT bgrow = background.ConstPlaneRow(c, y);
if (y < 0 || static_cast<size_t>(y) >= rect.ysize()) continue;
const float* JXL_RESTRICT row = rect.ConstPlaneRow(img, c, y);
const float* JXL_RESTRICT bgrow = rect.ConstPlaneRow(background, c, y);
for (int x = xc - kRectBounds; x <= xc + kRectBounds; x++) {
if (x < 0 || static_cast<size_t>(x) >= img.xsize()) continue;
if (x < 0 || static_cast<size_t>(x) >= rect.xsize()) continue;
double target = row[x] - bgrow[x];
double gaussian =
DotGaussianModel(x - ellipse->x, y - ellipse->y, ct, st,
Expand All @@ -501,13 +502,13 @@ GaussianEllipse FitGaussianFast(const ConnectedComponent& cc,
ans.intensity[c] = gd / (gg + 1e-6); // Regularized least squares
}
}
ComputeDotLosses(&ans, cc, img, background);
ComputeDotLosses(&ans, cc, rect, img, background);
return ans;
}

GaussianEllipse FitGaussian(const ConnectedComponent& cc, const ImageF& energy,
GaussianEllipse FitGaussian(const ConnectedComponent& cc, const Rect& rect,
const Image3F& img, const Image3F& background) {
auto ellipse = FitGaussianFast(cc, energy, img, background);
auto ellipse = FitGaussianFast(cc, rect, img, background);
if (ellipse.sigma_x < ellipse.sigma_y) {
std::swap(ellipse.sigma_x, ellipse.sigma_y);
ellipse.angle += kPi / 2.0;
Expand All @@ -534,14 +535,14 @@ GaussianEllipse FitGaussian(const ConnectedComponent& cc, const ImageF& energy,
} // namespace

StatusOr<std::vector<PatchInfo>> DetectGaussianEllipses(
const Image3F& opsin, const GaussianDetectParams& params,
const Image3F& opsin, const Rect& rect, const GaussianDetectParams& params,
const EllipseQuantParams& qParams, ThreadPool* pool) {
std::vector<PatchInfo> dots;
JXL_ASSIGN_OR_RETURN(Image3F smooth,
Image3F::Create(opsin.xsize(), opsin.ysize()));
JXL_ASSIGN_OR_RETURN(ImageF energy, ComputeEnergyImage(opsin, &smooth, pool));
JXL_ASSIGN_OR_RETURN(std::vector<ConnectedComponent> components,
FindCC(energy, params.t_low, params.t_high,
FindCC(energy, rect, params.t_low, params.t_high,
params.maxWinSize, params.minScore));
size_t numCC =
std::min(params.maxCC, (components.size() * params.percCC) / 100);
Expand All @@ -554,11 +555,11 @@ StatusOr<std::vector<PatchInfo>> DetectGaussianEllipses(
components.erase(components.begin() + numCC, components.end());
}
for (const auto& cc : components) {
GaussianEllipse ellipse = FitGaussian(cc, energy, opsin, smooth);
GaussianEllipse ellipse = FitGaussian(cc, rect, opsin, smooth);
if (ellipse.x < 0.0 ||
std::ceil(ellipse.x) >= static_cast<double>(opsin.xsize()) ||
std::ceil(ellipse.x) >= static_cast<double>(rect.xsize()) ||
ellipse.y < 0.0 ||
std::ceil(ellipse.y) >= static_cast<double>(opsin.ysize())) {
std::ceil(ellipse.y) >= static_cast<double>(rect.ysize())) {
continue;
}
if (ellipse.neg_pixels > params.maxNegPixels) continue;
Expand Down Expand Up @@ -586,8 +587,8 @@ StatusOr<std::vector<PatchInfo>> DetectGaussianEllipses(
for (size_t x = 0; x < patch.xsize; x++) {
for (size_t c = 0; c < 3; c++) {
patch.fpixels[c][y * patch.xsize + x] =
opsin.ConstPlaneRow(c, y0 + y)[x0 + x] -
smooth.ConstPlaneRow(c, y0 + y)[x0 + x];
rect.ConstPlaneRow(opsin, c, y0 + y)[x0 + x] -
rect.ConstPlaneRow(smooth, c, y0 + y)[x0 + x];
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/jxl/enc_detect_dots.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct EllipseQuantParams {

// Detects dots in XYB image.
StatusOr<std::vector<PatchInfo>> DetectGaussianEllipses(
const Image3F& opsin, const GaussianDetectParams& params,
const Image3F& opsin, const Rect& rect, const GaussianDetectParams& params,
const EllipseQuantParams& qParams, ThreadPool* pool);

} // namespace jxl
Expand Down
6 changes: 3 additions & 3 deletions lib/jxl/enc_dot_dictionary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ const std::array<size_t, 3> kEllipseIntensityQ{{10, 36, 10}};
} // namespace

StatusOr<std::vector<PatchInfo>> FindDotDictionary(
const CompressParams& cparams, const Image3F& opsin,
const CompressParams& cparams, const Image3F& opsin, const Rect& rect,
const ColorCorrelationMap& cmap, ThreadPool* pool) {
if (ApplyOverride(cparams.dots,
cparams.butteraugli_distance >= kMinButteraugliForDots)) {
Expand All @@ -52,13 +52,13 @@ StatusOr<std::vector<PatchInfo>> FindDotDictionary(
ellipse_params.maxCC = 100;
ellipse_params.percCC = 100;
EllipseQuantParams qParams{
opsin.xsize(), opsin.ysize(), kEllipsePosQ,
rect.xsize(), rect.ysize(), kEllipsePosQ,
kEllipseMinSigma, kEllipseMaxSigma, kEllipseSigmaQ,
kEllipseAngleQ, kEllipseMinIntensity, kEllipseMaxIntensity,
kEllipseIntensityQ, kEllipsePosQ <= 5, cmap.YtoXRatio(0),
cmap.YtoBRatio(0)};

return DetectGaussianEllipses(opsin, ellipse_params, qParams, pool);
return DetectGaussianEllipses(opsin, rect, ellipse_params, qParams, pool);
}
std::vector<PatchInfo> nothing;
return nothing;
Expand Down
2 changes: 1 addition & 1 deletion lib/jxl/enc_dot_dictionary.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
namespace jxl {

StatusOr<std::vector<PatchInfo>> FindDotDictionary(
const CompressParams& cparams, const Image3F& opsin,
const CompressParams& cparams, const Image3F& opsin, const Rect& rect,
const ColorCorrelationMap& cmap, ThreadPool* pool);

} // namespace jxl
Expand Down
7 changes: 6 additions & 1 deletion lib/jxl/enc_patch_dictionary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "lib/jxl/base/common.h"
#include "lib/jxl/base/compiler_specific.h"
#include "lib/jxl/base/override.h"
#include "lib/jxl/base/printf_macros.h"
#include "lib/jxl/base/random.h"
#include "lib/jxl/base/status.h"
#include "lib/jxl/dec_cache.h"
Expand Down Expand Up @@ -586,7 +587,9 @@ Status FindBestPatchDictionary(const Image3F& opsin,
state->cparams.dots,
state->cparams.speed_tier <= SpeedTier::kSquirrel &&
state->cparams.butteraugli_distance >= kMinButteraugliForDots)) {
JXL_ASSIGN_OR_RETURN(info, FindDotDictionary(state->cparams, opsin,
Rect rect(0, 0, state->shared.frame_dim.xsize,
state->shared.frame_dim.ysize);
JXL_ASSIGN_OR_RETURN(info, FindDotDictionary(state->cparams, opsin, rect,
state->shared.cmap, pool));
}

Expand Down Expand Up @@ -716,6 +719,8 @@ Status FindBestPatchDictionary(const Image3F& opsin,
}
}
for (const auto& pos : info[i].second) {
JXL_DEBUG_V(4, "Patch %" PRIuS "x%" PRIuS " at position %u,%u",
ref_pos.xsize, ref_pos.ysize, pos.first, pos.second);
positions.emplace_back(
PatchPosition{pos.first, pos.second, pref_positions.size()});
// Add blending for color channels, ignore other channels.
Expand Down

0 comments on commit 5c6b57f

Please sign in to comment.