# A Wright-Fisher simulation implemented in C via Cython.

OMG!

We would use GSL via CythonGSL, but that would require a GPL license for this notebook, but we're doing CCBY.

In [2]:
%load_ext Cython

In [3]:
import msprime
import numpy as np

In [4]:
%%cython

import msprime
import numpy as np
cimport numpy as np
from cython.view cimport array as cvarray
from libc.stdlib cimport malloc, realloc, free
from libc.stdint cimport int32_t, uint32_t

cdef int32_t * malloc_int32_t(size_t n):
    return <int32_t*>malloc(n*sizeof(int32_t))

cdef int32_t * realloc_int32_t(void * x, size_t n):
    return <int32_t*>realloc(x,n*sizeof(int32_t))

cdef double * malloc_double(size_t n):
    return <double*>malloc(n*sizeof(double))

cdef double * realloc_double(double * x, size_t n):
    return <double*>realloc(<double *>x,n*sizeof(double))

cdef struct Mutations:
    double * pos
    int32_t * time
    size_t next_mutation, capacity
    
cdef int init_Mutations(Mutations * m):
    m.next_mutation = 0
    m.capacity = 10000
    m.pos = malloc_double(m.capacity)
    if m.pos == NULL:
        return -1
    m.time = malloc_int32_t(m.capacity)
    if m.time == NULL:
        return -1
    return 0

cdef int realloc_Mutations(Mutations * m):
    m.capacity *= 2
    m.pos = realloc_double(m.pos,
                          m.capacity)
    if m.pos == NULL:
        return -1
    m.time = realloc_int32_t(m.time,
                            m.capacity)
    if m.time == NULL:
        return -1
    return 0

cdef void free_Mutations(Mutations * m):
    free(m.pos)
    free(m.time)
    m.next_mutation = 0
    m.capacity = 10000
    
cdef int add_mutation(double pos,
                     int32_t generation,
                     Mutations * m):
    cdef int rv = 0
    if m.next_mutation+1 >= m.capacity:
        rv = realloc_Mutations(m)
        if rv != 0:
            return rv
    m.pos[m.next_mutation] = pos
    m.time[m.next_mutation] = generation
    m.next_mutation+=1
    return rv
    
cdef struct Nodes:
    double * time
    size_t next_node, capacity
    
cdef int init_Nodes(Nodes * n):
    n.next_node = 0
    n.capacity = 10000
    n.time = malloc_double(n.capacity)
    if n.time == NULL:
        return -1
    return 0

cdef int realloc_Nodes(Nodes * n):
    n.capacity *= 2
    n.time = realloc_double(n.time,
                            n.capacity)
    if n.time == NULL:
        return -1
    return 0
    
cdef void free_Nodes(Nodes * n):
    if n.time != NULL:
        free(n.time)
    n.next_node = 0
    n.capacity = 10000

cdef int add_node(double t, Nodes *n):
    cdef int rv = 0
    if n.next_node >= n.capacity:
        rv = realloc_Nodes(n)
        if rv != 0:
            return rv
    n.time[n.next_node] = t
    n.next_node+=1
    return rv
    
cdef struct Edges:
    double *left
    double *right
    int32_t *parent
    int32_t *child
    size_t next_edge, capacity
    
cdef int init_Edges(Edges * e):
    e.next_edge = 0
    e.capacity = 10000
    e.left = malloc_double(e.capacity)
    if e.left == NULL:
        return -1
    e.right = malloc_double(e.capacity)
    if e.right == NULL:
        return -1
    e.parent = malloc_int32_t(e.capacity)
    if e.parent == NULL:
        return -1
    e.child = malloc_int32_t(e.capacity)
    if e.child == NULL:
        return -1
    return 0
   
cdef int realloc_Edges(Edges * e):
    e.capacity *= 2
    e.left = realloc_double(e.left,e.capacity)
    if e.left == NULL:
        return -1
    e.right = realloc_double(e.right,e.capacity)
    if e.right == NULL:
        return -1
    e.parent = realloc_int32_t(e.parent,e.capacity)
    if e.parent == NULL:
        return -1
    e.child = realloc_int32_t(e.child,e.capacity)
    if e.child == NULL:
        return -1
    return 0

