In [None]:
import sys
import os
from metient.util.globals import *
from metient.metient import *
import matplotlib

matplotlib.rcParams['figure.figsize'] = [3, 3]
custom_colors = ["#6aa84f","#c27ba0", "#be5742e1", "#6fa8dc", "#e69138", "#9e9e9e", "grey", "black", "#6aa84f","#c27ba0", "#be5742e1", "#6fa8dc", "#e69138", "#9e9e9e", "grey", "black", "grey"]

repo_dir = os.path.join(os.getcwd(), "../")
MSK_MET_FN = os.path.join(repo_dir, 'data/msk_met/msk_met_freq_by_cancer_type.csv')
      
DATA_DIR = '/data/morrisq/divyak/data/gundem_neuroblastoma_2023/'
TREE_DIR = os.path.join(DATA_DIR, 'orchard_trees')    
TSV_DIR = os.path.join(DATA_DIR, 'clustered_tsvs')                 

OUTPUT_DIR = os.path.join(repo_dir, "data", "gundem_neuroblastoma_2023", "metient_outputs")

print_config = PrintConfig(visualize=True, k_best_trees=6)

patient_ids = [x.replace("_clustered_SNVs.tsv", "") for x in os.listdir(TSV_DIR)]
len(patient_ids)

### Run calibrate

In [None]:
patient_ids = ["H103207"]
mut_trees_fns = [os.path.join(TREE_DIR, f"{patient_id}.results.npz") for patient_id in patient_ids]
trees = [data[0] for data in get_adj_matrices_from_pairtree_results(mut_trees_fns)]
ref_var_fns = [os.path.join(TSV_DIR, f"{patient_id}_clustered_SNVs.tsv") for patient_id in patient_ids]
run_names = [f"{pid}_calibrate" for pid in patient_ids]
calibrate(trees, ref_var_fns, print_config, OUTPUT_DIR, run_names, bias_weights=True, custom_colors=custom_colors, solve_polytomies=False)


### Run evaluate

In [None]:
from metient.util import data_extraction_util as dutil

def run_evaluate(mut_trees_fn, ref_var_fn, weights, run_name):    
    trees = get_adj_matrices_from_pairtree_results(mut_trees_fn)
    print("num trees:", len(trees))   
    tree_num = 1
    for adj_matrix in trees[:1]:
        print(f"\nTREE {tree_num}")
        print(adj_matrix.shape)

        evaluate(adj_matrix, ref_var_fn, weights, print_config, OUTPUT_DIR, f"{run_name}_tree{tree_num}",
                 batch_size=6000,
                 O=None, bias_weights=True, custom_colors=custom_colors, solve_polytomies=False)
        tree_num += 1


In [None]:
for patient_id in ["H103207"]:
    mut_trees_fn = os.path.join(TREE_DIR, f"{patient_id}.results.npz")
    ref_var_fn = os.path.join(TSV_DIR, f"{patient_id}_clustered_SNVs.tsv")
    weights = Weights(mig=100.0, comig=5.0, seed_site=1.0, gen_dist=0.0, organotrop=0.0)
    print(patient_id)
    run_evaluate(mut_trees_fn, ref_var_fn, weights, f"{patient_id}_evaluate")


In [None]:
def dfs_linear_chains(adj_matrix, node, visited, current_chain=[], linear_chains=[]):
    visited[node] = True
    current_chain.append(node)

    # Check if the current node has exactly one neighbor (child)
    neighbors = [i for i, connected in enumerate(adj_matrix[node]) if connected]
    if len(neighbors) == 1:
        dfs_linear_chains(adj_matrix, neighbors[0], visited, current_chain, linear_chains)
    else:
        # Check if the current chain is a linear chain (more than one node in the chain)
        if len(current_chain) > 1:
            linear_chains.append(current_chain)

        # Continue DFS for each neighbor
        for neighbor in neighbors:
            if not visited[neighbor]:
                dfs_linear_chains(adj_matrix, neighbor, visited, [neighbor], linear_chains)

    return linear_chains

def get_linear_chain_starting_nodes()
num_nodes = len(adjacency_matrix)
visited = [False] * num_nodes

# Perform DFS to detect and get starting nodes of linear chains
result_linear_chains = []
for start_node in range(num_nodes):
    if not visited[start_node]:
        linear_chains_for_start = dfs_linear_chains(adjacency_matrix, start_node, visited)
        result_linear_chains.extend(linear_chains_for_start)

print("Starting Nodes of Linear Chains:", [chain[0] for chain in result_linear_chains])


In [None]:
num_nodes = 35
num_sites = len(sites)
min_size = 256
if num_nodes > 15:
    min_size += 1280 * (num_nodes // 2)

    if num_sites > 3:
        min_size += 512 * (num_sites)

elif num_sites > 4:
    min_size += 256 * (num_sites // 2)

# cap this to a reasonably high sample size
min_size = min(min_size, 60000)