In [1]:
%matplotlib inline

In [2]:
%run proof_setup

In [17]:
import numpy as np
import sympy as sm


def rotate(sinw, cosw, sini, cosi, x2, y2, z2):
    Rwinv = sm.Matrix([[cosw, sinw, 0], [-sinw, cosw, 0], [0, 0, 1]])
    Riinv = sm.Matrix([[1, 0, 0], [0, cosi, sini], [0, -sini, cosi]])
    v2 = sm.Matrix([[x2], [y2], [z2]])
    v0 = Rwinv * Riinv * v2
    return sm.simplify(v0), sm.simplify(v2)


def get_quadratic_eqs(circular=False, edge=False, printer=None, wcase=False):
    if printer is None:
        printer = lambda x: x

    semimajor, ecc, w, incl, x, y, z, L = sm.symbols("a, e, omega, i, x, y, z, L")
    sinw = sm.sin(w)
    cosw = sm.cos(w)
    sini = sm.sin(incl)
    cosi = sm.cos(incl)

    if edge:
        cosi = 0
        sini = 1

    y = z * cosi / sini

    if wcase:
        sinw = 0
        cosw = 1

    if circular:
        ecc = 0

    v0, v2 = rotate(sinw, cosw, sini, cosi, x, y, z)

    print("x0 =", printer(v0[0]))
    print("y0 =", printer(v0[1]))
    print("z0 =", printer(v0[2]))
    print()

    eq = (v0[0] - semimajor * ecc) ** 2 + v0[1] ** 2 / (1 - ecc ** 2) - semimajor ** 2

    eq1 = sm.poly(eq, x, z)

    denom = ecc ** 2 - 1
    print("A =", printer(sm.simplify(denom * eq1.coeff_monomial(x ** 2))))
    print("B =", printer(sm.cancel(denom * eq1.coeff_monomial(x * z))))
    print("C =", printer(sm.simplify(denom * eq1.coeff_monomial(z ** 2))))
    print("D =", printer(sm.simplify(denom * eq1.coeff_monomial(x))))
    print("E =", printer(sm.simplify(denom * eq1.coeff_monomial(z))))
    print("F =", printer(sm.simplify(denom * eq1.coeff_monomial(1))))

    return (
        sm.simplify(denom * eq1.coeff_monomial(x ** 2)),
        sm.simplify(denom * eq1.coeff_monomial(x * z)),
        sm.simplify(denom * eq1.coeff_monomial(z ** 2)),
        sm.simplify(denom * eq1.coeff_monomial(x)),
        sm.simplify(denom * eq1.coeff_monomial(z)),
        sm.simplify(denom * eq1.coeff_monomial(1)),
    )

In [18]:
get_quadratic_eqs(printer=sm.latex)
print()
print()
get_quadratic_eqs(circular=True, printer=sm.latex)
print()
print()
get_quadratic_eqs(wcase=True, printer=sm.latex)
print()
print()
get_quadratic_eqs(edge=True, printer=sm.latex);

x0 = x \cos{\left (\omega \right )} + \frac{z \sin{\left (\omega \right )}}{\sin{\left (i \right )}}
y0 = - x \sin{\left (\omega \right )} + \frac{z \cos{\left (\omega \right )}}{\sin{\left (i \right )}}
z0 = 0

A = e^{2} \cos^{2}{\left (\omega \right )} - 1
B = \frac{2 e^{2} \sin{\left (\omega \right )} \cos{\left (\omega \right )}}{\sin{\left (i \right )}}
C = \frac{e^{2} \sin^{2}{\left (\omega \right )} - 1}{\sin^{2}{\left (i \right )}}
D = 2 a e \left(- e^{2} + 1\right) \cos{\left (\omega \right )}
E = - \frac{2 a e \left(e^{2} - 1\right) \sin{\left (\omega \right )}}{\sin{\left (i \right )}}
F = a^{2} \left(e^{2} - 1\right)^{2}


x0 = x \cos{\left (\omega \right )} + \frac{z \sin{\left (\omega \right )}}{\sin{\left (i \right )}}
y0 = - x \sin{\left (\omega \right )} + \frac{z \cos{\left (\omega \right )}}{\sin{\left (i \right )}}
z0 = 0

A = -1
B = 0
C = - \frac{1}{\sin^{2}{\left (i \right )}}
D = 0
E = 0
F = a^{2}


x0 = x
y0 = \frac{z}{\sin{\left (i \right )}}
z0 = 0

A = e^{2} 

In [9]:
def get_quartic_expr(circular=False, edge=False, printer=None, wcase=False):
    if printer is None:
        printer = lambda x: x

    A, B, C, D, E, F, T, L, x = sm.symbols("A, B, C, D, E, F, T, L, x", real=True)
    if edge:
        A, B, C, D, E, F = get_quadratic_eqs(edge=True)

    p0 = T
    p1 = 0
    p2 = x ** 2 - L ** 2

    q0 = C
    q1 = B * x + E
    q2 = A * x ** 2 + D * x + F

    quartic = sm.Poly(
        (p0 * q2 - p2 * q0) ** 2 - (p0 * q1 - p1 * q0) * (p1 * q2 - p2 * q1), x
    )

    if circular:
        args = {A: -1, B: 0, D: 0, E: 0}
    elif wcase:
        args = {B: 0, E: 0}
        quartic = sm.factor(quartic.subs(args))
        print(quartic)
        return
    else:
        args = {}

    for i in range(5):
        print(
            "a_{0} =".format(i),
            printer(sm.factor(sm.simplify(quartic.coeff_monomial(x ** i).subs(args)))),
        )

In [11]:
get_quartic_expr(printer=sm.latex)
print()
print()
get_quartic_expr(circular=True, printer=sm.latex)
print()
print()
get_quartic_expr(wcase=True, printer=sm.latex)

a_0 = C^{2} L^{4} + 2 C F L^{2} T - E^{2} L^{2} T + F^{2} T^{2}
a_1 = - 2 T \left(B E L^{2} - C D L^{2} - D F T\right)
a_2 = 2 A C L^{2} T + 2 A F T^{2} - B^{2} L^{2} T - 2 C^{2} L^{2} - 2 C F T + D^{2} T^{2} + E^{2} T
a_3 = 2 T \left(A D T + B E - C D\right)
a_4 = A^{2} T^{2} - 2 A C T + B^{2} T + C^{2}


a_0 = \left(C L^{2} + F T\right)^{2}
a_1 = 0
a_2 = - 2 \left(C + T\right) \left(C L^{2} + F T\right)
a_3 = 0
a_4 = \left(C + T\right)^{2}


(A*T*x**2 + C*L**2 - C*x**2 + D*T*x + F*T)**2


In [326]:
def balance_companion_matrix(companion_matrix):
    diag = np.array(np.diag(companion_matrix))
    companion_matrix[np.diag_indices_from(companion_matrix)] = 0.0
    degree = len(diag)

    # gamma <= 1 controls how much a change in the scaling has to
    # lower the 1-norm of the companion matrix to be accepted.
    #
    # gamma = 1 seems to lead to cycles (numerical issues?), so
    # we set it slightly lower.
    gamma = 0.9

    scaling_has_changed = True
    while scaling_has_changed:
        scaling_has_changed = False

        for i in range(degree):
            row_norm = np.sum(np.abs(companion_matrix[i]))
            col_norm = np.sum(np.abs(companion_matrix[:, i]))

            # Decompose row_norm/col_norm into mantissa * 2^exponent,
            # where 0.5 <= mantissa < 1. Discard mantissa (return value
            # of frexp), as only the exponent is needed.
            _, exponent = np.frexp(row_norm / col_norm)
            exponent = exponent // 2

            if exponent != 0:
                scaled_col_norm = np.ldexp(col_norm, exponent)
                scaled_row_norm = np.ldexp(row_norm, -exponent)
                if scaled_col_norm + scaled_row_norm < gamma * (col_norm + row_norm):
                    # Accept the new scaling. (Multiplication by powers of 2 should not
                    # introduce rounding errors (ignoring non-normalized numbers and
                    # over- or underflow))
                    scaling_has_changed = True
                    companion_matrix[i] *= np.ldexp(1.0, -exponent)
                    companion_matrix[:, i] *= np.ldexp(1.0, exponent)

    companion_matrix[np.diag_indices_from(companion_matrix)] = diag
    return companion_matrix


