# Sc2ts vs Usher cophylogenies

## "Pango-representative" samples

To plot cophylogenies, we aim to find identical "representative samples" for Pango lineages that exist in both the Usher tree and the sc2ts tree, and have the same Pango assignment in both. 

However, the earliest sample of each Pango type could be an erroneous classication. To avoid this, this notebook identifies "originating nodes" for each pango. An originating node of (say) B.1.1.7 is the earliest node that has > 50% of the B.1.1.7 samples as descendants, and which itself is labelled B.1.1.7. As this is an ARG, there are many trees: we count the maximum number of samples in any tree. To find a representative sample, we pick the oldest descendant sample node of the origination node which has entirely B.1.1.7 samples as descendants in a tree.

To reduce the number of tips to compare, we also remove samples which are known Pango-X recombinants (and descendants of them).

In [None]:
import collections
import heapq
import json
import os

import numpy as np
from tqdm.auto import tqdm
import tskit
import tszip

import sc2ts

ts = tszip.load("../data/sc2ts_v1_2023-02-21_pp_dels_bps_pango_dated_mmps.trees.tsz")
df = sc2ts.node_data(ts)
df.set_index("sample_id", inplace=True)

# Get the associated data and join, so that we can use either "pango" or "Viridian_pangolin" defintions
ds = sc2ts.Dataset("../data/viridian_mafft_2024-10-14_v1.vcz")
df = df.join(ds.metadata.as_dataframe(["Viridian_pangolin"]))

# Set which pango designation to use:
# Use "pango" to get the pango designations for all nodes computed by postprocessing the ARG.
# Use "Viridian_pangolin" to use the sample designations provided by Viridian.
Pango = "pango"

pango_lineage_samples = df[df.is_sample].groupby(Pango)['node_id'].apply(list).to_dict()
print(f"ARG has {len(pango_lineage_samples)} pango designations")

In [None]:
cutoff = 0.50

pango_representative_sample = {}

def find_first_fully_tracked_sample_by_time(tree, root):
    # Efficiently traverse nodes in time order using a priority queue
    # and return the oldest sample where all descendants are tracked.
    # Takes advantage of the fact that parents have older times than children.
    node_times = tree.tree_sequence.nodes_time
    if root == tskit.NULL:
        root = tree.virtual_root
    
    # Priority queue: (time, node)
    pq = [(-node_times[root], root)]
    while pq:
        time, u = heapq.heappop(pq)
        # Skip nodes with no tracked samples
        if tree.num_tracked_samples(u) == 0:
            continue
        # Check if this node is fully tracked and also in the reduced Usher tree
        if tree.num_samples(u) == tree.num_tracked_samples(u) and tree.is_sample(u):
            return u
        # Add children to priority queue (they have younger times)
        for child in tree.children(u):
            heapq.heappush(pq, (-node_times[child], child))
    return None  # No fully tracked node found


def origination_node(simp_ts, df, pango):
    nodes_time = simp_ts.nodes_time
    valid = {}
    samples = pango_lineage_samples[pango]
    if len(samples) == 0:
        return None
    # Find valid nodes: have same Pango type and > cutoff descendant samples in any tree

    tmp_df = df[df[Pango] == pango]  # Only look at the ones designated pango
    for tree in simp_ts.trees(tracked_samples=samples):
        for row in tmp_df.itertuples():
            u = row.Index
            tnts = tree.num_tracked_samples(u)
            if tnts > valid.get(u, 0):
                if tnts > len(samples) * cutoff:
                    valid[u] = tnts
    # Check through valid nodes for the best origination node: return a single element array if good
    if len(valid) == 0:
        return None
    else:
        return max(valid, key=valid.get)

        
def representative_node(ts, pango, pango_root):
    # descend from the root pango in each tree, but avoid any subtree that has no
    # tracked_samples
    for tree in ts.trees(tracked_samples=pango_lineage_samples[pango]):
        nd = find_first_fully_tracked_sample_by_time(tree, pango_root)
        if nd is not None:
            return nd

# For speed, simplify the ARG down to a set of subgroups of pango samples
# before looking for the sample descendants
pangos = list(pango_lineage_samples.keys())
n_batches = int(np.sqrt(len(pangos)))
poor_orig_nd_pct = {}
no_valid_orig_nd = []

pango_origin = {}
for i in tqdm(np.arange(n_batches)):
    batch = pangos[i::n_batches]
    tmp_df = df.set_index("node_id")
    tmp_df = tmp_df.loc[np.isin(df[Pango], batch)]
    samples = np.array([u for b in batch for u in pango_lineage_samples[b]])
    simp_ts = ts.simplify(samples, keep_unary=True, filter_nodes=False)
    for pango in tqdm(batch, leave=False):
        orig_nd = origination_node(simp_ts, tmp_df, pango)
        if orig_nd is None:
            no_valid_orig_nd.append(pango)
        else:
            best_nd = representative_node(ts, pango, orig_nd)
            if best_nd is not None:
                vv = df.loc[df.node_id == best_nd, Pango].values
                assert len(vv) == 1 and vv[0] == pango
                pango_representative_sample[pango] = best_nd
                

print(f"Found {len(pango_representative_sample)} pango sample nodes / {len(pango_lineage_samples)} pango groups")
print(f"{len(no_valid_orig_nd)} pangos rejected as no pango origination node with > {cutoff*100:.1f}% pango descendants:")
print(" " + ", ".join(no_valid_orig_nd))
sts = ts.simplify(
    [v for k, v in pango_representative_sample.items()
     if not k.startswith("X")
    ], keep_unary=True, filter_nodes=False)

print(f"Simplified ARG to representative non-pangoX samples: {sts.num_trees} trees and {sts.num_samples} samples")

In [None]:
# Are there any samples with a recombination node immediately above

sts = ts.simplify(
    [v for k, v in pango_representative_sample.items()
     if not k.startswith("X")
    ], keep_unary=True, filter_nodes=False)

recombinants = np.where(sts.nodes_flags & sc2ts.NODE_IS_RECOMBINANT)[0]
## Make a dictionary, `sample_desc`, with the descendants of each sample node
sample_desc = collections.defaultdict(set)
for tree in sts.trees():
    for u in recombinants:
        if tree.num_samples(u) > 0:
            sample_desc[u] |= set(tree.samples(u))

for k, v in sample_desc.items():
    print(f"Node {k} ({ts.node(k).metadata['pango']}): {len(v)} descendant pangos")

one_sample_re = {list(v)[0] for v in sample_desc.values() if len(v) <= 1}

internal_nodes = set(sts.edges_parent)

sts = ts.simplify([
    v for k, v in pango_representative_sample.items()
    if (
        v not in one_sample_re and
        v not in internal_nodes and
        not k.startswith("X") and
        v not in sample_desc[1396207]  # Exclude descendants of XBB (some of these do not start with an X)
    )
], keep_unary=True, filter_nodes=False)

print(
    "After removing 3 singleton Pangos below a RE node, and further PangoX descendants "
    f" the ARG has {sts.num_trees} trees and {sts.num_samples} pangos")

sample_desc = collections.defaultdict(set)
for tree in sts.trees():
    for u in recombinants:
        if tree.num_samples(u) > 0:
            sample_desc[u] |= set(tree.samples(u))
for k, v in sample_desc.items():
    print(f"Node {k} ({ts.node(k).metadata['pango']}): {len(v)} descendant samples")

sample_desc[None] = set(sts.samples()) - set().union(*list(sample_desc.values()))
print(len(sample_desc[None]), "samples not under a RE node")

The 5 trees correspond to 4 recombination nodes.
* 200039 is the Delta recombinant, and is tree-like under that
* 822854 is the BA.2 recombinant. The BA.5 recombinant lies under this
* 1189192 is the BA.5 recombinant.
* 1030562 is probably wrong: the recombination node is on the far RHS @ 27382, and includes only 4 sites that differ between parents there

This can be seen in the copying pattern below

In [None]:
from IPython.display import HTML
HTML(
    "<style>table.copying-table {font-size: 8px} table.copying-table .pattern td {font-size: 0.5em; width:0.2em}</style>" +
    sc2ts.info.CopyingTable(ts, 1030562).html(child_label="1030562", show_bases=None)
)

In [None]:
tables = sts.dump_tables()
tables.reference_sequence.clear()
bp_1030562 = (sts.edges_right[sts.edges_child == 1030562]).min()
print(" We treat descendants of the 1030562 fake 'recombinant' as non-recombining, "
      f"by trimming away the RHS from {bp_1030562} onwards"
     )
tables.keep_intervals([[0, bp_1030562]], simplify=False)
tables.trim()
core_sc2ts = tables.tree_sequence()
core_df = sc2ts.node_data(core_sc2ts)
core_df_samples = core_df[core_df.is_sample]
core_df_samples.set_index("sample_id", inplace=True)
core_df_samples = core_df_samples.join(ds.metadata.as_dataframe(["Date_tree", "Viridian_pangolin"]))

core_sc2ts_map = {}
core_sc2ts_pango_map = {}
for row in core_df_samples.itertuples():
    core_sc2ts_map[row.Index] = row.node_id
    core_sc2ts_pango_map[getattr(row, Pango)] = row.Index
    
assert core_sc2ts.num_trees == 4  # 3 breakpoints

In [None]:
uts = tszip.load("../arg_postprocessing/usher_v1_2024-06-06-di.trees")
print(f"Using Usher tree with {uts.num_samples} samples")

In [None]:
usher_map = {}
usher_pango = collections.defaultdict(list)
for u in tqdm(uts.samples()):
    nd = uts.node(u)
    usher_map[nd.metadata['strain']] = u

In [None]:
joint_keys = core_sc2ts_map.keys() & usher_map.keys()
print("Found", len(joint_keys), "sample ids shared between the sc2ts ARG and the Usher tree")
pangos = {}  # Map the key (e.g. ERR10001879) to a pango
for k in joint_keys:
    try:
        pangos[k] = df.loc[k, Pango]
    except IndexError:
        print(k, "not found")
reverse_pangos = {v: k for k, v in pangos.items()}
pango_numbers = {k: len(pango_lineage_samples[pangos[k]]) for k in joint_keys}


In [None]:
print("Aligning times between the Usher tree and the sc2ts tree")
core_sc2ts_nodes = np.array([core_sc2ts_map[k] for k in joint_keys])
usher_nodes = np.array([usher_map[k] for k in joint_keys])
time_diff = np.round((uts.nodes_time[usher_nodes]- core_sc2ts.nodes_time[core_sc2ts_nodes]).mean(), 5)
tables = uts.dump_tables()
tables.nodes.time = tables.nodes.time - time_diff
tables.mutations.time = tables.mutations.time - time_diff
usher_ts = tables.tree_sequence()

In [None]:
def sc2ts_tanglegram(
    ts,
    size=(800, 800),
    time_scale="rank",
    x_ticks=None,
    line_gap=40,
    separation=100,
    style="",
    label="",
    omit_sites=None,
    sample_fontsize="8px",
    **kwargs
):
    if omit_sites is None:
        omit_sites = True
    if x_ticks is not None:
        if time_scale=="rank": 
            # rescale the X ticks on each side
            node_ids = np.array(list(ts.first().nodes()))
            times = {t : i for i, t in enumerate(np.unique(ts.nodes_time[node_ids]))}
            # slight hack - just find the nearest time
            x_ticks_lft = {}
            for key, lab in x_ticks.items():
                closest_time = min(times.keys(), key=lambda x: abs(x - key))
                x_ticks_lft[times[closest_time]] = lab
            node_ids = np.array(list(ts.last().nodes()))
            times = {t : i for i, t in enumerate(np.unique(ts.nodes_time[node_ids]))}
            # slight hack - just find the nearest time
            x_ticks_rgt = {}
            for key, lab in x_ticks.items():
                closest_time = min(times.keys(), key=lambda x: abs(x - key))
                x_ticks_rgt[times[closest_time]] = lab
            x_ticks = (x_ticks_lft, x_ticks_rgt)
            
        else:
            x_ticks = (x_ticks, x_ticks)
    return nb_utils.tanglegram(
        ts, size=size, omit_sites=omit_sites, time_scale=time_scale, line_gap=line_gap,
        titles=(f"Sc2ts {label}", f"Usher {label}"),
        style="g.tangle_lines line {stroke: lightgrey} .sample .lab {font-size: %s}" % sample_fontsize + style,
        node_labels={u: pangos[ts.node(u).metadata['sample_id']] for u in ts.samples()},
        separation=separation,
        x_ticks=x_ticks,
        **kwargs
    )

