In [7]:
from tesseract_core import Tesseract
from tesseract_jax import apply_tesseract

import jax
import jax.numpy as jnp
from jax.numpy import array
jax.config.update("jax_enable_x64", True)

import sys
sys.path.append(".")
sys.path.append("src")

import tesseracts.sheaths.tanh_sheath.tesseract_api as tx_api

tx = Tesseract.from_tesseract_api(tx_api)

In [37]:
def Ip_of_V(y):
    V = y[0]
    Ip = apply_tesseract(tx, {"N": array(1e18), 
                             "n": array(1.1e22), 
                             "T": array(2e1), 
                             "Vp": array(V), 
                             "Lz": array(2.5) })["Ip"]
    return jnp.array([Ip])


jac_f = lambda y: jnp.atleast_2d(tx.jacobian({"N": array(1e18), 
                             "n": array(1.1e22), 
                             "T": array(2e1), 
                             "Vp": array(y[0]), 
                             "Lz": array(2.5) }, ["Vp"], ["Ip"])["Ip"]["Vp"])


def newton_solve(f, x0, jac):
    def newton_step(x):
        print(x)
        f_val, f_vjp = jax.vjp(f, x)
        J = jnp.stack([f_vjp(jnp.array([1.]))[0]], axis=1).T
        return x - jnp.linalg.solve(J, f_val[:, None]).flatten()
    return jax.lax.while_loop(lambda x: jnp.all(jnp.abs(f(x)) > 1e-10),
                              newton_step,
                              x0)
    
f = lambda V: Ip_of_V(V) + 43.0

In [38]:
newton_solve(f, jnp.array([0.0]), jac_f)

Traced<float64[1]>with<DynamicJaxprTrace>


Array([0.13434265], dtype=float64)

In [40]:
def my_tangent_solve(g, y):
    # Jg = jac_f(y)
    Jg = jnp.atleast_2d(g(y) - g(jnp.array([0.0])) / y)
    return jnp.linalg.solve(Jg, y[:, None]).flatten()
        
    
jax.lax.custom_root(f, jnp.array([0.0]),
                    lambda f, x: newton_solve(f, x, jac_f),
                    my_tangent_solve)

Traced<float64[1]>with<DynamicJaxprTrace>


Array([0.13434265], dtype=float64)

In [41]:
def test(target_Ip):
    f = lambda V: Ip_of_V(V) - target_Ip
    solve = lambda f, x: newton_solve(f, x, jac_f)
    def my_tangent_solve(g, y):
        Jg = jnp.atleast_2d(g(y) - g(jnp.array([0.0])) / y)
        return jnp.linalg.solve(Jg, y[:, None]).flatten()
        
    return jax.lax.custom_root(f, jnp.array([0.0]), solve, my_tangent_solve)[0]
    
jax.grad(test)(-30.0)

Traced<float64[1]>with<DynamicJaxprTrace>


TypeError: Argument 'UndefinedPrimal(float64[1,1])' of type '<class 'jax._src.interpreters.ad.UndefinedPrimal'>' is not a valid JAX type

In [10]:
def newton_solve(f, x0):
    def newton_step(x):
        f_val, f_vjp = jax.vjp(f, x)
        Jf = jnp.stack([f_vjp(jnp.array([1.]))[0]], axis=1).T
        return x - jnp.linalg.solve(Jf, f(x)[:, None]).flatten()

    return jax.lax.while_loop(lambda x: jnp.all(jnp.abs(f(x)) > 1e-10),
                              newton_step,
                              x0)
    

with Tesseract.from_tesseract_api(tx_api) as tx:
    def Ip_of_V(y):
        V = y[0]
        Ip = apply_tesseract(tx, {"N": array(1e18), 
                                             "n": array(1.1e22), 
                                             "T": array(2e1), 
                                             "Vp": array(V), 
                                             "Lz": array(2.5) })["Ip"]
        return jnp.array([Ip])

    def test(target_Ip):
        f = lambda V: Ip_of_V(V) - target_Ip
        jax.jacobian(Ip_of_V)(jnp.array([target_Ip]))
                
    Vp = test(-50.0)

NotImplementedError: Batching rule for 'tesseract_dispatch' not implemented

In [13]:
with Tesseract.from_tesseract_api(tx_api) as tx:
    def Ip_of_V(y):
        V = y[0]
        Ip = apply_tesseract(tx, {"N": array(1e18), 
                                 "n": array(1.1e22), 
                                 "T": array(2e1), 
                                 "Vp": array(V), 
                                 "Lz": array(2.5) })["Ip"]
        return jnp.array([Ip])

    jax.jacrev(Ip_of_V)(jnp.array([0.0]))

NotImplementedError: Batching rule for 'tesseract_dispatch' not implemented