def solve_companion_matrix(poly):
    poly = np.atleast_1d(poly)
    comp = np.eye(len(poly) - 1, k=-1)
    comp[:, -1] = -poly[:-1] / poly[-1]
    return np.linalg.eigvals(balance_companion_matrix(comp))


def _get_quadratic(a, e, cosw, sinw, cosi, sini):
    e2 = e * e
    e2mo = e2 - 1
    return (
        (e2 * cosw * cosw - 1),
        2 * e2 * sinw * cosw / sini,
        (e2mo - e2 * cosw * cosw) / (sini * sini),
        -2 * a * e * e2mo * cosw,
        -2 * a * e * e2mo * sinw / sini,
        a ** 2 * e2mo * e2mo,
    )


def _get_quartic(A, B, C, D, E, F, T, L):
    A2 = A * A
    B2 = B * B
    C2 = C * C
    D2 = D * D
    E2 = E * E
    F2 = F * F
    T2 = T * T
    L2 = L * L
    return (
        C2 * L2 * L2 + 2 * C * F * L2 * T - E2 * L2 * T + F2 * T2,
        -2 * T * (B * E * L2 - C * D * L2 - D * F * T),
        2 * A * C * L2 * T
        + 2 * A * F * T2
        - B2 * L2 * T
        - 2 * C2 * L2
        - 2 * C * F * T
        + D2 * T2
        + E2 * T,
        2 * T * (A * D * T + B * E - C * D),
        A2 * T2 - 2 * A * C * T + B2 * T + C2,
    )


def _get_roots_general(a, e, omega, i, L, tol=1e-8):
    cosw = np.cos(omega)
    sinw = np.sin(omega)
    cosi = np.cos(i)
    sini = np.sin(i)

    f0 = 2 * np.arctan2(cosw, 1 + sinw)

    quad = _get_quadratic(a, e, cosw, sinw, cosi, sini)
    A, B, C, D, E, F = quad
    T = cosi / sini
    T *= T
    quartic = _get_quartic(A, B, C, D, E, F, T, L)

    roots = solve_companion_matrix(quartic)
    roots = roots[np.argsort(np.real(roots))]

    # Deal with multiplicity
    roots[0] = roots[:2][np.argmin(np.abs(np.imag(roots[:2])))]
    roots[1] = roots[2:][::-1][np.argmin(np.abs(np.imag(roots[2:])[::-1]))]
    roots = roots[:2]

    # Only select real roots
    roots = np.clip(np.real(roots[np.abs(np.imag(roots)) < tol]), -L, L)
    if len(roots) < 2:
        return np.empty(0)

    angles = []
    for x in roots:
        b0 = A * x * x + D * x + F
        b1 = B * x + E
        b2 = C
        z1 = -0.5 * b1 / b2
        arg = b1 * b1 - 4 * b0 * b2
        if arg < 0:
            continue
        z2 = 0.5 * np.sqrt(arg) / b2
        for sgn in [-1, 1]:
            z = z1 + sgn * z2
            if z > 0:
                continue
            y = z * cosi / sini

            x0 = x * cosw + z * sinw / sini
            y0 = -x * sinw + z * cosw / sini
            angle = np.arctan2(y0, x0) - np.pi
            if angle < -np.pi:
                angle += 2 * np.pi
            angles.append(angle - f0)

    angles = np.sort(angles)

    # Wrap the roots properly to span the transit
    if len(angles) == 2:
        if np.all(angles > 0):
            angles = np.array([angles[1] - 2 * np.pi, angles[0]])
        if np.all(angles < 0):
            angles = np.array([angles[1], angles[0] + 2 * np.pi])
    else:
        angles = np.array([-np.pi, np.pi])

    return angles + f0


