In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import random
random.seed(1100038344)
import survivalstan
import numpy as np
import pandas as pd
from stancache import stancache
from matplotlib import pyplot as plt

In [None]:
model_code = survivalstan.models.pem_survival_model_randomwalk

In [None]:
print(model_code)

In [None]:
d = survivalstan.sim.sim_data_exp_correlated(N=100,
                                             censor_time=20,
                                             rate_form='1 + sex',
                                             rate_coefs=[-3, 0.5])
d['age_centered'] = d['age'] - d['age'].mean()
d.head()

In [None]:
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='female'], event_col='event', time_col='t', label='female')
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='male'], event_col='event', time_col='t', label='male')
plt.legend()

In [None]:
dlong = survivalstan.prep_data_long_surv(df=d, event_col='event', time_col='t')

In [None]:
dlong.head()

In [None]:
testfit = survivalstan.fit_stan_survival_model(
    model_cohort = 'test model',
    model_code = model_code,
    df = dlong,
    sample_col = 'index',
    timepoint_end_col = 'end_time',
    event_col = 'end_failure',
    formula = '~ age_centered + sex',
    iter = 1000,
    chains = 4,
    seed = 9001,
    FIT_FUN = stancache.cached_stan_fit,
    )


In [None]:
survivalstan.utils.print_stan_summary([testfit], pars='lp__')

In [None]:
survivalstan.utils.print_stan_summary([testfit], pars='log_baseline_raw')

In [None]:
survivalstan.utils.plot_stan_summary([testfit], pars='log_baseline_raw')

In [None]:
survivalstan.utils.plot_coefs([testfit], element='baseline')

In [None]:
survivalstan.utils.plot_coefs([testfit])

In [None]:
survivalstan.utils.plot_pp_survival([testfit], fill=False)
survivalstan.utils.plot_observed_survival(df=d, event_col='event', time_col='t', color='green', label='observed')
plt.legend()

In [None]:
ppsurv = survivalstan.utils.prep_pp_survival_data([testfit], by='sex')

In [None]:
ppsurv.head()

In [None]:
subplot = plt.subplots(1, 1)
survivalstan.utils._plot_pp_survival_data(ppsurv.query('sex == "male"').copy(), subplot=subplot, color='blue', alpha=0.5)
survivalstan.utils._plot_pp_survival_data(ppsurv.query('sex == "female"').copy(), subplot=subplot, color='red', alpha=0.5)
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='female'], event_col='event', time_col='t',
                                          color='red', label='female')
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='male'], event_col='event', time_col='t',
                                          color='blue', label='male')
plt.legend()

In [None]:
survivalstan.utils.plot_pp_survival([testfit], by='sex')

In [None]:
survivalstan.utils.plot_pp_survival([testfit], by='sex', pal=['red', 'blue'])