In [164]:
%%javascript
require.config({ 
     paths: { 
     d3: 'https://d3js.org/d3.v4.min'
}});

require(["d3"], function(d3) {
    window.d3 = d3;
});

<IPython.core.display.Javascript object>

In [187]:
import msprime
import numpy as np
import random
from string import Template
from IPython.display import HTML, Javascript, display
import time


def send_to_D3(arg):
    """
    This function is adapted from https://github.com/stitchfix/d3-jupyter-tutorial
    """
    #arg = {"arg": graph, "width": width + 100, "height": height + 100}
    divnum = int(random.uniform(0,9999999999))
    arg['divnum'] = divnum
    JS_text = Template("<div id='arg_${divnum}'></div><script>$main_text</script>")
    main_text_template = Template( open("visualizer.js",'r').read() )
    main_text = main_text_template.safe_substitute(arg)
    return display(HTML(JS_text.safe_substitute({'divnum': divnum, 'main_text': main_text})))

def draw_full_arg(ts, width=600, height=600, tree_highlighting=True):
    # Parameters for the dimensions of the D3 plot. Eventually want to handle this entirely in JS
    w_spacing = (width-100) / (ts.num_samples - 1)
    h_spacing = (height-100) / (ts.num_nodes - ts.num_samples - np.count_nonzero(ts.tables.nodes.flags == 131072)/2)
    
    if w_spacing < 40 or h_spacing < 40:
        print("WARNING, this plot may be too small to properly handle pathing. Try a larger size!")
    
    # Ordering of sample nodes is the same as the first tree in the sequence
    ordered_nodes = []
    for node in ts.first().nodes(order="minlex_postorder"):
        if node < ts.num_samples:
            ordered_nodes.append(node)
    
    # Determines the rank (y position) of each time point       
    unique_times = list(np.unique(ts.tables.nodes.time))
    
    # The two recombination node IDs will be merged together in this visualization
    recombination_nodes_to_merge = np.where(ts.tables.nodes.flags == 131072)[0][1::2]

    # Builds the nodes json. A "reference" is the id of another node that is used to determine a property in the
    # graph. Example: recombination nodes should have the same x position as their child, unless their child is
    # also a recombination node. This isn't yet implemented automatically in the layout as it breaks the force
    # layout.
    nodes = []
    for ID, node in enumerate(ts.tables.nodes):
        info = {
            "id": ID,
            "flag": node.flags,
            "time": node.time,
            "fy": height-(unique_times.index(node.time)*h_spacing)-50 #fixed y position, property of force layout
        }
        label = ID
        if node.flags == 1:
            info["fx"] = ordered_nodes.index(ID)*w_spacing+50 #sample nodes have a fixed x position
        elif node.flags == 131072:
            if ID in recombination_nodes_to_merge:
                continue
            label = str(ID)+"/"+str(ID+1)
            info["x_pos_reference"] = ts.tables.edges[np.where(ts.tables.edges.parent == ID)[0]].child[0]
        elif node.flags == 262144:
            info["x_pos_reference"] = ts.tables.edges[np.where(ts.tables.edges.parent == ID)[0]].child[0]
        info["label"] = label #label which is either the node ID or two node IDs for recombination nodes
        nodes.append(info)

    # Builds the edges json. For recombination nodes, replaces the larger number with the smaller. The direction
    # that the edge should go relates to the positions of not just the nodes connected by that edge, but also the
    # other edges connected to the child. See the JS for all of the different scenarios; still working through
    # that.
    links = []
    for edge in ts.tables.edges:
        child = edge.child
        alternative_child = ""
        alternative_parent = ""
        if edge.parent not in recombination_nodes_to_merge:
            left = edge.left
            right = edge.right
            if ts.tables.nodes.flags[edge.parent] != 131072:
                children = ts.tables.edges[np.where(ts.tables.edges.parent == edge.parent)[0]].child
                alternative_child = children[np.where(children != edge.child)][0]
                if alternative_child in recombination_nodes_to_merge:
                    alternative_child -= 1
            else:
                alt_edge = ts.tables.edges[np.where(ts.tables.edges.parent == edge.parent + 1)[0]]
                if left > alt_edge.left[0]:
                    left = alt_edge.left[0]
                if right < alt_edge.right[0]:
                    right = alt_edge.right[0]
            if ts.tables.nodes.flags[edge.child] == 131072:
                if edge.child in recombination_nodes_to_merge:
                    alt_id = edge.child - 1
                else:
                    alt_id = edge.child + 1
                alternative_parent = ts.tables.edges[np.where(ts.tables.edges.child == alt_id)[0]].parent[0]
            if edge.child in recombination_nodes_to_merge:
                child = edge.child - 1
            links.append({
                "source": edge.parent,
                "target": child,
                "left": left,
                "right": right,
                "alt_parent": alternative_parent, #recombination nodes have an alternative parent
                "alt_child": alternative_child
            })
    
    breakpoints = []
    if tree_highlighting:
        height += 100
        start = 0
        for bp in ts.breakpoints():
            if bp != 0:
                breakpoints.append({
                    "start": start,
                    "stop": bp,
                    "x_pos":(start/ts.sequence_length)*width,
                    "width":((bp - start)/ts.sequence_length)*width
                })
                start = bp
            
    
    return send_to_D3({
        "arg":{
            "nodes":nodes,
            "links":links,
            "breakpoints": breakpoints
        },
        "width":width,
        "height":height,
        "tree_highlighting":str(tree_highlighting).lower()
    })
    
# Generate a random tree sequence with record_full_arg=True so that you get marked recombination nodes
rs = random.randint(0,10000)   
ts = msprime.sim_ancestry(
    samples=5,
    recombination_rate=1e-8,
    sequence_length=2_000,
    population_size=10_000,
    record_full_arg=True,
    random_seed=rs
)
print("random seed:", rs)
#print(ts.draw_text())
draw_full_arg(ts)




random seed: 2020


In [25]:
# Generate a random tree sequence with record_full_arg=True so that you get marked recombination nodes
rs = random.randint(0,10000)   
ts = msprime.sim_ancestry(
    samples=2,
    recombination_rate=1e-8,
    sequence_length=2_000,
    population_size=10_000,
    record_full_arg=True,
    random_seed=9203
)
print("random seed:", 9203)
#print(ts.draw_text())
draw_full_arg(ts)

random seed: 9203