In [None]:
# Find the recombinant parents in the recombinant backbone

edges = core_sc2ts.edges_child == 200039
delta_parents = {k: v for k, v in zip(core_sc2ts.edges_right[edges], core_sc2ts.edges_parent[edges])}
delta_parents = [delta_parents[k] for k in sorted(delta_parents.keys())]
delta_parents = {reverse_pangos[p]: p for p in df.loc[np.isin(df.node_id, delta_parents), "pango"]}

edges = core_sc2ts.edges_child == 822854
ba2_parents = {k: v for k, v in zip(core_sc2ts.edges_right[edges], core_sc2ts.edges_parent[edges])}
ba2_parents = [ba2_parents[k] for k in sorted(ba2_parents.keys())]
ba2_parents = {reverse_pangos[p]: p for p in df.loc[np.isin(df.node_id, ba2_parents), "pango"]}

try:
    edges = core_sc2ts.edges_child == 1189192
    ba5_parents = {k: v for k, v in zip(core_sc2ts.edges_right[edges], core_sc2ts.edges_parent[edges])}
    ba5_parents = [ba5_parents[k] for k in sorted(ba5_parents.keys())]
    ba5_parents = {reverse_pangos[p]: p for p in df.loc[np.isin(df.node_id, ba5_parents), "pango"]}
except KeyError:
    #print(f"Could not find samples representing {df.loc[np.isin(df.node_id, ba5_parents), "pango"].values}")
    ba5_parents = {}



## Restricting to important pangos


In [None]:
import nb_utils
min_samples = 10
use = [k for k in joint_keys if pango_numbers[k] > min_samples]
print(f"Plotting Pango lineages that have over {min_samples} sequenced samples")

subtree_samples = {
    "base tree": (
        list({k for k in use if core_sc2ts_map[k] in sample_desc[None]} | delta_parents.keys() | ba2_parents.keys())
    ),
    "Delta subtree": [k for k in use if core_sc2ts_map[k] in sample_desc[200039]],
    "BA.2 subtree": [k for k in use if (core_sc2ts_map[k] in sample_desc[822854] and core_sc2ts_map[k] not in sample_desc[1189192])],
    "BA.5 subtree": [k for k in use if core_sc2ts_map[k] in sample_desc[1189192]],
}    

two_tree_ts = {}
total = 0
for subtree, sample_names in subtree_samples.items():
    jts = nb_utils.make_joint_ts(
        core_sc2ts.simplify([core_sc2ts_map[k] for k in sample_names], keep_input_roots=True),
        usher_ts.simplify([usher_map[k] for k in sample_names], keep_input_roots=True),
        "sample_id", "strain"
    )
    print(f"Joint tree sequence of {subtree} has {jts.num_samples} samples, classified as:")
    scorpio_counts = collections.Counter(
        [jts.node(u).metadata['scorpio'] for u in jts.samples()]
    )
    print("\n".join([f"\t{'unclassified' if k == "nan" else k}: {v} samples" for k, v in scorpio_counts.items()]))
    two_tree_ts[subtree] = jts
    total += jts.num_samples

print(f"Total number of plotted Pangos={total}")
    

In [None]:
import datetime

x_ticks = {}
zero = datetime.date.fromisoformat(core_sc2ts.metadata["time_zero_date"])
months = list(range(1, 13))  # 1 to 12

target_dates = (
    [datetime.date(2023, 2, 1)] + [datetime.date(2023, 1, 1)] + # 2023: Jan, Feb only
    [datetime.date(year, month, 1) for year in [2022, 2021, 2020, 2019] for month in reversed(months)]
)
for target_date in target_dates:
    days_difference = (zero - target_date).days
    x_ticks[days_difference] = target_date.strftime("%Y-%m")

