# Alignment experiments

## import necessary classes

In [None]:
import numpy as np
from barcode import Barcode
from cell_state import CellTypeTree, CellState
from cell_state_simulator import CellTypeSimulator
from clt_simulator import CLTSimulator
from barcode_simulator import BarcodeSimulator
from alignment import AlignerNW
from clt_observer import CLTObserver
from IPython.display import display
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
from ternary.helpers import simplex_iterator
from matplotlib.colors import Normalize

## define simulation parameters


In [None]:
# time to simulate tree for
time = 3.5
# base mismatch rate from e.g. sequencing error
error_rate = .01 
# poisson rates for DSBs on targets
target_lambdas = np.array([.5 for _ in range(10)])
# poisson rates for NHEJ on targets
repair_lambdas = np.array([100 for _ in range(10)]) 
# probability of imperfect repair
indel_probability = .1
# left and right average deletion
left_deletion_mu = 5
right_deletion_mu = 5
# average insertion length and dispersion
insertion_mu = 1
insertion_alpha = 10 # large dispersion for insertions
# cell branching process parameters
birth_lambda = 1.5
death_lambda = 1e-6

## define a simple cell-type tree with 3 types to parameterize how cell types (tree node colors) can transition
### here the root cell type can transition to one of two descendent cell types

In [None]:
# cell types None, 0, and 1
cell_type_tree = CellTypeTree(cell_type=None, rate=0.9)
cell_type_tree.add_child(
    CellTypeTree(cell_type=0, rate=0.5))
cell_type_tree.add_child(
    CellTypeTree(cell_type=1, rate=0.2))
print(cell_type_tree)

## instantiate barcode and tree simulators, and leaf observer

In [None]:
bcode_simulator = BarcodeSimulator(target_lambdas=np.array(target_lambdas),
                                   repair_rates=np.array(repair_lambdas),
                                   indel_probability=indel_probability,
                                   left_del_mu=left_deletion_mu,
                                   right_del_mu=right_deletion_mu,
                                   insertion_mu=insertion_mu,
                                   insertion_alpha=insertion_alpha)
cell_state_simulator = CellTypeSimulator(cell_type_tree)
# cell lineage tree (CLT) simulator combines barcode simulator, cell state simulator, and branching parameters
clt_simulator = CLTSimulator(birth_lambda, death_lambda, cell_state_simulator, bcode_simulator)
# observer object for getting the leaves of the tree with some error
observer = CLTObserver(sampling_rate=1, error_rate=error_rate)

## simulate a cell lineage tree

In [None]:
# keep simulating until we get a tree with at least 100 leaves
for trial in range(1, 101):    
    obs_leaves, pruned_clt = observer.observe_leaves(clt_simulator.simulate(Barcode(),
                                                                            CellState(categorical=cell_type_tree),
                                                                            time))
    print('try {}, {} leaves'.format(trial, len(obs_leaves)), end='\r')
    if len(obs_leaves) >= 100:
        break
# plot the editing profile as in Aaron et al.
# pruned_clt.editing_profile("foo.profile.pdf") # uncomment to save image
pruned_clt.editing_profile()
plt.show()
# show the tree with alignment
display(pruned_clt.savefig("%%inline"))
# pruned_clt.savefig("foo.pdf") # uncomment to save image
# save the barcode objects on the leaves for later
barcodes_tree = [leaf.barcode for leaf in obs_leaves]

## simulate a list of independent barcodes, each with 1 edit
### and plot the joint and marginal distributions of event start/end positions

In [None]:
barcodes_1event = []
start_end = []
while len(barcodes_1event) < 500:
    new_barcode = bcode_simulator.simulate(Barcode(), 1)
    if len(new_barcode.get_events()) == 1:
        barcodes_1event.append(new_barcode)
        start, end, _ = new_barcode.get_events()[0]
        start_end.append([start, end])
                
df = pd.DataFrame(start_end, columns=('event start position', 'event end position'))        
sns.jointplot(x='event start position', y='event end position', data=df,
              stat_func=None, joint_kws=dict(alpha=.2), marginal_kws=dict(bins=30))

## similarly, simulate a list of independent barcodes, each with 2 edits

In [None]:
barcodes_2events = []
while len(barcodes_2events) < 500:
    new_barcode = bcode_simulator.simulate(Barcode(), 2)
    if len(new_barcode.get_events()) == 2:
        barcodes_2events.append(new_barcode)

