In [2]:
from backend.wmg.data.rollup import rollup_across_cell_type_descendants
import owlready2
import json
import tiledb
from backend.wmg.data.ontology_labels import ontology_term_label, ontology_term_id_labels
import pandas as pd




# Build ontology tree JSON

In [29]:
def traverse(node):
    subclasses = list(node.subclasses())
    node_id = node.name.replace('_',':')
    if len(subclasses) == 0:
        return {"id": node.name,
                "name": id_to_name[node_id] if node_id in id_to_name else node_id,
                "n_cells_rollup": int(cell_counts_df_rollup[node_id] if node_id in cell_counts_df_rollup else 0),
                "n_cells_rollup_normalized": float(cell_counts_df_rollup_norm[node_id] if node_id in cell_counts_df_rollup_norm else 0),                
                "n_cells": int(cell_counts_df[node_id] if node_id in cell_counts_df else 0),
                "n_cells_normalized": float(cell_counts_df_norm[node_id] if node_id in cell_counts_df_norm else 0),                                
               }

    children = []
    for child in subclasses:
        children.append(traverse(child))

    return {"id": node.name,
                "name": id_to_name[node_id] if node_id in id_to_name else node_id,
                "n_cells_rollup": int(cell_counts_df_rollup[node_id] if node_id in cell_counts_df_rollup else 0),
                "n_cells_rollup_normalized": float(cell_counts_df_rollup_norm[node_id] if node_id in cell_counts_df_rollup_norm else 0),
                "n_cells": int(cell_counts_df[node_id] if node_id in cell_counts_df else 0),
                "n_cells_normalized": float(cell_counts_df_norm[node_id] if node_id in cell_counts_df_norm else 0),
                "children": children
               }

cell_counts = tiledb.open('prod-snapshot/cell_counts')
cell_counts_df = cell_counts.df[:]
cell_counts_df = cell_counts_df.groupby('cell_type_ontology_term_id').sum(numeric_only=True).reset_index()

all_cell_types = [{k: ontology_term_label(k)} for k in ontology_term_id_labels if k.startswith('CL:')]
all_cell_types_ids = [list(i.keys())[0] for i in all_cell_types]
to_attach = pd.DataFrame()
to_attach['cell_type_ontology_term_id']=[i for i in all_cell_types_ids if i not in cell_counts_df['cell_type_ontology_term_id'].values]
to_attach['n_cells']=0

cell_counts_df = pd.concat([cell_counts_df,to_attach],axis=0)
cell_counts_df_rollup = rollup_across_cell_type_descendants(cell_counts_df).set_index('cell_type_ontology_term_id')['n_cells']
cell_counts_df = cell_counts_df.set_index('cell_type_ontology_term_id')['n_cells']

cell_counts_df_rollup_norm = cell_counts_df_rollup/cell_counts_df_rollup.max()
cell_counts_df_norm = cell_counts_df/cell_counts_df.max()
id_to_name = pd.Series(index=cell_counts_df.index,data=[ontology_term_label(i) for i in cell_counts_df.index])


ontology = owlready2.get_ontology("https://github.com/obophenotype/cell-ontology/releases/latest/download/cl-basic.owl")
ontology.load()

root_node = ontology.world["http://purl.obolibrary.org/obo/CL_0000000"]

a = traverse(root_node)

json.dump(a,open('ontologyRawTree.json','w'))

# Get initial open/closed state of nodes per cell type

In [240]:
def _children(cell_type):
    cell_type_iri = cell_type.replace(":", "_")
    entity = ontology.search_one(iri=f"http://purl.obolibrary.org/obo/{cell_type_iri}")
    descendants = [i.name.replace("_", ":") for i in entity.subclasses()] if entity else [cell_type]
    return descendants

def _parents(cell_type):
    cell_type_iri = cell_type.replace(":", "_")
    entity = ontology.search_one(iri=f"http://purl.obolibrary.org/obo/{cell_type_iri}")    
    parent_names = [parent.name.replace("_",":") for parent in entity.is_a if isinstance(parent, owlready2.ThingClass) if parent.name!= "Thing"]
    return parent_names

def dfs(parents, end, start, node=None, path = None, all_paths = []):
    if path is None and node is None:
        path = [end]
        node = end

    if node == start:
        return path
    
    for parent in parents.get(node,[]):
        full_path = dfs(parents, end, start, node=parent, path = path+[parent], all_paths=all_paths)
        if full_path:
            all_paths.append(full_path)
    

In [241]:
all_children = {i: _children(i) for i in all_cell_types_ids}
all_parents = {i: _parents(i) for i in all_cell_types_ids}

In [248]:
start_node = 'CL:0000000'
open_nodes_dict={}
for end_node in all_cell_types_ids:
    all_paths = []
    dfs(all_parents,end_node,start_node,all_paths=all_paths)
    visited_nodes = list(set(sum(all_paths,[])))
    open_nodes_dict[end_node] = visited_nodes
json.dump(open_nodes_dict, open('initialNodeExpandedState.json','w'))