<a href="https://colab.research.google.com/github/profteachkids/StemUnleashed/blob/main/RubberBandBroyden.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import jax.numpy as jnp
import jax
from jax.config import config
from scipy.optimize import minimize
config.update("jax_enable_x64", True)
from plotly.subplots import make_subplots

In [23]:
anchor=jnp.array([[-1.,-1., 1.],
            [1, 1, 1],
            [1,-1,-1],
            [-1,1,-1]])

k=jnp.array([0.1,0.1,0.1,0.1])
length=jnp.array([0.5,0.3,0.5,0.4])

def fun(pc):
    v = anchor - pc
    vmag = jnp.linalg.norm(v,axis=1)
    f = ((vmag-length)/vmag * k)[:,None]*v
    return(jnp.sum(f,axis=0))

In [26]:
def broyden(f, x0, jac, maxiter=5, tol=1e-10):
    Jinv=np.linalg.inv(jac(x0))
    f0 = f(x0)
    for i in range(maxiter):
        dx = - Jinv @ f0
        x1 = x0 + dx
        f1 = f(x1)
        alpha = 1.
        jac_calc=False
        while (np.linalg.norm(f1) > np.linalg.norm(f0)):
            alpha = alpha * 0.75
            if alpha < 1e-7:
                jac_calc=True
                Jinv=np.linalg.inv(jac(x0))
                dx = - Jinv @ f0
            x1 = x0 + alpha*dx
            f1 = f(x1)
        
        if np.linalg.norm(f1)<tol:
            break
        dx, df = x1 - x0, f1-f0
        den = (dx.T @ Jinv @ df)
        Jinv = Jinv + ((dx - Jinv@df) @ (dx.T @ Jinv)) / (den+ 1e-12)
        x0, f0 = x1, f1
    return x1, f1

In [27]:
pc0 = jnp.array([0.,0.,0])
broyden(fun, pc0, jax.jacobian(fun), maxiter=50)

(DeviceArray([0.0177683 , 0.05189884, 0.0177683 ], dtype=float64),
 DeviceArray([-3.86414234e-11, -4.72381856e-12, -3.86414373e-11], dtype=float64))