You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to compute the Hessian of xᵀ A x w.r.t. x, then optimize the computation graph, expecting the result to be A + Aᵀ. However, the call to optimize crashes with a ValueError (please see the MWE below).
Interestingly, if I replace xᵀ A x by xᵀ A B x, the optimization works and yields the desired result A B + (A B)ᵀ.
Best,
Felix
fromautohootimportautodiffasadfromautohoot.graph_opsimportgraph_transformerdim=3x=ad.Variable(name="x", shape=[dim])
A=ad.Variable(name="A", shape=[dim, dim])
B=ad.Variable(name="B", shape=[dim, dim])
# ✔ Compute the Hessian of `y = xᵀ A B x` w.r.t. `x`y=ad.einsum("i,ij,jk,k->", x, A, B, x)
Hx_y=ad.hessian(y, [x])[0][0]
print(Hx_y)
# >>> (T.einsum('ac,cb->ab',T.identity(3),T.einsum('ab,bc->ca',A,B))+T.einsum('ac,cb->ab',T.identity(3),T.einsum('ab,bc->ac',A,B)))# ✔ Optimize the graph to get `A B + (A B)ᵀ`Hx_y_opt=graph_transformer.optimize(Hx_y)
print(Hx_y_opt)
# >>> (T.einsum('ab,bc->ca',A,B)+T.einsum('ab,bc->ac',A,B))# ✔ Compute the Hessian of `z = xᵀ A x` w.r.t. `x`z=ad.einsum("i,ij,j->", x, A, x)
Hx_z=ad.hessian(z, [x])[0][0]
print(Hx_z)
# >>> (T.einsum('ac,cb->ab',T.identity(3),T.einsum('ab,bc->ca',A,B))+T.einsum('ac,cb->ab',T.identity(3),T.einsum('ab,bc->ac',A,B)))# ❎ Optimize the graph to get `A + Aᵀ`Hx_z_opt=graph_transformer.optimize(Hx_z)
print(Hx_z_opt)
# >>> ValueError: Output character 'd' did not appear in the input
The text was updated successfully, but these errors were encountered:
Hi,
I am trying to compute the Hessian of
xᵀ A x
w.r.t.x
, then optimize the computation graph, expecting the result to beA + Aᵀ
. However, the call tooptimize
crashes with aValueError
(please see the MWE below).Interestingly, if I replace
xᵀ A x
byxᵀ A B x
, the optimization works and yields the desired resultA B + (A B)ᵀ
.Best,
Felix
The text was updated successfully, but these errors were encountered: