In [37]:
# packages for both resampling and plotting
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import re

# packages for only resampling
from itertools import combinations_with_replacement
import random
from collections import Counter


# +
def _sorted_doublets(tree):
    """Sorts the doublets into alphabetical order (important for assigning doublet index).
    
    Args:
        tree (string): Tree in NEWICK format.
    
    Returns:
        tree (string): New tree in NEWICK format after sorted doublets alphabetically.
    """
    for i in re.findall("\(\w*,\w*\)", tree):
        i_escape = re.escape(i)
        i_split = re.split('[\(\),]', i)
        ingroup = sorted([i_split[1], i_split[2]])
        tree = re.sub(i_escape, f'({ingroup[0]},{ingroup[1]})', tree)
    return tree

def _align_triplets(tree):
    """Aligns triplets so that all of them are in the order of (outgroup, ingroup).
    
    Find all ((x,x),x) triplets, then replace them with the same triplet but in (x,(x,x)) form.
    
    Args:
        tree (string): Tree in NEWICK format. Tree should have doublets sorted already.
    
    Returns:
        tree (string): New tree in NEWICK format after aligned triplets.
    """
    for i in re.findall("\(\(\w*,\w*\),\w*\)", tree):
        j = re.findall("\w*", i)
        i_escape = re.escape(i)
        tree = re.sub(i_escape, f'({j[7]},({j[2]},{j[4]}))', tree)
    return tree

def _sorted_quartets(tree):
    """Sorts the quartets so that it is in alphabetical order (important for assigning doublet index).
    
    Args:
        tree (string): Tree in NEWICK format. Tree should have doublets sorted already.
    
    Returns:
        tree (string): New tree in NEWICK format after sorted quartets alphabetically.
    """
    for i in re.findall("\(\(\w*,\w*\),\(\w*,\w*\)\)", tree):
        i_escape = re.escape(i)
        k = sorted([i[1:6], i[7:12]])
        subtree = f"({k[0]},{k[1]})"
        tree = re.sub(i_escape, subtree, tree)
    return tree

def _align_asym_quartet(tree):
    """Aligns asymmetric quartet so that all of them are in the order of (outgroup 1, outgroup 2, ingroup).
    
    Find all ((x,(x,x)),x) quartets, then replace them with the same asymmetric quartet but in (x,(x,(x,x))) form.
    
    Args:
        tree (string): Tree in NEWICK format. Tree should have doublets sorted and triplets aligned already.
    
    Returns:
        tree (string): New tree in NEWICK format after aligned asymmetric quartet.
    """
    for i in re.findall("\(\(\w*,\(\w*,\w*\)\),\w*\)", tree):
        j = re.findall("\w*", i)
        i_escape = re.escape(i)
        tree = re.sub(i_escape, f'({j[11]},({j[2]},({j[5]},{j[7]})))', tree)
    return tree

def _align_asym_quintet(tree):
    """Aligns asymmetric quintet so that all of them are in the order of (outgroup 1, outgroup 2, ingroup).
    
    Find all ((x,(x,(x,x))),x) quintets, then replace them with the same asymmetric quintet but in (x,(x,(x,(x,x)))) form.
    
    Args:
        tree (string): Tree in NEWICK format. Tree should have doublets sorted, and triplets and asymmetric quartets aligned already.
    
    Returns:
        tree (string): New tree in NEWICK format after aligned asymmetric quintet.
    """
    for i in re.findall("\(\(\w*,\(\w*,\(\w*,\w*\)\)\),\w*\)", tree):
        j = re.findall("\w*", i)
        i_escape = re.escape(i)
        tree = re.sub(i_escape, f'({j[15]},({j[2]},({j[5]},({j[8]},{j[10]}))))', tree)
    return tree

def _align_asym_sextet(tree):
    """Aligns asymmetric sextet so that all of them are in the order of (outgroup 1, outgroup 2, ingroup).
    
    Find all ((x,(x,(x,(x,x)))),x) sextets, then replace them with the same asymmetric sextet but in (x,(x,(x,(x,(x,x))))) form.
    
    Args:
        tree (string): Tree in NEWICK format. Tree should have doublets sorted, and triplets, asymmetric quartets, 
        and asymmetric quintets aligned already.
    
    Returns:
        tree (string): New tree in NEWICK format after aligned asymmetric sextet.
    """
    for i in re.findall("\(\(\w*,\(\w*,\(\w*,\(\w*,\w*\)\)\)\),\w*\)", tree):
        j = re.findall("\w*", i)
        i_escape = re.escape(i)
        tree = re.sub(i_escape, f'({j[19]},({j[2]},({j[5]},({j[8]},({j[11]},{j[13]})))))', tree)
    return tree

# -

def sort_align_tree(tree):
    """Sort and align provided tree. 
    
    Args:
        tree (string): Tree in NEWICK format.
    
    Returns:
        tree (string): Tree in NEWICK format.
            Trees are sorted to have all asymmetric sextets in (x,(x,(x,(x,(x,x))))) format, asymmetric quintets in (x,(x,(x,(x,x))))
            asymmetric quartets in (x,(x,(x,x))) format, triplets in (x,(x,x)) format, and all doublets/quartets in alphabetical order.
    """
    tree = _align_asym_sextet(_align_asym_quintet(_align_asym_quartet(_sorted_quartets(_sorted_doublets(_align_triplets(tree))))))
    return tree

def read_dataset(path):
    """Reads dataset txt file located at `path`.
    
    Args:
        path (string): Path to txt file of dataset. txt file should be formatted as NEWICK trees 
            separated with semi-colons and no spaces.
    
    Returns:
        all_trees_sorted (list): List where each entry is a string representing a tree in NEWICK format. 
            Trees are sorted to have all asymmetric sextets in (x,(x,(x,(x,(x,x))))) format, asymmetric quintets in (x,(x,(x,(x,x))))
            asymmetric quartets in (x,(x,(x,x))) format, triplets in (x,(x,x)) format, and all doublets/quartets in alphabetical order.
    """
    with open(path) as f:
        lines = f.readlines()

    all_trees_unsorted = lines[0].split(';')
    all_trees_sorted = [sort_align_tree(i) for i in all_trees_unsorted]
    return all_trees_sorted


# +
def make_cell_dict(cell_fates):
    """Makes a dictionary of all possible cell fates.
    
    Args:
        cell_fates (list): List with each entry as a cell fate.
    
    Returns:
        cell_dict (dict): Keys are cell types, values are integers.
    """
    
    cell_dict = {}
    for i, j in enumerate(cell_fates):
        cell_dict[j] = i
        
    return cell_dict

### for doublet analysis

def make_doublet_dict(cell_fates):
    """Makes a dictionary of all possible doublets.
    
    Args:
        cell_fates (list): List with each entry as a cell fate.
        
    Returns:
        doublet_dict (dict): Keys are doublets, values are integers.
    """

    total = '0123456789'
    doublet_combinations = []
    for j in list(combinations_with_replacement(total[:len(cell_fates)],2)):
        #print(j)
        k = sorted([cell_fates[int(j[0])], cell_fates[int(j[1])]])
        doublet = f"({k[0]},{k[1]})"
        doublet_combinations.append(doublet)

    doublet_dict = {}
    for i, j in enumerate(doublet_combinations):
        doublet_dict[j] = i
    return doublet_dict

# returns relavent subtrees
def _flatten_doublets(all_trees_sorted):
    """Makes a list of all doublets in set of trees.
    
    Args:
        all_trees_sorted (list): List where each entry is a string representing a tree in NEWICK format. 
            Trees are sorted to have all asymmetric quartets in (x,(x,(x,x))) format, triplets in (x,(x,x)) format, 
            and all doublets/quartets in alphabetical order.
    
    Returns:
        doublets (list): List with each entry as a doublet (string).
    """
    doublets = []
    for i in all_trees_sorted:
        doublets.extend(re.findall("\(\w*,\w*\)", i))
    return doublets


def _flatten_all_cells(all_trees_sorted):
    """Makes a list of all cells in the set of trees.
    
    Args:
        all_trees_sorted (list): List where each entry is a string representing a tree in NEWICK format. 
            Trees are sorted to have all asymmetric quartets in (x,(x,(x,x))) format, triplets in (x,(x,x)) format, 
            and all doublets/quartets in alphabetical order.
    
    Returns:
        all_cells (list): List with each entry as a cell (string).
    """
    
    all_cells = []
    for i in all_trees_sorted:
        for j in re.findall("[A-Za-z0-9]+", i):
            all_cells.extend(j)
    return all_cells


def make_df_doublets(all_trees_sorted, doublet_dict, resample, labels_bool=False):
    """Makes a DataFrame of all doublets in the set of trees provided.
    
    Args:
        all_trees_sorted (list): List where each entry is a string representing a tree in NEWICK format. 
            Trees are sorted to have all asymmetric quartets in (x,(x,(x,x))) format, triplets in (x,(x,x)) format, 
            and all doublets/quartets in alphabetical order.
        doublet_dict (dict): Keys are doublets, values are integers.
        resample (int): Resample number.
        labels_bool (bool, optional): If True, then index of resulting DataFrame uses `doublet_dict` keys.
            
    Returns:
        df_doublets (DataFrame): Rows are doublets, column is resample number.
    """
    doublets = _flatten_doublets(all_trees_sorted)
    doublets_resample_index = [doublet_dict[i] for i in doublets]
    df_doublets = pd.DataFrame.from_dict(Counter(doublets_resample_index), orient='index', columns=[f"{resample}"])
    if labels_bool == True:
        df_doublets = df_doublets.rename({v: k for k, v in doublet_dict.items()})
    return df_doublets


def make_df_all_cells(all_trees_sorted, cell_dict, resample, labels_bool=False):
    """Makes a DataFrame of all cells in the set of trees.
    
    Args:
        all_trees_sorted (list): List where each entry is a string representing a tree in NEWICK format. 
            Trees are sorted to have all asymmetric quartets in (x,(x,(x,x))) format, triplets in (x,(x,x)) format, 
            and all doublets/quartets in alphabetical order.
        cell_dict (dict): Keys are cell types, values are integers.
        resample (int): Resample number.
        labels_bool (bool, optional): If True, then index of resulting DataFrame uses `doublet_dict` keys.
            
    Returns:
        df_doublets (DataFrame): Rows are cell types, column is resample number.
    """
    all_cells = _flatten_all_cells(all_trees_sorted)
    all_cells_resample_index = [cell_dict[i] for i in all_cells]
    df_all_cells = pd.DataFrame.from_dict(Counter(all_cells_resample_index), orient='index', columns=[f"{resample}"])
    if labels_bool == True:
        df_all_cells = df_all_cells.rename({v: k for k, v in cell_dict.items()})
    return df_all_cells


# Replace all leaves drawing from repl_list
def _replace_all(tree, repl_list, replacement_bool):
    """Replaces all cells in tree with a cell drawing from `repl_list`.
    
    Args:
        tree (string): Tree in NEWICK format.
        repl_list (list): List of all cells.
        replacement_bool (bool): Draw with or without replacement from `repl_list`.
    
    Returns:
        new_tree_sorted (string): tree in NEWICK format.
            Tree is sorted to have all triplets in (x,(x,x)) format, and all doublets/quartets in alphabetical order.
    """
    if replacement_bool==False:
        def repl_all(var):
            return repl_list.pop()
    elif replacement_bool==True:
        def repl_all(var):
            return random.choice(repl_list)
    new_tree = re.sub("[A-Za-z0-9]+", repl_all, tree)
    new_tree_sorted = sort_align_tree(new_tree)
    return new_tree_sorted


