From e0adee8481623613933551e00adcd9ddea18d889 Mon Sep 17 00:00:00 2001 From: Peter Klausler Date: Sun, 29 May 2022 14:18:51 -0700 Subject: [PATCH] [flang] Correct folding of CSHIFT and EOSHIFT for DIM>1 The algorithm was wrong for higher dimensions, and so were the expected test results. Rework. Differential Revision: https://reviews.llvm.org/D127018 --- flang/lib/Evaluate/fold-implementation.h | 112 ++++++++++++++--------- flang/test/Evaluate/folding23.f90 | 4 +- flang/test/Evaluate/folding27.f90 | 4 +- 3 files changed, 71 insertions(+), 49 deletions(-) diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h index 97abfc688d6398..c0dee020f8fb85 100644 --- a/flang/lib/Evaluate/fold-implementation.h +++ b/flang/lib/Evaluate/fold-implementation.h @@ -613,26 +613,33 @@ template Expr Folder::CSHIFT(FunctionRef &&funcRef) { } if (ok) { std::vector> resultElements; - ConstantSubscripts arrayAt{array->lbounds()}; - ConstantSubscript dimLB{arrayAt[zbDim]}; + ConstantSubscripts arrayLB{array->lbounds()}; + ConstantSubscripts arrayAt{arrayLB}; + ConstantSubscript &dimIndex{arrayAt[zbDim]}; + ConstantSubscript dimLB{dimIndex}; // initial value ConstantSubscript dimExtent{array->shape()[zbDim]}; - ConstantSubscripts shiftAt{shift->lbounds()}; - for (auto n{GetSize(array->shape())}; n > 0; n -= dimExtent) { - ConstantSubscript shiftCount{shift->At(shiftAt).ToInt64()}; - ConstantSubscript zbDimIndex{shiftCount % dimExtent}; - if (zbDimIndex < 0) { - zbDimIndex += dimExtent; - } - for (ConstantSubscript j{0}; j < dimExtent; ++j) { - arrayAt[zbDim] = dimLB + zbDimIndex; - resultElements.push_back(array->At(arrayAt)); - if (++zbDimIndex == dimExtent) { - zbDimIndex = 0; + ConstantSubscripts shiftLB{shift->lbounds()}; + for (auto n{GetSize(array->shape())}; n > 0; --n) { + ConstantSubscript origDimIndex{dimIndex}; + ConstantSubscripts shiftAt; + if (shift->Rank() > 0) { + int k{0}; + for (int j{0}; j < rank; ++j) { + if (j != zbDim) { + shiftAt.emplace_back(shiftLB[k++] + arrayAt[j] - arrayLB[j]); + } } } - arrayAt[zbDim] = dimLB + std::max(dimExtent, 1) - 1; + ConstantSubscript shiftCount{shift->At(shiftAt).ToInt64()}; + dimIndex = dimLB + ((dimIndex - dimLB + shiftCount) % dimExtent); + if (dimIndex < dimLB) { + dimIndex += dimExtent; + } else if (dimIndex >= dimLB + dimExtent) { + dimIndex -= dimExtent; + } + resultElements.push_back(array->At(arrayAt)); + dimIndex = origDimIndex; array->IncrementSubscripts(arrayAt); - shift->IncrementSubscripts(shiftAt); } return Expr{PackageConstant( std::move(resultElements), *array, array->shape())}; @@ -714,42 +721,57 @@ template Expr Folder::EOSHIFT(FunctionRef &&funcRef) { } if (ok) { std::vector> resultElements; - ConstantSubscripts arrayAt{array->lbounds()}; - ConstantSubscript dimLB{arrayAt[zbDim]}; + ConstantSubscripts arrayLB{array->lbounds()}; + ConstantSubscripts arrayAt{arrayLB}; + ConstantSubscript &dimIndex{arrayAt[zbDim]}; + ConstantSubscript dimLB{dimIndex}; // initial value ConstantSubscript dimExtent{array->shape()[zbDim]}; - ConstantSubscripts shiftAt{shift->lbounds()}; - ConstantSubscripts boundaryAt; + ConstantSubscripts shiftLB{shift->lbounds()}; + ConstantSubscripts boundaryLB; if (boundary) { - boundaryAt = boundary->lbounds(); + boundaryLB = boundary->lbounds(); } - for (auto n{GetSize(array->shape())}; n > 0; n -= dimExtent) { + for (auto n{GetSize(array->shape())}; n > 0; --n) { + ConstantSubscript origDimIndex{dimIndex}; + ConstantSubscripts shiftAt; + if (shift->Rank() > 0) { + int k{0}; + for (int j{0}; j < rank; ++j) { + if (j != zbDim) { + shiftAt.emplace_back(shiftLB[k++] + arrayAt[j] - arrayLB[j]); + } + } + } ConstantSubscript shiftCount{shift->At(shiftAt).ToInt64()}; - for (ConstantSubscript j{0}; j < dimExtent; ++j) { - ConstantSubscript zbAt{shiftCount + j}; - if (zbAt >= 0 && zbAt < dimExtent) { - arrayAt[zbDim] = dimLB + zbAt; - resultElements.push_back(array->At(arrayAt)); - } else if (boundary) { - resultElements.push_back(boundary->At(boundaryAt)); - } else if constexpr (T::category == TypeCategory::Integer || - T::category == TypeCategory::Real || - T::category == TypeCategory::Complex || - T::category == TypeCategory::Logical) { - resultElements.emplace_back(); - } else if constexpr (T::category == TypeCategory::Character) { - auto len{static_cast(array->LEN())}; - typename Scalar::value_type space{' '}; - resultElements.emplace_back(len, space); - } else { - DIE("no derived type boundary"); + dimIndex += shiftCount; + if (dimIndex >= dimLB && dimIndex < dimLB + dimExtent) { + resultElements.push_back(array->At(arrayAt)); + } else if (boundary) { + ConstantSubscripts boundaryAt; + if (boundary->Rank() > 0) { + for (int j{0}; j < rank; ++j) { + int k{0}; + if (j != zbDim) { + boundaryAt.emplace_back( + boundaryLB[k++] + arrayAt[j] - arrayLB[j]); + } + } } + resultElements.push_back(boundary->At(boundaryAt)); + } else if constexpr (T::category == TypeCategory::Integer || + T::category == TypeCategory::Real || + T::category == TypeCategory::Complex || + T::category == TypeCategory::Logical) { + resultElements.emplace_back(); + } else if constexpr (T::category == TypeCategory::Character) { + auto len{static_cast(array->LEN())}; + typename Scalar::value_type space{' '}; + resultElements.emplace_back(len, space); + } else { + DIE("no derived type boundary"); } - arrayAt[zbDim] = dimLB + std::max(dimExtent, 1) - 1; + dimIndex = origDimIndex; array->IncrementSubscripts(arrayAt); - shift->IncrementSubscripts(shiftAt); - if (boundary) { - boundary->IncrementSubscripts(boundaryAt); - } } return Expr{PackageConstant( std::move(resultElements), *array, array->shape())}; diff --git a/flang/test/Evaluate/folding23.f90 b/flang/test/Evaluate/folding23.f90 index f31478ed3c5e5b..c25d2fc9398281 100644 --- a/flang/test/Evaluate/folding23.f90 +++ b/flang/test/Evaluate/folding23.f90 @@ -9,7 +9,7 @@ module m logical, parameter :: test_eoshift_3 = all(eoshift([1., 2., 3.], 1) == [2., 3., 0.]) logical, parameter :: test_eoshift_4 = all(eoshift(['ab', 'cd', 'ef'], -1, 'x') == ['x ', 'ab', 'cd']) logical, parameter :: test_eoshift_5 = all([eoshift(arr, 1, dim=1)] == [2, 0, 4, 0, 6, 0]) - logical, parameter :: test_eoshift_6 = all([eoshift(arr, 1, dim=2)] == [3, 5, 0, 4, 6, 0]) + logical, parameter :: test_eoshift_6 = all([eoshift(arr, 1, dim=2)] == [3, 4, 5, 6, 0, 0]) logical, parameter :: test_eoshift_7 = all([eoshift(arr, [1, -1, 0])] == [2, 0, 0, 3, 5, 6]) - logical, parameter :: test_eoshift_8 = all([eoshift(arr, [1, -1], dim=2)] == [3, 5, 0, 0, 2, 4]) + logical, parameter :: test_eoshift_8 = all([eoshift(arr, [1, -1], dim=2)] == [3, 0, 5, 2, 0, 4]) end module diff --git a/flang/test/Evaluate/folding27.f90 b/flang/test/Evaluate/folding27.f90 index 0d3d333c0f1000..43699184f31aa4 100644 --- a/flang/test/Evaluate/folding27.f90 +++ b/flang/test/Evaluate/folding27.f90 @@ -9,7 +9,7 @@ module m logical, parameter :: test_cshift_3 = all(cshift([1, 2, 3], 4) == [2, 3, 1]) logical, parameter :: test_cshift_4 = all(cshift([1, 2, 3], -1) == [3, 1, 2]) logical, parameter :: test_cshift_5 = all([cshift(arr, 1, dim=1)] == [2, 1, 4, 3, 6, 5]) - logical, parameter :: test_cshift_6 = all([cshift(arr, 1, dim=2)] == [3, 5, 1, 4, 6, 2]) + logical, parameter :: test_cshift_6 = all([cshift(arr, 1, dim=2)] == [3, 4, 5, 6, 1, 2]) logical, parameter :: test_cshift_7 = all([cshift(arr, [1, 2, 3])] == [2, 1, 3, 4, 6, 5]) - logical, parameter :: test_cshift_8 = all([cshift(arr, [1, 2], dim=2)] == [3, 5, 1, 6, 2, 4]) + logical, parameter :: test_cshift_8 = all([cshift(arr, [1, 2], dim=2)] == [3, 6, 5, 2, 1, 4]) end module