In [None]:
import numpy as np
import os
import h5py

import fitsio
import joblib
from smatch.matcher import Matcher

In [None]:
pth = "/global/cfs/cdirs/des/y6-balrog/DataRelease/"
maglim_fname = "fiducial_matched_measured_maglim.hdf5"
truth_fname = "fiducial_injected_sof.hdf5"


def _extract_data_tile(tile_name, skip_inj=False, skip_maglim=False):
    d_inj = None
    d_maglim = None

    if not skip_inj:
        # read the truth columns
        with h5py.File(os.path.join(pth, truth_fname)) as fp:
            msk = (fp["wide_tilename"][:] == tile_name.encode("utf-8"))
            msk &= (fp["flags_bad_zp"][:] == 0)
            msk &= (fp["bdf_flags"][:] == 0)
            msk &= (fp["flags"][:] == 0)

            d_inj = dict(
                id=fp["id"][msk],
                ra=fp["ra"][msk],
                dec=fp["dec"][msk],
                bdf_flux_deredden=fp["bdf_flux_deredden"][msk],
                bdf_flux_err=fp["bdf_flux_err"][msk],
                wide_tilename=fp["wide_tilename"][msk],
            )

    if not skip_maglim:
        # read the maglim/sof catalogs w/ cuts
        with h5py.File(os.path.join(pth, maglim_fname)) as fp:
            msk = (fp["sof_wide_tilename"][:] == tile_name.encode("utf-8"))
            msk &= (fp["sof_meas_PASS_GOLD_FLAGS"][:] == 1)
            msk &= (fp["sof_flags_bad_zp"][:] == 0)
            msk &= (fp["sof_meas_EXT_MASH"][:] >= 2)
            msk &= (fp["sof_meas_flags"][:] == 0)
            msk &= (fp["sof_meas_bdf_flags"][:] == 0)

            d_maglim = dict(
                id=fp["sof_id"][msk],
                ra=fp["sof_ra"][msk],
                dec=fp["sof_dec"][msk],
                meas_ra=fp["sof_meas_ra"][msk],
                meas_dec=fp["sof_meas_dec"][msk],
                meas_flux_mag_deredden=fp["sof_meas_bdf_flux_deredden"][msk],
                meas_bdf_flux_err=fp["sof_meas_bdf_flux_err"][msk],
                wide_tilename=fp["sof_wide_tilename"][msk],
                tomo_bin=fp["tomo_bin"][msk],
                DNF_Z=fp["DNF_Z"][msk],
            )

    # match to the truth to build the combined catalog
    if d_maglim is not None and d_inj is not None:
        mtch = Matcher(
            d_maglim["ra"], d_maglim["dec"]
        )
        idx_into_maglim, d_to_maglim = mtch.query_knn(
            d_inj["ra"],
            d_inj["dec"],
            return_distances=True,

        )
        assert idx_into_maglim.shape == d_inj["ra"].shape
        msk = d_to_maglim == 0
        assert int(np.sum(msk)) == d_maglim["ra"].shape[0]

        for k in d_maglim:
            if k not in d_inj:
                v_maglim = d_maglim[k]

                if len(v_maglim.shape) == 2:
                    shape = (d_inj["ra"].shape[0], v_maglim.shape[1])
                else:
                    shape = d_inj["ra"].shape

                v_inj = np.zeros(shape, dtype=v_maglim.dtype)

                if "f" in v_maglim.dtype.descr:
                    v_inj += np.nan
                else:
                    v_inj += -1

                v_inj[msk] = v_maglim[idx_into_maglim[msk]]

                d_inj[k] = v_inj

            assert np.array_equal(
                d_inj[k][msk],
                d_maglim[k][idx_into_maglim[msk]]
            )

        d_ret = d_inj
    elif d_maglim is not None:
        d_ret = d_maglim
    elif d_inj is not None:
        d_ret = d_inj
    else:
        d_ret = None

    return d_ret


def _concat_dicts(dicts):
    d_ret = {}
    for k in list(dicts[0].keys()):
        d_ret[k] = np.concatenate(
            [d[k] for d in dicts],
            axis=0,
        )
        for d in dicts:
            del d[k]
    return d_ret

In [None]:
tnames = np.loadtxt("tnames.txt", dtype="U")
rng = np.random.default_rng(seed=42)
tnames_to_process = rng.choice(tnames, replace=False, size=10)

jobs = [joblib.delayed(_extract_data_tile)(tname) for tname in tnames_to_process]

with joblib.Parallel(backend="loky", verbose=100, n_jobs=2) as exc:
    d = _concat_dicts(exc(jobs))

In [None]:
with h5py.File("desy6_balrog_maglim_matched_v1_10.h5", "w") as fp:
    for k in d:
        fp.create_dataset(k, data=d[k])