def _process_dfs_doublet(df_doublet_true, dfs_doublet_new, num_resamples, doublet_dict, cell_dict, df_all_cells_true):
    """Arranges observed counts for each doublet in all resamples and original trees into a combined DataFrame.
    
    Last column is analytically solved expected number of each doublet.
        
    Args:
        df_doublet_true (DataFrame): DataFrame with number of each doublet in original trees, indexed by `doublet_dict`.
        dfs_doublet_new (list): List with each entry as DataFrame of number of each doublet in each set
            of resampled trees, indexed by doublet_dict.
        num_resamples (int): Number of resample datasets.
        doublet_dict (dict): Keys are doublets, values are integers.
        cell_dict (dict): Keys are cell types, values are integers.
        df_all_cells_true (DataFrame): DataFrame with number of each cell fate in original trees, indexed by `cell_dict`.
    
    Returns:
        dfs_c (DataFrame): Indexed by values from `doublet_dict`.
            Last column is analytically solved expected number of each doublet.
            Second to last column is observed number of occurences in the original dataset.
            Rest of columns are the observed number of occurences in the resampled sets.
    
    """
    
    dfs_list = [dfs_doublet_new[i] for i in range(num_resamples)] + [df_doublet_true]
    dfs_c = pd.concat(dfs_list, axis=1, sort=False)
    
    dfs_c.fillna(0, inplace=True)

    # for doublet df
    empty_indices = [i for i in range(0,len(doublet_dict)) if i not in dfs_c.index]
    df_to_append_list = []
    for i in empty_indices:
        num_zeros = num_resamples+1
        index_to_append = {i: [0]*num_zeros}
        df_to_append = pd.DataFrame(index_to_append)
        df_to_append = df_to_append.transpose()
        df_to_append.columns = dfs_c.columns
        df_to_append_list.append(df_to_append)
    dfs_c = pd.concat([dfs_c]+df_to_append_list, axis=0)
    dfs_c.sort_index(inplace=True)
    
    # for all cells df
    empty_indices = [i for i in range(0,len(cell_dict)) if i not in df_all_cells_true.index]
    for i in empty_indices:
        df_to_append = pd.DataFrame([0], index=[i], columns=[f'{num_resamples}'])
        df_all_cells_true = pd.concat([df_all_cells_true, df_to_append], axis=0)
    
    df_all_cells_true_norm = df_all_cells_true/df_all_cells_true.sum()
    df_all_cells_true_norm = df_all_cells_true_norm.rename({v: k for k, v in cell_dict.items()})
    
    expected_list = []
    for key in doublet_dict.keys():
        split = key.split(',')
        cell_1 = split[0][-1]
        cell_2 = split[1][0]
        #print(cell_1, cell_2)
        p_cell_1 = df_all_cells_true_norm.loc[cell_1].values[0]
        p_cell_2 = df_all_cells_true_norm.loc[cell_2].values[0]
        #print(p_cell_1, p_cell_2)
        expected = dfs_c.sum()[0]*p_cell_1*p_cell_2
        if cell_1 != cell_2:
            expected *= 2
        #print(expected)
        expected_list.append(expected)
        
    dfs_c = dfs_c.copy()
    dfs_c['expected'] = expected_list
    dfs_c.fillna(0, inplace=True)
    
    return dfs_c


def resample_trees_doublets(all_trees_sorted, 
                            num_resamples=10000, 
                            replacement_bool=True, 
                            cell_fates='auto'
                            ):
    """Performs resampling of trees, drawing with or without replacement, returning subtree dictionary and DataFrame containing
    number of doublets across all resamples, the original trees, and the expected number (solved analytically).
    
    Resampling is done by replacing each cell fate with a randomly chosen cell fate across all trees.
    If `cell_fates` not explicitly provided, use automatically determined cell fates based on tree dataset.
    
    
    Args:
        all_trees_sorted (list): List where each entry is a string representing a tree in NEWICK format. 
            Trees are sorted to have all asymmetric quartets in (x,(x,(x,x))) format, triplets in (x,(x,x)) format, 
            and all doublets/quartets in alphabetical order.
        num_resamples (int, optional): Number of resample datasets.
        replacement_bool (bool, optional): Sample cells with or without replacement drawing from the pool of all cells.
        cell_fates (string or list, optional): If 'auto' (i.e. not provided by user), automatically determined 
            based on tree dataset. User can also provide list where each entry is a string representing a cell fate.
    
    Returns:
        (tuple): Contains the following variables.
        - doublet_dict (dict): Keys are doublets, values are integers.
        - cell_fates (list): List where each entry is a string representing a cell fate.
        - dfs_c (DataFrame): Indexed by values from `doublet_dict`.
            Last column is analytically solved expected number of each doublet.
            Second to last column is observed number of occurences in the original dataset.
            Rest of columns are the observed number of occurences in the resampled sets.


    """
    # automatically determine cell fates if not explicitly provided
    if cell_fates == 'auto':
        cell_fates = sorted(list(np.unique(re.findall('[A-Z]', ''.join([i for sublist in all_trees_sorted for i in sublist])))))
    
    # make_subtree_dict functions can only handle 10 cell fates max
    if len(cell_fates)>10:
        print('warning!')
        
    doublet_dict = make_doublet_dict(cell_fates)
    cell_dict = make_cell_dict(cell_fates)
    
    # store result for each rearrangement in dfs list
    dfs_doublets_new = []
    df_doublets_true = make_df_doublets(all_trees_sorted, doublet_dict, 'observed', False)
    df_all_cells_true = make_df_all_cells(all_trees_sorted, cell_dict, 'observed', False)

    # rearrange leaves num_resamples times
    for resample in tqdm(range(0, num_resamples)):
        all_cells_true = _flatten_all_cells(all_trees_sorted)
        
        # shuffle if replacement=False
        if replacement_bool==False:
            random.shuffle(all_cells_true)
            
        new_trees = [_replace_all(i, all_cells_true, replacement_bool) for i in all_trees_sorted]
        df_doublets_new = make_df_doublets(new_trees, doublet_dict, resample, False)
        dfs_doublets_new.append(df_doublets_new)
        
    dfs_c = _process_dfs_doublet(df_doublets_true, dfs_doublets_new, num_resamples, doublet_dict, cell_dict, df_all_cells_true)
    
    return (doublet_dict, cell_fates, dfs_c)

