# Survival analysis

In [None]:
from typing import List, Iterable
from functools import partial
from collections import Counter

In [None]:
import sys
sys.path.append('./lib')

In [None]:
import numpy as np
import pandas as pd

In [None]:
import matplotlib.pyplot as plt
from IPython.core.pylabtools import figsize
import seaborn as sns

In [None]:
sns.set_theme()
figsize(11, 6)

In [None]:
import nsfg
import compstats
from cdf import Cdf, resample_rows_weighted

In [None]:
r3 = partial(np.round, decimals=3)
r2 = partial(np.round, decimals=2)

## Survival Curves

The fundamental concept in survival analysis is the survival curve, S(t), which is a function that maps from a duration, t, to the probability of surviving longer than t. If you know the distribution of durations, or “lifetimes”, finding the survival curve is easy; it’s just the complement of the CDF:

$$
S(t) = 1 - CDF(t)
$$

where $CDF(t)$ is the probability of a lifetime less than or equal to t

For example, in the NSFG dataset, we know the duration of 9038 complete pregnancies. We can read this data and compute the CDF:

In [None]:
preg = nsfg.read_fem_preg()
complete = preg.query('outcome in [1,3, 4]').prglngth

The outcome codes 1, 3, 4 indicate live birth, stillbirth, and miscarriage.

For this analysis I am excluding induced abortions, ectopic pregnancies, and pregnancies that were in progress when the respondent was interviewed.

To represent the survival curve, I define an object that wraps a Cdf and adapts the interface:

In [None]:
class SurvivalFunction:
    
    def __init__(self, cdf: Cdf):
        self.cdf = cdf
        
    @property
    def ts(self):
        '''
        The durations
        '''
        return self.cdf.xs
    
    @property
    def ss(self):
        '''
        The probabilities of surviving longer than t
        '''
        return 1 - self.cdf.ps
    
    def __getitem__(self, t):
        return self.prob(t)
    
    def prob(self, t):
        return 1 - self.cdf.prob(t)
    
    def probs(self, x: np.array):
        return 1 - self.cdf.probs(x)
    
    def __call__(self, t):
        if isinstance(t, Iterable):
            return 1 - self.cdf.probs(t)
        return 1 - self.cdf.prob(t)

In [None]:
p = sns.histplot(
    x = complete,
    binwidth=1
);
p.set(
    xlabel = 'Pregnancy length (weeks)'
);

In [None]:
cdf = Cdf.from_series(complete)
sf = SurvivalFunction(cdf)

In [None]:
sns.ecdfplot(
    x = complete,
    label='CDF'
);
plt.plot(
    sf.ts,
    sf.ss,
    label='Survival'
);
plt.xlabel('Pregnancy length (in weeks)')
plt.ylabel('CDF')
plt.legend(loc='center right');

The curve is nearly flat between 13 and 26 weeks, which shows that few pregnancies end in the second trimester. And the curve is steepest around 39 weeks, which is the most common pregnancy length.

For example, $sf(13)$ is the fraction of pregnancies that proceed past the first trimester:

In [None]:
r3(sf(13))

In [None]:
r3(cdf(13))

About 85% of pregancies proceed past the first trimester; about 14% do not

### Hazard function

From the survival curve we can derive the hazard function; for pregnancy lengths, the hazard function maps from a time, t, to the fraction of pregnancies that continue until t and then end at t, where t is a discrete unit of time like a week or a day. To be more precise:

