Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector valued outputs #14

Closed
JasperSnoek opened this issue Apr 15, 2015 · 10 comments
Closed

Vector valued outputs #14

JasperSnoek opened this issue Apr 15, 2015 · 10 comments

Comments

@JasperSnoek
Copy link
Contributor

Please... :-)

@mattjj
Copy link
Contributor

mattjj commented Apr 15, 2015

This is a bit messy but it's a quick wrapper:

import autograd.numpy as np
from autograd import grad

def D(f,outdim):
    def f_i(i):
        return lambda *args, **kwargs: f(*args,**kwargs)[i]
    def deriv(*args,**kwargs):
        return np.concatenate(
            [grad(f_i(i))(*args,**kwargs)[None,...] for i in xrange(outdim)])
    return deriv

It can be used like

A = np.random.randn(3,3)
def test(v):
    return np.dot(A,v)
print D(test,3)(np.ones(3))
print
print A

I'm sure it can be improved but I think it reflects the best general strategy (for reverse mode). EDIT: it would be easy to change outdim to outshape, too.

@dougalm
Copy link
Contributor

dougalm commented Apr 15, 2015

Jasper, do you want the full Jacobian of a vector-to-vector function or do you just want its diagonal? If it's the Jacobian itself that you want, then you have to loop over the gradients of each output component with respect to the input vector, as Matt shows. If you want just the diagonal, and if the off-diagonal elements are zero (the conversations we've had makes me think that's what you're talking about) then you can just use the gradient of the sum of the output, and the calculation happens in a single pass. For example:

>>> import autograd.numpy as np
>>> from autograd import grad
>>> def jac_diag(fun):
...     return grad(lambda x : np.sum(fun(x)))
... 
>>> x = np.linspace(-3, 3, 5)
>>> jac_diag(np.sin)(x)
array([-0.9899925,  0.0707372,  1.       ,  0.0707372, -0.9899925])
>>> np.cos(x)
array([-0.9899925,  0.0707372,  1.       ,  0.0707372, -0.9899925])

I could add the wrapper functions jacobian (Matt's D) and jac_diag if you think they'd be useful...

@kswersky
Copy link

Those wrappers would be useful!

@JasperSnoek
Copy link
Contributor Author

That's really neat @mattjj. I already used your solution to quickly put together a Kayak module :-) I took the diagonal of the output of Matt's solution, but yes taking the sum is much simpler. Those wrappers are tremendously useful, but I'm not sure where exactly they'd fit in to autograd.

@dougalm
Copy link
Contributor

dougalm commented Apr 15, 2015

Ok, I'll put them in autograd.util for now

@mattjj
Copy link
Contributor

mattjj commented Apr 15, 2015

Sounds like there was side channel information about what Jasper really wanted!

Support for general derivatives (of maps from R^n to R^m), a.k.a. Jacobians, would be a nice feature even if it's not the main thrust of the library. Then autograd could be used for easy implementations of e.g. extended Kalman filters and smoothers (unless I'm missing something).

Maybe the jacobian function could avoid the outdim (or outshape) argument if it ran a single forward pass the first time it was called and inspected (and cached) the shape of the result.

@mattjj
Copy link
Contributor

mattjj commented Apr 15, 2015

@dougalm I don't think that jac_diag function returns the diagonal of the jacobian in general:

def test2(x):
    return np.array([np.sum(x), np.sum(x**2), np.sum(x**3)])
print D(test2,3)(np.ones(3))

def jac_diag(fun):
    return grad(lambda x: np.sum(fun(x)))
print jac_diag(test2)(np.ones(3))

# prints:
# [[ 1.  1.  1.]
#  [ 2.  2.  2.]
#  [ 3.  3.  3.]]
# [ 6.  6.  6.]

EDIT: oh you said "if the off-diagonal elements are zero" of course!

@dougalm
Copy link
Contributor

dougalm commented Apr 15, 2015

Exactly. It's a common case that people seem interested in. They have a scalar-to-scalar function and they want its gradient at a number of places. Mike Gelbart and Jon Malmaud were both cross that grad doesn't automatically do this when you give it a vector-to-vector function.

But perhaps jac_grad is a misleading name. Maybe elementwise_grad? or map_grad?

@mattjj
Copy link
Contributor

mattjj commented Apr 15, 2015

Or maybe diag_jac. I'm probably parsing too much here, but that could be slightly more suggestive that it computes a diagonal Jacobian (represented by its diagonal elements) rather than the Jacobian's diagonal (in the general case). On the other hand jac_diag probably conveys pretty much the same thing and the docstring can be used to spell out the constraints :).

@dougalm
Copy link
Contributor

dougalm commented Apr 15, 2015

Done.

@dougalm dougalm closed this as completed Apr 15, 2015
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants