In [2]:
import gzip
import pickle
import glob
import pandas as pd
import numpy as np
import os
from metient.util import eval_util as eutil
from metient.util.globals import *
import seaborn as sns
import matplotlib.pyplot as plt
from statannot import add_stat_annotation
import itertools

REPO_DIR = os.path.join(os.getcwd(), "../")

METRICS = ["Migration", "Comigration", "Seeding site"]
DATASET_NAMES = [ "HGSOC", "Melanoma", "HR-NB", "NSCLC"]
#os.path.join(REPO_DIR,"data/hoadley_breast_cancer_2016/metient_outputs/calibrate")
CALIBRATE_DIRS = [os.path.join(REPO_DIR,"data/mcpherson_ovarian_2016/metient_outputs/","calibrate"),
                  os.path.join(REPO_DIR,"data/sanborn_melanoma_2015/metient_outputs/", "calibrate"),
                  os.path.join(REPO_DIR,"data/gundem_neuroblastoma_2023/metient_outputs","calibrate"),
                  os.path.join(REPO_DIR,"data/tracerx_nsclc/metient_outputs/pyclone_clustered_conipher_trees_03282024/calibrate",)]

num_bootstrap_samples = 50

num_mets = []
dataset_to_pickle_files = {dataset:[] for dataset in DATASET_NAMES}
for dataset_name, calibrate_dir in zip(DATASET_NAMES, CALIBRATE_DIRS):
    matching_files = glob.glob(f'{calibrate_dir}/*pkl.gz')
    for fn in matching_files:
        with gzip.open(fn, 'rb') as f:
            pkl = pickle.load(f)
            num_sites = len(pkl[OUT_SITES_KEY])
            loss_dicts = pkl[OUT_LOSS_DICT_KEY]
            #if num_sites > 2:
            dataset_to_pickle_files[dataset_name].append((fn, num_sites))
            num_mets.append(num_sites)

def convert_thetas_to_alt_metrics(thetas):
    wtot = thetas[0] + thetas[1]
    delta = thetas[1]/wtot
    return wtot, delta, thetas[2]
    
sizes = {k:len(v) for k,v in dataset_to_pickle_files.items()}
sizes


KeyboardInterrupt: 

### How consistent are the thetas when calibrated on random samples within the same cancer type cohort?

In [None]:
data = []


# Bootstrap sampling
for i in range(num_bootstrap_samples):
    print(f"\n**** RUN {i+1} ***")
    for dataset in dataset_to_pickle_files:
        matching_files = [x[0] for x in dataset_to_pickle_files[dataset]]
        if len(matching_files) == 0:
            continue
        # Create a bootstrap sample by sampling with replacement
        bootstrap_sample = list(np.random.choice(matching_files, size=len(matching_files), replace=True))
        thetas = eutil.get_max_cross_ent_thetas(pickle_file_list=bootstrap_sample)
        #thetas = convert_thetas_to_alt_metrics(thetas)
        for midx, metric in enumerate(METRICS):
            data.append([dataset, metric, thetas[midx]])

thetas_split_on_same_cohort = pd.DataFrame(data, columns=["dataset", "Parsimony metric", "Fit theta"])
thetas_split_on_same_cohort.to_csv('thetas_split_on_same_cohort.csv') 
thetas_split_on_same_cohort

### How consistent are thetas when calibrating on random samples amongst all cancer type cohorts?

In [None]:
data = []
dataset_sizes = []
all_matching_files = []
for dataset in dataset_to_pickle_files:
    files = dataset_to_pickle_files[dataset]
    dataset_sizes.append(len(files))
    all_matching_files.extend(files)
print(dataset_sizes, len(all_matching_files))

for i in range(num_bootstrap_samples):
    print(f"\n**** RUN {i+1} ***")
    matching_files = [x[0] for x in all_matching_files]
    if len(matching_files) == 0:
        continue
    # Create a bootstrap sample by sampling with replacement
    bootstrap_sample = list(np.random.choice(matching_files, size=len(matching_files), replace=True))
    thetas = eutil.get_max_cross_ent_thetas(pickle_file_list=bootstrap_sample)
    for midx, metric in enumerate(METRICS):
        data.append([metric, thetas[midx]])

thetas_split_on_random_cohort = pd.DataFrame(data, columns=["Parsimony metric", "Fit theta"])
thetas_split_on_random_cohort.to_csv('thetas_split_on_random_cohort.csv') 

thetas_split_on_random_cohort

In [None]:
# Load data
thetas_split_on_random_cohort = pd.read_csv('thetas_split_on_random_cohort.csv', index_col=False).drop(columns=['Unnamed: 0']) 
thetas_split_on_same_cohort = pd.read_csv('thetas_split_on_same_cohort.csv', index_col=False).drop(columns=['Unnamed: 0']) 

colors = sns.color_palette("flare")[2:]

def point_plot(df, color, name, with_stats):
   
    sns.set(style='ticks', rc={'axes.labelsize': 6, 'xtick.labelsize': 6, 'ytick.labelsize': 6, })

    fig = plt.figure(figsize=(2, 2.5), dpi=200)
    snsfig = sns.pointplot(data=df, x='Parsimony metric', y='Fit theta', markersize=0.5, errwidth=1.7,capsize=0.1, 
                           dodge=True, markers='.', errorbar='sd', palette='viridis')

#     snsfig = sns.violinplot(x='Parsimony metric', y='Fit theta', data=df, 
#                          linewidth=0.7, color=color, alpha=0.5, inner=None)
    
    if with_stats:
        add_stat_annotation(snsfig, data=df, x='Parsimony metric', y="Fit theta",
                            box_pairs=itertools.combinations(METRICS, 2),
                            line_offset=0.2,test='t-test_welch', text_format='star', loc='outside', line_offset_to_box=0.1, 
                            text_offset=0.00001, verbose=2,  fontsize=7, comparisons_correction=None,linewidth=1.0,
                            )
    if with_stats:
        plt.ylim(0, 0.9) 
    else:
        plt.ylim(0, 0.7)
    plt.ylabel("Penalty")
    snsfig.spines['top'].set_visible(False)
    snsfig.spines['right'].set_visible(False)
    saved_name = "_".join(name.split(" "))
    plt.setp(snsfig.collections, alpha=0.8)
    plt.tight_layout(pad=1.8) 
    plt.xticks(rotation=45)  
    plt.savefig(f"output_plots/{saved_name}_theta_distribution.png", dpi=500)
    plt.title(name, fontsize=8)
    plt.show()
    plt.close()

with_stats = True
for i,dataset in enumerate(dataset_to_pickle_files):
    subset = thetas_split_on_same_cohort[thetas_split_on_same_cohort['dataset']==dataset]
    point_plot(subset, colors[i], dataset, with_stats)

point_plot(thetas_split_on_random_cohort, 'lightseagreen', 'random', with_stats)


### Make one plot per metric with datasets side by side 

In [None]:
from matplotlib.colors import to_rgb
from statannot import add_stat_annotation
dataset_colors = ["#b84988","#06879e","#5a9e09","#d4892a", "#939598"]

dataset_order = ["Melanoma", "HGSOC", "HR-NB", "NSCLC", "Combined"]
def plot_all_metrics(df, metric, with_stats, ylim, shape=(1.5,1.5), colors=None):

    sns.set(style='ticks', rc={'axes.labelsize': 6, 'xtick.labelsize': 6, 'ytick.labelsize': 6,  })
    subset = df[df['Parsimony metric']==metric]
    if with_stats:
        shape = (shape[0], shape[1]+0.3)
    fig = plt.figure(figsize=(shape[0], shape[1]), dpi=200)

    snsfig = sns.pointplot(data=subset, x='dataset', y='Fit theta', s=5, errwidth=1.7, order=dataset_order,
                           dodge=True, markers='.', errorbar='se', palette=dataset_colors, capsize=0.2)

#     snsfig = sns.violinplot(x='dataset', y='Fit theta', data=subset, palette=colors,
#                          linewidth=0.7, alpha=0.5, inner=None )

    groups = list(subset['dataset'].unique())
    
    # Plot the medians
