In [None]:
import jax
from jax import numpy as jnp

In [None]:
def taylor_extension(n_terms=30, loc=0.):
    def decorator(function):
        def wrapper(input):
            # Ensure input is a square matrix
            assert input.shape[0] == input.shape[1], "Input must be a square matrix"

            x_loc = loc
            output = jnp.zeros_like(input)

            for term in range(n_terms):

                if term == 0:
                    grad_fn = function
                    factorial = jnp.ones(1)
                    mat_pow = jnp.eye(input.shape[0])

                else:
                    grad_fn = jax.grad(grad_fn) 
                    factorial = factorial*term
                    mat_pow @= input

                output += grad_fn(x_loc)/factorial * mat_pow

            return output

        return wrapper
    return decorator

In [None]:
import jax
import jax.numpy as jnp

def taylor_extension(n_terms=30, loc=0.0):
    # @jax.jit
    def compute_taylor(input, function):
        assert input.shape[0] == input.shape[1], "Input must be a square matrix"

        x_loc = loc
        factorial = 1.0
        mat_pow = jnp.eye(input.shape[0])
        output = jnp.zeros_like(input)

        # Precompute gradients up to n_terms
        grad_fns = [function]
        for _ in range(1, n_terms):
            grad_fns.append(jax.grad(grad_fns[-1]))

        for term, grad_fn in enumerate(grad_fns):
            if term > 0:
                factorial *= term
                mat_pow = mat_pow @ input

            output += grad_fn(x_loc) / factorial * mat_pow

        return output

    def decorator(function):
        def wrapper(input):
            return compute_taylor(input, function)
        return wrapper

    return decorator

In [None]:
@taylor_extension(50)
def f(x):
    return jnp.sinc(x)

In [None]:
@taylor_extension(n_terms=50, loc=0.0)
def example_function(x):
    return jnp.sin(x)

input_matrix = jnp.eye(3)
result = example_function(input_matrix)
print(result)

In [None]:
input_matrix = jnp.eye(3)

f(input_matrix)