In [1]:
from astropy.table import Table 
from pathlib import Path

In [2]:
FITS_FILE = Path("/Users/vinland/extinction/jwst_gc_ref_table.fits")
catalog = Table.read(FITS_FILE)

In [3]:
cutoff1 = [(8, 23.9), (6.2, 21.2)] 
cutoff2 = [(8, 23.5), (6.2, 20.8)]

In [10]:
import pickle
import numpy as np 
import pandas as pd 

from pathlib import Path 
from typing import Tuple, Optional
from astropy.table import Table


Cutoffs = Tuple[Tuple[float, float], Tuple[float, float]]

class GenerateRedClump: 
    """ 
    gets red clump stars in all wavelengths from a catalog 
    uses the stars in F115W - F212N vs. F115W CMDs as a reference
    """

    def __init__( 
        self,
        cutoff1     : Cutoffs, 
        cutoff2     : Cutoffs, 
        det         : str, 
        out_dir     : Path,
        *, 
        clr_range   : Optional[Tuple[float, float]] = None, 
        expand_factor: float = 3.0, 
    ):
        self.det       = det
        self.cutoff1   = cutoff1 
        self.cutoff2   = cutoff2 
        self.clr_range = clr_range 
        self.expand    = expand_factor
        self.out_dir   = out_dir 

        self.catalog = catalog

        self.i_f115w = _epoch_idx_first(self.catalog, "F115W", det)
        self.i_f212n = _epoch_idx_first(self.catalog, "F212N", det) 
        self.i_f140m = _try_epoch(self.catalog, "F140M", det)
        
        # F182M not in NRCB1 
        self.i_f182m = None if det == "NRCB1" else _try_epoch(self.catalog, "F182M", det)

        # LW filters 
        self.i_f323n = _try_epoch(self.catalog, "F323N", "NRCB5")
        self.i_f405n = _try_epoch(self.catalog, "F405N", "NRCB5")

    def _mask_cmd(self): 
        m_vega = self.catalog['m_vega']
        me_vega = self.catalog['me_vega']

        mF115 = np.ma.MaskedArray(m_vega[:, self.i_f115w], copy=False)
        mF212 = np.ma.MaskedArray(m_vega[:, self.i_f212n], copy=False)
        meF115 = np.ma.MaskedArray(me_vega[:, self.i_f115w], copy=False)
        meF212 = np.ma.MaskedArray(me_vega[:, self.i_f212n], copy=False)

        good = (
                (~mF115.mask) & (~mF212.mask) &
                (~meF115.mask) & (~meF212.mask)
        )

        color = mF115 - mF212
        ymag  = mF115

        # parallel lines to slope, intercepts
        (x1,y1),(x2,y2) = self.cutoff1
        (_, _), (x4,y4) = self.cutoff2
        slope = (y2 - y1) / (x2 - x1)
        b1 = y2 - slope * x2
        b2 = y4 - slope * x4
        height = abs(b1 - b2)
        upper_b = max(b1, b2) + self.expand * height
        lower_b = min(b1, b2) - self.expand * height

        if not self.clr_range: 
            self.clr_range = (color.min(), color.max())

        rc_mask = (
            good &
            (ymag  <= (slope * color + upper_b)) &
            (ymag  >= (slope * color + lower_b)) &
            (color >= self.clr_range[0]) &
            (color <= self.clr_range[1])
        )
        return rc_mask, mF115, mF212

    def extract(self):
        """
        extracts the red clump stars from 
        F115W - F212N vs. F115W CMDs 

        saves a dataframe to `out_dir` with np.ndarrays
        of the idx, x, y, mF115W/.../F405N magnitudes 
        and their errors. 
        """

        rc_mask, mF115, mF212 = self._mask_cmd()
        idx = np.flatnonzero(rc_mask)

        # Base outputs (positions from F115W epoch)
        x = np.asarray(self.catalog['x'][idx, self.i_f115w], dtype=float)
        y = np.asarray(self.catalog['y'][idx, self.i_f115w], dtype=float)

        out = {
            'idx' : idx.astype(int),
            'x' : x, 
            'y' : y,
            'mF115W'  : np.asarray(mF115[idx], dtype=float),
            'mF212N'  : np.asarray(mF212[idx], dtype=float),
            'meF115W' : np.asarray(self.catalog['me_vega'][idx, self.i_f115w], dtype=float),
            'meF212N' : np.asarray(self.catalog['me_vega'][idx, self.i_f212n], dtype=float),
        }

        # wavelength lifts
        def _lift(name: str, i: Optional[int]):
            if i is None: return
            out[f'm{name}']  = np.asarray(self.catalog['m_vega'][idx, i], dtype=float)
            out[f'me{name}'] = np.asarray(self.catalog['me_vega'][idx, i], dtype=float)

        _lift('F140M', self.i_f140m)
        _lift('F182M', self.i_f182m)   
        _lift('F323N', self.i_f323n)   
        _lift('F405N', self.i_f405n)  

        df = pd.DataFrame(out)

        self.out_dir.mkdir(parents=True, exist_ok=True)
        with open(self.out_dir / f'{self.det}_red_clump_stars.pickle', 'wb') as f:
            pickle.dump(df, f)


def _try_epoch(catalog: Table, filt: str, det: str) -> Optional[int]:
    try:
        return _epoch_idx_first(catalog, filt, det)
    except Exception:
        return None

In [11]:
import numpy as np
from astropy.table import Table

# predicted `det` mapping
_DET_FALLBACK = np.array(
    [
        "NRCB1","NRCB2","NRCB3","NRCB4", 
        "NRCB1","NRCB2","NRCB3","NRCB4",
        "NRCB1","NRCB2","NRCB3","NRCB4","NRCB5","NRCB5",
        "NRCB1","NRCB2","NRCB3","NRCB4","NRCB5",
        "NRCB1","NRCB2","NRCB3","NRCB4","NRCB5",
        "NRCB1","NRCB2","NRCB3","NRCB4",
        "NRCB1","NRCB2","NRCB3","NRCB4",
        "NRCB1","NRCB2","NRCB3","NRCB4","NRCB5"
    ], 
    dtype="U10"
)


def load_catalog(path: str): 
    return Table.read(path)


def _s(x): 
    # decode bytes to str 
    return x.decode("utf-8") if isinstance(x, (bytes, np.bytes_)) else str(x)


def _epoch_meta(catalog):
    n_epochs = catalog["filt"].shape[1]
    filt_row = np.empty(n_epochs, dtype="U10") 

    # get all filters 
    for i in range(n_epochs): 
        # unmasked values
        good = np.where(~catalog['x'][:, i].mask)[0]
        if good.size > 0: 
            filt_row[i] = _s(catalog["filt"][good[0], i]) 
        else: 
            filt_row[i] = "None" 

    # get all detectors 
    if "det" in catalog.colnames: 
        det_row = np.empty(n_epochs, dtype="U10")

        for j in range(n_epochs): 
            # unmasked values 
            good = np.where(~catalog['x'][:, j].mask)[0]
            if good.size > 0: 
                det_row[j] = _s(catalog["det"][good[0], j]) 
            else: 
                det_row[j] = "None" 
    else: 
        det_row = _DET_FALLBACK 
        if det_row.size != n_epochs: 
            raise ValueError(
                "detector fallback length mismatch: "
                f"{det_row.size} != n_epochs {n_epochs}"
            )

    return filt_row, det_row

def _epoch_idx_first(catalog, filt: str, det: str) -> int: 
    """
    return the first epoch index for (filt, det)
    """

    n_epochs = catalog["filt"].shape[1]
    filt_row = np.empty(n_epochs, dtype="U10")

    for i in range(n_epochs):
        xcol_i = catalog['x'][:, i]  
        good   = np.where(~xcol_i.mask)[0]

        filt_row[i] = _s(catalog["filt"][good[0], i]) if good.size else "None"

    if "det" in catalog.colnames:
        det_row = np.empty(n_epochs, dtype="U10")
        for j in range(n_epochs):
            xcol_j = catalog['x'][:, j]
            good   = np.where(~xcol_j.mask)[0]

            det_row[j] = _s(catalog["det"][good[0], j]) if good.size else "None"
    else:
        det_row = _DET_FALLBACK
        if det_row.size != n_epochs:
            raise ValueError(
                "detector fallback length mismatch: "
                f"{det_row.size} != n_epochs {n_epochs}"
            )

    idx = np.where((filt_row == _s(filt)) & (det_row == _s(det)))[0]
    if idx.size == 0:
        raise ValueError(f"No epoch found for ({filt}, {det})")
    return int(idx[0])  


def get_matches(catalog, filt1: str, det1: str, filt2: str, det2: str): 
    filt_row, det_row = _epoch_meta(catalog) 

    # first epoch 
    idx1_arr = np.where((filt_row == filt1) & (det_row == det1))[0]
    idx2_arr = np.where((filt_row == filt2) & (det_row == det2))[0]
    if idx1_arr.size == 0 or idx2_arr.size == 0:
        raise ValueError(
            "No epoch found for either " 
            f"({filt1}, {det1}) or ({filt2}, {det2})"
        )
    idx1 = int(idx1_arr[0])
    idx2 = int(idx2_arr[0])


    m1  = np.ma.MaskedArray(catalog["m_vega"][:, idx1],  copy=False)
    m2  = np.ma.MaskedArray(catalog["m_vega"][:, idx2],  copy=False)
    me1 = np.ma.MaskedArray(catalog["me_vega"][:, idx1], copy=False)
    me2 = np.ma.MaskedArray(catalog["me_vega"][:, idx2], copy=False)

    is_finite = ( 
        (~m1.mask) & (~m2.mask) & (~me1.mask) & (~me2.mask) & 
        np.isfinite(m1.data)  & np.isfinite(m2.data) & 
        np.isfinite(me1.data) & np.isfinite(me2.data) 
    )

    return m1[is_finite], m2[is_finite], me1[is_finite], me2[is_finite]


In [15]:
rc = GenerateRedClump(
    det="NRCB1",
    cutoff1=cutoff1,
    cutoff2=cutoff2,
    out_dir=Path.home() / "extinction/",
    expand_factor=3.0,  
)

df = rc.extract()

TypeError: extract() got an unexpected keyword argument 'export_basename'