In [None]:
import numpy as np
import time
import multiprocessing as mpr
import h5py
import os
import configparser
import logging
from mpmath import mp, mpc, sqrt, exp, besselj, quad,log, atan,atanh


mp.dps = 50  # set precision to 50 decimal places (arbitrary precision)
eps = 1e-3  # relative accuracy tolerance


def psi(x, rs):
    if x == 0:
        return 0
    elif x < 1:
        return rs/2 * ((log(x/2))**2 - np.atanh(sqrt(1 - x**2))**2)
    else:
        return rs/2 * ((log(x/2))**2 +  np.atan(sqrt(x**2 - 1))**2)
    

def func(x, w, y,rs):
    # x: real integration variable (mpf)
    # w, y, amp, core, p: mpc or mpf
    sqrt2x = sqrt(2 * x)
    arg_bessel = w * y * sqrt2x
    bessel_val = besselj(0, arg_bessel)
    psi_val = psi(sqrt2x, rs)
    return bessel_val * exp(-1j * w * psi_val)

def func2(x, w, y, rs):
    return func(x, w, y, rs) * exp(1j * w * x)

def dfunc(x, w, y, rs):
    sqrt2x = sqrt(2 * x)
    psi_val = psi(sqrt2x, rs)
    J1 = besselj(1, w * y * sqrt2x)
    prefactor = -w * y / sqrt2x
    return prefactor * J1 * exp(-1j * w * psi_val) - (1j * w / (2 * x)) * func(x, w, y,rs)

def ddfunc(x, w, y,rs):
    sqrt2x = sqrt(2 * x)
    psi_val = psi(sqrt2x, rs)
    #denom = (sqrt2x**2 + core**2)**(1 - p/2)
    #dpsi = amp * p * sqrt2x / denom
    
    # derivative squared term:
    #d2psi = amp * p * (1 - p / 2) * (2 * x)**(-0.5) * (1 - 2 * x / (2 * x + core**2)) / (2 * x + core**2)**(1 - p/2)
    
    term1 = (w * y) / (2 * x * sqrt2x) * (2 + 1j * w) * besselj(1, w * y * sqrt2x) * exp(-1j * w * psi_val)
    term2 = -1 / (2 * x) * (w**2 * y**2 - 1j * w / x) * func(x, w, y, rs)
    term3 = -1j * w / (2 * x) * dfunc(x, w, y,rs)
    return term1 + term2 + term3



def NFW(w, y, rs):
    a = 0.00001

    if w < 1 :
        b = 100 / w
    elif 1 < w < 500:
        b = 1000/w 
    else: b = 10000/w   
    #zzp = mpc(-1)
    
    
    # mpmath.quad with complex integrand
    zz = quad(lambda x: func2(x, w, y,rs), [a, b], error=True, maxdegree=10)
    zz_val = zz[0]  # integral value
    
    # Add tail correction terms (at b)
    tail = (-func(b, w, y, rs) / (1j * w) * exp(1j * w * b)
            - dfunc(b, w, y, rs) / (w**2) * exp(1j * w * b)
            + ddfunc(b, w, y, rs) / (1j * w**3) * exp(1j * w * b))
    
    zz_val += tail
        
    return -1j * w * exp(0.5 * 1j * w * y**2) * zz_val