In [1]:
import os
import glob
import gzip
import pickle
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from metient.util import plotting_util as putil

from metient.util.globals import *

four_dataset_colors = {"Melanoma":"#b84988","HGSOC":"#06879e","HR-NB":"#5a9e09", "NSCLC":"#d4892a"}

REPO_DIR = '/lila/data/morrisq/divyak/projects/metient/metient/'
OUTPUT_DIR = os.path.join(REPO_DIR, 'jupyter_notebooks', 'output_plots')
DATASET_NAMES = ["Breast Cancer", "HGSOC", "Melanoma", "HR-NB", "NSCLC"]
CALIBRATE_DIRS = [os.path.join(REPO_DIR,"data/hoadley_breast_cancer_2016/metient_outputs/solve_polys", "calibrate"),
                  os.path.join(REPO_DIR,"data/mcpherson_ovarian_2016/metient_outputs/solve_polys", "calibrate"),
                  os.path.join(REPO_DIR,"data/sanborn_melanoma_2015/metient_outputs_solve_polys_no_cna_pyclone_vi_orchard", "calibrate"),
                  os.path.join(REPO_DIR,"data/gundem_neuroblastoma_2023/metient_outputs/solve_polys", "calibrate"),
                  os.path.join(REPO_DIR,"data/tracerx_nsclc/metient_outputs/tracerx_trees_06142024/calibrate",)]

desired_order = ["single-source", 'multi-source', "reseeding", "primary single-source"]

In [2]:
# Helper functions
def gene_names(item, dataset):
    if dataset == 'HR-NB':
        if str.isdigit(item.split(";")[0]):
            return [item]
        return [item.split(";")[0]]
    elif dataset == 'NSCLC':
        return item.split(";")
    return [item]
    
def update_dict(muts, dct, dataset):
    for x in muts:
        for y in x:
            genes = gene_names(y, dataset)
            #print(genes)
            for g in genes:
                seeding_muts_to_count = dct[dataset]
                if g not in seeding_muts_to_count:
                    seeding_muts_to_count[g] = 0
                seeding_muts_to_count[g] += 1

In [8]:

data = []

dataset_to_seeding_muts_to_count = {d:{} for d in DATASET_NAMES}
dataset_to_nonseeding_muts_to_count = {d:{} for d in DATASET_NAMES}

for calibrate_dir,dataset in zip(CALIBRATE_DIRS, DATASET_NAMES):
    # Use glob to get the list of matching files
    matching_files = glob.glob(f'{calibrate_dir}/*pkl.gz')
    patients = [m.split("/")[-1].split("_")[0] for m in matching_files]
    print(dataset, len(patients))
    for fn in matching_files:
        with gzip.open(fn, 'rb') as f:
            
            pid = fn.split("/")[-1].split("_")[0]
            pkl = pickle.load(f)
            
            # Best calibrated tree
            V = torch.tensor(pkl[OUT_LABElING_KEY][0])
            A = torch.tensor(pkl[OUT_ADJ_KEY][0])
            U = pkl[OUT_SUB_PRES_KEY]
            losses = [l.item() for l in pkl[OUT_LOSSES_KEY]]
            seeding_clusters = putil.seeding_clusters(V, A)
            idx_to_label = pkl[OUT_IDX_LABEL_KEY][0]
            
            idx_to_label_no_leaves = {}
            for i in idx_to_label:
                if idx_to_label[i][1] == False:
                    idx_to_label_no_leaves[i] = idx_to_label[i][0]
            nonseeding_clusters = [k for k in idx_to_label_no_leaves.keys() if k not in seeding_clusters]
            
            #print(seeding_clusters, nonseeding_clusters)
            
            #print(idx_to_label)
            seeding_muts = [idx_to_label_no_leaves[x] for x in seeding_clusters]
            nonseeding_muts = [idx_to_label_no_leaves[x] for x in nonseeding_clusters]
#             print(seeding_muts, nonseeding_muts)
            #print(nonseeding_muts)
            for n in nonseeding_muts:
                for x in n:
                    if "pol" in x:
                        print(n)
            update_dict(seeding_muts, dataset_to_seeding_muts_to_count, dataset)
            update_dict(nonseeding_muts, dataset_to_nonseeding_muts_to_count, dataset)
            


