Skip to content

Commit

Permalink
Buffer overrun detected with asan (pytorch#2218)
Browse files Browse the repository at this point in the history
Also changed [i] to .at(i)
  • Loading branch information
naoyam committed Nov 28, 2022
1 parent 3146e89 commit 7dfb553
Showing 1 changed file with 19 additions and 18 deletions.
37 changes: 19 additions & 18 deletions torch/csrc/jit/codegen/cuda/transform_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ class AnalyzeViewTransformation {
int64_t original_view_index = 0;
int64_t transform_view_index = 0;
int64_t new_view_index = 0;
int64_t current_size = original_view_[0];
int64_t current_size = original_view_.at(0);

// Safety counters to make sure we don't end up in an infinite loop.
int64_t prev_original_view_index = std::numeric_limits<int64_t>::max();
Expand Down Expand Up @@ -508,10 +508,10 @@ class AnalyzeViewTransformation {
"View is complete, but there's still some elements to distribute.");
}

if ((new_view_index == new_view_.size() ||
(new_view_[new_view_index + 1] != 1)) &&
if ((new_view_index + 1 >= new_view_.size() ||
(new_view_.at(new_view_index + 1) != 1)) &&
original_view_index + 1 < original_view_.size() &&
original_view_[original_view_index + 1] == 1 &&
original_view_.at(original_view_index + 1) == 1 &&
!isImplicitBroadcast(original_view_index + 1)) {
// Next index in original_view is runtime size 1 and next new view is
// not, merge the size 1 into the current view before moving on. Even if
Expand All @@ -525,7 +525,7 @@ class AnalyzeViewTransformation {

if (new_view_index < new_view_.size() &&
// Still new dimensions to resolve and current size does resolve it.
current_size == new_view_[new_view_index]) {
current_size == new_view_.at(new_view_index)) {
// Keep this dimension, it's good to go, we hit a boundary where there's
// a multiple of original dims, that matches a multiple of view dims.
// Increment state and keep going.
Expand All @@ -536,7 +536,7 @@ class AnalyzeViewTransformation {

// Update current_size with the next size in original view
if (original_view_index < original_view_.size()) {
current_size = original_view_[original_view_index];
current_size = original_view_.at(original_view_index);
} else {
current_size = 0;
}
Expand All @@ -547,7 +547,8 @@ class AnalyzeViewTransformation {
// view. Insert broadcast and increment new_view. Size 1 dimensions in
// new_view that don't match up with runtime size 1's in original view are
// assumed to be broadcast (not a split from a runtime domain).
if (new_view_index < new_view_.size() && new_view_[new_view_index] == 1) {
if (new_view_index < new_view_.size() &&
new_view_.at(new_view_index) == 1) {
broadcast_transforms_.push_back(
std::make_shared<BroadcastTransform>(new_view_index));
++new_view_index;
Expand All @@ -571,7 +572,7 @@ class AnalyzeViewTransformation {

// Update original position and current size.
if (original_view_index < original_view_.size()) {
current_size = original_view_[original_view_index];
current_size = original_view_.at(original_view_index);
} else {
current_size = 0;
}
Expand All @@ -597,11 +598,11 @@ class AnalyzeViewTransformation {
"Expecting to still have new dimensions to work on in view, but none left.");

if (new_view_index < new_view_.size() &&
current_size % new_view_[new_view_index] == 0) {
current_size % new_view_.at(new_view_index) == 0) {
// Insert split to generate the next new_view domain.
view_transforms_.push_back(std::make_shared<SplitTransform>(
transform_view_index, new_view_[new_view_index]));
current_size /= new_view_[new_view_index];
transform_view_index, new_view_.at(new_view_index)));
current_size /= new_view_.at(new_view_index);
TORCH_INTERNAL_ASSERT(current_size > 1, "This should be unreachable.");
// Update transform and new since a split doesn't increment from the
// original domain we're working on.
Expand All @@ -618,7 +619,7 @@ class AnalyzeViewTransformation {

view_transforms_.push_back(
std::make_shared<MergeTransform>(transform_view_index));
current_size *= original_view_[++original_view_index];
current_size *= original_view_.at(++original_view_index);
}
}

Expand Down Expand Up @@ -655,7 +656,7 @@ TensorDomain* createViewDomain(
// Apply squeeze.
for (auto id_i : c10::irange(orig_root_domain.size())) {
if (!view_analysis.squeeze_axes.at(id_i)) {
auto id = orig_root_domain[id_i];
auto id = orig_root_domain.at(id_i);
new_root_domain.push_back(id->cloneWithoutRFactor());
continue;
}
Expand Down Expand Up @@ -695,21 +696,21 @@ std::pair<std::vector<int64_t>, std::vector<int64_t>> inferViewShapes(
int64_t dynamic_index = -1;
int64_t new_size_num_elements = 1;
for (int64_t idx = 0; idx < new_sizes.size(); ++idx) {
if (new_sizes[idx] == -1) {
if (new_sizes.at(idx) == -1) {
TORCH_INTERNAL_ASSERT(
dynamic_index == -1, "Only one dimension can by inferred.")
dynamic_index = idx;
} else {
TORCH_INTERNAL_ASSERT(new_sizes[idx] > 0);
new_size_num_elements *= new_sizes[idx];
new_view[idx] = new_sizes[idx];
TORCH_INTERNAL_ASSERT(new_sizes.at(idx) > 0);
new_size_num_elements *= new_sizes.at(idx);
new_view.at(idx) = new_sizes.at(idx);
}
}

const int64_t kNumElements = std::accumulate(
original_view.begin(), original_view.end(), 1, std::multiplies<>());
if (dynamic_index != -1) {
new_view[dynamic_index] = kNumElements / new_size_num_elements;
new_view.at(dynamic_index) = kNumElements / new_size_num_elements;
}

return {original_view, new_view};
Expand Down

0 comments on commit 7dfb553

Please sign in to comment.