# Survival Analysis for Churn Prediction

This notebook demonstrates how to use survival analysis (Cox Proportional Hazards) for churn prediction.

## Key Differences from Classification:
- **One row per user** (instead of one row per payment)
- **Time-to-event** (duration in days until churn)
- **Censoring** (users still active have censored observations)
- **Hazard ratios** (interpretable risk factors)

In [None]:
import pandas as pd
import numpy as np
from google.cloud import bigquery
import matplotlib.pyplot as plt
import seaborn as sns

# Survival analysis libraries
from lifelines import CoxPHFitter, KaplanMeierFitter
from lifelines.statistics import logrank_test
from lifelines.utils import median_survival_times

sns.set_style('whitegrid')
%matplotlib inline

## 1. Load Survival Data from BigQuery

In [None]:
client = bigquery.Client(project='lily-demo-ml')
query = "SELECT * FROM `lily-demo-ml.churn.survival_input`"
df = client.query(query).to_dataframe()

print(f"Total users: {len(df)}")
print(f"Churned: {df['churned'].sum()} ({df['churned'].mean():.1%})")
print(f"Censored: {(1-df['churned']).sum()} ({(1-df['churned']).mean():.1%})")
df.head()

## 2. Exploratory Data Analysis

In [None]:
# Distribution of duration
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Duration by status
df.groupby('status')['duration_days'].hist(bins=50, ax=axes[0], alpha=0.7)
axes[0].set_xlabel('Duration (days)')
axes[0].set_ylabel('Count')
axes[0].set_title('Distribution of Time-to-Event by Status')
axes[0].legend(['Censored', 'Churned'])

# Summary statistics
summary = df.groupby('status')['duration_days'].describe()
print("\nDuration Statistics by Status:")
print(summary)

## 3. Kaplan-Meier Survival Curves

In [None]:
# Overall survival curve
kmf = KaplanMeierFitter()
kmf.fit(df['duration_days'], df['churned'], label='All Users')

plt.figure(figsize=(10, 6))
kmf.plot_survival_function()
plt.title('Kaplan-Meier Survival Curve')
plt.ylabel('Probability of Retention')
plt.xlabel('Days Since First Payment')
plt.grid(True, alpha=0.3)

# Median survival time
median_survival = kmf.median_survival_time_
print(f"\nMedian customer lifetime: {median_survival:.0f} days ({median_survival/30:.1f} months)")

## 4. Survival Curves by Feature Groups

In [None]:
# Create feature groups (e.g., high vs low f_0)
df['f_0_group'] = pd.cut(df['f_0'], bins=3, labels=['Low', 'Medium', 'High'])

plt.figure(figsize=(10, 6))
for group in ['Low', 'Medium', 'High']:
    mask = df['f_0_group'] == group
    kmf = KaplanMeierFitter()
    kmf.fit(df[mask]['duration_days'], df[mask]['churned'], label=f'f_0: {group}')
    kmf.plot_survival_function()

plt.title('Survival Curves by f_0 Feature Group')
plt.ylabel('Probability of Retention')
plt.xlabel('Days Since First Payment')
plt.grid(True, alpha=0.3)
plt.legend()

# Log-rank test
low = df[df['f_0_group'] == 'Low']
high = df[df['f_0_group'] == 'High']
results = logrank_test(low['duration_days'], high['duration_days'], 
                       low['churned'], high['churned'])
print(f"\nLog-rank test p-value: {results.p_value:.4f}")

## 5. Cox Proportional Hazards Model

In [None]:
# Prepare features
features = ['f_0', 'f_1', 'f_2', 'f_3', 'f_4', 'total_payments', 'signup_month']

# Train/test split (time-based)
df_sorted = df.sort_values('first_payment_date')
train_size = int(0.7 * len(df_sorted))
df_train = df_sorted.iloc[:train_size].copy()
df_test = df_sorted.iloc[train_size:].copy()

print(f"Train: {len(df_train)} users, Test: {len(df_test)} users")

# Fit Cox model
cph = CoxPHFitter(penalizer=0.1)
cph.fit(df_train[features + ['duration_days', 'churned']], 
        duration_col='duration_days', 
        event_col='churned')

print("\n" + "="*60)
print("Cox Proportional Hazards Model Summary")
print("="*60)
cph.print_summary()

## 6. Hazard Ratios Interpretation

In [None]:
# Plot hazard ratios
plt.figure(figsize=(10, 6))
cph.plot()
plt.title('Hazard Ratios (95% CI)')
plt.axvline(1, color='red', linestyle='--', alpha=0.5, label='No effect')
plt.xlabel('Hazard Ratio')
plt.legend()
plt.tight_layout()

print("\nHazard Ratio Interpretation:")
print("-" * 60)
for feature in features:
    hr = np.exp(cph.params_[feature])
    if hr > 1:
        print(f"{feature:20s}: HR={hr:.3f} → {(hr-1)*100:+.1f}% increased churn risk")
    else:
        print(f"{feature:20s}: HR={hr:.3f} → {(1-hr)*100:+.1f}% decreased churn risk")

## 7. Model Evaluation

In [None]:
# Concordance index (C-index)
train_ci = cph.concordance_index_
test_ci = cph.score(df_test[features + ['duration_days', 'churned']], scoring_method='concordance_index')

print(f"\nModel Performance:")
print(f"Train C-index: {train_ci:.4f}")
print(f"Test C-index:  {test_ci:.4f}")
print(f"\nC-index interpretation: {test_ci:.4f} = {test_ci*100:.1f}% probability that model")
print(f"correctly ranks pairs of users by their churn risk")

## 8. Survival Predictions for Individual Users

In [None]:
# Predict survival function for specific users
sample_users = df_test[features].head(5)

plt.figure(figsize=(12, 6))
for idx, row in sample_users.iterrows():
    surv_func = cph.predict_survival_function(row.to_frame().T)
    plt.plot(surv_func.index, surv_func.values.flatten(), label=f"User {idx}")

plt.xlabel('Days Since First Payment')
plt.ylabel('Survival Probability')
plt.title('Predicted Survival Curves for Sample Users')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()

## 9. Comparison: Survival vs Classification

| Aspect | Survival Analysis | Classification (Current) |
|--------|------------------|-------------------------|
| **Data structure** | One row per user | One row per payment |
| **Target** | Time to churn + event | Binary churn (0/1) |
| **Censoring** | Handles naturally | Not applicable |
| **Interpretation** | Hazard ratios, median survival | Churn probability |
| **Sample size** | ~5K users | ~190K observations |
| **Use case** | Lifetime value, risk factors | Monthly intervention |
| **Time-varying** | Limited support | Native (months_since_signup) |
| **Actionability** | Strategic (long-term) | Tactical (immediate) |

### Recommendation:
- Use **Classification** for: Monthly churn prediction, targeted interventions
- Use **Survival Analysis** for: Understanding customer lifetime, strategic planning, identifying risk factors
- **Best approach**: Use both complementarily!