### First import everything! ###

In [None]:
import numpy as np
import xarray as xr
import re

### Take the wavelengths ###

In [17]:
def parse_wavelengths(lines):
    """
    Extracts the wavelength array from a file's lines.
    """
    num_wavelengths = int(lines[0].strip())
    i = 1
    wavelengths = []
    while len(wavelengths) < num_wavelengths:
        wavelengths.extend([float(x) for x in lines[i].strip().split()])
        i += 1
    return np.array(wavelengths), i

### Take the temperature ###

In [18]:
def parse_temperature_array(lines, start_index):
    """
    Extracts the temperature values listed before the flux blocks.
    """
    i = start_index
    while not re.search(r'\d', lines[i]):
        i += 1
    temperature_values = []
    while re.search(r'^([ \d.eE+-]+)$', lines[i].strip()):
        temperature_values.extend([float(x) for x in lines[i].strip().split()])
        i += 1
    return np.array(temperature_values), i


### Take the flux blocks ###

In [None]:
def parse_flux_blocks(lines, start_index, num_wavelengths):
    """
    Extracts the lengthy flux blocks from the file alongside their assigned temperature.
    """
    i = start_index
    temperatures = []
    flux_columns = []

    fluxes_per_line = 6  

    while i < len(lines):
        line = lines[i]

        if "Effective temperature" in line:
            # Extract temperature
            temp_match = re.search(r"Effective temperature\s*=\s*([\d.]+)", line)
            if not temp_match:
                raise ValueError(f"Could not parse temperature from line: {line}")
            temp = float(temp_match.group(1))
            temperatures.append(temp)
            i += 1

            # Read flux values for this temperature block
            flux_block = []
            while i < len(lines) and len(flux_block) < num_wavelengths:
                current_line = lines[i].strip()
                if current_line == "":
                    i += 1  # skip empty lines
                    continue
                try:
                    if fluxes_per_line:
                        # Parse multiple floats in one line
                        values = [float(x) for x in current_line.split()]
                        flux_block.extend(values)
                    else:
                        # Parse one float per line
                        flux_value = float(current_line)
                        flux_block.append(flux_value)
                    i += 1
                except ValueError:
                    # Stop if line is not numeric flux data
                    break

            if len(flux_block) != num_wavelengths:
                raise ValueError(f"Expected {num_wavelengths} flux values, got {len(flux_block)} for temperature {temp}")

            flux_columns.append(flux_block)
        else:
            i += 1

    flux_array = np.column_stack(flux_columns)  # shape (num_wavelengths, num_temperatures)
    temperatures = np.array(temperatures)
    return temperatures, flux_array



### Do it for every file ###

In [27]:
def parse_flux_file(filepath):
    """
    Full parser: reads a file and returns wavelength, temperature, flux arrays.
    """
    with open(filepath, 'r') as f:
        lines = f.readlines()

    wavelengths, i = parse_wavelengths(lines)
    temperatures, i = parse_temperature_array(lines, i)
    temp_labels, flux_array = parse_flux_blocks(lines, i, len(wavelengths))

    return wavelengths, temp_labels, flux_array


### Make them 3D w/ gravity ###

In [28]:
def load_flux_cube(file_gravity_pairs):
    """
    Loads a cube of flux data across multiple gravities.
    """
    all_fluxes = []
    gravities = []

    for filename, gravity in file_gravity_pairs:
        print(f"Loading: {filename}")
        wavelengths, temps, flux = parse_flux_file(filename)
        all_fluxes.append(flux[np.newaxis, :, :])  # Add gravity dim
        gravities.append(gravity)

    stacked = np.concatenate(all_fluxes, axis=0)

    flux = xr.DataArray(
        stacked,
        dims=["gravity", "wavelength", "temperature"],
        coords={
            "gravity": gravities,
            "wavelength": wavelengths,
            "temperature": temps
        },
        name="flux",
        attrs={"units": "erg cm^-2 s^-1 Hz^-1"}
    )
    flux.attrs["units"] = "erg cm^-2 s^-1 Hz^-1"
    flux.coords["wavelength"].attrs["units"] = "Ã…"        # or "nm", etc.
    flux.coords["temperature"].attrs["units"] = "K"
    flux.coords["gravity"].attrs["units"] = "log(g) [cm s^-2]"
    return flux