## Untangling algorithm

We use _dendroscope_ (Huson and Scornavacca, DOI:10.1093/sysbio/sys062) to untangle the trees. This is still the most effective software for untangling trees with polytomies.

In [None]:
dendroscope_binary = "/Applications/Dendroscope/Dendroscope.app/Contents/MacOS/JavaApplicationStub"

import tempfile
import subprocess
import re

def run_dendroscope_untangle(ts_2_trees):
    # Uses the Neighbor-net heuristic algorithm, which works well with polytomies
    assert ts_2_trees.num_trees == 2
    with tempfile.TemporaryDirectory() as tmpdirname:
        newick_path = os.path.join(tmpdirname, "cophylo.nwk")
        command_path = os.path.join(tmpdirname, "commands.txt")
        with open(newick_path, "wt") as file:
            for tree in ts_2_trees.trees():
                print(tree.as_newick(), file=file)
        with open(command_path, "wt") as file:
            print(f"open file='{newick_path}';", file=file)
            print("compute tanglegram method=nnet", file=file)
            print(
                f"save format=newick file='{newick_path}'", file=file
            )  # overwrite
            print("quit;", file=file)
        subprocess.run([dendroscope_binary, "-g", "-c", command_path])
        order = []
        with open(newick_path, "rt") as newicks:
            for line in newicks:
                # hack: use the order of `nX encoded in the string
                order.append([int(n[1:]) for n in re.findall(r"n\d+", line)])
    return order

In [None]:
# Use the run_dendroscope_untangle above: this can take hours, so we save the results

orders = {}
for subtree, ts2 in two_tree_ts.items():
    path = subtree + " order-" + Pango + str(min_samples) + ".json"
    if not os.path.exists(path):
        orders[subtree] = run_dendroscope_untangle(ts2)
        with open(path, "wt") as f:
            f.write(json.dumps(orders[subtree]))
    else:
        with open(path, "rt") as f:
            orders[subtree] = json.loads(f.read())


In [None]:
from hashlib import blake2b

def plot_tanglegram(ts, name, orders, delta_parents=None, ba2_parents=None, legend_func=None, **kwargs):
    
    def hash_samples_under_node(tree, u):
        return blake2b(" ".join(str(u) for u in sorted(tree.samples(u))).encode(), digest_size=20).digest()

    scorpios = collections.defaultdict(list)
    sample_id_to_node_id = {}
    for u in ts.samples():
        nd = ts.node(u)
        s = nd.metadata['scorpio']
        s = "unclassified" if s=="nan" else s
        scorpios[s].append(u)
        sample_id_to_node_id[nd.metadata["sample_id"]] = u


    ltree, rtree = ts.first(), ts.last()
    l_hashes = {hash_samples_under_node(ltree, u): u for u in ltree.nodes() if not ltree.is_sample(u)}
    r_hashes = {hash_samples_under_node(rtree, u): u for u in rtree.nodes() if not rtree.is_sample(u)}

    joint_hashes = l_hashes.keys() & r_hashes.keys()

    colours = ["#77AADD", "#EE8866", "#EEDD88", "#FFAABB", "#44BB99", "#BBCC33", "#AAAA00", "#99DDFF",]

    tg, lft_map, rgt_map = sc2ts_tanglegram(ts, label=name, order=orders, symbol_size=3, x_axis=True, **kwargs)
    styles = ['.y-axis .ticks .lab {font-size: 10px; font-family: "Arial Narrow"}']
    # Add styles for edges and shared nodes
    for nd_map, cls, hashes in zip([lft_map, rgt_map], ["lft_tree", "rgt_tree"], [l_hashes, r_hashes]):
        styles += [
            (
                ",".join([f".{cls} > .tree .n{nd_map[u]} .edge" for u in v]) + f"{{stroke: {c}}}" +
                ",".join([f".{cls} > .tree .n{nd_map[u]}.sample .sym" for u in v]) + f"{{fill: {c}}}"
            )
            for c, (k, v) in zip(colours, scorpios.items())
        ]
        styles += [
            ",".join([f".{cls} > .tree .n{nd_map[hashes[hsh]]} > .sym" for hsh in joint_hashes]) +
            "{r: 2px; fill: magenta; stroke: black;}"
        ]

    legend = "" if legend_func is None else legend_func(colours, scorpios.keys())

    if delta_parents:
        delta_node_ids = [sample_id_to_node_id[k] for k in delta_parents.keys()]
        styles += [
            (
                f".lft_tree > .tree .node.n{lft_map[u]} .sym " +
                f"{{transform: translate(0px, 3px) scale(2.5); stroke: black; stroke-width: 0.5px; fill: {c};}}" +
                f".lft_tree > .tree .node.n{lft_map[u]} .lab " +
                "{text-anchor: start; transform: rotate(90deg) translate(10px);}"
            )
            for u, c in zip(delta_node_ids, ("#9CDB90", "#76A8D8"))]
    if ba2_parents:
        ba2_node_ids = [sample_id_to_node_id[k] for k in ba2_parents.keys()]
        styles += [
            (
                f".lft_tree > .tree .node.n{lft_map[u]} .sym " +
                f"{{transform: translate(0px, 3px) scale(2.5) rotate(45deg); stroke: black; stroke-width: 0.5px; fill: {c};}}" +
                f".lft_tree > .tree .node.n{lft_map[u]} .lab " +
                "{text-anchor: start; transform: rotate(90deg) translate(10px);}"
            )
            for u, c in zip(ba2_node_ids, ("#9CDB90", "#76A8D8"))
        ]
    
    tg.preamble = "<style>" + "".join(styles) + "</style>" + tg.preamble + legend  
    return tg.draw()

