In [1]:
import msprime
import math
import collections
import itertools
import numpy as np
from IPython.core.display import HTML

In [2]:
header = """<script src="http://www.x3dom.org/download/x3dom.js"></script>
<link rel="stylesheet" href="http://www.x3dom.org/download/x3dom.css">"""

In [3]:
import types
d = types.SimpleNamespace()
d.discretise_coordinates = False
d._width = 100
d._height = 100

NULL_NODE = -1

In [4]:
def identify_pruned_subtree(edges_out, edges_in):
    """If the subtree has been grafted back onto one of the pruned edges,
    it is logically impossible to use edge changes to identify which of
    the two daughter clades was the one that was pruned. Fortunately, this
    also does not result in a change in topology, so it is irrelevant for
    display purposes. These cases result in either 2 or 3 edge changes, 
    depending if they happen at the root or at an internal node.
    
    Returns the pruned node. Only works for bifurcating trees with 
    infinite recombination locations
    """
    assert len(edges_out) == len(edges_in), (edges_out, edges_in)
    assert 3 <= len(edges_out) <= 4, (edges_out, edges_in)
    # The parent of the moved node is the only one that is the parent of 2 removed nodes and the child of 1
    parent_to_child = collections.defaultdict(set)
    child_to_parent = collections.defaultdict(set)
    for e in edges_out:
        parent_to_child[e.parent].add(e.child)
        child_to_parent[e.child].add(e.parent)
    target_parent = [nd for nd, n in parent_to_child.items() if len(n)==2 and len(child_to_parent[nd])==1]
    if len(target_parent) == 0 and len(edges_out) == 3:
        # This may not have changed the topology - it could be a graft onto the same branch
        target_parent = [nd for nd, n in parent_to_child.items() if len(n)==2 and len(child_to_parent[nd])==0]
        if len(target_parent) != 1:
            return None
        else:
            # this involves the root. The unmoved child is now the root, and should not
            # be present as a child node in any of the in_edges
            target_children = parent_to_child[target_parent[0]]
            for e in edges_in:
                if e.child in target_children:
                    return e.child
            assert False

    assert len(target_parent) == 1, (edges_out, edges_in, target_parent)
    target_children = parent_to_child[target_parent[0]]
    assert len(target_children) == 2
    # this is an internal move
    grandparent = child_to_parent[target_parent[0]].pop()
    # If the branch has not moved, its new parent should == the old grandparent 
    bad_child = [e.child for e in edges_in if e.child in target_children and e.parent == grandparent]
    if len(bad_child) == 0 and len(edges_out) == 3:
        # This has not changed the topology - it is a graft onto the same branch
        return None
    assert len(bad_child)==1, (edges_out, edges_in, target_children, bad_child)
    for nd in target_children:
        if nd != bad_child[0]:
            return nd


def _discretise(self, x):
        """
        Discetises the specified value, if necessary.
        """
        ret = x
        if self.discretise_coordinates:
            ret = int(round(x))
        return ret

def _assign_x_coordinates(self, node, sample_order_dict=None):
    """
    Assign x coordinates to all nodes underneath this node, keeping the 
    samples in the order specified by sample_order_dict where possible,
    by the simple expedient of putting the max value in the sample_order_dict
    on one side of the bifurcation
    """
    if self._tree.is_internal(node):
        children = self._tree.children(node)
        if sample_order_dict:
            children = sorted(children, key=lambda n: max([sample_order_dict[x] for x in self._tree.samples(n)]))
        else:
            children = sorted(children, key=lambda n: max(self._tree.samples(n)))
        old_c=-1
        for c in children:
            _assign_x_coordinates(self, c, sample_order_dict)
        coords = [self._x_coords[c] for c in children]
        a = min(coords)
        b = max(coords)
        self._x_coords[node] = _discretise(self, a + (b - a) / 2)
    else:
        self._x_coords[node] = _discretise(self, self._leaf_x * self._x_scale)
        self._leaf_x += 1


