In [1]:
from src.v1.symbolic import create_vars
from src.v4.projectables import ResidualProjectable
from src.v4.arghandling import var_encoding
from src.v4.functionals import encode_sympy, EncodedFunction
from src.v4.functional_noobj import concatenate_residuals
from src.v4.arghandling import flatten_args, decode, unflatten_args

In [2]:
x, y, z, a = create_vars('x y z a')
h = encode_sympy(x + y**2 - z**2)
h2 = encode_sympy(y-x*z+a)
h3 = encode_sympy(y-2*a)
H = concatenate_residuals([h, h2])

In [3]:
#P1 = ResidualProjectable(h, solvevar_encoder=var_encoding(x))
P2 = ResidualProjectable(H, solvevar_encoder=var_encoding(x,y))

In [4]:
x_F = {z: 1, a: 0}
sol = {x: -1., y: 1}

In [5]:
P2.residuals().dict_in_flat_out({**x_F, **sol})

DeviceArray([-1.,  2.], dtype=float32, weak_type=True)

In [6]:
P2.solvepar(x_initial={x:0, y:0})

In [7]:
F = P2.functional()
sol = F.dict_in_dict_out(x_F, cleanup=True)

In [8]:
sol

{x: 0.6180340051651001, y: 0.6180340051651001}

In [9]:
P2.residuals().dict_in_flat_out({**x_F, **sol})

DeviceArray([3.6705515e-08, 0.0000000e+00], dtype=float32, weak_type=True)

# Adjoint Gradients

In [10]:
from jax import jacfwd, jacrev, grad, custom_jvp
import jax.numpy as jnp
import numpy as np

In [11]:
def adjoints(projectable):
    H = projectable.residuals()
    F = projectable.functional()
    g = jacfwd(H.flat_in_flat_out)
    N = sum(map(sum, F.decoder.shapes))
    def calculate(*args):
        x_F =  F.encoder.decode(args)
        x0 = {**x_F, **F.dict_out_only(*args)}
        x0_np = H.encoder.encode(x0, flatten=True)
        d = decode(flatten_args(g(x0_np).T), 
                   H.encoder.order, 
                   [(N,) for elt in H.encoder.order], unflatten=True)
        grad_h_y = np.vstack(F.decoder.encode(d)).T
        grad_h_x = np.vstack(F.encoder.encode(d)).T
        inv_grad_h_y = np.linalg.inv(grad_h_y)
        DJ = -np.dot(inv_grad_h_y, grad_h_x)
        return unflatten_args(flatten_args(DJ), [(N,) for elt in F.encoder.order])
    return EncodedFunction(calculate, F.encoder)

In [12]:
F.decoder.order

(x, y)

In [13]:
A = adjoints(P2)

In [14]:
A.dict_in_only(x_F)

(DeviceArray([0.55278635, 0.5527864 ], dtype=float32),
 DeviceArray([ 1.1708204, -0.4472136], dtype=float32))

In [15]:
v = F.encoder.encode({z:1., a:0}, flatten=True)
v

DeviceArray([1., 0.], dtype=float32, weak_type=True)

In [16]:
F.flat_in_flat_out(v), A.flat_in_only(v)

(DeviceArray([0.618034, 0.618034], dtype=float32),
 (DeviceArray([0.55278635, 0.5527864 ], dtype=float32),
  DeviceArray([ 1.1708204, -0.4472136], dtype=float32)))

In [17]:
@custom_jvp
def f(x):
  return F.flat_in_flat_out(x)

@f.defjvp
def f_jvp(primals, tangents):
  v, = primals
  x_dot,  = tangents
  y = f(v)
  dJ = A.flat_in_only(v)
  tangent_out = sum([dj*x_dot[idx] for idx,dj in enumerate(dJ)])
  return y, tangent_out

In [18]:
jacfwd(f)(np.array([1., 0]))

DeviceArray([[ 0.55278635,  1.1708204 ],
             [ 0.5527864 , -0.4472136 ]], dtype=float32)

In [19]:
def g(x):
    return f(x)+jnp.sin(x[1])

In [20]:
jacfwd(g)(np.array([1., 0]))

DeviceArray([[0.55278635, 2.1708202 ],
             [0.5527864 , 0.5527864 ]], dtype=float32)

# Double check

In [21]:
eps = 1e-3
((flatten_args(F.f(1+eps,0))-flatten_args(F.f(1.,0)))/eps,
(flatten_args(F.f(1,0+eps))-flatten_args(F.f(1.,0)))/eps)

(DeviceArray([0.55229664, 1.1708736 ], dtype=float32),
 DeviceArray([ 0.55265427, -0.44733283], dtype=float32))