In [None]:
def base_legend(colours, labels):
    return (
        '<g transform="translate(447 20)">' +
        '<rect x="0" y="0" width="140" height="145" fill="white" stroke="black" />' +
        '<text x="10" y="17" font-style="italic" font-size="14px">Scorpio classification</text>' +
        "".join([
            (f'<line x1="5" x2="25" y1="{30 + 15*i}" y2="{30 + 15*i}" stroke="{c}" stroke-width="2" />'
             f'<text x="30" y="{34 + 15 * i}" font-size="12px">{label}</text>'
            )
            for i, (c, label) in enumerate(zip(colours, labels))
        ]) +
        '</g>' +
            
        '<g transform="translate(467 200)">' +
        '<rect x="0" y="0" width="100" height="130" fill="white" stroke="black" />' +
        '<text x="20" y="17" font-style="italic" font-size="14px">Node types</text>' +
        '<rect x="10" height="8" y="30" width="8" stroke="black" stroke-width="1.5" fill="#9CDB90"/>' +
        '<text x="25" y="38" font-size="12px">Delta (left)</text>' +
        '<rect x="10" height="8" y="45" width="8" stroke="black" stroke-width="1.5" fill="#76A8D8" />' +
        '<text x="25" y="53" font-size="12px">Delta (right)</text>' +
        '<rect height="8" width="8" stroke="black" stroke-width="1.5" transform="translate(14, 65) rotate(45)" fill="#9CDB90" />' +
        '<text x="25" y="75" font-size="12px">BA.2 (left)</text>' +
        '<rect height="8" width="8" stroke="black" stroke-width="1.5" transform="translate(14, 80) rotate(45)" fill="#76A8D8" />' +
        '<text x="25" y="90" font-size="12px">BA.2 (right)</text>' +
        '<circle cx="14" cy="115" r="3" stroke="black" stroke-width="1.5" fill="magenta" />' +
        '<text x="25" y="120" font-size="12px">Identical clade</text>' +
        '</g>'
    )

name = "base tree"

plot_tanglegram(
    two_tree_ts[name], name, orders[name],
    size=(1000, 1450),
    x_ticks={k: v for k, v, in x_ticks.items() if 400 < k < 1200},
    delta_parents=delta_parents,
    ba2_parents=ba2_parents,
    legend_func=base_legend,
    x_label="",
    sample_fontsize="6px",
    line_gap=25, 
    
)

