In [None]:
import metmhn.regularized_optimization as reg_opt
import metmhn.Utilityfunctions as utils

import pandas as pd
import warnings
warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)

import jax.numpy as jnp
import numpy as np
import jax as jax
jax.config.update("jax_enable_x64", True)

import logging
# Adapt path to where logs should be kept
logging.basicConfig(filename='../logs/paad.log', filemode='w', level=logging.INFO, force=True)

In the following example, we have information about the genotypes of Primary Tumors (PTs) and Metastases (MTs) for patients suffering from Pancreatic adenocarcinomas (PAADs) and Pancreatic neuroendocrine tumors (PANETs). For some patients we only know the status of either the PT or the MT. This is indicated in our data in the column 'isPaired'. The column 'metaStatus' indicates the type of the tumor ('present', 'isMetastasis', 'absent') if only a single genotype is available.

In [None]:
mut_handle = "../data/luad/G13_LUAD_PM_v2_Events_20and15_Full.csv"
annot_handle = "../data/luad/G13_LUAD_PM_v2_sampleSelection_20and15.csv"
#annot_handle = "../data/paad/G13_PAADPANET_PM_v2_sampleSelection_30and15.csv"
#mut_handle = "../data/paad/G13_PAADPANET_PM_v2_Events_30and15_Full.csv"

annot_data = pd.read_csv(annot_handle)
mut_data = pd.read_csv(mut_handle)
mut_data.rename(columns={"Unnamed: 0":"patientID"}, inplace = True)
dat = pd.merge(mut_data, annot_data.loc[:, ['patientID', 'metaStatus']], 
               on=["patientID", "patientID"])

In [None]:
#muts = dat.columns[np.concatenate([np.arange(1,9), [15, 16], np.arange(19,25), np.arange(27, 31), np.arange(41, 63) ])]
muts = ['P.TP53 (M)', 'M.TP53 (M)', 'P.TERT/5p (Amp)', 'M.TERT/5p (Amp)',
       'P.MCL1/1q (Amp)', 'M.MCL1/1q (Amp)', 'P.KRAS (M)', 'M.KRAS (M)',
       'P.EGFR (M)', 'M.EGFR (M)', 'P.TP53/17p (Del)', 'M.TP53/17p (Del)',
       'P.STK11/19p (Del)', 'M.STK11/19p (Del)', 'P.STK11 (M)', 'M.STK11 (M)',
       'P.KRAS/12p (Amp)', 'M.KRAS/12p (Amp)', 'P.KEAP1 (M)', 'M.KEAP1 (M)',
       'P.SMARCA4 (M)', 'M.SMARCA4 (M)', 'P.ATM (M)', 'M.ATM (M)', 'P.NF1 (M)',
       'M.NF1 (M)', 'P.PTPRD (M)', 'M.PTPRD (M)', 'P.PTPRT (M)', 'M.PTPRT (M)',
       'P.ARID1A (M)', 'M.ARID1A (M)', 'P.PIK3CA (M)', 'M.PIK3CA (M)',
       'P.BRAF (M)', 'M.BRAF (M)', 'P.SETD2 (M)', 'M.SETD2 (M)', 'P.EPHA3 (M)',
       'M.EPHA3 (M)', 'P.FAT1 (M)', 'M.FAT1 (M)']
mult = dat.set_index(["paired", "metaStatus"])
cleaned = mult.loc[zip(*[[0,0,0,1],["present", "absent", "isMetastasis", "isPaired"]]), 
                   muts]
cleaned = cleaned.sort_index()
cleaned.loc[(0, ["present", "isMetastasis"]), "Seeding"] = 1
cleaned.loc[(0, "absent"), "Seeding"] = 0
cleaned.loc[(1, "isPaired"), "Seeding"] = 1
dat_prim_nomet, dat_met_only, dat_prim_met, dat_coupled = utils.split_data(cleaned)
dat_prim_nomet = dat_prim_nomet.at[jnp.sum(dat_prim_nomet, axis=1)>0,:].get()
jnp.max(dat_coupled.sum(axis=1))

Retrieve the event names and trim the PT/MT identifier:

In [None]:
events = []
for elem in cleaned.columns[:-1].to_list()[::2]:
    full_mut_id = elem.split(".")
    events.append(full_mut_id[1])
events.append("Seeding")

Enumerate the frequencies of SNVs and CNVs in all subgroups. 'NM/EM' refer to Never Metastasizing/ Ever Metastasizing tumors, where only a single genotype is known. A Mutation is referred to as 'MT/PT-private' if happens exclusively in the MT or PT, otherwise it is called 'shared':

In [None]:
n_tot = (cleaned.shape[1]-1)//2 + 1
n_mut = n_tot-1
arr = dat_coupled * np.array([1,2]*n_mut+[1])
arr = arr @ (np.diag([1,0]*n_mut+[1]) + np.diag([1,0]*n_mut, -1))
counts = np.zeros((6, n_tot))
for i in range(0,2*n_tot,2):
    i_h = int(i/2)
    for j in range(1,4):
        counts[j-1, i_h] = np.count_nonzero(arr[:,i]==j)/dat_coupled.shape[0]
    counts[3, i_h] = np.sum(dat_prim_nomet[:, i], axis=0)/dat_prim_nomet.shape[0]
    counts[4, i_h] = (np.sum(dat_prim_met[:, i], axis=0))/dat_prim_met.shape[0]
    counts[5, i_h] = (np.sum(dat_met_only[:, i+1], axis=0))/dat_met_only.shape[0]

labels = [["Coupled ("+str(dat_coupled.shape[0])+")"]*3 +\
          ["NM ("+str(dat_prim_nomet.shape[0])+")"] +\
          ["EM-PT ("+str(dat_prim_met.shape[0])+")"] +\
          ["EM-MT ("+str(dat_met_only.shape[0])+")"],
          ["PT-Private", "MT-Private", "Shared"] + ["Present"]*3]
       
inds =  pd.MultiIndex.from_tuples(list(zip(*labels)))
counts = pd.DataFrame(np.around(counts, 2), columns=events, index=inds).T
counts

Optional: We use a  sparsity promoting L1 penalty. The weight of the penalization can be determined in a k-fold crossvalidation:

In [None]:
log_lams = np.linspace(-4.5, -2, 8)
lams = 10**log_lams
utils.cross_val(cleaned, lams, 3, 0.65)

Learn the MHN on the full dataset:

In [None]:
penal1 = 0.005 #L1 penalty on off-diagonals
m_p_corr = 0.65
th_init, fd_init, sd_init = utils.indep(jnp.array(cleaned.to_numpy()), dat_coupled.shape[0])
theta, fd_effects, sd_effects = reg_opt.learn_mhn(th_init, fd_init, sd_init, dat_prim_nomet, dat_prim_met, 
                                                  dat_met_only, dat_coupled, m_p_corr, penal1)

In [None]:
th_plot = np.row_stack((fd_effects.reshape((1,-1)), 
                    sd_effects.reshape((1,-1)), 
                    theta))

In [None]:
utils.plot_theta(th_plot, events, 0.1)

In [None]:
df2 = pd.DataFrame(th_plot, columns=events)
df2.to_csv("../results/luad/paired_20_8_001.csv")