In [1]:
import sys
sys.path.append('..')
from mtDNAsim import *
import pandas as pd
import numpy as np
from copy import deepcopy
from collections import Counter
import argparse
import pickle

  from tqdm.autonotebook import tqdm


In [2]:
data_path = '/data3/wangkun/mtsim_res/20240903/test/'
# simid = parser.parse_args().i
nrm = 100
bn = 'mid'
mt_mutrate = 0.8
n_mts = 500


imr = nrm / 1000
selection = 0.6

num_elements = 1
success = 0
while not success:
    try:
        system = Gillespie(
            num_elements,
            inits=[1],
            max_cell_num=20000
        )

        p0 = lambda t: 0.8
        dr = lambda t: 0.33
        system.add_reaction(p0, [1], [2], index=0) # 0 self renew
        system.add_reaction(dr, [1], [0], index=13) # 3 -> 4 differentiation
        system.evolute(1000000)
        success = 1
    except:
        None

curr_cells = []

for i in system.curr_cells.values():
    curr_cells += i

sim_utils.wirte_lineage_info(
    f"{data_path}/lineage_info.csv", system.anc_cells, curr_cells, system.t[-1]
)

reconstruct(f"{data_path}/lineage_info.csv", output=f"{data_path}/gt_tree.nwk", num=1000, is_balance=True)

  0%|          | 0/20000 [00:00<?, ?it/s]

  avg_generation = np.dot(
  A /= A0
  t0 = -np.log(np.random.random()) / A0


  0%|          | 0/20000 [00:00<?, ?it/s]

In [3]:
tree = loadtree(f'{data_path}/gt_tree.nwk')[0]
for i in tree.get_nonterminals():
    i.branch_length=1
for i in tree.get_terminals():
    i.branch_length=1
Phylo.write(tree, f'{data_path}/gt_tree.nwk', format='newick')


mt_cn = {
    'mid':lambda x: 1.52 if x <= 10 else (2.85 if x <= 20 else 2),
    'const':lambda x: 2 
}

success = 0
while not success:
    try:
        mt_muts, mutid = mtmutation(tree, mut_rate=mt_mutrate/n_mts, init_mut_rate=imr, mt_copynumber=mt_cn[bn], nmts=n_mts)
        n_root_muts = len(set(sum([list(i) for i in mt_muts['<0_0>']], [])))
        if nrm == 0:
            success = 1
        elif np.abs(n_root_muts-nrm)/nrm <= 0.7:
            success = 1
        else:
            pass
    except:
        pass

pickle.dump(mt_muts, open(f"{data_path}/mt_allmuts_30.pkl", 'wb'))

  0%|          | 0/499 [00:00<?, ?it/s]

Simulating MT mutation::   0%|          | 0/9014 [00:00<?, ?it/s]

In [4]:
mt_freq = sparse_freq(mt_muts)
tree_gt = Phylo.read(f'{data_path}/gt_tree.nwk', format='newick')
mt_freq_leave = mt_freq.loc[[i.name for i in tree_gt.get_terminals()]]
for cutoff in [0, 0.01]:
    muts = mt_freq_leave>cutoff
    muts = muts.iloc[:, np.where(muts.sum(0)>0)[0]]
    muts = muts.astype(int).astype(str)
    translation_table = str.maketrans({'1': 'A', '0': 'G'})
    seqs = f'{muts.shape[0]} {muts.shape[1]}\n'
    for i in range(muts.shape[0]):
        seqs += f'{muts.index[i]} '
        seqs += ''.join(muts.iloc[i].to_numpy()).translate(translation_table)
        seqs += '\n'
    with open(f'{data_path}/mt_allmuts_{cutoff}.phy', 'w') as f:
        f.write(seqs)

sel_cells = [i.name for i in tree.get_terminals()]
max_mut_id = max([max([max(list(i)+[0]) for i in mt_muts[j]]+[0]) for j in sel_cells])
new_mts_1 = dict()
for cell in tree.get_terminals():
    new_mts_1[cell.name] = [{cell.name:deepcopy(mt_muts[cell.name])}]

  0%|          | 0/9014 [00:00<?, ?it/s]

In [5]:
from io import StringIO

In [6]:
def get_gt_tree(gt_tree30, mts, save_path = None):
    tree100_gt = deepcopy(gt_tree30)
    # mts = rs_cvt(mts)
    mt_freq = sparse_freq(mts)
    alive_cells = [i+'>' for i in set([i.split('>')[0] for i in mt_freq.index])]
    keep_cells = []
    for i in alive_cells:
        for c in tree100_gt.get_path(tree100_gt.find_any(i)):
            keep_cells.append(c.name)
    keep_cells = list(set(keep_cells))
    
    tree_nwk = StringIO()
    Phylo.write(gt_tree30, tree_nwk, 'newick')
    tree_nwk = tree_nwk.getvalue()
    tree100_gt = ete3.Tree(tree_nwk.replace('\n', ';'), format=1)
    tree100_gt.prune(keep_cells)
    tree100_gt = Phylo.read(StringIO(tree100_gt.write()), format='newick')
    for i in tree100_gt.get_terminals():
        i.branch_length = 1
    for i in tree100_gt.get_nonterminals():
        i.branch_length = 1
    
    expansion_clades = dict()
    rec_cells = dict()
    for c in tqdm(mt_freq.index):
        anc_name, lin_info = c.split('>')
        anc_name = f'{anc_name}>'
        if not anc_name in expansion_clades:
            expansion_clades[anc_name] = Phylo.BaseTree.Clade(branch_length=1, name=anc_name)
            rec_cells[anc_name] = []
        for li in range(len(lin_info)):
            if f'{anc_name}{lin_info[:li+1]}' in rec_cells[anc_name]:
                continue
            else:
                anc_t = expansion_clades[anc_name].find_any(f'{anc_name}{lin_info[:li]}')
                anc_t.clades.append(Phylo.BaseTree.Clade(branch_length=1, name=f'{anc_name}{lin_info[:li+1]}'))
                rec_cells[anc_name].append(f'{anc_name}{lin_info[:li+1]}')

    for i in expansion_clades:
        tree100_gt.find_any(i).clades = expansion_clades[i]
    # Phylo.write(tree100_gt, f'{data_path}/gt_tree100.nwk', format='newick')
    if not save_path is None:
        Phylo.write(tree100_gt, save_path, format='newick')
    return tree100_gt




In [68]:
def get_gt_tree(gt_tree30, mts, save_path = None, collapse=True):
    tree100_gt = deepcopy(gt_tree30)
    # mts = rs_cvt(mts)
    mt_freq = sparse_freq(mts)
    alive_cells = [i+'>' for i in set([i.split('>')[0] for i in mt_freq.index])]
    keep_cells = []
    for i in alive_cells:
        for c in tree100_gt.get_path(tree100_gt.find_any(i)):
            keep_cells.append(c.name)
    keep_cells = list(set(keep_cells))
    
    tree_nwk = StringIO()
    Phylo.write(gt_tree30, tree_nwk, 'newick')
    tree_nwk = tree_nwk.getvalue()
    tree100_gt = ete3.Tree(tree_nwk.replace('\n', ';'), format=1)
    tree100_gt.prune(keep_cells)
    tree100_gt = Phylo.read(StringIO(tree100_gt.write()), format='newick')
    for i in tree100_gt.get_terminals():
        i.branch_length = 1
    for i in tree100_gt.get_nonterminals():
        i.branch_length = 1
    
    expansion_clades = dict()
    rec_cells = dict()
    for c in tqdm(mt_freq.index):
        anc_name, lin_info = c.split('>')
        anc_name = f'{anc_name}>'
        if not anc_name in expansion_clades:
            expansion_clades[anc_name] = Phylo.BaseTree.Clade(branch_length=1, name=anc_name)
            rec_cells[anc_name] = []
        for li in range(len(lin_info)):
            if f'{anc_name}{lin_info[:li+1]}' in rec_cells[anc_name]:
                continue
            else:
                anc_t = expansion_clades[anc_name].find_any(f'{anc_name}{lin_info[:li]}')
                anc_t.clades.append(Phylo.BaseTree.Clade(branch_length=1, name=f'{anc_name}{lin_info[:li+1]}'))
                rec_cells[anc_name].append(f'{anc_name}{lin_info[:li+1]}')

    for i in expansion_clades:
        tree100_gt.find_any(i).clades = [expansion_clades[i]]
    # Phylo.write(tree100_gt, f'{data_path}/gt_tree100.nwk', format='newick')
    if collapse:
        for i in Phylo.BaseTree._preorder_traverse(tree330.root, lambda elem: elem.clades):
            if len(i.clades) == 1:
                tree330.collapse(i)
    if not save_path is None:
        Phylo.write(tree100_gt, save_path, format='newick')
    return tree100_gt

In [7]:
   
with tqdm(total=300) as pbar:
    for _ in range(100):
        # gen += 1
        cell_number = np.sum([len(new_mts_1[i]) for i in new_mts_1.keys()])
        if cell_number > 1200:
            p = 0.433
        elif cell_number < 800:
            p = 0.6
        else:
            p = 0.5
        for cell in tree.get_terminals():
            tmp = ncell_division_with_mt1(new_mts_1[cell.name], max_mut_id, mt_mutrate, p=p, s=selection)
            max_mut_id = tmp[-1]
            new_mts_1[cell.name] = tmp[0]  
        pbar.update(1)
        
    new_mts_11 = rs_cvt(new_mts_1)
    get_gt_tree(tree_gt, new_mts_11, save_path = f"{data_path}/gt_tree130.nwk")
    pickle.dump(new_mts_11, open(f"{data_path}/mt_allmuts_130.pkl", 'wb'))
     
    for _ in range(200):
        # gen += 1
        cell_number = np.sum([len(new_mts_1[i]) for i in new_mts_1.keys()])
        if cell_number > 1200:
            p = 0.433
        elif cell_number < 800:
            p = 0.6
        else:
            p = 0.5
        for cell in tree.get_terminals():
            tmp = ncell_division_with_mt1(new_mts_1[cell.name], max_mut_id, mt_mutrate, p=p, s=selection)
            max_mut_id = tmp[-1]
            new_mts_1[cell.name] = tmp[0]  
        pbar.update(1)
    new_mts_11 = rs_cvt(new_mts_1)
    get_gt_tree(tree_gt, new_mts_11, save_path = f"{data_path}/gt_tree330.nwk")
    pickle.dump(new_mts_11, open(f"{data_path}/mt_allmuts_330.pkl", 'wb'))  

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/911 [00:00<?, ?it/s]

  0%|          | 0/911 [00:00<?, ?it/s]

  0%|          | 0/900 [00:00<?, ?it/s]

  0%|          | 0/900 [00:00<?, ?it/s]

In [69]:
tree330 = get_gt_tree(tree_gt, new_mts_11)

  0%|          | 0/900 [00:00<?, ?it/s]

  0%|          | 0/900 [00:00<?, ?it/s]

In [71]:
for i in tree330.get_terminals():
    for j in tree330.get_path(i)[::-1][1:]:
        if len(j.clades) == 1:
            tree330.collapse(j)

  0%|          | 0/900 [00:00<?, ?it/s]

In [72]:
len(tree330.get_nonterminals())

899

In [18]:
tree330.root.clades

[Clade(branch_length=1, confidence=1), Clade(branch_length=1, confidence=1)]

In [56]:
for i in Phylo.BaseTree._preorder_traverse(tree330.root, lambda elem: elem.clades):
    if len(i.clades) == 1:
        tree330.collapse(i)
for i in Phylo.BaseTree._preorder_traverse(tree330.root, lambda elem: elem.clades):
    if len(i.clades) == 1:
        tree330.collapse(i)

In [57]:
len(tree330.get_nonterminals())

1049

In [47]:
Phylo.write(tree330, '/data3/wangkun/mtsim_res/20240903/test/tree300_long.nwk', format='newick')

1

In [73]:
Phylo.write(tree330, '/data3/wangkun/mtsim_res/20240903/test/tree300_short.nwk', format='newick')

1

In [74]:
def sparse_freq(cells, df=True, count=False):
    cell_names = list(cells.keys())
    max_mut_id = max([max([max(list(i)+[0]) for i in cells[j]]+[0]) for j in cell_names])
    
    _row, _col, _data = [], [], []
    _data_cnt = []
    with tqdm(total=len(cell_names)) as pbar:
        for ind, cell in enumerate(cells):
            cell_muts = sum([list(i) for i in cells[cell]], [])
            nmts = len(cells[cell])
            cnt = Counter(cell_muts)
            for mut in cnt:
                _col.append(mut)
                _row.append(ind)
                _data.append(cnt[mut]/nmts)
                _data_cnt.append(cnt[mut])
            pbar.update(1)
    freq = coo_matrix((_data, (_row, _col))).tocsr()
    mut_id = np.arange(freq.shape[1])
    sel = np.array(freq.sum(axis=0)!=0).flatten()
    mut_id = mut_id[sel]
    freq = freq[:, sel]
    
    if df:
        freq = pd.DataFrame(freq.A, index=cell_names, columns=mut_id)    
    if count:
        count = coo_matrix((_data_cnt, (_row, _col))).tocsr()
        mut_id = np.arange(count.shape[1])
        sel = np.array(count.sum(axis=0)!=0).flatten()
        mut_id = mut_id[sel]
        count = count[:, sel]
        if df:
            count = pd.DataFrame(count.A, index=cell_names, columns=mut_id)
        return freq, count
    else:
        return freq

In [76]:
freq, count = sparse_freq(mt_muts, count=True)

  0%|          | 0/9014 [00:00<?, ?it/s]

In [144]:
from scipy.stats import nbinom, binom
def sequence_sim(f, coverage, n=2.5):
    depth = nbinom(p=n/(n+coverage), n=n).rvs(size=f.shape)
    read_cnt = binom(n=depth, p=f).rvs()
    freq_samp = read_cnt/depth
    freq_samp[np.isnan(freq_samp)] = 0
    freq_samp = pd.DataFrame(freq_samp, index=f.index, columns=f.columns)
    return freq_samp

In [145]:
freq_samp = sequence_sim(freq, 50, n=2.5)

  freq_samp = read_cnt/depth
