In [1]:
# python imports
import jax.numpy as jnp

# custom imports
from vnn_srg.potentials import Potential
from vnn_srg.srg import compute_srg_transformation
from vnn_srg.integration import unattach_weights_from_matrix

In [2]:
channel_arg = '1S0'
coupled_channel = False
kvnn = 6
kmax = 15
kmid = 3
ntot = 120
generator = 'T'
lamb = 2

potential = Potential(kvnn, channel_arg, kmax, kmid, ntot)
k_array, k_weights = potential.load_mesh()

In [3]:
H_initial = potential.load_hamiltonian(method = 'initial')
H_final = potential.load_hamiltonian(lamb = lamb, generator = generator, method = 'srg')

In [4]:
U_matrix_weights = compute_srg_transformation(H_initial, H_final)

# Check if Unitary
print(jnp.diagonal(U_matrix_weights @ U_matrix_weights.T))

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


In [5]:
U_matrix = unattach_weights_from_matrix(k_array, k_weights, U_matrix_weights, coupled_channel)

# cut off momentum k > Lambda, k' < Lambda
Lambda = 2.5 #[fm^-1]

indicies = jnp.where(k_array > Lambda)
cutoff_index = indicies[0][0]

# Setting all values where k > Lambda to 0
U_matrix[cutoff_index:,:] = 0

# Setting all values where q < Lambda to 0
U_matrix[:,:cutoff_index] = 0 

In [6]:
U, S, Vh = jnp.linalg.svd(U_matrix)

In [7]:
print(f'My numbers: d1 = {S[0]}')
print(f'Paper Numbers: d1 = 0.763')
print('')
print(f'My numbers: d2 = {S[1]}')
print(f'Paper Numbers: d2 = 0.033')

My numbers: d1 = 1.6197949647903442
Paper Numbers: d1 = 0.763

My numbers: d2 = 0.09073500335216522
Paper Numbers: d2 = 0.033
