Much of the code in ``engine" was not typed up by hand. Instead, we do all of the algebra in Wolfram Mathematica .wls files, have those export some Fortran code, then convert that to Jax with this kind of script. All files required some manual tweaking in the end, but this sped up the process. Actually using this would require a local installation of Mathematica.

In [1]:
import jax

jax.config.update("jax_enable_x64", True)

import os
import subprocess


# will use this function for all the mathematica files
def parse_fortranform(file, function_name, final_variables, black=False, variable_renames={}):
    with open(file) as f:
        s = f.read()
    if len(variable_renames) > 0:
        for key, value in variable_renames.items():
            s = s.replace(key, value)
    s = s.replace('Sqrt', 'jnp.sqrt')
    s = s.replace('Cos', 'jnp.cos')
    s = s.replace('Sin', 'jnp.sin')
    s = s.replace('Pi', 'jnp.pi')
    s = s.replace('Abs', 'jnp.abs')
    s = s.replace('ArcCos', 'jnp.arccos')
    s = s.replace('ArcSin', 'jnp.arcsin')
    s = s.replace('ArcTan', 'jnp.arctan2')
    s = s.replace('Cot', '1/jnp.tan')
    s = s.replace('Csc', '1/jnp.sin')
    s = s.replace('Sec', '1/jnp.cos')
    s = s.replace('Tan', 'jnp.tan')
    s = s.replace("0.3333333333333333", "(1/3)")
    s = s.replace("E**", "jnp.exp")
    if 'arctan' in s:
        print('WARNING: FLIP THE ARCTAN ARGUMENTS MANUALLY')
    file = file.split('.')[0] + '.py'
    with open(file, 'w') as f:
        f.write(f"def {function_name}({','.join(final_variables)}):\n")
        f.write("\treturn (\n")
        f.write(s)
        f.write("\t)")

    if black:
        subprocess.run(['python', '-m', 'black', file])
    return s

In [3]:
# example usage: create the functions that generate the coefficients of the implicit
# surface of the star after rotating it to be along the z-axis. These are needed to
# solve for the projected area of the planet *as seen from the star*, which might vary
# with phase for an oblate planet and affects the total flux reflected back to us

# might need to run this outside of here, had some issues with the NotebookDirectory command
# subprocess.run(["wolframscript", "-file", "terminator_from_star.wls"])

# Parse those .txt into python functions, save as .py files
for i in [
    "pxx",
    "pxy",
    "pxz",
    "px0",
    "pyy",
    "pyz",
    "py0",
    "pzz",
    "pz0",
    "p00",
]:
    parse_fortranform(
        file=f"terminator_coeff_{i}.txt",
        function_name=f"_{i}",
        final_variables=[
            "p_xx",
            "p_xy",
            "p_xz",
            "p_x0",
            "p_yy",
            "p_yz",
            "p_y0",
            "p_zz",
            "p_z0",
            "p_00",
            "x_c",
            "y_c",
            "z_c",
        ],
        black=False,
        variable_renames={
            "pxx": "p_xx",
            "pxy": "p_xy",
            "pxz": "p_xz",
            "px0": "p_x0",
            "pyy": "p_yy",
            "pyz": "p_yz",
            "py0": "p_y0",
            "pzz": "p_zz",
            "pz0": "p_z0",
            "p00": "p_00",
            "xc": "x_c",
            "yc": "y_c",
            "zc": "z_c",
        },
        # variable_renames={"ω": "omega", "Ω": "Omega"}
    )

# Combine those functions into one file, also
# add one large jitted function that runs them all
funcs = []
for i in [
    "pxx",
    "pxy",
    "pxz",
    "px0",
    "pyy",
    "pyz",
    "py0",
    "pzz",
    "pz0",
    "p00",
]:
    with open(f"terminator_coeff_{i}.py") as f:
        funcs.append(f.read())

