In [5]:
import re

In [19]:
def _align_asym_septet(tree):
    """Aligns asymmetric septet so that all of them are in the order of (outgroup 1, outgroup 2, ingroup).
    
    Find all ((x,(x,(x,(x,x)))),x) septets, then replace them with the same asymmetric septet but in (x,(x,(x,(x,(x,x))))) form.
    
    Args:
        tree (string): Tree in NEWICK format. Tree should have doublets sorted, and triplets and asymmetric quartets aligned already.
    
    Returns:
        tree (string): New tree in NEWICK format after aligned asymmetric septet.
    """
    for i in re.findall("\(\(\w*,\(\w*,\(\w*,\(\w*,\(\w*,\w*\)\)\)\)\),\w*\)", tree):
        j = re.findall("\w*", i)
        i_escape = re.escape(i)
        tree = re.sub(i_escape, f'({j[23]},({j[2]},({j[5]},({j[8]},({j[11]},({j[14]},{j[16]}))))))', tree)
    return tree

In [18]:
tree = "((x,(x,(x,(x,(x,x))))),x)"

_align_asym_septet(tree)

'(x,(x,(x,(x,(x,(x,x))))))'

In [2]:
tree = "(A,(A,(A,(A,A))))"

tree[1], tree[3:-1]

('A', '(A,(A,(A,A)))')

In [None]:
def _process_dfs_doublet(df_doublet_true, dfs_doublet_new, num_resamples, doublet_dict, cell_dict, df_all_cells_true):
    """Arranges observed counts for each doublet in all resamples and original trees into a combined DataFrame.
    
    Last column is analytically solved expected number of each doublet.
        
    Args:
        df_doublet_true (DataFrame): DataFrame with number of each doublet in original trees, indexed by `doublet_dict`.
        dfs_doublet_new (list): List with each entry as DataFrame of number of each doublet in each set
            of resampled trees, indexed by doublet_dict.
        num_resamples (int): Number of resample datasets.
        doublet_dict (dict): Keys are doublets, values are integers.
        cell_dict (dict): Keys are cell types, values are integers.
        df_all_cells_true (DataFrame): DataFrame with number of each cell fate in original trees, indexed by `cell_dict`.
    
    Returns:
        dfs_c (DataFrame): Indexed by values from `doublet_dict`.
            Last column is analytically solved expected number of each doublet.
            Second to last column is observed number of occurences in the original dataset.
            Rest of columns are the observed number of occurences in the resampled sets.
    
    """
    
    dfs_list = [dfs_doublet_new[i] for i in range(num_resamples)] + [df_doublet_true]
    dfs_c = pd.concat(dfs_list, axis=1, sort=False)
    
    dfs_c.fillna(0, inplace=True)

    # for doublet df
    empty_indices = [i for i in range(0,len(doublet_dict)) if i not in dfs_c.index]
    df_to_append_list = []
    for i in empty_indices:
        num_zeros = num_resamples+1
        index_to_append = {i: [0]*num_zeros}
        df_to_append = pd.DataFrame(index_to_append)
        df_to_append = df_to_append.transpose()
        df_to_append.columns = dfs_c.columns
        df_to_append_list.append(df_to_append)
    dfs_c = pd.concat([dfs_c]+df_to_append_list, axis=0)
    dfs_c.sort_index(inplace=True)
    
    # for all cells df
    empty_indices = [i for i in range(0,len(cell_dict)) if i not in df_all_cells_true.index]
    for i in empty_indices:
        df_to_append = pd.DataFrame([0], index=[i], columns=[f'{num_resamples}'])
        df_all_cells_true = pd.concat([df_all_cells_true, df_to_append], axis=0)
    
    df_all_cells_true_norm = df_all_cells_true/df_all_cells_true.sum()
    df_all_cells_true_norm = df_all_cells_true_norm.rename({v: k for k, v in cell_dict.items()})
    
    expected_list = []
    for key in doublet_dict.keys():
        split = key.split(',')
        cell_1 = split[0][-1]
        cell_2 = split[1][0]
        #print(cell_1, cell_2)
        p_cell_1 = df_all_cells_true_norm.loc[cell_1].values[0]
        p_cell_2 = df_all_cells_true_norm.loc[cell_2].values[0]
        #print(p_cell_1, p_cell_2)
        expected = dfs_c.sum()[0]*p_cell_1*p_cell_2
        if cell_1 != cell_2:
            expected *= 2
        #print(expected)
        expected_list.append(expected)
        
    dfs_c = dfs_c.copy()
    dfs_c['expected'] = expected_list
    dfs_c.fillna(0, inplace=True)
    
    return dfs_c