In [1]:
import toytree
import toyplot
import numpy as np
from ipcoal.smc.smc5 import tree_unch_prob_bt,topo_unch_prob_bt
import pandas as pd

## Use ruler-measured divergence times from the figure to set relative heights, and arbitrarily scale so that the root of the species tree is at 1e6 generations.

In [2]:
ruler_scaler = 250000 # so that the st root will be at 1e6

### Species tree:

In [3]:
# species tree topology
st = toytree.tree('(((a,b),c),d);').mod.edges_extend_tips_to_align()
st.draw(ts='s');

In [4]:
# set divergence times
wab = 1*ruler_scaler
wabc = 3*ruler_scaler
wabcd = 4*ruler_scaler
st = st.set_node_data(
    "height",
    mapping={4:wab, 
             5: wabc, 
             6: wabcd, 
            }, 
    default=0
)

In [5]:
# set N values
st = st.set_node_data(
    "Ne",
    mapping={0:70000,
             1:70000,
             2:50000,
             3:70000,
             4:100000, 
             5:120000, 
             6:140000, 
            }, 
    default=0
)

### Genealogy:

In [6]:
# genealogy topology
gt = toytree.tree('(((a_1,a_2),(a_3,(b_1,b_2))),(c_1,d_1));')
gt = gt.mod.edges_extend_tips_to_align()
gt.draw(ts='s');

In [7]:
# set divergence times
gt = gt.set_node_data(
    "height",
    mapping={7:0.5*ruler_scaler, 
             8: 0.7*ruler_scaler, 
             9: 2.1*ruler_scaler, 
             10: 3.4*ruler_scaler,
             11:4.3*ruler_scaler,
             12:4.8*ruler_scaler,
            }, 
    default=0
)
gt.draw();

## Map the genealogy tips to a species tree population.

In [8]:
imap = {'a': ['a_1','a_2','a_3'], 'b': ['b_1','b_2'], 'c': ['c_1'], 'd': ['d_1']}

## Make a sample probabilities plot for one branch.

In [9]:
for idx in [2]:
    node = gt[idx]
    vals = []
    for t in np.linspace(node.height+1,(node.height+node.dist-1),200):
        vals.append(topo_unch_prob_bt(node, t, st, gt, imap))
    
    node = gt[idx]
    trvals = []
    for t in np.linspace(node.height+1,(node.height+node.dist-1),200):
        trvals.append(tree_unch_prob_bt(node, t, st, gt, imap))

canvas = toyplot.Canvas(400, 400) 
axes = canvas.cartesian(grid=(3, 1, 0),ymax=1,ymin=0,margin=25,padding=10)
axes.x.ticks.show = True
axes.x.domain.show = False
axes.y.ticks.show = True
axes.y.domain.show = False
mark = axes.plot(np.linspace(node.height+1,(node.height+node.dist-1),200), 
                 trvals,
                 stroke_width=5,
                 color='#d7191c'
                )

axes = canvas.cartesian(grid=(3, 1, 1),ymax=1,ymin=0,margin=25,padding=10)
axes.x.ticks.show = True
axes.x.domain.show = False
axes.y.ticks.show = True
axes.y.domain.show = False
mark = axes.plot(np.linspace(node.height+1,(node.height+node.dist-1),200), 
                 1-np.array(trvals),
                 stroke_width=5,
                 color='#d7191c'
                )

axes = canvas.cartesian(grid=(3, 1, 2),ymax=1,ymin=0,margin=25,padding=10)
axes.x.ticks.show = True
axes.x.domain.show = False
axes.y.ticks.show = True
axes.y.domain.show = False
mark = axes.plot(np.linspace(node.height+1,(node.height+node.dist-1),200), 
                 1-np.array(vals),
                 stroke_width=5,
                 color='#d7191c'
                )


## Now replicate across three columns for final figure.

In [10]:
# Column 1:

