In [None]:
import pandas as pd
import scipy.optimize as opt
import regularized_optimization as reg_opt
import Utilityfunctions as utils
import jax.numpy as jnp
import numpy as np
import jax as jax
jax.config.update("jax_enable_x64", True)

In [None]:
mut_handle = "../data/G12_PAADPANET_PM_z10_Events.csv"
annot_handle = "../data/sampleSelection.txt"
annot_data = pd.read_csv(annot_handle, sep="\t")
mut_data = pd.read_csv("../data/G12_PAADPANET_PM_z10_Events.csv")
mut_data.rename(columns={"Unnamed: 0":"patientID"}, inplace = True)
dat = pd.merge(mut_data, annot_data.loc[:, ['patientID', 'metaStatus']], on=["patientID", "patientID"])
# Remove datapoints, that consist solely of NaNs
dat = dat.loc[dat.iloc[:,1:-3].isna().all(axis=1) == False, :]
dat = dat.loc[(dat.iloc[:,1:-3].sum(axis=1) > 0), :]

In [None]:
# Select coupled datapoints
dat_coupled =  dat.loc[dat.paired==True, 'P.Mut.KRAS':'M.Mut.KMT2D']
dat_coupled['Seeding'] = 1
dat_coupled = dat_coupled.to_numpy(dtype = int)
dat_coupled = jnp.array(dat_coupled)

# select prim only + no metastastasis generated 
dat_prim_nomet =  dat.loc[(dat.paired == False) & (dat.metaStatus == "absent"), 'P.Mut.KRAS':'M.Mut.KMT2D']
dat_prim_nomet['Seeding'] = 0
dat_prim_nomet = dat_prim_nomet.to_numpy(dtype = int)
dat_prim_nomet = jnp.array(dat_prim_nomet)

# select prim_only + no metastasis_sequenced
dat_prim_met =  dat.loc[(dat.paired == False) & (dat.metaStatus == "present"), 'P.Mut.KRAS':'M.Mut.KMT2D']
dat_prim_met['Seeding'] = 1
dat_prim_met = dat_prim_met.to_numpy(dtype = int)
dat_prim_met = jnp.array(dat_prim_met)

# select metastasis only
dat_met_only =  dat.loc[(dat.paired == False) & (dat.metaStatus == "isMetastasis"), 'P.Mut.KRAS':'M.Mut.KMT2D']
dat_met_only['Seeding'] = 1
dat_met_only = dat_met_only.to_numpy(dtype = int)
dat_met_only = jnp.array(dat_met_only)

events = list(dat.columns[1:-4:2])
events.append("Seeding")
dat = jnp.vstack((dat_prim_nomet, dat_prim_met, dat_coupled, dat_met_only))

In [None]:
n = (dat.shape[1] -1)//2
n += 1
lam1_start = np.log(30/391)
lam2_start = np.log(30/391) # observed mean timee to second diagnosis is 391 days
indep = utils.indep(dat)
indep = indep.at[np.diag_indices(n)].add(lam2_start) # Assumption diagnosis and progression rates are on the same scale
start_params = np.append(indep, [lam1_start, lam2_start])

In [None]:
reg_opt.log_lik(start_params, dat_prim_met, dat_prim_nomet, dat_coupled, dat_met_only, 0.0)

In [None]:
res = reg_opt.grad(start_params, dat_prim_met, dat_prim_nomet, dat_coupled, dat_met_only, 0.0)

In [None]:
w_prim = 36/77
w_met = 41/77
weights = np.array([w_prim, w_met/3, w_met/3, w_met/3])
penal = 0.01
x = opt.minimize(reg_opt.log_lik, x0 = start_params, args = (dat_prim_met, dat_prim_nomet, dat_coupled, dat_met_only, penal, weights), 
                method = "L-BFGS-B", jac = reg_opt.grad, options={"maxiter":100, "disp":True, "ftol":1e-06})

In [None]:
print(1/jnp.exp(x.x[-2]))
df = pd.DataFrame(x.x[:-2].reshape((n,n)), columns=events, index=events)
df.to_csv("../results/prad_"+str(penal)+".csv")