diff --git a/tools/models/metrics/ontology_mapper.py b/tools/models/metrics/ontology_mapper.py new file mode 100644 index 000000000..76a323d32 --- /dev/null +++ b/tools/models/metrics/ontology_mapper.py @@ -0,0 +1,543 @@ +# ruff: noqa +# type: ignore + +""" +Provides classes to recreate cell type and tissue mappings as used in CELLxGENE Discover + +- OntologyMapper abstract class to create other mappers +- SystemMapper to map any tissue to a System +- OrganMapper to map any tissue to an Organ +- TissueGeneralMapper to map any tissue to another tissue as shown in Gene Expression and Census +- CellClassMapper to map any cell type to a Cell Class +- CellSubclassMapper to map any cell type to a Cell Subclass + +""" + +import os +from abc import ABC, abstractmethod +from typing import List, Union + +import owlready2 + + +class OntologyMapper(ABC): + # Terms to ignore when mapping + BLOCK_LIST = [ + "BFO_0000004", + "CARO_0000000", + "CARO_0030000", + "CARO_0000003", + "NCBITaxon_6072", + "Thing", + "unknown", + ] + + def __init__( + self, + high_level_ontology_term_ids: List[str], + ontology_owl_path: Union[str, os.PathLike], + root_ontology_term_id: str, + ): + self._cached_high_level_terms = {} + self._cached_labels = {} + self.high_level_terms = high_level_ontology_term_ids + self.root_ontology_term_id = root_ontology_term_id + + # TODO improve this. First time it loads it raises a TypeError for CL. But redoing it loads it correctly + # The type error is + # 'http://purl.obolibrary.org/obo/IAO_0000028' belongs to more than one entity + # types (cannot be both a property and a class/an individual)! + # So we retry only once + try: + self._ontology = owlready2.get_ontology(ontology_owl_path).load() + except TypeError: + self._ontology = owlready2.get_ontology(ontology_owl_path).load() + + def get_high_level_terms(self, ontology_term_id: str) -> List[str]: + """ + Returns the associated high-level ontology term IDs from any other ID + """ + + if ontology_term_id == "unknown": + return ["unknown"] + + ontology_term_id = self.reformat_ontology_term_id(ontology_term_id, to_writable=False) + + if ontology_term_id in self._cached_high_level_terms: + return self._cached_high_level_terms[ontology_term_id] + + owl_entity = self._get_entity_from_id(ontology_term_id) + + # If not found as an ontology ID raise + if not owl_entity: + raise ValueError("ID not found in the ontology.") + + # List ancestors for this entity, including itself if it is in the list of high level terms + ancestors = [owl_entity.name] if ontology_term_id in self.high_level_terms else [] + + branch_ancestors = self._get_branch_ancestors(owl_entity) + # Ignore branch ancestors if they are not under the root node + if branch_ancestors: + if self.root_ontology_term_id in branch_ancestors: + ancestors.extend(branch_ancestors) + + # Check if there's at least one top-level entity in the list of ancestors, and add them to + # the return list of high level term. Always include itself + resulting_high_level_terms = [] + for high_level_term in self.high_level_terms: + if high_level_term in ancestors: + resulting_high_level_terms.append(high_level_term) + + # If no valid high level terms return None + if len(resulting_high_level_terms) == 0: + resulting_high_level_terms.append(None) + + resulting_high_level_terms = [ + self.reformat_ontology_term_id(i, to_writable=True) for i in resulting_high_level_terms + ] + self._cached_high_level_terms[ontology_term_id] = resulting_high_level_terms + + return resulting_high_level_terms + + def get_top_high_level_term(self, ontology_term_id: str) -> str: + """ + Return the top high level term + """ + + return self.get_high_level_terms(ontology_term_id)[0] + + @abstractmethod + def _get_branch_ancestors(self, owl_entity): + """ + Gets ALL ancestors from an owl entity. What's defined as an ancestor depends on the mapper type, for + example CL ancestors are likely to just include is_a relationship + """ + + def get_label_from_id(self, ontology_term_id: str): + """ + Returns the label from and ontology term id that is in writable form + Example: "UBERON:0002048" returns "lung" + Example: "UBERON_0002048" raises ValueError because the ID is not in writable form + """ + + if ontology_term_id == "unknown": + return "unknown" + + if ontology_term_id in self._cached_labels: + return self._cached_labels[ontology_term_id] + + if ontology_term_id is None: + return None + + entity = self._get_entity_from_id(self.reformat_ontology_term_id(ontology_term_id, to_writable=False)) + if entity: + result = entity.label[0] + else: + result = ontology_term_id + + self._cached_labels[ontology_term_id] = result + return result + + @staticmethod + def reformat_ontology_term_id(ontology_term_id: str, to_writable: bool = True): + """ + Converts ontology term id string between two formats: + - `to_writable == True`: from "UBERON_0002048" to "UBERON:0002048" + - `to_writable == False`: from "UBERON:0002048" to "UBERON_0002048" + """ + + if ontology_term_id is None: + return None + + if to_writable: + if ontology_term_id.count("_") != 1: + raise ValueError(f"{ontology_term_id} is an invalid ontology term id, it must contain exactly one '_'") + return ontology_term_id.replace("_", ":") + else: + if ontology_term_id.count(":") != 1: + raise ValueError(f"{ontology_term_id} is an invalid ontology term id, it must contain exactly one ':'") + return ontology_term_id.replace(":", "_") + + def _list_ancestors(self, entity: owlready2.entity.ThingClass, ancestors: List[str] = []) -> List[str]: + """ + Recursive function that given an entity of an ontology, it traverses the ontology and returns + a list of all ancestors associated with the entity. + """ + + if self._is_restriction(entity): + # Entity is a restriction, check for part_of relationship + + prop = entity.property.name + if prop != "BFO_0000050": + # BFO_0000050 is "part of" + return ancestors + ancestors.append(entity.value.name.replace("obo.", "")) + + # Check for ancestors of restriction + self._list_ancestors(entity.value, ancestors) + return ancestors + + elif self._is_entity(entity) and not self._is_and_object(entity): + # Entity is a superclass, check for is_a relationships + + if entity.name in self.BLOCK_LIST: + return ancestors + ancestors.append(entity.name) + + # Check for ancestors of superclass + for super_entity in entity.is_a: + self._list_ancestors(super_entity, ancestors) + return ancestors + + def _get_entity_from_id(self, ontology_term_id: str) -> owlready2.entity.ThingClass: + """ + Given a readable ontology term id (e.g. "UBERON_0002048"), it returns the associated ontology entity + """ + return self._ontology.search_one(iri=f"http://purl.obolibrary.org/obo/{ontology_term_id}") + + @staticmethod + def _is_restriction(entity: owlready2.entity.ThingClass) -> bool: + return hasattr(entity, "value") + + @staticmethod + def _is_entity(entity: owlready2.entity.ThingClass) -> bool: + return hasattr(entity, "name") + + @staticmethod + def _is_and_object(entity: owlready2.entity.ThingClass) -> bool: + return hasattr(entity, "Classes") + + +class CellMapper(OntologyMapper): + # From schema 5.0.0 https://github.com/chanzuckerberg/single-cell-curation/blob/main/schema/5.0.0/schema.md + CXG_CL_ONTOLOGY_URL = "https://github.com/obophenotype/cell-ontology/releases/download/v2024-01-04/cl.owl" + # Only look up ancestors under Cell + ROOT_NODE = "CL_0000000" + + def __init__(self, cell_type_high_level_ontology_term_ids: List[str]): + super(CellMapper, self).__init__( + high_level_ontology_term_ids=cell_type_high_level_ontology_term_ids, + ontology_owl_path=self.CXG_CL_ONTOLOGY_URL, + root_ontology_term_id=self.ROOT_NODE, + ) + + def _get_branch_ancestors(self, owl_entity): + branch_ancestors = [] + for is_a in self._get_is_a_for_cl(owl_entity): + branch_ancestors = self._list_ancestors(is_a, branch_ancestors) + + return set(branch_ancestors) + + @staticmethod + def _get_is_a_for_cl(owl_entity): + # TODO make this a recurrent function instead of 2-level for nested loop + result = [] + for is_a in owl_entity.is_a: + if CellMapper._is_entity(is_a): + result.append(is_a) + elif CellMapper._is_and_object(is_a): + for is_a_2 in is_a.get_Classes(): + if CellMapper._is_entity(is_a_2): + result.append(is_a_2) + + return result + + +class TissueMapper(OntologyMapper): + # From schema 5.0.0 https://github.com/chanzuckerberg/single-cell-curation/blob/main/schema/5.0.0/schema.md + CXG_UBERON_ONTOLOGY_URL = "https://github.com/obophenotype/uberon/releases/download/v2024-01-18/uberon.owl" + + # Only look up ancestors under anatomical entity + ROOT_NODE = "UBERON_0001062" + + def __init__(self, tissue_high_level_ontology_term_ids: List[str]): + self.cell_type_high_level_ontology_term_ids = tissue_high_level_ontology_term_ids + super(TissueMapper, self).__init__( + high_level_ontology_term_ids=tissue_high_level_ontology_term_ids, + ontology_owl_path=self.CXG_UBERON_ONTOLOGY_URL, + root_ontology_term_id=self.ROOT_NODE, + ) + + def _get_branch_ancestors(self, owl_entity): + branch_ancestors = [] + for is_a in owl_entity.is_a: + branch_ancestors = self._list_ancestors(is_a, branch_ancestors) + + return set(branch_ancestors) + + +class OrganMapper(TissueMapper): + # List of tissue classes, ORDER MATTERS. If for a given cell type there are multiple cell classes associated + # then `self.get_top_high_level_term()` returns the one that appears first in th this list + ORGANS = [ + "UBERON_0000992", # ovary + "UBERON_0000029", # lymph node + "UBERON_0002048", # lung + "UBERON_0002110", # gallbladder + "UBERON_0001043", # esophagus + "UBERON_0003889", # fallopian tube + "UBERON_0018707", # bladder organ + "UBERON_0000178", # blood + "UBERON_0002371", # bone marrow + "UBERON_0000955", # brain + "UBERON_0000310", # breast + "UBERON_0000970", # eye + "UBERON_0000948", # heart + "UBERON_0000160", # intestine + "UBERON_0002113", # kidney + "UBERON_0002107", # liver + "UBERON_0000004", # nose + "UBERON_0001264", # pancreas + "UBERON_0001987", # placenta + "UBERON_0002097", # skin of body + "UBERON_0002240", # spinal cord + "UBERON_0002106", # spleen + "UBERON_0000945", # stomach + "UBERON_0002370", # thymus + "UBERON_0002046", # thyroid gland + "UBERON_0001723", # tongue + "UBERON_0000995", # uterus + "UBERON_0001013", # adipose tissue + ] + + def __init__(self): + super().__init__(tissue_high_level_ontology_term_ids=self.ORGANS) + + +class SystemMapper(TissueMapper): + # List of tissue classes, ORDER MATTERS. If for a given cell type there are multiple cell classes associated + # then `self.get_top_high_level_term()` returns the one that appears first in th this list + SYSTEMS = [ + "UBERON_0001017", # central nervous system + "UBERON_0000010", # peripheral nervous system + "UBERON_0001016", # nervous system + "UBERON_0001009", # circulatory system + "UBERON_0002390", # hematopoietic system + "UBERON_0004535", # cardiovascular system + "UBERON_0001004", # respiratory system + "UBERON_0001007", # digestive system + "UBERON_0000922", # embryo + "UBERON_0000949", # endocrine system + "UBERON_0002330", # exocrine system + "UBERON_0002405", # immune system + "UBERON_0001434", # skeletal system + "UBERON_0000383", # musculature of body + "UBERON_0001008", # renal system + "UBERON_0000990", # reproductive system + "UBERON_0001032", # sensory system + ] + + def __init__(self): + super().__init__(tissue_high_level_ontology_term_ids=self.SYSTEMS) + + +class TissueGeneralMapper(TissueMapper): + # List of tissue classes, ORDER MATTERS. If for a given cell type there are multiple cell classes associated + # then `self.get_top_high_level_term()` returns the one that appears first in th this list + TISSUE_GENERAL = [ + "UBERON_0000178", # blood + "UBERON_0002048", # lung + "UBERON_0002106", # spleen + "UBERON_0002371", # bone marrow + "UBERON_0002107", # liver + "UBERON_0002113", # kidney + "UBERON_0000955", # brain + "UBERON_0002240", # spinal cord + "UBERON_0000310", # breast + "UBERON_0000948", # heart + "UBERON_0002097", # skin of body + "UBERON_0000970", # eye + "UBERON_0001264", # pancreas + "UBERON_0001043", # esophagus + "UBERON_0001155", # colon + "UBERON_0000059", # large intestine + "UBERON_0002108", # small intestine + "UBERON_0000160", # intestine + "UBERON_0000945", # stomach + "UBERON_0001836", # saliva + "UBERON_0001723", # tongue + "UBERON_0001013", # adipose tissue + "UBERON_0000473", # testis + "UBERON_0002367", # prostate gland + "UBERON_0000057", # urethra + "UBERON_0000056", # ureter + "UBERON_0003889", # fallopian tube + "UBERON_0000995", # uterus + "UBERON_0000992", # ovary + "UBERON_0002110", # gall bladder + "UBERON_0001255", # urinary bladder + "UBERON_0018707", # bladder organ + "UBERON_0000922", # embryo + "UBERON_0004023", # ganglionic eminence --> this a part of the embryo, remove in case generality is desired + "UBERON_0001987", # placenta + "UBERON_0007106", # chorionic villus + "UBERON_0002369", # adrenal gland + "UBERON_0002368", # endocrine gland + "UBERON_0002365", # exocrine gland + "UBERON_0000030", # lamina propria + "UBERON_0000029", # lymph node + "UBERON_0004536", # lymph vasculature + "UBERON_0001015", # musculature + "UBERON_0000004", # nose + "UBERON_0003688", # omentum + "UBERON_0000977", # pleura + "UBERON_0002370", # thymus + "UBERON_0002049", # vasculature + "UBERON_0009472", # axilla + "UBERON_0001087", # pleural fluid + "UBERON_0000344", # mucosa + "UBERON_0001434", # skeletal system + "UBERON_0002228", # rib + "UBERON_0003129", # skull + "UBERON_0004537", # blood vasculature + "UBERON_0002405", # immune system + "UBERON_0001009", # circulatory system + "UBERON_0001007", # digestive system + "UBERON_0001017", # central nervous system + "UBERON_0001008", # renal system + "UBERON_0000990", # reproductive system + "UBERON_0001004", # respiratory system + "UBERON_0000010", # peripheral nervous system + "UBERON_0001032", # sensory system + "UBERON_0002046", # thyroid gland + "UBERON_0004535", # cardiovascular system + "UBERON_0000949", # endocrine system + "UBERON_0002330", # exocrine system + "UBERON_0002390", # hematopoietic system + "UBERON_0000383", # musculature of body + "UBERON_0001465", # knee + "UBERON_0001016", # nervous system + "UBERON_0001348", # brown adipose tissue + "UBERON_0015143", # mesenteric fat pad + "UBERON_0000175", # pleural effusion + "UBERON_0001416", # skin of abdomen + "UBERON_0001868", # skin of chest + "UBERON_0001511", # skin of leg + "UBERON_0002190", # subcutaneous adipose tissue + "UBERON_0000014", # zone of skin + "UBERON_0000916", # abdomen + ] + + def __init__(self): + super().__init__(tissue_high_level_ontology_term_ids=self.TISSUE_GENERAL) + + +class CellClassMapper(CellMapper): + # List of cell classes, ORDER MATTERS. If for a given cell type there are multiple cell classes associated + # then `self.get_top_high_level_term()` returns the one that appears first in th this list + CELL_CLASS = [ + "CL_0002494", # cardiocyte + "CL_0002320", # connective tissue cell + "CL_0000473", # defensive cell + "CL_0000066", # epithelial cell + "CL_0000988", # hematopoietic cell + "CL_0002319", # neural cell + "CL_0011115", # precursor cell + "CL_0000151", # secretory cell + "CL_0000039", # NEW germ cell line + "CL_0000064", # NEW ciliated cell + "CL_0000183", # NEW contractile cell + "CL_0000188", # NEW cell of skeletal muscle + "CL_0000219", # NEW motile cell + "CL_0000325", # NEW stuff accumulating cell + "CL_0000349", # NEW extraembryonic cell + "CL_0000586", # NEW germ cell + "CL_0000630", # NEW supporting cell + "CL_0001035", # NEW bone cell + "CL_0001061", # NEW abnormal cell + "CL_0002321", # NEW embryonic cell (metazoa) + "CL_0009010", # NEW transit amplifying cell + "CL_1000600", # NEW lower urinary tract cell + "CL_4033054", # NEW perivascular cell + ] + + def __init__(self): + super().__init__(cell_type_high_level_ontology_term_ids=self.CELL_CLASS) + + +class CellSubclassMapper(CellMapper): + # List of cell classes, ORDER MATTERS. If for a given cell type there are multiple cell classes associated + # then `self.get_top_high_level_term()` returns the one that appears first in th this list + CELL_SUB_CLASS = [ + "CL_0002494", # cardiocyte + "CL_0000624", # CD4-positive, alpha-beta T cell + "CL_0000625", # CD8-positive, alpha-beta T cell + "CL_0000084", # T cell + "CL_0000236", # B cell + "CL_0000451", # dendritic cell + "CL_0000576", # monocyte + "CL_0000235", # macrophage + "CL_0000542", # lymphocyte + "CL_0000738", # leukocyte + "CL_0000763", # myeloid cell + "CL_0008001", # hematopoietic precursor cell + "CL_0000234", # phagocyte + "CL_0000679", # glutamatergic neuron + "CL_0000617", # GABAergic neuron + "CL_0000099", # interneuron + "CL_0000125", # glial cell + "CL_0000101", # sensory neuron + "CL_0000100", # motor neuron + "CL_0000117", # CNS neuron (sensu Vertebrata) + "CL_0000540", # neuron + "CL_0000669", # pericyte + "CL_0000499", # stromal cell + "CL_0000057", # fibroblast + "CL_0000152", # exocrine cell + "CL_0000163", # endocrine cell + "CL_0000115", # endothelial cell + "CL_0002076", # endo-epithelial cell + "CL_0002078", # meso-epithelial cell + "CL_0011026", # progenitor cell + "CL_0000015", # NEW male germ cell + "CL_0000021", # NEW female germ cell + "CL_0000034", # NEW stem cell + "CL_0000055", # NEW non-terminally differentiated cell + "CL_0000068", # NEW duct epithelial cell + "CL_0000075", # NEW columnar/cuboidal epithelial cell + "CL_0000076", # NEW squamous epithelial cell + "CL_0000079", # NEW stratified epithelial cell + "CL_0000082", # NEW epithelial cell of lung + "CL_0000083", # NEW epithelial cell of pancreas + "CL_0000095", # NEW neuron associated cell + "CL_0000098", # NEW sensory epithelial cell + "CL_0000136", # NEW fat cell + "CL_0000147", # NEW pigment cell + "CL_0000150", # NEW glandular epithelial cell + "CL_0000159", # NEW seromucus secreting cell + "CL_0000182", # NEW hepatocyte + "CL_0000186", # NEW myofibroblast cell + "CL_0000187", # NEW muscle cell + "CL_0000221", # NEW ectodermal cell + "CL_0000222", # NEW mesodermal cell + "CL_0000244", # NEW urothelial cell + "CL_0000351", # NEW trophoblast cell + "CL_0000584", # NEW enterocyte + "CL_0000586", # NEW germ cell + "CL_0000670", # NEW primordial germ cell + "CL_0000680", # NEW muscle precursor cell + "CL_0001063", # NEW neoplastic cell + "CL_0002077", # NEW ecto-epithelial cell + "CL_0002222", # NEW vertebrate lens cell + "CL_0002327", # NEW mammary gland epithelial cell + "CL_0002503", # NEW adventitial cell + "CL_0002518", # NEW kidney epithelial cell + "CL_0002535", # NEW epithelial cell of cervix + "CL_0002536", # NEW epithelial cell of amnion + "CL_0005006", # NEW ionocyte + "CL_0008019", # NEW mesenchymal cell + "CL_0008034", # NEW mural cell + "CL_0009010", # NEW transit amplifying cell + "CL_1000296", # NEW epithelial cell of urethra + "CL_1000497", # NEW kidney cell + "CL_2000004", # NEW pituitary gland cell + "CL_2000064", # NEW ovarian surface epithelial cell + "CL_4030031", # NEW interstitial cell + ] + + def __init__(self, map_orphans_to_class: bool = False): + if map_orphans_to_class: + cell_type_high_level = self.CELL_SUB_CLASS + CellClassMapper.CELL_CLASS + else: + cell_type_high_level = self.CELL_SUB_CLASS + super().__init__(cell_type_high_level_ontology_term_ids=cell_type_high_level) diff --git a/tools/models/metrics/requirements.txt b/tools/models/metrics/requirements.txt new file mode 100644 index 000000000..a3dd411e6 --- /dev/null +++ b/tools/models/metrics/requirements.txt @@ -0,0 +1,3 @@ +owlready2 +scib-metrics==0.5.1 +pyyaml \ No newline at end of file diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py new file mode 100644 index 000000000..be761b962 --- /dev/null +++ b/tools/models/metrics/run-scib.py @@ -0,0 +1,407 @@ +# ruff: noqa +# type: ignore + +import datetime +import functools +import itertools +import pickle +import sys +import warnings + +import cellxgene_census +import numpy as np +import ontology_mapper +import pandas as pd +import scanpy as sc +import scib_metrics +import tiledbsoma as soma +import yaml +from sklearn import svm +from sklearn.ensemble import RandomForestClassifier +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import accuracy_score, roc_auc_score +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import LabelEncoder + +warnings.filterwarnings("ignore") + + +class CensusClassifierMetrics: + def __init__(self): + self._default_metric = "accuracy" + + def lr_labels(self, X, labels, metric=None): + return self._base_accuracy(X, labels, LogisticRegression, metric=metric) + + def svm_svc_labels(self, X, labels, metric=None): + return self._base_accuracy(X, labels, svm.SVC, metric=metric) + + def random_forest_labels(self, X, labels, metric=None, n_jobs=8): + return self._base_accuracy(X, labels, RandomForestClassifier, metric=metric, n_jobs=n_jobs) + + def lr_batch(self, X, batch, metric=None): + return 1 - self._base_accuracy(X, batch, LogisticRegression, metric=metric) + + def svm_svc_batch(self, X, batch, metric=None): + return 1 - self._base_accuracy(X, batch, svm.SVC, metric=metric) + + def random_forest_batch(self, X, batch, metric=None, n_jobs=8): + return 1 - self._base_accuracy(X, batch, RandomForestClassifier, metric=metric, n_jobs=n_jobs) + + def _base_accuracy(self, X, y, model, metric, test_size=0.4, **kwargs): + """Train LogisticRegression on X with labels y and return classifier accuracy score""" + y_encoded = LabelEncoder().fit_transform(y) + X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=test_size, random_state=42) + model = model(**kwargs).fit(X_train, y_train) + + if metric == None: + metric = self._default_metric + + if metric == "roc_auc": + # return y_test + # return model.predict_proba(X_test) + return roc_auc_score(y_test, model.predict_proba(X_test), multi_class="ovo", average="macro") + elif metric == "accuracy": + return accuracy_score(y_test, model.predict(X_test)) + else: + raise ValueError("Only {'accuracy', 'roc_auc'} are supported as a metric") + + +def safelog(a): + return np.log(a, out=np.zeros_like(a), where=(a != 0)) + + +def nearest_neighbors_hnsw(x, ef=200, M=48, n_neighbors=100): + import hnswlib + + labels = np.arange(x.shape[0]) + p = hnswlib.Index(space="l2", dim=x.shape[1]) + p.init_index(max_elements=x.shape[0], ef_construction=ef, M=M) + p.add_items(x, labels) + p.set_ef(ef) + idx, dist = p.knn_query(x, k=n_neighbors) + return idx, dist + + +def compute_entropy_per_cell(adata, obsm_key, batch_key): + indices, dist = nearest_neighbors_hnsw(adata.obsm[obsm_key], n_neighbors=200) + + batch_labels = np.array(list(adata.obs[batch_key])) + unique_batch_labels = np.unique(batch_labels) + + indices_batch = batch_labels[indices] + + label_counts_per_cell = np.vstack([(indices_batch == label).sum(1) for label in unique_batch_labels]).T + label_counts_per_cell_normed = label_counts_per_cell / label_counts_per_cell.sum(1)[:, None] + return (-label_counts_per_cell_normed * safelog(label_counts_per_cell_normed)).sum(1) + + +if __name__ == "__main__": + try: + file = sys.argv[1] + except IndexError: + file = "scib-metrics-config.yaml" + + with open(file) as f: + config = yaml.safe_load(f) + + census_config = config.get("census") + embedding_config = config.get("embeddings") + metrics_config = config.get("metrics") + + census_version = census_config.get("version") + experiment_name = census_config.get("organism") + + # These are embeddings hosted in the Census + embeddings_census = embedding_config.get("census") or [] + + # Raw embeddings (external) + embeddings_raw = embedding_config.get("raw") or dict() + + # All embedding names + embs = list(embeddings_census) + list(embeddings_raw.keys()) + + print("Embeddings to use: ", embs) + + census = cellxgene_census.open_soma(census_version=census_version) + + def subclass_mapper(): + mapper = ontology_mapper.CellSubclassMapper(map_orphans_to_class=True) + cell_types = ( + census["census_data"]["homo_sapiens"] + .obs.read(column_names=["cell_type_ontology_term_id"], value_filter="is_primary_data == True") + .concat() + .to_pandas() + ) + cell_types = cell_types["cell_type_ontology_term_id"].drop_duplicates() + subclass_dict = {i: mapper.get_label_from_id(mapper.get_top_high_level_term(i)) for i in cell_types} + return subclass_dict + + def class_mapper(): + mapper = ontology_mapper.CellClassMapper() + cell_types = ( + census["census_data"]["homo_sapiens"] + .obs.read(column_names=["cell_type_ontology_term_id"], value_filter="is_primary_data == True") + .concat() + .to_pandas() + ) + cell_types = cell_types["cell_type_ontology_term_id"].drop_duplicates() + class_dict = {i: mapper.get_label_from_id(mapper.get_top_high_level_term(i)) for i in cell_types} + return class_dict + + class_dict = class_mapper() + subclass_dict = subclass_mapper() + + def build_anndata_with_embeddings( + embedding_names: list[str], + embeddings_raw: dict, + coords: list[int] = None, + obs_value_filter: str = None, + column_names=dict, + census_version: str = None, + experiment_name: str = None, + ): + """For a given set of Census cell coordinates (soma_joinids) + fetch embeddings with TileDBSoma and return the corresponding + AnnData with embeddings slotted in. + + `embedding_names` is a list with embedding names included in Census. + `embeddings_raw` are embeddings provided in raw format (npy) on a local drive + + + Assume that all embeddings provided are coming from the same experiment. + """ + with cellxgene_census.open_soma(census_version=census_version) as census: + print("Getting anndata with Census embeddings: ", embedding_names) + + ad = cellxgene_census.get_anndata( + census, + organism=experiment_name, + measurement_name="RNA", + obs_value_filter=obs_value_filter, + obs_coords=coords, + obs_embeddings=embedding_names, + column_names=column_names, + ) + + obs_soma_joinids = ad.obs["soma_joinid"].to_numpy() + + # For these, we need to extract the right cells via soma_joinid + for key, val in embeddings_raw.items(): + print("Getting raw embedding:", key) + # Alternative approach: set type in the config file + try: + # Assume it's a numpy ndarray + emb = np.load(val["uri"]) + emb_idx = np.load(val["idx"]) + obs_indexer = pd.Index(emb_idx) + idx = obs_indexer.get_indexer(obs_soma_joinids) + ad.obsm[key] = emb[idx] + except Exception: + from scipy.sparse import vstack + + # Assume it's a TileDBSoma URI + all_embs = [] + with soma.open(val["uri"]) as E: + for mat in E.read(coords=(obs_soma_joinids,)).blockwise(axis=0).scipy(): + all_embs.append(mat[0]) + ad.obsm[key] = vstack(all_embs).toarray() + print("DIM:", ad.obsm[key].shape) + + # Embeddings with missing data contain all NaN, + # so we must find the intersection of non-NaN rows in the fetched embeddings + # and subset the AnnData accordingly. + filt = np.ones(ad.shape[0], dtype="bool") + for key in ad.obsm.keys(): + nan_row_sums = np.sum(np.isnan(ad.obsm[key]), axis=1) + total_columns = ad.obsm[key].shape[1] + filt = filt & (nan_row_sums != total_columns) + ad = ad[filt].copy() + + return ad + + column_names = { + "obs": ["cell_type_ontology_term_id", "cell_type", "assay", "suspension_type", "dataset_id", "soma_joinid"] + } + umap_plot_labels = ["cell_subclass", "cell_class", "cell_type", "dataset_id"] + + block_cell_types = ["native cell", "animal cell", "eukaryotic cell", "unknown"] + + all_bio = {} + all_batch = {} + + tissues = metrics_config.get("tissues") + + bio_metrics = metrics_config["bio"] + batch_metrics = metrics_config["batch"] + + for tissue_node in tissues: + tissue = tissue_node["name"] + query = tissue_node.get("query") or f"tissue_general == '{tissue}' and is_primary_data == True" + + print("Tissue", tissue, " getting Anndata") + + # Getting anddata + adata_metrics = build_anndata_with_embeddings( + embedding_names=embeddings_census, + embeddings_raw=embeddings_raw, + obs_value_filter=query, + census_version=census_version, + experiment_name="homo_sapiens", + column_names=column_names, + ) + + for column in adata_metrics.obs.columns: + if adata_metrics.obs[column].dtype.name == "category": + adata_metrics.obs[column] = adata_metrics.obs[column].astype(str) + + # Create batch variable + adata_metrics.obs["batch"] = ( + adata_metrics.obs["assay"] + adata_metrics.obs["dataset_id"] + adata_metrics.obs["suspension_type"] + ) + + # Get cell subclass + adata_metrics.obs["cell_subclass"] = adata_metrics.obs["cell_type_ontology_term_id"].replace(subclass_dict) + adata_metrics = adata_metrics[~adata_metrics.obs["cell_subclass"].isna(),] + + # Get cell class + adata_metrics.obs["cell_class"] = adata_metrics.obs["cell_type_ontology_term_id"].replace(class_dict) + adata_metrics = adata_metrics[~adata_metrics.obs["cell_class"].isna(),] + + # Remove cells in block list of cell types + adata_metrics[~adata_metrics.obs["cell_type"].isin(block_cell_types),] + + print("Tissue", tissue, "cells", adata_metrics.n_obs) + + # Calculate neighbors + for emb_name in embs: + print(datetime.datetime.now(), "Getting neighbors", emb_name) + sc.pp.neighbors(adata_metrics, use_rep=emb_name, key_added=emb_name) + # Only necessary + if "ilisi_knn_batch" in metrics_config["batch"]: + sc.pp.neighbors(adata_metrics, n_neighbors=90, use_rep=emb_name, key_added=emb_name + "_90") + sc.tl.umap(adata_metrics, neighbors_key=emb_name) + adata_metrics.obsm["X_umap_" + emb_name] = adata_metrics.obsm["X_umap"].copy() + del adata_metrics.obsm["X_umap"] + + # Save a few UMAPS + print(datetime.datetime.now(), "Saving UMAP plots") + for emb_name in embs: + for label in umap_plot_labels: + title = "_".join(["UMAP", tissue, emb_name, label]) + sc.pl.embedding( + adata_metrics, basis="X_umap_" + emb_name, color=label, title=title, save=title + ".png" + ) + + bio_labels = ["cell_subclass", "cell_class"] + batch_labels = ["batch", "assay", "dataset_id", "suspension_type"] + + # Initialize results + metric_bio_results = { + "embedding": [], + "bio_label": [], + } + metric_batch_results = { + "embedding": [], + "batch_label": [], + } + + for metric in bio_metrics: + metric_bio_results[metric] = [] + + for metric in batch_metrics: + metric_batch_results[metric] = [] + + # Calculate metrics + for bio_label, emb in itertools.product(bio_labels, embs): + print("\n\nSTART", bio_label, emb) + + metric_bio_results["embedding"].append(emb) + metric_bio_results["bio_label"].append(bio_label) + + print(datetime.datetime.now(), "Calculating ARI Leiden") + + class NN: + def __init__(self, conn): + self.knn_graph_connectivities = conn + + X = NN(adata_metrics.obsp[emb + "_connectivities"]) + + if "leiden_nmi" in bio_metrics and "leiden_ari" in bio_metrics: + this_metric = scib_metrics.nmi_ari_cluster_labels_leiden( + X=X, + labels=adata_metrics.obs[bio_label], + optimize_resolution=True, + resolution=1.0, + n_jobs=64, + ) + metric_bio_results["leiden_nmi"].append(this_metric["nmi"]) + metric_bio_results["leiden_ari"].append(this_metric["ari"]) + + if "silhouette_label" in bio_metrics: + print(datetime.datetime.now(), "Calculating silhouette labels") + + this_metric = scib_metrics.silhouette_label( + X=adata_metrics.obsm[emb], labels=adata_metrics.obs[bio_label], rescale=True, chunk_size=512 + ) + metric_bio_results["silhouette_label"].append(this_metric) + + if "classifier" in bio_metrics: + metrics = CensusClassifierMetrics() + + m1 = metrics.lr_labels(X=adata_metrics.obsm[emb], labels=adata_metrics.obs[bio_label]) + m2 = metrics.svm_svc_labels(X=adata_metrics.obsm[emb], labels=adata_metrics.obs[bio_label]) + m3 = metrics.random_forest_labels(X=adata_metrics.obsm[emb], labels=adata_metrics.obs[bio_label]) + + metric_bio_results["classifier"].append({"lr": m1, "svm": m2, "random_forest": m3}) + + for batch_label, emb in itertools.product(batch_labels, embs): + print("\n\nSTART", batch_label, emb) + + metric_batch_results["embedding"].append(emb) + metric_batch_results["batch_label"].append(batch_label) + + if "silhouette_batch" in batch_metrics: + print(datetime.datetime.now(), "Calculating silhouette batch") + + this_metric = scib_metrics.silhouette_batch( + X=adata_metrics.obsm[emb], + labels=adata_metrics.obs[bio_label], + batch=adata_metrics.obs[batch_label], + rescale=True, + chunk_size=512, + ) + metric_batch_results["silhouette_batch"].append(this_metric) + + if "ilisi_knn_batch" in batch_metrics: + print(datetime.datetime.now(), "Calculating ilisi knn batch") + + ilisi_metric = scib_metrics.ilisi_knn( + X=adata_metrics.obsp[f"{emb}_90_distances"], + batches=adata_metrics.obs[batch_label], + scale=True, + ) + + metric_batch_results["ilisi_knn_batch"].append(ilisi_metric) + + if "classifier" in batch_metrics: + metrics = CensusClassifierMetrics() + + m4 = metrics.lr_batch(X=adata_metrics.obsm[emb], batch=adata_metrics.obs[batch_label]) + m5 = metrics.random_forest_batch(X=adata_metrics.obsm[emb], batch=adata_metrics.obs[batch_label]) + m6 = metrics.svm_svc_batch(X=adata_metrics.obsm[emb], batch=adata_metrics.obs[batch_label]) + metric_batch_results["classifier"].append({"lr": m4, "random_forest": m5, "svm": m6}) + + if "entropy" in batch_metrics: + print(datetime.datetime.now(), "Calculating entropy") + + entropy = compute_entropy_per_cell(adata_metrics, emb, batch_label) + e_mean = entropy.mean() + metric_batch_results["entropy"].append(e_mean) + + filename = f"metrics.{tissue}.pickle".replace(" ", "-").lower() + + with open(filename, "wb") as fp: + pickle.dump( + {"bio": metric_bio_results, "batch": metric_batch_results}, fp, protocol=pickle.HIGHEST_PROTOCOL + ) diff --git a/tools/models/metrics/scib-metrics-config.yaml b/tools/models/metrics/scib-metrics-config.yaml new file mode 100644 index 000000000..2ba7d30eb --- /dev/null +++ b/tools/models/metrics/scib-metrics-config.yaml @@ -0,0 +1,18 @@ +census: + version: + "2023-12-15" + organism: + "homo_sapiens" +embeddings: + census: + [scvi, geneformer] +metrics: + tissues: + - name: "adipose tissue" + - name: "spinal cord" + - name: "heart" + query: 'tissue in ["cardiac ventricle", "heart left ventricle", "heart right ventricle"] and dataset_id in ["53d208b0-2cfd-4366-9866-c3c6114081bc", "d567b692-c374-4628-a508-8008f6778f22", "f15e263b-6544-46cb-a46e-e33ab7ce8347", "d4e69e01-3ba2-4d6b-a15d-e7048f78f22e"] and is_primary_data==True' + bio: + ["leiden_nmi", "leiden_ari", "silhouette_label", "classifier"] + batch: + ["silhouette_batch", "ilisi_knn_batch", "classifier", "entropy"] \ No newline at end of file