$$
\lambda(t) = \frac{S(t) - S(t+1}{S(t)}
$$

The numerator is the fraction of lifetimes that end at t, which is also $PMF(t)$.

In [None]:
class HazardFunction:
    
    @classmethod
    def from_dict(cls, data: dict, sort=False):
        series = pd.Series(data)
        if sort:
            series.sort_index(inplace=True)
        return cls(series)
    
    def __init__(self, series: pd.Series):
        # series will be sorted by the keys
        self.series = series
        
    def __call__(self, t):
        return self.series[t]
    
    @property
    def ts(self):
        '''
        Returns the time durations
        '''
        return self.series.index
    
    @property
    def ss(self):
        '''
        Returns the proportion of items that end for each duration
        '''
        return self.series.values
        
def make_hazard(sf: SurvivalFunction) -> HazardFunction:
    ss = sf.ss
    lams = {}
    for i, t in enumerate(sf.ts[:-1]):
        # e.g (week 39 - week 40) / week 39
        # where week 39 >= week 40
        lams[t] = (ss[i] - ss[i+1]) / ss[i]
    return HazardFunction.from_dict(lams)

In [None]:
hf = make_hazard(sf)

In [None]:
r3(hf(39))

So of all pregnancies that proceed until week 39, about 50% end in week 39.

In [None]:
plt.plot(
    hf.ts,
    hf.ss,
    label='hazard'
);
plt.xlabel('t(weeks)');
plt.xticks(np.arange(0, 55, 5));
plt.title('Hazard function for pregnancy lengths');

A plot of the hazard function for pregnancy lengths. For times after week 42, the hazard function is erratic because it is based on a small number of cases. Other than that the shape of the curve is as expected: it is highest around 39 weeks, and a little higher in the first trimester than in the second.

The hazard function is useful in its own right, but it is also an important tool for estimating survival curves

### Inferring survival curves

If someone gives you the CDF of lifetimes, it is easy to compute the survival and hazard functions. But in many real-world scenarios, we can’t measure the distribution of lifetimes directly. We have to infer it.

For example, suppose you are following a group of patients to see how long they survive after diagnosis. Not all patients are diagnosed on the same day, so at any point in time, some patients have survived longer than others. If some patients have died, we know their survival times.

For patients who are still alive, we don’t know survival times, but we have a lower bound. If we wait until all patients are dead, we can compute the survival curve, but if we are evaluating the effectiveness of a new treatment, we can’t wait that long! We need a way to estimate survival curves using incomplete information.


In [None]:
def estimate_hazard(complete: List[int], ongoing: List[int]) -> HazardFunction:
    '''
    Estimate a hazard function based on known and unknown outcomes
    
    complete: set of complete observations
        (e.g the ages when respondents got married)
    ongoing: set of incomplete observations
        (e.g the ages of unmarried women when they were interviewed)
    '''
    
    # frequencies, duration => frequency
    # e.g age => number married at that age
    hist_complete = Counter(complete)
    # age => number unmarried at time of interview
    hist_ongoing = Counter(ongoing)
    # union of ages
    ts = list(hist_complete | hist_ongoing)
    # keeps track of the number of respondents considered at risk at each stage
    # initially it is the total number of respondents
    at_risk = len(complete) + len(ongoing)
    # map of each age to te estimated hazard function at that age
    lams = {}
    # loop through ages in increasing order
    for t in sorted(ts):
        # consider one age t
        # i.e the number of respondents married at that age
        ended = hist_complete[t]
        # the number of women whose future marriage dates are censored (unknown)
        censored = hist_ongoing[t]
        # estimate is the fraction of cases at risk that end at t
        lams[t] = ended / at_risk
        # subtract the number of cases that ended or were censored at t
        at_risk -= ended + censored
    return HazardFunction.from_dict(lams)

### The marriage curve

In [None]:
resp = nsfg.read_fem_resp().rename(columns={
    'cmmarrhx': 'date_married',
    'cmbirth': 'dob',
    'cmintvw': 'interview_date',
    'evrmarry': 'is_married'
})

In [None]:
resp = resp.assign(
    # compute respondants age when married
    agemarry = lambda df: (df.date_married - df.dob) / 12,
    # age when interviewed
    age = lambda df: (df.interview_date - df.dob) / 12
)

In [None]:
cols_of_interest = [
    'caseid',
    'date_married',
    'dob',
    'interview_date',
    'is_married',
    'agemarry',
    'age',
    'finalwgt'
]
resp = resp.loc[:, cols_of_interest]

In [None]:
resp.apply(lambda col: col.isna().sum())

Next we extract complete, which is the age at marriage for women who have been married, and ongoing, which is the age at interview for women who have not:

In [None]:
complete = r2(resp[resp.is_married==1].agemarry.dropna())
ongoing = r2(resp[resp.is_married==0].age.dropna())

Finally, we compute the hazard function

In [None]:
hf = estimate_hazard(complete, ongoing)

In [None]:
plt.plot(
    hf.ts,
    hf.ss
);
plt.xlabel('age (years)');
plt.ylabel('hazard');
plt.title('Hazard function for age at first marriage');

### Estimating the survival curve

Once we have the hazard function, we can estimate the survival curve. The chance of surviving past time t is the chance of surviving all times up through t, which is the cumulative product of the complementary hazard function:

$$
[1-\lambda(0)][1-\lambda(1)]\cdots[1-\lambda(t)]
$$

In [None]:
def make_survival(hf: HazardFunction) -> SurvivalFunction:
    ts = hf.series.index
    # cumulative product of the complimentry hazard function
    ss = (1 - hf.series.values).cumprod()
    # compliment of this to make a cdf, and then instantiate a survival function
    return SurvivalFunction(Cdf(ts, 1-ss))

In [None]:
sf = make_survival(hf)

In [None]:
plt.plot(
    sf.ts,
    sf.ss
);
plt.xlabel('age (years)');
plt.ylabel('prob unmarried');
plt.yticks(np.linspace(0, 1, 11));
plt.title('Survival curve for age at first marriage');

The survival curve is steepest between 25 and 35, when most women get married. Between 35 and 45, the curve is nearly flat, indicating that women who do not marry before age 35 are unlikely to get married.

## Confidence intervals

Kaplan-Meier analysis yields a single estimate of the survival curve, but it is also important to quantify the uncertainty of the estimate. As usual, there are three possible sources of error: measurement error, sampling error, and modeling error.

In this example, measurement error is probably small. People generally know when they were born, whether they’ve been married, and when. And they can be expected to report this information accurately.

We can quantify sampling error by resampling.

In [None]:
def estimate_survival(resp: pd.DataFrame) -> SurvivalFunction:
    """Estimates the survival curve.
    resp: DataFrame of respondents
    returns: pair of HazardFunction, SurvivalFunction
    """
    complete = resp[resp.is_married == 1].agemarry.dropna()
    ongoing = resp[resp.is_married == 0].age

    hf = estimate_hazard(complete, ongoing)
    return make_survival(hf)

In [None]:
def resample_survival(resp, iters=101):
    low, high = resp.agemarry.min(), resp.agemarry.max()
    ts = np.arange(low, high, 1/12)
    # sequence of evaluated survival curves
    ss_seq = []
    for i in range(iters):
        sample = resample_rows_weighted(resp, 'finalwgt')
        sf = estimate_survival(sample)
        ss_seq.append(sf(ts))
    return ts, np.array(ss_seq)

In [None]:
# estimated survival curves
ts, s_curves = resample_survival(resp)

In [None]:
# actual survival curve
sf = estimate_survival(resp)

In [None]:
s_curves.shape

In [None]:
lows, highs = np.percentile(s_curves, [5, 95], axis=0)

In [None]:
plt.plot(
    sf.ts,
    sf.ss,
    color='darkred'
);
plt.fill_between(
    ts,
    lows,
    highs,
    color='gray'
)
  
plt.xlabel('age (years)');
plt.ylabel('prob unmarried');
plt.yticks(np.linspace(0, 1, 11));