In [None]:
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform

In [None]:
import pandas as pd
import warnings
warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)
import scipy.optimize as opt
import regularized_optimization as reg_opt
import Utilityfunctions as utils
import jax.numpy as jnp
import numpy as np
import jax as jax
jax.config.update("jax_enable_x64", True)
import matplotlib.pyplot as plt
import logging
logging.basicConfig(filename='example.log', filemode='w', level=logging.INFO, force=True)

In [None]:
#mut_handle = "../data/paad/G12_PAADPANET_PM_z10_EventsAEP.csv"
#mut_handle = "../data/paad/G12_PAADPANET_PM_z10_Events_Mut50_Full.csv"
#mut_handle = "../data/paad/G12_PAADPANET_PM_z10_Events_30and5_Full.csv"
#annot_handle = "../data/paad/sampleSelection.txt"
mut_handle = "../data/luad/G13_LUAD_PM_v2_Events_20and15_Full.csv"
annot_handle = "../data/luad/G13_LUAD_PM_v2_sampleSelection_20and15.csv"
annot_data = pd.read_csv(annot_handle)
mut_data = pd.read_csv(mut_handle)
mut_data.rename(columns={"Unnamed: 0":"patientID"}, inplace = True)
dat = pd.merge(mut_data, annot_data.loc[:, ['patientID', 'metaStatus']], on=["patientID", "patientID"])

# Remove datapoints, that consist solely of NaNs
dat = dat.loc[dat.iloc[:,1:-3].isna().all(axis=1) == False, :]
dat = dat.loc[(dat.iloc[:,1:-3].sum(axis=1) > 0), :]
dat.columns

In [None]:
start = 'P.TP53 (M)'
stop = 'M.PTPRD/9p (Amp)'
mult = dat.set_index(["paired", "metaStatus"])
cleaned = mult.loc[zip(*[[0,0,0,1],["present", "absent", "isMetastasis", "isPaired"]]), start:stop]
cleaned = cleaned.sort_index()
cleaned.loc[(0, ["present", "isMetastasis"]), "Seeding"] = 1
cleaned.loc[(0, "absent"), "Seeding"] = 0
cleaned.loc[(1, "isPaired"), "Seeding"] = 1
dat_prim_nomet, dat_met_only, dat_prim_met, dat_coupled = utils.split_data(cleaned)
print(dat_prim_nomet.shape[0], dat_prim_met.shape[0], dat_coupled.shape[0], dat_met_only.shape[0])

In [None]:
events = []
for elem in cleaned.columns.to_list()[::2]:
    iwas = elem.split(".")
    if len(iwas) > 2:
        if iwas[1]  == 'Mut':
            events.append(iwas[2] + " (M)")
        else:
            events.append(iwas[1] + " (" + iwas[2] + ")")
    elif len(iwas) == 2:
        events.append(iwas[1])
    else:
        events.append("Seeding")

In [None]:
iwas = {"EM_Met":cleaned.loc[(0, "isMetastasis")].iloc[:, 1::2].sum(axis=1), 
        "EM_Prim":cleaned.loc[(0, "present")].iloc[:, 0:-1:2].sum(axis=1),
        "NM":cleaned.loc[(0, "absent")].iloc[:, 0:-1:2].sum(axis=1)}

fig, ax = plt.subplots()
ax.boxplot(iwas.values())
ax.set_xticklabels(iwas.keys())
plt.show()

In [None]:
n = (cleaned.shape[1]-1)//2
n += 1
lam1_start = np.log(30/87)
lam2_start = np.log(30/87) # observed mean time to second diagnosis is 87/162 days
indep = utils.indep(jnp.vstack((dat_met_only, dat_prim_met, dat_prim_nomet)), dat_coupled)
start_params = np.append(indep, [lam1_start, lam2_start])

In [None]:
n_mod = n-1
arr = dat_coupled * np.array([1,2]*n_mod+[1])
arr = arr @ (np.diag([1,0]*n_mod+[1]) + np.diag([1,0]*n_mod, -1))
counts = np.zeros((6, n))
for i in range(0,2*n,2):
    i_h = int(i/2)
    for j in range(1,4):
        counts[j-1, i_h] = np.count_nonzero(arr[:,i]==j)/dat_coupled.shape[0]
    counts[3, i_h] = np.sum(dat_prim_nomet[:, i], axis=0)/dat_prim_nomet.shape[0]
    counts[4, i_h] = (np.sum(dat_prim_met[:, i], axis=0))/dat_prim_met.shape[0]
    counts[5, i_h] = (np.sum(dat_met_only[:, i+1], axis=0))/dat_met_only.shape[0]

labels = [["Coupled ("+str(dat_coupled.shape[0])+")"]*3 +\
          ["NM ("+str(dat_prim_nomet.shape[0])+")"] +\
          ["EM-PT ("+str(dat_prim_met.shape[0])+")"] +\
          ["EM-MT ("+str(dat_met_only.shape[0])+")"],
          ["PT-Private", "MT-Private", "Shared"] + ["Present"]*3]
       
inds =  pd.MultiIndex.from_tuples(list(zip(*labels)))
iwas = pd.DataFrame(np.around(counts, 2), columns=events, index=inds).T
#iwas.to_latex("luad_samples.tex")
iwas

In [None]:
#utils.cross_val(cleaned, np.linspace(0.0001, 0.01, 5), 5, start_params, 0.65, n)

In [None]:
z = jnp.arange(0, 20).reshape((4,5))
print(z)
z.sum(axis=0)

In [10]:
penal1 = 1/900 # L1 penalty on off-diagonals
penal2 = 0.00 # L2 penalty on diagonals
m_p_corr = 0.65
reg_opt.value_grad(start_params ,dat_prim_nomet, dat_coupled, dat_prim_met, dat_met_only, n-1, penal1, penal2,  m_p_corr)

In [None]:
penal1 = 1/900 # L1 penalty on off-diagonals
penal2 = 0.00 # L2 penalty on diagonals
m_p_corr = 0.65
x = opt.minimize(reg_opt.value_grad, x0 = start_params, args = (dat_prim_nomet, dat_coupled, dat_prim_met, dat_met_only, n-1, penal1, penal2,  m_p_corr), 
                method = "L-BFGS-B", jac = True, options={"maxiter":10000, "disp":True, "ftol":1e-04})

In [None]:
print(1/jnp.exp(x.x[-2]))
df2 = pd.DataFrame(x.x[:-2].reshape((n, n)), columns=events, index=events)
theta = df2.copy()
df2["Sampling"] = np.append(np.array(x.x[-2:]), np.zeros(n-2))
df2.to_csv("../results/luad/luad_paired_only_20_15_0011.csv")
df2.round(3)

In [None]:
utils.plot_theta(df2, 0.4)