In [None]:
def delta_legend(colours, labels):
    elems = {}
    i = 0
    # Hack to get a narrower legend by wrapping legend labels on "+"
    for c, label in zip(colours, labels):
        svgline = '<line x1="5" x2="25" y1="{y}" y2="{y}" stroke="{c}" stroke-width="2" />'
        if "+" not in label:
            elems[i] = (svgline.format(c=c, y=30 + 15*i), label)
            i += 1
        else:
            label = label.split("+")
            label[1] = "   +" +  label[1]
            for l in label:
                elems[i] = ("" if "+" in l else svgline.format(c=c, y=30 + 15*(i+0.5)), l)
                i += 1
    return (
        '<g transform="translate(450 370)">' +
        '<rect x="0" y="0" width="140" height="100" fill="white" stroke="black" />' +
        '<text x="10" y="17" font-style="italic" font-size="14px">Scorpio classification</text>' +
        "".join([
            f'{svgline}<text xml:space="preserve" x="30" dy="{34 + 15 * i}" font-size="12px">{label}</text>'
            for i, (svgline, label) in elems.items()
        ]) +
        '</g>' +
        '<g transform="translate(455 500)">' +
        '<rect x="0" y="0" width="120" height="40" fill="white" stroke="black" />' +
        '<text x="20" y="17" font-style="italic">Node types</text>' +
        '<circle cx="14" cy="30" r="3" stroke="black" stroke-width="1" fill="magenta" />' +
        '<text x="30" y="35" font-size="14px">Identical clade</text>' +
        '</g>'
   
    )

name = "Delta subtree"

plot_tanglegram(
    two_tree_ts[name],
    name, orders[name],
    size=(1000, 600),
    x_ticks={k: v for k, v, in x_ticks.items() if 450 < k < 1200},
    time_scale="rank",
    line_gap=30,
    tweak_rh_lab=-2.5,
    legend_func=delta_legend,
    x_label="",
)

In [None]:
def ba2_legend(colours, labels):
    return (
        '<g transform="translate(430 20)">' +
        '<rect x="0" y="0" width="170" height="60" fill="white" stroke="black" />' +
        '<text x="20" y="17" font-style="italic" font-size="14px">Scorpio classification</text>' +
        "".join([
            (f'<line x1="5" x2="25" y1="{30 + 15*i}" y2="{30 + 15*i}" stroke="{c}" stroke-width="2" />'
             f'<text x="30" y="{34 + 15 * i}" font-size="12px">{label}</text>'
            )
            for i, (c, label) in enumerate(zip(colours, labels))
        ]) +
        '</g>' +
    
        '<g transform="translate(455 85)">' +
        '<rect x="0" y="0" width="120" height="40" fill="white" stroke="black" />' +
        '<text x="22" y="17" font-style="italic" font-size="14px">Node types</text>' +
        '<circle cx="14" cy="30" r="4" stroke="black" stroke-width="1" fill="magenta" />' +
        '<text x="30" y="35" font-size="12px">Identical clade</text>' +
        '</g>'
    )

name = "BA.2 subtree"

plot_tanglegram(
    two_tree_ts[name],
    name, orders[name],
    size=(1000, 650),
    x_ticks={k: v for k, v, in x_ticks.items() if 40 < k < 450},
    time_scale="rank", 
    line_gap=30,
    legend_func=ba2_legend,
)


In [None]:
name = "BA.5 subtree"

def ba5_legend(colours, labels):
    return (
        '<g transform="translate(455 485)">' +
        '<rect x="0" y="0" width="120" height="40" fill="white" stroke="black" />' +
        '<text x="22" y="17" font-style="italic" font-size="14px">Node types</text>' +
        '<circle cx="14" cy="30" r="4" stroke="black" stroke-width="1" fill="magenta" />' +
        '<text x="30" y="35" font-size="12px">Identical clade</text>' +
        '</g>'
    )

plot_tanglegram(
    two_tree_ts[name], name, orders[name],
    size=(1000, 1000),
    x_ticks={k: v for k, v, in x_ticks.items() if 0 < k < 1100}, time_scale="rank", 
    legend_func=ba5_legend,
    line_gap=30,
)