Skip to content

Commit

Permalink
Add squeeze(f, A, dims) for reductions to drop dims
Browse files Browse the repository at this point in the history
This simple definition makes it easier to write reductions that drops the dimensions over which they reduce. Fixes JuliaLang#16606, addresses part of the root issue in JuliaLang#22000.
  • Loading branch information
mbauman authored and nickrobinson251 committed Aug 31, 2019
1 parent 8f575c5 commit d9cbea9
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
8 changes: 8 additions & 0 deletions base/abstractarraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,16 @@ function _dropdims(A::AbstractArray, dims::Dims)
end
reshape(A, d::typeof(_sub(axes(A), dims)))
end

_dropdims(A::AbstractArray, dim::Integer) = _dropdims(A, (Int(dim),))

"""
squeeze(f, A, dims)
Compute reduction `f` over dimensions `dims` in array `A` and drop those dimensions from the result.
"""
squeeze(f, A::AbstractArray, dims::Union{Dims, Integer}) = squeeze(f(A, dims), dims)

## Unary operators ##

conj(x::AbstractArray{<:Real}) = x
Expand Down
11 changes: 11 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,17 @@ end
@test_throws ArgumentError dropdims(a, dims=4)
@test_throws ArgumentError dropdims(a, dims=6)

@test @inferred(squeeze(sum, a, 1)) == @inferred(squeeze(sum, a, (1,))) == reshape(sum(a, 1), (1, 8, 8, 1))
@test @inferred(squeeze(sum, a, 3)) == @inferred(squeeze(sum, a, (3,))) == reshape(sum(a, 3), (1, 1, 8, 1))
@test @inferred(squeeze(sum, a, 4)) == @inferred(squeeze(sum, a, (4,))) == reshape(sum(a, 4), (1, 1, 8, 1))
@test @inferred(squeeze(sum, a, (1, 5))) == squeeze(sum, a, (5, 1)) == reshape(sum(a, (5, 1)), (1, 8, 8))
@test @inferred(squeeze(sum, a, (1, 2, 5))) == squeeze(sum, a, (5, 2, 1)) == reshape(sum(a, (5, 2, 1)), (8, 8))
@test_throws ArgumentError squeeze(sum, a, 0)
@test_throws ArgumentError squeeze(sum, a, (1, 1))
@test_throws ArgumentError squeeze(sum, a, (1, 2, 1))
@test_throws ArgumentError squeeze(sum, a, (1, 1, 2))
@test_throws ArgumentError squeeze(sum, a, 6)

sz = (5,8,7)
A = reshape(1:prod(sz),sz...)
@test A[2:6] == [2:6;]
Expand Down

0 comments on commit d9cbea9

Please sign in to comment.