# Auspice Trimming

This notebook demonstrates how to load and trim an auspice v2 JSON.

In [None]:
import json
from enum import Enum
import copy

In [None]:
class NodeType(Enum):
    NODE = 0
    LEAF = 1

class Node():
    def __init__(self, data_dict = None):
        self.parent = None
        self.children = []
        self.branch_attrs = {}
        self.node_attrs = {}
        self.name = None
        self.type = None
        
        if data_dict:
            self.from_dict(data_dict)
        
    def to_dict(self):
        d = {'branch_attrs': self.branch_attrs,
             'node_attrs': self.node_attrs,
             'name': self.name}
        if self.children:
             d['children'] = [child.to_dict() for child in self.children]

        return d
        
    def from_dict(self, d):
        self.branch_attrs = d['branch_attrs']
        self.node_attrs = d['node_attrs']
        self.name = d['name']

        if 'children' in d and len(d['children']) > 0:
            self.children = [Node(c) for c in d['children']]
        else:
            self.children = []
        for c in self.children:
            c.parent = self
            
        if self.children:
            self.type = NodeType.NODE
        else:
            self.type = NodeType.LEAF
            
    def descendents(self):
        return self.children + [node for c in self.children for node in c.descendents()]

    def get_attr(self, attr):
        if attr in self.branch_attrs:
            if isinstance(self.branch_attrs[attr], dict):
                return self.branch_attrs[attr]['value']
            else:
                return self.branch_attrs[attr]
        
        if attr in self.node_attrs:
            if isinstance(self.node_attrs[attr], dict):
                return self.node_attrs[attr]['value']
            else:
                return self.node_attrs[attr]
            
    def set_attr(self, attr, value, attr_type='node'):
        if value is None:
            return
        if attr_type == 'node':
            self.node_attrs[attr] = {'value': value}
        else:
            self.branch_attrs[attr] = {'value': value}
    
class Tree():
    def __init__(self, data_dict = None):
        self.root = None
        self.nodes = []
        if data_dict:
            self.from_dict(data_dict)
  
    def to_dict(self):
        return self.root.to_dict()
                           
    def from_dict(self, data_dict):
        self.root = Node(data_dict)
        self.nodes = [self.root] + self.root.descendents()
        
    def set_node_attr(self, attr, state):
        for node in self.nodes:
            self.node_attrs[attr] = state
            
    def subset_tree(self, nodes_to_keep):
        for node in self.nodes:
            node.children = [c for c in node.children if c in nodes_to_keep]
        self.nodes = nodes_to_keep
        
    def trim_terminal_nodes(self):
        nodes_to_keep = [node for node in self.nodes if 
                        node.type == NodeType.LEAF or len(node.children) > 0]
        self.subset_tree(nodes_to_keep)

In [None]:
def walk_to_root(nodes):
    stack = nodes.copy()
    done = []
    while stack:
        node = stack.pop()
        if node.parent and node.parent not in stack and node.parent not in done:
            stack.append(node.parent)
        done.append(node)
    return done

def walk_down(nodes, mode = "steps", depth = 1):
    if mode == 'steps':
        levels = [nodes.copy()]
        for i in range(depth):
            next_level = []
            for node in levels[-1]:
                next_level.extend(node.children)
            levels.append(next_level)
        return [node for level in levels for node in level]
    if mode == "mutations":
        levels = [nodes.copy()] + [[]]*depth
        done = nodes.copy()
        for i in range(depth + 1):
            j = 0
            while j < len(levels[i]):
                node = levels[i][j]
                for c in node.children:
                    distance = i + num_mutations(c)
                    if distance < depth + 1:
                        if c not in done:
                            levels[distance].append(c)
                            done.append(c)
                j += 1
        return done

In [None]:
def num_mutations(node):
    if 'mutations' in node.branch_attrs:
        if 'nuc' in node.branch_attrs['mutations']:
            return len(node.branch_attrs['mutations']['nuc'])
    return 0

In [None]:
import pandas as pd
import numpy as np

def get_county(node):
    divison = node.get_attr('division')
    location = node.get_attr('location')
    originating_lab = node.node_attrs.get('originating_lab')
    
    if divison == 'Grand Princess' or location == 'Grand Princess cruise ship':
        return 'Grand Princess Cruise Ship'
    
    if originating_lab and 'Santa Clara' in originating_lab:
        return 'Santa Clara'
    
    if divison == 'California':
        if isinstance(location, str):
            if location[-7:] == ' County':
                return location[:-7]

        return location
    
    return None

In [None]:
with open("/Users/josh/Desktop/ncov_rr68.json", 'r') as fp:
    js = json.load(fp)

In [None]:
t = Tree(js['tree'])

In [None]:
len(t.nodes)

In [None]:
# Set County
for node in t.nodes:
    if node.type == NodeType.LEAF:
        county = get_county(node)
        node.set_attr('county', county)

In [None]:
nodes_to_keep = [node for node in t.nodes if 
                 node.get_attr('county') == 'Santa Clara']

In [None]:
len(nodes_to_keep)

In [None]:
# Get ancestors of Santa Clara
nodes_to_keep = walk_to_root(nodes_to_keep)
# Get things with 0 SNPs from those ancestors
nodes_to_keep = walk_down(nodes_to_keep, mode='mutations', depth=0)
# Keep only subtree determined by leaves (ie, filter out internal nodes w/ no descendents)
nodes_to_keep = walk_to_root([n for n in nodes_to_keep if n.type == NodeType.LEAF])

In [None]:
t.subset_tree(nodes_to_keep)

In [None]:
len(t.nodes)

# Translate Names

In [None]:
scc_ids = pd.read_csv('/Users/josh/Downloads/scc_sample_ids.csv')

local_translator = dict(zip(scc_ids['CZB_ID'], scc_ids['Supplier_ID']))
local_translator.update(dict(zip(scc_ids['gisaid_name'], scc_ids['Supplier_ID'])))

gisaid_translator = dict(zip(scc_ids['CZB_ID'], scc_ids['gisaid_name']))

In [None]:
# id_translator = gisaid_translator
id_translator = local_translator

for node in t.nodes:
    if 'RR0' in node.name:
        node.name = "_".join(node.name.split('_')[:2])
    if node.name in id_translator:
        node.name = id_translator[node.name]

# Add Metadata

In [None]:
county_coloring = {'key': 'county', 'title': 'County', 'type': 'categorical'}
js['meta']['colorings'].insert(0, county_coloring)

In [None]:
js['meta']['filters'].insert(0, 'county')

In [None]:
maintainers = [{'name': 'Chan Zuckerberg Biohub', 'url': 'https://www.czbiohub.org'},
               {'name': 'Santa Clara DPH', 'url': 'https://www.sccgov.org/sites/phd/Pages/phd.aspx'}]

In [None]:
js['meta']['maintainers'] = maintainers

In [None]:
js['meta']['display_defaults']['color_by'] = 'county'
js['meta']['display_defaults']['geo_resolution'] = 'division'

In [None]:
with open('/Users/josh/Desktop/scc_description.md', 'r') as fp:
    description = fp.read()

In [None]:
js['meta']['description'] = description

In [None]:
with open('/Users/josh/Desktop/scc.json', 'w') as fp:
    json.dump(
        {"meta": js['meta'],
         "version": js['version'],
        "tree": t.to_dict()},
        fp,
        indent=2)

# Misc

Tree Library: https://github.com/caesar0301/treelib
Baltic: https://github.com/evogytis/baltic