In [1]:
from jax_solver import batched_local_diagonal, batched_local_couplings
import jax.numpy as jnp
import jax
from cc_couplings import CC_Couplings
from rmatrix import Rmatrix_free
from cc_constants import CC_Constants
from cc_asymptotics import CC_Asymptotics
import numpy as np
from core_solver_jax import Core_Solver
from radial_interactions import woods_saxon_potential, woods_saxon_deformed_interaction


In [81]:
nbasis = 2
nchannels = 4
n_int = 2
a = 40.0 #channel radius

In [82]:

constants_class = CC_Constants(48.0, 1.0, 30.0, [0.0, 5.5])
couplings_class = CC_Couplings(48.0, 1.0, [0.0, 5.5], [0.0, 2.0], 1.0, 1.0, 10.0)
rmatrix_free_class = Rmatrix_free(nbasis)
hbar_2mu = constants_class.h2_mass

# dictionary of quantum number is matrix shape, the keys are the shapes of the matrices
quantum_numbers_dict = couplings_class.batched_dict

# dictionaries of energy, l, and k arrays, the keys are the shapes of the matrices
E_dict, l_dict, k_dict, keys = couplings_class.generate_energy_centrifugal_mom_batched(30.0) #passs the energy of the system

# dictionary of the coupling matrices, the keys are the shapes of the matrices
couplings_dict, keys = couplings_class.generate_couplings_batched(2.0) #pass the order of the deformation

Hp_dic, Hpp_dic, Hm_dic, Hmp_dic = CC_Asymptotics.generate_bessel_batched(a, l_dict, k_dict, keys)



In [83]:
#extract the homogeneous shapes for testing:
quantum_numbers_arr = quantum_numbers_dict.get(nchannels)
E_arr = E_dict.get(nchannels)
l_arr = l_dict.get(nchannels)
k_arr = k_dict.get(nchannels)
Hp_arr = Hp_dic.get(nchannels)
Hpp_arr = Hpp_dic.get(nchannels)
Hm_arr = Hm_dic.get(nchannels)
Hmp_arr = Hmp_dic.get(nchannels)
couplings_arr = couplings_dict.get(nchannels)

#generate the free matrices:
free_matrix_arr = []
for E, l in zip(E_arr, l_arr):
    free_matrix = rmatrix_free_class.free_matrix(a, l, E, hbar_2mu)
    free_matrix_arr.append(free_matrix)
free_matrix_arr = np.array(free_matrix_arr)

b_arr = rmatrix_free_class.precompute_boundaries(a)

batch, nch, nch = couplings_arr.shape
diag = np.eye(nch, dtype=np.complex128)
diag_couplings_arr = np.tile(diag, (batch, 1, 1)) 

total_couplings_arr = np.array([couplings_arr, diag_couplings_arr])

In [159]:
print(Hm_arr[0])

[[-0.99401012+0.11535611j  0.        +0.j          0.        +0.j
   0.        +0.j        ]
 [ 0.        +0.j          0.08719375+0.99619137j  0.        +0.j
   0.        +0.j        ]
 [ 0.        +0.j          0.        +0.j         -0.01643625-1.00070344j
   0.        +0.j        ]
 [ 0.        +0.j          0.        +0.j          0.        +0.j
  -0.14898574+0.99167939j]]


In [84]:
core_solver_class = Core_Solver(free_matrix_arr, b_arr, Hp_arr, Hpp_arr, Hm_arr,
                                Hmp_arr,total_couplings_arr, hbar_2mu, a, nbasis, nchannels, batch, 2)

 No GPU found. Using CPU.


In [85]:
abscissa = rmatrix_free_class.kernel.quadrature.abscissa
abscissa = jnp.array(abscissa, dtype=jnp.complex64)
interaction_sph = woods_saxon_potential(abscissa, 48.0, 4.5, 0.65)
interaction_def = woods_saxon_deformed_interaction(abscissa, 1.2, 48.0, 4.5, 0.65)

interaction_sph_matrix = jnp.diag(interaction_sph)
interaction_def_matrix = jnp.diag(interaction_def)

interaction_sph_matrix = jnp.tile(interaction_sph_matrix, (batch, 1, 1))
interaction_def_matrix = jnp.tile(interaction_def_matrix, (batch, 1, 1))

total_interaction_arr = jnp.array([interaction_sph_matrix, interaction_def_matrix])
total_interaction_arr = np.array(total_interaction_arr, dtype=jnp.complex64)

In [86]:
S_batch = core_solver_class.solver(appended_block_arr=total_interaction_arr,fn_core=core_solver_class.fn_core, fn_interaction=core_solver_class.fn_interaction)

