Skip to content

Commit

Permalink
BUG: update 'fit_spectra' to use 'fobs_cents' instead of 'fobs' in li…
Browse files Browse the repository at this point in the history
…brary files. Also handle NaN fit values in PSD.
  • Loading branch information
lzkelley committed Apr 15, 2024
1 parent e0a8d27 commit d8ace3a
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions holodeck/librarian/fit_spectra.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
"""Script and methods to fit simulated GWB spectra with analytic functions.
Usage
-----
Usage (fit_spectra.py)
----------------------
For usage information, run the script with the ``-h`` or ``--help`` arguments, i.e.::
python -m holodeck.librarian.fit_spectra -h
Typically, the only argument required is the path to the folder containing the combined library
file (``sam_lib.hdf5``).
This script can be run serially (on a single processor) or in parallel. To run in parallel, MPI and
``mpi4py`` are required. For parallel runs, use::
mpirun -np <NUM_CORES> python -m holodeck.librarian.fit_spectra <ARGS>
Notes
-----
As a script, this submodule runs in parallel, with the main processor loading a holodeck library
Expand Down Expand Up @@ -88,9 +95,9 @@ def fit_library_spectra(library_path, log, recreate=False):
comm = MPI.COMM_WORLD
except Exception as err:
comm = None
holo.log.error(f"failed to load `mpi4py` in {__file__}: {err}")
holo.log.error("`mpi4py` may not be included in the standard `requirements.txt` file")
holo.log.error("Check if you have `mpi4py` installed, and if not, please install it")
log.error(f"failed to load `mpi4py` in {__file__}: {err}")
log.error("`mpi4py` may not be included in the standard `requirements.txt` file")
log.error("Check if you have `mpi4py` installed, and if not, please install it.")
raise err

# ---- setup path
Expand Down Expand Up @@ -124,7 +131,7 @@ def fit_library_spectra(library_path, log, recreate=False):
# ---- load library GWB and convert to PSD

with h5py.File(library_path, 'r') as library:
fobs = library['fobs'][()]
fobs = library['fobs_cents'][()]
psd = holo.utils.char_strain_to_psd(fobs[np.newaxis, :, np.newaxis], library['gwb'][()])

nsamps, nfreqs, nreals = psd.shape
Expand Down Expand Up @@ -203,7 +210,10 @@ def fit_library_spectra(library_path, log, recreate=False):
all_psd = all_psd[idx]

# confirm that the resorting worked correctly
assert np.all(all_psd == psd)
matches = (all_psd == psd)
# if values are NaN, then equality check will fail... skip those
skips = ~np.isfinite(all_psd)
assert np.all(matches | skips)

# reshape arrays to convert back to (Samples, Realizations, ...)
len_nbins_plaw = len(nbins_plaw)
Expand All @@ -219,7 +229,10 @@ def fit_library_spectra(library_path, log, recreate=False):
all_psd = np.moveaxis(all_psd, 1, -1)

# confirm that reshaping worked correctly
assert np.all(all_psd == psd_check)
matches = (all_psd == psd_check)
# if values are NaN, then equality check will fail... skip those
skips = ~np.isfinite(all_psd)
assert np.all(matches | skips)

# Report how many fits failed
fails = np.any(~np.isfinite(fits_plaw), axis=-1)
Expand Down

0 comments on commit d8ace3a

Please sign in to comment.