# Survival *lifelines*
[https://lifelines.readthedocs.io/](https://lifelines.readthedocs.io/en/latest/Survival%20Regression.html)

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
os.chdir("/home/jovyan/work/phd/datasets/cancers/lung")

In [None]:
df_files = pd.read_csv("files.txt", sep="\t").set_index("file_name").dropna(how="all", thresh=5, axis=1)

#df_files = pd.read_csv("files.dat", sep=",",index_col=0).dropna(how="all", thresh=100, axis=1)
#df_files = df_files[df_files["dataset"]=="tcga"]
df_files.info()

In [None]:
df_files["cases.0.diagnoses.0.last_known_disease_status"].unique()

In [None]:
bins = np.linspace(0,100,20)
avg = df_files["cases.0.diagnoses.0.age_at_diagnosis"].mean(skipna=True)
df_files["age_at_diagnosis"]=pd.cut(df_files["cases.0.diagnoses.0.age_at_diagnosis"].fillna(avg)/365., bins=bins, labels = (bins[1:]+bins[:-1])/2).__array__()

In [None]:
df_files["gender"] = (df_files["cases.0.demographic.gender"]=="male").astype(int)

In [None]:
for letter in ["a", "b"]:
    for (old, new) in zip(["stage %s%s"%(i,letter) for i in ["i", "ii", "iii", "iv"]],["stage %s"%i for i in ["i", "ii", "iii", "iv"]]):
        df_files.replace(old, new, inplace=True)
df_files["cases.0.diagnoses.0.tumor_stage"].unique()

In [None]:
df_files["tumor_stage"]=df_files["cases.0.diagnoses.0.tumor_stage"]
for i,stage in enumerate(["stage i", "stage ii", "stage iii", "stage iv", "stage v"]):
    df_files["tumor_stage"].replace(stage, i+1, inplace=True)

In [None]:
#1 = Alive
df_files["vital_status"]=(df_files["cases.0.demographic.vital_status"]=="Dead").astype(int)

In [None]:
print(df_files["cases.0.diagnoses.0.last_known_disease_status"].unique())
(df_files["cases.0.diagnoses.0.age_at_diagnosis"]/365).hist()

In [None]:
def get_survival(case):
    if case["cases.0.demographic.vital_status"] == 1:
        return case["cases.0.demographic.days_to_death"]
    else:
        return case["cases.0.diagnoses.0.days_to_last_follow_up"]

df_files["days_survival"] = df_files.apply(get_survival,1)

In [None]:
bins = np.linspace(0,100,10)
df_files["smoke"] = pd.cut(df_files["cases.0.exposures.0.years_smoked"], bins=bins, labels = (bins[1:]+bins[:-1])/2 ).__array__()
df_files["smoke"].fillna(0, inplace=True)

In [None]:
(df_files["days_survival"]/365).hist()

In [None]:
subset = df_files[~df_files["days_survival"].isna()].sample(80)

In [None]:
from lifelines.plotting import plot_lifetimes

CURRENT_TIME = 5

actual_lifetimes = subset["days_survival"].to_numpy()/365
observed_lifetimes = np.minimum(actual_lifetimes, CURRENT_TIME)
death_observed = actual_lifetimes < CURRENT_TIME

ax = plot_lifetimes(observed_lifetimes, event_observed=death_observed, figsize=(18,15))

ax.set_xlim(0, CURRENT_TIME*1.1)
ax.vlines(CURRENT_TIME, 0, 30, lw=2, linestyles='--')
ax.set_xlabel("time (years)", fontsize=35)
ax.tick_params(labelsize=35)
ax.set_title("Births and deaths of our population, at $t=70$", fontsize=35)
plt.tight_layout()
#print("Observed lifetimes at time %d:\n" % (CURRENT_TIME), observed_lifetimes)

In [None]:
subset = df_files[~df_files["days_survival"].isna()]
data = {}
data["duration"]=subset["days_survival"]/365
data["observed"]=subset["vital_status"]
data["entry"]=subset["cases.0.demographic.days_to_birth"]/365

In [None]:
from lifelines import KaplanMeierFitter
kmf = KaplanMeierFitter()

In [None]:
df_files["vital_status"].sum()

In [None]:
T = data["duration"]
E = data["observed"]
entry = data["entry"]

In [None]:
kmf.fit(T, event_observed=E)
ax = kmf.plot()

ax.set_title('Survival function of Lung', fontsize=35)
ax.set_xlabel("time (years)", fontsize=35)
ax.set_ylabel("Survival(t)", fontsize=35)
ax.tick_params(labelsize=30)
ax.set_xlim(0,15)

In [None]:
from lifelines.utils import median_survival_times
median_ci = median_survival_times(kmf.confidence_interval_)
print(kmf.median_survival_time_, "+-", median_ci)

In [None]:
ax = plt.subplot(111)

mask = (subset["cases.0.demographic.gender"]=="male")

kmf.fit(T[mask], event_observed=E[mask], label="male")
kmf.plot(ax=ax)

kmf.fit(T[~mask], event_observed=E[~mask], label="female")
kmf.plot(ax=ax)

plt.ylim(0, 1)
plt.xlim(0,15)
plt.title("Survival of different cancer types")

In [None]:
key = 'cases.0.demographic.gender'
labels = subset[key].unique()

for i, label in enumerate(np.sort(labels)):
    ax = plt.subplot(3, 3, i + 1)

    ix = subset[key] == label
    kmf.fit(T[ix], E[ix], label=label)
    kmf.plot(ax=ax, legend=False)

    plt.title(label)
    plt.xlim(0, 15)

    if i==0:
        plt.ylabel('Frac. alive after $n$ days')

plt.tight_layout()

In [None]:
import importlib, survival
importlib.reload(survival)
from survival import fit_cox, add_group_to_subset, save_plot

In [None]:
df_clusters = pd.read_csv("topsbm/topsbm_level_1_topic-dist.csv",index_col=1).drop("i_doc",1)
df_clusters = pd.read_csv("lda/lda_level_2_topic-dist.csv",index_col=1).drop("i_doc",1)

#df_clusters = pd.read_csv("topsbm/topsbm_level_1_topic-dist.csv",index_col=1).drop("i_doc",1)
#df_clusters = df_clusters[df_clusters.index.isin(filter(lambda doc: "GTEX" not in doc,df_clusters.index))]

In [None]:
df_files.head(2)

In [None]:
for dataset in ["TCGA-LUAD", "TCGA-LUSC"]:
    mask = (~df_files["days_survival"].isna()) & (df_files["cases.0.diagnoses.0.tumor_stage"]!="not reported") & (df_files["cases.0.project.project_id"]==dataset)
    subset = df_files[mask]
    subset = subset[["days_survival","vital_status","gender", "tumor_stage", "age_at_diagnosis"]]
    for topic in df_clusters.columns:
        top_set = add_group_to_subset(topic, subset, df_clusters, 0.5)
        print(top_set["group"].sum())
        summary, _, ax = fit_cox(top_set, topic)
        if summary is not None:
            if summary.at[summary.index[-1],"-log2(p)"] > 3:
                print(dataset,": ",topic,"\n",summary.loc[summary.index[-1],["coef", "p"]],"\n")
        if ax is not None:
            ax.set_title(dataset+" "+ax.title.get_text(), fontsize=35)
            save_plot(ax, dataset, topic)

In [None]:
summary, _, ax = fit_cox(top_set, "Topic 3")
print(summary[["coef", "exp(coef)", "p", "-log2(p)", "corrected_p", "-log2(corrected_p)"]])
save_plot(ax, "all", "Topic 3")

In [None]:
for g in pd.read_csv("datasets/cancers/lung/lda/lda_level_2_topics.csv",index_col=1)["Topic 18"].values:
    print(g[:15])

In [None]:
(pd.read_csv("datasets/cancers/lung/topsbm/topsbm_level_1_topics.csv",index_col=1)=="ENSG00000121552").any()