### For simulated data, we run the same clustering and tree inference algorithm as MACHINA for better comparison. Their tree inference algorithm gets rid of clusters sometimes. In those instances, fix up the tsvs with the correct character index to character label so that the inferred trees appropriately correspond

In [1]:
import os

repo_dir = os.path.join(os.getcwd(), "../")
machina_sims_data_dir = os.path.join(repo_dir, 'data', 'machina_sims')

In [54]:
import fnmatch
from metient.util import data_extraction_util as dutil
import pandas as pd
import numpy as np

sites = ["m8", "m5"]
mig_types = ["M", "mS", "R", "S"]

for site in sites:

    for mig_type in mig_types:
        site_mig_data_dir = os.path.join(machina_sims_data_dir, site, mig_type)
        seeds = fnmatch.filter(os.listdir(site_mig_data_dir), 'reads_seed*.tsv')
        seeds = [s.replace(".tsv", "").replace("reads_seed", "") for s in seeds]
        for seed in seeds:
            cluster_fn = os.path.join(machina_sims_data_dir, f"{site}_clustered_input", f"cluster_{mig_type}_seed{seed}.txt")
            all_mut_trees_fn = os.path.join(machina_sims_data_dir, f"{site}_mut_trees", f"mut_trees_{mig_type}_seed{seed}.txt")
            ref_var_fn = os.path.join(site_mig_data_dir, f"reads_seed{seed}.tsv")
#             ref_var_fn = os.path.join(machina_sims_data_dir, f"{site}_clustered_input", f"cluster_{mig_type}_seed{seed}.tsv")

            idx_to_cluster_label = dutil.get_idx_to_cluster_label(cluster_fn, ignore_polytomies=True)
            data = dutil.get_adj_matrices_from_spruce_mutation_trees(all_mut_trees_fn, idx_to_cluster_label, is_sim_data=True)

            for tree_num, (adj_matrix, pruned_idx_to_clstr_label) in enumerate(data):
                
                # Use the pruned index to cluster label to build a map between each 
                # mutation and its cluster index assignment
                mut_name_to_cluster_idx = {}
                for clstr_idx in pruned_idx_to_clstr_label:
                    muts = pruned_idx_to_clstr_label[clstr_idx].split(";")
                    for mut in muts:
                        mut_name_to_cluster_idx[int(mut)] = int(clstr_idx)
                df = pd.read_csv(ref_var_fn, sep="\t", skiprows=3)
                df['var_read_prob'] = 0.5
                df['site_category'] = df.apply(lambda row: 'primary' if row['anatomical_site_label']=="P" else 'metastasis', axis=1)
                # take out any mutations not used in the adjacency matrix (i.e. no cluster assignment anymore)
                x = len(df)
                df['cluster_index'] = df.apply(lambda row: int(mut_name_to_cluster_idx[row['character_label']]) if row['character_label'] in mut_name_to_cluster_idx else np.nan, axis=1)
                df = df.dropna(subset=['cluster_index'])
                df = df[final_cols]
                
                if site=='m8' and mig_type=='M' and seed=='19' and tree_num==0:
#                 if len(df) != x:
                    print(site, mig_type, seed, tree_num)
                    print(idx_to_cluster_label)
                    print(pruned_idx_to_clstr_label)
                    print(mut_name_to_cluster_idx)
                    
                
                df.to_csv(os.path.join(machina_sims_data_dir, f"{site}_clustered_input_corrected", f"cluster_{mig_type}_seed{seed}_tree{tree_num}.tsv"), sep="\t")
                
                

