## Label nodes by clade
In this notebook, I'm going to label nodes in the tree using the sample clade assignments.
Each node has a set with clade labels. 
First, I'll do upwards traversal.
Each tip's set is automatically labelled with it's clade.
If len(set) == 1, pass up the clade label.

Then we'll do downward's traversal.
If child has multiple labels, take parent's clade.

This algorithm will probably be confusing for me, so let's start with 21H (Mu) and then label this based on pango lineages.

In [7]:
import os
from Bio import Phylo
import pandas as pd
import numpy as np

In [8]:
os.chdir('/Users/cwagner2/Work/projects/covid/long-deletions/')

In [18]:
with open('usher/trimmed/subset.nwk', 'r') as f:
    tree = Phylo.read(f,'newick')
with open('usher/trimmed/subset_sample_clades.tsv','r') as f:
    df = pd.read_csv(f, sep='\t')

In [19]:
df.head()

Unnamed: 0,sample,annotation_1,annotation_2
0,England/LIVE-9C50A/2020|2020-03-21,19A,B.15
1,England/PHEC-1E01E/2020|2020-04-03,19A,B.15
2,USA/CA-CZB-11879/2020|MW276507.1|2020-04-08,19A,B.20
3,USA/CA-CZB-1085/2020|MT449676.1|2020-04-08,19A,B.20
4,England/BIRM-61AC3/2020|2020-03-23,19A,B.52


In [61]:
def tabulate_names(tree):
    names = {}
    for idx, clade in enumerate(tree.find_clades()):
        if not clade.name:
            clade.name = str(idx)
        names[clade.name] = clade
    return names

def all_parents(tree):
    parents= {}
    for clade in tree.find_clades(order="level"):
        for child in clade:
            parents[child.name] = clade
    return parents

nextstrain_clades = [
    '19A',
    '19B',
    '20A',
    '20B',
    '20C',
    '20D',
    '20E (EU1)',
    '20F',
    '20G',
    '20H (Beta,V2)',
    '20I (Alpha,V1)',
    '20J (Gamma,V3)',
    '21A (Delta)',
    '21B (Kappa)',
    '21C (Epsilon)',
    '21D (Eta)',
    '21E (Theta)',
    '21F (Iota)',
    '21G (Lambda)',
    '21H (Mu)',
    '21I (Delta)',
    '21J (Delta)',
    '21K (Omicron)',
    '21L (Omicron)',
    '21M (Omicron)',
    '22A (Omicron)',
    '22B (Omicron)',
    '22C (Omicron)',
    '22D (Omicron)',
    '22E (Omicron)',
    '22F (Omicron)'
]

def get_mapping(clades):
    return {clade:i for i,clade in enumerate(clades)}

def get_oldest(lineages,mapping):
    lineages = list(lineages)
    ordered = [mapping[lin]for lin in lineages]
    oldest = lineages[np.argmin(ordered)]
    return set([oldest])

def label_up(child,parents,labels,mapped):
    if child.name in parents.keys():
        parent = parents[child.name]
        if parent.name not in labels.keys():
            labels[parent.name] = labels[child.name]
            label_up(parent,parents,labels,mapped)
        elif not labels[child.name].issubset(labels[parent.name]):
            summed = labels[parent.name].union(labels[child.name])
            labels[parent.name] = get_oldest(summed,mapped)
            label_up(parent,parents,labels,mapped)
            #Hmm unless you also pass up these labels, I think the parents will just get labelled whatever is first.
            

def tabulate_labels(tree,parents,clades,mapped):
    labels = {}
    leaves = tree.get_terminals()
    for leaf in leaves:
        labels[leaf.name] = set([clades[leaf.name]])
        label_up(leaf,parents,labels,mapped)
    return labels

def resolve_labels(tree, labels,parents):
    for clade in tree.find_clades():
        if len(labels[clade.name])>1:
            if clade.name in parents.keys():
                parent = parents[clade.name]
                labels[clade.name] = labels[parent.name]
    return labels


In [62]:
named = tabulate_names(tree)
parents = all_parents(tree)
clades = {df.at[row,'sample']:df.at[row,'annotation_1'] for row in df.index}
mapped = get_mapping(nextstrain_clades)

In [63]:
labelled = tabulate_labels(tree,parents,clades,mapped)
#labelled = resolve_labels(tree,labelled_up,parents)

In [64]:
def to_df(labelled):
    nodes = []
    labels = []
    for node in labelled.keys():
        if len(labelled[node])==1:
            for e in labelled[node]:
                nodes.append(node)
                labels.append(e)
    new_df = pd.DataFrame({'node':nodes,'clade':labels})
    return new_df

In [65]:
new_df = to_df(labelled)

new_df.to_csv('usher/trimmed/subset_nodes_clades.tsv',sep='\t',index=False)