In [1]:
import numpy as np
import networkx as nx
from ipysigma import Sigma
import matplotlib.pyplot as plt

np.random.seed(0)

In [2]:
class Mutation:
    count = 0

    @staticmethod
    def new(n):
        start = Mutation.count
        Mutation.count += n
        stop = Mutation.count
        return range(start, stop)

In [3]:
class Node:
    node_count = 0
    def __init__(self, t):
        self.id = Node.node_count
        Node.node_count += 1
        
        self.t = t
        self.parent = None
        self.children = []
        self.mutated = False
        self.mutations = range(0)

    def print(self, indent=0):
        print(f"{'  ' * indent}{self.id} ({self.t:.2f}) {self.mutations}")
        for child in self.children:
            child.print(indent + 1)

    def adopt(self, child):
        self.children.append(child)
        child.parent = self

    def to_networkx(self):
        G = nx.DiGraph()
        mut_str = self.mutations if len(self.mutations) > 0 else ''
        G.add_node(self.id, label=f"{self.id} ({self.t:.2f}) {mut_str}")
        for child in self.children:
            G.add_edge(self.id, child.id)
            G = nx.compose(G, child.to_networkx())
        return G
    
    def mutate(self, rate):
        assert self.mutated is False
        assert self.parent is not None
        n = np.random.poisson(rate * (self.t - self.parent.t))
        self.mutations = Mutation.new(n)
        for child in self.children:
            child.mutate(rate)
        return n
    
    def get_mutation_set(self):
        # return set of my mutations union mutations of my parent
        if self.parent is None:
            return set(self.mutations)
        return set(self.mutations).union(self.parent.get_mutation_set())
        

In [4]:
def get_lineage(node):
    path = [node]
    while path[-1].parent is not None:
        path.append(path[-1].parent)
    return reversed(path)

def get_common_ancestor(nodes):
    lineages = [list(get_lineage(node)) for node in nodes]
    common_ancestor = None
    for ancestors in zip(*lineages):
        if len(set(ancestors)) == 1:
            common_ancestor = ancestors[0]
        else:
            break
    return common_ancestor

In [5]:
def coalesce(pop_size, sample_size, initial_t, final_t):
    t = final_t

    leaves = [Node(t) for _ in range(sample_size)]
    orphans = leaves.copy()

    while len(orphans) > 1:
        tmean = 2 * pop_size / (sample_size * (sample_size - 1))
        t -= np.random.exponential(tmean)

        if t < initial_t:
            break

        # Pick two orphans to coalesce, remove then from the list of orphas
        node1 = np.random.choice(orphans)
        orphans.remove(node1)

        node2 = np.random.choice(orphans)
        orphans.remove(node2)

        # Create a new node that is the parent of the two orphans
        parent = Node(t)
        parent.adopt(node1)
        parent.adopt(node2)

        # Add the parent to the list of orphans
        orphans.append(parent)

    return orphans, leaves

In [6]:
class Individual:
    def __init__(self, t, origin, n_pop, n_sample, mut_rate):
        self.t = t
        self.origin = origin
        self.n_pop = n_pop
        self.n_sample = n_sample
        self.mut_rate = mut_rate

        self.sample = [Node(t) for _ in range(n_sample)]
        for node in self.sample:
            origin.adopt(node)

    def evolve(self, dt):
        # Coalesce the sample and get the orphans
        orphans, leaves = coalesce(self.n_pop, self.n_sample, self.t, self.t + dt)

        # Pick the parents of the orphans
        parents = np.random.choice(self.sample, len(orphans), replace=False)

        # Adopt the orphans
        for parent, orphan in zip(parents, orphans):
            parent.adopt(orphan)

        for orphan in orphans:
            orphan.mutate(self.mut_rate)

        self.sample = leaves
        self.t += dt

    def contaminate(self):
        # return a new individual originating from a random node in the sample
        origin = np.random.choice(self.sample)
        return Individual(self.t, origin, self.n_pop, self.n_sample, self.mut_rate)
    
    def get_mutation_sets(self):
        return [node.get_mutation_set() for node in self.sample]
    
    def get_non_fixated_mutation_distribution(self):
        mutation_sets = self.get_mutation_sets()
        all_mutations = set.union(*mutation_sets)
        distribution = {m: 0 for m in all_mutations}
        for mutation_set in mutation_sets:
            for m in mutation_set:
                distribution[m] += 1
        
        # ignore fixated mutations
        distribution = {k: v for k, v in distribution.items() if v < self.n_sample}
        
        # sort by value
        return {k: v for k, v in sorted(distribution.items(), key=lambda item: item[1], reverse=True)}



In [7]:
n_pop = 1000
n_sample = 10
mut_rate = 0.1
T = 100

t = 0
origin = Node(t)
a = Individual(t, origin, n_pop, n_sample, mut_rate)
a.evolve(T)
b = a.contaminate()
a.evolve(T)
b.evolve(T)
c = b.contaminate()
a.evolve(T)
b.evolve(T)
c.evolve(T)


G = origin.to_networkx()
# set colors for the nodes in the sample
for node in a.sample:
    G.nodes[node.id]["color"] = "red"
for node in b.sample:
    G.nodes[node.id]["color"] = "blue"
for node in c.sample:
    G.nodes[node.id]["color"] = "green"
Sigma(G, hide_info_panel=True)

Sigma(nx.DiGraph with 116 nodes and 115 edges)

In [13]:
get_common_ancestor(c.sample).print()

49 (200.00) range(230, 231)
  63 (200.00) range(0, 0)
    105 (300.00) range(394, 406)
  64 (200.00) range(0, 0)
    115 (255.37) range(463, 468)
      109 (300.00) range(468, 473)
      114 (266.02) range(473, 474)
        110 (300.00) range(474, 478)
        113 (300.00) range(478, 479)
  65 (200.00) range(0, 0)
  66 (200.00) range(0, 0)
  67 (200.00) range(0, 0)
    104 (300.00) range(385, 394)
  68 (200.00) range(0, 0)
    107 (300.00) range(422, 431)
  69 (200.00) range(0, 0)
    108 (300.00) range(431, 446)
  70 (200.00) range(0, 0)
    111 (300.00) range(446, 453)
  71 (200.00) range(0, 0)
    106 (300.00) range(406, 422)
  72 (200.00) range(0, 0)
    112 (300.00) range(453, 463)
