In [None]:
import pandas as pd
import scipy.optimize as opt
import regularized_optimization as reg_opt
import Utilityfunctions as utils
import jax.numpy as jnp
import numpy as np
import jax as jax
import matplotlib.pyplot as plt
jax.config.update("jax_enable_x64", True)

In [None]:
mut_handle = "../data/G12_PAADPANET_PM_z10_EventsAEP.csv"
#mut_handle = "../data/G12_PAADPANET_PM_z10_Events_Mut50_OnlyPaired.csv"
annot_handle = "../data/sampleSelection.txt"
annot_data = pd.read_csv(annot_handle, sep="\t")
mut_data = pd.read_csv(mut_handle)
mut_data.rename(columns={"Unnamed: 0":"patientID"}, inplace = True)
dat = pd.merge(mut_data, annot_data.loc[:, ['patientID', 'metaStatus']], on=["patientID", "patientID"])
# Remove datapoints, that consist solely of NaNs
dat = dat.loc[dat.iloc[:,1:-3].isna().all(axis=1) == False, :]
dat = dat.loc[(dat.iloc[:,1:-3].sum(axis=1) > 0), :]
dat2 = dat.copy()
dat

In [None]:
dat2["Seeding"] = 0
dat2.loc[dat2.paired==True, 'Seeding'] = 1
dat2.loc[(dat2.paired == False) & (dat2.metaStatus == "absent"), "Seeding"] = 0
dat2.loc[(dat2.paired == False) & (dat2.metaStatus == "present"),'Seeding'] = 1
dat2.loc[(dat2.paired == False) & (dat2.metaStatus == "isMetastasis"), 'Seeding'] = 1

In [None]:
def co_oc_mat(dat, events):
    n_events = events.shape[0]
    co_oc = np.full((n_events, n_events), np.NAN)
    for i in range(n_events):
        for j in range(i+1):
            co_oc[i, j] = dat.loc[(dat[events[i]] == 1) & (dat[events[j]] == 1),:].shape[0]
    return co_oc/dat.shape[0]

In [None]:
# Events in primary tumors
events_prim = np.append(np.array(dat2.columns[1:-5:2]), ["Seeding"])
dat_ps = dat2.loc[dat2.metaStatus == "absent", : ]
co_oc_ps = co_oc_mat(dat_ps, events_prim)
dat_ps_ms = dat2.loc[dat2.metaStatus == "present", : ]
co_oc_ms = co_oc_mat(dat_ps_ms, events_prim) 
dat_mt = dat2.loc[dat2.metaStatus == "isMetastasis", : ]
events_met = np.append(np.array(dat2.columns[2:-5:2]), ["Seeding"])
co_oc_mt = co_oc_mat(dat_mt, events_met)
co_oc_ms.shape

In [None]:
# Metastases
f, ax = plt.subplots(figsize=(19,15))
ax.matshow(co_oc_mt)
plt.xticks(range(co_oc_mt.shape[1]), events_met, fontsize=14, rotation=90)
plt.yticks(range(co_oc_mt.shape[1]), events_met, fontsize=14)
current_cmap = plt.cm.get_cmap()
current_cmap.set_bad(color='red')
#cb = ax.colorbar()
#cb.ax.tick_params(labelsize=14)
for i in range(events_met.shape[0]):
    for j in range(i+1):
        c = np.round(co_oc_mt[i,j], 2)
        ax.text(j, i, str(c), va='center', ha='center')
plt.show()


In [None]:
# Prims with metastases
f, ax = plt.subplots(figsize=(19,15))#plt.figure(figsize=(19, 15))
ax.matshow(co_oc_ms)
plt.xticks(range(co_oc_ms.shape[1]), events_prim, fontsize=14, rotation=90)
plt.yticks(range(co_oc_ms.shape[1]), events_prim, fontsize=14)
current_cmap = plt.cm.get_cmap()
current_cmap.set_bad(color='red')
#cb = ax.colorbar()
#cb.ax.tick_params(labelsize=14)

for i in range(events_prim.shape[0]):
    for j in range(i+1):
        c = np.round(co_oc_ms[i,j], 2)
        ax.text(j, i, str(c), va='center', ha='center')
plt.show()

In [None]:
# Prims without mets
f, ax = plt.subplots(figsize=(19,15))#plt.figure(figsize=(19, 15))
ax.matshow(co_oc_ps)
plt.xticks(range(co_oc_ps.shape[1]), events_prim, fontsize=14, rotation=90)
plt.yticks(range(co_oc_ps.shape[1]), events_prim, fontsize=14)
current_cmap = plt.cm.get_cmap()
current_cmap.set_bad(color='red')
#cb = ax.colorbar()
#cb.ax.tick_params(labelsize=14)

