In [1]:
from legwork import utils
import numpy as np
import astropy.units as u
from astropy.visualization import quantity_support
quantity_support()
from astropy.constants import G
from astropy.constants import c
from scipy.integrate import quad
from joblib import Parallel, delayed
from numba import njit
import os
import math
from math import gamma as math_gamma
os.environ["NPY_NUM_BUFSIZE"] = "8192"   # 增大NumPy缓冲区
os.environ["NPY_NUM_THREADS"] = "1"      # 禁用NumPy内部多线程
os.environ["MKL_NUM_THREADS"] = "1"      # 禁用MKL多线程

In [2]:
m_1 = 40 * u.Msun
m_2 =  4.3*10**6 * u.Msun
mu = m_1 * m_2 / (m_1 + m_2)
M = m_1 + m_2
m_c = utils.chirp_mass(m_1, m_2)

ecc_i = 0.9995
a_i = 0.01 * u.pc
f_orb_i = utils.get_f_orb_from_a(a =a_i, m_1=m_1, m_2=m_2)

dist = 8 * u.kpc
t_obs = 1 * u.yr

# calculate other params
beta = utils.beta(m_1=m_1, m_2=m_2)
c_0 = utils.c_0(a_i, ecc_i)

M_sun = (1 * u.Msun).si.value      
pc = (1 * u.pc).si.value

gam = 3.5          
rs = 0.01 * pc                 
rhos = 15246512 * M_sun / pc**3 

In [3]:

@njit(nogil=True, fastmath=True)
def f_df(v_dm, v_c, gam):
    numerator = math_gamma(gam + 1)
    denominator = math_gamma(gam - 0.5) 
    
    base_factor = numerator / denominator
    pow_factor = (2**gam) * (np.pi**1.5) * v_c**(2 * gam)
    
    velocity_term = 2 * v_c**2 - v_dm**2
    if velocity_term <= 0:
        return 0.0
    
    return (base_factor / pow_factor) * velocity_term**(gam - 1.5)
@njit(nogil=True, fastmath=True)
def rho(gam, r, rs, rhos):
    return rhos * (r / rs)**-gam
def epsilon(a, e, f, gam):
    cos_u = (np.cos(f) + e) / (1 + e * np.cos(f))
    r = a * (1 - e * cos_u)
    v = np.sqrt(G.si.value * m_2.si.value * (1 + e * cos_u) / (a * (1 - e * cos_u)))
    v_c = np.sqrt(G.si.value * m_2.si.value / r)
    v_esc = np.sqrt(2) * v_c
    if v >= v_esc:
        return 0.0
    alpha, _ = quad(
        lambda x: 4 * np.pi * f_df(x, v_c, gam) * x**2,
        0, v,
        epsabs=1e-5, epsrel=1e-4
    )
    
    beta, _ = quad(
        lambda x: 4 * np.pi * f_df(x, v_c, gam) * x**2 * np.log((x + v)/(x - v)),
        v, v_esc,
        epsabs=1e-5, epsrel=1e-4
    )
    
    delta, _ = quad(
        lambda x: -8 * np.pi * v * f_df(x, v_c, gam) * x,
        v, v_esc,
        epsabs=1e-5, epsrel=1e-4
    )
    rho_val = rho(gam, r, rs, rhos)
    log_term = np.log(pc * v_c**2 / (G.si.value * m_2.si.value))
    return -4 * np.pi * G.si.value**2 * rho_val * m_1.si.value * (log_term * alpha + beta + delta)
def da_dt_integrand(f, a_elem, e_elem, gam):
    eps = epsilon(a_elem, e_elem, f, gam)
    cos_f = np.cos(f)
    ecos = e_elem * cos_f
    denominator = (1 + ecos)**2 * np.sqrt(1 + e_elem**2 + 2*ecos)
    return eps / denominator
def da_dt_compute_element(a_elem, e_elem):
    n = math.sqrt(G.si.value * m_2.si.value / a_elem**3)
    prefactor = (1 - e_elem**2)**2 / (np.pi * n**3 * a_elem**2)
    
    integral, _ = quad(
        da_dt_integrand,
        0, 2*np.pi,
        args=(a_elem, e_elem, gam),
        epsabs=1e-5,
        epsrel=1e-4,
        limit=100
    )
    return prefactor * integral

def de_dt_integrand(f, a_elem, e_elem, gam):
    eps = epsilon(a_elem, e_elem, f, gam)
    cos_f = np.cos(f)
    ecos = e_elem * cos_f
    denominator = (1 + e_elem**2 + 2*ecos)**1.5 * (1 + ecos)**2
    return (e_elem + cos_f) / denominator * eps
def de_dt_compute_element(a_elem, e_elem):
    n_val = math.sqrt(G.si.value * m_2.si.value / a_elem**3)
    factor1 = (1 - e_elem**2)**3 / (np.pi * n_val**3 * a_elem**3)
    
    factor2, _ = quad(
        de_dt_integrand,
        0, 2*np.pi,
        args=(a_elem, e_elem, gam),
        epsabs=1e-4,
        epsrel=1e-3,
        limit=50
    )
    return factor1 * factor2

def dynamic_batch_size(n_elements):
    n_cores = os.cpu_count()
    return max(4, n_elements // (n_cores * 2))
def parallel_wrapper(func, a, e):
    a = np.asarray(a, dtype=np.float64)
    e = np.asarray(e, dtype=np.float64)
    
    if a.shape != e.shape:
        raise ValueError("Input arrays must have the same shape")
    
    if a.ndim == 0 and e.ndim == 0:
        return func(a.item(), e.item())
    
    batch_size = dynamic_batch_size(a.size)
    
    results = Parallel(n_jobs=-1, backend="threading", batch_size=batch_size)(
        delayed(func)(ai, ei) for ai, ei in zip(a.ravel(), e.ravel())
    )
    
    return np.array(results).reshape(a.shape)

def da_dt_df(ad, e):
    a = ad.si.value
    return parallel_wrapper(da_dt_compute_element, a, e)
def de_dt_df(ad, e):
    a = ad.si.value
    return parallel_wrapper(de_dt_compute_element, a, e)