### File Lists ###

In [41]:
files_below_7_no_H2 = [
    ("grid_3d_new_v3.tar/500_3D", 500),
    ("grid_3d_new_v3.tar/550_3D", 550),
    ("grid_3d_new_v3.tar/600_3D", 600),
    ("grid_3d_new_v3.tar/650_3D", 650),
]

files_7_and_above_no_H2 = [
    ("grid_3d_new_v3.tar/700_3D", 700),
    ("grid_3d_new_v3.tar/750_3D", 750),
    ("grid_3d_new_v3.tar/800_3D", 800),
    ("grid_3d_new_v3.tar/850_3D", 850),
    ("grid_3d_new_v3.tar/900_3D", 900),
]

files_H2 = [
    ("grid_3d_h2lines.tar/700_3D_H2lines", 700),
    ("grid_3d_h2lines.tar/750_3D_H2lines", 750),
    ("grid_3d_h2lines.tar/800_3D_H2lines", 800),
    ("grid_3d_h2lines.tar/850_3D_H2lines", 850),
    ("grid_3d_h2lines.tar/900_3D_H2lines", 900)
]
flux_cube_below_7_no_H2 = load_flux_cube(files_below_7_no_H2)
flux_cube_7_and_above_no_H2 = load_flux_cube(files_7_and_above_no_H2)

flux_cube_H2 = load_flux_cube(files_H2)

Loading: grid_3d_new_v3.tar/500_3D
Loading: grid_3d_new_v3.tar/550_3D
Loading: grid_3d_new_v3.tar/600_3D
Loading: grid_3d_new_v3.tar/650_3D
Loading: grid_3d_new_v3.tar/700_3D
Loading: grid_3d_new_v3.tar/750_3D
Loading: grid_3d_new_v3.tar/800_3D
Loading: grid_3d_new_v3.tar/850_3D
Loading: grid_3d_new_v3.tar/900_3D
Loading: grid_3d_h2lines.tar/700_3D_H2lines
Loading: grid_3d_h2lines.tar/750_3D_H2lines
Loading: grid_3d_h2lines.tar/800_3D_H2lines
Loading: grid_3d_h2lines.tar/850_3D_H2lines
Loading: grid_3d_h2lines.tar/900_3D_H2lines


In [None]:
# cut_dataset = dataset.sel(dimension=slice(bottom, top))


In [None]:
# Access flux at specific gravity, wavelength, temperature
flux_cube.sel(gravity=None, temperature=None, wavelength=None, method="nearest")

In [None]:
# Plot spectrum at a given gravity
flux_cube.sel(gravity=None, temperature=None, method="nearest").plot()

### Conversion to cdf/h5 ###

In [None]:
def save_to_netcdf_h5netcdf(dataset, output_path):
    """
    Save an xarray Dataset or DataArray to a NetCDF file using the h5netcdf engine.

    Input:
        dataset (xr.Dataset or xr.DataArray): The xarray object to save.
        output_path (str): File path to save the NetCDF file.

    Returns:
        str: The file path where the data was saved.
    """
    dataset.to_netcdf(output_path, engine="h5netcdf")
    return output_path

In [45]:
save_to_netcdf_h5netcdf(flux_cube_below_7_no_H2, "flux_model_3D_A.h5")
save_to_netcdf_h5netcdf(flux_cube_7_and_above_no_H2, "flux_model_3D_B.h5")
save_to_netcdf_h5netcdf(flux_cube_H2, "flux_model_3D_H2.h5")

'flux_model_3D_H2.h5'