cdef void free_Edges(Edges * e):
    free(e.left)
    free(e.right)
    free(e.parent)
    free(e.child)
    e.next_edge = 0
    e.capacity = 10000
    
cdef int add_edge(double left, double right,
             int32_t parent, int32_t child,
             Edges * edges):
    cdef int rv=0
    if edges.next_edge+1 >= edges.capacity:
        rv = realloc_Edges(edges)
        if rv != 0:
            return rv
        
    edges.left[edges.next_edge] = left
    edges.right[edges.next_edge] = right
    edges.parent[edges.next_edge] = parent
    edges.child[edges.next_edge] = child
    edges.next_edge += 1
    return rv

cdef void cleanup(Nodes * n, Edges * e, Mutations * m):
    free_Nodes(n)
    free_Edges(e)
    free_Mutations(m)
    
cdef int infsites(double mu, int32_t generation,
                  Mutations * mutations,
                  dict lookup):
    cdef unsigned nmut = np.random.poisson(mu)
    cdef unsigned i = 0
    cdef np.ndarray[double,ndim=1] pos
    cdef int rv = 0
    for i in range(nmut):
        pos = np.random.random_sample(1)
        while pos[0] in lookup:
            pos = np.random.random_sample(1)
        rv = add_mutation(pos[0],
                         generation,
                         mutations)
        if rv != 0:
            return rv
        lookup[pos[0]] = True
    return rv

cdef int poisson_recombination(double r,
                                tuple parent_indexes,
                                int32_t next_offspring_id,
                                Edges * edges):
    cdef unsigned nbreaks = np.random.poisson(r)
    cdef list b = []
    cdef unsigned i = 0
    cdef np.ndarray[double,ndim=1] x
    cdef list pgams
    cdef int rv = 0
    cdef double left,right
    cdef int32_t p
    if nbreaks == 0:
        # The parent passes the entire region onto the child
        rv = add_edge(0.0,1.0,parent_indexes[0],
                      next_offspring_id,edges)
        if rv != 0:
            return rv
    else:
        while i < nbreaks:
            x = np.random.random_sample(1)
            while x[0] in b:
                x = np.random.random_sample(1)
            b.append(x[0])
            i += 1
        b.sort()
        b.append(1.0)

        if b[0] != 0.0:
            b.insert(0,0.0)
        else:
            parent_indexes = (parent_indexes[1], parent_indexes[0])
        
        pgams = list([*tuple(i for i in parent_indexes)]*int(len(b)/2))

        for left,right,p in zip(b[0:len(b)-1],b[1:len(b)],pgams):
            rv = add_edge(left,right,p,
                          next_offspring_id,edges)
            if rv != 0:
                return rv
    return 0

