In [1]:
import numpy as np
import matplotlib.pyplot as plt
from astropy import constants as const
from astropy import units as u

c = const.c.value
G = const.G.value
M_sun = const.M_sun.value
M_earth = const.M_earth.value

In [2]:
def rel_len(x_1, x_2):
    r_vec = x_1 - x_2
    r = np.sqrt(np.dot(r_vec, r_vec))
    return r

def findrCM(x_1, x_2, m_1, m_2):
    M = m_1 + m_2
    r_cm = (m_1*x_1 + m_2*x_2)/M
    return r_cm

def γ(v, DL):
    if DL == True:
        c = 1
    else:
        c = const.c.value
    return 1/np.sqrt(1 - np.dot(v,v)/c**2)

In [3]:
def accelerationEIH(r_vec, v_vec, p_vec, m_1, m_2, DL):
    if DL == True:
        G, c = 1, 1
    else:
        c = const.c.value
        G = const.G.value

    r = np.linalg.norm(r_vec)
    v = np.linalg.norm(v_vec)
    
    r_dot = np.dot(r_vec, v_vec)/r

    # WHAT ABOUT γ?!!?!?!
    p = np.linalg.norm(p_vec)

    p_vec_dot = G**2*m_1*m_2*(m_1+m_2)/(c**2*r**4)*r_vec - G/r**3*( 1 + (1/2 + 3/2*(m_1 + m_2)**2/(m_1*m_2)*p**2/c**2) + np.dot(p_vec, r_vec)**2/(c*r)**2 )*r_vec

    # q_vec_dot = (1/m_1 + 1/m_2)*p_vec - 1/2*(1/m_1**3 + 1/m_2**3)*p**2*p_vec - G/r*( (1 + 3*(m_1 + m_2)**2/(m_1*m_2))/c**2*p_vec_dot + np.dot(p_vec, r_vec)/(c*r)**2*r_vec )

    a = (1/m_1 + 1/m_2)*p_vec_dot - 3/2*(1/m_1**3 + 1/m_2**3)*p**2*p_vec_dot - G/r*( -r_dot/r*((1 + 3*(m_1 + m_2)**2/(m_1*m_2))/c**2*p_vec + np.dot(p_vec, r_vec)**2/(c*r)**2*r_vec ) + (1 + 3*(m_1 + m_2)**2/(m_1*m_2))/c**2*p_vec_dot + ( (np.dot(p_vec_dot, r_vec) + np.dot(p_vec, v_vec))*r_vec + np.dot(p_vec, r_vec)*v_vec  )/(c*r)**2 - 2*np.dot(p_vec, r_vec)/(c**2*r**3)*r_vec*r_dot )

    return a

def boost(v_vec, a, dt):
    v_vec += a*dt
    return v_vec

def move(r_vec, v_vec, dt):
    r_vec += v_vec*dt
    return r_vec

In [4]:
def findMaxPos(pos):
    xmax, xmin, ymax, ymin = max(pos[0]), min(pos[0]), max(pos[1]), min(pos[1])
    return xmax, xmin, ymax, ymin

def plotLims(pos):
    start_pos = np.array([pos[0][0], pos[1][0]])
    end_pos = np.array([pos[0][-1], pos[1][-1]])

    xmax, xmin, ymax, ymin = findMaxPos(pos)
    max_pos = max(max(abs(start_pos)),max(abs(end_pos)))

    xlim = (xmin-0.1*max_pos, xmax+0.1*max_pos)
    ylim = (ymin-0.1*max_pos, ymax+0.1*max_pos)
    return xlim, ylim