Breast Cancer 0
HGSOC 13
['2pol6']
['1pol11']
['1pol11']
['7pol12']
['3pol11']
['2pol7']
['2pol6']
['2pol7']
Melanoma 7
['5pol9']
['3pol7']
['5pol8']
HR-NB 30
['0pol5']
['0pol5']
['0pol5']
['2pol28']
['11pol29']
['17pol30']
['15pol16']
['2pol15']
['6pol16']
['12pol18']
['6pol13']
['3pol35']
['18pol36']
['18pol37']
['7pol14']
['23pol37']
['23pol38']
['31pol40']
['34pol42']
['6pol10']
['0pol9']
['5pol8']
['0pol4']
['0pol16']
['0pol5']
NSCLC 128
['16pol24']
['17pol25']
['15pol25']
['0pol6']
['6pol16']
['2pol8']
['6pol13']
['7pol20']
['7pol21']
['15pol25']
['2pol11']
['25pol32']
['2pol17']
['10pol34']
['14pol35']
['1pol5']
['1pol14']
['6pol17']
['11pol18']
['5pol12']
['6pol13']
['3pol7']
['2pol15']
['2pol16']
['1pol7']
['7pol13']
['1pol5']
['13pol21']
['19pol27']
['19pol31']
['5pol13']
['6pol17']
['10pol19']
['12pol21']
['0pol2']
['3pol9']
['6pol30']
['28pol34']
['1pol15']
['12pol15']
['12pol16']
['5pol19']
['17pol20']
['18pol22']
['11pol20']
['3pol5']
['6pol8']
['4pol10']
['3pol13']
['3po

In [4]:
dataset_to_nonseeding_muts_to_count

{'Breast Cancer': {},
 'HGSOC': {'2pol6': 2,
  'E': 10,
  '1pol11': 2,
  'C': 6,
  'F': 4,
  'G': 6,
  'I': 2,
  'D': 4,
  '7pol12': 1,
  '3pol11': 1,
  'A': 2,
  'B': 3,
  '2pol7': 2,
  'H': 2},
 'Melanoma': {'ABLIM2': 1,
  'SHISA3': 1,
  'SDS': 1,
  'EPS8L3': 1,
  'CHRDL2': 1,
  'AIFM3': 1,
  'ZNF829': 1,
  'MLL4': 2,
  'PROM2': 1,
  'GUSB': 1,
  'FBXL7': 2,
  'SHANK1': 2,
  'SLC5A1': 1,
  'SLC4A9': 1,
  'SPPL2B': 2,
  'MESDC1': 1,
  'NRAS': 2,
  'ZNF629': 1,
  'DNAJC6': 1,
  'RP1L1': 1,
  'LILRA2': 1,
  'ISG20': 1,
  'SLC18A1': 1,
  'B3GNT6': 1,
  'ATP10B': 1,
  'SBSN': 1,
  'TRPM1': 1,
  'ZNF865': 1,
  'NCOA2': 1,
  'RERGL': 1,
  'EMG1': 1,
  'EXOSC10': 1,
  'ERCC8': 1,
  'GAlymph node metastasis, left groinT10': 1,
  'UTP15': 1,
  'LPPR5': 1,
  'HADHA': 1,
  'ARID1A': 1,
  'FBN2': 1,
  'KDM2B': 1,
  'MYO15A': 1,
  'ZZZ3': 1,
  'CDH9': 1,
  'PTPRC': 1,
  'MYT1L': 1,
  'MAP3K10': 1,
  'DALRD3': 1,
  'ZNF267': 1,
  'CCDC24': 1,
  'UBE2O': 1,
  'FSIP2': 1,
  'FZD1': 2,
  'TCERG1': 1,


In [5]:
top_genes_per_dataset = {}

for dataset, genes in dataset_to_seeding_muts_to_count.items():
    # Sort genes by count in descending order and get the top 10
    sorted_genes = sorted(genes.items(), key=lambda item: item[1], reverse=True)[:10]
    top_genes_per_dataset[dataset] = dict(sorted_genes)

print(top_genes_per_dataset)

{'Breast Cancer': {}, 'HGSOC': {'A': 11, 'B': 10, 'C': 7, 'D': 7, 'F': 4, 'H': 2, 'E': 1}, 'Melanoma': {'TTN': 6, 'MUC16': 5, 'PLCH1': 4, 'APOB': 4, 'DCC': 4, 'FLG': 4, 'RAD51AP2': 4, 'LRRIQ4': 3, 'DNAH6': 3, 'MRVI1': 3}, 'HR-NB': {'DPP10': 61, 'LRP1B': 60, 'EYS': 58, 'NKAIN2': 53, 'CSMD3': 50, 'CNTNAP2': 49, 'NAALADL2': 44, 'ERBB4': 43, 'PTPRD': 41, 'CTNNA3': 36}, 'NSCLC': {'TP53': 15, 'COL5A1': 11, 'PIK3CA': 8, 'HGF': 7, 'KIF1A': 6, 'KEAP1': 6, 'PDGFRA': 6, 'ATR': 5, 'NF2': 4, 'NIPBL': 4}}


In [6]:
top_nonseedinggenes_per_dataset = {}

for dataset, genes in dataset_to_nonseeding_muts_to_count.items():
    # Sort genes by count in descending order and get the top 10
    sorted_genes = sorted(genes.items(), key=lambda item: item[1], reverse=True)[:10]
    top_nonseedinggenes_per_dataset[dataset] = dict(sorted_genes)

print(top_nonseedinggenes_per_dataset)

{'Breast Cancer': {}, 'HGSOC': {'E': 10, 'C': 6, 'G': 6, 'F': 4, 'D': 4, 'B': 3, '2pol6': 2, '1pol11': 2, 'I': 2, 'A': 2}, 'Melanoma': {'MLL4': 2, 'FBXL7': 2, 'SHANK1': 2, 'SPPL2B': 2, 'NRAS': 2, 'FZD1': 2, 'ALPK3': 2, 'FLG': 2, 'DLL4': 2, 'OLFR959': 2}, 'HR-NB': {'LRP1B': 304, 'DPP10': 301, 'CSMD3': 282, 'CNTNAP2': 240, 'EYS': 230, 'PCDH15': 221, 'PTPRD': 220, 'CTNNA3': 218, 'AC007682.1': 213, 'NKAIN2': 204}, 'NSCLC': {'COL5A1': 36, 'SPTA1': 28, 'TP53': 24, 'ATR': 18, 'MACF1': 18, 'TTN': 18, 'KEAP1': 18, 'NIPBL': 17, 'PLXNB2': 15, 'PDGFRA': 15}}
