### Test migration history reconstruction on ovarian cancer data

In [1]:
import sys
import os
from metient.util.globals import *
from metient.metient import *
import matplotlib

matplotlib.rcParams['figure.figsize'] = [3, 3]
custom_colors = ["#6aa84f","#c27ba0", "#be5742e1", "#6fa8dc", "#e69138", "#9e9e9e"]

repo_dir = os.path.join(os.getcwd(), "../")
MSK_MET_FN = os.path.join(repo_dir, 'data/msk_met/msk_met_freq_by_cancer_type.csv')
      
MCPHERSON_DATA_DIR = os.path.join(repo_dir, 'data', 'mcpherson_ovarian_2016')
TREE_DIR = os.path.join(MCPHERSON_DATA_DIR, 'orchard_trees')    
TSV_DIR = os.path.join(MCPHERSON_DATA_DIR, 'pyclone_clustered_tsvs')                 

OUTPUT_DIR = os.path.join(MCPHERSON_DATA_DIR, "metient_outputs")

PATIENT_IDS = [1,2,3,4,7,9]
print_config = PrintConfig(visualize=True, k_best_trees=6)


CUDA GPU: False


In [2]:
from metient.util import data_extraction_util as dutil

def run_evaluate(mut_trees_fn, ref_var_fn, weights, run_name):    
    _, _, _, unique_sites, _, _ = dutil.get_ref_var_omega_matrices(ref_var_fn)
    trees = get_adj_matrices_from_pairtree_results(mut_trees_fn)
    print("num trees:", len(trees))
    print(unique_sites)
   
    tree_num = 1
    for adj_matrix in trees[:1]:
        print(f"\nTREE {tree_num}")
        print(adj_matrix.shape)

        evaluate(adj_matrix, ref_var_fn, weights, print_config, OUTPUT_DIR, f"{run_name}_tree{tree_num}",
                 O=None, batch_size=4096, bias_weights=True,
                 custom_colors=custom_colors, solve_polytomies=False)
        tree_num += 1


### Run all patients in evaluate mode

In [3]:
for patient_id in PATIENT_IDS:
    mut_trees_fn = os.path.join(TREE_DIR, f"patient{patient_id}.results.npz")
    ref_var_fn = os.path.join(TSV_DIR, f"patient{patient_id}_clustered_SNVs.tsv")
    weights = Weights(mig=10.0, comig=5.0, seed_site=1.0, gen_dist=0.0, organotrop=0.0)
    
    run_evaluate(mut_trees_fn, ref_var_fn, weights, f"patient{patient_id}_evaluate")


num trees: 48
['omentum_site', 'right_ovary_site', 'small_bowel_site']

TREE 1
torch.Size([30, 30])


TypeError: evaluate() got an unexpected keyword argument 'weight_init_primary'

### Run all patients in calibrate mode

In [None]:
mut_trees_fns = [os.path.join(TREE_DIR, f"patient{patient_id}.results.npz") for patient_id in PATIENT_IDS]
trees = [data[0] for data in get_adj_matrices_from_pairtree_results(mut_trees_fns)]
ref_var_fns = [os.path.join(TSV_DIR, f"patient{patient_id}_clustered_SNVs.tsv") for patient_id in PATIENT_IDS]
run_names = [f"{pid}_calibrate" for pid in PATIENT_IDS]
calibrate(trees, ref_var_fns, print_config, OUTPUT_DIR, run_names, bias_weights=True, custom_colors=custom_colors, solve_polytomies=False)