In [16]:
from lineage_motif.simulate import *

In [17]:
base=0.485
inv_base=0.8-base
increment=-0.02

transition_matrix = np.array([[0.2, inv_base, 0, 0, 0, 0, base, 0, 0, 0, 0, 0],
                              [0, 0.2, inv_base+increment, 0, 0, 0, 0, base-increment, 0, 0, 0, 0],
                              [0, 0, 0.2, inv_base+increment*2, 0, 0, 0, 0, base-increment*2, 0, 0, 0],
                              [0, 0, 0, 0.2, inv_base+increment*3, 0, 0, 0, 0, base-increment*3, 0, 0],
                              [0, 0, 0, 0, 0.2, inv_base+increment*4, 0, 0, 0, 0, base-increment*4, 0],
                              [0, 0, 0, 0, 0, 0.2, 0, 0, 0, 0, 0, 0.8],
                              [0]*12,
                              [0]*12,
                              [0]*12,
                              [0]*12,
                              [0]*12,
                              [0]*12,])

In [38]:
all_trees_unsorted = [simulate_tree(transition_matrix, starting_progenitor='a', labels='abcdefABCDEF') for i in range(1)]
all_trees_sorted = [sort_align_tree(i) for i in all_trees_unsorted]
all_trees_sorted[:10]

['(A,((((((((((E,E),(F,F)),E),D),C),((B,(((((E,E),(D,(D,(D,(D,(D,D)))))),((D,D),((E,(F,F)),(D,(D,(E,E)))))),(C,(D,D))),(B,(B,B)))),(B,(C,C)))),B),(C,((E,(E,E)),(D,(E,(F,F)))))),A),A))']

In [39]:
(doublet_dict, 
 cell_fates, 
 dfs_c) = resample_trees_doublets(all_trees_sorted, 
                                           num_resamples=10, 
                                           replacement_bool=True, 
                                           cell_fates='auto'
                                           )

  0%|          | 0/10 [00:00<?, ?it/s]

[0, 5]
   0  1  2  3  4  5  6  7  8  9  observed
0  0  0  0  0  0  0  0  0  0  0         0
   0  1  2  3  4  5  6  7  8  9  observed
5  0  0  0  0  0  0  0  0  0  0         0


In [40]:
dfs_c

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,observed,expected
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.05104
1,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.204159
2,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.170132
3,0.0,0.0,0.0,2.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.476371
4,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,0.408318
5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.204159
6,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.204159
7,0.0,2.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.340265
8,2.0,0.0,0.0,2.0,0.0,0.0,2.0,1.0,2.0,3.0,0.0,0.952741
9,0.0,1.0,4.0,1.0,2.0,2.0,1.0,1.0,0.0,1.0,0.0,0.816635