def evolve(int N, int ngens, double theta, double rho, int gc):
    nodes = msprime.NodeTable()
    edges = msprime.EdgeTable()
    mutations = msprime.MutationTable()
    
    cdef double mu = theta/<double>(4*N)
    cdef double r = rho/<double>(4*N)
    
    cdef int rv
    cdef size_t i, generation
    cdef Nodes temp_nodes
    cdef Edges temp_edges
    cdef Mutations temp_mutations
    rv = init_Nodes(&temp_nodes)
    if rv != 0:
        cleanup(&temp_nodes,&temp_edges,&temp_mutations)
        raise RuntimeError("could not initialize temp_nodes")
    rv = init_Edges(&temp_edges)
    if rv != 0:
        cleanup(&temp_nodes,&temp_edges,&temp_mutations)
        raise RuntimeError("could not initialize temp_edges")
    rv = init_Mutations(&temp_mutations)
    if rv != 0:
        cleanup(&temp_nodes,&temp_edges,&temp_mutations)
        raise RuntimeError("could not initialize temp_mutations")
        
    for i in range(2*<size_t>N):
        nodes.add_row(time=0.0,
                      flags=msprime.NODE_IS_SAMPLE)
        
    
    cdef int32_t next_offspring_index, first_parental_index
    next_offspring_index = len(nodes)
    first_parental_index = 0
    PARENT_DTYPE = np.int32
    cdef np.ndarray[int32_t,ndim=1] parents
    cdef double mendel[2]
    cdef size_t parent1, parent2,pindex
    cdef int32_t[:] pview
    cdef int32_t p1g1, p1g2, p2g1, p2g2
    cdef dict lookup = {}
    for generation in range(1,<size_t>(ngens+1)):
        parents = np.random.randint(0, N, 2*N, dtype=PARENT_DTYPE)
        pview = parents
        for pindex in range(0,2*N,2):
            parent1=pview[pindex]
            parent2=pview[pindex+1]
            p1g1 = first_parental_index + 2*parent1
            p1g2 = p1g1 + 1
            p2g1 = first_parental_index + 2*parent2
            p2g2 = p2g1 + 1
            
            mendel = np.random.random_sample(2)
            if mendel[0] < 0.5:
                p1g1, p1g2 = p1g2, p1g1
            if mendel[1] < 0.5:
                p2g1, p2g2 = p2g2, p2g1
                
            rv = poisson_recombination(r,(p1g1,p1g2),
                                      next_offspring_index,
                                      &temp_edges)
            if rv != 0:
                cleanup(&temp_nodes,&temp_edges,
                       &temp_mutations)
                raise RuntimeError("error during recombination")
                
            rv = infsites(mu,generation,
                         &temp_mutations,lookup)
            
            if rv != 0:
                cleanup(&temp_nodes,&temp_edges,
                       &temp_mutations)
                raise RuntimeError("error during mutation")
                
            rv = add_node(<double>generation, &temp_nodes)
            if rv != 0:
                cleanup(&temp_nodes,&temp_edges,
                       &temp_mutations)
                raise RuntimeError("error during adding nodes")
                
            next_offspring_index += 1
                
            rv = poisson_recombination(r,(p2g1,p2g2),
                                      next_offspring_index,
                                      &temp_edges)
            if rv != 0:
                cleanup(&temp_nodes,&temp_edges,
                       &temp_mutations)
                raise RuntimeError("error during recombination")
                
            rv = infsites(mu,generation,
                         &temp_mutations,lookup)
            
            if rv != 0:
                cleanup(&temp_nodes,&temp_edges,
                       &temp_mutations)
                raise RuntimeError("error during mutation")
                
            rv = add_node(<double>generation, &temp_nodes)
            if rv != 0:
                cleanup(&temp_nodes,&temp_edges,
                       &temp_mutations)
                raise RuntimeError("error during adding nodes")
                
            next_offspring_index += 1
        first_parental_index += 2*N
    
    # Push first nodes times further back
    nodes.set_columns(time=nodes.time + ngens +1,
                     flags=nodes.flags)
            
    # Add our data to the tables
    cdef double[:] timeview = <double[:temp_nodes.next_node]>temp_nodes.time
    time=np.asarray(timeview,np.float)
    time-=time.max()
    time*=-1.0
    nodes.append_columns(time=time,
                         flags=np.ones(temp_nodes.next_node,np.uint32))
    edges.append_columns(left=np.asarray(<double[:temp_edges.next_edge]>temp_edges.left),
                        right=np.asarray(<double[:temp_edges.next_edge]>temp_edges.right),
                        parent=np.asarray(<int32_t[:temp_edges.next_edge]>temp_edges.parent),
                        child=np.asarray(<int32_t[:temp_edges.next_edge]>temp_edges.child))
    
    print(nodes.time.min(),nodes.time.max())

    msprime.sort_tables(nodes=nodes,edges=edges)
    
    samples = np.where(nodes.time==0)[0]
    
    print(samples)
    
    msprime.simplify_tables(samples=samples.tolist(),
                           nodes=nodes,
                           edges=edges)
                
    cleanup(&temp_nodes,&temp_edges,&temp_mutations)
    
    return msprime.load_tables(nodes=nodes,edges=edges)
    
    
