In [None]:
import metmhn.regularized_optimization as reg_opt
import metmhn.Utilityfunctions as utils

import pandas as pd
import warnings
warnings.simplefilter(action='ignore', 
                      category=pd.errors.PerformanceWarning)

import jax.numpy as jnp
import numpy as np
import jax as jax
jax.config.update("jax_enable_x64", True)

import logging
# Adapt path to where logs should be kept
logging.basicConfig(filename='../logs/paad.log',
                    format='%(asctime)s %(levelname)-8s %(message)s',
                    filemode='w', 
                    level=logging.INFO, 
                    force=True,
                    datefmt='%Y-%m-%d %H:%M:%S'
                    )

In the following example, we have information about the genotypes of Primary Tumors (PTs) and Metastases (MTs) for patients suffering from Pancreatic adenocarcinomas (PAADs) and Pancreatic neuroendocrine tumors (PANETs). For some patients we only know the status of either the PT or the MT. This is indicated in our data in the column 'isPaired'. The column 'metaStatus' indicates the type of the tumor ('present', 'isMetastasis', 'absent') if only a single genotype is available.

In [None]:
#annot_handle = "../data/prad/G13_PRAD_PM_onlyMuts_sampleSelection_50.csv"
#mut_handle = "../data/prad/G13_PRAD_PM_onlyMuts_Events_50.csv"
#annot_handle = "../data/blca/G13_BLCA_PM_onlyMuts_sampleSelection_33.csv"
#mut_handle = "../data/blca/G13_BLCA_PM_onlyMuts_Events_33.csv"
#mut_handle = "../data/paad/G13_PAADPANET_PM_v2_Events_30and15_Full.csv"
#annot_handle = "../data/paad/G13_PAADPANET_PM_v2_sampleSelection_30and15.csv"
mut_handle = "../data/luad/G13_LUAD_PM_v2_Events_20and15_Full.csv"
annot_handle = "../data/luad/G13_LUAD_PM_v2_sampleSelection_20and15.csv"
annot_data = pd.read_csv(annot_handle)
mut_data = pd.read_csv(mut_handle)
mut_data.rename(columns={"Unnamed: 0":"patientID"}, inplace = True)
dat = pd.merge(mut_data, annot_data.loc[:, ['patientID', 'metaStatus']], 
               on=["patientID", "patientID"])
dat.columns
muts = dat.columns[1:-4].to_list()
muts

In [None]:
muts = ['P.KRAS (M)', 'M.KRAS (M)', 'P.TP53 (M)', 'M.TP53 (M)',
       'P.SMAD4/18q (Del)', 'M.SMAD4/18q (Del)', 'P.MYC/8q (Amp)', 'M.MYC/8q (Amp)',
       'P.TP53/17p (Del)', 'M.TP53/17p (Del)', 'P.SMAD4 (M)', 'M.SMAD4 (M)', 
       'P.KRAS/12p (Amp)', 'M.KRAS/12p (Amp)', 'P.SETD2/3p (Del)', 'M.SETD2/3p (Del)',
       'P.MEN1/11q (Del)', 'M.MEN1/11q (Del)',
       'P.RNF43 (M)', 'M.RNF43 (M)', 'P.MEN1 (M)', 'M.MEN1 (M)', 
       'P.GNAS (M)', 'M.GNAS (M)', 'P.DAXX (M)', 'M.DAXX (M)', 'P.KMT2C (M)',
       'M.KMT2C (M)', 'P.ATM (M)', 'M.ATM (M)', 'P.RBM10 (M)', 'M.RBM10 (M)',
       'P.BRCA2 (M)', 'M.BRCA2 (M)', 'P.ATRX (M)', 'M.ATRX (M)',
       'P.PIK3CA (M)', 'M.PIK3CA (M)', 'P.SETD2 (M)', 'M.SETD2 (M)',
       'P.TGFBR1 (M)', 'M.TGFBR1 (M)', 'P.RB1 (M)', 'M.RB1 (M)',
       'P.SMARCA4 (M)', 'M.SMARCA4 (M)', 'P.SMAD3 (M)', 'M.SMAD3 (M)',
       'P.MAP2K4 (M)', 'M.MAP2K4 (M)', 'P.TSC2 (M)', 'M.TSC2 (M)',
       'P.BCOR (M)', 'M.BCOR (M)', 'P.STK11 (M)', 'M.STK11 (M)', 'P.U2AF1 (M)',
       'M.U2AF1 (M)', 'P.SF3B1 (M)', 'M.SF3B1 (M)', 'P.PTPRT (M)',
       'M.PTPRT (M)']
