# Phylo experiments
- We'll simulate some trees, then use PHYLIP mix to infer the tree.
- We can compare an inferred tree to the true from the simulation using the Robinson-Foulds distance.
- *A specific question will be how alignment parameters affect tree accuracy.*
This follows up on alignment experiments that suggested non-default Needeleman-Wunsch parameters might decrease rates of erroneous event splitting and fusing.

## import necessary classes

In [None]:
import numpy as np
from barcode import Barcode
from cell_state import CellTypeTree, CellState
from cell_state_simulator import CellTypeSimulator
from clt_simulator import CLTSimulator
from barcode_simulator import BarcodeSimulator
from alignment import AlignerNW
from clt_observer import ObservedAlignedSeq, CLTObserver
from clt_estimator import CLTParsimonyEstimator
from collapsed_tree import CollapsedTree
from alignment import AlignerNW
from IPython.display import display
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, Counter
from itertools import product
from random import shuffle

## define simulation parameters


In [None]:
# time to simulate tree for
time = 5
# poisson rates for DSBs on targets
target_lambdas = np.array([1 for _ in range(10)])
# poisson rates for NHEJ on targets
repair_lambdas = np.array([150 for _ in range(10)]) 
# probability of imperfect repair
indel_probability = .1
# left and right average deletion
left_deletion_mu = 5
right_deletion_mu = 5
# average insertion length and dispersion
insertion_mu = 1
insertion_alpha = 10 # large dispersion for insertions
# cell branching process parameters
birth_lambda = 2
death_lambda = 1
# observed base mismatch rate from e.g. sequencing error
error_rate = 0.01
# fraction of simulated leaves that we sample
sampling_rate = 0.1
# Needleman-Wunsch alignment
# this aligner uses default NW params
aligner_default = AlignerNW()
# below are params optimal in alignment experiments
aligner_new = AlignerNW(mismatch=-3.68, gap_open=-7.36, gap_extend=-0.46)

## define a trivial cell-type tree with 1 type (we're not modeling cell types in this analysis)

In [None]:
cell_type_tree = CellTypeTree(cell_type=None, rate=1)

## instantiate barcode and tree simulators, leaf observer, and parsimony estimator

In [None]:
bcode_simulator = BarcodeSimulator(target_lambdas=np.array(target_lambdas),
                                   repair_rates=np.array(repair_lambdas),
                                   indel_probability=indel_probability,
                                   left_del_mu=left_deletion_mu,
                                   right_del_mu=right_deletion_mu,
                                   insertion_mu=insertion_mu,
                                   insertion_alpha=insertion_alpha)
cell_state_simulator = CellTypeSimulator(cell_type_tree)
# cell lineage tree (CLT) simulator combines barcode simulator, cell state simulator, and branching parameters
clt_simulator = CLTSimulator(birth_lambda, death_lambda, cell_state_simulator, bcode_simulator)
# observer object for getting the leaves of the tree with some error
observer_default = CLTObserver(sampling_rate=sampling_rate, error_rate=error_rate, aligner=aligner_default)
observer_new = CLTObserver(sampling_rate=sampling_rate, error_rate=error_rate, aligner=aligner_new)
# PHYLIP Mix tree estimator
clt_estimator = CLTParsimonyEstimator()

## simulate a cell lineage tree (clt)

In [None]:
# keep simulating until we get a tree with at least n_leaves leaves, give up if we fail 1000 times in a row
n_leaves = 500
for trial in range(1, 1001):    
    simulated_clt = clt_simulator.simulate(Barcode(),
                                           CellState(categorical=cell_type_tree),
                                           time)
    print('try {}, {} leaves'.format(trial, len(simulated_clt)), end='\r')
    if n_leaves <= len(simulated_clt) <= 2*n_leaves:
        break

## Sample the simulated tree
- we'll create sample data sets using both default and new alignment params
- branch lengths are proportional to time in this tree

In [None]:
# this returns the unique sampled sequences, and the tree with unobserved lineages pruned 
seed = np.random.randint(100)
obs_leaves_default, pruned_clt_default = observer_default.observe_leaves(simulated_clt, seed=seed)
obs_leaves_new, pruned_clt_new = observer_new.observe_leaves(simulated_clt, seed=seed)
for obs_leaves in (obs_leaves_default, obs_leaves_new):
    abundances = [leaf.abundance for leaf in obs_leaves]
    print('cells sampled: {}'.format(sum(abundances)))
    print('unique barcode sequences: {}'.format(len(obs_leaves)))
    print('maximum abundance: {}'.format(max(abundances)))
    print('abundance distribution:')
    sns.distplot(abundances, color='grey', kde=False,
                 bins=np.arange(.5, max(abundances) + 1.5),
                 hist_kws=dict(edgecolor='k', lw=2))
    plt.xlabel('cell abundance')
    plt.ylabel('unique barcodes')
    sns.despine()
    plt.show()
    # display(pruned_clt.savefig("%%inline"))