def test_infsites():
    cdef Mutations m
    init_Mutations(&m)
    cdef dict lookup = {}
    cdef int rv = infsites(100000,1,&m,lookup)
    print(m.next_mutation,m.capacity)
    free_Mutations(&m)
    print(m.next_mutation,m.capacity)
    print("done!")
    
def test_add_edges():
    cdef Edges e
    cdef int rv = init_Edges(&e)
    print(e.next_edge,e.capacity)
    print(e.left == NULL)
    print(e.right == NULL)
    print(e.parent == NULL)
    print(e.child == NULL)
    print(rv)
    for i in range(20000):
        rv=add_edge(0,1,0,1,&e)
        #print(e.next_edge,e.capacity)
        # print(i,rv)
        if rv != 0:
            raise RuntimeError("error adding edges")
    print(e.next_edge,e.capacity)

In [5]:
%%time
evolve(1000, 10000, 100.0, 100.0, 10)

-0.0 10001.0
[20000000 20000001 20000002 ..., 20001997 20001998 20001999]
CPU times: user 2min 25s, sys: 4.7 s, total: 2min 29s
Wall time: 3min 18s


<msprime.trees.TreeSequence at 0x7f3160a57fd0>

In [6]:
test_infsites()

99859 160000
0 10000
done!


In [6]:
test_add_edges()

0 10000
False
False
False
False
0
20000 40000


## Literal Alg W with Cython

In [16]:
import random

def wright_fisher1(N, T, simplify_interval=1):    
    L = 1
    edges = msprime.EdgeTable()
    nodes = msprime.NodeTable()
    P = [j for j in range(N)]
    for j in range(N):
        nodes.add_row(time=T, flags=1)
    t = T    
    while t > 0:
        t -= 1
        Pp = [P[j] for j in range(N)]
        for j in range(N):
            n = len(nodes)
            nodes.add_row(time=t, flags=1)
            Pp[j] = n
            a = random.randint(0, N - 1)
            b = random.randint(0, N - 1)
            x = random.uniform(0, L)
            edges.add_row(0, x, P[a], n)
            edges.add_row(x, L, P[b], n)
        P = Pp
        if t % simplify_interval == 0:
            msprime.sort_tables(nodes=nodes, edges=edges)
            msprime.simplify_tables(Pp, nodes, edges)
            P = list(range(N))       
    # We will always simplify at t = 0, so no need for special case at the end
    return msprime.load_tables(nodes=nodes, edges=edges)


In [32]:
%%time 
ts1 = wright_fisher1(1000, 1000, 1000)

CPU times: user 11.8 s, sys: 48 ms, total: 11.8 s
Wall time: 11.8 s


In [30]:
%%cython
import msprime
import random

def wright_fisher2(N, T, simplify_interval=1):    
    L = 1
    edges = msprime.EdgeTable()
    nodes = msprime.NodeTable()
    P = [j for j in range(N)]
    for j in range(N):
        nodes.add_row(time=T, flags=1)
    t = T    
    while t > 0:
        t -= 1
        Pp = [P[j] for j in range(N)]
        for j in range(N):
            n = len(nodes)
            nodes.add_row(time=t, flags=1)
            Pp[j] = n
            a = random.randint(0, N - 1)
            b = random.randint(0, N - 1)
            x = random.uniform(0, L)
            edges.add_row(0, x, P[a], n)
            edges.add_row(x, L, P[b], n)
        P = Pp
        if t % simplify_interval == 0:
            msprime.sort_tables(nodes=nodes, edges=edges)
            msprime.simplify_tables(Pp, nodes, edges)
            P = list(range(N))       
    # We will always simplify at t = 0, so no need for special case at the end
    return msprime.load_tables(nodes=nodes, edges=edges)


In [33]:
%%time 
ts2 = wright_fisher2(1000, 1000, 1000)

CPU times: user 11.4 s, sys: 48 ms, total: 11.4 s
Wall time: 11.4 s


