In [1]:
from foundations.projectables import ResidualProjectable
from modeling.gen4.arghandling import var_encoding
from modeling.compute import create_vars
from foundations.functionals import encode_sympy
from foundations.functional_noobj import concatenate_residuals
from modeling.arghandling import Encoding, flatten_args, decode, encode, unflatten_args, EncodedFunction

In [21]:
x, y, z, a = create_vars('x y z a')
h = encode_sympy(x + y**2 - z**2)
h2 = encode_sympy(y-x*z+a)
h3 = encode_sympy(y-2*a)
H = concatenate_residuals([h, h2])

In [22]:
#P1 = ResidualProjectable(h, solvevar_encoder=var_encoding(x))
P2 = ResidualProjectable(H, solvevar_encoder=var_encoding(x,y))

In [24]:
x_F = {z: 1, a: 0}
sol = {x: -1., y: 1}

In [25]:
P2.residuals().dict_in_flat_out({**x_F, **sol})

DeviceArray([-1.,  2.], dtype=float64, weak_type=True)

In [26]:
P2.solvepar(x_initial={x:0, y:0})

In [27]:
F = P2.functional()
sol = F.dict_in_dict_out(x_F, cleanup=True)

In [28]:
sol

{x: 0.6180339887498949, y: 0.6180339887498949}

In [29]:
P2.residuals().dict_in_flat_out({**x_F, **sol})

DeviceArray([0., 0.], dtype=float64, weak_type=True)

# Adjoint Gradients

In [45]:
from jax import jacfwd, jacrev, grad, custom_jvp
import jax.numpy as jnp
import numpy as np

In [31]:
def adjoints(projectable):
    H = projectable.residuals()
    F = projectable.functional()
    g = jacfwd(H.flat_in_flat_out)
    N = sum(map(sum, F.decoder.shapes))
    def calculate(*args):
        x_F =  F.encoder.decode(args)
        x0 = {**x_F, **F.dict_out_only(*args)}
        x0_np = H.encoder.encode(x0, flatten=True)
        d = decode(flatten_args(g(x0_np).T), 
                   H.encoder.order, 
                   [(N,) for elt in H.encoder.order], unflatten=True)
        grad_h_y = np.vstack(F.decoder.encode(d)).T
        grad_h_x = np.vstack(F.encoder.encode(d)).T
        inv_grad_h_y = np.linalg.inv(grad_h_y)
        DJ = -np.dot(inv_grad_h_y, grad_h_x)
        return unflatten_args(flatten_args(DJ), [(N,) for elt in F.encoder.order])
    return EncodedFunction(calculate, F.encoder)

In [32]:
F.decoder.order

(x, y)

In [33]:
A = adjoints(P2)

In [34]:
A.dict_in_only(x_F)

(DeviceArray([0.5527864, 0.5527864], dtype=float64),
 DeviceArray([ 1.17082039, -0.4472136 ], dtype=float64))

In [35]:
v = F.encoder.encode({z:1., a:0}, flatten=True)
v

DeviceArray([1., 0.], dtype=float64, weak_type=True)

In [36]:
F.flat_in_flat_out(v), A.flat_in_only(v)

(DeviceArray([0.61803399, 0.61803399], dtype=float64),
 (DeviceArray([0.5527864, 0.5527864], dtype=float64),
  DeviceArray([ 1.17082039, -0.4472136 ], dtype=float64)))

In [42]:
@custom_jvp
def f(x):
  return F.flat_in_flat_out(x)

@f.defjvp
def f_jvp(primals, tangents):
  v, = primals
  x_dot,  = tangents
  y = f(v)
  dJ = A.flat_in_only(v)
  tangent_out = sum([dj*x_dot[idx] for idx,dj in enumerate(dJ)])
  return y, tangent_out

In [38]:
jacfwd(f)(np.array([1., 0]))

DeviceArray([[ 0.5527864 ,  1.17082039],
             [ 0.5527864 , -0.4472136 ]], dtype=float64)

In [46]:
def g(x):
    return f(x)+jnp.sin(x[1])

In [47]:
jacfwd(g)(np.array([1., 0]))

DeviceArray([[0.5527864 , 2.17082039],
             [0.5527864 , 0.5527864 ]], dtype=float64)

# Double check

In [39]:
eps = 1e-3
((flatten_args(F.f(1+eps,0))-flatten_args(F.f(1.,0)))/eps,
(flatten_args(F.f(1,0+eps))-flatten_args(F.f(1.,0)))/eps)

(DeviceArray([0.55231517, 1.17090148], dtype=float64),
 DeviceArray([ 0.55269693, -0.44730307], dtype=float64))