with open("planet_viewed_from_star.py", "w") as f:
    f.write("import jax\n")
    f.write('jax.config.update("jax_enable_x64", True)\n')
    f.write("import jax.numpy as jnp\n")
    for q in funcs:
        f.write(q)
        f.write("\n")

    f.write("@jax.jit\n")
    f.write("def planet_from_star(p_xx, p_xy, p_xz, p_x0, p_yy, p_yz, p_y0, p_zz, p_z0, p_00, x_c, y_c, z_c, **kwargs):\n")
    f.write("\treturn {\n")
    f.write("'p_xx': _pxx(p_xx, p_xy, p_xz, p_x0, p_yy, p_yz, p_y0, p_zz, p_z0, p_00, x_c, y_c, z_c),\n")
    f.write("'p_xy': _pxy(p_xx, p_xy, p_xz, p_x0, p_yy, p_yz, p_y0, p_zz, p_z0, p_00, x_c, y_c, z_c),\n")
    f.write("'p_xz': _pxz(p_xx, p_xy, p_xz, p_x0, p_yy, p_yz, p_y0, p_zz, p_z0, p_00, x_c, y_c, z_c),\n")
    f.write("'p_x0': _px0(p_xx, p_xy, p_xz, p_x0, p_yy, p_yz, p_y0, p_zz, p_z0, p_00, x_c, y_c, z_c),\n")
    f.write("'p_yy': _pyy(p_xx, p_xy, p_xz, p_x0, p_yy, p_yz, p_y0, p_zz, p_z0, p_00, x_c, y_c, z_c),\n")
    f.write("'p_yz': _pyz(p_xx, p_xy, p_xz, p_x0, p_yy, p_yz, p_y0, p_zz, p_z0, p_00, x_c, y_c, z_c),\n")
    f.write("'p_y0': _py0(p_xx, p_xy, p_xz, p_x0, p_yy, p_yz, p_y0, p_zz, p_z0, p_00, x_c, y_c, z_c),\n")
    f.write("'p_zz': _pzz(p_xx, p_xy, p_xz, p_x0, p_yy, p_yz, p_y0, p_zz, p_z0, p_00, x_c, y_c, z_c),\n")
    f.write("'p_z0': _pz0(p_xx, p_xy, p_xz, p_x0, p_yy, p_yz, p_y0, p_zz, p_z0, p_00, x_c, y_c, z_c),\n")
    f.write("'p_00': _p00(p_xx, p_xy, p_xz, p_x0, p_yy, p_yz, p_y0, p_zz, p_z0, p_00, x_c, y_c, z_c),\n")
    f.write("\t}\n")

# Format that big file
subprocess.run(["python", "-m", "black", "planet_viewed_from_star.py"])

# Remove the temporary files
for i in [
    "pxx",
    "pxy",
    "pxz",
    "px0",
    "pyy",
    "pyz",
    "py0",
    "pzz",
    "pz0",
    "p00",
]:
    os.remove(f"terminator_coeff_{i}.txt")
    os.remove(f"terminator_coeff_{i}.py")

reformatted planet_viewed_from_star.py

All done! ✨ 🍰 ✨
1 file reformatted.


In [16]:
# might need to run this outside of here, had some issues with the NotebookDirectory command
# subprocess.run(["wolframscript", "-file", "emission_profile.wls"])

# Parse those .txt into python functions, save as .py files
for i in [
    "x_x",
    "x_y",
    "x_z",
    "x_0",
    "y_x",
    "y_y",
    "y_z",
    "y_0",
    "z_x",
    "z_y",
    "z_z",
    "z_0",
]:
    parse_fortranform(
        file=f"{i}.txt",
        function_name=f"_{i}",
        final_variables=[
            "a",
            "e",
            "f",
            "Omega",
            "i",
            "omega",
            "r",
            "phi",
            "theta",
        ],
        black=False,
        variable_renames={"θ": "theta", "ϕ": "phi", "ω": "omega", "Ω": "Omega"},
)


# Combine those functions into one file, also
# add one large jitted function that runs them all
funcs = []
for i in [
    "x_x",
    "x_y",
    "x_z",
    "x_0",
    "y_x",
    "y_y",
    "y_z",
    "y_0",
    "z_x",
    "z_y",
    "z_z",
    "z_0",
]:
    with open(f"{i}.py") as f:
        funcs.append(f.read())

