# Analyze MERS recombination networks from Muller et al.

Experimental work to parse and interpret MERS recombination networks produced by [Muller et al.](https://bedford.io/papers/muller-cov-recombination/). These networks were stored in extended NEXUS format and we need to parse them into a corresponding Python network data structure. This notebook shows how to identify recombination donor/recipient pairs in the NEXUS data structure. These pairs and the tree topology can then be used to create a network data structure.

## Imports

In [None]:
from augur.utils import annotate_parents_for_tree, read_node_data
import Bio.Phylo.NexusIO
from collections import defaultdict
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import seaborn as sns

%matplotlib inline

In [None]:
sns.set_style("white")

## Inputs

In [None]:
mers_clades_path = "data/manual_multihost_clades.json"

In [None]:
mers_network_path = "data/mers_all.tree"

## Outputs

In [None]:
recombination_groups_path = "recombination_groups.tsv"

In [None]:
clade_recombination_groups_path = "clade_recombination_groups.tsv"

## Load clade annotations

In [None]:
clades = read_node_data(mers_clades_path)

## Load phylogenetic network

In [None]:
trees = list(Bio.Phylo.NexusIO.parse(mers_network_path))
tree = trees[0]

In [None]:
tree = annotate_parents_for_tree(tree)

In [None]:
# Make a single pass through the tree in postorder to store a set of all
# terminals descending from each node. This uses more memory, but it allows
# faster identification of MRCAs between any pair of tips in the tree and
# speeds up pairwise distance calculations by orders of magnitude.
for node in tree.find_clades(order="postorder"):
    node.terminals = set()
    for child in node.clades:
        if child.is_terminal() and not child.name.startswith("#"):
            node.terminals.add(child.name)
        else:
            node.terminals.update(child.terminals)

## Identify recombination pairs

Recombination events in the network appear as nodes in the tree named like `#HXX` where `XX` is an integer id for the event. There should be two nodes in the "tree" with the same name, one with no children (a terminal, indicating the donor of the recombination event) and another with at least one child (the recipient of the recombination event).

In [None]:
recombination_pairs = defaultdict(list)
for n in tree.find_clades():
    if n.name is None and n.confidence is not None:
        n.name = n.confidence
        n.confidence = None

    if n.name is not None and n.name.startswith("#"):
        recombination_pairs[n.name].append(n)

In [None]:
len(recombination_pairs)

## Parse lengths of recombination events

Each recombination pair should have a donor node (one with no children) that has a "length" annotation in the node's `comment` field. The parsing below could be more robust, but it was sufficient for this initial experimental work.

In [None]:
length_by_donor = {}
lengths = []
for name, pairs in recombination_pairs.items():
    print(f"event name: {name}")
    for node in pairs:
        print(f"node name: {node.name}")
        print(f"children: {len(node.clades)}")
        length = [
            int(piece.split("=")[1].replace("]", ""))
            for piece in node.comment.split(",")
            if "length=" in piece
        ][0]
        node.length = length
        print(f"length: {length}")
        
        # Donor nodes have no children.
        if len(node.clades) == 0:
            node.is_donor = True
            lengths.append(length)
            
            length_by_donor[node.name] = length
        else:
            node.is_donor = False
    print()

In [None]:
length_by_donor

## Summarize recombination event lengths

Use the distribution of recombination event lengths to reason about the minimum size an event should be to define "subtrees" in the network that share that same event.

In [None]:
median_length = np.median(lengths)

In [None]:
median_length

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=150)
ax.hist(lengths)

ax.axvline(x=median_length, color="red", label=f"median length={median_length} bp")

ax.set_xlabel("event length (bp)")
ax.set_ylabel("number of events")

ax.legend(frameon=False)

sns.despine()

Set the threshold for minimum recombination donor size in base pairs. This threshold is somewhat arbitrary, like any threshold, but we want a value that excludes most small events that are unlikely to substantially alter the recipient genome.

In [None]:
length_threshold = 5000

In [None]:
large_events = (np.array(lengths) >= length_threshold).sum()

In [None]:
large_events

