Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizing the Hessian of xᵀ A x #255

Closed
f-dangel opened this issue Apr 12, 2023 · 2 comments
Closed

Optimizing the Hessian of xᵀ A x #255

f-dangel opened this issue Apr 12, 2023 · 2 comments
Assignees

Comments

@f-dangel
Copy link

Hi,

I am trying to compute the Hessian of xᵀ A x w.r.t. x, then optimize the computation graph, expecting the result to be A + Aᵀ. However, the call to optimize crashes with a ValueError (please see the MWE below).

Interestingly, if I replace xᵀ A x by xᵀ A B x, the optimization works and yields the desired result A B + (A B)ᵀ.

Best,
Felix

from autohoot import autodiff as ad
from autohoot.graph_ops import graph_transformer

dim = 3
x = ad.Variable(name="x", shape=[dim])
A = ad.Variable(name="A", shape=[dim, dim])
B = ad.Variable(name="B", shape=[dim, dim])

# ✔ Compute the Hessian of `y = xᵀ A B x` w.r.t. `x`
y = ad.einsum("i,ij,jk,k->", x, A, B, x)
Hx_y = ad.hessian(y, [x])[0][0]
print(Hx_y)
# >>> (T.einsum('ac,cb->ab',T.identity(3),T.einsum('ab,bc->ca',A,B))+T.einsum('ac,cb->ab',T.identity(3),T.einsum('ab,bc->ac',A,B)))

# ✔ Optimize the graph to get `A B + (A B)ᵀ`
Hx_y_opt = graph_transformer.optimize(Hx_y)
print(Hx_y_opt)
# >>> (T.einsum('ab,bc->ca',A,B)+T.einsum('ab,bc->ac',A,B))

# ✔ Compute the Hessian of `z = xᵀ A x` w.r.t. `x`
z = ad.einsum("i,ij,j->", x, A, x)
Hx_z = ad.hessian(z, [x])[0][0]
print(Hx_z)
# >>> (T.einsum('ac,cb->ab',T.identity(3),T.einsum('ab,bc->ca',A,B))+T.einsum('ac,cb->ab',T.identity(3),T.einsum('ab,bc->ac',A,B)))

# ❎ Optimize the graph to get `A + Aᵀ`
Hx_z_opt = graph_transformer.optimize(Hx_z)
print(Hx_z_opt)
# >>> ValueError: Output character 'd' did not appear in the input
@LinjianMa
Copy link
Owner

Hi Felix,

Thank you for letting us know! This issue should have been fixed in #256. Now you should get an output of

(T.einsum('ab->ba',A)+A)

Best,
Linjian

@f-dangel
Copy link
Author

Hi Linjian,

just installed and verified the fixed version! Thanks a lot for your prompt reply :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants