Skip to content

Commit

Permalink
API refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Jun 27, 2024
1 parent ee6c184 commit a1e4daa
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import tiledb.vector_search as vs
import tiledbsoma as soma

from .._experiment import _get_experiment_name
from .._open import DEFAULT_TILEDB_CONFIGURATION, open_soma
from .._release_directory import CensusMirror, _get_census_mirrors
from .._util import _uri_join
from ._embedding import get_embedding_metadata_by_name


class NeighborObs(NamedTuple):
Expand All @@ -32,20 +34,23 @@ class NeighborObs(NamedTuple):


def find_nearest_obs(
embedding_metadata: Dict[str, Any],
embedding_name: str,
organism: str,
census_version: str,
query: ad.AnnData,
*,
k: int = 10,
nprobe: int = 100,
memory_GiB: int = 4,
mirror: Optional[str] = None,
embedding_metadata: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> NeighborObs:
"""Search Census for similar obs (cells) based on nearest neighbors in embedding space.
Args:
embedding_metadata:
Information about the embedding to search, as found by
:func:`get_embedding_metadata_by_name`.
embedding_name, organism, census_version:
Identify the embedding to search, as in :func:`get_embedding_metadata_by_name`.
query:
AnnData object with an obsm layer embedding the query cells. The obsm layer name
matches ``embedding_metadata["embedding_name"]`` (e.g. scvi, geneformer). The layer
Expand All @@ -57,8 +62,15 @@ def find_nearest_obs(
cells) for a thorough search. Decrease for faster but less accurate search.
memory_GiB:
Memory budget for the search index, in gibibytes; defaults to 4 GiB.
mirror:
Name of the Census mirror to use for the search.
embedding_metadata:
The result of `get_embedding_metadata_by_name(embedding_name, organism, census_version)`.
Supplying this saves a network request for repeated searches.
"""
embedding_name = embedding_metadata["embedding_name"]
if embedding_metadata is None:
embedding_metadata = get_embedding_metadata_by_name(embedding_name, organism, census_version)
assert embedding_metadata["embedding_name"] == embedding_name
n_features = embedding_metadata["n_features"]

# validate query (expected obsm layer exists with the expected dimensionality)
Expand Down Expand Up @@ -100,17 +112,17 @@ def _resolve_embedding_index(


def predict_obs_metadata(
embedding_metadata: Dict[str, Any],
organism: str,
census_version: str,
neighbors: NeighborObs,
column_names: Sequence[str],
experiment: Optional[soma.Experiment] = None,
) -> pd.DataFrame:
"""Predict obs metadata attributes for the query cells based on the embedding nearest neighbors.
Args:
embedding_metadata:
Information about the embedding searched, as found by
:func:`get_embedding_metadata_by_name`.
organism, census_version:
Embedding information as supplied to :func:`find_nearest_obs`.
neighbors:
Results of a :func:`find_nearest_obs` search.
column_names:
Expand All @@ -129,8 +141,8 @@ def predict_obs_metadata(
with ExitStack() as cleanup:
if experiment is None:
# open Census transiently
census = cleanup.enter_context(open_soma(census_version=embedding_metadata["census_version"]))
experiment = census["census_data"][embedding_metadata["experiment_name"]]
census = cleanup.enter_context(open_soma(census_version=census_version))
experiment = census["census_data"][_get_experiment_name(organism)]

# fetch the desired obs metadata for all of the found neighbors
neighbor_obs = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from cellxgene_census.experimental import (
NeighborObs,
find_nearest_obs,
get_embedding_metadata_by_name,
predict_obs_metadata,
)

Expand Down Expand Up @@ -42,30 +41,38 @@ def test_embeddings_search(true_neighbors: Dict[str, Any], query_result: Neighbo

@pytest.mark.experimental
@pytest.mark.live_corpus
def test_embeddings_search_errors(emb_metadata: Dict[str, Any], query_anndata: ad.AnnData) -> None:
# no index for the embedding
emb_metadata2 = emb_metadata.copy()
emb_metadata2["indexes"] = []
with pytest.raises(ValueError, match="No suitable embedding index"):
find_nearest_obs(emb_metadata2, query_anndata)
def test_embeddings_search_errors(query_anndata: ad.AnnData) -> None:
# bogus embedding name
with pytest.raises(ValueError, match="No embeddings found"):
find_nearest_obs(
"bogus123", TRUE_NEAREST_NEIGHBORS_ORGANISM, TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION, query_anndata
)
# query anndata missing the embedding layer
bogus_ad = query_anndata.copy()
bogus_ad.obsm.pop(TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME)
with pytest.raises(ValueError, match="Query does not have"):
find_nearest_obs(emb_metadata, bogus_ad)
find_nearest_obs(
TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME,
TRUE_NEAREST_NEIGHBORS_ORGANISM,
TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION,
bogus_ad,
)
# embedding layer has wrong number of features
bogus_ad = query_anndata.copy()
bogus_ad.obsm[TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME] = np.zeros((len(bogus_ad), 42))
with pytest.raises(ValueError, match="features, expected"):
find_nearest_obs(emb_metadata, bogus_ad)
find_nearest_obs(
TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME,
TRUE_NEAREST_NEIGHBORS_ORGANISM,
TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION,
bogus_ad,
)
return


@pytest.mark.experimental
@pytest.mark.live_corpus
def test_predict_obs_metadata(
emb_metadata: Dict[str, Any], query_anndata: ad.AnnData, query_result: NeighborObs
) -> None:
def test_predict_obs_metadata(query_anndata: ad.AnnData, query_result: NeighborObs) -> None:
columns = ["cell_type", "tissue_general"]

with cellxgene_census.open_soma(census_version=TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION) as census:
Expand All @@ -76,7 +83,9 @@ def test_predict_obs_metadata(
.to_pandas()
)

pred_df = predict_obs_metadata(emb_metadata, query_result, columns)
pred_df = predict_obs_metadata(
TRUE_NEAREST_NEIGHBORS_ORGANISM, TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION, query_result, columns
)
assert len(pred_df) == len(query_anndata.obs)

for col in columns:
Expand All @@ -86,16 +95,6 @@ def test_predict_obs_metadata(
assert accuracy > 0.75, f"Accuracy for {col} is {accuracy}"


@pytest.fixture(scope="module")
def emb_metadata() -> Dict[str, Any]:
return get_embedding_metadata_by_name(
TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME,
TRUE_NEAREST_NEIGHBORS_ORGANISM,
TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION,
"obs_embedding",
)


@pytest.fixture(scope="module")
def true_neighbors() -> Dict[int, List[Dict[str, Any]]]:
ans = {}
Expand All @@ -117,9 +116,11 @@ def query_anndata(true_neighbors: Dict[str, Any]) -> ad.AnnData:


@pytest.fixture(scope="module")
def query_result(emb_metadata: Dict[str, Any], query_anndata: ad.AnnData) -> NeighborObs:
def query_result(query_anndata: ad.AnnData) -> NeighborObs:
return find_nearest_obs(
emb_metadata,
TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME,
TRUE_NEAREST_NEIGHBORS_ORGANISM,
TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION,
query_anndata,
k=TRUE_NEAREST_NEIGHBORS_K,
nprobe=25,
Expand Down

0 comments on commit a1e4daa

Please sign in to comment.