# Visualize phylogenetic trees

Import Python packages:

In [None]:
import collections
import itertools
import math
import os

import Bio.SeqIO

import numpy

import pandas as pd

import ete3

Get variables from `snakemake`:

In [None]:
tree_files = snakemake.input.trees
state_files = snakemake.input.states
all_csv = snakemake.input.all_csv
alignment = snakemake.input.alignment
deleted_csv = snakemake.input.deleted_csv
comparator_map_csv = snakemake.input.comparator_map
site_offset = snakemake.params.site_offset
progenitors = snakemake.params.progenitors
outgroups = snakemake.params.outgroups
region_of_interest = snakemake.params.region_of_interest
cat_colors = snakemake.params.cat_colors
subcat_colors = snakemake.params.subcat_colors
wuhan_hu_1_add_muts = snakemake.params.wuhan_hu_1_add_muts

Read data frames:

In [None]:
all_df = pd.read_csv(all_csv, na_filter=None)

deleted_df = pd.read_csv(deleted_csv, na_filter=None)

outgroup_map = pd.read_csv(comparator_map_csv)

Read the reconstructed states and add the tip states, considering only sites where there are multiple identities.
Now we make a matrix of substitutions of each node relative to each other.
Note that in the general case this could get **really slow**, although for the small-ish alignment of similar sequences used here it's fine:

In [None]:
assert len(state_files) == len(tree_files) == len(progenitors)

tip_to_seq = {s.id: list(str(s.seq).upper())
              for s in Bio.SeqIO.parse(alignment, 'fasta')}

tip_states = (
    pd.DataFrame.from_dict(tip_to_seq, orient='index')
    .rename_axis('Node')
    .reset_index()
    .melt(id_vars='Node',
          var_name='Site',
          value_name='State',
          )
    .assign(Site=lambda x: x['Site'] + 1)
    )

subs_matrices = {}

for progenitor, state_file in zip(progenitors, state_files):
    internal_states = (
        pd.read_csv(state_file,
                    sep='\t',
                    comment='#',
                    usecols=['Node', 'Site', 'State'])
        )

    states = (
        internal_states
        .append(tip_states)
        .assign(Site=lambda x: x['Site'] + site_offset,
                n_states_at_site=lambda x: x.groupby('Site')['State'].transform('nunique'),
                )
        .query('n_states_at_site > 1')
        .drop(columns='n_states_at_site')
        )

    states_dict = states.set_index(['Node', 'Site'])['State'].to_dict()

    nodes = states['Node'].unique().tolist()
    sites = sorted(states['Site'].unique())

    subs_matrix = {}  # keyed by (parent, descendant)
    for n1, n2 in itertools.product(nodes, nodes):
        subs = []
        for site in sites:
            nt1 = states_dict[(n1, site)]
            nt2 = states_dict[(n2, site)]
            if nt1 != nt2:
                if nt1 in {'A', 'C', 'G', 'T'} and nt2 in {'A', 'C', 'G', 'T'}:
                    subs.append(f"{nt1}{site}{nt2}")
        subs_matrix[(n1, n2)] = ', '.join(subs)
    subs_matrices[progenitor] = subs_matrix

Get annotations of which mutations are to an outgroup:

In [None]:
subs_to_outgroup = {outgroup: set(outgroup_map
                                  .assign(sub=lambda x: x['site'].astype(str) + x[outgroup])
                                  ['sub']
                                  )
                    for outgroup in outgroups}

List information on the possible progenitors:

In [None]:
pd.set_option('display.max_colwidth', 1000)

progenitor_info = (
    all_df
    .query('representative_strain in @progenitors')
    .assign(all_strains_dates=lambda x: x['all_strains_dates'].str.split(', '))
    [['substitutions', 'representative_strain', 'nstrains', 'all_strains_dates']]
    .set_index(['substitutions', 'representative_strain', 'nstrains'])
    .explode('all_strains_dates')
    .assign(strain=lambda x: x['all_strains_dates'].str.split(' \(').str[0],
            date=lambda x: x['all_strains_dates'].str.split(' \(').str[1].str[: -1],
            )
    .drop(columns='all_strains_dates')
    )

display(progenitor_info)

Set up labels on possible progenitors:

In [None]:
progenitor_dates = progenitor_info.set_index('strain')['date'].to_dict()
progenitor_subs = progenitor_info.reset_index().set_index('strain')['substitutions'].to_dict()

wu_hu_1_subs = {int(s[1: -1]): (s[0], s[-1]) for s in wuhan_hu_1_add_muts}