So, we're not any faster with a straight Cythonize. What can we do better?

In [36]:
%%prun -l 10 -s cumulative
ts2 = wright_fisher2(100, 1000, 1000)

 

   ```
    1956014 function calls in 1.618 seconds

   Ordered by: cumulative time
   List reduced from 24 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    1.618    1.618 {built-in method builtins.exec}
        1    0.000    0.000    1.618    1.618 <string>:2(<module>)
        1    0.441    0.441    1.618    1.618 {_cython_magic_90e4263cb4787d6ae8fae8befdafe5a8.wright_fisher2}
   200000    0.105    0.000    0.533    0.000 random.py:214(randint)
   200000    0.190    0.000    0.429    0.000 random.py:170(randrange)
   200000    0.170    0.000    0.285    0.000 tables.py:349(add_row)
   200000    0.171    0.000    0.238    0.000 random.py:220(_randbelow)
   100100    0.086    0.000    0.156    0.000 tables.py:174(add_row)
   200000    0.115    0.000    0.115    0.000 {function EdgeTable.add_row at 0x7f3135bf8f28}
   100000    0.077    0.000    0.087    0.000 random.py:342(uniform)
   ```

In [43]:
import msprime
import numpy as np

def wright_fisher3(N, T, simplify_interval=1):    
    L = 1
    edges = msprime.EdgeTable()
    nodes = msprime.NodeTable()
    P = np.arange(N, dtype=np.int32)    
    for j in range(N):
        nodes.add_row(time=T, flags=1)
    t = T    
    while t > 0:
        t -= 1
        parent_a = P[np.random.randint(0, N, N)]
        parent_b = P[np.random.randint(0, N, N)]  
        breakpoint = np.random.uniform(0, L, N)
        for j in range(N):
            child = nodes.add_row(time=t, flags=1)
            P[j] = child                 
            edges.add_row(0, breakpoint[j], parent_a[j], child)
            edges.add_row(breakpoint[j], L, parent_b[j], child)        
        if t % simplify_interval == 0:
            msprime.sort_tables(nodes=nodes, edges=edges)
            msprime.simplify_tables(P, nodes, edges)
            P[:] = np.arange(N, dtype=np.int32)       
    # We will always simplify at t = 0, so no need for special case at the end
    return msprime.load_tables(nodes=nodes, edges=edges)


In [44]:
%%time 
ts3 = wright_fisher3(1000, 1000, 1000)

CPU times: user 5.78 s, sys: 48 ms, total: 5.82 s
Wall time: 5.84 s


In [51]:
%%cython

import msprime
import numpy as np

def wright_fisher4(N, T, simplify_interval=1):    
    L = 1
    edges = msprime.EdgeTable()
    nodes = msprime.NodeTable()
    P = np.arange(N, dtype=np.int32)    
    for j in range(N):
        nodes.add_row(time=T, flags=1)
    t = T    
    while t > 0:
        t -= 1
        parent_a = P[np.random.randint(0, N, N)]
        parent_b = P[np.random.randint(0, N, N)]  
        breakpoint = np.random.uniform(0, L, N)
        for j in range(N):
            child = nodes.add_row(time=t, flags=1)
            P[j] = child                 
            edges.add_row(0, breakpoint[j], parent_a[j], child)
            edges.add_row(breakpoint[j], L, parent_b[j], child)        
        if t % simplify_interval == 0:
            msprime.sort_tables(nodes=nodes, edges=edges)
            msprime.simplify_tables(P, nodes, edges)
            P[:] = np.arange(N, dtype=np.int32)       
    # We will always simplify at t = 0, so no need for special case at the end
    return msprime.load_tables(nodes=nodes, edges=edges)


In [52]:
%%time 
ts4 = wright_fisher4(1000, 1000, 1000)

CPU times: user 6.58 s, sys: 48 ms, total: 6.62 s
Wall time: 6.63 s


In [54]:
%%prun -l 10 -s cumulative
ts4 = wright_fisher4(1000, 1000, 1000)

 