def check_roots(a, e, omega, i, L, tol=1e-8):
    L /= a
    a = 1.0
    roots = _get_roots_general(a, e, omega, i, L, tol=tol)

    for f in roots:
        b2 = (
            a ** 2
            * (e ** 2 - 1) ** 2
            * (
                np.cos(i) ** 2
                * (np.cos(omega) * np.sin(f) + np.sin(omega) * np.cos(f)) ** 2
                + (np.cos(omega) * np.cos(f) - np.sin(omega) * np.sin(f)) ** 2
            )
            / (e * np.cos(f) + 1) ** 2
        )
        print("b2 = ", b2, " L2 = ", L ** 2)

    print(roots)

In [334]:
check_roots(10.0, 0.5, -0.15, 0.5 * np.pi - 0.01, 1.0)

b2 =  0.010014687059509328  L2 =  0.010000000000000002
b2 =  0.010019159137027812  L2 =  0.010000000000000002
[1.58853526 1.83656275]


In [331]:
check_roots(100.0, 0.0, np.pi, 0.5 * np.pi, 1.5)

b2 =  0.00022499999999999343  L2 =  0.000225
b2 =  0.000225000000000003  L2 =  0.000225
[-1.58579689 -1.55579576]


In [280]:
get_quadratic_eqs()
print()
get_quartic_expr();

x0 = x*cos(omega) + z*sin(omega)/sin(i)
y0 = -x*sin(omega) + z*cos(omega)/sin(i)
z0 = 0

A = e**2*cos(omega)**2 - 1
B = 2*e**2*sin(omega)*cos(omega)/sin(i)
C = (e**2*sin(omega)**2 - 1)/sin(i)**2
D = 2*a*e*(-e**2 + 1)*cos(omega)
E = -2*a*e*(e**2 - 1)*sin(omega)/sin(i)
F = a**2*(e**2 - 1)**2

a_0 = C**2*L**4 + 2*C*F*L**2*T - E**2*L**2*T + F**2*T**2
a_1 = -2*T*(B*E*L**2 - C*D*L**2 - D*F*T)
a_2 = 2*A*C*L**2*T + 2*A*F*T**2 - B**2*L**2*T - 2*C**2*L**2 - 2*C*F*T + D**2*T**2 + E**2*T
a_3 = 2*T*(A*D*T + B*E - C*D)
a_4 = A**2*T**2 - 2*A*C*T + B**2*T + C**2


In [128]:
get_quadratic_eqs(edge=True)
print()
print()
get_quadratic_eqs(circular=True, printer=sm.latex, edge=True);

x0 = x*cos(omega) + z*sin(omega)
y0 = -x*sin(omega) + z*cos(omega)
z0 = -y

b_0 = (a**2*e**4 - 2*a**2*e**2 + a**2 - 2*a*e**3*x*cos(omega) + 2*a*e*x*cos(omega) + e**2*x**2*cos(omega)**2 - x**2)/(e**2 - 1)
b_1 = (-2*a*e**3*sin(omega) + 2*a*e*sin(omega) + 2*e**2*x*sin(omega)*cos(omega))/(e**2 - 1)
b_2 = (e**2*sin(omega)**2 - 1)/(e**2 - 1)


x0 = x
y0 = z
z0 = - y

b_0 = - a^{2} + x^{2}
b_1 = 0
b_2 = 1