node_labels = {}
for progenitor in progenitors:
    subs = [s for s in progenitor_subs[progenitor].split(',') if s]
    
    subs_wu_1 = []
    sites_added = set()
    for s in subs:
        site, wt, mut = int(s[1: -1]), s[0], s[-1]
        if site not in wu_hu_1_subs:
            subs_wu_1.append((site, wt, mut))
        else:
            mut = wu_hu_1_subs[site][-1]
            if mut != wt:
                subs_wu_1.append((site, wt, mut))
        sites_added.add(site)
    for site, (wt, mut) in wu_hu_1_subs.items():
        if site not in sites_added:
            subs_wu_1.append((site, wt, mut))
    subs_wu_1 = [f"{wt}{site}{mut}" for site, wt, mut in sorted(subs_wu_1)]
        
    node_labels[progenitor] = [
            progenitor.replace('hCoV-19/', '') + f" ({progenitor_dates[progenitor]})",
            'mutations from proCoV2 (Kumar et al): ' + (', '.join(subs) if subs else 'none'),
            'mutations from Wuhan-Hu-1: ' + (', '.join(subs_wu_1) if subs else 'none'),
            ]

This next line enables Jupyter notebook rendering of trees in `ete3`:

In [None]:
# to enable jupyter notebook rendering: 
os.environ['QT_QPA_PLATFORM'] = 'offscreen'

Add in deleted sequences at all "plausible" locations, adding proportional to number of nodes at each location at which it's consistent, and making the category labels the sub categories:

In [None]:
start, end = region_of_interest['start'], region_of_interest['end']

def is_compatible(all_subs, subs):
    all_subs = {s for s in all_subs.split(',')
                if s and start <= int(s[1 : -1]) <= end}
    return subs == all_subs

all_plus_deleted_df = (
    all_df
    .assign(deleted=0)
    .rename(columns={c: c.split('_')[1] if c.startswith('subcat_') else c
                     for c in all_df.columns})
    )
for tup in deleted_df.itertuples():
    subs_str, n = tup.substitutions, tup.nstrains
    subs = {s for s in subs_str.split(',') if s}
    all_plus_deleted_df = (
        all_plus_deleted_df
        .assign(compatible=lambda x: x['substitutions'].apply(is_compatible, args=(subs,)))
        )
    compatible_df = all_plus_deleted_df[['substitutions', 'nstrains', 'compatible']].query('compatible')
    n_tot = compatible_df['nstrains'].sum()
    print(f"There are {len(compatible_df)} sequences comprising {n_tot} strains compatible "
          f"with the {n} deleted sequences with the following substitutions: {subs_str}")
    if n_tot == 0:
        continue
    all_plus_deleted_df = (
        all_plus_deleted_df
        .assign(deleted=lambda x: x['deleted'] + x['nstrains'] * n / n_tot * x['compatible'].astype(int))
        )
all_plus_deleted_df = (
    all_plus_deleted_df
    .assign(nstrains=lambda x: x['nstrains'] + x['deleted'])
    .rename(columns={'deleted': 'deleted early Wuhan'})
    .drop(columns='compatible')
    )

Modify `all_df` to just have categories as columns, and also to split the `Wuhan` category into `Huanan Seafood Market` and `other Wuhan`:

In [None]:
convert_cats_to_subcats = {'Wuhan': ['Huanan Seafood Market', 'other Wuhan']}

for cat, subcats in convert_cats_to_subcats.items():
    assert all(all_df[f"cat_{cat}"] == all_df[[f"subcat_{subcat}" for subcat in subcats]].sum(axis=1))
    all_df_cat_cols = (
        all_df
        .drop(columns=f"cat_{cat}")
        .rename(columns={f"subcat_{subcat}": subcat.replace('subcat_', 'cat_')
                         for subcat in subcats}
                )
        )
    new_cat_colors = {}
    for key, val in cat_colors.items():
        if key == cat:
            for subcat in subcats:
                new_cat_colors[subcat] = subcat_colors[subcat]
        else:
            new_cat_colors[key] = val
    cat_colors = new_cat_colors

all_df_cat_cols = all_df_cat_cols.rename(columns={c: c.split('_')[1] if c.startswith('cat_') else c
                                                  for c in all_df.columns})

Draw the trees:

In [None]:
def get_PieChartFace(node_annotation, nodesizescale, locations):
    nstrains = node_annotation['nstrains']
    loc_sum = sum(node_annotation[loc] for loc in locations)
    assert numpy.allclose(nstrains, loc_sum, atol=1e-3), f"{nstrains=}\n{loc_sum=}\n{locations=}\n{node_annotation=}"
    radius = nodesizescale * math.sqrt(node_annotation['nstrains'])
    percents = []
    colors = []
    for loc, color in locations.items():
        percents.append(100 * node_annotation[loc] / nstrains)
        colors.append(color)
    return ete3.PieChartFace(percents=percents,
                             colors=colors,
                             width=radius,
                             height=radius,
                             #line_color='black',
                             )

