In [10]:
import fnmatch
import pandas as pd
import os
import torch

from metient.util import eval_util as eutil
from metient.util import plotting_util as putil

from metient.util.globals import *

x = 0
k = float("inf")
LOSS_THRES = 0.0

SITES = ["m5", "m8"]
MIG_TYPES = ["mS", "M", "S", "R"]

def get_overall_pattern(pattern):
    # 3 options: single-source, multi-source, reseeding
    return pattern.split(" ")[1]

def get_ground_truth_pattern(site, mig_type, seed):
    labeling_fn = os.path.join(SIMS_DIR, site, mig_type, f"T_seed{seed}.vertex.labeling")
    tree_fn = os.path.join(SIMS_DIR, site, mig_type, f"T_seed{seed}.tree")
    true_edges, true_mig_edges, true_labeling = eutil.parse_clone_tree(tree_fn, labeling_fn)
    del true_labeling['GL']
    num_nodes = len(true_labeling)
    num_sites = len(set(list(true_labeling.values())))
    node_label_to_idx = {k:i for i,k in enumerate(list(true_labeling.keys()))}
    A = torch.zeros((num_nodes, num_nodes))
    for edge in true_edges:
        if edge[0] == "GL":
            continue
        A[node_label_to_idx[edge[0]], node_label_to_idx[edge[1]]] = 1
    
    site_label_to_idx = {k:i for i,k in enumerate(set(list(true_labeling.values())))}
    V = torch.zeros((num_sites, num_nodes))
    for node_label in true_labeling:
        site_label = true_labeling[node_label]
        V[site_label_to_idx[site_label], node_label_to_idx[node_label]] = 1
    return putil.get_verbose_seeding_pattern(V,A)

SIMS_DIR = "/data/morrisq/divyak/projects/machina/data/sims/"

def get_gt_mig_type_to_top_seeding_pattern(prediction_dir, suffix):
    gt_mig_type_to_top_seeding_pattern = { }

    for mig_type in MIG_TYPES:
        for site in SITES:
            print(site, mig_type)
            site_mig_type_dir = os.path.join(SIMS_DIR, site, mig_type)
            seeds = fnmatch.filter(os.listdir(site_mig_type_dir), 'clustering_observed_seed*.txt')
            seeds = [s.replace(".txt", "").replace("clustering_observed_seed", "") for s in seeds]
            for seed in seeds:
                predicted_site_mig_type_data_dir = os.path.join(prediction_dir, site, mig_type)
                print(os.listdir(predicted_site_mig_type_data_dir))
                tree_info = eutil.get_metient_min_loss_trees(predicted_site_mig_type_data_dir, seed, k, loss_thres=LOSS_THRES, suffix=suffix)
                for i, (loss, met_results_dict, met_tree_num) in enumerate(tree_info):
                    V = torch.tensor(met_results_dict[OUT_LABElING_KEY][met_tree_num])
                    A = torch.tensor(met_results_dict[OUT_ADJ_KEY][met_tree_num])
                    seeding_pattern = putil.get_verbose_seeding_pattern(V,A)
                    gt_pattern = get_ground_truth_pattern(site, mig_type, seed)
                    if gt_pattern not in gt_mig_type_to_top_seeding_pattern:
                        gt_mig_type_to_top_seeding_pattern[gt_pattern] = []
                    gt_mig_type_to_top_seeding_pattern[gt_pattern].append(seeding_pattern) 
    return gt_mig_type_to_top_seeding_pattern



In [3]:
calibrate_prediction_dir = "/data/morrisq/divyak/data/metient_prediction_results/predictions_bs1024_calibrate_wip_parsweight10xgen_01302024"

calibrate_gt_mig_type_to_top_seeding_pattern = get_gt_mig_type_to_top_seeding_pattern(calibrate_prediction_dir, "_calibrate")
calibrate_gt_mig_type_to_top_seeding_pattern

m5 mS
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
m8 mS
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
m5 M
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
m8 M
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
m5 S
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss trees: 1
# min loss tr

{'monoclonal primary single-source seeding': ['monoclonal primary single-source seeding',
  'polyclonal primary single-source seeding',
  'polyclonal primary single-source seeding',
  'polyclonal primary single-source seeding',
  'polyclonal primary single-source seeding',
  'polyclonal primary single-source seeding',
  'polyclonal single-source seeding',
  'polyclonal primary single-source seeding',
  'polyclonal primary single-source seeding',
  'polyclonal primary single-source seeding'],
 'monoclonal single-source seeding': ['monoclonal primary single-source seeding',
  'monoclonal primary single-source seeding',
  'polyclonal primary single-source seeding',
  'polyclonal primary single-source seeding',
  'polyclonal primary single-source seeding',
  'polyclonal primary single-source seeding',
  'polyclonal primary single-source seeding',
  'polyclonal primary single-source seeding',
  'polyclonal single-source seeding',
  'polyclonal primary single-source seeding'],
 'polyclonal m

In [11]:
evaluate_prediction_dir = "/data/morrisq/divyak/data/metient_prediction_results/predictions_bs1024_evaluate_wip_01292024"

evaluate_gt_mig_type_to_top_seeding_pattern = get_gt_mig_type_to_top_seeding_pattern(evaluate_prediction_dir, "_evaluate")
evaluate_gt_mig_type_to_top_seeding_pattern

m5 mS
['tree9_seed8_evaluate.mig_graph.dot', 'tree0_seed10_evaluate.mig_graph.dot', 'tree5_seed8_evaluate.mig_graph.dot', 'G_tree2_seed4_evaluate.predicted.tree', 'T_tree8_seed8_evaluate.predicted.vertex.labeling', 'tree3_seed3_evaluate.mig_graph.dot', 'T_tree3_seed2_evaluate.predicted.vertex.labeling', 'tree0_seed5_evaluate.tree.dot', 'tree1_seed3_evaluate.tree.dot', 'tree1_seed10_evaluate.mig_graph.dot', 'T_tree11_seed8_evaluate.predicted.vertex.labeling', 'G_tree7_seed8_evaluate.predicted.tree', 'tree2_seed4_evaluate.pkl.gz', 'T_tree1_seed7_evaluate.predicted.vertex.labeling', 'tree1_seed0_evaluate.mig_graph.dot', 'tree2_seed3_evaluate.mig_graph.dot', 'G_tree2_seed2_evaluate.predicted.tree', 'G_tree11_seed8_evaluate.predicted.tree', 'tree0_seed2_evaluate.mig_graph.dot', 'T_tree0_seed0_evaluate.predicted.tree', 'tree10_seed8_evaluate.pkl.gz', 'tree4_seed8_evaluate.pkl.gz', 'tree8_seed8_evaluate.tree.dot', 'tree0_seed10_evaluate.pkl.gz', 'T_tree1_seed8_evaluate.predicted.tree', 'T_tre

IndexError: list index out of range

In [41]:
def get_percent_right(dct):
    for gt_pattern in dct:
        pct_right = (dct[gt_pattern].count(gt_pattern))/len(dct[gt_pattern])
        print(gt_pattern, pct_right)
print('evaluate')
get_percent_right(evaluate_gt_mig_type_to_topk_seeding_pattern)

print('calibrate')
get_percent_right(calibrate_gt_mig_type_to_topk_seeding_pattern)

evaluate
monoclonal single-source seeding 0.21739130434782608
polyclonal multi-source seeding 0.8571428571428571
polyclonal single-source seeding 0.8571428571428571
polyclonal reseeding 0.2
calibrate
monoclonal single-source seeding 0.23809523809523808
polyclonal multi-source seeding 0.45
polyclonal single-source seeding 0.88
polyclonal reseeding 0.3333333333333333


In [45]:
print('calibrate')
get_percent_right(calibrate_gt_mig_type_to_topk_seeding_pattern)

calibrate
monoclonal single-source seeding 0.23809523809523808
polyclonal multi-source seeding 0.45
polyclonal single-source seeding 0.88
polyclonal reseeding 0.26666666666666666


### What is the confusion matrix for the top seeding pattern in calibrate vs evaluate?

In [1]:
def confusion_matrix(cm, title, output_name):
    plt.figure(figsize=(5, 4))
    sns.heatmap(cm, annot=True, fmt=".0f", cmap="Blues", cbar=True,)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(title)
    plt.show()

In [None]:
import numpy as np
patterns = [k for k in mig_type_to_top_pattern]
print(patterns)
cm = np.zeros((len(patterns), len(patterns)))
for i,gt_pattern in enumerate(patterns):
    print("ground truth:", gt_pattern)
    pred_df = pd.DataFrame(mig_type_to_topk_seeding_pattern_df[gt_pattern]['Overall Pattern'], columns=["Overall Pattern"])
    #pred_df = pd.DataFrame(mig_type_to_top_pattern[gt_pattern], columns=["Seeding pattern"])
    #pred_df['Overall Pattern'] = pred_df.apply(lambda row: get_overall_pattern(row), axis=1)
    counts_df = pred_df['Overall Pattern'].value_counts()
    print(counts_df)
    for j,pred_pattern in enumerate(patterns):
        cm[i,j] = counts_df[pred_pattern]
confusion_matrix(cm, "Calibrate mode: \nGenetic Distance Classification", f"top_seeding_pattern_by_gd_eval_loss_{PARAMS}")
                 