def _assign_coordinates(self, log_y=False, best_tip_order=None, return_tip_order=True, pruned_node=None):
    """
    Assign geometrical coordinates to the nodes. If best_tip_order is
    given (and pruned_node is None) try to keep tips in this order (or near it) 
    - you can pass the order in the previous tree. If pruned_node is given, then
    this is assumed to be a single SPR move from the best_tip_order, and 
    we attempt to create an order that only differs from the best_tip_order by
    moving the pruned_node (and its children)
    
    Returns an OrderedDict giving the mappings of tip id to X pos in order
    """
    self._x_coords = {}
    self._y_coords = {}

    y_padding = 20
    t = self._tree.tree_sequence.max_root_time
    #if self._tree.num_roots > 0:
    #    t = max(self._tree.time(root) for root in self._tree.roots)
    # In pathological cases, all the roots are at time 0
    #if t == 0:
    #    t = 1
    # Do we have any mutations over a root?
    mutations_over_root = any(
        self._tree.parent(mut.node) == NULL_NODE for mut in self._tree.mutations())
    root_branch_length = 0
    if mutations_over_root:
        # Allocate a fixed about of space to show the mutations on the
        # 'root branch'
        root_branch_length = self._height / 10
    self._y_scale = (self._height - root_branch_length - 2 * y_padding) / t
    self._y_coords[-1] = y_padding
    for u in self._tree.nodes():
        scaled_t = self._tree.get_time(u) * self._y_scale
        self._y_coords[u] = self._height - scaled_t - y_padding
    self._x_scale = self._width / (self._num_leaves + 2)
    self._leaf_x = 1
    for root in self._tree.roots:
        if best_tip_order:
            if pruned_node:
                pruned_tips = set(self._tree.leaves(pruned_node))
                # allocate the pruned tips very low (negative) numbers so they are not sorted with the others
                tip_order_dict = {x:(n-max(pruned_tips) if x in pruned_tips else n+1) for n,x in enumerate(best_tip_order)}
            else:
                tip_order_dict = {x:n for n,x in enumerate(best_tip_order)}
            _assign_x_coordinates(self, root, tip_order_dict)
        else:
            _assign_x_coordinates(self, root)
    self._mutations = []
    node_mutations = collections.defaultdict(list)
    for site in self._tree.sites():
        for mutation in site.mutations:
            node_mutations[mutation.node].append(mutation)
    for child, mutations in node_mutations.items():
        n = len(mutations)
        parent = self._tree.parent(child)
        # Ignore any mutations that are above non-roots that are
        # not in the current tree.
        if child in self._x_coords:
            x = self._x_coords[child]
            y1 = self._y_coords[child]
            y2 = self._y_coords[parent]
            chunk = (y2 - y1) / (n + 1)
            for k, mutation in enumerate(mutations):
                z = x, _discretise(self, y1 + (k + 1) * chunk)
                self._mutations.append((z, mutation))

    if log_y:
        max_y = max(self._y_coords.values())
        min_y = min(self._y_coords.values())

        for k, y in self._y_coords.items():
            self._y_coords[k] = max_y-y+min_y

        remax = max_y/math.log(max_y)
        for k, y in self._y_coords.items():
            self._y_coords[k] = math.log(y) * remax
                
    if return_tip_order:
        tips = {i:self._x_coords[i] for i in self._tree.leaves()}
        return collections.OrderedDict(sorted(tips.items(), key=lambda t: t[1]))

