In [2]:
#| default_exp k_select

In [6]:
#| export
from __future__ import annotations
import jax, jax.numpy as jnp
from jax import lax

In [8]:
#| export
#chewc/k_select.py
def select_topk(scores: jnp.ndarray, k: int) -> jnp.ndarray:
    """Return indices of top-k scores. scores: (N,)."""
    vals, idx = lax.top_k(scores, k)
    return idx

def score_from_bv_pheno(
    bv: jnp.ndarray,    # (N, T) or None
    pheno: jnp.ndarray, # (N, T) or None
    is_active: jnp.ndarray,  # (N,)
    use="pheno",        # "pheno" or "bv"
    trait_idx=0,
) -> jnp.ndarray:
    """Build selection scores with masking of inactive as -inf."""
    if use == "bv" and bv is not None:
        raw = bv[:, trait_idx]
    else:
        raw = pheno[:, trait_idx]
    # mask inactives
    neg_inf = jnp.array(-jnp.inf, raw.dtype)
    return jnp.where(is_active, raw, neg_inf)


In [10]:
x = jnp.array([0.1, 0.9, -1.0, 0.5])
mask = jnp.array([1,1,0,1], dtype=bool)
scores = score_from_bv_pheno(bv=None, pheno=x[:,None], is_active=mask, use="pheno")
assert set(select_topk(scores, 2).tolist()) <= {1,3}


In [5]:
#| hide
import nbdev; nbdev.nbdev_export()