# Basis function backends
The TFC module comes equipped with C++ and Python versions of the basis function backends. For the `utfc` and `mtfc` modules, the basis function backend can be changed using the `backend` keyword. The `C++` backend can only handle the `double` type, whereas the `Python` backend is more robust, e.g., it can handle complex values. The `C++` backed is the only type that has a JAX `jit` transform, but the Python backends can still be JITed using `pejit`, provided the basis functions can be cached: see the [pejit tutorial](pejit.ipynb).

For the vast majority of applications, doubles are the type that are used. Plus, it is easier for a newer user to use the regular JAX `jit` transform than `pejit`, so `C++` is the default backend. However, for more advanced applications, e.g., solving complex differential equations, a more robust version of the basis functions are needed, which is why the Python backend exists.

## Using the backends
Other than the JIT transform, the basis function backends function the same. They have the same API and can be used in the same way. 

In [1]:
import jax.numpy as np
from tfc import utfc
from tfc.utils import egrad

# Create two versions of the utfc class. One with a C++ backend and the other with a Python backend.
cppBackend = utfc(6, 0, 2, x0=0.0, xf=1.0)
pythonBackend = utfc(6, 0, 2, x0=0.0, xf=1.0, backend="C++")

# Get H and x
x = cppBackend.x
Hcpp = cppBackend.H
Hpython = pythonBackend.H

# Take a derivative and print the result
dHcpp = egrad(Hcpp)
dHpython = egrad(Hpython)

print("C++ result:")
print(dHcpp(x))
print("\nPython result:")
print(dHpython(x))

C++ result:
[[ 0.          2.         -8.        ]
 [ 0.          2.         -6.47213595]
 [ 0.          2.         -2.47213595]
 [ 0.          2.          2.47213595]
 [ 0.          2.          6.47213595]
 [ 0.          2.          8.        ]]

Python result:
[[ 0.          2.         -8.        ]
 [ 0.          2.         -6.47213595]
 [ 0.          2.         -2.47213595]
 [ 0.          2.          2.47213595]
 [ 0.          2.          6.47213595]
 [ 0.          2.          8.        ]]


When compiling the backends with JAX's JIT, only the `C++` backend can be compiled natively. If one wants to compile the Python backends, they must be cached as complile time constants using `pejit`. 

In [2]:
from jax import jit
from tfc.utils import pejit

# Define xi for use in f
xi = np.ones(Hcpp(x).shape[1])

# Define the functions to be JITed
cpp_f = lambda x,xi: np.dot(dHcpp(x),xi)
python_f = lambda x,xi: np.dot(dHpython(x),xi)

# JIT the functions
cpp_f_jit = jit(cpp_f)
python_f_jit = pejit(x, xi, constant_arg_nums=[0])(python_f)

# Print the results
print("C++ backend result:")
print(cpp_f_jit(x,xi))
print("\nPython backend result:")
print(python_f_jit(xi))

C++ backend result:
[-6.         -4.47213595 -0.47213595  4.47213595  8.47213595 10.        ]

Python backend result:
[-6.         -4.47213595 -0.47213595  4.47213595  8.47213595 10.        ]


Notice, this means that in order to compile a function using a Python backend, the Python backend must not need to be computed at run time, i.e., the function the user wants must be setup in such a way that the result of the Python backend basis function is known and can be cached at compile time. This is the case for differential equations, e.g., see the [complex ODE tutorial](Complex_ODE.ipynb) and many other optimization problems.