def plotLimsTwoBody(pos1, pos2):
    start_pos1 = np.array([pos1[0][0], pos1[1][0]])
    end_pos1 = np.array([pos1[0][-1], pos1[1][-1]])
    start_pos2 = np.array([pos2[0][0], pos2[1][0]])
    end_pos2 = np.array([pos2[0][-1], pos2[1][-1]])

    x1max, x1min, y1max, y1min = findMaxPos(pos1)
    x2max, x2min, y2max, y2min = findMaxPos(pos2)

    xmax = max(x1max, x2max)
    xmin = min(x1min, x2min)
    ymax = max(y1max, y2max)
    ymin = min(y1min, y2min)

    max_pos1 = max(max(abs(start_pos1)),max(abs(end_pos1)))
    max_pos2 = max(max(abs(start_pos2)),max(abs(end_pos2)))
    max_pos = max(max_pos1, max_pos2)

    xlim = (xmin-0.1*max_pos, xmax+0.1*max_pos)
    ylim = (ymin-0.1*max_pos, ymax+0.1*max_pos)
    return xlim, ylim

def orbPlotter(positions, positionsN = 0, xlim = 0, ylim = 0, slice = slice(0, -1, 1), aspect = 1, filename='', CM = True, save = False, show = True, DL = False, N = False, figsize=(8,8)):
    if DL == False:
        x_1, x_2, x_cm = positions[0:3,:,slice]/149597871000
    else:
        x_1, x_2, x_cm = positions[0:3,:,slice]
    if CM == True:
        x_1 = x_1 - x_cm
        x_2 = x_2 - x_cm
        x_cm = np.zeros_like(x_1)

    fig, ax = plt.subplots(figsize=figsize)
    if type(xlim) != tuple:
        xlim = plotLimsTwoBody(x_1, x_2)[0]
    if type(ylim) != tuple:
        ylim = plotLimsTwoBody(x_1, x_2)[1]
    
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)

    # if type(positionsN) == np.ndarray:
    #     if DL == False:
    #         x_1N, x_2N, x_cmN = positionsN[0:3,:,slice]/149597871000
    #     else:
    #         x_1N, x_2N, x_cmN = positionsN[0:3,:,slice]
    #     if CM == True:
    #         x_1N = x_1N - x_cmN
    #         x_2N = x_2N - x_cmN
    #     ax.plot(x_1N[0], x_1N[1], 'y', lw=0.5, label='Newtonian motion')
    #     ax.plot(x_2N[0], x_2N[1], 'y', lw=0.5, label='Newtonian motion')

    if CM == False:
        ax.plot(x_cm[0], x_cm[1], 'g:',label = 'CM')
        ax.plot(x_cm[0][0], x_cm[1][0], 'gx',markersize=8)
        ax.plot(x_cm[0][-1], x_cm[1][-1], 'g.',markersize=8)

    ax.plot(x_1[0], x_1[1],'b')
    ax.plot(x_1[0][0], x_1[1][0], 'bx', label = 'm_1 start', markersize=15)
    ax.plot(x_1[0][-1], x_1[1][-1], 'b.', label = 'm_1 stop',markersize=15)

    ax.plot(x_2[0], x_2[1],'r')
    ax.plot(x_2[0][0], x_2[1][0], 'rx', label = 'm_2 start', markersize=15)
    ax.plot(x_2[0][-1], x_2[1][-1], 'r.', label = 'm_2 stop',markersize=15)

    # if DL == True:
    #     ax.set_xlabel('$R_S$', fontsize = 15)
    #     ax.set_ylabel('$R_S$', fontsize = 15)
    # else:
    #     ax.set_xlabel('$x \ [\mathrm{AU}]$', fontsize = 15)
    #     ax.set_ylabel('$y \ [\mathrm{AU}]$', fontsize = 15)  
    ax.grid(c='grey', alpha=0.2, ls ='--')
    ax.set_aspect(aspect)
    ax.set_title(f'{filename}', fontsize = 20)
    ax.legend(facecolor='grey', fontsize = 12, bbox_to_anchor=(1.01,1), loc='upper left')

    plt.title(f'{filename}', fontsize=25, y=1.08)
    fig.patch.set_facecolor('white')
    fig.tight_layout()
    if show == False:
        plt.close(fig)
    if save == True:
        fig.savefig(f'./Plots/{filename}.png', dpi=300, transparent = False)
    return