# Survival Analysis Simple Example

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from survival import MIRModel

## Load data

In [None]:
df = pd.read_csv('../data/SEER_thyroid_survival_data_20200228.csv')

In [None]:
# subset the useful rows and columns for the fit
selected_columns = [
    'Year',
    'Deaths',
    'Cases',
    'b'
]
df_sub = df[
    (df['Summary Interval'] == 0) &
    (df['Year'] != 0)
][selected_columns].rename(columns={'b': 'other_mortality'})

## Compute survival rate

In [None]:
disease_period = 7
num_years = 5

In [None]:
model = MIRModel(df_sub['Deaths']/df_sub['Cases'],
                 df_sub['other_mortality'],
                 disease_period=disease_period)
model.compute_excess_mortality()
survival_rate = model.get_survival_rate(num_years=num_years)

In [None]:
df_sub['excess_mortality'] = model.excess_mortality
df_sub['abs_survival_rate'] = survival_rate['abs']
df_sub['rel_survival_rate'] = survival_rate['rel']
df_sub.sort_values('Year', inplace=True)

In [None]:
df_sub.head()

## Visualize result

In [None]:
# subset the data for comparison
df_compare = df[
    (df['interval'] == num_years) &
    (df['Year'] != 0)
][['Year', 'Observed']]

df_compare[df_compare['Observed'] == '.'] = np.nan
df_compare.sort_values('Year', inplace=True)

In [None]:
fig, ax = plt.subplots(1, figsize=(10, 10))

true = df_compare['Observed'].to_numpy().astype(np.float)/100.0
predicted = df_sub['abs_survival_rate'].to_numpy()

ax.scatter(true, predicted, marker='.')
for i, txt in enumerate(df_compare['Year'].to_numpy()):
    ax.annotate(txt, (true[i], predicted[i]), fontsize=7)

# plot settings
ax.set_xlabel(f'true {num_years} year survival rate')
ax.set_ylabel(f'predicted {num_years} year survival rate from M/I')
ax.set_aspect('equal')

r_min = min(true[~np.isnan(true)].min(), predicted[~np.isnan(predicted)].min())
r_max = max(true[~np.isnan(true)].max(), predicted[~np.isnan(predicted)].max())
r_len = r_max - r_min
ax.set_xlim(r_min - r_len*0.1, r_max + r_len*0.1)
ax.set_ylim(r_min - r_len*0.1, r_max + r_len*0.1)

ax.plot([r_min, r_max], [r_min, r_max], '--k', linewidth=0.7)