for idx in [2]:
    node = gt[idx]
    vals = []
    for t in np.linspace(node.height+1,(node.height+node.dist-1),200):
        vals.append(topo_unch_prob_bt(node, t, st, gt, imap))
    
    node = gt[idx]
    trvals = []
    for t in np.linspace(node.height+1,(node.height+node.dist-1),200):
        trvals.append(tree_unch_prob_bt(node, t, st, gt, imap))

canvas = toyplot.Canvas(600, 400) 
axes = canvas.cartesian(grid=(3, 3, 0),ymax=1,ymin=0,margin=30,padding=10)
axes.x.ticks.show = True
locations = [0,0.5*ruler_scaler,1*ruler_scaler,2.1*ruler_scaler]
labels = ['','','','']
axes.x.ticks.locator = toyplot.locator.Explicit(locations=locations, labels=labels)
axes.x.domain.show = False
axes.y.ticks.show = True
axes.y.domain.show = False
mark = axes.plot(np.linspace(node.height+1,(node.height+node.dist-1),200), 
                 trvals,
                 stroke_width=5,
                 color='#d7191c'
                )

axes = canvas.cartesian(grid=(3, 3, 3),ymax=1,ymin=0,margin=30,padding=10)
axes.x.ticks.show = True
locations = [0,0.5*ruler_scaler,1*ruler_scaler,2.1*ruler_scaler]
labels = ['','','','']
axes.x.ticks.locator = toyplot.locator.Explicit(locations=locations, labels=labels)
axes.x.domain.show = False
axes.y.ticks.show = True
axes.y.domain.show = False
mark = axes.plot(np.linspace(node.height+1,(node.height+node.dist-1),200), 
                 1-np.array(trvals),
                 stroke_width=5,
                 color='#d7191c'
                )

axes = canvas.cartesian(grid=(3, 3, 6),ymax=1,ymin=0,margin=30,padding=10)
axes.x.ticks.show = True
locations = [0,0.5*ruler_scaler,1*ruler_scaler,2.1*ruler_scaler]
labels = ['0','t<sub>8</sub>','W<sub>ab</sub>','t<sub>9</sub>']
axes.x.ticks.locator = toyplot.locator.Explicit(locations=locations, labels=labels)
axes.x.domain.show = False
axes.y.ticks.show = True
axes.y.domain.show = False
mark = axes.plot(np.linspace(node.height+1,(node.height+node.dist-1),200), 
                 1-np.array(vals),
                 stroke_width=5,
                 color='#d7191c'
                )

# Column 2

for idx in [5]:
    node = gt[idx]
    vals = []
    for t in np.linspace(node.height+1,(node.height+node.dist-1),200):
        vals.append(topo_unch_prob_bt(node, t, st, gt, imap))
    
    node = gt[idx]
    trvals = []
    for t in np.linspace(node.height+1,(node.height+node.dist-1),200):
        trvals.append(tree_unch_prob_bt(node, t, st, gt, imap))

axes = canvas.cartesian(grid=(3, 3, 1),ymax=1,ymin=0,margin=30,padding=10)
axes.x.ticks.show = True
locations = [0,3*ruler_scaler,3.4*ruler_scaler,4*ruler_scaler,4.3*ruler_scaler]
labels = ['','','','','']
axes.x.ticks.locator = toyplot.locator.Explicit(locations=locations, labels=labels)
axes.x.domain.show = False
axes.y.ticks.show = True
axes.y.domain.show = False
mark = axes.plot(np.linspace(node.height+1,(node.height+node.dist-1),200), 
                 trvals,
                 stroke_width=5,
                 color='#d7191c'
                )

axes = canvas.cartesian(grid=(3, 3, 4),ymax=1,ymin=0,margin=30,padding=10)
axes.x.ticks.show = True
locations = [0,3*ruler_scaler,3.4*ruler_scaler,4*ruler_scaler,4.3*ruler_scaler]
labels = ['','','','','']
axes.x.ticks.locator = toyplot.locator.Explicit(locations=locations, labels=labels)
axes.x.domain.show = False
axes.y.ticks.show = True
axes.y.domain.show = False
mark = axes.plot(np.linspace(node.height+1,(node.height+node.dist-1),200), 
                 1-np.array(trvals),
                 stroke_width=5,
                 color='#d7191c'
                )

