In [None]:
import numpy as np
import matplotlib.pyplot as plt

from scipy.integrate import quad
from scipy.special import gamma
from scipy.special import hyp2f1
from scipy.optimize import fsolve
from scipy.optimize import brentq
from scipy.integrate import solve_ivp

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from scipy.optimize import brentq

from astropy.constants import G,c
import astropy.units as u

import pandas as pd

from math import gamma as math_gamma

from joblib import Parallel, delayed, Memory
from numba import njit
import os
import math
os.environ["NPY_NUM_BUFSIZE"] = "8192"   # 增大NumPy缓冲区
os.environ["NPY_NUM_THREADS"] = "1"      # 禁用NumPy内部多线程
os.environ["MKL_NUM_THREADS"] = "1"      # 禁用MKL多线程

## Initial condition

In [None]:
G = G.si.value
c = c.si.value
M_sun = (1 * u.Msun).si.value      
pc = (1 * u.pc).si.value            
n_cores = os.cpu_count()

m_1 = 40 * M_sun
m_2 = 4.3e6 * M_sun
m_bh = 40 * M_sun         

gam = 3.5         
rs = 0.01 * pc                
rhos = 34219232 * M_sun / pc**3 

## Time derivative of a and e

### DM

In [None]:
@njit(nogil=True, fastmath=True)
def f_df(v_dm, v_c, gam):
    """速度分布函数（Numba兼容实现）"""
    # 计算gamma函数组合项
    numerator = math_gamma(gam + 1)
    denominator = math_gamma(gam - 0.5) 
    
    # 分子项计算
    base_factor = numerator / denominator
    pow_factor = (2**gam) * (np.pi**1.5) * v_c**(2 * gam)
    
    # 速度项计算
    velocity_term = 2 * v_c**2 - v_dm**2
    if velocity_term <= 0:
        return 0.0
    
    # 合并计算结果
    return (base_factor / pow_factor) * velocity_term**(gam - 1.5)
@njit(nogil=True, fastmath=True)
def rho(gam, r, rs, rhos):
    """密度分布函数（Numba加速）"""
    return rhos * (r / rs)**-gam
# ================== epsilon计算模块 ==================
def epsilon(a, e, f, gam):
    """完整的能量耗散计算"""
    # 轨道参数计算
    cos_u = (np.cos(f) + e) / (1 + e * np.cos(f))
    r = a * (1 - e * cos_u)
    v = np.sqrt(G * m_2 * (1 + e * cos_u) / (a * (1 - e * cos_u)))
    v_c = np.sqrt(G * m_2 / r)
    v_esc = np.sqrt(2) * v_c
    # 积分保护条件
    if v >= v_esc:
        return 0.0
    # 三重积分计算
    alpha, _ = quad(
        lambda x: 4 * np.pi * f_df(x, v_c, gam) * x**2,
        0, v,
        epsabs=1e-5, epsrel=1e-4
    )
    
    beta, _ = quad(
        lambda x: 4 * np.pi * f_df(x, v_c, gam) * x**2 * np.log((x + v)/(x - v)),
        v, v_esc,
        epsabs=1e-5, epsrel=1e-4
    )
    
    delta, _ = quad(
        lambda x: -8 * np.pi * v * f_df(x, v_c, gam) * x,
        v, v_esc,
        epsabs=1e-5, epsrel=1e-4
    )
    # 最终结果合成
    rho_val = rho(gam, r, rs, rhos)
    log_term = np.log(pc * v_c**2 / (G * m_2))
    return -4 * np.pi * G**2 * rho_val * m_1 * (log_term * alpha + beta + delta)
# ================== 轨道演化计算模块 ==================
def da_dt_integrand(f, a_elem, e_elem, gam):
    """半长轴变化被积函数"""
    eps = epsilon(a_elem, e_elem, f, gam)
    cos_f = np.cos(f)
    ecos = e_elem * cos_f
    denominator = (1 + ecos)**2 * np.sqrt(1 + e_elem**2 + 2*ecos)
    return eps / denominator
def da_dt_compute_element(a_elem, e_elem):
    """单元素半长轴变化率计算"""
    n = math.sqrt(G * m_2 / a_elem**3)
    prefactor = (1 - e_elem**2)**2 / (math.pi * n**3 * a_elem**2)
    
    integral, _ = quad(
        da_dt_integrand,
        0, 2*math.pi,
        args=(a_elem, e_elem, gam),
        epsabs=1e-5,
        epsrel=1e-4,
        limit=100
    )
    return prefactor * integral

def de_dt_integrand(f, a_elem, e_elem, gam):
    """de/dt专用被积函数（移除JIT装饰）"""
    eps = epsilon(a_elem, e_elem, f, gam)
    cos_f = np.cos(f)
    ecos = e_elem * cos_f
    denominator = (1 + e_elem**2 + 2*ecos)**1.5 * (1 + ecos)**2
    return (e_elem + cos_f) / denominator * eps
def de_dt_compute_element(a_elem, e_elem):
    """de/dt单元素计算"""
    n_val = math.sqrt(G * m_2 / a_elem**3)
    factor1 = (1 - e_elem**2)**3 / (math.pi * n_val**3 * a_elem**3)
    
    factor2, _ = quad(
        de_dt_integrand,
        0, 2*math.pi,
        args=(a_elem, e_elem, gam),
        epsabs=1e-4,
        epsrel=1e-3,
        limit=50
    )
    return factor1 * factor2

# ================== 并行计算框架 ==================
def dynamic_batch_size(n_elements):
    """智能批量大小计算"""
    n_cores = os.cpu_count()
    return max(4, n_elements // (n_cores * 2))
def parallel_wrapper(func, a, e):
    """通用并行框架"""
    a = np.asarray(a, dtype=np.float64)
    e = np.asarray(e, dtype=np.float64)
    
    if a.shape != e.shape:
        raise ValueError("Input arrays must have the same shape")
    
    if a.ndim == 0 and e.ndim == 0:
        return func(a.item(), e.item())
    
    batch_size = dynamic_batch_size(a.size)
    
    results = Parallel(n_jobs=-1, backend="threading", batch_size=batch_size)(
        delayed(func)(ai, ei) for ai, ei in zip(a.ravel(), e.ravel())
    )
    
    return np.array(results).reshape(a.shape)
# ================== 用户接口 ==================
def da_dt_df(a, e):
    """计算da/dt的完整接口"""
    return parallel_wrapper(da_dt_compute_element, a, e)
def de_dt_df(a, e):
    """计算de/dt的完整接口"""
    return parallel_wrapper(de_dt_compute_element, a, e)


### GW

In [None]:
def da_dt_gw(a, e):

    mu = (m_1 * m_2 / (m_1 + m_2))  
    M = (m_1 + m_2) 

    factor1 = - (64 / 5) * (G**3 * mu * M**2) / (c**5 * a**3)
    factor2 = 1 / (1 - e**2)**(7/2)
    factor3 = 1 + (73/24) * e**2 + (37/96) * e**4
    return factor1 * factor2 * factor3

def de_dt_gw(a, e):

    mu = (m_1 * m_2 / (m_1 + m_2))  
    M = (m_1 + m_2)  
    
    factor1 = - (304 / 15) * (G**3 * mu * M**2) / (c**5 * a**4)
    factor2 = e / (1 - e**2)**(5/2)
    factor3 = 1 + (121/304) * e**2
    return factor1 * factor2 * factor3

### total

In [None]:
def da_dt(a, e):
    
    return da_dt_gw(a, e) + da_dt_df(a, e) 


def de_dt(a, e):

    return de_dt_gw(a, e) + de_dt_df(a, e) 

## Characteristic time

In [None]:
def tgw(a, e): #total

    f1 = 1-e
    f2 = np.abs(de_dt(a ,e))

    return  f1 / f2

def tgw0(a, e): #GW

    f1 = 1-e
    f2 = np.abs(de_dt_gw(a ,e))

    return  f1 / f2

def tgwdf(a, e): #DM

    f1 = 1-e
    f2 = np.abs(de_dt_df(a ,e))

    return  f1 / f2

def trlx(a, e): # relaxation time

    n0 = 2e4
    r0 = pc
    
    f1 = 4.26 / (3)**(3/2)
    f2 = np.sqrt(r0**3 * (G * m_2)**(-1)) / (np.log(m_2 / m_bh) * n0)
    f3 = (m_2 / m_bh)**2

    t0 = f1 * f2 * f3

    return t0 * (a / r0)**(1/2) * (1 - e)
def eq1(a,e):

    return tgw(a,e) - trlx(a,e)

def a_for_plu(e):

    w = 0.26
    return w * 8 * G * m_2 / (c**2 * (1 - e))

## Formation calculation

In [None]:
one_minus_e = np.logspace(-6,0,100)
e0 = 1 - one_minus_e
a1 = a_for_plu(e0)

In [None]:

def find_root_with_scan(eq1, ei, bounds, num_points=100, epsilon=epsilon):
    lower, upper = bounds
    x_values = np.linspace(lower, upper, num_points)
    f_values = [eq1(x, ei) for x in x_values]

    for i in range(len(f_values) - 1):
        if abs(f_values[i]) < epsilon:  
            return x_values[i]
        if abs(f_values[i + 1]) < epsilon:  
            return x_values[i + 1]
        if f_values[i] * f_values[i + 1] < 0: 
            return brentq(eq1, x_values[i], x_values[i + 1], args=(ei,))
    
    raise ValueError(f"No root found within bounds for e={ei}")

In [None]:

def find_root_wrapper(ei):
    try:
        root = find_root_with_scan(eq1, ei, bounds=(1e-6 * pc, 1 * pc), epsilon=365*24*60*60*1)
        return root
    except ValueError as ve:
        return 0

a2 = Parallel(n_jobs=-1)(delayed(find_root_wrapper)(ei) for ei in e0)

print(a2)

In [None]:
one_minus_e = np.array(one_minus_e)
a1 = np.array(a1)
a2 = np.array(a2)


mask = a2 > pc*10**(-6)  
filtered_one_minus_e = one_minus_e[mask]
filtered_a2 = a2[mask]

In [None]:

emri_form_df = pd.DataFrame({
        '1-e':filtered_one_minus_e,

        'a_decay_dominate':filtered_a2/pc
})

emri_form_df.to_csv('data/10_3.5_1200_emri_form_df0.26.csv',index=False)

In [None]:

f1 = interp1d(e0, a1, kind='linear', fill_value="extrapolate")
f2 = interp1d(e0, a2, kind='linear', fill_value="extrapolate")
f3 = interp1d(a2,e0, kind='linear', fill_value="extrapolate")

def diff(e):
    return f1(e) - f2(e)


e_min, e_max = 0.9999999, 0.999
try:
    intersection_e = brentq(diff, e_min, e_max)
    intersection_a = f1(intersection_e)

    print(f"交点：e = {intersection_e}, a = {intersection_a/pc}pc")

except ValueError:
    print("没有找到交点，可能是因为曲线没有交汇点或需要更大的区间。")

In [None]:
A = np.logspace(-6, 1, 80) * pc
ONE_MINUS_E = np.logspace(-6, 0, 80)
A, ONE_MINUS_E = np.meshgrid(A, ONE_MINUS_E)
E = 1 - ONE_MINUS_E
trlx_values = trlx(A, E)
tgw_values = tgw(A, E)

In [None]:
plt.figure(figsize=(8, 6))
plt.plot(one_minus_e, a1 / pc, label="plunge orbit")
plt.plot(filtered_one_minus_e, filtered_a2 / pc, label="critical orbit")


condition1 = trlx_values < tgw_values
condition2 = 1e10 * 365 * 24 * 60 * 60 < tgw_values

plt.contourf(
    ONE_MINUS_E,
    A / pc,
    condition1,
    levels=[0.5, 1],
    colors=['red'],
    alpha=0.3,
    linestyles=None
)

plt.contourf(
    ONE_MINUS_E,
    A / pc,
    condition2,
    levels=[0.5, 1],
    colors=['black'],
    alpha=0.3,
    linestyles=None
)



plt.scatter(1-intersection_e,intersection_a/pc)
plt.xscale('log')
plt.yscale('log')
plt.ylim(10**(-6),10**(-1))
plt.xlim(10**(-6),10**(-1))
plt.xlabel("1-e (log scale)")
plt.ylabel("a (log scale)")
plt.legend()
plt.grid(True, which="both", ls="--")
plt.show()

## Characteristic time comparison

In [None]:
df40_35_1200 = pd.read_csv("data/selected_emri_moments40_3.5_1200.csv")
a40_35_1200 = df40_35_1200['a (pc)']
ecc40_35_1200 = df40_35_1200['e']

a = a40_35_1200 * pc
e = ecc40_35_1200

t_gw = tgw0(a,e)
t_df = tgwdf(a,e)
t_rlx = trlx(a,e)
t_r = tgw(a,e)

In [None]:
plt.rcParams["figure.figsize"] = [4.5,3]
plt.rcParams["figure.dpi"] = 300
plt.rc('font', size=10)

In [None]:
trlx(0.01*pc,0.9995)-tgw(0.01*pc,0.9995)

In [None]:
plt.plot(a/pc,t_gw/(365*24*60*60),label='gravitational time')
plt.plot(a/pc,t_df/(365*24*60*60),label='DM dynamical friction time')
plt.plot(a/pc,t_rlx/(365*24*60*60),label='relaxation time')
plt.plot(a/pc,t_r/(365*24*60*60),label='circularization time')

plt.xscale('log')
plt.yscale('log')

plt.xlabel(r'$\rm{a(pc)}$', fontdict={'family' : 'Times New Roman'})
plt.ylabel(r'Characteristic Time(yr)', fontdict={'family' : 'Times New Roman'})

plt.legend(frameon=False,fontsize=9)
plt.savefig('D:/pyfile/25DMEMRIs/output/CharacteristicTime.pdf', bbox_inches='tight')
plt.show()