with open("emission_profile.py", "w") as f:
    f.write("import jax\n")
    f.write('jax.config.update("jax_enable_x64", True)\n')
    f.write("import jax.numpy as jnp\n")
    for q in funcs:
        f.write(q)
        f.write("\n")

    f.write("@jax.jit\n")
    f.write("def pre_squish_transform(a, e, f, Omega, i, omega, r, phi, theta):\n")
    f.write("\treturn jnp.array(\n")
    f.write("\t\t[\n")
    f.write("\t\t\t[_x_x(a, e, f, Omega, i, omega, r, phi, theta), _x_y(a, e, f, Omega, i, omega, r, phi, theta), _x_z(a, e, f, Omega, i, omega, r, phi, theta), _x_0(a, e, f, Omega, i, omega, r, phi, theta)],\n")
    f.write("\t\t\t[_y_x(a, e, f, Omega, i, omega, r, phi, theta), _y_y(a, e, f, Omega, i, omega, r, phi, theta), _y_z(a, e, f, Omega, i, omega, r, phi, theta), _y_0(a, e, f, Omega, i, omega, r, phi, theta)],\n")
    f.write("\t\t\t[_z_x(a, e, f, Omega, i, omega, r, phi, theta), _z_y(a, e, f, Omega, i, omega, r, phi, theta), _z_z(a, e, f, Omega, i, omega, r, phi, theta), _z_0(a, e, f, Omega, i, omega, r, phi, theta)],\n")
    f.write("\t\t]\n")
    f.write("\t)\n")


parse_fortranform(
    file="profile.txt",
    function_name="_profle",
    final_variables=[
        "x",
        "y",
        "z",
        "f1",
        "f2",
        "alpha",
        "beta",
        "kappa",
    ],
    black=False,
    variable_renames={"α" : "alpha", "β" : "beta", "κ" : "kappa"}
)

with open("emission_profile.py", "a") as f, open("profile.py") as g:
    f.write(g.read() + "\n")


with open("emission_profile.py", "a") as f:
    f.write("@jax.jit\n")
    f.write("def emission_profile(x, y, z, a, e, f, Omega, i, omega, r, phi, theta, f1, f2, alpha, beta, kappa):\n")
    f.write("\tt = pre_squish_transform(a, e, f, Omega, i, omega, r, phi, theta)\n")
    f.write("\tx, y, z = jnp.matmul(t, jnp.array([x,y,z,1]))\n")
    f.write("\treturn _profle(x, y, z, f1, f2, alpha, beta, kappa)\n")


# Format that big file
subprocess.run(["python", "-m", "black", "emission_profile.py"])

# Remove the temporary files
for i in [
    "x_x",
    "x_y",
    "x_z",
    "x_0",
    "y_x",
    "y_y",
    "y_z",
    "y_0",
    "z_x",
    "z_y",
    "z_z",
    "z_0",
    "profile"
]:
    os.remove(f"{i}.txt")
    os.remove(f"{i}.py")

reformatted emission_profile.py

All done! ✨ 🍰 ✨
1 file reformatted.


In [5]:
s = "((-(Sin(s)*Subscript(c,y1)) + Cos(s)*Subscript(c,y2))*(Pi + 6*(Cos(s)*Subscript(c,x1) + Sin(s)*Subscript(c,x2) + Subscript(c,x3))*   Sqrt(1 - (Cos(s)*Subscript(c,x1) + Sin(s)*Subscript(c,x2) + Subscript(c,x3))**2 -      (Cos(s)*Subscript(c,y1) + Sin(s)*Subscript(c,y2) + Subscript(c,y3))**2) -   6*ArcTan((Cos(s)*Subscript(c,x1) + Sin(s)*Subscript(c,x2) + Subscript(c,x3))/     Sqrt(1 - (Cos(s)*Subscript(c,x1) + Sin(s)*Subscript(c,x2) + Subscript(c,x3))**2 -        (Cos(s)*Subscript(c,y1) + Sin(s)*Subscript(c,y2) + Subscript(c,y3))**2))*   (-1 + (Cos(s)*Subscript(c,y1) + Sin(s)*Subscript(c,y2) + Subscript(c,y3))**2)))/12."


