In [1]:
import numpy as np
import networkx as nx
from ipysigma import Sigma
np.random.seed(0)

In [2]:
class Node:
    node_count = 0
    all_nodes = []

    def __init__(self, mut_prob=0.0):
        self.id = Node.node_count
        Node.node_count += 1
        Node.all_nodes.append(self)

        self.parent = None
        self.children = []

        self.mutated = np.random.rand() < mut_prob
    
    def adopt(self, child):
        assert child not in self.children
        assert child.parent is None
        self.children.append(child)
        child.parent = self

class WrightFisher:
    def __init__(self, origin, n_pop, mut_prob=0.0):
        self.origin = origin
        self.population = [origin]
        self.n_pop = n_pop
        self.mut_prob = mut_prob
        self.generation = 0
    
    def next_generation(self):
        new_population = [Node(mut_prob=self.mut_prob) for _ in range(self.n_pop)]
        parents = np.random.choice(self.population, size=self.n_pop, replace=True)
        for parent, child in zip(parents, new_population):
            parent.adopt(child)
        self.population = new_population
        self.generation += 1

    def evolve(self, n_gen):
        for _ in range(n_gen):
            self.next_generation()

In [3]:
def to_networkx(nodes, root):
    # returns a networkx graph of the subtree rooted at root that contains all the nodes
    G = nx.DiGraph()
    for node in nodes:
        u = node
        while u.parent != root:
            v = u.parent
            G.add_edge(v.id, u.id)
            u = v
        G.add_edge(root.id, u.id)
    return G

def simplify(graph, root):
    # returns a simplified version of the subtree rooted at root
    
    ans = nx.DiGraph()

    for v in list(graph.successors(root)):
        # find first node u that has number of children != 1
        u = v
        weight = 1
        mut_count = 1 if Node.all_nodes[u].mutated else 0
        while graph.out_degree(u) == 1:
            u = list(graph.successors(u))[0]
            mut_count += 1 if Node.all_nodes[u].mutated else 0
            weight += 1
        
        # add edge from root to u
        ans.add_edge(root, u, weight=weight)

        # add mutation count to u
        ans.nodes[u]['n_mutations'] = str(mut_count)
        if mut_count > 0:
            ans.nodes[u]['color'] = '#ff0000'

        # recursively simplify subtree rooted at u, and add it to ans
        ans = nx.compose(ans, simplify(graph, u))
    return ans

In [4]:
n = 100
n_gen = n
n_pop = n
n_sample = n_pop
mut_prob = 0.5

origin = Node()
wf = WrightFisher(origin, n_pop, mut_prob)
wf.evolve(n_gen)

sample = np.random.choice(wf.population, size=n_sample, replace=False)

In [5]:
G = to_networkx(sample, origin)

# color nodes based on mutation status
for node in G.nodes:
    if Node.all_nodes[node].mutated:
        G.nodes[node]['color'] = "red"

sample_ids = [node.id for node in sample]

# make sample nodes green
for node in G.nodes:
    if node in sample_ids:
        G.nodes[node]['color'] = "green"

# make origin node black
G.nodes[origin.id]['color'] = "black"

Sigma(G)

Sigma(nx.DiGraph with 844 nodes and 843 edges)

In [6]:
simplified_G = simplify(G, origin.id)

# color origin node blue
simplified_G.nodes[origin.id]["color"] = "#0000ff"

Sigma(simplified_G, hide_info_panel=True, edge_label="weight", node_label="n_mutations")

Sigma(nx.DiGraph with 182 nodes and 181 edges)