The purpose of this notebook is to inspect the domains of the X (physical
parameter) and Y (waveform component value) data, so that transformers can
be designed to improve the deep learning part of the offline phase.

In [None]:
import numpy as np

from romgw.config.env import PROJECT_ROOT
from romgw.typing.core import RealArray, ComplexArray, BBHSpinType, ModeType, ComponentType
from romgw.typing.utils import validate_literal
from romgw.waveform.dataset import ComponentWaveformDataset

In [None]:
bbh_spin: BBHSpinType = "NS"
mode: ModeType = "2,2"
component: ComponentType = "phase"
model_name = f"NonLinearRegression"

In [None]:
# ----- Validate literals -----
bbh_spin = validate_literal(bbh_spin, BBHSpinType)
mode = validate_literal(mode, ModeType)
component = validate_literal(component, ComponentType)

In [None]:
# ----- Root directory for IO operations -----
data_dir = PROJECT_ROOT / "data" / bbh_spin / "train" / mode / component

In [None]:
def load_raw_data(
    bbh_spin: BBHSpinType,
    mode: ModeType,
    component: ComponentType,
) -> tuple[RealArray, RealArray]:
    """"""
    # Validate literals. Raises error if invalid.
    bbh_spin = validate_literal(bbh_spin, BBHSpinType)
    mode = validate_literal(mode, ModeType)
    component = validate_literal(component, ComponentType)

    # Root directory for IO operations.
    data_dir = PROJECT_ROOT / "data" / bbh_spin / "train" / mode / component

    # Load waveforms.
    wf_dir = data_dir / "raw"
    waveforms = ComponentWaveformDataset.from_directory(wf_dir,
                                                        component=component)

    # Load empirical time nodes.
    empirical_time_nodes_file = (
        data_dir / "empirical_interpolation" / "empirical_time_nodes.npy"
    )
    empirical_time_nodes = np.load(empirical_time_nodes_file,
                                   allow_pickle=False)

    # Make X and Y data arrays.
    X_raw = waveforms.params_array
    Y_raw = waveforms.array[:, empirical_time_nodes]

    return X_raw, Y_raw