In [None]:
total_events = len(lengths)

In [None]:
total_events

In [None]:
large_events / total_events

## Inspect some recombination pairs

Inspect the values associated with nodes in a given recombination pair, to better understand what data are available to us. This was important for understanding the directionality of recombination events and the associated length of the donor sequences.

In [None]:
recombination_pairs["#H59"][0].__dict__

In [None]:
recombination_pairs["#H59"][1].__dict__

In [None]:
recombination_pairs["#H52"][0].__dict__

In [None]:
recombination_pairs["#H52"][1].__dict__

## Calculate number of recombination events between tips

Calculate the number of recombination events on the path between two tips in the network that were manually inspected ahead of time with [IcyTree](https://icytree.org/). If we parsed the network properly, we should find 9 events between the two tips (2 on the path from one tip to its MRCA with the other tip and 7 on the path from the other tip to their MRCA).

The logic below is based on [a script from the seasonal flu Nextstrain build](https://github.com/nextstrain/seasonal-flu/blob/87740b36b4c4b11b9b450b7ad1e558390d6a1289/flu-forecasting/scripts/pairwise_titer_tree_distances.py#L12-L36) that needs to make similar pairwise tip distance calculations.

In [None]:
tip_a = [tip for tip in tree.find_clades(terminal=True) if "Riyadh_9_2013" in tip.name][0]

In [None]:
tip_a

In [None]:
tip_b = [tip for tip in tree.find_clades(terminal=True) if "Hafr-Al-Batin_4_2013" in tip.name][0]

In [None]:
tip_b

In [None]:
# Find MRCA of tips from one tip up. Sum the distance of interest
# while walking up to the MRCA, to avoid an additional pass later. The loop
# below stops when the past node is found in the list of the candidate
# MRCA's terminals. This test should always evaluate to true when the MRCA
# is the root node, so we should not have to worry about trying to find the
# parent of the root.
current_node_branch_sum = 0.0
mrca = tip_a
while tip_b.name not in mrca.terminals:
    if hasattr(mrca, "is_donor") and not mrca.is_donor:
        # Count the number of recombination events.
        current_node_branch_sum += 1
    
    # Print debugging info, so we understand the network traversal.
    print(f"Changing MRCA from {mrca} to {mrca.parent}")
    mrca = mrca.parent

In [None]:
current_node_branch_sum

In [None]:
mrca

In [None]:
# Sum the node weights for the other tip from the bottom up until we reach
# the MRCA. The value of the MRCA is intentionally excluded here, as it
# would represent the branch leading to the MRCA and would be outside the
# path between the two tips.
past_node_branch_sum = 0.0
current_node = tip_b
while current_node != mrca:
    if hasattr(current_node, "is_donor") and not current_node.is_donor:
        # Count the number of recombination events on the path from
        # the second tip to its MRCA with the first tip.
        past_node_branch_sum += 1
        
    current_node = current_node.parent

In [None]:
past_node_branch_sum

In [None]:
final_sum = past_node_branch_sum + current_node_branch_sum

In [None]:
final_sum

## Assign tips to groups based on the most recent shared recombination event of a minimum size

In the first attempt to assign tips to "subtrees" akin to clades or reassortment clusters (MCCs) from TreeKnit, we traverse the tree in postorder to find the most recent recombination event of a minimum size and then assign all tips descending from that event to a group named after the recombination event.

Importantly, we look for nodes that are _recipients_ of the recombination event and then check the length of the sequence contributed by the donor event (located elsewhere in the network). The precalculated data structure with length by donor helps us do this.

In [None]:
for node in tree.find_clades(terminal=True):
    node.group = None

In [None]:
for node in tree.find_clades(order="postorder"):
    # Find recipient nodes for recombination events with events longer than the threshold.
    if hasattr(node, "is_donor") and not node.is_donor and length_by_donor[node.name] >= length_threshold:
        # Assign all unassigned tips to a group named after the current node.
        print(f"Node '{node.name}' defines a cluster")
        for tip in node.find_clades(terminal=True):
            # Donor recombination events appear as terminal nodes in the "tree",
            # so we need to omit these from our group assignment by ignoring nodes
            # with recombination event names like "#HXX".
            if not tip.name.startswith("#") and tip.group is None:
                tip.group = node.name

## Inspect the set of all groups

There should be no more groups than the number of recombination events with a length greater than or equal to the threshold. However, there can be fewer groups, when all tips have been assigned to groups based on later recombination events and the earlier recombination events have no tips to be assigned.

In [None]:
(np.array(lengths) >= length_threshold).sum()

In [None]:
groups = {
    tip.group
    for tip in tree.find_clades(terminal=True)
    if not tip.name.startswith("#") and tip.group is not None
}

In [None]:
len(groups)

In [None]:
groups

## Create and export a data frame of recombination groups

The resulting TSV file can be dragged onto Auspice to quickly inspect the groups and iterate with different length thresholds, as desired.

In [None]:
recombination_records = []
for tip in tree.find_clades(terminal=True):
    if tip.group is not None:
        recombination_records.append({
            "strain": tip.name,
            "group": tip.group[1:],
        })

In [None]:
recombination_groups = pd.DataFrame(recombination_records)

In [None]:
recombination_groups

In [None]:
recombination_groups.to_csv(recombination_groups_path, sep="\t", index=False)

## Annotate groups based on coalescent and recombination events

Use previously defined multi-host clade annotations to represent shared coalescent events for groups of tips and combine these annotations with the recombination group annotations. The joint annotation of these event types should reflect the most evolutionarily related tips in the network for inspect in different embedding spaces.

In [None]:
clades_df = pd.DataFrame([
    {"strain": key, "clade": values["multihost_clade_membership"]}
    for key, values in clades["nodes"].items()
])

In [None]:
clades_df

In [None]:
clade_recombination_groups = recombination_groups.merge(clades_df, on="strain")

In [None]:
clade_recombination_groups["cluster"] = clade_recombination_groups.apply(lambda row: row["clade"] + "|" + row["group"], axis=1)

In [None]:
clade_recombination_groups["cluster"].value_counts()

In [None]:
clade_recombination_groups["cluster"].drop_duplicates().shape

As above, export the data frame of clade/recombination groups to a TSV file that we can drag onto Auspice and inspect in the context of the tree and different embeddings.

In [None]:
clade_recombination_groups.to_csv(clade_recombination_groups_path, sep="\t", index=False)

## Create a network data structure from the tree data structure

To simplify thinking about distances between tips in the network, create a network data structure with [networkx](https://networkx.org/) based on our knowledge of how the recombination network is encoded in a tree data structure. This network structure should also simplify plotting of the network and inspection in other visualization tools that accept dot files.

The network should be a directed acyclic graph (DAG) with edges between internal nodes and their descendents and edges between donor and recipient recombination events. Edge weights should be calculated from the branch lengths annotated to the corresponding internal or donor nodes. We may want to scale edge weights for recombination events by proportion of the genome contributed from the donor (e.g., `branch_length * (event_length / genome_length)`). This is up for discussion, though.

In [None]:
# Code TBD.
network = nx.Graph()

## Calculate pairwise tip distances on the network

Given the network data structure defined above, calculate the distance between all pairs of tips in the network. Start by implementing a function that calculates the distance between any given pair of tips and testing this function with a small subset to confirm that the implementation works. This distance calculation should scale exponentially with the number of tips in the network, so expect it to take a while.

In [None]:
# Code TBD.
def get_network_distance(network, tip_a, tip_b):
    """Return the distance between the given tips on the given network.
    
    Arguments
    ---------
    network : nx.Graph
        a graph representation of a phylogenetic network with weighted edges between
        ancestral nodes of coalescent events and their children and weighted edges
        between donor/recipient pairs in recombination events.
        
    tip_a : string
        name of a tip in the given network to calculate distance to the second tip.

    tip_b : string
        name of a tip in the given network to calculate distance to the first tip.
        
    Returns
    -------
    float :
        distance between tips in the network
    
    """
    pass