axes = canvas.cartesian(grid=(3, 3, 7),ymax=1,ymin=0,margin=30,padding=10)
axes.x.ticks.show = True
locations = [0,3*ruler_scaler,3.4*ruler_scaler,4*ruler_scaler,4.3*ruler_scaler]
labels = ['0','W<sub>abc</sub>','t<sub>10</sub>','W<sub>abcd</sub>','t<sub>11</sub>']
axes.x.ticks.locator = toyplot.locator.Explicit(locations=locations, labels=labels)
axes.x.domain.show = False
axes.y.ticks.show = True
axes.y.domain.show = False
mark = axes.plot(np.linspace(node.height+1,(node.height+node.dist-1),200), 
                 1-np.array(vals),
                 stroke_width=5,
                 color='#d7191c'
                )

# Column 3

for idx in [7]:
    node = gt[idx]
    vals = []
    for t in np.linspace(node.height+1,(node.height+node.dist-1),200):
        vals.append(topo_unch_prob_bt(node, t, st, gt, imap))
    
    node = gt[idx]
    trvals = []
    for t in np.linspace(node.height+1,(node.height+node.dist-1),200):
        trvals.append(tree_unch_prob_bt(node, t, st, gt, imap))

axes = canvas.cartesian(grid=(3, 3, 2),ymax=1,ymin=0,margin=30,padding=10)
axes.x.ticks.show = True
locations = [0.5*ruler_scaler,
             1*ruler_scaler,#-0.5*ruler_scaler,
             2.1*ruler_scaler,#-0.5*ruler_scaler,
             3*ruler_scaler,#-0.5*ruler_scaler,
             3.4*ruler_scaler,#-0.5*ruler_scaler
            ]
labels = ['','','','','']
axes.x.ticks.locator = toyplot.locator.Explicit(locations=locations, labels=labels)
axes.x.domain.show = False
axes.y.ticks.show = True
axes.y.domain.show = False
mark = axes.plot(np.linspace(node.height+1,(node.height+node.dist-1),200), 
                 trvals,
                 stroke_width=5,
                 color='#d7191c'
                )

axes = canvas.cartesian(grid=(3, 3, 5),ymax=1,ymin=0,margin=30,padding=10)
axes.x.ticks.show = True
locations = [0.5*ruler_scaler,
             1*ruler_scaler,#-0.5*ruler_scaler,
             2.1*ruler_scaler,#-0.5*ruler_scaler,
             3*ruler_scaler,#-0.5*ruler_scaler,
             3.4*ruler_scaler,#-0.5*ruler_scaler
            ]
labels = ['','','','','']
axes.x.ticks.locator = toyplot.locator.Explicit(locations=locations, labels=labels)
axes.x.domain.show = False
axes.y.ticks.show = True
axes.y.domain.show = False
mark = axes.plot(np.linspace(node.height+1,(node.height+node.dist-1),200), 
                 1-np.array(trvals),
                 stroke_width=5,
                 color='#d7191c'
                )

axes = canvas.cartesian(grid=(3, 3, 8),ymax=1,ymin=0,margin=30,padding=10)
axes.x.ticks.show = True
locations = [0.5*ruler_scaler,
             1*ruler_scaler,#-0.5*ruler_scaler,
             2.1*ruler_scaler,#-0.5*ruler_scaler,
             3*ruler_scaler,#-0.5*ruler_scaler,
             3.4*ruler_scaler,#-0.5*ruler_scaler
            ]
labels = ['t<sub>7</sub>','W<sub>ab</sub>','t<sub>9</sub>','W<sub>abc</sub>','t<sub>10</sub>']
axes.x.ticks.locator = toyplot.locator.Explicit(locations=locations, labels=labels)
axes.x.domain.show = False
axes.y.ticks.show = True
axes.y.domain.show = False
mark = axes.plot(np.linspace(node.height+1,(node.height+node.dist-1),200), 
                 1-np.array(vals),
                 stroke_width=5,
                 color='#d7191c'
                )
