Skip to content

Commit

Permalink
Check height limit in modular trees. (libjxl#3943)
Browse files Browse the repository at this point in the history
Also rewrite the implementation to use iterative checking instead of
recursive checking of tree property values, to ensure stack usage is
low.

Before, it was possible for appropriately-crafted files to use a
significant amount of stack (in the order of hundreds of MB).

(cherry picked from commit bf4781a)
  • Loading branch information
veluca93 authored and mo271 committed Nov 26, 2024
1 parent 04aee7f commit 5480afe
Showing 1 changed file with 45 additions and 21 deletions.
66 changes: 45 additions & 21 deletions lib/jxl/modular/encoding/dec_ma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "lib/jxl/modular/encoding/dec_ma.h"

#include <limits>
#include <vector>

#include "lib/jxl/base/printf_macros.h"
#include "lib/jxl/dec_ans.h"
Expand All @@ -17,23 +18,49 @@ namespace jxl {

namespace {

Status ValidateTree(
const Tree &tree,
const std::vector<std::pair<pixel_type, pixel_type>> &prop_bounds,
size_t root) {
if (tree[root].property == -1) return true;
size_t p = tree[root].property;
int val = tree[root].splitval;
if (prop_bounds[p].first > val) return JXL_FAILURE("Invalid tree");
// Splitting at max value makes no sense: left range will be exactly same
// as parent, right range will be invalid (min > max).
if (prop_bounds[p].second <= val) return JXL_FAILURE("Invalid tree");
auto new_bounds = prop_bounds;
new_bounds[p].first = val + 1;
JXL_RETURN_IF_ERROR(ValidateTree(tree, new_bounds, tree[root].lchild));
new_bounds[p] = prop_bounds[p];
new_bounds[p].second = val;
return ValidateTree(tree, new_bounds, tree[root].rchild);
Status ValidateTree(const Tree &tree) {
int num_properties = 0;
for (auto node : tree) {
if (node.property >= num_properties) {
num_properties = node.property + 1;
}
}
std::vector<int> height(tree.size());
std::vector<std::pair<pixel_type, pixel_type>> property_ranges(
num_properties * tree.size());
for (int i = 0; i < num_properties; i++) {
property_ranges[i].first = std::numeric_limits<pixel_type>::min();
property_ranges[i].second = std::numeric_limits<pixel_type>::max();
}
const int kHeightLimit = 2048;
for (size_t i = 0; i < tree.size(); i++) {
if (height[i] > kHeightLimit) {
return JXL_FAILURE("Tree too tall: %d", height[i]);
}
if (tree[i].property == -1) continue;
height[tree[i].lchild] = height[i] + 1;
height[tree[i].rchild] = height[i] + 1;
for (size_t p = 0; p < static_cast<size_t>(num_properties); p++) {
if (p == static_cast<size_t>(tree[i].property)) {
pixel_type l = property_ranges[i * num_properties + p].first;
pixel_type u = property_ranges[i * num_properties + p].second;
pixel_type val = tree[i].splitval;
if (l > val || u <= val) {
return JXL_FAILURE("Invalid tree");
}
property_ranges[tree[i].lchild * num_properties + p] =
std::make_pair(val + 1, u);
property_ranges[tree[i].rchild * num_properties + p] =
std::make_pair(l, val);
} else {
property_ranges[tree[i].lchild * num_properties + p] =
property_ranges[i * num_properties + p];
property_ranges[tree[i].rchild * num_properties + p] =
property_ranges[i * num_properties + p];
}
}
}
return true;
}

Status DecodeTree(BitReader *br, ANSSymbolReader *reader,
Expand Down Expand Up @@ -82,10 +109,7 @@ Status DecodeTree(BitReader *br, ANSSymbolReader *reader,
tree->size() + to_decode + 2, Predictor::Zero, 0, 1);
to_decode += 2;
}
std::vector<std::pair<pixel_type, pixel_type>> prop_bounds;
prop_bounds.resize(256, {std::numeric_limits<pixel_type>::min(),
std::numeric_limits<pixel_type>::max()});
return ValidateTree(*tree, prop_bounds, 0);
return ValidateTree(*tree);
}
} // namespace

Expand Down

0 comments on commit 5480afe

Please sign in to comment.