In [5]:
class X3DOM_tree_seq:
    def __init__(self):
        self.branch_width = 0.4
        # We only need to "draw" a single node and edge for the entire tree seq
        # and after that we just use a x3dom copy. The following booleans define whether
        # we have yet to draw the first node or edge
        self.node_drawn=False
        self.mutation_drawn=False
        self.edge_drawn=False
        self.x3dom_string = ''
                
    def add_node(self, x, y, z=0):
        self.x3dom_string += '<Transform translation="{x} {y} {z}">'.format(x=x, y=y, z=z)
        if self.node_drawn==False:
            self.x3dom_string += (
                '<Shape DEF="node">'
                  '<Appearance>'
                    '<Material diffuseColor="0 0 1" specularColor=".5 .5 .5" />'
                  '</Appearance>'
                  '<Sphere radius="0.5" />'
                '</Shape>')
            self.node_drawn==True
        else:
            self.x3dom_string += '<Shape USE="node"/>'
        self.x3dom_string += '</Transform>'

    def add_mutation(self, x, y, z=0):
        self.x3dom_string += '<Transform translation="{x} {y} {z}">'.format(x=x, y=y, z=z)
        if self.node_drawn==False:
            self.x3dom_string += (
                '<Shape DEF="mutation">'
                  '<Appearance>'
                    '<Material diffuseColor="1 0.1 0" specularColor="1 1 1" />'
                  '</Appearance>'
                  '<Sphere radius="1" />'
                '</Shape>')
            self.mutation_drawn==True
        else:
            self.x3dom_string += '<Shape USE="mutation"/>'
        self.x3dom_string += '</Transform>'        
        
    def add_edge(self, from_xy, to_xy, z=0):
        dy = from_xy[1] - to_xy[1]
        dx = from_xy[0] - to_xy[0]

        self.x3dom_string += '<Transform translation="{x} {y} {z}">'.format(x=from_xy[0],y=to_xy[1],z=z)
        if self.edge_drawn==False:
            self.x3dom_string += (
                '<Shape DEF="linecap">'
                    '<Appearance>'
                        '<Material diffuseColor="0 0 1" specularColor=".5 .5 .5" DEF="edgecolour" />'
                    '</Appearance>'
                    '<Sphere radius="{w}" />'
                '</Shape>'
                '<Transform scale="1 {l} 1">'
                    '<Group DEF="line">'
                        '<Transform translation="0 0.5 0">'
                            '<Shape>'
                                '<Appearance>'
                                    '<Material USE="edgecolour"/>'
                                '</Appearance>'
                                '<Cylinder radius="{w}" height="1.0"/>'
                            '</Shape>'
                        '</Transform>'
                    '</Group>'
                '</Transform>').format(w=self.branch_width, l=dy)
            self.edge_drawn=True
        else:
            self.x3dom_string += (
                '<Shape USE="linecap"/>'
                '<Transform scale="1 {l} 1">'
                  '<Group USE="line"/>'
                '</Transform>').format(l=dy)

        if dx < 0:
            #branches to right
            self.x3dom_string += (
                '<Transform scale="1 {l} 1" rotation="0.0 0.0 1.0 -1.570796">'
                    '<Group USE="line"/>'
                '</Transform>').format(x=to_xy[0], y=to_xy[1], l=-dx)
        else:
            self.x3dom_string += (
                '<Transform scale="1 {l} 1" rotation="0.0 0.0 1.0 1.570796">'
                    '<Group USE="line"/>'
                '</Transform>').format(x=to_xy[0], y=to_xy[1], l=dx)
            
        self.x3dom_string += '</Transform>'

    def as_string(self):
        return self.x3dom_string



In [6]:
def add_tree(self, x3dom_ts, plane=0):
        """
        Calls the appropriate x3dom.add_node and x3dom.add_edge
        methods on self._tree . The resulting XML
        code to self.x3dom_string 
        """
        
        for i, u in enumerate(self._tree.nodes()):
            v = self._tree.get_parent(u)
            p = self._x_coords[u], self._y_coords[u]
            x3dom_ts.add_node(p[0], p[1], plane)
            dx = 0
            dy = -5
            #labels = mid_labels
            if self._tree.is_leaf(u):
                dy = 20
            elif self._tree.parent(u) != NULL_NODE:
                dx = 5
                if self._tree.left_sib(u) == NULL_NODE:
                    dx *= -1
                    #labels = right_labels
                else:
                    #labels = left_labels
                    pass
            #if self._node_labels[u] is not None:
            #    labels.add(dwg.text(self._node_labels[u], (x[0] + dx, x[1] + dy)))
            if self._tree.parent(u) != NULL_NODE:
                q = self._x_coords[v], self._y_coords[v]
                x3dom_ts.add_edge(p, q, plane)

        # Experimental stuff to render the mutation labels. Not working very
        # well at the moment.
        """left_labels = dwg.add(dwg.g(
            font_size=14, text_anchor="start", font_style="italic",
            alignment_baseline="middle"))
        right_labels = dwg.add(dwg.g(
            font_size=14, text_anchor="end", font_style="italic",
            alignment_baseline="middle"))
        """
        for (x, y), mutation in self._mutations:
            x3dom_ts.add_mutation(x, y, plane)
        


In [7]:
def html_3D_tree(ts):
    strip_coords = {tip_id:{'l':[], 'r':[]} for tip_id in ts.samples()}
    z_scale = 0.1 # reduce the genome coord by this much for plotting


    drawer = X3DOM_tree_seq()
    first_order = None
    last_order = None
    last_tree = None
    for i, (tree, (interval, e_out, e_in)) in enumerate(zip(ts.trees(sample_lists=True), ts.edge_diffs())):
        # Plot for each tree (and remember which were clumped in an SPR)
        d._tree = tree
        leaves = list(tree.leaves())
        d._num_leaves = len(leaves)
        if last_order is None:
            # this is the first tree, get some initial coords and set up constants
            order = _assign_coordinates(d, log_y=False)
            first_order = order.copy()
            dist_between_tips = np.mean(np.diff(np.array(list(order.values())))).item()
            #add a strip for one tip. Z is the plane the tree is in, x is x from the svg, y is constant, below 
            # first define some constants
            relative_strip_width = 0.7
            # how much graphical space to allow for the diagonal swapping lines
            # (this is a maximum value because when the tree span is small, we will need to shrink it)
            max_z_swap_dist = 100 #base pairs / genomic distance

            strip_half_width = dist_between_tips/2*relative_strip_width

            SPR_shrink_factor = 5
            z=interval[0]

            # set up the first points
            for n, x in order.items():
                strip_coords[n]['l'].append((x - strip_half_width, 0.0, (z)*z_scale))
                strip_coords[n]['r'].append((x + strip_half_width, 0.0, (z)*z_scale))

        else:
            #we plot per tree - at a minimum we should have 1/4 of the space allocated for the diagonals
            z_swap_dist = min(max_z_swap_dist, last_span/4)

            # Need to find potential SPR to this tree
            assert 2 <= len(e_out) <= 4
            tips_moved_to_make_this_tree = set()
            if len(e_out)==2:
                # must be the same order
                order = _assign_coordinates(d, log_y=False, best_tip_order=list(last_order.keys()))
                for o1, o2 in zip(last_order, order):
                    assert o1 == o2
            if len(e_out)==3 or len(e_out)==4:
                #here we calculate the SPR
                moved_node = identify_pruned_subtree(e_out, e_in)
                #print('------------------', i, '-----------_',"\n","Moved", moved_node)
                order = _assign_coordinates(d, log_y=False, best_tip_order=list(last_order.keys()), pruned_node=moved_node)
                if i==40:
                    print(moved_node, e_out, e_in)
                if moved_node is not None:
                    tips_moved_to_make_this_tree = set(tree.leaves(moved_node))

            if len(tips_moved_to_make_this_tree)==0:
                # everything stays the same - no need to do anything complicated
                pass
            else:
                z = interval[0]-z_swap_dist
                for l in leaves:
                    if l in tips_moved_to_make_this_tree or last_order[l]!=order[l]:
                        strip_coords[l]['l'].append((last_order[l] - strip_half_width, 0.0, z*z_scale))
                        strip_coords[l]['r'].append((last_order[l] + strip_half_width, 0.0, z*z_scale))

                # Do the constriction (only needed for SPR tips)
                SPR_moved = False
                for key, group in itertools.groupby(last_order.items(), lambda x: x[0] in tips_moved_to_make_this_tree):
                    if key == True:
                        assert SPR_moved==False # There should only be one continuous run of moved tips
                        z = interval[0] - z_swap_dist * 0.5
                        SPR_moved = True
                        SPR_tips = collections.OrderedDict(list(group))
                        SPR_pos = np.array(list(SPR_tips.values()))
                        SPR_pos = (SPR_pos-np.mean(SPR_pos))/SPR_shrink_factor + np.mean(SPR_pos)
                        for l, p in zip(SPR_tips, SPR_pos):
                            strip_coords[l]['l'].append((p - strip_half_width/SPR_shrink_factor, -1.0, z*z_scale))
                            strip_coords[l]['r'].append((p + strip_half_width/SPR_shrink_factor, -1.0, z*z_scale))
                # Do the expansion
                z_swap_dist = min(max_z_swap_dist, tree.span/4)
                SPR_moved = False
                for key, group in itertools.groupby(order.items(), lambda x: x[0] in tips_moved_to_make_this_tree):
                    if key == True:
                        assert SPR_moved==False # There should only be one continuous run of moved tips
                        z = interval[0] + z_swap_dist * 0.5
                        SPR_moved = True
                        SPR_tips = collections.OrderedDict(list(group))
                        SPR_pos = np.array(list(SPR_tips.values()))
                        SPR_pos = (SPR_pos-np.mean(SPR_pos))/SPR_shrink_factor + np.mean(SPR_pos)
                        for l, p in zip(SPR_tips, SPR_pos):
                            strip_coords[l]['l'].append((p - strip_half_width/SPR_shrink_factor, -1.0, z*z_scale))
                            strip_coords[l]['r'].append((p + strip_half_width/SPR_shrink_factor, -1.0, z*z_scale))

                z = interval[0]+z_swap_dist
                for l in leaves:
                    if l in tips_moved_to_make_this_tree or last_order[l]!=order[l]:
                        strip_coords[l]['l'].append((order[l] - strip_half_width, 0.0, z*z_scale))
                        strip_coords[l]['r'].append((order[l] + strip_half_width, 0.0, z*z_scale))

        add_tree(d, drawer, (interval[0]+interval[1])/2*z_scale)

        last_order = order
        last_span = tree.span

    z = ts.sequence_length
    for l, p in order.items():
        #add the last points for the strips - cheat by tailing them off into the distance using +10000
        strip_coords[l]['l'].append((p - strip_half_width, 0.0, (z+10000)*z_scale))
        strip_coords[l]['r'].append((p + strip_half_width, 0.0, (z+10000)*z_scale))


    html = header
    html += "<x3d width='1600px' height='1600px'>"
    html += "<Scene>"
    html += '<Viewpoint position="19.88623 6.59745 4.10251" orientation="-0.41812 0.88937 0.18492 0.98590" zNear="12.46871" zFar="45.50975" description="defaultX3DViewpointNode"></Viewpoint>' 
    #'<Viewpoint position="17.58460 16.28339 7.16874" orientation="-0.68793 0.63583 0.34996 1.06339" zNear="12.46871" zFar="45.50975" description="defaultX3DViewpointNode"></Viewpoint>'
    html += "<Transform scale ='0.1 0.1 0.1' translation='0 5 0.5' rotation='1 0 0 3.1415'>"
    html += drawer.as_string()

    max_order = max(first_order.keys())
    for o, tip in enumerate(reversed(first_order.keys())):
        for i in range(len(strip_coords[tip]['l'])-1):
            html += "<Shape>"
            html += "<IndexedFaceSet coordIndex='0 1 2 3' solid='false'><Coordinate point='"
            p = strip_coords[tip]['l'][i]
            html += "{:.3f} {:.3f} {:.3f}, ".format(p[0], p[1]*2+80, p[2])
            p = strip_coords[tip]['l'][i+1]
            html += "{:.3f} {:.3f} {:.3f}, ".format(p[0], p[1]*2+80, p[2])
            p = strip_coords[tip]['r'][i+1]
            html += "{:.3f} {:.3f} {:.3f}, ".format(p[0], p[1]*2+80, p[2])
            p = strip_coords[tip]['r'][i]
            html += "{:.3f} {:.3f} {:.3f}, ".format(p[0], p[1]*2+80, p[2])
            p = strip_coords[tip]['l'][i]
            html += "{:.3f} {:.3f} {:.3f}, ".format(p[0], p[1]*2+80, p[2])
            html += "' /></IndexedFaceSet>"
            html += "<Appearance><Material emissiveColor='{r} {g} {b}' specularColor='0.5 0.5 0.5' /></Appearance>".format(
                r=0.8 +(o/max_order)*0.2,
                g=(0.2 if p[1] else 0.3) + +(o/max_order)*0.2,
                b=0+(o/max_order)*0.2)
            html += "</Shape>"
    html += "</Transform></Scene>"
    html += "</x3d>"
    return(html)