muts = ['P.TP53 (M)', 'M.TP53 (M)', 'P.TERT/5p (Amp)', 'M.TERT/5p (Amp)',
       'P.MCL1/1q (Amp)', 'M.MCL1/1q (Amp)', 'P.KRAS (M)', 'M.KRAS (M)',
       'P.EGFR (M)', 'M.EGFR (M)', 
       'P.TP53/17p (Del)', 'M.TP53/17p (Del)',
       'P.STK11/19p (Del)', 'M.STK11/19p (Del)', 'P.STK11 (M)', 'M.STK11 (M)',
       'P.KRAS/12p (Amp)', 'M.KRAS/12p (Amp)', 'P.KEAP1 (M)', 'M.KEAP1 (M)',
       'P.SMARCA4 (M)', 'M.SMARCA4 (M)', 'P.ATM (M)', 'M.ATM (M)', 'P.NF1 (M)',
       'M.NF1 (M)', 'P.PTPRD (M)', 'M.PTPRD (M)', 'P.PTPRT (M)', 'M.PTPRT (M)',
       'P.ARID1A (M)', 'M.ARID1A (M)', 'P.PIK3CA (M)', 'M.PIK3CA (M)',
       'P.BRAF (M)', 'M.BRAF (M)', 'P.SETD2 (M)', 'M.SETD2 (M)', 'P.EPHA3 (M)',
       'M.EPHA3 (M)', 'P.FAT1 (M)', 'M.FAT1 (M)']
# Label each datapoint with a numeric value according to its sequencetype
dat["type"] = dat.apply(utils.categorize, axis=1)
dat["Seeding"] = dat.apply(utils.add_seeding, axis=1)
events_data = muts+["Seeding"]

# Only use datapoints where the state of the metastasis is known
cleaned = dat.loc[dat["type"].isin([0,1,2,3]), muts+["Seeding", "type"]]
cleaned.drop(cleaned[cleaned.iloc[:,:-1].sum(axis=1)<1].index, inplace=True)
dat_prim_nomet, dat_prim_met, dat_met_only, dat_coupled = utils.split_data(cleaned, events_data)

Retrieve the event names and trim the PT/MT identifier:

In [None]:
events_plot = []
for elem in cleaned.columns[:-2].to_list()[::2]:
    full_mut_id = elem.split(".")
    events_plot.append(full_mut_id[1])
events_plot.append("Seeding")

Enumerate the frequencies of SNVs and CNVs in all subgroups. 'NM/EM' refer to Never Metastasizing/ Ever Metastasizing tumors, where only a single genotype is known. A Mutation is referred to as 'MT/PT-private' if happens exclusively in the MT or PT, otherwise it is called 'shared':

In [None]:
n_tot = (cleaned.shape[1]-1)//2 + 1
n_mut = n_tot-1
utils.marg_frequs(dat_prim_nomet, dat_prim_met, dat_met_only, dat_coupled, events_plot)

Optional: We use a  sparsity promoting L1 penalty. The weight of the penalization can be determined in a k-fold crossvalidation:

In [None]:
#log_lams = np.linspace(-4, -2, 5)
#lams = 10**log_lams
#print(lams)
#utils.cross_val(cleaned.copy(), events_data, lams, 5, 0.83)

Train an MHN on the full dataset:

In [None]:
penal = 0.0028 #L1 penalty on off-diagonals
m_p_corr = 0.65
#weights = jnp.array([4/dat_prim_met.shape[0], 4/dat_met_only.shape[0], 1/dat_coupled.shape[0]])
weights = jnp.array([1.,1.])
th_init, fd_init, sd_init = utils.indep(jnp.array(cleaned[events_data].to_numpy()), dat_coupled.shape[0])
theta, fd_effects, sd_effects= reg_opt.learn_mhn(th_init, fd_init, sd_init, dat_prim_nomet, dat_prim_met, 
                                                dat_coupled, m_p_corr, weights, penal)

In [None]:
th_plot = np.row_stack((fd_effects.reshape((1,-1)), 
                    sd_effects.reshape((1,-1)), 
                    theta))

Visualize the results:

In [None]:
utils.plot_theta(th_plot, events_plot, .1)

In [None]:
df2 = pd.DataFrame(th_plot, columns=events_plot)
df2.to_csv("../results/luad/luad_25_muts_0028.csv")