def get_pretty_tree(treefile,
                    df,
                    subs_matrix,
                    progenitor,
                    to_outgroup_muts,
                    locations,
                    widthscale=1350,
                    heightscale=1,
                    nodesizescale=35,
                    label_nodes=False,
                    label_fontsize=25,
                    branch_linewidth=4,
                    mut_label_color='black',
                    mut_to_outgroup_label_color='#CC79A7',
                    collapse_below={'C28144T': ['colapsed clade B', 'Wuhan-Hu-1']},
                    node_label_color='#999999',
                    node_labels={},
                    title_fontsize=35,
                    draw_legend=True,
                    ):
    """Returns `(tree, tree_style)`."""
    annotations = df.set_index('representative_strain').to_dict(orient='index')
    for loc in locations:
        if loc not in list(annotations.values())[0]:
            raise ValueError(f"annotations missing {loc}:\n{list(annotations.values())[0]}")
   
    t = ete3.Tree(treefile, format=1)

    ts = ete3.TreeStyle()
    ts.show_leaf_name = False  # add tip names manually
     
    nstyle_dict = {'hz_line_width': branch_linewidth,
                   'vt_line_width': branch_linewidth,
                   'hz_line_color': 'black',
                   'vt_line_color': 'black',
                   'size': 0}

    # label nodes
    for n in t.traverse():
        if n != t:
            subs = subs_matrix[(n.up.name, n.name)]
            icol = 0
            for sub in subs.split(', '):
                if sub:
                    to_outgroup = sub[1:] in to_outgroup_muts
                    n.add_face(ete3.TextFace(f"{sub}  ",
                                             fsize=label_fontsize,
                                             fgcolor=(mut_to_outgroup_label_color
                                                      if to_outgroup else
                                                      mut_label_color),
                                             bold=to_outgroup,
                                             ),
                               column=icol,
                               position='branch-top',
                               )
                    icol += 1
        nstyle = ete3.NodeStyle(**nstyle_dict)
        if n.is_leaf() or n.name == progenitor:
            n.add_face(get_PieChartFace(annotations[n.name], nodesizescale, locations),
                       column=0,
                       position='branch-right',
                       )
            if label_nodes:
                n.add_face(ete3.TextFace(annotations[n.name]['substitutions'],
                                         tight_text=True,
                                         fsize=labelfontsize,
                                         ),
                           column=0,
                           position='branch-right',
                           )
        n.set_style(nstyle)
        
    # set dummy node for root, this allows us to put
    # progenitor at base of root branch
    progenitor_n = t.search_nodes(name=progenitor)
    assert len(progenitor_n) == 1
    progenitor_n = progenitor_n[0]
    dummy_outgroup = progenitor_n.add_child(name='dummy',
                                            dist=1e-5)
    t.set_outgroup(dummy_outgroup)
    assert len(progenitor_n.children) == 1
    progenitor_child = progenitor_n.children[0]
    nstyle = ete3.NodeStyle(**nstyle_dict)
    progenitor_child.set_style(nstyle)
    if progenitor_n.faces:
        sub = subs_matrix[(progenitor_n.up.name, progenitor)]
        if sub:
            assert len(sub.split(',')) == 1
            sub_flipped = f"{sub[-1]}{sub[1: -1]}{sub[0]}"
            to_outgroup = sub_flipped[1:] in to_outgroup_muts
            progenitor_child.add_face(ete3.TextFace(f"{sub_flipped}  ",
                                                    fsize=label_fontsize,
                                                    fgcolor=(mut_to_outgroup_label_color
                                                             if to_outgroup else
                                                             mut_label_color),
                                                    bold=to_outgroup,
                                                    ),
                                      column=0,
                                      position='branch-top',
                                      )
        delattr(progenitor_n, '_faces')
    empty_nstyle = ete3.NodeStyle(hz_line_color='white',
                                  vt_line_color='white',
                                  size=0)
    dummy_outgroup.set_style(empty_nstyle)
    dummy_outgroup.up.set_style(empty_nstyle)
    progenitor_n.add_face(get_PieChartFace(annotations[progenitor], nodesizescale, locations),
                          column=0,
                          position='branch-right',
                          )

    t.ladderize()
    
    # collapse tree at indicated nodes
    for below_subs, txt in collapse_below.items():
        collapse_n = [n for n in t.traverse() if n.name and n.up and n.up.name and
                      n.name != 'dummy' and subs_matrix[(n.up.name, n.name)] == below_subs]
        if len(collapse_n) > 1:
            raise ValueError(f"more than one node to collapse for {below_subs}")
        elif len(collapse_n) == 0:
            raise ValueError(f"no nodes to collapse for {below_subs}")
        collapse_n = collapse_n[0]
        collapse_nstrains = 0
        collapse_annotations = {key: 0 for key in ['nstrains', *locations]}
        for n in [collapse_n] + collapse_n.get_descendants():
            if n.is_leaf():
                collapse_annotations['nstrains'] += annotations[n.name]['nstrains']
                for location in locations:
                    collapse_annotations[location] += annotations[n.name][location]
            if n != collapse_n:
                n.detach()
        nstyle = ete3.NodeStyle(**nstyle_dict)
        collapse_n.set_style(nstyle)
        collapse_n.add_face(get_PieChartFace(collapse_annotations, nodesizescale, locations),
                            column=0,
                            position='branch-right',
                            )
        for face_txt in txt:
            collapse_n.add_face(ete3.TextFace(f" {face_txt}",
                                              fsize=label_fontsize,
                                              fgcolor=node_label_color,
                                              bold=True,
                                              ),
                                column=1,
                                position='branch-right',
                                )
    
    # label specified non-progenitor nodes
    for name, labels in node_labels.items():
        if name == progenitor:
            continue
        n = t.search_nodes(name=name)
        assert len(n) == 1
        n = n[0]
        for face_txt in labels[: 1]:
            n.add_face(ete3.TextFace(f" {face_txt}",
                                     fsize=label_fontsize,
                                     fgcolor=node_label_color,
                                     bold=True,
                                     ),
                       column=1,
                       position='branch-right',
                       )
            
    # add title
    if progenitor in node_labels:
        title_labels = node_labels[progenitor]
    else:
        title_labels = [progenitor]
    for i, face_txt in enumerate(title_labels):
        if i == 0:
            prefix = 'progenitor as '
        else:
            prefix = ''
        ts.title.add_face(ete3.TextFace(prefix + face_txt,
                                        fsize=title_fontsize,
                                        fgcolor='black'),
                          column=0,
                          )
    ts.title.add_face(ete3.TextFace('', fsize=title_fontsize),
                      column=0)
    
    ts.show_scale = False
    height = t.get_farthest_node()[1]
    ts.scale = widthscale / height
    ts.branch_vertical_margin = heightscale
    
    if draw_legend:
        for loc, color in locations.items():
            # add padding
            for col in [0, 1]:
                ts.legend.add_face(ete3.RectFace(0.5 * nodesizescale,
                                                 0.5 * nodesizescale,
                                                 'white', 'white'),
                                   column=col)
            # add legend
            ts.legend.add_face(ete3.CircleFace(1.2 * nodesizescale,
                                               color),
                               column=0)
            ts.legend.add_face(ete3.TextFace(' ' + loc,
                                             fsize=title_fontsize),
                               column=1)
        ts.legend_position = 3  # 3 is bottom left
    
    return t, ts