In [8]:
# get a nice example, with evenly spaced trees and nice SPRs
n=0
done = 0
while done != 0b11:
    done = 0
    n += 1
    seed = np.random.randint(2**32-1) #[39059, 4831264, 1091849950, 620444574]
    seed = 3763020057

    ts = msprime.simulate(8, length=0.2e4, Ne=1e4, recombination_rate=1e-8, random_seed=seed)
    samples = set(ts.samples())
    if ts.num_trees == 3:
        spans = [t.span for t in ts.trees()]

        #print(ts.num_trees, n, spans, min(spans)/max(spans))
        if min(spans)/max(spans) > 0.4:
            for tree, (interval, e_out, e_in) in zip(ts.trees(), ts.edge_diffs()):
                #print(len(e_out))
                if len(e_out)==4:
                    moved_node = identify_pruned_subtree(e_out, e_in)
                    print(len(list(tree.leaves(moved_node))), end=" ")
                    if len(list(tree.leaves(moved_node))) == 2:
                        done = done | 1
                    if len(list(tree.leaves(moved_node))) > 2:
                        done = done | 2
    #break
print(seed)

3 2 3763020057


In [12]:
# This makes the Tree Sequence and plots it

ts = msprime.simulate(8, length=0.2e4, Ne=1e4, recombination_rate=1e-8, random_seed=3763020057)
ts = msprime.mutate(ts, 5e-8, random_seed=1234)

In [13]:
print("SVG tree")
display(HTML(ts.draw_svg()))

SVG tree


In [14]:
print("3D rotatable tree in browser")
display(HTML(header + html_3D_tree(ts)))

3D rotatable tree in browser
