# Table S2. Count epitope mutations by trunk status for natural populations

For a given tree, classify each node as trunk or not and count the number of epitope and non-epitope mutations. Finally, summarize the number of mutations by category of trunk and mutation.

In [None]:
# Define inputs.
full_tree_json = snakemake.input.full_tree_json
epitope_sites_distance_map = snakemake.input.epitope_sites_distance_map

# Define outputs.
output_table = snakemake.output.table

In [None]:
from augur.distance import read_distance_map
from augur.utils import json_to_tree
import Bio.Phylo
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

%matplotlib inline

## Load tree data

In [None]:
with open(full_tree_json, "r") as fh:
    tree_json = json.load(fh)

In [None]:
tree = json_to_tree(tree_json)

In [None]:
tree

## Load distance map

In [None]:
distance_map = read_distance_map(epitope_sites_distance_map)

In [None]:
# Extract all epitope sites from the distance map, readjusting to one-based coordinates
# for comparison with one-based coordinates of amino acid mutations annotated on trees.
epitope_sites = [site + 1 for site in distance_map["map"]["HA1"].keys()]

In [None]:
np.array(epitope_sites)

In [None]:
", ".join([str(site) for site in epitope_sites[:-1]]) + ", and " + str(epitope_sites[-1])

## Annotate number of epitope and non-epitope mutations per node

In [None]:
for node in tree.find_clades():
    epitope_mutations = 0
    nonepitope_mutations = 0
    
    if len(node.aa_muts) > 0:
        for gene, muts in node.aa_muts.items():
            for mut in muts:
                if gene == "HA1" and int(mut[1:-1]) in epitope_sites:
                    epitope_mutations += 1
                else:
                    nonepitope_mutations += 1
                    
    node.epitope_mutations = epitope_mutations
    node.nonepitope_mutations = nonepitope_mutations

In [None]:
set([node.epitope_mutations for node in tree.find_clades() if node.epitope_mutations > 0])

## Assign trunk status

[Bedford et al. 2015](http://www.nature.com.offcampus.lib.washington.edu/nature/journal/v523/n7559/pdf/nature14460.pdf) defines trunk as "all branches ancestral to viruses
sampled within 1 year of the most recent sample". The algorithm for finding the trunk based on this definition is then:

  1. Select all nodes in the last year
  1. Select the parent of each selected node until the root
  1. Create a unique set of nodes
  1. Omit all nodes from the last year since resolution of the trunk is limited (note: this step is not implemented below)

Note that this definition was based on 12 years of flu data from 2000 to 2012.

In [None]:
max_date = max([tip.attr["num_date"] for tip in tree.find_clades(terminal=True)])

In [None]:
max_date

In [None]:
# Find all tips of the tree sampled within a year of the most recent sample in the tree.
recent_nodes = [node for node in tree.find_clades(terminal=True) if node.attr["num_date"] > (max_date - 1)]

In [None]:
len(recent_nodes)

In [None]:
# Find the last common ancestor of all recent nodes.
mrca = tree.common_ancestor(recent_nodes)

In [None]:
mrca

In [None]:
mrca.attr["num_date"]

In [None]:
# Label all nodes as not part of the trunk by default.
for node in tree.find_clades():
    node.is_trunk = False
    node.is_side_branch_ancestor = False

In [None]:
# Find all nodes that are ancestral to recent nodes.
# Label these ancestral nodes as part of the "trunk"
# and collect the set of distinct nodes in the trunk.
for recent_node in recent_nodes:
    current_node = recent_node.parent
    
    # Traverse from the current node to the tree's root.
    while current_node != tree.root:
        # Mark a node as part of the trunk if it was sampled
        # before the MRCA of all recent nodes.
        if current_node.attr["num_date"] < mrca.attr["num_date"]:
            current_node.is_trunk = True
            
        current_node = current_node.parent

In [None]:
def is_side_branch_ancestor(node):
    """Returns True if the current node belongs to a "side branch" clade
    and is the immediate descendent from a trunk.
    """
    return node.parent is not None and node.parent.is_trunk

In [None]:
trunk_path = [node for node in tree.find_clades(terminal=False)
              if node.is_trunk]

In [None]:
# Find all nodes that are not on the trunk. These are
# side branch nodes.
side_branch_nodes = [node for node in tree.find_clades(terminal=False)
                     if not node.is_trunk and node.attr["num_date"] < mrca.attr["num_date"]]

In [None]:
len(trunk_path)

In [None]:
len(side_branch_nodes)

In [None]:
# Find all side branch nodes whose immediate parent is on the trunk.
side_branch_ancestors = []
for node in side_branch_nodes:
    if is_side_branch_ancestor(node):
        node.is_side_branch_ancestor = True
        side_branch_ancestors.append(node)

In [None]:
len(side_branch_ancestors)

In [None]:
# Color nodes by status as on the trunk or as a side branch ancestor.
for node in tree.find_clades():
    if node.is_trunk:
        node.color = "green"
    elif node.is_side_branch_ancestor:
        node.color = "orange"
    else:
        node.color = "black"

In [None]:
# Draw tree with node colors instead of with node labels.
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111)
Bio.Phylo.draw(tree, axes=ax, label_func=lambda node: "")

## Annotate mutations by trunk status

In [None]:
records = []
for node in tree.find_clades(terminal=False):
    # Collect records for nodes that are on the trunk or that were sampled prior
    # to the MRCA of recent nodes (i.e., side branch nodes).
    if node.is_trunk or node.attr["num_date"] < mrca.attr["num_date"]:
        records.append({
            "node": node.name,
            "branch type": "trunk" if node.is_trunk else "side branch",
            "epitope mutations": node.epitope_mutations,
            "non-epitope mutations": node.nonepitope_mutations
        })

In [None]:
df = pd.DataFrame(records)

In [None]:
df.head()

In [None]:
counts_by_trunk_status = df.groupby("branch type").aggregate({"epitope mutations": "sum", "non-epitope mutations": "sum"})

In [None]:
counts_by_trunk_status["epitope-to-non-epitope ratio"] = round(
    counts_by_trunk_status["epitope mutations"] / counts_by_trunk_status["non-epitope mutations"]
, 2)

In [None]:
counts_by_trunk_status

In [None]:
counts_by_trunk_status_table = counts_by_trunk_status.to_latex(escape=False)

with open(output_table, "w") as oh:
    oh.write(counts_by_trunk_status_table)