In [None]:
from lumfunc import Config, Bandpasses, SED, EmissionLineModel, library_dir, fig_dir

import numpy as np
import matplotlib.pyplot as plt

In [None]:
config = Config(library_dir / "default" / "config.yaml")
bandpasses = Bandpasses(config.filter_files)
line_model = EmissionLineModel()

lines = {line: [] for line in line_model.lines}
colors = []

for i in range(7, 31):
    # Create the SED
    sed = SED(config.sed_list[i], bandpasses)

    # Extract the underlying galsim SED
    sed = sed.sed_intrinsic

    # Predict the equivalent widths
    ews = line_model.predict(sed)
    for line in ews:
        lines[line].append(ews[line])

    # Save the colors the jax net used
    colors.append(line_model.jn_colors(sed))

lines = {line: np.array(ews) for line, ews in lines.items()}
colors = np.array(colors).T

In [None]:
fig, ax = plt.subplots(dpi=150)

# Plot the equivalent width as a function of color
cs = colors[1]
idx = np.argsort(cs)
for line, ews in lines.items():
    ax.plot(cs[idx], ews[idx], label=line.replace("_", " "))

ax.legend()
ax.set(
    xlabel="UV color",
    ylabel=r"Equivalent width ($\AA$)",
    ylim=(0, 120),
)

ax.axvline(0.125, c="silver", lw=5, zorder=0)

fig.savefig(fig_dir / "emission_lines_vs_color.pdf", bbox_inches="tight")

In [None]:
for i in range(len(cs)):
    print(f"{i + 7:<3} {cs[i]:>6.3f}   {cs[i] < 0.125}")