# tissue

In [1]:
import logging
import typing as t
from functools import lru_cache
from logging import log

import networkx as nx
import owlready2
from scipy import sparse as sp

import numpy as np
import torch
import pandas as pd
import pandas_gbq
from google.cloud import bigquery

In [2]:
# Query template format
query_format = """
    select {column_name}, count(*) as num_cells 
    from `cas_2024_05_16_dataset.human_cellariumgpt_extract__extract_cell_info`
    group by {column_name}
    order by num_cells desc
"""

In [3]:
def get_ancestors_csr_matrix(graph, names_to_idx_map: t.Dict[str, int]) -> sp.csr_matrix:
    """Returns a sparse matrix representation of ancestors.

    .. note:
        The matrix element (i, j) = 1 iff j is an ancetor of i.
    """
    n_nodes = len(graph.nodes)

    row = []
    col = []
    data = []

    for name, self_idx in names_to_idx_map.items():
        row.append(self_idx)
        col.append(self_idx)
        data.append(1)
        for ancestor_name in nx.ancestors(graph, name):
            ancestor_idx = names_to_idx_map[ancestor_name]
            row.append(self_idx)
            col.append(ancestor_idx)
            data.append(1)

    ancestors_csr_matrix = sp.csr_matrix((data, (row, col)), shape=(n_nodes, n_nodes))
    return ancestors_csr_matrix

In [4]:
def get_shortest_distances_matrix(graph, names_to_idx_map: t.Dict[str, int]) -> np.ndarray:
    """Returns a sparse matrix representation of shortest distances.

    .. note:
        The matrix element (i, j) = d iff d is the shortest distance between i and j.
    """
    n_nodes = len(graph.nodes)

    distance_matrix = np.full((n_nodes, n_nodes), np.inf)

    for target, value in dict(nx.all_pairs_shortest_path_length(graph)).items():
        for source, distance in value.items():
            source_idx = names_to_idx_map[source]
            target_idx = names_to_idx_map[target]
            distance_matrix[source_idx, target_idx] = distance
            distance_matrix[target_idx, source_idx] = distance

    return distance_matrix

## extract

In [5]:
query_name = "tissue_ontology_term_id"
names_counts = pandas_gbq.read_gbq(
    query_format.format(column_name=query_name),
    project_id="dsp-cell-annotation-service",
)
names_counts.head()

  record_batch = self.to_arrow(


Unnamed: 0,tissue_ontology_term_id,num_cells
0,UBERON:0000178,9446186
1,UBERON:0002048,2412727
2,UBERON:0000310,2269902
3,UBERON:0009834,1591509
4,UBERON:0002771,1317262


In [6]:
query_label = "tissue"
labels_counts = pandas_gbq.read_gbq(
    query_format.format(column_name=query_label),
    project_id="dsp-cell-annotation-service",
)
labels_counts

  record_batch = self.to_arrow(


Unnamed: 0,tissue,num_cells
0,blood,9446186
1,lung,2412727
2,breast,2269902
3,dorsolateral prefrontal cortex,1591509
4,middle temporal gyrus,1317262
...,...,...
259,left parietal lobe,163
260,vault of skull,123
261,kidney blood vessel,47
262,skin of face,37


How many different ontologies are used for `tissue`?

In [7]:
set([value[0] for value in names_counts[query_name].str.split(":")])

{'CL', 'UBERON'}

## ontology

In [31]:
# Used in CZ CELLxGENE schema v5:
# https://github.com/chanzuckerberg/single-cell-curation/blob/main/schema/5.0.0/schema.md
OWL_PATH = "https://github.com/obophenotype/uberon/releases/download/v2024-01-18/uberon.owl"

# only keep nodes with the following prefix when parsing CL ontology
PREFIX = "UBERON_"

# the 'cell' node
LIFE_CYCLE_NODE = "UBERON:0000104"

# the 'eukaryotic cell' node
ANATOMICAL_ENTITY_NODE = "UBERON:0001062"

# relationships we need
PARTOF_RELATIONSHIP = "BFO_0000050"  # part_of

In [32]:
ontology = owlready2.get_ontology(OWL_PATH).load()

In [33]:
all_classes = list(ontology.classes())

In [34]:
# only keep CL classes with a singleton label
classes = list(
    _class for _class in all_classes if _class.name.startswith(PREFIX) and len(_class.label) == 1
)

names = [_class.name.replace("_", ":") for _class in classes]
labels = [str(_class.label[0]) for _class in classes]
assert len(set(names)) == len(classes)
# assert len(set(labels)) == len(classes)

In [35]:
classes_set = set(classes)
names_to_labels_map = {name: label for name, label in zip(names, labels)}
names_to_idx_map = {name: idx for idx, name in enumerate(names)}
# labels_to_names_map = {label: name for name, label in zip(names, labels)}
# labels_to_idx_map = {label: idx for idx, label in enumerate(labels)}
idx_to_names_map = {idx: name for idx, name in enumerate(names)}
idx_to_labels_map = {idx: label for idx, label in enumerate(labels)}

Are all `tissue`s covered by the ontology tree?

In [36]:
set(labels_counts[query_label]) - set(labels)

{'embryonic stem cell'}

In [37]:
set(names_counts[query_name]) - set(names)

{'CL:0002322'}

In [38]:
len(labels)

15567

In [39]:
# build a networkx graph from CL
graph = nx.DiGraph(name="CL graph")

for _class in classes:
    graph.add_node(_class.name.replace("_", ":"))

for self_class in classes:
    # parents
    for parent_class in ontology.get_parents_of(self_class):
        if parent_class not in classes_set:
            continue
        graph.add_edge(parent_class.name.replace("_", ":"), self_class.name.replace("_", ":"))
    # children
    for child_class in ontology.get_children_of(self_class):
        if child_class not in classes_set:
            continue
        graph.add_edge(self_class.name.replace("_", ":"), child_class.name.replace("_", ":"))
    # part of
    for prop in self_class.get_class_properties():
        if PARTOF_RELATIONSHIP in prop.name:
            for related_term in prop[self_class]:
                if related_term.name.startswith(PREFIX):
                    graph.add_edge(related_term.name.replace("_", ":"), self_class.name.replace("_", ":"))

    # deprecated terms (WHY???!!)
    if "deprecated" in [prop.name for prop in self_class.get_class_properties()]:
        for prop in self_class.get_class_properties():
            if "consider" in prop.name:
                for substitute in prop[self_class]:
                    if substitute.startswith(PREFIX):
                        graph.add_edge(substitute, self_class.name.replace("_", ":"))

In [40]:
ancestors_matrix = get_ancestors_csr_matrix(graph, names_to_idx_map).toarray()
ancestors_matrix

array([[1, 0, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       [0, 1, 1, ..., 0, 0, 0],
       ...,
       [0, 1, 0, ..., 1, 0, 0],
       [0, 1, 0, ..., 0, 1, 0],
       [0, 1, 0, ..., 0, 0, 1]])

In [48]:
shortest_distances_matrix = get_shortest_distances_matrix(graph, names_to_idx_map)
shortest_distances_matrix = np.where(ancestors_matrix, shortest_distances_matrix, np.inf)
shortest_distances_matrix

array([[ 0., inf, inf, ..., inf, inf, inf],
       [inf,  0., inf, ..., inf, inf, inf],
       [inf,  4.,  0., ..., inf, inf, inf],
       ...,
       [inf,  6., inf, ...,  0., inf, inf],
       [inf,  6., inf, ..., inf,  0., inf],
       [inf,  5., inf, ..., inf, inf,  0.]])

Only leave nodes that exist in the data and their ancestors.

In [49]:
extract_names_set = set(names_counts[query_name]) - set(["CL:0002322"])
extract_idx = [idx for name, idx in names_to_idx_map.items() if name in extract_names_set]
assert len(extract_idx) == len(extract_names_set)
len(extract_idx)

263

In [50]:
new_extract_idx = ancestors_matrix[extract_idx].any(axis=0).nonzero()[0].tolist()
new_extract_idx = [idx for idx in new_extract_idx if idx_to_names_map[idx] not in {LIFE_CYCLE_NODE, ANATOMICAL_ENTITY_NODE}]
len(new_extract_idx)

822

In [51]:
new_names = [idx_to_names_map[idx] for idx in new_extract_idx]
new_labels = [idx_to_labels_map[idx] for idx in new_extract_idx]

In [52]:
new_shortest_distances_matrix = shortest_distances_matrix[new_extract_idx][:, new_extract_idx]
# new_longest_distances_matrix = longest_distances_matrix[new_extract_idx][:, new_extract_idx]
new_ancestors_matrix = ancestors_matrix[new_extract_idx][:, new_extract_idx]

In [53]:
torch.save({
        "names": new_names,
        "labels": new_labels,
        "shortest_distances_matrix": torch.tensor(new_shortest_distances_matrix, dtype=torch.float32),
        # "longest_distances_matrix": torch.tensor(new_longest_distances_matrix, dtype=torch.float32),
        "ancestors_matrix": torch.tensor(new_ancestors_matrix, dtype=torch.int32),
    },
    "tissue_ontology_data.pt"
)