# Adding your own custom high-fidelity solver

Advanced users may want to substitute the default solvers in `rose` for their own custom ones. This tutorial will walk through that process, with scattering of 14.1 MeV protons on $^{27}$Al, using the [Koning-Delaroche](https://www.sciencedirect.com/science/article/pii/S0375947402013210?casa_token=qS1v6U4xDQEAAAAA:NIi9D5LpP3f05AMwRnvbQ6or8hSvXoEIgKBV56KA4l9aObCOVDAndmuCeIH77iuzoXMOOlAMyw) optical potential for the proton-nucleus interaction. 

We will add [this solver](https://pypi.org/project/lagrange-rmatrix/), which uses the calculable R-matrix method on a Lagrange-Legendre mesh. In principle, this should provide a very precise solver, that is also capable of non-local and coupled-channels potentials. 

In [1]:
import rose
import numpy as np
from matplotlib import pyplot as plt

In [2]:
# !pip install lagrange-rmatrix
from lagrange_rmatrix import (
    ProjectileTargetSystem,       # defines channel-independent data for the system
    RadialSEChannel,              # channel description
    LagrangeRMatrix,              # solver
    woods_saxon_potential,        # short-range nuclear interaction
    coulomb_charged_sphere,       # long-range EM interaction
    delta,                        # function to calculate phase shift 
)


In [3]:
# set up kinematics
from rose.koning_delaroche import KDGlobal, Projectile

# for 27-Al
A = 27
Z = 13

# lab bombarding energy
Elab = 35  # MeV

# get kinematics and default KD params
omp = rose.koning_delaroche.KDGlobal(Projectile.proton)
(mu, Ecom, k, eta, R_C), parameters = omp.get_params(A, Z, Elab)

  return float(df[mask]["BINDING_ENERGY/A"]) * A / 1e3
  return float(df[mask]["BINDING_ENERGY/A"]) * A / 1e3


In [4]:
# create an interaction space for partial waves
interactions = rose.InteractionSpace(
    rose.koning_delaroche.KD_simple,
    len(parameters),
    mu,
    Ecom,
    is_complex=True,
    spin_orbit_potential=rose.koning_delaroche.KD_simple_so,
    Z_1=1,
    Z_2=13,
    R_C=R_C,
)

To use the `lagrange-rmatrix` package as a custom solver in `rose`, we will need to add a new derived class of `rose.SchroedingerEquation`, like so:

In [5]:
# first set up the system
nodes_within_radius = 6
sys = ProjectileTargetSystem(
        incident_energy=Ecom,
        reduced_mass=mu,
        channel_radius=nodes_within_radius * (2 * np.pi),
        Ztarget=Z,
        Zproj=1,
    )

In [6]:
class LagrangeRmatrixSolver(rose.SchroedingerEquation):
    def __init__(
        self, interaction: rose.Interaction, sys: ProjectileTargetSystem, Nbasis: int
    ):
        self.interaction = interaction
        v = lambda r: self.interaction.v_r
        self.se = RadialSEChannel(
            l=interaction.ell,
            system=system,
            interaction=None,
        )
        self.solver = LagrangeRMatrix(Nbasis, sys, self.se)


    def clone_for_new_interaction(self, interaction: rose.Interaction):
        return LagrangeRmatrixSolver(interaction, self.se.system, self.solver.N)

    def phi(self, alpha: np.array, s_mesh: np.array):
        pass
        _, _, u = self.solver.solve_wavefunction()

In [7]:
                                                        
import numpy as np                                                                                      
from time import perf_counter                                                                           


In [8]:
from numba import jit, njit, config, __version__, errors
from numba.extending import overload
import numba
import numpy as np
assert tuple(int(x) for x in __version__.split('.')[:2]) >= (0, 46)
print(__version__)

0.57.1


In [9]:
from scipy import special
import scipy
print(scipy.__version__)
@njit
def call_scipy_in_jitted_code():
    print("special.beta(1.2, 3.4)", special.beta(1.2, 3.4))
    print("special.j0(5.6)       ", special.j0(5.6))
    print("special.eval_legendre(3,0.1)  ", special.eval_legendre(3,0.1))

    
call_scipy_in_jitted_code()

1.7.3
special.beta(1.2, 3.4) 0.20455811064350188
special.j0(5.6)        0.026970884685114372
special.eval_legendre(3,0.1)        -0.1474999999999999


In [None]:
scipy.__version__