In [1]:
# Test migration history reconstruction on simulated data

import sys
import os

repo_dir = "/Users/divyakoyyalagunta/Desktop/Cornell_Research/Morris_Lab/metastatic_history_reconstruction_git/"
os.chdir(repo_dir)
sys.path.append(os.path.join(repo_dir, "metastatic_history_reconstruction/util"))
from metastatic_history_reconstruction.lib import vertex_labeling

from metastatic_history_reconstruction.util import machina_data_extraction_util as mach_util
from metastatic_history_reconstruction.util import vertex_labeling_util as vert_util
import matplotlib
import torch

MACHINA_DATA_DIR = '/Users/divyakoyyalagunta/Desktop/Cornell_Research/Morris_Lab/machina/data/'


In [2]:
custom_colors = [matplotlib.colors.to_hex(c) for c in ['limegreen', 'royalblue', 'hotpink', 'grey', 'saddlebrown', 'darkorange', 'purple', 'red', 'black', 'black', 'black', 'black']]

def find_labeling(cluster_fn, tree_fn, ref_var_fn):
    cluster_label_to_idx = mach_util.get_cluster_label_to_idx(cluster_fn, ignore_polytomies=True)
    print(cluster_label_to_idx)
    idx_to_label = {v:k for k,v in cluster_label_to_idx.items()}
    
    T = torch.tensor(mach_util.get_adj_matrix_from_machina_tree(cluster_label_to_idx, tree_fn, skip_polytomies=True), dtype = torch.float32)
    B = vert_util.get_mutation_matrix_tensor(T)
    ref_matrix, var_matrix, unique_sites= mach_util.get_ref_var_matrices_from_machina_sim_data(ref_var_fn, 
                                                                                               cluster_label_to_idx=cluster_label_to_idx,
                                                                                               T=T)
    
    print(unique_sites)   
    primary_idx = unique_sites.index('P')
    r = torch.nn.functional.one_hot(torch.tensor([primary_idx]), num_classes=len(unique_sites)).T

    vertex_labeling.gumbel_softmax_optimization(T, ref_matrix, var_matrix, B, ordered_sites=unique_sites, 
                                                p=r, node_idx_to_label=idx_to_label, 
                                                w_e=0.01, w_l=3, w_m=10, max_iter=100, batch_size=128, 
                                                custom_colors=custom_colors, primary="prostate")

import pydot
from IPython.display import Image, display

def visualize_ground_truth(tree_fn, graph_fn):
    def view_pydot(pdot):
        plt = Image(pdot.create_png())
        display(plt)
    
    (graph,) = pydot.graph_from_dot_file(tree_fn)
    view_pydot(graph)

    (graph,) = pydot.graph_from_dot_file(graph_fn)
    view_pydot(graph)
        
        

In [None]:
num_sites = "m8"
mig_type = "M"
SEED = "243"

SIM_DATA_DIR = os.path.join(MACHINA_DATA_DIR, "sims", num_sites, mig_type)

find_labeling(os.path.join(SIM_DATA_DIR, f"clustering_observed_seed{SEED}.txt"),
              os.path.join(SIM_DATA_DIR, f"T_seed{SEED}.tree"),
              os.path.join(SIM_DATA_DIR, f"seed{SEED}_0.95.tsv"))

print("Ground truth")
visualize_ground_truth(os.path.join(SIM_DATA_DIR, f"T_seed{SEED}.dot"), os.path.join(SIM_DATA_DIR, f"G_seed{SEED}.dot"))


{'0': 0, '2;3': 1, '8;12;13;19;40': 2, '45': 3, '50': 4, '66;69': 5, '73': 6, '25;26': 7, '30;34;41;44;60;63;74': 8, '48': 9, '36': 10, '38': 11, '78': 12, '86': 13, '87': 14, '67': 15, '33;47;56;58': 16, '68;75': 17, '54;62;64;65': 18, '49;55;72': 19, '24;39;42;52;53;57;81': 20, '70;79': 21, '59': 22, '27;35': 23, '5;7;10;11;17;18;20;22;23': 24, '1;4;6;14;15;16;21;28;31;32': 25, '9': 26}
child_to_parent_map {'2;3': '0', '24;39;42;52;53;57;81': '0', '5;7;10;11;17;18;20;22;23': '0', '1;4;6;14;15;16;21;28;31;32': '0', '9': '0', '8;12;13;19;40': '2;3', '27;35': '2;3', '45': '8;12;13;19;40', '50': '8;12;13;19;40', '54;62;64;65': '45', '49;55;72': '45', '70;79': '45', '59': '45', '66;69': '50', 'M5_1': '50', '68;75': '50', '73': '66;69', '67': '38', '86': '67', '87': '54;62;64;65', '78': 'M5_1', '48': '27;35', '36': '5;7;10;11;17;18;20;22;23', '38': '5;7;10;11;17;18;20;22;23', '25;26': '9', '30;34;41;44;60;63;74': '9', '33;47;56;58': '9'}
['P', 'M1', 'M2', 'M3', 'M4', 'M5', 'M6', 'M7', 'M8'

In [None]:
visualize_ground_truth(os.path.join(SIM_DATA_DIR, f"T_seed{SEED}.dot"), os.path.join(SIM_DATA_DIR, f"G_seed{SEED}.dot"))


In [None]:
num_sites = "m8"
mig_type = "M"
SEED = "243"

SIM_DATA_DIR = os.path.join(MACHINA_DATA_DIR, "sims", num_sites, mig_type)

find_labeling(os.path.join(SIM_DATA_DIR, f"clustering_observed_seed{SEED}.txt"),
              os.path.join(SIM_DATA_DIR, f"T_seed{SEED}.tree"),
              os.path.join(SIM_DATA_DIR, f"seed{SEED}_0.95.tsv"))

print("Ground truth")
visualize_ground_truth(os.path.join(SIM_DATA_DIR, f"T_seed{SEED}.dot"), os.path.join(SIM_DATA_DIR, f"G_seed{SEED}.dot"))

In [None]:
num_sites = "m5"
mig_type = "mS"
SEED = "5"

SIM_DATA_DIR = os.path.join(MACHINA_DATA_DIR, "sims", num_sites, mig_type)

find_labeling(os.path.join(SIM_DATA_DIR, f"clustering_observed_seed{SEED}.txt"),
              os.path.join(SIM_DATA_DIR, f"T_seed{SEED}.tree"),
              os.path.join(SIM_DATA_DIR, f"seed{SEED}_0.95.tsv"))

print("Ground truth")
visualize_ground_truth(os.path.join(SIM_DATA_DIR, f"T_seed{SEED}.dot"), os.path.join(SIM_DATA_DIR, f"G_seed{SEED}.dot"))


In [None]:
SEED=2
visualize_ground_truth(os.path.join(SIM_DATA_DIR, f"T_seed{SEED}.dot"), os.path.join(SIM_DATA_DIR, f"G_seed{SEED}.dot"))
