### How often do we find more parsimonious solutions?

In [1]:
import os
import pickle
from src.lib.vertex_labeling import *
from src.util import eval_util as eutil
from src.util.globals import *
from tqdm import tqdm
import fnmatch
import seaborn as sns

REPO_DIR = os.path.join(os.getcwd(), "../")
os.chdir(REPO_DIR)
DATE = "10092023"
PARAMS = "m10x8_s1x10_delt0.8_gd1.0"

SIMS_DIR = os.path.join(REPO_DIR, "data/machina_sims")

def get_num_mut_trees(mut_tree_fn):
    with open(mut_tree_fn, 'r') as f:
        # look for line w/ "3 #trees" as an example
        for line in f:
            if  "#trees" in line:
                return int(line.strip().split()[0])    
            
def get_pckl_info(bs, run, site, mig_type, seed, tree_num):
    prediction_dir = os.path.join(REPO_DIR, 'test', 'machina_simulated_data', f'batch_experiments_{DATE}', f'predictions_bs{bs}_{PARAMS}_r{run}_{DATE}')
    predicted_site_mig_type_data_dir = os.path.join(prediction_dir, site, mig_type)
    metient_pickle = open(os.path.join(predicted_site_mig_type_data_dir, f"tree{tree_num}_seed{seed}.pickle"), "rb")
    pckl = pickle.load(metient_pickle)
    Vs = pckl[OUT_LABElING_KEY]
    A = pckl[OUT_ADJ_KEY]
    losses = pckl[OUT_LOSSES_KEY]
    return Vs, A, losses


In [2]:
import torch
def get_V_A_from_ground_truth_tree(site, mig_type, seed):
    labeling_fn = os.path.join(SIMS_DIR, site, mig_type, f"T_seed{seed}.vertex.labeling")
    tree_fn = os.path.join(SIMS_DIR, site, mig_type, f"T_seed{seed}.tree")
    true_edges, true_mig_edges, true_labeling = eutil.parse_clone_tree(tree_fn, labeling_fn)
    del true_labeling['GL']
    num_nodes = len(true_labeling)
    num_sites = len(set(list(true_labeling.values())))
    node_label_to_idx = {k:i for i,k in enumerate(list(true_labeling.keys()))}
    A = torch.zeros((num_nodes, num_nodes))
    for edge in true_edges:
        if edge[0] == "GL":
            continue
        A[node_label_to_idx[edge[0]], node_label_to_idx[edge[1]]] = 1
    
    site_label_to_idx = {k:i for i,k in enumerate(set(list(true_labeling.values())))}
    V = torch.zeros((num_sites, num_nodes))
    for node_label in true_labeling:
        site_label = true_labeling[node_label]
        V[site_label_to_idx[site_label], node_label_to_idx[node_label]] = 1
    return V, A


In [3]:

SITES = ["m5", "m8"]
MIG_TYPES = ["mS", "M", "S", "R"]
BATCH_SIZES = ['64', '256', '1024', '4096', '8192']

num_runs = 10
num_less = 0
site_to_bs_to_num_more_pars = {'m5':{bs:[0]*num_runs for bs in BATCH_SIZES}, 'm8':{bs:[0]*num_runs for bs in BATCH_SIZES}}
for bs in BATCH_SIZES:
    print(bs)
    for site in SITES:
        for mig_type in MIG_TYPES:
            print(site, mig_type)
            # Get all seeds for mig_type + site combo
            site_mig_type_dir = os.path.join(SIMS_DIR, site, mig_type)
            seeds = fnmatch.filter(os.listdir(site_mig_type_dir), 'clustering_observed_seed*.txt')
            seeds = [s.replace(".txt", "").replace("clustering_observed_seed", "") for s in seeds]
            
            for seed in seeds:
                # Get all the clone trees for this seed
                num_trees = get_num_mut_trees(os.path.join(SIMS_DIR, f"{site}_mut_trees", f"mut_trees_{mig_type}_seed{seed}.txt"))
                
                true_V, true_A = get_V_A_from_ground_truth_tree(site, mig_type, seed)
                true_m,true_c,true_s,_,_ = get_ancestral_labeling_metrics(true_V.reshape(1, true_V.shape[0], true_V.shape[1]), true_A, None, None, None)
                #print("true:", true_m,true_c,true_s)
                for run in range(1,num_runs+1):
                    
                    for tree_num in range(num_trees):
                        Vs, A, losses = get_pckl_info(bs, run, site, mig_type, seed, tree_num)
                        for V in Vs:
                            pred_m,pred_c,pred_s,_,_ = get_ancestral_labeling_metrics(V.reshape(1, V.shape[0], V.shape[1]), A, None, None, None)
                            #print("pred:", pred_m,pred_c,pred_s)
                            if pred_m < true_m:
                                #print("true:", true_m,true_c,true_s)
                                #print("pred:", pred_m,pred_c,pred_s)
                                site_to_bs_to_num_more_pars[site][bs][run-1] += 1
                                continue # don't count extra for every V found
site_to_bs_to_num_more_pars


64
m5 mS
m5 M
m5 S
m5 R
m8 mS
m8 M
m8 S
m8 R
256
m5 mS
m5 M
m5 S
m5 R
m8 mS
m8 M
m8 S



KeyboardInterrupt



In [None]:
site_to_bs_to_num_more_pars_cop = site_to_bs_to_num_more_pars

In [None]:
site_to_bs_to_avg_more_pars = {'m5':{bs:0 for bs in BATCH_SIZES}, 'm8':{bs:0 for bs in BATCH_SIZES}}
for site in site_to_bs_to_num_more_pars:
    for bs in site_to_bs_to_num_more_pars[site]:
        site_to_bs_to_avg_more_pars[site][bs] = sum(site_to_bs_to_num_more_pars[site][bs])/len(site_to_bs_to_num_more_pars[site][bs])
site_to_bs_to_avg_more_pars
        

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
colors = sns.color_palette("crest")
colors = colors[:len(BATCH_SIZES)]
#colors.append((0.99609375, 0.56640625, 0.37109375))
sns.set(font_scale=1.0)
sns.set_style("whitegrid")
sns.set_style("ticks")
sns.set_palette(sns.color_palette(colors))

for site in SITES:
    print(site_to_bs_to_avg_more_pars[site])
    df = pd.DataFrame(site_to_bs_to_avg_more_pars[site].items(),columns=["Metient Sample Size", "Number of times a more \nparsimonious solution is found"])
    print(df)

    fig = plt.figure(figsize=(5, 4), dpi=300)
    sns.despine()
    ax = sns.barplot(df, x="Metient Sample Size", y="Number of times a more \nparsimonious solution is found")
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    fig.savefig(os.path.join(REPO_DIR, "test/output_plots/", f"{site}_more_pars_{DATE}.png"), dpi=600, bbox_inches='tight', pad_inches=0.5) 


In [None]:
site_to_bs_to_avg_more_pars[site].items()
