# Week 11 - Survival Analysis - Daniel Solis Toro
## Exercise 13-1

In NSFG Cycles 6 and 7, the variable endivorex contains the date of divorce for the respondent's first marriage, if applicable, encoded in century-months.

Compute the duration of marriages that have ended in divorce, and the duration, so far, of marriages that are ongoing. Estimate the hazard and survival function for the duration of marriage. 

Use resampling to take into account sampling weights, and plot data from several re-samples to visualize sampling error. Consider dividing the respondents into groups by decade of birth, and possibly by age at first marriage.

In [21]:
# Import libraries
import thinkstats2
import thinkplot
import numpy as np
import pandas as pd
import survival
import matplotlib.pyplot as plt
from lifelines import KaplanMeierFitter, NelsonAalenFitter
from lifelines.utils import datetimes_to_durations
from sklearn.utils import resample

# Download Data sets
from os.path import basename, exists

def download(url):
    filename = basename(url)
    if not exists(filename):
        from urllib.request import urlretrieve

        local, _ = urlretrieve(url, filename)
        print("Downloaded " + local)

download("https://github.com/AllenDowney/ThinkStats2/raw/master/code/2006_2010_FemRespSetup.dct")
download("https://github.com/AllenDowney/ThinkStats2/raw/master/code/2006_2010_FemResp.dat.gz")
import nsfg

# Load Data set
resp6 = survival.ReadFemResp2002()
resp7 = survival.ReadFemResp2010()
df = [resp6, resp7]

# Convert century-months to datetime
df['marriage_date'] = pd.to_datetime((df['cmmarrhx'] - 1) * 30, origin='1900-01-01', unit='D')
df['divorce_date'] = pd.to_datetime((df['cmdivorcx'] - 1) * 30, origin='1900-01-01', unit='D')

# Calculate duration for divorced and ongoing marriages
current_date = pd.to_datetime('today')
df['duration'] = np.where(df['divorce_date'].isna(),
                          (current_date - df['marriage_date']).dt.days / 30,
                          (df['divorce_date'] - df['marriage_date']).dt.days / 30)
df['event_occurred'] = df['divorce_date'].notna().astype(int)

# Estimate survival function using Kaplan-Meier
kmf = KaplanMeierFitter()
kmf.fit(df['duration'], event_observed=df['event_occurred'], label='Kaplan Meier Estimate')
kmf.plot_survival_function()
plt.title('Survival Function of Marriage Duration')
plt.xlabel('Duration in Months')
plt.ylabel('Survival Probability')
plt.show()

# Estimate hazard function using Nelson-Aalen
naf = NelsonAalenFitter()
naf.fit(df['duration'], event_observed=df['event_occurred'], label='Nelson-Aalen Estimate')
naf.plot_hazard()
plt.title('Hazard Function of Marriage Duration')
plt.xlabel('Duration in Months')
plt.ylabel('Hazard Rate')
plt.show()

# Resampling to account for sampling weights
n_resamples = 100
resampled_survival = []

for _ in range(n_resamples):
    resampled_df = resample(df, replace=True, n_samples=len(df), random_state=_)
    kmf_resampled = KaplanMeierFitter()
    kmf_resampled.fit(resampled_df['duration'], event_observed=resampled_df['event_occurred'])
    resampled_survival.append(kmf_resampled.survival_function_)

# Plot resampled survival functions
plt.figure()
for survival in resampled_survival:
    plt.step(survival.index, survival.values.flatten(), color='gray', alpha=0.1)
kmf.plot_survival_function(ci_show=False, color='blue', linewidth=2)
plt.title('Resampled Survival Functions')
plt.xlabel('Duration in Months')
plt.ylabel('Survival Probability')
plt.show()