## ASSESSING ALIGNMENT-BASED EVENT INFERENCE
- a function that takes a list of barcodes and assesses event identification performance over a simplex of alignment penalty parameters
- shows a heatmap of error rate over the parameter simplex, and a bar plot of ranked parameter sets, with default NW parameters as a red bar 
- note that we only look at a half simplex since gap extension penalty cannot be greater than gap open penalty

In [None]:
def simplex_map(barcodes, scale=10):
    default0 = np.array([1, 10, .5])
    default = tuple(scale*default0/default0.sum())
    
    fig = plt.figure(figsize=(7,6))
    ax = fig.add_subplot(111, projection='3d')

    count_error_rate = dict()
    events_true = [sorted(barcode.get_events()) for barcode in barcodes]
    counts_true = [len(event) for event in events_true] # count of indels in each events

    ijks = [default]
    for i, j, k in simplex_iterator(scale, boundary=False):
        if j > k:
            ijks.append((i, j, k))
    for ct, ijk in enumerate(ijks, 1):
        i, j, k = ijk        
        params = dict(return_all=True, mismatch=-i, gap_open=-j, gap_extend=-k)
        events_NW = [barcode.get_events(aligner=AlignerNW(**params)) for barcode in barcodes]       
        miscalled_barcodes = [(x, y) for x,y in zip(events_true, events_NW) if not next((True for yy in y if x == sorted(yy)), False)]
        count_NW = [len(events[0]) for events in events_NW]
        count_error_rate[(i, j, k)] = sum((y > x if x == 1 else (y < x if x == 2 else y != x)) for x, y in zip(counts_true, count_NW))/len(barcodes)           
        print('{:.2%} complete\r'.format(ct/len(ijks)), end='')
    
    for i, j, k in count_error_rate:
        if (i, j, k) != default:
            p = ax.scatter(i, j, k, c=count_error_rate[(i, j, k)], norm=Normalize(vmin=0, vmax=max(count_error_rate.values())), marker='o', cmap='cool')
        else:
            p = ax.scatter(*default, c='k', marker='+')
            
    ax.set_xlim([0, scale])
    ax.set_ylim([0, scale])
    ax.set_zlim([0, scale])
    ax.set_xlabel('mismatch penalty')
    ax.set_ylabel('gap open penalty')
    ax.set_zlabel('gap extend penalty')
    ax.view_init(elev=45., azim=45.)
    fig.colorbar(p)
    plt.tight_layout()
    plt.show()
    
    

#     count_error_rate = dict()
#     events_true = [sorted(barcode.get_events()) for barcode in barcodes]
#     counts_true = [len(event) for event in events_true] # count of indels in each events
#     for (i, j, k) in list(simplex_iterator(scale, boundary=False)) + [default]:
#         params = dict(return_all=True, mismatch=-i, gap_open=-j, gap_extend=-k)
#         if j > k:
#             events_NW = [barcode.get_events(aligner=AlignerNW(**params)) for barcode in barcodes]       
#             miscalled_barcodes = [(x, y) for x,y in zip(events_true, events_NW) if not next((True for yy in y if x == sorted(yy)), False)]
#             count_NW = [len(events[0]) for events in events_NW]
#             count_error_rate[(i, j, k)] = sum((y > x if x == 1 else (y < x if x == 2 else y != x)) for x, y in zip(counts_true, count_NW))/len(barcodes)
#     ax = ternary.heatmap({key:count_error_rate[key] for key in count_error_rate if key!=default}, scale=scale, cmap=None)
    
#     plt.figure(figsize=(10,3))
    df = pd.DataFrame([[params, count_error_rate[params], (True if params==default else False)] for params in sorted(list(count_error_rate.keys()), key=lambda key: count_error_rate[key])],
                      columns=('parameters', 'miscount rate', 'default'))
    plt.bar(left=range(len(df.parameters)), height=df['miscount rate'], color=['red' if x else 'gray' for x in df.default])
    plt.xlabel('rank')
    plt.ylabel('miscount rate')   
    best = df['parameters'][0]
    best_rate = count_error_rate[best]
    best = default0.sum() * np.array(best) / sum(best)
    print('best alignment: {}, count error rate: {}'.format(best, best_rate))

## using the single-event barcodes, assess how often events are split into more than one

In [None]:
simplex_map(barcodes_1event, 25)

## using the double-event barcodes, assess how often the two events are fused into one

In [None]:
simplex_map(barcodes_2events, 25)

## using the barcodes from the leaves of the tree, assess how often the number of events is incorrect

In [None]:
simplex_map(barcodes_tree, 25)

## to do:
- include flanking sequence from Aaron
- try cut-site aware gap penalties