In [None]:
%%timeit
S_batch = core_solver_class.solver(appended_block_arr=total_interaction_arr,fn_core=core_solver_class.fn_core, fn_interaction=core_solver_class.fn_interaction)

1.66 ms ± 10.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [49]:
print(total_interaction_arr[0, None, :, None, :]*total_couplings_arr[0, :, None, :, None])

ValueError: operands could not be broadcast together with shapes (1,9,1,30,30) (9,1,4,1,4) 

In [54]:
total_interaction_arr[:, None, :, None, :].shape

(2, 1, 9, 1, 30, 30)

In [101]:
arr_couplings_seq = np.array([[1, 2, 3, 4],[5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.complex64)
arr_seq = np.tile(arr_couplings_seq, (batch, 1, 1))
arr_diag = np.array([[1, 0, 0, 0],[0, 1, 0, 0],[0, 0, 1, 0],[0, 0, 0, 1]], dtype=np.complex64)
arr_diag = np.tile(arr_diag, (batch, 1, 1))



interaction_sph_matrix = jnp.eye(nbasis)
interaction_def_matrix = jnp.eye(nbasis)

interaction_sph_matrix = jnp.tile(interaction_sph_matrix, (batch, 1, 1))
interaction_def_matrix = jnp.tile(interaction_def_matrix, (batch, 1, 1))

total_interaction_arr = jnp.array([interaction_sph_matrix, interaction_def_matrix])
total_interaction_arr = np.array(total_interaction_arr, dtype=jnp.complex64)

total_couplings_arr = jnp.array([arr_seq, arr_diag])

In [105]:
f_kernel = Core_Solver.make_coupling_kernel(nchannels, nbasis)
int_builder = Core_Solver.generate_coupling_interaction


total_interaction_arr = jnp.array(total_interaction_arr, dtype=jnp.complex64)
total_couplings_arr = jnp.array(total_couplings_arr, dtype=jnp.complex64)

interaction = int_builder(appended_couplings_jax = total_couplings_arr, 
                          appended_block_jax = total_interaction_arr, 
                          precomp_fill_fn = f_kernel)

In [108]:
print(jnp.array(interaction[0], dtype=jnp.float32))

[[ 2.  0.  2.  0.  3.  0.  4.  0.]
 [ 0.  2.  0.  2.  0.  3.  0.  4.]
 [ 5.  0.  7.  0.  7.  0.  8.  0.]
 [ 0.  5.  0.  7.  0.  7.  0.  8.]
 [ 9.  0. 10.  0. 12.  0. 12.  0.]
 [ 0.  9.  0. 10.  0. 12.  0. 12.]
 [13.  0. 14.  0. 15.  0. 17.  0.]
 [ 0. 13.  0. 14.  0. 15.  0. 17.]]


  out_array: Array = lax_internal._convert_element_type(


In [153]:
import jax.numpy as jnp
import numpy as np

# Parameters
batch_size = 10
matrix_size = 10
nbasis = 5
nchannels = 2

# Random batch of invertible matrices (with added identity to ensure invertibility)
key = jnp.array(np.random.default_rng(0).normal(size=(batch_size, matrix_size, matrix_size)).astype(np.complex64))
A_batch = key + jnp.eye(matrix_size, dtype=jnp.complex64)[None, :, :] * 5.0

# JAX batched inverse
C_jax = jnp.linalg.inv(A_batch)

# NumPy loop inverse
A_np = np.array(A_batch, dtype=np.complex64)
C_np = np.stack([np.linalg.inv(A_np[i]) for i in range(batch_size)], axis=0)

# Comparison
print("Allclose:", np.allclose(C_jax, C_np, atol=1e-7))

Allclose: True


In [154]:
b = jnp.array(np.random.default_rng(0).normal(size=(nbasis)).astype(np.complex64))
b_np = np.array(b, dtype=np.complex64)

In [155]:
C_jax = C_jax.reshape(batch_size, nchannels, nbasis, nchannels, nbasis)
R_ij_jax = jnp.einsum('m,bimjn,n->bij', b, C_jax, b)

In [156]:
R_ij_np = np.zeros((batch_size, nchannels, nchannels), dtype=np.complex64)
for b in range(batch_size):
    for j in range(nchannels):
        for m in range(nchannels):
            C_block = C_np[b, j * nbasis : (j + 1) * nbasis, m * nbasis : (m + 1) * nbasis]
            R_ij_np[b, j, m] = b_np.T @ C_block @ b_np
            


In [158]:
print("Allclose:", np.allclose(R_ij_jax, R_ij_np, atol=1e-8))

Allclose: True