```
         6002012 function calls in 8.019 seconds

   Ordered by: cumulative time
   List reduced from 16 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    8.019    8.019 {built-in method builtins.exec}
        1    0.000    0.000    8.018    8.018 <string>:2(<module>)
        1    2.824    2.824    8.018    8.018 {_cython_magic_e74d95b05791c2c9d265cd84a8026655.wright_fisher4}
  2000000    1.518    0.000    2.795    0.000 tables.py:349(add_row)
  1001000    0.729    0.000    1.316    0.000 tables.py:174(add_row)
  2000000    1.277    0.000    1.277    0.000 {function EdgeTable.add_row at 0x7f3135bf8f28}
        1    0.000    0.000    0.757    0.757 tables.py:1172(sort_tables)
        1    0.757    0.757    0.757    0.757 {built-in method _msprime.sort_tables}
  1001000    0.587    0.000    0.587    0.000 {function NodeTable.add_row at 0x7f3135bf8840}
        1    0.000    0.000    0.304    0.304 tables.py:1232(simplify_tables)
```

Even with Cython, a lot of the running time is down to add_row, which is probably due to the Python overhead of these functions. So, let's try one where we replace the 'generation loop' with numpy:

In [1]:
import msprime
import numpy as np

def wright_fisher5(N, T, simplify_interval=1):    
    L = 1
    edges = msprime.EdgeTable()
    nodes = msprime.NodeTable()
    num_sorted_edges = 0
    P = np.arange(N, dtype=np.int32)    
    flags = np.zeros(N, dtype=np.uint32) + msprime.NODE_IS_SAMPLE
    nodes.set_columns(time=np.zeros(N) + T, flags=flags)
    t = T    
    left = np.zeros(N)
    right = np.zeros(N) + L    
    while t > 0:
        t -= 1
        parent_a = P[np.random.randint(0, N, N)]
        parent_b = P[np.random.randint(0, N, N)]  
        breakpoint = np.random.uniform(0, L, N)        
        child = np.arange(N, dtype=np.int32) + len(nodes)
        edges.append_columns(left, breakpoint, parent_a, child)
        edges.append_columns(breakpoint, right, parent_b, child)
        nodes.append_columns(time=np.zeros(N) + t, flags=flags)
        P = child       
        if t % simplify_interval == 0:            
            msprime.sort_tables(nodes=nodes, edges=edges)
            msprime.simplify_tables(P, nodes, edges)
            num_sorted_edges = len(edges)
            P[:] = np.arange(N, dtype=np.int32)       
    # We will always simplify at t = 0, so no need for special case at the end
    return msprime.load_tables(nodes=nodes, edges=edges)


In [2]:
%%time 
ts5 = wright_fisher5(1000, 1000, 1000)

CPU times: user 1.15 s, sys: 112 ms, total: 1.26 s
Wall time: 1.26 s


In [3]:
%%prun -l 10 -s cumulative
wright_fisher5(1000, 1000, 1000)

 

```
   26020 function calls in 6.034 seconds

   Ordered by: cumulative time
   List reduced from 24 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    6.034    6.034 {built-in method builtins.exec}
        1    0.011    0.011    6.034    6.034 <string>:2(<module>)
        1    0.087    0.087    6.023    6.023 <ipython-input-83-a048ba58c6db>:4(wright_fisher5)
        1    0.000    0.000    3.305    3.305 tables.py:1172(sort_tables)
        1    3.305    3.305    3.305    3.305 {built-in method _msprime.sort_tables}
        1    0.000    0.000    1.168    1.168 tables.py:1232(simplify_tables)
        1    1.168    1.168    1.168    1.168 {built-in method _msprime.simplify_tables}
     4000    0.009    0.000    1.077    0.000 tables.py:382(append_columns)
     4000    1.068    0.000    1.068    0.000 {function EdgeTable.append_columns at 0x7f3135bfa0d0}
     2000    0.007    0.000    0.128    0.000 tables.py:223(append_columns)
```