for i in range(events_prim.shape[0]):
    for j in range(i+1):
        c = np.round(co_oc_ps[i,j], 2)
        ax.text(j, i, str(c), va='center', ha='center')
plt.show()

In [None]:
# Select coupled datapoints
dat_coupled =  dat.loc[dat.paired==True, 'P.Mut.KRAS':'M.Mut.KMT2D']
dat_coupled['Seeding'] = 1
dat_coupled = dat_coupled.to_numpy(dtype = int)
dat_coupled = jnp.array(dat_coupled)

# select prim only + no metastastasis generated 
dat_prim_nomet =  dat.loc[(dat.paired == False) & (dat.metaStatus == "absent"), 'P.Mut.KRAS':'M.Mut.KMT2D']
dat_prim_nomet['Seeding'] = 0
dat_prim_nomet = dat_prim_nomet.to_numpy(dtype = int)
dat_prim_nomet = jnp.array(dat_prim_nomet)

# select prim_only + no metastasis_sequenced
dat_prim_met =  dat.loc[(dat.paired == False) & (dat.metaStatus == "present"), 'P.Mut.KRAS':'M.Mut.KMT2D']
dat_prim_met['Seeding'] = 1
dat_prim_met = dat_prim_met.to_numpy(dtype = int)
dat_prim_met = jnp.array(dat_prim_met)

# select metastasis only
dat_met_only =  dat.loc[(dat.paired == False) & (dat.metaStatus == "isMetastasis"), 'P.Mut.KRAS':'M.Mut.KMT2D']
dat_met_only['Seeding'] = 1
dat_met_only = dat_met_only.to_numpy(dtype = int)
dat_met_only = jnp.array(dat_met_only)

events = list(dat.columns[1:-4:2])
events.append("Seeding")
dat = jnp.vstack((dat_prim_nomet, dat_prim_met, dat_coupled, dat_met_only))
print(dat_prim_nomet.shape[0], dat_prim_met.shape[0], dat_coupled.shape[0], dat_met_only.shape[0])

In [None]:
n = (dat.shape[1] -1)//2
n += 1
lam1_start = np.log(30/391)
lam2_start = np.log(30/391) # observed mean time to second diagnosis is 391 days
indep = utils.indep(dat)
indep = indep.at[np.diag_indices(n)].add(lam2_start) # Assumption diagnosis and progression rates are on the same scale
start_params = np.append(indep, [lam1_start, lam2_start])

In [None]:
reg_opt.log_lik(start_params, dat_prim_met, dat_prim_nomet, dat_coupled, dat_met_only, 0.01, 0.8)

In [None]:
res = reg_opt.grad(start_params, dat_prim_met, dat_prim_nomet, dat_coupled, dat_met_only, 0.01, 0.8)
res

In [None]:
penal = 0.012
m_p_corr = 0.8
x = opt.minimize(reg_opt.log_lik, x0 = start_params, args = (dat_prim_nomet, dat_prim_met, dat_coupled, dat_met_only, penal, m_p_corr), 
                method = "L-BFGS-B", jac = reg_opt.grad, options={"maxiter":10000, "disp":True, "ftol":1e-05})

In [None]:
print(1/jnp.exp(x.x[-2]))
df2 = pd.DataFrame(x.x[:-2].reshape((22,22)), columns=events, index=events)
df2.to_csv("../results/prad_" + "aep_08_0012" + ".csv")
df = df2.copy()
df2.round(3)

In [None]:
df[(df.round(2) == 0) | (df.round(2) == -0)] = np.nan
df[df.abs() < 0.1] = np.nan
theta_diag = np.diag(df.copy()).reshape((-1,1))
np.fill_diagonal(df.values, np.nan)

plt.style.use("default")
f, (ax, ax2) = plt.subplots(1, 2, figsize=(19,15), gridspec_kw={'width_ratios': [6, 1]})
f.tight_layout()
ax.matshow(df, cmap="coolwarm")
ax2.matshow(theta_diag, cmap="coolwarm")
ax.set_xticks(range(df.shape[1]), events, fontsize=14, rotation=90)
ax.set_yticks(range(df.shape[1]), events, fontsize=14)
ax2.set_yticks(range(df.shape[1]), events, fontsize=14)
ax2.yaxis.tick_right()
ax2.yaxis.set_label_position("right")
ax2.set_xticks([])
current_cmap = plt.cm.get_cmap()
current_cmap.set_bad(color='black')


for i in range(22):
    for j in range(22):
        if np.isnan(df.iloc[i,j]) == False:
            c = np.round(df.iloc[i,j].round(2), 2)
        else:
            c = ""
        ax.text(j, i, str(c), va='center', ha='center')
    ax2.text(0, i, np.round(theta_diag[i,0],3), va='center', ha='center')
plt.show()