---
title: "Logistic Regression and Survival Analysis"
output-file: "04_logistic_and_survival.html"
format: html
---

# 📊 4.6 Logistic Regression and Survival Analysis

This notebook introduces logistic regression and survival analysis for nutrition research, focusing on binary outcomes and time-to-event data.

**Objectives**:
- Apply logistic regression to predict binary outcomes.
- Perform survival analysis to model time-to-event data.
- Use `vitamin_trial.csv` to analyse vitamin D trial outcomes.

**Context**: Logistic regression predicts outcomes like improved health, while survival analysis models time to events, such as response to treatment, in nutrition studies.

<details><summary>Fun Fact</summary>
Hippos may not run clinical trials, but their vitamin D data helps us model health outcomes with statistical flair! 🦛
</details>

In [None]:
# Setup for Google Colab: Fetch datasets automatically or manually
%run ../../bootstrap.py    # installs requirements + editable package

import fns_toolkit as fns

import pandas as pd
import numpy as np

## 📥 Data Preparation

We’ll load `vitamin_trial.csv`, which contains data from a vitamin D trial. The dataset includes:
- `ID`: Participant identifier.
- `Group`: Control or Treatment.
- `Vitamin_D`: Vitamin D level (ng/mL).
- `Time`: Time to outcome (months).
- `Outcome`: Normal or Improved.

Let’s load and inspect the data to prepare for analysis.

In [None]:
# Load the dataset
df = fns.get_dataset('vitamin_trial')  # Path relative to the current working directory (notebooks/04_data_analysis/)

# Display basic information about the dataset
print(f'Data shape: {df.shape}')  # Show the number of rows and columns
print(f'Sample row: ID={df.iloc[0]["ID"]}, Group={df.iloc[0]["Group"]}, Vitamin_D={df.iloc[0]["Vitamin_D"]}, Time={df.iloc[0]["Time"]}, Outcome={df.iloc[0]["Outcome"]}')  # Show the first row for inspection

# Display the first few rows to understand the data structure
df.head()

## 📈 Logistic Regression

We’ll use logistic regression to model the probability of the `Outcome` being "Improved" (binary outcome) based on predictors `Vitamin_D` and `Group`. Logistic regression is ideal for binary classification tasks, predicting the log-odds of the outcome.

**Steps**:
1. Encode categorical variables (`Group` and `Outcome`) as numerical values.
2. Fit a logistic regression model.
3. Interpret the coefficients.

In [None]:
# Import library for logistic regression
from sklearn.linear_model import LogisticRegression  # Scikit-learn's logistic regression model
from sklearn.preprocessing import LabelEncoder      # For encoding categorical variables as numbers

# Encode categorical variables
# Convert 'Group' (Control/Treatment) to numerical values: Control=0, Treatment=1
le_group = LabelEncoder()
df['Group_Encoded'] = le_group.fit_transform(df['Group'])

# Convert 'Outcome' (Normal/Improved) to numerical values: Normal=0, Improved=1
le_outcome = LabelEncoder()
df['Outcome_Encoded'] = le_outcome.fit_transform(df['Outcome'])

# Prepare features (X) and target (y) for the model
X = df[['Vitamin_D', 'Group_Encoded']]  # Features: Vitamin D level and Group
y = df['Outcome_Encoded']               # Target: Outcome (0 or 1)

# Fit the logistic regression model
model = LogisticRegression(random_state=42)  # random_state ensures reproducibility
model.fit(X, y)                              # Train the model on the data

# Print the coefficients
# Positive coefficients indicate an increase in the predictor increases the log-odds of 'Improved'
print('Logistic Regression Coefficients:')
print(f'- Vitamin_D: {model.coef_[0][0]:.3f}')  # Coefficient for Vitamin_D
print(f'- Group (Treatment): {model.coef_[0][1]:.3f}')  # Coefficient for Group (Treatment vs Control)

# Interpretation
print('\nInterpretation:')
print(f'- A 1-unit increase in Vitamin_D changes the log-odds of "Improved" by {model.coef_[0][0]:.3f}.')
print(f'- Being in the Treatment group (vs Control) changes the log-odds of "Improved" by {model.coef_[0][1]:.3f}.')

## ⏳ Survival Analysis

