Skip to content

Commit

Permalink
[flang] Correct folding of SPREAD() for higher ranks
Browse files Browse the repository at this point in the history
The construction of the dimension order vector used to populate the
result array was incorrect, leading to a scrambled-looking result
for rank-3 and higher results.  Fix, and extend tests.

Differential Revision: https://reviews.llvm.org/D125113
  • Loading branch information
klausler committed May 9, 2022
1 parent 9641b9b commit 85fdbc1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 2 additions & 2 deletions flang/lib/Evaluate/fold-implementation.h
Expand Up @@ -890,9 +890,9 @@ template <typename T> Expr<T> Folder<T>::SPREAD(FunctionRef<T> &&funcRef) {
Constant<T> spread{source->Reshape(std::move(shape))};
std::vector<int> dimOrder;
for (int j{0}; j < sourceRank; ++j) {
dimOrder.push_back(j);
dimOrder.push_back(j < *dim - 1 ? j : j + 1);
}
dimOrder.insert(dimOrder.begin() + *dim - 1, sourceRank);
dimOrder.push_back(*dim - 1);
ConstantSubscripts at{spread.lbounds()}; // all 1
spread.CopyFrom(*source, TotalElementCount(spread.shape()), at, &dimOrder);
return Expr<T>{std::move(spread)};
Expand Down
4 changes: 3 additions & 1 deletion flang/test/Evaluate/fold-spread.f90
Expand Up @@ -5,9 +5,11 @@ module m1
logical, parameter :: test_stov = all(spread(1, 1, 2) == [1, 1])
logical, parameter :: test_vtom1 = all(spread([1, 2], 1, 3) == reshape([1, 1, 1, 2, 2, 2], [3, 2]))
logical, parameter :: test_vtom2 = all(spread([1, 2], 2, 3) == reshape([1, 2, 1, 2, 1, 2], [2, 3]))
logical, parameter :: test_vtom3 = all(spread([1, 2], 2, 3) == reshape([1, 2, 1, 2, 1, 2], [2, 3]))
logical, parameter :: test_vtom3 = all(spread([1, 2], 1, 3) == reshape([1, 1, 1, 2, 2, 2], [3, 2]))
logical, parameter :: test_log1 = all(all(spread([.false., .true.], 1, 2), dim=2) .eqv. [.false., .false.])
logical, parameter :: test_log2 = all(all(spread([.false., .true.], 2, 2), dim=2) .eqv. [.false., .true.])
logical, parameter :: test_log3 = all(any(spread([.false., .true.], 1, 2), dim=2) .eqv. [.true., .true.])
logical, parameter :: test_log4 = all(any(spread([.false., .true.], 2, 2), dim=2) .eqv. [.false., .true.])
logical, parameter :: test_m2toa3 = all(spread(reshape([(j,j=1,6)],[2,3]),1,4) == &
reshape([((j,k=1,4),j=1,6)],[4,2,3]))
end module

0 comments on commit 85fdbc1

Please sign in to comment.