m8 M 19 0
OrderedDict([(0, '0'), (1, '1'), (2, '2'), (3, '3;4'), (4, '5;10'), (5, '19;21'), (6, '13;25;29'), (7, '11;12;14;31;32'), (8, '22;26;27;28;30;36;45'), (9, '6;7;9;15;18;20;35;42;43;46'), (10, '8;38;41;44;48;49;51;54;58;59'), (11, '52;61'), (12, '62'), (13, '47;50;56;57;64'), (14, '74'), (15, '75')])
{0: '0', 1: '1', 2: '2', 3: '3;4', 4: '5;10', 5: '19;21', 6: '13;25;29', 7: '11;12;14;31;32', 8: '22;26;27;28;30;36;45', 9: '6;7;9;15;18;20;35;42;43;46', 10: '8;38;41;44;48;49;51;54;58;59', 11: '52;61', 12: '62', 13: '47;50;56;57;64', 14: '74', 15: '75'}
{0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 4, 10: 4, 19: 5, 21: 5, 13: 6, 25: 6, 29: 6, 11: 7, 12: 7, 14: 7, 31: 7, 32: 7, 22: 8, 26: 8, 27: 8, 28: 8, 30: 8, 36: 8, 45: 8, 6: 9, 7: 9, 9: 9, 15: 9, 18: 9, 20: 9, 35: 9, 42: 9, 43: 9, 46: 9, 8: 10, 38: 10, 41: 10, 44: 10, 48: 10, 49: 10, 51: 10, 54: 10, 58: 10, 59: 10, 52: 11, 61: 11, 62: 12, 47: 13, 50: 13, 56: 13, 57: 13, 64: 13, 74: 14, 75: 15}


In [60]:

# Compare the cluster index to mutation labels that is indicated by the corrected tsvs to the
# one used when producing the adjacency matrix, and confirm it's the same

def get_index_to_cluster_label_from_corrected_sim_tsv(ref_var_fn):
    df = pd.read_csv(ref_var_fn, sep="\t")
    clstr_idx_to_label = {}
    labels = df['character_label'].unique()
    for label in labels:
        idx = int(df[df['character_label']==label]['cluster_index'].unique().item())
        if idx not in clstr_idx_to_label:
            clstr_idx_to_label[idx] = []
        clstr_idx_to_label[idx].append(str(label))
    clstr_idx_to_label = {k:";".join(v) for k,v in clstr_idx_to_label.items()}
    return clstr_idx_to_label
        
for site in sites:

    for mig_type in mig_types:
        site_mig_data_dir = os.path.join(machina_sims_data_dir, site, mig_type)
        seeds = fnmatch.filter(os.listdir(site_mig_data_dir), 'reads_seed*.tsv')
        seeds = [s.replace(".tsv", "").replace("reads_seed", "") for s in seeds]
        for seed in seeds:
            cluster_fn = os.path.join(machina_sims_data_dir, f"{site}_clustered_input", f"cluster_{mig_type}_seed{seed}.txt")

            all_mut_trees_fn = os.path.join(machina_sims_data_dir, f"{site}_mut_trees", f"mut_trees_{mig_type}_seed{seed}.txt")
            trees = fnmatch.filter(os.listdir(os.path.join(machina_sims_data_dir, f"{site}_clustered_input_corrected")), f"cluster_{mig_type}_seed{seed}_tree*.tsv")
            idx_to_cluster_label = dutil.get_idx_to_cluster_label(cluster_fn, ignore_polytomies=True)
            #print(len(df))
#             x = None
            for tree_num in range(len(trees)):
                ref_var_fn = os.path.join(machina_sims_data_dir, f"{site}_clustered_input_corrected", f"cluster_{mig_type}_seed{seed}_tree{tree_num}.tsv")
                corrected_idx_to_cluster_label = get_index_to_cluster_label_from_corrected_sim_tsv(ref_var_fn)
                data = dutil.get_adj_matrices_from_spruce_mutation_trees(all_mut_trees_fn, idx_to_cluster_label, is_sim_data=True)
    
                assert(data[tree_num][1] == corrected_idx_to_cluster_label)
                tree = data[tree_num][0]
                if site=='m8' and mig_type=='mS' and seed=='4' and tree_num==0:
                    print(tree)
                    print(idx_to_cluster_label)
                
                
                

[[0. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 1. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
OrderedDict([(0, '0'), (1, '1'), (2, '2'), (3, '4;7;8;11;12;17;18'), (4, '19'), (