Skip to content

Commit

Permalink
cat can grow ndims. Fixes JuliaGPU#72, JuliaGPU#73
Browse files Browse the repository at this point in the history
Unnecessary copies are created when growing ndims.
  • Loading branch information
gustafsson committed Apr 25, 2018
1 parent bb414a4 commit af6e3a1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,22 @@ end
x, ntuple(n -> n == dim ? i : I[n], Val{N})
end

function growdims(dim, x)
if ndims(x) >= dim
x
else
x[fill(Colon(), dim)...]
end
end

function _cat(dim, dest, xs...)
function kernel(dim, dest, xs)
I = @cuindex dest
@inbounds n, I′ = catindex(dim, Int.(I), size.(xs))
@inbounds dest[I...] = xs[n][I′...]
return
end
xs = growdims.(dim, xs)
blk, thr = cudims(dest)
@cuda (blk, thr) kernel(dim, dest, xs)
return dest
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ end
@test testf(vcat, ones(5), zeros(5))
@test testf(hcat, rand(3, 3), rand(3, 3))
@test testf(vcat, rand(3, 3), rand(3, 3))
@test testf(hcat, rand(3), rand(3))
@test testf(cat, 4, rand(3, 4), rand(3, 4))
end

@testset "Broadcast" begin
Expand Down

0 comments on commit af6e3a1

Please sign in to comment.