Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiation with np.select (piecewise functions) #36

Closed
richardotis opened this issue Jul 25, 2015 · 8 comments
Closed

Differentiation with np.select (piecewise functions) #36

richardotis opened this issue Jul 25, 2015 · 8 comments

Comments

@richardotis
Copy link
Contributor

You're doing some great work here! I've run into a slight issue with computing gradients of piecewise functions using np.select. I've put together a minimal (trivial) example (Python 3.3):

import autograd.numpy as anp
from autograd import grad

def func(x):
    return anp.select([True], [x[0]*(anp.log(x[1]**x[1]) + anp.log(x[2]**x[2]))])

gradient = grad(func)(anp.array([300., 0.5, 0.5]))
print(gradient)
[ 0.  0.  0.]
[...]/autograd/core.py:39: UserWarning: Output seems independent of input. Returning zero gradient.
  warnings.warn("Output seems independent of input. Returning zero gradient.")

select accepts a list of conditions and a list of choices, and the first condition to be evaluated True is returned from the corresponding choice list. The gradient of a piecewise-defined function should just be the gradient of the choice list, but I haven't yet wrapped my head around how constructing a closure for defgrad would work in this context.

@mattjj
Copy link
Contributor

mattjj commented Oct 17, 2015

I think the problem was that np.select was returning a 0-dimensional array (like np.array(FloatNode)) and those are a pain to handle. I think we just have to call wrap_if_nodes_inside. (My first fix unpacked the 0d array into a scalar type, but I realized that constitutes a change in behavior from the unwrapped numpy function.)

Reopen if still busted!

@duvenaud
Copy link
Contributor

@mattjj Are you sure you don't want to add a test? A while ago I added a systematic test for np.select but left it commented out:
https://github.com/HIPS/autograd/blob/master/tests/test_systematic.py#L176

@mattjj
Copy link
Contributor

mattjj commented Oct 17, 2015

Yeah yeah okay :)

That test is actually for the gradient of select, right? My fix didn't implement a gradient for select so much as it made select work as control flow (as it would have done automatically if not for the zero-dimensional array issue, afaict).

@mattjj
Copy link
Contributor

mattjj commented Oct 17, 2015

Except it only works in the scalar case (i.e. when condlist and choicelist are lists of scalars)... reopening!

@mattjj mattjj reopened this Oct 17, 2015
mattjj added a commit that referenced this issue Oct 17, 2015
@mattjj
Copy link
Contributor

mattjj commented Oct 17, 2015

Okay, I undid the change because making np.select actually work in general seems like a bit of a mess and I can't think of a quick fix. Any ideas?

@mattjj
Copy link
Contributor

mattjj commented Oct 17, 2015

My quick-fix thinking was that we could just use the fact that np.select should work on object ndarrays, then just wrap its output back to an ArrayNode. I'm kind of confused at the moment as to why that's not working.

@mattjj
Copy link
Contributor

mattjj commented Oct 18, 2015

Okay, took another quick stab at this one. The basic strategy is for autograd.numpy.select to be a non-primitive function that unboxes its arguments (from ListNode to list, or just keeps them as lists if they come in as lists), uses the underlying np.select on the enclosed object ndarrays (containing FloatNodes and stuff), and then reboxes the result.

Maybe this is shady! But it passes both David's test and the OP's test case.

Since I haven't thought too hard about how shady this might be, I'm leaving it in a branch named issue36 for now.

@mattjj
Copy link
Contributor

mattjj commented Oct 19, 2015

The consensus is that 2ae6bed is a good fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants