In [None]:
cd ../

In [None]:
import cellxgene_census
from backend.wmg.data.rollup import rollup_across_cell_type_descendants
import json
import tiledb
from backend.wmg.data.ontology_labels import ontology_term_label, ontology_term_id_labels
import pandas as pd
import numpy as np
from backend.wmg.pipeline.integrated_corpus.transform import get_high_level_tissue
from pronto import Ontology

def traverse_with_counting(node):
    global traverse_node_counter
    global all_unique_nodes
    node_count = traverse_node_counter.get(node.id, 0)
    traverse_node_counter[node.id] = node_count + 1
    all_unique_nodes.add(node.id +"__"+str(node_count))
    
    subclasses = list(node.subclasses(with_self=False, distance=1))

    if len(subclasses) == 0:
        return {"id": node.id+"__"+str(node_count),
                "name": id_to_name[node.id] if node.id in id_to_name else node.id,
                "n_cells_rollup": int(cell_counts_df_rollup[node.id] if node.id in cell_counts_df_rollup else 0),
                "n_cells": int(cell_counts_df[node.id] if node.id in cell_counts_df else 0),
               }
        
    children = []
    for child in subclasses:
        children.append(traverse_with_counting(child))

    return {"id": node.id+"__"+str(node_count),
                "name": id_to_name[node.id] if node.id in id_to_name else node.id,
                "n_cells_rollup": int(cell_counts_df_rollup[node.id] if node.id in cell_counts_df_rollup else 0),
                "n_cells": int(cell_counts_df[node.id] if node.id in cell_counts_df else 0),
                "children": children,
               }


def dfs(parents, end, start, node=None, path = None, all_paths = []):
    if path is None and node is None:
        path = [end]
        node = end

    if node == start:
        return path
    
    for parent in parents.get(node,[]):
        full_path = dfs(parents, end, start, node=parent, path = path+[parent], all_paths=all_paths)
        if full_path:
            all_paths.append(full_path)
            
def truncate_graph(graph,valid_nodes):   
    if graph['id'] not in valid_nodes:
        return False

    children= graph.get("children",[])
    valid_children = []
    append_dummy = False
    
    invalid_children_ids = []
    for child in children:
        is_valid = truncate_graph(child, valid_nodes)
        if is_valid:
            valid_children.append(child)
        elif child['id']!='':
            invalid_children_ids.append(child['id'])
            append_dummy = True

    if append_dummy and len(valid_children) > 0:
        valid_children.append(
            {"id": "",
            "name": "",
            "n_cells_rollup": 0,
            "n_cells": 0,
             "invalid_children_ids": invalid_children_ids,
            "parent": graph['id']
            }        
        )
    if len(valid_children) > 0:
        graph['children'] = valid_children
    else:
        if 'children' in graph:
            del graph['children']

    return True


def truncate_graph_per_tissue(graph, valid_nodes, total_count, tissue_cell_counts, depth=0):
    global seen_nodes_per_tissue
    
    children = graph.get('children',[])
    if len(children):
        new_children = []
        invalid_children_ids = []
        for child in children:
            outlier_branch = depth == 1 and (tissue_cell_counts.get(child['id'].split('__')[0],{'n_cells_rollup': 0})['n_cells_rollup'] / total_count * 100) < 0.1
            if child['id'] in valid_nodes and child['id'] not in seen_nodes_per_tissue and not outlier_branch:
                new_children.append(child)
                seen_nodes_per_tissue.add(child['id'])
            else:
                invalid_children_ids.append(child['id'])
        if len(new_children) == 0:
            del graph['children']
        elif len(invalid_children_ids) > 0:
            new_children.append(
                {"id": "",
                "name": "",
                "n_cells_rollup": 0,
                "n_cells": 0,
                 "invalid_children_ids": invalid_children_ids,
                "parent": graph['id']
                }                
            )
            graph['children'] = new_children
        else:
            graph['children'] = new_children
        
        for child in graph.get('children',[]):
            if child['id'] != '':
                truncate_graph_per_tissue(child, valid_nodes, total_count, tissue_cell_counts, depth = depth+1)

def truncate_graph2(graph, visited_nodes_in_paths):
    # i want every node to only show children once
    # this means deleting "children" if seen more than once
    # EXCEPT if one of your children is in a path leading to acinar cell.
    # Then, you collapse the remaining children
    global nodesWithChildrenFound
    if graph['id'].split("__")[0] in nodesWithChildrenFound:
        if 'children' in graph:
            children = graph['children']            
            new_children = []
            invalid_children_ids = []
            for child in children:
                if child['id'] in visited_nodes_in_paths:
                    new_children.append(child)
                elif child['id'] != '':
                    invalid_children_ids.append(child['id'])
                    
            if len(children) > len(new_children) and len(new_children) > 0:
                # append dummy
                new_children.append(
                    {"id": "",
                    "name": "",
                    "n_cells_rollup": 0,
                    "n_cells": 0,
                     "invalid_children_ids": invalid_children_ids,
                     "parent": graph['id']
                    }        
                )
            if len(new_children) > 0:
                graph['children'] = new_children
            else:
                del graph['children']
    elif 'children' in graph:
        nodesWithChildrenFound.add(graph['id'].split("__")[0])
    
    
    children = graph.get("children",[])
    for child in children:
        if child['id'] != "":
            truncate_graph2(child, visited_nodes_in_paths)    

def prune_node_distinguishers(graph):
    graph['id'] = graph['id'].split('__')[0]
    for child in graph.get('children',[]):
        prune_node_distinguishers(child)

def delete_unknown_terms(graph):
    new_children = []
    for child in graph.get('children',[]):
        unknown = child['name'].startswith('CL:')
        if not unknown:
            new_children.append(child)
    if len(new_children) > 0:
        graph['children'] = new_children
    elif 'children' in graph:
        del graph['children']
    
    for child in graph.get('children',[]):
        delete_unknown_terms(child)
        
def truncate_graph_one_target(graph, target):
    global targetFound
    if targetFound and graph['id'].split("__")[0] == target.split("__")[0]:
        del graph['children']
    elif graph['id'] == target:
        targetFound = True
    
    children = graph.get("children",[])
    for child in children:
        truncate_graph_one_target(child, target)

def build_children(graph):
    global all_children
    children = graph.get('children',[])
    if len(children) == 0:
        ids = []
    else:
        ids = [child['id'] for child in children]
        
    all_children[graph['id']] = ids
    
    for child in children:
        build_children(child)

def build_parents(graph):
    global all_parents
    children = graph.get('children',[])
    
    for child in children:
        all_parents[child['id']]=[graph['id']]
        build_parents(child)
        
def getExpandedData(graph):
    global isExpandedNodes
    if 'children' in graph:
        isExpandedNodes.append(graph['id'])
        for child in graph['children']:
            getExpandedData(child)
                
        
def getShownData(graph):
    global notShownWhenExpandedNodes
    
    if 'children' in graph:
        for child in graph['children']:
            if child['id'] == "":
                if len(child["invalid_children_ids"]) > 0:
                    notShownWhenExpandedNodes.append({child['parent']: list(set(child["invalid_children_ids"]))})
            else:
                getShownData(child)
        
def _to_dict(a, b):
    """
    convert a flat key array (a) and a value array (b) into a dictionary with values grouped by keys
    """
    a = np.array(a)
    b = np.array(b)
    idx = np.argsort(a)
    a = a[idx]
    b = b[idx]
    bounds = np.where(a[:-1] != a[1:])[0] + 1
    bounds = np.append(np.append(0, bounds), a.size)
    bounds_left = bounds[:-1]
    bounds_right = bounds[1:]
    slists = [b[bounds_left[i] : bounds_right[i]] for i in range(bounds_left.size)]
    d = dict(zip(np.unique(a), [list(set(x)) for x in slists]))
    return d

# Build ontology tree JSON

In [None]:
ontology = Ontology("https://github.com/obophenotype/cell-ontology/releases/latest/download/cl-basic.obo")


all_cell_types = []
classes = [i for i in ontology if i.startswith('CL:')]
all_cell_type_owl_descriptions = {}
id_to_name = {}
for c in classes :
    c = ontology[c]
    if not c.id.startswith("CL:"):
        continue
    if c.obsolete :
        continue
    all_cell_types.append(
    {
        "label": c.name,
        "id": c.id
    }
    )
    id_to_name[c.id] = c.name
    
    all_cell_type_owl_descriptions[c.id] = str(c.definition) if str(c.definition) != 'None' else ''

In [None]:
X = tiledb.open('prod-snapshot/cell_counts')
cc = X.df[:]
cell_counts_df = cc.groupby('cell_type_ontology_term_id').sum(numeric_only=True)[['n_cells']]
uberon_by_celltype = _to_dict(cc['tissue_ontology_term_id'].values,cc['cell_type_ontology_term_id'].values)

In [None]:
cell_counts_df=cell_counts_df.reset_index()

In [None]:
all_cell_types_ids = [i["id"] for i in all_cell_types]
to_attach = pd.DataFrame()
to_attach['cell_type_ontology_term_id']=[i for i in all_cell_types_ids if i not in cell_counts_df['cell_type_ontology_term_id'].values]
to_attach['n_cells']=0

cell_counts_df = pd.concat([cell_counts_df,to_attach],axis=0)
cell_counts_df_rollup = rollup_across_cell_type_descendants(cell_counts_df).set_index('cell_type_ontology_term_id')['n_cells']
cell_counts_df = cell_counts_df.set_index('cell_type_ontology_term_id')['n_cells']

In [None]:
root_node = ontology['CL:0000548']

traverse_node_counter = {}
all_unique_nodes = set()
a = traverse_with_counting(root_node) 
all_unique_nodes = list(all_unique_nodes)
print(max(traverse_node_counter.values()))

In [None]:
all_children={}
all_parents={}    
build_children(a)
build_parents(a) 

In [None]:
start_node = 'CL:0000548__0'

all_states_per_cell_type = {}
for i,end_node in enumerate(all_cell_types_ids):
    if i%100==0:
        print(i)
    if end_node in traverse_node_counter:
        all_paths=[]
        for i in range(traverse_node_counter[end_node]):
            paths = []    
            dfs(all_parents,end_node+"__"+str(i),start_node,all_paths=paths)
            paths = [i[::-1] for i in paths] 
            if len(paths) == 0:
                all_paths.append([end_node+"__"+str(i)])
            else:
                all_paths.append(paths[0])

        ### RULES ###
        # 1. We only want to show terms that are CHILDREN, GRANDCHILDREN, SIBLINGS OF TARGET, or IN A PATH TO TARGET
        visited_nodes_in_paths = list(set(sum(all_paths,[])))

        children1 = all_children.get(end_node+"__0",[]) #children
        children2 = sum([all_children.get(child,[]) for child in children1],[]) #grandchildren
        siblings=[]
        for i in range(traverse_node_counter[end_node]):
            sibs = sum([all_children.get(parent,[]) for parent in all_parents.get(end_node+"__"+str(i),[])],[]) #siblings
            siblings.append(sibs)
        siblings = list(set(sum(siblings,[])))


        valid_nodes = list(set(visited_nodes_in_paths + children1 + children2 + siblings))

        a_copy = json.loads(json.dumps(a))
        truncate_graph(a_copy,valid_nodes) 

        nodesWithChildrenFound=set()
        truncate_graph2(a_copy, visited_nodes_in_paths)
        delete_unknown_terms(a_copy)
        
        # now, given this graph, populate what you need - specifically, we need "notShownWhenExpanded" and "isExpanded"
        notShownWhenExpandedNodes=[]
        isExpandedNodes=[]
        
        getExpandedData(a_copy)
        getShownData(a_copy)

        assert(len(list(set([list(i.keys())[0] for i in notShownWhenExpandedNodes])))==len(notShownWhenExpandedNodes))        
        
        notShownWhenExpanded = {}
        for i in notShownWhenExpandedNodes:
            notShownWhenExpanded.update(i)
            
        all_states_per_cell_type[end_node] = {'isExpandedNodes': list(set(isExpandedNodes)), 'notShownWhenExpandedNodes': notShownWhenExpanded}  

In [None]:
all_cell_types_final = []

for ct in all_cell_types:
    if cell_counts_df_rollup[ct['id']] > 0:
        all_cell_types_final.append(ct)

In [None]:
uberon = Ontology("http://purl.obolibrary.org/obo/uberon.obo")

In [None]:
hemolymphoid_system = 'UBERON:0002193'
hematopoietic_system = 'UBERON:0002390'
blood = 'UBERON:0000178'
immune_organ = 'UBERON:0005057'
blacklist = [hemolymphoid_system, hematopoietic_system, blood, immune_organ]

In [None]:
start_node = 'CL:0000548__0'
tissue_counts = cc.groupby('tissue_ontology_term_id').sum(numeric_only=True)['n_cells']

all_states_per_tissue = {}
tissue_by_cell_type = []
for tissue in uberon_by_celltype:
    if " (" not in tissue:
        print(tissue)
        tissueId=tissue
        tissue_term = uberon[tissueId]
        tissue_label = tissue_term.name

        end_nodes = uberon_by_celltype[tissue]
        uberon_ancestors = [i.id for i in tissue_term.superclasses()]
        if len(list(set(blacklist).intersection(uberon_ancestors)))==0:
            end_nodes2 = [e for e in end_nodes if 'CL:0000988' not in [i.id for i in ontology[e].superclasses()]]
            if len(end_nodes2)==0:
                print("Not filtering out immune cell for",tissue_label)
            else:
                end_nodes=end_nodes2
        else:
            print("Not filtering out immune cell for",tissue_label)


        tissue_ct_df = cc.groupby(['tissue_ontology_term_id','cell_type_ontology_term_id']).sum(numeric_only=True).reset_index()
        tissue_ct_df = tissue_ct_df[tissue_ct_df['tissue_ontology_term_id']==tissue]
        df = tissue_ct_df[['cell_type_ontology_term_id','n_cells']]

        to_attach = pd.DataFrame()
        to_attach['cell_type_ontology_term_id']=[i for i in all_cell_types_ids if i not in df['cell_type_ontology_term_id'].values]
        to_attach['n_cells']=0

        df = pd.concat([df,to_attach],axis=0)    
        df['n_cells_rollup'] = df['n_cells']
        df_rollup = rollup_across_cell_type_descendants(df,ignore_cols=['n_cells'])
        df_rollup = df_rollup[df_rollup['n_cells_rollup'] > 0]

        res = dict(zip(df_rollup['cell_type_ontology_term_id'],df_rollup[['n_cells','n_cells_rollup']].to_dict(orient='records')))

        tissue_by_cell_type.append({"id": tissue, "label": tissue_label})



        all_paths=[]
        for end_node in end_nodes:
            if end_node in traverse_node_counter:
                i=0 #only get path to the first instance of a node.
                paths = []    
                dfs(all_parents,end_node+"__"+str(i),start_node,all_paths=paths)

                paths = [i[::-1] for i in paths] 
                if len(paths) == 0:
                    all_paths.append([end_node+"__"+str(i)])
                else:
                    all_paths.append(paths[0])

        ### RULES ###
        # 1. We only want to show terms that are CHILDREN, GRANDCHILDREN, SIBLINGS OF TARGET, or IN A PATH TO TARGET
        visited_nodes_in_paths = list(set(sum(all_paths,[])))

        valid_nodes = list(set(visited_nodes_in_paths))

        a_copy = json.loads(json.dumps(a))
        seen_nodes_per_tissue=set()
        truncate_graph_per_tissue(a_copy,valid_nodes, tissue_counts[tissue], res) 

        delete_unknown_terms(a_copy)

        # now, given this graph, populate what you need - specifically, we need "notShownWhenExpanded" and "isExpanded"
        notShownWhenExpandedNodes=[]
        isExpandedNodes=[]

        getExpandedData(a_copy)
        getShownData(a_copy)

        assert(len(list(set([list(i.keys())[0] for i in notShownWhenExpandedNodes])))==len(notShownWhenExpandedNodes))        

        notShownWhenExpanded = {}
        for i in notShownWhenExpandedNodes:
            notShownWhenExpanded.update(i)

        all_states_per_tissue[tissue] = {'isExpandedNodes': list(set(isExpandedNodes)), 'notShownWhenExpandedNodes': notShownWhenExpanded, "tissueCounts": res}

In [None]:
delete_unknown_terms(a)