#     medians = [np.median(subset[subset['dataset']==dataset]['Fit theta']) for dataset in set(groups)]
#     plt.setp(snsfig.collections, alpha=.3)
#     plt.scatter(x=range(len(medians)),y=medians,c="k",s=5.0, )
    
    if with_stats:
        if "Combined" in groups: groups.remove("Combined")

        line_offset = 0.17 if metric == 'Comigration' else 0.07
        add_stat_annotation(snsfig, data=subset, x='dataset', y="Fit theta",
                            box_pairs=itertools.combinations(groups, 2),order=dataset_order, line_offset=line_offset,
                            test='t-test_welch', text_format='star', loc='outside', line_offset_to_box=0.1, 
                            text_offset=0.00001, verbose=2,  fontsize=7, comparisons_correction=None,linewidth=1.0,
                            )
        plt.ylim(ylim[0], ylim[1]) 
    else:
        plt.ylim(ylim[0], ylim[1]) 
            
    snsfig.set_xticklabels(snsfig.get_xticklabels(), rotation=45, horizontalalignment='right')
    
    plt.xlabel("Dataset", fontsize=8)
    ylabel = f"{metric} penalty" if "index" not in metric else metric
    plt.ylabel(ylabel, fontsize=8)
    snsfig.spines['top'].set_visible(False)
    snsfig.spines['right'].set_visible(False)
    saved_name = "_".join(metric.split(" "))
    plt.setp(snsfig.collections, alpha=0.7)
    #plt.tight_layout() 
    plt.savefig(f"output_plots/{saved_name}_theta_distribution_stat.png", dpi=500,  bbox_inches='tight', pad_inches=0.05)
    plt.show()
    plt.close()
    
thetas_split_on_random_cohort['dataset'] = "Combined"
print(thetas_split_on_random_cohort)
print(thetas_split_on_same_cohort)
combined_df = pd.concat([thetas_split_on_same_cohort, thetas_split_on_random_cohort], axis=0)
combined_df = combined_df[combined_df['dataset']!='Breast Cancer'].reset_index(drop=True)

# This is to plot the alternative metrics (Effective migration #, polyclonality index)
import math
group_size = 3
tolerance = 1e-4
rows_to_add = []
for _, group in combined_df.groupby(combined_df.index // group_size):
    assert(math.isclose(group['Fit theta'].sum(),1.0, rel_tol=tolerance, abs_tol=tolerance))
    dataset = group['dataset'].unique().item()

    mig = float(group[group['Parsimony metric']=="Migration"]["Fit theta"].item())
    comig = float(group[group['Parsimony metric']=="Comigration"]["Fit theta"].item())
    ss = float(group[group['Parsimony metric']=="Seeding site"]["Fit theta"].item())
    rows_to_add.append([dataset,  "Migration # index", 1-(mig)])
    rows_to_add.append([dataset,  "Polyclonality index", 1-(comig/(mig+comig))])
    rows_to_add.append([dataset,  "Comigration # index", 1-(comig)])
    rows_to_add.append([dataset,  "Seeding site index", 1-ss])

combined_df = pd.concat([combined_df, pd.DataFrame(rows_to_add, columns=combined_df.columns)], axis=0)

with_stats = True
plot_all_metrics(combined_df, 'Migration', with_stats, (0.3,0.8))
plot_all_metrics(combined_df, 'Comigration', with_stats,(0.1,0.4))
plot_all_metrics(combined_df, "Seeding site", with_stats, (0.1,0.4))
plot_all_metrics(combined_df, 'Polyclonality index', with_stats, (0.5,0.9))



### How consistent are thetas across NSCLC subtypes?

In [None]:
# Split NSCLC lung cancer patients into their subtype (LUAD and LUSC)
import pyreadr
tracerx_patient_info = pyreadr.read_r(os.path.join('/data/morrisq/divyak/data/tracerx_nsclc_2023/20221109_TRACERx421_all_patient_df.rds'))[None]
tracerx_patient_info['histology_multi_full_genomically.confirmed'].value_counts()

import re
nsclc_subtype_to_pickle_files = {"LUSC":[], "LUAD":[]}
pattern = re.compile(r'CRUK[^_]+')
tracerx_pids = [pattern.search(fn[0]).group() for fn in dataset_to_pickle_files['NSCLC']]
print(len(tracerx_pids))
for pid,fn in zip(tracerx_pids, dataset_to_pickle_files['NSCLC']):
    subtype = tracerx_patient_info[tracerx_patient_info['cruk_id']==pid]['histology_multi_full_genomically.confirmed'].item()
    subtype = "LUAD" if "LUAD" in subtype else subtype
    if subtype == "Other":
        print("Not LUAD or LUSC subtype", pid)
        continue
    nsclc_subtype_to_pickle_files[subtype].append((fn[0], fn[1])) 
num_lusc = len(nsclc_subtype_to_pickle_files["LUSC"])
num_luad = len(nsclc_subtype_to_pickle_files["LUAD"])
print(f"LUSC: {num_lusc} patients, LUAD: {num_luad} patients")
    

In [None]:
# Fit thetas to bootstrap samples within the same subtype
data = []
num_bootstrap_samples = 50

# Bootstrap sampling
for i in range(num_bootstrap_samples):
    print(f"\n**** RUN {i+1} ***")
    for subtype in nsclc_subtype_to_pickle_files:
        matching_files = [x[0] for x in nsclc_subtype_to_pickle_files[subtype]]
        if len(matching_files) == 0:
            continue
        # Create a bootstrap sample by sampling with replacement
        bootstrap_sample = list(np.random.choice(matching_files, size=len(matching_files), replace=True))
        thetas = eutil.get_max_cross_ent_thetas(pickle_file_list=bootstrap_sample)
        for midx, metric in enumerate(METRICS):
            data.append([subtype, metric, thetas[midx]])

thetas_split_on_same_cohort = pd.DataFrame(data, columns=["dataset", "Parsimony metric", "Fit theta"])
thetas_split_on_same_cohort

In [None]:
rows_to_add = []
for _, group in thetas_split_on_same_cohort.groupby(thetas_split_on_same_cohort.index // group_size):
    assert(math.isclose(group['Fit theta'].sum(),1.0, rel_tol=tolerance, abs_tol=tolerance))
    dataset = group['dataset'].unique().item()

    mig = float(group[group['Parsimony metric']=="Migration"]["Fit theta"].item())
    comig = float(group[group['Parsimony metric']=="Comigration"]["Fit theta"].item())
    ss = float(group[group['Parsimony metric']=="Seeding site"]["Fit theta"].item())
    rows_to_add.append([dataset,  "Migration # index", 1-(mig)])
    rows_to_add.append([dataset,  "Comigration # index", 1-(comig)])
    rows_to_add.append([dataset,  "Seeding site index", 1-ss])

full_df = pd.concat([thetas_split_on_same_cohort, pd.DataFrame(rows_to_add, columns=thetas_split_on_same_cohort.columns)], axis=0)

plot_all_metrics(full_df, "Migration # index", False, shape=(2,2.3), name="nsclc_subtype", colors=['tab:blue', 'tab:orange'])
plot_all_metrics(full_df, "Comigration # index", False, shape=(2,2.3), name="nsclc_subtype", colors=['tab:blue', 'tab:orange'])
plot_all_metrics(full_df, "Seeding site index", False, shape=(2,2.3), name="nsclc_subtype", colors=['tab:blue', 'tab:orange'])


### Histogram of the number of metastases each patient has (in all cancer cohorts)

In [None]:
# Reload all patient data (this time don't exclude patients with <2 mets, and include breast)
DATASET_NAMES = DATASET_NAMES.insert(0,"Breast")
CALIBRATE_DIRS = CALIBRATE_DIRS.insert(0, os.path.join(REPO_DIR,"data/hoadley_breast_cancer_2016/metient_outputs/calibrate"))
                  
num_mets = []
dataset_to_all_pickle_files = {dataset:[] for dataset in DATASET_NAMES}
for dataset_name, calibrate_dir in zip(DATASET_NAMES, CALIBRATE_DIRS):
    matching_files = glob.glob(f'{calibrate_dir}/*pkl.gz')
    for fn in matching_files:
        with gzip.open(fn, 'rb') as f:
            pkl = pickle.load(f)
            num_sites = len(pkl[OUT_SITES_KEY])
            loss_dicts = pkl[OUT_LOSS_DICT_KEY]
            dataset_to_all_pickle_files[dataset_name].append((fn, num_sites))
            
# Do the same with nsclc patients for each nsclc subtype
nsclc_subtype_to_all_pickle_files = {"LUSC":[], "LUAD":[]}
pattern = re.compile(r'CRUK[^_]+')
tracerx_pids = [pattern.search(fn[0]).group() for fn in dataset_to_all_pickle_files['NSCLC']]
print(len(tracerx_pids))
for pid,fn in zip(tracerx_pids, dataset_to_all_pickle_files['NSCLC']):
    subtype = tracerx_patient_info[tracerx_patient_info['cruk_id']==pid]['histology_multi_full_genomically.confirmed'].item()
    subtype = "LUAD" if "LUAD" in subtype else subtype
    if subtype == "Other":
        print("Not LUAD or LUSC subtype", pid)
        continue
    nsclc_subtype_to_all_pickle_files[subtype].append((fn[0], fn[1])) 

In [None]:
def plot_hist(df, hue, bin_edges, colors, shape=(3.2,2)):
    sns.set(style='ticks', rc={'axes.labelsize': 8, 'xtick.labelsize': 8, 'ytick.labelsize': 8, 'axes.linewidth': 1.0})
    fig = plt.figure(figsize=(shape[0],shape[1]), dpi=500)
    
    print(bin_edges)
    ax = sns.histplot(df, hue=hue, x="Number of metastases", kde=False,  alpha=0.6, palette=colors, 
                      bins=bin_edges,legend=False, multiple='dodge', shrink=0.85, stat="count")
    ax.set_xlim(math.floor(bin_edges[0]), math.ceil(bin_edges[-1]))
    plt.xticks(ticks=[x for x in range(math.floor(bin_edges[0]), math.ceil(bin_edges[-1])+1)])
    plt.xlabel('Number of metastases')
    plt.ylabel('Number of patients')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.setp(ax.collections, alpha=0.7)
    # plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=2)
    plt.tight_layout(pad=1.6) 
    # saved_name = name + "_" + "_".join(metric.split(" "))
    # plt.savefig(f"output_plots/{saved_name}_theta_distribution_stat.png", dpi=300)
    plt.show()
    plt.close()
    

In [None]:
colors = ["tab:blue", "tab:orange"]
lusc_met_counts = [x[1] for x in nsclc_subtype_to_all_pickle_files['LUSC']]
luad_met_counts = [x[1] for x in nsclc_subtype_to_all_pickle_files['LUAD']]
print(lusc_met_counts)
print(luad_met_counts)
data = []
for name, counts in zip(["LUSC", "LUAD"], [lusc_met_counts, luad_met_counts]):
    for count in counts:
        data.append([name, count-1])
df = pd.DataFrame(data, columns=["Subtype", "Number of metastases"])
bin_edges = np.arange(min(min(lusc_met_counts), min(luad_met_counts))-1.5, max(max(lusc_met_counts), max(luad_met_counts)), 1)
plot_hist(df, "Subtype", bin_edges, colors)


In [None]:
# Make a dataframe of the number of mets for each patient
data = []
for dataset in dataset_to_all_pickle_files:
    fns = dataset_to_all_pickle_files[dataset]
    for fn in fns:
        data.append([dataset,fn[1]-1])
        
df = pd.DataFrame(data, columns=["Cancer type", "Number of metastases"])
print(df)  
for dataset in DATASET_NAMES:
    sns.set(style='ticks', rc={'axes.labelsize': 6, 'xtick.labelsize': 6, 'ytick.labelsize': 6, 'xtick.major.width': 0.8, 'ytick.major.width': 0.8, 'axes.linewidth': 0.8, })
    fig = plt.figure(figsize=(2,1.5), dpi=200)
    bin_edges = np.arange(0.5, 10.5, 1)
    
    print(dataset)

    ax = sns.histplot(data=df[df['Cancer type']==dataset], x='Number of metastases', legend=False,
                     stat="probability", hue='Cancer type', alpha=0.7, color="#808285", bins=bin_edges,
                     )
    ax.set_xlim(0,8)
    ax.set_ylim(0,1.0)
    plt.xticks(ticks=[x for x in range(0,9)])
    plt.xlabel('Number of metastases')
    plt.ylabel('Fraction of cohort')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.title(f"{dataset} (n={len(dataset_to_all_pickle_files[dataset])})", fontsize=6)
    plt.setp(ax.collections, alpha=0.7)
    plt.tight_layout(pad=1.6) 
    plt.savefig(f"output_plots/{dataset}_num_mets_histogram.png", dpi=500)
    plt.show()
    plt.close()


In [None]:
df[df['Cancer type']=="NSCLC"].value_counts()

In [None]:
nsclc_subtype_to_pickle_files['LUAD']