# Integration of da/dt

In [None]:
import numpy as np
import holodeck as holo
from holodeck.constants import YR, GYR, MSOL, PC
from holodeck import utils, cosmo

import tqdm

SEPA_INIT = 1e4 * PC

RCHAR = 10.0*PC
GAMMA_INNER = -1.0
GAMMA_OUTER = +1.5


def integrate(hard_func, mtot, mrat, sepa_init, nsteps):
    risco = 3.0 * utils.schwarzschild_radius(mtot)
    sepa_log10 = np.log10(sepa_init)
    dx_log10 = (sepa_log10 - np.log10(risco)) / nsteps
    sepa_left = 10.0 ** sepa_log10
    dadt_left = hard_func(mtot, mrat, sepa_left)
    time_euler_left = 0.0
    time_euler_right = 0.0
    time_euler_ave = 0.0
    time_trapz_loglog = 0.0
    time_trapz = 0.0
    for ii in range(nsteps):
        sepa_log10 -= dx_log10
        sepa_right = 10.0 ** sepa_log10
        dx = (sepa_right - sepa_left)

        time_euler_left += dx / dadt_left

        dadt_right = hard_func(mtot, mrat, sepa_right)
        time_euler_right += dx / dadt_right

        dadt = 0.5 * (dadt_left + dadt_right)
        time_euler_ave += dx / dadt        

        dt1 = utils.trapz_loglog([-1.0/dadt_left, -1.0/dadt_right], [sepa_right, sepa_left])[0]
        time_trapz_loglog += dt1

        dt2 = utils.trapz([-1.0/dadt_left, -1.0/dadt_right], [sepa_right, sepa_left])[0]
        time_trapz += dt2

        sepa_left = sepa_right
        dadt_left = dadt_right

    names = ['euler_left', 'euler_right', 'euler_ave', 'trapz_loglog', 'trapz']
    times = [time_euler_left, time_euler_right, time_euler_ave, time_trapz_loglog, time_trapz]

    return times, names


def run_integration_test(mtot, mrat, norm):

    def hard_func(_mtot, _mrat, _sepa):
        dadt = holo.hardening.Fixed_Time_2PL._dadt_dedt(_mtot, _mrat, _sepa, norm, RCHAR, GAMMA_INNER, GAMMA_OUTER)[0]
        return dadt

    steps_list = [10, 20, 30, 50, 100, 200, 500, 1000, 2000, 5000, 10000]

    all_times = []
    for steps in tqdm.tqdm(steps_list):
        tt, names = integrate(hard_func, mtot, mrat, SEPA_INIT, steps)
        all_times.append(tt)
        
    all_times = np.asarray(all_times)
    all_times = all_times.T
    return steps_list, all_times, names


In [None]:
def plot_test(mtot, mrat, norm, steps, times, names):
    truth = times[:, -1].mean()
    print(f"{truth/GYR=:.4e}")

    fig, ax = plt.subplots()
    title = f"M={np.log10(mtot/MSOL):.4f}, q={mrat:.4f}, A={norm:.4e}"
    ax.set(xscale='log', xlabel='steps', yscale='linear', ylabel='time', title=title)
    ax.grid(True, alpha=0.25)
    tw = ax.twinx()
    tw.set(yscale='log')

    for ii, time in enumerate(times):
        if ii % 3 == 0:
            ls = '-'
            lw = 1.5
        elif ii % 3 == 1: 
            ls = '--'
            lw = 2.0
        else:
            ls = ':'
            lw = 3.0

        # yy = time/GYR
        yy = (time - truth)/truth
        ax.plot(steps, yy, label=names[ii], alpha=0.5, ls=ls, lw=lw)
        tw.plot(steps, np.fabs(yy), alpha=0.25, ls=ls, lw=lw)

    ax.legend()    
    return fig

In [None]:
MTOT = 1e12 * MSOL
MRAT = 0.3
NORM = 1e6

steps, times, names = run_integration_test(MTOT, MRAT, NORM)
plot_test(MTOT, MRAT, NORM, steps, times, names)
plt.show()

In [None]:
MTOT = 1e12 * MSOL
MRAT = 0.3
NORM = 1e8

steps, times, names = run_integration_test(MTOT, MRAT, NORM)
plot_test(MTOT, MRAT, NORM, steps, times, names)
plt.show()

In [None]:
MTOT = 1e6 * MSOL
MRAT = 0.3
NORM = 1e8

steps, times, names = run_integration_test(MTOT, MRAT, NORM)
plot_test(MTOT, MRAT, NORM, steps, times, names)
plt.show()

In [None]:
MTOT = 1e6 * MSOL
MRAT = 0.3
NORM = 1e6

steps, times, names = run_integration_test(MTOT, MRAT, NORM)
plot_test(MTOT, MRAT, NORM, steps, times, names)
plt.show()