# Group analysis by decade of birth and age at first marriage
df['decade_of_birth'] = (df['birth_year'] // 10) * 10
df['age_at_first_marriage'] = (df['marriage_date'].dt.year - df['birth_year'])

# Example: Plot survival functions by decade of birth
plt.figure()
for decade in df['decade_of_birth'].unique():
    mask = df['decade_of_birth'] == decade
    kmf.fit(df[mask]['duration'], event_observed=df[mask]['event_occurred'], label=f'Decade {decade}')
    kmf.plot_survival_function()
plt.title('Survival Function by Decade of Birth')
plt.xlabel('Duration in Months')
plt.ylabel('Survival Probability')
plt.legend()
plt.show()



The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  resp.cmmarrhx.replace([9997, 9998, 9999], np.nan, inplace=True)
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  resp.cmmarrhx.replace([9997, 9998, 9999], np.nan, inplace=True)


TypeError: list indices must be integers or slices, not str

In [44]:
#Turn off panda 3.0 future warnings 
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [46]:
from __future__ import print_function

import pandas
import numpy as np

import thinkplot
import thinkstats2
import survival

def CleanData(resp):
    """Cleans respondent data.

    resp: DataFrame
    """
    # Replace invalid values in cmdivorcx with NaN
    resp.cmdivorcx.replace([9998, 9999], np.nan, inplace=True)

    # Create a flag for ongoing marriages
    resp['notdivorced'] = resp.cmdivorcx.isnull().astype(int)

    # Calculate duration for divorced marriages
    resp['duration'] = (resp.cmdivorcx - resp.cmmarrhx) / 12.0

    # Calculate duration so far for ongoing marriages
    resp['durationsofar'] = (resp.cmintvw - resp.cmmarrhx) / 12.0

    # Handle NaNs in durationsofar (ongoing marriages)
    # Drop rows where durationsofar is NaN (invalid data)
    resp.dropna(subset=['durationsofar'], inplace=True)

    # Add decade of birth
    month0 = pandas.to_datetime('1899-12-15')
    dates = [month0 + pandas.DateOffset(months=cm) 
             for cm in resp.cmbirth]
    resp['decade'] = (pandas.DatetimeIndex(dates).year - 1900) // 10

def ResampleDivorceCurve(resps):
    """Plots divorce curves based on resampled data.

    resps: list of respondent DataFrames
    """
    for _ in range(41):
        samples = [thinkstats2.ResampleRowsWeighted(resp) 
                   for resp in resps]
        sample = pandas.concat(samples, ignore_index=True)
        PlotDivorceCurveByDecade(sample, color='#225EA8', alpha=0.1)

    thinkplot.Show(xlabel='years',
                   axis=[0, 28, 0, 1])

def ResampleDivorceCurveByDecade(resps):
    """Plots divorce curves for each birth cohort.

    resps: list of respondent DataFrames    
    """
    for i in range(41):
        samples = [thinkstats2.ResampleRowsWeighted(resp) 
                   for resp in resps]
        sample = pandas.concat(samples, ignore_index=True)
        groups = sample.groupby('decade')
        if i == 0:
            survival.AddLabelsByDecade(groups, alpha=0.7)

        EstimateSurvivalByDecade(groups, alpha=0.1)

    thinkplot.Save(root='survival7',
                   xlabel='years',
                   axis=[0, 28, 0, 1])

def EstimateSurvivalByDecade(groups, **options):
    """Groups respondents by decade and plots survival curves.

    groups: GroupBy object
    """
    thinkplot.PrePlot(len(groups))
    for name, group in groups:
        print(name, len(group))
        _, sf = EstimateSurvival(group)
        thinkplot.Plot(sf, **options)

def EstimateSurvival(resp):
    """Estimates the survival curve.

    resp: DataFrame of respondents

    returns: pair of HazardFunction, SurvivalFunction
    """
    complete = resp[resp.notdivorced == 0].duration
    ongoing = resp[resp.notdivorced == 1].durationsofar

    # Ensure no NaNs in ongoing durations
    if ongoing.isnull().any():
        raise ValueError("ongoing durations still contain NaNs after cleaning")

    hf = survival.EstimateHazardFunction(complete, ongoing)
    sf = hf.MakeSurvival()

    return hf, sf

def main():
    resp6 = survival.ReadFemResp2002()
    CleanData(resp6)
    married6 = resp6[resp6.evrmarry==1]

    resp7 = survival.ReadFemResp2010()
    CleanData(resp7)
    married7 = resp7[resp7.evrmarry==1]

    ResampleDivorceCurveByDecade([married6, married7])

if __name__ == '__main__':
    main()

5 483
6 4186
7 3755
8 1174
9 9
5 516
6 4205
7 3721
8 1157
9 8
5 477
6 4208
7 3739
8 1172
9 11
5 507
6 4080
7 3793
8 1222
9 5
5 488
6 4210
7 3782
8 1121
9 6
5 509
6 4135
7 3800
8 1162
9 1
5 529
6 4101
7 3810
8 1160
9 7
5 462
6 4214
7 3691
8 1232
9 8
5 486
6 4166
7 3806
8 1136
9 13
5 533
6 4165
7 3775
8 1125
9 9
5 518
6 4260
7 3709
8 1113
9 7
5 510
6 4134
7 3758
8 1193
9 12
5 476
6 4156
7 3782
8 1185
9 8
5 519
6 4201
7 3718
8 1163
9 6
5 503
6 4182
7 3763
8 1147
9 12
5 489
6 4146
7 3723
8 1243
9 6
5 539
6 4212
7 3714
8 1130
9 12
5 520
6 4143
7 3719
8 1212
9 13
5 525
6 4202
7 3675
8 1196
9 9
5 494
6 4134
7 3788
8 1183
9 8
5 519
6 4137
7 3738
8 1210
9 3
5 530
6 4092
7 3777
8 1192
9 16
5 528
6 4161
7 3772
8 1135
9 11
5 504
6 4075
7 3786
8 1234
9 8
5 521
6 4132
7 3749
8 1200
9 5
5 488
6 4187
7 3827
8 1095
9 10
5 508
6 4172
7 3729
8 1191
9 7
5 504
6 4172
7 3784
8 1146
9 1
5 546
6 4174
7 3740
8 1131
9 16
5 511
6 4187
7 3748
8 1154
9 7
5 492
6 4278
7 3659
8 1169
9 9
5 524
6 4134
7 3714
8 1226
9 

<Figure size 800x600 with 0 Axes>