In [46]:
from functools import partial
import inspect
from typing import Callable
import jax
from jax import Array
import jax.numpy as jnp

def lennard_jones(r, sigma, epsilon):
    return 1 * epsilon * ( (sigma/r) ** 12 - (sigma/r) ** 6)

def parse_jaxpr(fn: Callable, args: tuple[str] | None = None, symbols: dict = {}):

    placeholders = []
    
    if args is None:
        sig = inspect.signature(fn)
        for key, val in sig.parameters.items():
            if val.default == inspect._empty:
                placeholders.append(key)

    jaxpr_kwargs = {
        key: jax.ShapedArray((), dtype=jnp.float32)
        for key in placeholders
    }
    jaxpr = jax.make_jaxpr(fn)(**jaxpr_kwargs)
    symbols = {
        str(sym): val if val not in symbols else symbols[val]
        for (sym, val) in zip(jaxpr.jaxpr.invars, placeholders)
    }

    def fetch_symbol(arg):
        if isinstance(arg, jax.core.Var):
            return symbols[str(arg)]
        elif isinstance(arg, jax.core.Literal):
            return str(arg)
        else:
            raise ValueError()
    
    for eqn in jaxpr.eqns:
        match eqn.primitive.name:
            case "mul":
                fst, snd = map(fetch_symbol, eqn.invars)
                out = str(eqn.outvars[0])
                symbols[out] = f"({fst} * {snd})"
            case "div":
                fst, snd = map(fetch_symbol, eqn.invars)
                out = str(eqn.outvars[0])
                symbols[out] = f"({fst} / {snd})"
            case "integer_pow":
                fst, = map(fetch_symbol, eqn.invars)
                exp = str(eqn.params['y'])
                out = str(eqn.outvars[0])
                symbols[out] = f"{fst}^{exp}"
            case "sub":
                fst, snd = map(fetch_symbol, eqn.invars)
                out = str(eqn.outvars[0])
                symbols[out] = f"({fst} - {snd})"
            case "add":
                fst, snd = map(fetch_symbol, eqn.invars)
                out = str(eqn.outvars[0])
                symbols[out] = f"({fst} + {snd})"
            case "convert_element_type":
                fst, = map(fetch_symbol, eqn.invars)
                out = str(eqn.outvars[0])
                symbols[out] = f"{fst}"
            case _:
                raise ValueError(eqn.primitive.name)

    return tuple(map(fetch_symbol, jaxpr.jaxpr.outvars))

def approx_with_square(original, cutoff: Array, slope: float):
    y0, dyx0 = jax.value_and_grad(original)(cutoff)
    a = slope
    b = dyx0 - 2 * a * cutoff
    c = y0 - a * cutoff ** 2 - b * cutoff

    def approx(r):
        return a * r ** 2 + b * r + c

    return approx

def approx_with_linear(original, cutoff: Array):
    y0, dyx0 = jax.value_and_grad(original)(cutoff)
    a = dyx0
    b = y0 - a * cutoff

    def approx(r):
        return a * r + b 

    return approx

def approx(fun, approximation, cutoff, **kwargs):
    original = parse_jaxpr(fun)[0]
    fun_ = partial(fun, **kwargs)
    approx = parse_jaxpr(approximation(fun_, cutoff))[0]
    sig = inspect.signature(fun)
    inp = tuple(sig.parameters.keys())[0]
    filter = f"step({inp} - {cutoff})"
    return f"select({filter}, {original}, {approx})"

# parse_jaxpr(approx_with_linear(lennard_jones, cutoff=0.1, sigma0=0.1, epsilon0=0.1))[0]
approx(partial(lennard_jones, sigma=0.316435, epsilon=0.680946), partial(approx_with_square, slope=10.), 0.8)

'select(step(r - 0.8), (0.6809459924697876 * ((0.3164350092411041 / r)^12 - (0.3164350092411041 / r)^6)), (((10.0 * r^2) + (-15.9805908203125 * r)) + 6.3818745613098145))'

In [24]:
sig = inspect.signature(partial(lennard_jones, sigma=0.3, epsilon=4.))
ist(sig.parameters.values())[0].default
    

inspect._empty

In [6]:
approx(lennard_jones, approx_with_linear, 0.08,  sigma=0.34, epsilon=4.)

'select(step(r - 0.08), ((4.0 * epsilon) * ((sigma / r)^12 - (sigma / r)^6)), ((-83337707520.0 * r) + 7222554112.0))'

In [41]:
parse_jaxpr(lennard_jones)

('((4.0 * r) * ((epsilon / sigma)^12 - (epsilon / sigma)^6))',)