From d9cbea92934bbcb08f15b365a0906950a180135a Mon Sep 17 00:00:00 2001 From: Matt Bauman Date: Tue, 29 Aug 2017 15:11:16 -0400 Subject: [PATCH] Add squeeze(f, A, dims) for reductions to drop dims This simple definition makes it easier to write reductions that drops the dimensions over which they reduce. Fixes #16606, addresses part of the root issue in #22000. --- base/abstractarraymath.jl | 8 ++++++++ test/arrayops.jl | 11 +++++++++++ 2 files changed, 19 insertions(+) diff --git a/base/abstractarraymath.jl b/base/abstractarraymath.jl index 9c7b098ff0d42..b7acf141bc4e3 100644 --- a/base/abstractarraymath.jl +++ b/base/abstractarraymath.jl @@ -84,8 +84,16 @@ function _dropdims(A::AbstractArray, dims::Dims) end reshape(A, d::typeof(_sub(axes(A), dims))) end + _dropdims(A::AbstractArray, dim::Integer) = _dropdims(A, (Int(dim),)) +""" + squeeze(f, A, dims) + +Compute reduction `f` over dimensions `dims` in array `A` and drop those dimensions from the result. +""" +squeeze(f, A::AbstractArray, dims::Union{Dims, Integer}) = squeeze(f(A, dims), dims) + ## Unary operators ## conj(x::AbstractArray{<:Real}) = x diff --git a/test/arrayops.jl b/test/arrayops.jl index 7a2fa864f543c..65e31028b644e 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -303,6 +303,17 @@ end @test_throws ArgumentError dropdims(a, dims=4) @test_throws ArgumentError dropdims(a, dims=6) + @test @inferred(squeeze(sum, a, 1)) == @inferred(squeeze(sum, a, (1,))) == reshape(sum(a, 1), (1, 8, 8, 1)) + @test @inferred(squeeze(sum, a, 3)) == @inferred(squeeze(sum, a, (3,))) == reshape(sum(a, 3), (1, 1, 8, 1)) + @test @inferred(squeeze(sum, a, 4)) == @inferred(squeeze(sum, a, (4,))) == reshape(sum(a, 4), (1, 1, 8, 1)) + @test @inferred(squeeze(sum, a, (1, 5))) == squeeze(sum, a, (5, 1)) == reshape(sum(a, (5, 1)), (1, 8, 8)) + @test @inferred(squeeze(sum, a, (1, 2, 5))) == squeeze(sum, a, (5, 2, 1)) == reshape(sum(a, (5, 2, 1)), (8, 8)) + @test_throws ArgumentError squeeze(sum, a, 0) + @test_throws ArgumentError squeeze(sum, a, (1, 1)) + @test_throws ArgumentError squeeze(sum, a, (1, 2, 1)) + @test_throws ArgumentError squeeze(sum, a, (1, 1, 2)) + @test_throws ArgumentError squeeze(sum, a, 6) + sz = (5,8,7) A = reshape(1:prod(sz),sz...) @test A[2:6] == [2:6;]