# Alignment experiments
## summary:
- simulate a forest of trees with the same simulation parameters
- observe the barcodes from the leaves of the trees
- use Needleman-Wunsch alignment to infer the indel events on these leaves
- assess how many barcodes in each simulation have their events correctly inferred

## SIMULATIONS
### import necessary classes

In [None]:
import numpy as np

from barcode import Barcode
from cell_state import CellTypeTree, CellState
from cell_state_simulator import CellTypeSimulator
from clt_simulator import CLTSimulator
from barcode_simulator import BarcodeSimulator
from alignment import AlignerNW
from clt_observer import CLTObserver
from IPython.display import display
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns

### define simulation parameters


In [None]:
n_trees = 20
time = 3
error_rate = .01
target_lambdas = np.array([0.1 for _ in range(10)])
repair_lambdas = np.array([1 for _ in range(10)])
indel_probability = .5
left_deletion_mu = 3
right_deletion_mu = 3
insertion_mu = 3
insertion_alpha = 10 # large dispersion for insertions
birth_lambda = 1.25
death_lambda = 0.01

### define a cell-type tree to parameterize how cell types can transition

In [None]:
# cell types 0 and 1
cell_type_tree = CellTypeTree(cell_type=None, rate=0.9)
cell_type_tree.add_child(
    CellTypeTree(cell_type=0, rate=0.5))
cell_type_tree.add_child(
    CellTypeTree(cell_type=1, rate=0.2))
print(cell_type_tree)

### instantiate barcode and tree simulators, and leaf observer

In [None]:
bcode_simulator = BarcodeSimulator(target_lambdas=np.array(target_lambdas),
                                   repair_rates=np.array(repair_lambdas),
                                   indel_probability=indel_probability,
                                   left_del_mu=left_deletion_mu,
                                   right_del_mu=right_deletion_mu,
                                   insertion_mu=insertion_mu,
                                   insertion_alpha=insertion_alpha)
cell_state_simulator = CellTypeSimulator(cell_type_tree)
clt_simulator = CLTSimulator(birth_lambda, death_lambda, cell_state_simulator, bcode_simulator)
observer = CLTObserver(sampling_rate=1, error_rate=error_rate)

### simulate a forest of observed leaves and trees pruned to the observed lineages

In [None]:
forest_obs_leaves, forest_pruned_clt = zip(*sorted((observer.observe_leaves(clt_simulator.simulate(Barcode(), CellState(categorical=cell_type_tree), time)) for _ in range(n_trees)), key=lambda x: len(x[0])))
# display the first of n
for tree_idx, pruned_clt in enumerate(forest_pruned_clt, 1):
    print('tree {}'.format(tree_idx))
    display(pruned_clt.savefig("%%inline"))

## ALIGNMENT-BASED EVENT INFERENCE
### define a function that takes a list of alignment parameter dictionaries and makes a plot showing the ranked number of simulated barcodes (dashed line) and the number of correct event inference results for each parameter set (colored bars)

In [None]:
def left_align_indel(barcode, event):
    raise NotImplementedError()

In [None]:
def alignment_experiment(params_list):
    cut_sites = forest_obs_leaves[0][0].barcode.abs_cut_sites
    data = []
    for params in params_list:
        for sim, obs_leaves in enumerate(forest_obs_leaves, 1):
            events_true, events_NW = zip(*[(leaf.barcode.get_events(), leaf.barcode.get_events(aligner=AlignerNW(**params))) for leaf in obs_leaves])
            miscalled_barcodes = sum(sorted(x) != sorted(y) for x,y in zip(events_true, events_NW))
            events_NW_total = sum(len(evts) for evts in events_NW)
            events_NW_orphans = sum(sum(evt[0] <= cut_site <= evt[1] - 1 for cut_site in cut_sites) == 0 for evts in events_NW for evt in evts)
            data.append([sim,
                         miscalled_barcodes,
                         len(obs_leaves),
                         str(params),
                         events_NW_total,
                         events_NW_orphans])
    df = pd.DataFrame(data=data,
                      columns=('simulation',
                               'incorrectly annotated barcodes',
                               'simulated barcodes',
                               'NW parameters',
                               'total events',
                               'split events'))
    df['fraction split'] = df['split events']/df['total events'] 
    plt.figure(figsize=(10,5))
    sns.pointplot(y='simulated barcodes', x='simulation', data=df,
                  clip_on=False, color='gray', linestyles='--')    
    sns.barplot(y='incorrectly annotated barcodes', x='simulation', hue='NW parameters', data=df)
    plt.tight_layout()
    plt.show()
    plt.figure(figsize=(10,5))
    sns.barplot(y='fraction split', x='simulation', hue='NW parameters', data=df)
    plt.tight_layout()
    plt.show()

### experiment 1: different gap open penalties

In [None]:
params_list = [dict(mismatch=-1, gap_open=gap_open, gap_extend=-.5) for gap_open in (-15, -10, -5, -.5)]
alignment_experiment(params_list)

### experiment 2: different gap extension penalties

In [None]:
params_list = [dict(mismatch=-1, gap_open=-10, gap_extend=gap_extend) for gap_extend in (-10, -5, -1, -.5, -.1, -.01)]
alignment_experiment(params_list)

### experiment 3: different mismatch penalties

In [None]:
params_list = [dict(mismatch=mismatch, gap_open=-10, gap_extend=-0.5) for mismatch in (-3, -2, -1, -.5)]
alignment_experiment(params_list)

## to do:
- be soft on left alignment of events, maybe add an option to `get_events`
- get flanking sequence from Aaron