assert len(progenitors) == len(tree_files)
dirname = 'results/phylogenetics/tree_images'
os.makedirs(dirname, exist_ok=True)
for outgroup, (progenitor, tree_file) in itertools.product(outgroups,
                                                           zip(progenitors, tree_files)
                                                           ):
    progenitor_str = progenitor.replace('/', '%')
        
    t, ts = get_pretty_tree(tree_file,
                            df=all_df_cat_cols,
                            subs_matrix=subs_matrices[progenitor],
                            progenitor=progenitor,
                            to_outgroup_muts=subs_to_outgroup[outgroup],
                            node_labels=node_labels,
                            locations=cat_colors,
                            )
    display(t.render('%%inline', tree_style=ts, w=300))
    treefile = os.path.join(dirname, f"{progenitor_str}_{outgroup}_without_deleted_seqs.pdf")
    print(f"Saving to {treefile}\n")
    t.render(treefile, tree_style=ts)
    
    t_deleted, ts_deleted = get_pretty_tree(
                            tree_file,
                            df=all_plus_deleted_df,
                            subs_matrix=subs_matrices[progenitor],
                            progenitor=progenitor,
                            to_outgroup_muts=subs_to_outgroup[outgroup],
                            node_labels=node_labels,
                            locations=subcat_colors,
                            )
    display(t_deleted.render('%%inline', tree_style=ts_deleted, w=300))
    treefile = os.path.join(dirname, f"{progenitor_str}_{outgroup}_with_deleted_seqs.pdf")
    print(f"Saving to {treefile}\n")
    t_deleted.render(treefile, tree_style=ts_deleted)