Survival analysis models the time to an event. Here, we’ll estimate Kaplan-Meier survival curves to model the time to `Outcome` = "Improved", stratified by `Group`. The Kaplan-Meier method is a non-parametric approach to estimate survival probabilities over time.

**Steps**:
1. Create an event indicator (1 if Outcome = Improved, 0 otherwise).
2. Fit Kaplan-Meier curves for each group.
3. Plot the survival curves to compare groups.

In [None]:
# Import libraries for survival analysis and plotting
from lifelines import KaplanMeierFitter  # For Kaplan-Meier survival analysis
import matplotlib.pyplot as plt         # For plotting survival curves

# Create an event indicator
# Event = 1 if Outcome is 'Improved', 0 if 'Normal'
df['Event'] = df['Outcome'].apply(lambda x: 1 if x == 'Improved' else 0)

# Initialize the Kaplan-Meier fitter
kmf = KaplanMeierFitter()

# Set up the plot
plt.figure(figsize=(8, 6))

# Fit and plot survival curves for each group
for group in ['Control', 'Treatment']:
    # Create a mask to filter data for the current group
    mask = df['Group'] == group
    # Fit the Kaplan-Meier model to the group's data
    kmf.fit(df[mask]['Time'], df[mask]['Event'], label=group)
    # Plot the survival curve
    kmf.plot_survival_function()

# Add plot labels and title
plt.title('Kaplan-Meier Survival Curves by Group')
plt.xlabel('Time (Months)')
plt.ylabel('Survival Probability (Not Improved)')
plt.grid(True)  # Add a grid for readability
plt.show()      # Display the survival curves

# Note: In this context, 'survival' means the probability of not having the event (Outcome = Improved).
# A lower curve indicates a higher probability of 'Improved' occurring earlier.

## 🧪 Exercises: Extend the Analysis

Let’s deepen your understanding with two tasks:

1. **Extend Logistic Regression**: Add `Time` as a predictor in the logistic regression model and report the new coefficients.
2. **Survival Analysis**: Compute the median survival time (time to 50% probability of not having the event) for each group.

**Guidance**:
- For the logistic regression, include `Time` in the feature matrix `X` and re-fit the model.
- For survival analysis, use `kmf.median_survival_time_` after fitting the Kaplan-Meier model to get the median survival time.

**Your Answers**:

**Exercise 1: Extend Logistic Regression**  
Add `Time` as a predictor and report the coefficients.

```python
# Extend the feature matrix to include Time
X_extended = df[['Vitamin_D', 'Group_Encoded', 'Time']]
y_extended = df['Outcome_Encoded']

# Fit the extended logistic regression model
model_extended = LogisticRegression(random_state=42)
model_extended.fit(X_extended, y_extended)

# Print the coefficients
print('Extended Logistic Regression Coefficients:')
print(f'- Vitamin_D: {model_extended.coef_[0][0]:.3f}')
print(f'- Group (Treatment): {model_extended.coef_[0][1]:.3f}')
print(f'- Time: {model_extended.coef_[0][2]:.3f}')
```

**Coefficients**:
- Vitamin_D: [Your Result]
- Group (Treatment): [Your Result]
- Time: [Your Result]

**Exercise 2: Median Survival Times**  
Compute the median survival time for each group.

```python
# Compute median survival times
kmf = KaplanMeierFitter()
for group in ['Control', 'Treatment']:
    mask = df['Group'] == group
    kmf.fit(df[mask]['Time'], df[mask]['Event'], label=group)
    median_time = kmf.median_survival_time_
    print(f'Median survival time for {group}: {median_time:.1f} months')
```

**Median Survival Times**:
- Control: [Your Result] months
- Treatment: [Your Result] months

## Conclusion

You’ve applied logistic regression and survival analysis to model vitamin D trial outcomes, uncovering predictors of improvement and time-to-event patterns. These techniques are powerful for understanding health outcomes in nutrition research.

**Next Steps**: Explore clinical trial analysis in `4.7_clinical_trial_analysis.ipynb` or dive into advanced topics in `notebooks/05_advanced/`.

**Resources**:
- [Scikit-Learn Documentation](https://scikit-learn.org/)
- [Lifelines Documentation](https://lifelines.readthedocs.io/)
- Repository: [github.com/ggkuhnle/data-analysis-toolkit-FNS](https://github.com/ggkuhnle/data-analysis-toolkit-FNS)