Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix roundtrip failure for images with dots near the border. #3367

Merged
merged 1 commit into from Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
89 changes: 45 additions & 44 deletions lib/jxl/enc_detect_dots.cc
Expand Up @@ -178,7 +178,7 @@ const size_t kMaxCCSize = 1000;

// Extracts a connected component from a Binary image where seed is part
// of the component
bool ExtractComponent(ImageF* img, std::vector<Pixel>* pixels,
bool ExtractComponent(const Rect& rect, ImageF* img, std::vector<Pixel>* pixels,
const Pixel& seed, double threshold) {
static const std::vector<Pixel> neighbors{{1, -1}, {1, 0}, {1, 1}, {0, -1},
{0, 1}, {-1, -1}, {-1, 1}, {1, 0}};
Expand All @@ -190,9 +190,9 @@ bool ExtractComponent(ImageF* img, std::vector<Pixel>* pixels,
if (pixels->size() > kMaxCCSize) return false;
for (const Pixel& delta : neighbors) {
Pixel child = current + delta;
if (child.x >= 0 && static_cast<size_t>(child.x) < img->xsize() &&
child.y >= 0 && static_cast<size_t>(child.y) < img->ysize()) {
float* value = &img->Row(child.y)[child.x];
if (child.x >= 0 && static_cast<size_t>(child.x) < rect.xsize() &&
child.y >= 0 && static_cast<size_t>(child.y) < rect.ysize()) {
float* value = &rect.Row(img, child.y)[child.x];
if (*value > threshold) {
*value = 0.0;
q.push_back(child);
Expand Down Expand Up @@ -223,7 +223,7 @@ struct ConnectedComponent {
float score;
Pixel mode;

void CompStats(const ImageF& energy, int extra) {
void CompStats(const ImageF& energy, const Rect& rect, int extra) {
maxEnergy = 0.0;
meanEnergy = 0.0;
varEnergy = 0.0;
Expand All @@ -236,12 +236,12 @@ struct ConnectedComponent {
for (int sy = -extra; sy < (static_cast<int>(bounds.ysize()) + extra);
sy++) {
int y = sy + static_cast<int>(bounds.y0());
if (y < 0 || static_cast<size_t>(y) >= energy.ysize()) continue;
const float* JXL_RESTRICT erow = energy.ConstRow(y);
if (y < 0 || static_cast<size_t>(y) >= rect.ysize()) continue;
const float* JXL_RESTRICT erow = rect.ConstRow(energy, y);
for (int sx = -extra; sx < (static_cast<int>(bounds.xsize()) + extra);
sx++) {
int x = sx + static_cast<int>(bounds.x0());
if (x < 0 || static_cast<size_t>(x) >= energy.xsize()) continue;
if (x < 0 || static_cast<size_t>(x) >= rect.xsize()) continue;
if (erow[x] > maxEnergy) {
maxEnergy = erow[x];
mode.x = x;
Expand Down Expand Up @@ -284,23 +284,24 @@ Rect BoundingRectangle(const std::vector<Pixel>& pixels) {
}

StatusOr<std::vector<ConnectedComponent>> FindCC(const ImageF& energy,
double t_low, double t_high,
const Rect& rect, double t_low,
double t_high,
uint32_t maxWindow,
double minScore) {
const int kExtraRect = 4;
JXL_ASSIGN_OR_RETURN(ImageF img,
ImageF::Create(energy.xsize(), energy.ysize()));
CopyImageTo(energy, &img);
std::vector<ConnectedComponent> ans;
for (size_t y = 0; y < img.ysize(); y++) {
float* JXL_RESTRICT row = img.Row(y);
for (size_t x = 0; x < img.xsize(); x++) {
for (size_t y = 0; y < rect.ysize(); y++) {
float* JXL_RESTRICT row = rect.Row(&img, y);
for (size_t x = 0; x < rect.xsize(); x++) {
if (row[x] > t_high) {
std::vector<Pixel> pixels;
row[x] = 0.0;
bool success = ExtractComponent(
&img, &pixels, Pixel{static_cast<int>(x), static_cast<int>(y)},
t_low);
rect, &img, &pixels,
Pixel{static_cast<int>(x), static_cast<int>(y)}, t_low);
if (!success) continue;
#if JXL_DEBUG_DOT_DETECT
for (size_t i = 0; i < pixels.size(); i++) {
Expand All @@ -311,7 +312,7 @@ StatusOr<std::vector<ConnectedComponent>> FindCC(const ImageF& energy,
Rect bounds = BoundingRectangle(pixels);
if (bounds.xsize() < maxWindow && bounds.ysize() < maxWindow) {
ConnectedComponent cc{bounds, std::move(pixels)};
cc.CompStats(energy, kExtraRect);
cc.CompStats(energy, rect, kExtraRect);
if (cc.score < minScore) continue;
JXL_DEBUG(JXL_DEBUG_DOT_DETECT,
"cc mode: (%d,%d), max: %f, bgMean: %f bgVar: "
Expand All @@ -330,7 +331,8 @@ StatusOr<std::vector<ConnectedComponent>> FindCC(const ImageF& energy,
// TODO(sggonzalez): Adapt this function for the different color spaces or
// remove it if the color space with the best performance does not need it
void ComputeDotLosses(GaussianEllipse* ellipse, const ConnectedComponent& cc,
const Image3F& img, const Image3F& background) {
const Rect& rect, const Image3F& img,
const Image3F& background) {
const int rectBounds = 2;
const double kIntensityR = 0.0; // 0.015;
const double kSigmaR = 0.0; // 0.01;
Expand All @@ -350,15 +352,15 @@ void ComputeDotLosses(GaussianEllipse* ellipse, const ConnectedComponent& cc,
for (int sy = -rectBounds;
sy < (static_cast<int>(cc.bounds.ysize()) + rectBounds); sy++) {
int y = sy + cc.bounds.y0();
if (y < 0 || static_cast<size_t>(y) >= img.ysize()) continue;
const float* JXL_RESTRICT row = img.ConstPlaneRow(c, y);
if (y < 0 || static_cast<size_t>(y) >= rect.ysize()) continue;
const float* JXL_RESTRICT row = rect.ConstPlaneRow(img, c, y);
// bgrow is only used if kOptimizeBackground is false.
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
const float* JXL_RESTRICT bgrow = background.ConstPlaneRow(c, y);
const float* JXL_RESTRICT bgrow = rect.ConstPlaneRow(background, c, y);
for (int sx = -rectBounds;
sx < (static_cast<int>(cc.bounds.xsize()) + rectBounds); sx++) {
int x = sx + cc.bounds.x0();
if (x < 0 || static_cast<size_t>(x) >= img.xsize()) continue;
if (x < 0 || static_cast<size_t>(x) >= rect.xsize()) continue;
double target = row[x];
double dotDelta = DotGaussianModel(
x - ellipse->x, y - ellipse->y, ct, st, ellipse->sigma_x,
Expand Down Expand Up @@ -393,9 +395,8 @@ void ComputeDotLosses(GaussianEllipse* ellipse, const ConnectedComponent& cc,
ellipse->ridge_loss = ellipse->l2_loss + ridgeTerm;
}

GaussianEllipse FitGaussianFast(const ConnectedComponent& cc,
const ImageF& energy, const Image3F& img,
const Image3F& background) {
GaussianEllipse FitGaussianFast(const ConnectedComponent& cc, const Rect& rect,
const Image3F& img, const Image3F& background) {
constexpr bool leastSqIntensity = true;
constexpr double kEpsilon = 1e-6;
GaussianEllipse ans;
Expand All @@ -413,18 +414,18 @@ GaussianEllipse FitGaussianFast(const ConnectedComponent& cc,
"%" PRIuS " %" PRIuS " %" PRIuS " %" PRIuS "\n", cc.bounds.x0(),
cc.bounds.y0(), cc.bounds.xsize(), cc.bounds.ysize());
for (int c = 0; c < 3; c++) {
color[c] = img.ConstPlaneRow(c, cc.mode.y)[cc.mode.x] -
background.ConstPlaneRow(c, cc.mode.y)[cc.mode.x];
color[c] = rect.ConstPlaneRow(img, c, cc.mode.y)[cc.mode.x] -
rect.ConstPlaneRow(background, c, cc.mode.y)[cc.mode.x];
}
double sign = (color[1] > 0) ? 1 : -1;
for (int sy = -kRectBounds; sy <= kRectBounds; sy++) {
int y = sy + cc.mode.y;
if (y < 0 || static_cast<size_t>(y) >= energy.ysize()) continue;
const float* JXL_RESTRICT row = img.ConstPlaneRow(1, y);
const float* JXL_RESTRICT bgrow = background.ConstPlaneRow(1, y);
if (y < 0 || static_cast<size_t>(y) >= rect.ysize()) continue;
const float* JXL_RESTRICT row = rect.ConstPlaneRow(img, 1, y);
const float* JXL_RESTRICT bgrow = rect.ConstPlaneRow(background, 1, y);
for (int sx = -kRectBounds; sx <= kRectBounds; sx++) {
int x = sx + cc.mode.x;
if (x < 0 || static_cast<size_t>(x) >= energy.xsize()) continue;
if (x < 0 || static_cast<size_t>(x) >= rect.xsize()) continue;
double w = std::max(kEpsilon, sign * (row[x] - bgrow[x]));
sum += w;

Expand All @@ -434,7 +435,7 @@ GaussianEllipse FitGaussianFast(const ConnectedComponent& cc,
m2[1] += w * x * y;
m2[2] += w * y * y;
for (int c = 0; c < 3; c++) {
bgColor[c] += background.ConstPlaneRow(c, y)[x];
bgColor[c] += rect.ConstPlaneRow(background, c, y)[x];
}
N++;
}
Expand Down Expand Up @@ -485,11 +486,11 @@ GaussianEllipse FitGaussianFast(const ConnectedComponent& cc,
int yc = static_cast<int>(cc.mode.y);
int xc = static_cast<int>(cc.mode.x);
for (int y = yc - kRectBounds; y <= yc + kRectBounds; y++) {
if (y < 0 || static_cast<size_t>(y) >= img.ysize()) continue;
const float* JXL_RESTRICT row = img.ConstPlaneRow(c, y);
const float* JXL_RESTRICT bgrow = background.ConstPlaneRow(c, y);
if (y < 0 || static_cast<size_t>(y) >= rect.ysize()) continue;
const float* JXL_RESTRICT row = rect.ConstPlaneRow(img, c, y);
const float* JXL_RESTRICT bgrow = rect.ConstPlaneRow(background, c, y);
for (int x = xc - kRectBounds; x <= xc + kRectBounds; x++) {
if (x < 0 || static_cast<size_t>(x) >= img.xsize()) continue;
if (x < 0 || static_cast<size_t>(x) >= rect.xsize()) continue;
double target = row[x] - bgrow[x];
double gaussian =
DotGaussianModel(x - ellipse->x, y - ellipse->y, ct, st,
Expand All @@ -501,13 +502,13 @@ GaussianEllipse FitGaussianFast(const ConnectedComponent& cc,
ans.intensity[c] = gd / (gg + 1e-6); // Regularized least squares
}
}
ComputeDotLosses(&ans, cc, img, background);
ComputeDotLosses(&ans, cc, rect, img, background);
return ans;
}

GaussianEllipse FitGaussian(const ConnectedComponent& cc, const ImageF& energy,
GaussianEllipse FitGaussian(const ConnectedComponent& cc, const Rect& rect,
const Image3F& img, const Image3F& background) {
auto ellipse = FitGaussianFast(cc, energy, img, background);
auto ellipse = FitGaussianFast(cc, rect, img, background);
if (ellipse.sigma_x < ellipse.sigma_y) {
std::swap(ellipse.sigma_x, ellipse.sigma_y);
ellipse.angle += kPi / 2.0;
Expand All @@ -534,14 +535,14 @@ GaussianEllipse FitGaussian(const ConnectedComponent& cc, const ImageF& energy,
} // namespace

StatusOr<std::vector<PatchInfo>> DetectGaussianEllipses(
const Image3F& opsin, const GaussianDetectParams& params,
const Image3F& opsin, const Rect& rect, const GaussianDetectParams& params,
const EllipseQuantParams& qParams, ThreadPool* pool) {
std::vector<PatchInfo> dots;
JXL_ASSIGN_OR_RETURN(Image3F smooth,
Image3F::Create(opsin.xsize(), opsin.ysize()));
JXL_ASSIGN_OR_RETURN(ImageF energy, ComputeEnergyImage(opsin, &smooth, pool));
JXL_ASSIGN_OR_RETURN(std::vector<ConnectedComponent> components,
FindCC(energy, params.t_low, params.t_high,
FindCC(energy, rect, params.t_low, params.t_high,
params.maxWinSize, params.minScore));
size_t numCC =
std::min(params.maxCC, (components.size() * params.percCC) / 100);
Expand All @@ -554,11 +555,11 @@ StatusOr<std::vector<PatchInfo>> DetectGaussianEllipses(
components.erase(components.begin() + numCC, components.end());
}
for (const auto& cc : components) {
GaussianEllipse ellipse = FitGaussian(cc, energy, opsin, smooth);
GaussianEllipse ellipse = FitGaussian(cc, rect, opsin, smooth);
if (ellipse.x < 0.0 ||
std::ceil(ellipse.x) >= static_cast<double>(opsin.xsize()) ||
std::ceil(ellipse.x) >= static_cast<double>(rect.xsize()) ||
ellipse.y < 0.0 ||
std::ceil(ellipse.y) >= static_cast<double>(opsin.ysize())) {
std::ceil(ellipse.y) >= static_cast<double>(rect.ysize())) {
continue;
}
if (ellipse.neg_pixels > params.maxNegPixels) continue;
Expand Down Expand Up @@ -586,8 +587,8 @@ StatusOr<std::vector<PatchInfo>> DetectGaussianEllipses(
for (size_t x = 0; x < patch.xsize; x++) {
for (size_t c = 0; c < 3; c++) {
patch.fpixels[c][y * patch.xsize + x] =
opsin.ConstPlaneRow(c, y0 + y)[x0 + x] -
smooth.ConstPlaneRow(c, y0 + y)[x0 + x];
rect.ConstPlaneRow(opsin, c, y0 + y)[x0 + x] -
rect.ConstPlaneRow(smooth, c, y0 + y)[x0 + x];
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/jxl/enc_detect_dots.h
Expand Up @@ -58,7 +58,7 @@ struct EllipseQuantParams {

// Detects dots in XYB image.
StatusOr<std::vector<PatchInfo>> DetectGaussianEllipses(
const Image3F& opsin, const GaussianDetectParams& params,
const Image3F& opsin, const Rect& rect, const GaussianDetectParams& params,
const EllipseQuantParams& qParams, ThreadPool* pool);

} // namespace jxl
Expand Down
6 changes: 3 additions & 3 deletions lib/jxl/enc_dot_dictionary.cc
Expand Up @@ -35,7 +35,7 @@ const std::array<size_t, 3> kEllipseIntensityQ{{10, 36, 10}};
} // namespace

StatusOr<std::vector<PatchInfo>> FindDotDictionary(
const CompressParams& cparams, const Image3F& opsin,
const CompressParams& cparams, const Image3F& opsin, const Rect& rect,
const ColorCorrelationMap& cmap, ThreadPool* pool) {
if (ApplyOverride(cparams.dots,
cparams.butteraugli_distance >= kMinButteraugliForDots)) {
Expand All @@ -52,13 +52,13 @@ StatusOr<std::vector<PatchInfo>> FindDotDictionary(
ellipse_params.maxCC = 100;
ellipse_params.percCC = 100;
EllipseQuantParams qParams{
opsin.xsize(), opsin.ysize(), kEllipsePosQ,
rect.xsize(), rect.ysize(), kEllipsePosQ,
kEllipseMinSigma, kEllipseMaxSigma, kEllipseSigmaQ,
kEllipseAngleQ, kEllipseMinIntensity, kEllipseMaxIntensity,
kEllipseIntensityQ, kEllipsePosQ <= 5, cmap.YtoXRatio(0),
cmap.YtoBRatio(0)};

return DetectGaussianEllipses(opsin, ellipse_params, qParams, pool);
return DetectGaussianEllipses(opsin, rect, ellipse_params, qParams, pool);
}
std::vector<PatchInfo> nothing;
return nothing;
Expand Down
2 changes: 1 addition & 1 deletion lib/jxl/enc_dot_dictionary.h
Expand Up @@ -22,7 +22,7 @@
namespace jxl {

StatusOr<std::vector<PatchInfo>> FindDotDictionary(
const CompressParams& cparams, const Image3F& opsin,
const CompressParams& cparams, const Image3F& opsin, const Rect& rect,
const ColorCorrelationMap& cmap, ThreadPool* pool);

} // namespace jxl
Expand Down
7 changes: 6 additions & 1 deletion lib/jxl/enc_patch_dictionary.cc
Expand Up @@ -17,6 +17,7 @@
#include "lib/jxl/base/common.h"
#include "lib/jxl/base/compiler_specific.h"
#include "lib/jxl/base/override.h"
#include "lib/jxl/base/printf_macros.h"
#include "lib/jxl/base/random.h"
#include "lib/jxl/base/status.h"
#include "lib/jxl/dec_cache.h"
Expand Down Expand Up @@ -586,7 +587,9 @@ Status FindBestPatchDictionary(const Image3F& opsin,
state->cparams.dots,
state->cparams.speed_tier <= SpeedTier::kSquirrel &&
state->cparams.butteraugli_distance >= kMinButteraugliForDots)) {
JXL_ASSIGN_OR_RETURN(info, FindDotDictionary(state->cparams, opsin,
Rect rect(0, 0, state->shared.frame_dim.xsize,
state->shared.frame_dim.ysize);
JXL_ASSIGN_OR_RETURN(info, FindDotDictionary(state->cparams, opsin, rect,
state->shared.cmap, pool));
}

Expand Down Expand Up @@ -716,6 +719,8 @@ Status FindBestPatchDictionary(const Image3F& opsin,
}
}
for (const auto& pos : info[i].second) {
JXL_DEBUG_V(4, "Patch %" PRIuS "x%" PRIuS " at position %u,%u",
ref_pos.xsize, ref_pos.ysize, pos.first, pos.second);
positions.emplace_back(
PatchPosition{pos.first, pos.second, pref_positions.size()});
// Add blending for color channels, ignore other channels.
Expand Down