Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

slicemap on TrackedArrays produces Array{TrackedReal} #3

Open
baggepinnen opened this issue Nov 21, 2019 · 3 comments
Open

slicemap on TrackedArrays produces Array{TrackedReal} #3

baggepinnen opened this issue Nov 21, 2019 · 3 comments

Comments

@baggepinnen
Copy link

This is a problem since the code below will not run on the GPU (unless one allows scalar operations which is not ideal)

julia> using SliceMap, Flux

julia> slicemap(norm, Flux.param(randn(2,2,2,2)), dims=(1,2))
2×2 Array{Tracker.TrackedReal{Float64},2}:
 1.91925  2.69252
 1.26406  1.22966
@baggepinnen
Copy link
Author

Oops, I see now that this function is really only supported for Zygote. Feel free to close this issue if appropriate.

@mcabbott
Copy link
Owner

mcabbott commented Nov 21, 2019

That’s right, I don’t recally exactly why but I couldn’t make the gradients work in general for Tracker. (For mapcols I explicitly handle the pullback functions myself, rather than getting Tracker to keep track of them.)

The simplest version of this was the @grad function gluecol here:
https://github.com/mcabbott/SliceMap.jl/blob/master/src/SliceMap.jl#L197
Since norm returns a scalar, so you don’t actually need to glue anything, but the slicing stage also seems to have problems.

@mcabbott
Copy link
Owner

However you can do this:

julia> reshape(mapcols(norm, reshape(Tracker.param(randn(2,2,2,3)), 4,6)), 2,3)
Tracked 2×3 Array{Float64,2}:
 1.2732   1.34758  2.0512 
 1.42643  1.55332  1.70304

I guess anything with dims=1:M can be handled like this, norm(vec(m)) == norm(m) but for other functions you could insert reshape there too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants