In [1]:
from sympy import var, exp
import numpy as np
from coupledgrad import pure_functions, coupled_functions
from coupledgrad import solver, compose, eliminate
from coupledgrad import jacfwd

### Sympy Model

In [2]:
x,y1,y2,z1,z2,a,b = var('x y1 y2 z1 z2 a b')
f1 = z1+z2-0.2*y2 # a
r1 = a - f1
f2 = x**2+a # y1
r2 = y1 - f2
f3 = abs(y1)**.5+z1+z2 # y2
r3 = y2-f3
f4 = x**2+y1+z2 # b
obj = b-y2**2
ineq = y1 - 3.16

In [3]:
objd, ineqd, r1d, r2d, r3d = pure_functions(
    [obj, ineq, r1, r2, r3])
f4d, = coupled_functions([(f4, b)])

In [5]:
Sdata = solver([r1d, r2d, r3d], [a, y1, y2], 
                      x0_pointer=[[1.,1.,1.]])

## Validate solver

In [16]:
FS,_,Sins=Sdata
y = [0.1,0.1,1.9]
x = FS(y)
x
gFs = jacfwd(FS)
gFs(np.array(y))

DeviceArray([[-0.01571348,  0.737146  ,  0.737146  ],
             [ 0.18428652,  0.737146  ,  0.737146  ],
             [ 0.07856742,  1.3142697 ,  1.3142697 ]], dtype=float32)

In [17]:
Sins

[x, z2, z1]

In [18]:
eps = 1e-3
yd = np.array(y)+eps*np.eye(len(y))[2]
(FS(yd)-FS(y))/eps

array([0.7371584 , 0.7371584 , 1.31420801])

## Composition

In [19]:
Fdata, Rdata = compose([Sdata, f4d], 
                       forelimin=[objd])

In [20]:
Rdata[0](Fdata[0](y))

DeviceArray([-8.581169], dtype=float32)

In [21]:
Relim = eliminate(Rdata, phi=Fdata)

In [22]:
Relim(y)

DeviceArray([-8.581169], dtype=float32)

In [23]:
gRelim = jacfwd(Relim)
gRelim(np.array(y))

DeviceArray([[-0.1142697, -6.6026635, -7.6026635]], dtype=float32)

# Done

In [None]:

Fs1 = [(F1, (a,y1,y2), yother),
      (lambda x: objl(*x), (), (x,y1,y2,z2))]

In [271]:
n = 2
rx = lambda y0,y1,x0,x1: [x0+y0-3,x1*x0-x0-y1]
g = jacfwd(lambda x: rx(*x)) # x will be an array
y = [1,2] # example y

$x_0 = 3-y_0$ so we get $(-1,0)$

$x_1=1+\frac{y_1}{3-y_0}$ so we get $(y_1/(3-y_0)^2,1/(3-y_0)) = (0.5,0.5)$

In [295]:
def solver(eqs, solvefor, x0_pointer):
    # x0_pointer is a hack to use a list
    # lambdify eqs
    all_vars = {vr for eq in eqs for vr in eq.free_symbols}
    yother = list(all_vars-set(solvefor))
    rx = lambdify(solvefor+yother, eqs, modules=jnp)
    gradf = jacfwd(lambda x: rx(*x))
    n = len(solvefor)
    @custom_jvp
    def f(y):
        return fsolve(lambda x: rx(*x,*y), 
                      x0=x0_pointer[0])
    @f.defjvp
    def f_jvp(primals, tangents):
        y, = primals
        x_dot,  = tangents
        fval = f(y)
        x_full = np.hstack((fval,y))
        print(rx(*x_full))
        grad_val = np.vstack(gradf(x_full)).T
        grad_hy, grad_hx = grad_val[:n,:],grad_val[n:,:]
        inv_grad_hy = np.linalg.inv(grad_hy.T)
        dJ = -np.dot(inv_grad_hy, grad_hx.T).T
        tangent_out = sum([dj*x_dot[idx] for idx,dj in enumerate(dJ)])
        return fval, tangent_out
    return f, yother, rx

In [296]:
#rx = lambda y0,y1,x0,x1: [x0+y0-3,x1*x0-x0-y1]
eqs = [z1+y1-3,z2*z1-z1-y2]
fs,yother,_ = solver(eqs, solvefor=[z1,z2], x0_pointer=[[1,1]])
jacfwd(fs)(np.array([1., 2]))

[0.0, 0.0]


DeviceArray([[-1. ,  0. ],
             [ 0.5,  0.5]], dtype=float32)

In [297]:
x0arr=jnp.array([1., 2.])
eps = 1e-3
#fs(x0arr),fs(x0arr.at[1].set((1+eps)*x0arr[1]))
(fs([1.,2+eps])-fs([1.,2]))/eps

array([0. , 0.5])

In [298]:
fs(x0arr), fs(x0arr.at[1].set((1+eps)*x0arr[1]))

(array([2.00000004, 1.99999991]), array([2.00000005, 2.00099996]))

In [299]:
F1,yother,rx = solver([r1, r2, r3], [a, y1, y2], x0_pointer=[[1.,1.,1.]])

In [300]:
#jacfwd(F1)(np.array([1., 1., 1.]))

In [301]:
from itertools import chain

In [302]:
f1.free_symbols,f2.free_symbols,obj.free_symbols

({y2, z1, z2}, {a, x}, {x, y1, y2, z2})

In [303]:
f1l = lambdify((y2,z1,z2),f1)
f2l = lambdify((a,x),f2)

In [304]:
objl = lambdify((x,y1,y2,z2),obj,modules=jnp)
objf = (lambda x: objl(*x), (x,y1,y2,z2))
r1l = lambdify((a, y2, z1, z2), r1)
r1f = (lambda x: r1l(*x), (a, y2, z1, z2))

In [305]:
Fs = [(lambda x: f1l(*x), (a,),(y2,z1,z2)),
      (lambda x: f2l(*x), (y1,),(a,x)),
      (lambda x: objl(*x), (), (x,y1,y2,z2))]

In [306]:
Fs1 = [(F1, (a,y1,y2), yother),
      (lambda x: objl(*x), (), (x,y1,y2,z2))]

In [307]:
def structure(Fs):
    _,allout,allin = zip(*Fs)
    allvars = tuple({vr for vrs in chain(allout,allin) 
                for vr in vrs})
    lookup_table = []
    for Fi,youts,yins in Fs:
        lookup_table.append((Fi,
            jnp.array(tuple(allvars.index(yout) for yout in youts)),
            jnp.array(tuple(allvars.index(yin) for yin in yins))))
    coupled_in = ({vr for vrs in allin for vr in vrs}
              -{vr for vrs in allout for vr in vrs})
    indices_coupled_in=jnp.array(tuple(allvars.index(vr) 
                    for vr in coupled_in))
    return indices_coupled_in, lookup_table, allvars
structure(Fs1)

(DeviceArray([1, 2, 5], dtype=int32),
 [(<jax._src.custom_derivatives.custom_jvp at 0x1954c85ba30>,
   DeviceArray([4, 0, 3], dtype=int32),
   DeviceArray([1, 2, 5], dtype=int32)),
  (<function __main__.<lambda>(x)>,
   DeviceArray([], dtype=float32),
   DeviceArray([2, 0, 3, 1], dtype=int32))],
 (y1, z2, x, y2, a, z1))

In [308]:
def sequential_evaluation(indices_coupled_in, 
                           lookup_table, allvars):
    def f(x):
        xout = jnp.zeros(len(allvars))
        xout = xout.at[indices_coupled_in].set(x)
        for fi, indexout, indexin in lookup_table:
            xout = xout.at[indexout].set(fi(xout[indexin]))
        return xout
    return f

In [309]:
def sequential_elimination(f_with_in, phi):
    f_to_reduce, indexin = f_with_in
    def f(x):
        xout = phi(x)
        return f_to_reduce(xout[indexin])
    return f

In [310]:
x0 = (0.1,1.9,0.1)
indices_coupled_in, lookup_table, allvars = structure(Fs1)
phi = sequential_evaluation(indices_coupled_in, 
                            lookup_table[:-1], allvars)
testf = objf
testf_lookup = (testf[0], np.array([allvars.index(yin) 
                        for yin in testf[1]]))
testf_reduced = sequential_elimination(testf_lookup, 
                                       phi=phi)

In [311]:
#np.hstack((x0arr[0], F1(x0arr)[1:], x0arr[2])),testf[1]

In [312]:
# xobj_test = np.hstack((x0arr[0], F1(x0arr)[1:], x0arr[2]))
# eps = 1e-4
# xobjdelta = xobj_test+eps*np.eye(len(xobj_test))[3]
# (testf[0](xobjdelta)-testf[0](xobj_test))/eps

In [313]:
#jacfwd(testf[0])(xobj_test)

In [314]:
[allvars[idx] for idx in indices_coupled_in]
 

[z2, x, z1]

In [317]:
x0arr=jnp.array(x0, dtype=np.float32)
eps = 1e-3
x0arrdelta = x0arr+eps*np.eye(len(x0arr))[0]
(testf_reduced(x0arrdelta)-testf_reduced(x0arr))/eps

DeviceArray(-3.1707284, dtype=float32)

In [316]:
gradf = jacfwd(testf_reduced)
gradf(x0arr)

[-3.6994451668714845e-09, -3.0504533565789416e-08, DeviceArray(0., dtype=float32)]


DeviceArray([-3.1707401,  3.4090974, -4.17074  ], dtype=float32)

In [294]:
obj

x**2 + y1 - y2**2 + z2

In [165]:
dict(zip(allvars, xout))

{z2: 1.0, y2: 1.0, y1: 2.8, z1: 1.0, x: 1.0, a: 1.8}

In [125]:
def evaluation_sequence(Fs):
    _,allout,allin = zip(*Fs)
    allvars = {vr for vrs in chain(allout,allin) 
               for vr in vrs}
    for _,you,yin in Fs:
        lookup_table.append(())


In [12]:
groups = {r1,r2}, {r3}

In [None]:
phi = [F1,F2]
obj, dobj = reformulate(objective, eliminate=phi)