In [1]:
from jax_solver import batched_local_diagonal, batched_local_couplings
import jax.numpy as jnp
import jax
from rmatrix import Rmatrix_free
from cc_constants import CC_Constants
from core_solver_jax import Core_Solver
from radial_interactions import woods_saxon_potential, woods_saxon_deformed_interaction

In [2]:
from mpmath import coulombf, coulombg
import numpy as np
from scipy.special import eval_legendre

class CoulombAsymptotics:
    @staticmethod
    def F(s, l, eta):
        """
        Coulomb function of the first kind.
        """
        return np.complex128(coulombf(l, eta, s))

    @staticmethod
    def G(s, l, eta):
        """
        Coulomb function of the second kind.
        """
        return np.complex128(coulombg(l, eta, s))


def H_plus(s, l, eta, asym=CoulombAsymptotics):
    """
    Hankel/Coulomb-Hankel function of the first kind (outgoing).
    """
    return asym.G(s, l, eta) + 1j * asym.F(s, l, eta)


def H_minus(s, l, eta, asym=CoulombAsymptotics):
    """
    Hankel/Coulomb-Hankel function of the second kind (incoming).
    """
    return asym.G(s, l, eta) - 1j * asym.F(s, l, eta)


def coulomb_func_deriv(func, s, l, eta):
    """
    Derivative of Coulomb functions F, G, and Coulomb Hankel functions H+ and H-
    """
    # recurrance relations from https://dlmf.nist.gov/33.4
    # dlmf Eq. 33.4.4
    R = np.sqrt(1 + eta**2 / (l + 1) ** 2)
    S = (l + 1) / s + eta / (l + 1)
    Xl = func(s, l, eta)
    Xlp = func(s, l + 1, eta)
    return S * Xl - R * Xlp


def H_plus_prime(s, l, eta, asym=CoulombAsymptotics):
    """
    Derivative of the Hankel function (first kind) with respect to s
    """
    return coulomb_func_deriv(H_plus, s, l, eta)


def H_minus_prime(s, l, eta, dx=1e-6, asym=CoulombAsymptotics):
    """
    Derivative of the Hankel function (second kind) with respect to s.
    """
    return coulomb_func_deriv(H_minus, s, l, eta)

In [3]:
batch_size = 11
l_max = batch_size -1
nbasis = 30
nchannels=1

constants_class = CC_Constants(40.0, 1.0, 14.0, [0.0, 2.2])
E_com_arr = constants_class.E_lab_to_COM()
E_com = E_com_arr[0]
hbar_2mu = constants_class.h2_mass

In [4]:
E_arr = np.tile(np.array([E_com]), (batch_size, 1))
l_arr = np.array([[l] for l in range(l_max + 1)])
k_arr = np.tile(np.array([constants_class.k()[0]]), (batch_size, 1))
a = 40.0  # channel radius

Hp_arr = np.array([[H_plus(k[0]*a, l[0], 0)] for l, k in zip(l_arr, k_arr)]).reshape(batch_size, 1, 1)
Hpp_arr = np.array([[H_plus_prime(k[0]*a, l[0], 0)] for l, k in zip(l_arr, k_arr)]).reshape(batch_size, 1, 1)
Hm_arr = np.array([[H_minus(k[0]*a, l[0], 0)] for l, k in zip(l_arr, k_arr)]).reshape(batch_size, 1, 1)
Hmp_arr = np.array([[H_minus_prime(k[0]*a, l[0], 0)] for l, k in zip(l_arr, k_arr)]).reshape(batch_size, 1, 1)

In [5]:
rmatrix_free_class = Rmatrix_free(nbasis)
#generate the free matrices:
free_matrix_arr = []
for E, l in zip(E_arr, l_arr):
    free_matrix = rmatrix_free_class.free_matrix(a, l, E, hbar_2mu)
    free_matrix_arr.append(free_matrix)
free_matrix_arr = np.array(free_matrix_arr)

b_arr = rmatrix_free_class.precompute_boundaries(a)

In [6]:
abscissa = rmatrix_free_class.kernel.quadrature.abscissa
diag_potential = np.diag(np.array(woods_saxon_potential(a*abscissa, 48.9, 40**(1/3)*1.19, 0.67), dtype=np.complex64))
diag_potential = np.tile(diag_potential, (batch_size, 1, 1))
couplings = np.tile(np.eye(1), (batch_size, 1, 1))
total_couplings_arr = np.array([couplings], dtype=np.complex64)
total_potential_arr = np.array([diag_potential], dtype=np.complex64)

In [7]:
core_solver_class = Core_Solver(free_matrix_arr, b_arr, Hp_arr, Hpp_arr, Hm_arr,
                                Hmp_arr,total_couplings_arr, hbar_2mu, a, nbasis, nchannels, batch_size, 1)

 No GPU found. Using CPU.


In [9]:
%%timeit
S_batch = core_solver_class.solver(appended_block_arr=total_potential_arr,fn_core=core_solver_class.fn_core, fn_interaction=core_solver_class.fn_interaction).block_until_ready()


301 μs ± 1.23 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
