In [None]:
from jax_solver import batched_local_diagonal, batched_local_couplings
import jax.numpy as jnp
import jax
from cc_couplings import CC_Couplings
from rmatrix import Rmatrix_free
from cc_constants import CC_Constants
from cc_asymptotics import CC_Asymptotics
import numpy as np
from core_solver_jax import Core_Solver
from radial_interactions import woods_saxon_potential, woods_saxon_deformed_interaction


In [None]:
nbasis = 50
nchannels = 4
n_int = 2
a = 40.0 #channel radius

In [None]:

constants_class = CC_Constants(48.0, 1.0, 30.0, [0.0, 5.5])
couplings_class = CC_Couplings(48.0, 1.0, [0.0, 5.5], [0.0, 2.0], 1.0, 1.0, 10.0)
rmatrix_free_class = Rmatrix_free(nbasis)
hbar_2mu = constants_class.h2_mass

# dictionary of quantum number is matrix shape, the keys are the shapes of the matrices
quantum_numbers_dict = couplings_class.batched_dict

# dictionaries of energy, l, and k arrays, the keys are the shapes of the matrices
E_dict, l_dict, k_dict, keys = couplings_class.generate_energy_centrifugal_mom_batched(30.0) #passs the energy of the system

# dictionary of the coupling matrices, the keys are the shapes of the matrices
couplings_dict, keys = couplings_class.generate_couplings_batched(2.0) #pass the order of the deformation

Hp_dic, Hpp_dic, Hm_dic, Hmp_dic = CC_Asymptotics.generate_bessel_batched(a, l_dict, k_dict, keys)



  E_ch = float(E_com[idx])


In [None]:
#extract the homogeneous shapes for testing:
quantum_numbers_arr = quantum_numbers_dict.get(nchannels)
E_arr = E_dict.get(nchannels)
l_arr = l_dict.get(nchannels)
k_arr = k_dict.get(nchannels)
Hp_arr = Hp_dic.get(nchannels)
Hpp_arr = Hpp_dic.get(nchannels)
Hm_arr = Hm_dic.get(nchannels)
Hmp_arr = Hmp_dic.get(nchannels)
couplings_arr = couplings_dict.get(nchannels)

#generate the free matrices:
free_matrix_arr = []
for E, l in zip(E_arr, l_arr):
    free_matrix = rmatrix_free_class.free_matrix(a, l, E, hbar_2mu)
    free_matrix_arr.append(free_matrix)
free_matrix_arr = np.array(free_matrix_arr)

b_arr = rmatrix_free_class.precompute_boundaries(a)

batch, nch, nch = couplings_arr.shape
diag = np.eye(nch, dtype=np.complex128)
diag_couplings_arr = np.tile(diag, (batch, 1, 1)) 

total_couplings_arr = np.array([couplings_arr, diag_couplings_arr])

(9, 4)


In [None]:
core_solver_class = Core_Solver(free_matrix_arr, b_arr, Hp_arr, Hpp_arr, Hm_arr,
                                Hmp_arr,total_couplings_arr, hbar_2mu, a, nbasis, nchannels, batch, 2)

In [None]:
abscissa = rmatrix_free_class.kernel.quadrature.abscissa
abscissa = jnp.array(abscissa, dtype=jnp.complex64)
interaction_sph = woods_saxon_potential(abscissa, 48.0, 4.5, 0.65)
interaction_def = woods_saxon_deformed_interaction(abscissa, 1.2, 48.0, 4.5, 0.65)

interaction_sph_matrix = jnp.diag(interaction_sph)
interaction_def_matrix = jnp.diag(interaction_def)

interaction_sph_matrix = jnp.tile(interaction_sph_matrix, (batch, 1, 1))
interaction_def_matrix = jnp.tile(interaction_def_matrix, (batch, 1, 1))

total_interaction_arr = jnp.array([interaction_sph_matrix, interaction_def_matrix])
total_interaction_arr = np.array(total_interaction_arr, dtype=jnp.complex64)

In [None]:
S_batch = core_solver_class.solver(appended_block_arr=total_interaction_arr,fn_core=core_solver_class.fn_core, fn_interaction=core_solver_class.fn_interaction)