## Collapsed trees (deduplicate repeated sister taxa)
- branch lengths in collapsed tree correspond to event set difference rather than time
- genotype abundance indicated by number and bars on right
- if we end up with homoplasy (repeated genotypes that aren't sisters) these will still be repeated in the tree
- again we do it twice, one using default alignment params, and again with new params

In [None]:
collapsed_clt_default = pruned_clt_default.copy() 
for node in collapsed_clt_default.iter_descendants():
    node.dist = 0 if set(node.barcode_events.events) == set(node.up.barcode_events.events) else 1
collapsed_clt_default = CollapsedTree.collapse(collapsed_clt_default, deduplicate_sisters=True)

collapsed_clt_new = pruned_clt_new.copy() 
for node in collapsed_clt_new.iter_descendants():
    node.dist = 0 if set(node.barcode_events.events) == set(node.up.barcode_events.events) else 1
collapsed_clt_new = CollapsedTree.collapse(collapsed_clt_new, deduplicate_sisters=True)

# plot the editing profile as in Aaron et al.
for collapsed_clt in (collapsed_clt_default, collapsed_clt_new):
    collapsed_clt.editing_profile() # add file name argument (e.g. 'profile.pdf') if you want it saved
    plt.show()
    collapsed_clt.ladderize()
    # show the tree with alignment
    display(collapsed_clt.savefig("%%inline"))

# Now we estimate the tree and compare to the truth

## Use PHYLIP mix estimator to get set of maximally parsimonious trees
- this could take a while for larger trees
- we collapse zero-length branches in the binary trees to generate unique multifurcating trees

In [None]:
parsimony_clts_default = clt_estimator.estimate(obs_leaves_default, encode_hidden=False)
print('inferred tree, default alignment parameters')
display(parsimony_clts_default[0].savefig("%%inline"))
parsimony_clts_new = clt_estimator.estimate(obs_leaves_new, encode_hidden=False)
print('inferred tree, new alignment parameters')
display(parsimony_clts_new[0].savefig("%%inline"))

## define a custom RF distance function

In [None]:
def my_rf(tree1s, tree2, truncate=None):
    """
    custom Robinson-Foulds tree distance that aggregates over resolutions of repeated genotypes
    tree1s list of trees with no repeats
    tree2 may have repeats
    returns an array of possible RF values (rows) for each tree in tree1s (columns)
    """
#     assert(set(leaf.barcode_events for leaf in tree1) == set(leaf.barcode_events for leaf in tree2))
    repeats = defaultdict(list)
    for leaf in tree2:
        repeats[leaf.barcode_events].append(leaf.name)
    choices = list(product(*list(repeats.values())))
    if truncate is not None:
        choices = choices[:truncate]
    n_resolutions = len(choices)
    n_tree1s = len(tree1s)
    RFs = np.zeros((n_resolutions, n_tree1s))
    n_elements = n_resolutions * n_tree1s
    shuffle(choices)
    for i, choice in enumerate(choices):        
        choice = set(choice)
        tree2_copy = tree2.copy()
        for leaf in tree2_copy:
            if leaf.name not in choice:
                leaf.delete()
        for j, tree1 in enumerate(tree1s):
            rf = tree1.robinson_foulds(tree2_copy,
                               unrooted_trees=True,
                               attr_t1='barcode_events',
                               attr_t2='barcode_events')             
#             if not len(rf[2]) == len(tree1) == len(tree2_copy):
#                 raise ValueError(len(rf[2]), len(tree1), len(tree2_copy))
            RFs[i, j] = rf[0]
#             print('aggregating over duplication resolutions: {:.2%}\r'.format((i * n_tree1s + j + 1)/ n_elements),
#                   end='',
#                   flush=True)
#     print('')
    return RFs

## distance between true tree and each of the parsimony trees

In [None]:
truncate = None
print('RF using default alignment parameters:', np.max(my_rf([parsimony_clts_default[0]], collapsed_clt_default, truncate), 0)[0])
print('    RF using new alignment parameters:', np.max(my_rf([parsimony_clts_new[0]], collapsed_clt_new, truncate), 0)[0])

# Now a bigger experiment repeating this many times and aggregating results

In [None]:
    agg_data = []
    n_experiments = 100
    for exp_i in range(n_experiments):
        # simulate a cell lineage tree (clt)
        # keep simulating until we get a tree with at least n_leaves leaves, give up if we fail 1000 times in a row
        n_leaves = 500
        for trial in range(1, 1001):
            assert trial <= 1000
            simulated_clt = clt_simulator.simulate(Barcode(),
                                                   CellState(categorical=cell_type_tree),
                                                   time)
            print('simulation {}, try {}, {} leaves   '.format(exp_i + 1, trial, len(simulated_clt)),
                   end='\r', flush=True)
            if n_leaves <= len(simulated_clt) <= 2*n_leaves:
                break
        print()

        # Sample the simulated tree
        # branch lengths are proportional to time in this tree
        # this returns the unique sampled sequences, and the tree with unobserved lineages pruned
        for trial in range(1, 1001): 
            assert trial <= 1000
            seed = np.random.randint(100)
            obs_leaves_default, pruned_clt_default = observer_default.observe_leaves(simulated_clt, seed=seed)
            obs_leaves_new, pruned_clt_new = observer_new.observe_leaves(simulated_clt, seed=seed)
            if len(obs_leaves_default) >= 3 and len(obs_leaves_new) >= 3:
                break

        # Collapsed tree (deduplicate repeated sister taxa)
        # - branch lengths in collapsed tree correspond to event set difference rather than time
        # - genotype abundance indicated by number and bars on right
        # - if we end up with homoplasy (repeated genotypes that aren't sisters) these will still be repeated in the tree

        collapsed_clt_default = pruned_clt_default.copy()
        for node in collapsed_clt_default.iter_descendants():
            node.dist = 0 if set(node.barcode_events.events) == set(node.up.barcode_events.events) else 1
        collapsed_clt_default = CollapsedTree.collapse(collapsed_clt_default, deduplicate_sisters=True)

        collapsed_clt_new = pruned_clt_new.copy()
        for node in collapsed_clt_new.iter_descendants():
            node.dist = 0 if set(node.barcode_events.events) == set(node.up.barcode_events.events) else 1
        collapsed_clt_new = CollapsedTree.collapse(collapsed_clt_new, deduplicate_sisters=True)

        # Now we estimate the tree and compare to the truth
        #
        # Use PHYLIP mix estimator to get set of maximally parsimonious trees
        # - this could take a while
        # - we collapse zero-length branches in the binary trees to generate unique multifurcating trees

        parsimony_clts_default = clt_estimator.estimate(obs_leaves_default, encode_hidden=False)
        parsimony_clts_new = clt_estimator.estimate(obs_leaves_new, encode_hidden=False)

        # just use first
        parsimony_clts_default = [parsimony_clts_default[0]]
        parsimony_clts_new = [parsimony_clts_new[0]]

        # Distribution of distance between true tree and each of the parsimony trees
        # we use the mean of the results of different repeat resolutions from the function above

        truncate = 20
        rf_default = np.max(my_rf(parsimony_clts_default, collapsed_clt_default, truncate), 0)[0]
        rf_new     = np.max(my_rf(parsimony_clts_new,     collapsed_clt_new,     truncate), 0)[0]

        agg_data.append([exp_i, rf_default, rf_new])

    df = pd.DataFrame(agg_data,
                      columns=('simulation run', 'RF distance to true tree\ndefault alignment parameters', 'RF distance to true tree\nnew alignment parameters'))
    lim = max(df['RF distance to true tree\ndefault alignment parameters'].max(),
              df['RF distance to true tree\nnew alignment parameters'].max()) + 1
    g = sns.jointplot(x='RF distance to true tree\nnew alignment parameters',
                      y='RF distance to true tree\ndefault alignment parameters',
                      data=df,
                      color='black',
                      stat_func=None,
                      joint_kws=dict(alpha=.3, clip_on=False),
                      marginal_kws=dict(bins=np.arange(.5, lim + 1.5), hist_kws=dict(edgecolor='k')),
                      xlim=[0, lim],
                      ylim=[0, lim],
                      size=4,
                      space=0)
    g.ax_joint.plot([0, lim], [0, lim], c='grey', ls='--', zorder=0)
    plt.tight_layout()
    plt.savefig('phylo_experiments.pdf')
    plt.show()


# Questions for Aaron
- How are you collapsing zero-length branches in trees from mix, given that there are "?" states in the ancestors and some edges are "maybe" for having steps?
- Are you randomly resolving the unknown states to determine definite zero-length branches?

# Next things to try
- add an aligner with affine gap penalty, like the more current GESTALT
- try a new liner with a gap penalty function that is cut site aware