s = s.replace('Sqrt', 'jnp.sqrt')
s = s.replace('Cos', 'jnp.cos')
s = s.replace('Sin', 'jnp.sin')
s = s.replace('Pi', 'jnp.pi')
s = s.replace('Abs', 'jnp.abs')
s = s.replace('ArcCos', 'jnp.arccos')
s = s.replace('ArcSin', 'jnp.arcsin')
s = s.replace('ArcTan', 'jnp.arctan2')
s = s.replace('Cot', '1/jnp.tan')
s = s.replace('Csc', '1/jnp.sin')
s = s.replace('Sec', '1/jnp.cos')
s = s.replace('Tan', 'jnp.tan')
s = s.replace("0.3333333333333333", "(1/3)")
s = s.replace("E**", "jnp.exp")
s = s.replace("Subscript(c,x1)", "c_x1")
s = s.replace("Subscript(c,x2)", "c_x2")
s = s.replace("Subscript(c,x3)", "c_x3")
s = s.replace("Subscript(c,y1)", "c_y1")
s = s.replace("Subscript(c,y2)", "c_y2")
s = s.replace("Subscript(c,y3)", "c_y3")

print(s)

((-(jnp.sin(s)*c_y1) + jnp.cos(s)*c_y2)*(jnp.pi + 6*(jnp.cos(s)*c_x1 + jnp.sin(s)*c_x2 + c_x3)*   jnp.sqrt(1 - (jnp.cos(s)*c_x1 + jnp.sin(s)*c_x2 + c_x3)**2 -      (jnp.cos(s)*c_y1 + jnp.sin(s)*c_y2 + c_y3)**2) -   6*jnp.arctan2((jnp.cos(s)*c_x1 + jnp.sin(s)*c_x2 + c_x3)/     jnp.sqrt(1 - (jnp.cos(s)*c_x1 + jnp.sin(s)*c_x2 + c_x3)**2 -        (jnp.cos(s)*c_y1 + jnp.sin(s)*c_y2 + c_y3)**2))*   (-1 + (jnp.cos(s)*c_y1 + jnp.sin(s)*c_y2 + c_y3)**2)))/12.


In [None]:
# (
#     (-(jnp.sin(s) * c_y1) + jnp.cos(s) * c_y2)
#     * (
#         jnp.pi
#         + 6
#         * (jnp.cos(s) * c_x1 + jnp.sin(s) * c_x2 + c_x3)
#         * jnp.sqrt(
#             1
#             - (jnp.cos(s) * c_x1 + jnp.sin(s) * c_x2 + c_x3) ** 2
#             - (jnp.cos(s) * c_y1 + jnp.sin(s) * c_y2 + c_y3) ** 2
#         )
#         - 6
#         * jnp.arctan2(
#             (jnp.cos(s) * c_x1 + jnp.sin(s) * c_x2 + c_x3)
#             / jnp.sqrt(
#                 1
#                 - (jnp.cos(s) * c_x1 + jnp.sin(s) * c_x2 + c_x3) ** 2
#                 - (jnp.cos(s) * c_y1 + jnp.sin(s) * c_y2 + c_y3) ** 2
#             )
#         )
#         * (-1 + (jnp.cos(s) * c_y1 + jnp.sin(s) * c_y2 + c_y3) ** 2)
#     )
# ) / 12.0

In [None]:
# (
#     (-(jnp.sin(s) * c_y1) + jnp.cos(s) * c_y2)
#     * (
#         jnp.pi
#         + 6
#         * (jnp.cos(s) * c_x1 + jnp.sin(s) * c_x2 + c_x3)
#         * jnp.sqrt(
#             1
#             - (jnp.cos(s) * c_x1 + jnp.sin(s) * c_x2 + c_x3) ** 2
#             - (jnp.cos(s) * c_y1 + jnp.sin(s) * c_y2 + c_y3) ** 2
#         )
#         - 6
#         * jnp.arctan(
#             (jnp.cos(s) * c_x1 + jnp.sin(s) * c_x2 + c_x3)
#             / jnp.sqrt(
#                 1
#                 - (jnp.cos(s) * c_x1 + jnp.sin(s) * c_x2 + c_x3) ** 2
#                 - (jnp.cos(s) * c_y1 + jnp.sin(s) * c_y2 + c_y3) ** 2
#             )
#         )
#         * (-1 + (jnp.cos(s) * c_y1 + jnp.sin(s) * c_y2 + c_y3) ** 2)
#     )
# ) / 12.0