# 01 - Data Exploration

This notebook explores the DHS maternal health data and prepares it for fairness analysis.

## Objectives
1. Load DHS data from multiple African countries
2. Explore demographic distributions
3. Create the high-risk pregnancy indicator
4. Examine outcome distributions across demographic groups

In [None]:
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Import our custom modules
import sys
sys.path.append('..')
from src.data_loader import (
    load_dhs_data, 
    load_multiple_countries,
    create_high_risk_indicator,
    create_demographic_groups
)

# Plot settings
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')
%matplotlib inline

## 1. Load Data

**Note:** Before running this section, ensure you have:
1. Registered at [dhsprogram.com](https://dhsprogram.com/data/)
2. Downloaded the approved datasets to `../data/raw/`

DHS files follow naming conventions like:
- `NGIR7BFL.DTA` - Nigeria Individual Recode, DHS-VII, Full dataset
- `KEIR8AFL.DTA` - Kenya Individual Recode, DHS-VIII, Full dataset

In [None]:
# Define country files (update with your actual filenames)
COUNTRY_FILES = {
    'NG': 'NGIR7BFL.DTA',  # Nigeria 2018
    'KE': 'KEIR8AFL.DTA',  # Kenya 2022
    'GH': 'GHIR8AFL.DTA',  # Ghana 2022
    'UG': 'UGIR7BFL.DTA',  # Uganda 2016
    'TZ': 'TZIR8AFL.DTA',  # Tanzania 2022
}

DATA_DIR = '../data/raw/'

In [None]:
# Load data (uncomment when you have the files)
# df = load_multiple_countries(DATA_DIR, COUNTRY_FILES)

# For demonstration, create synthetic data
print("Creating synthetic demonstration data...")
print("(Replace with actual DHS data loading when files are available)")

np.random.seed(42)
n_samples = 5000

df = pd.DataFrame({
    'country': np.random.choice(['NG', 'KE', 'GH', 'UG', 'TZ'], n_samples, p=[0.35, 0.2, 0.15, 0.15, 0.15]),
    'age': np.random.normal(28, 7, n_samples).clip(15, 49).astype(int),
    'residence_type': np.random.choice([1, 2], n_samples, p=[0.35, 0.65]),  # 1=urban, 2=rural
    'education_level': np.random.choice([0, 1, 2, 3], n_samples, p=[0.15, 0.35, 0.35, 0.15]),
    'wealth_index': np.random.choice([1, 2, 3, 4, 5], n_samples),
    'total_children_born': np.random.poisson(2.5, n_samples),
    'anemia_level': np.random.choice([1, 2, 3, 4], n_samples, p=[0.05, 0.15, 0.25, 0.55]),
    'birth_interval_months': np.random.exponential(30, n_samples).astype(int),
})

print(f"\nDataset shape: {df.shape}")
df.head()

## 2. Create Demographic Groups

In [None]:
# Create readable demographic categories
df = create_demographic_groups(df)

print("Columns added:")
print([col for col in df.columns if 'group' in col or 'binary' in col or col == 'residence'])

In [None]:
# Examine distributions
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Country distribution
df['country'].value_counts().plot(kind='bar', ax=axes[0, 0], color='steelblue')
axes[0, 0].set_title('Samples by Country')
axes[0, 0].set_xlabel('')

# Wealth distribution
df['wealth_group'].value_counts().reindex(['Poorest', 'Poorer', 'Middle', 'Richer', 'Richest']).plot(
    kind='bar', ax=axes[0, 1], color='coral'
)
axes[0, 1].set_title('Wealth Quintile Distribution')
axes[0, 1].set_xlabel('')

# Urban/Rural
df['residence'].value_counts().plot(kind='bar', ax=axes[1, 0], color='seagreen')
axes[1, 0].set_title('Urban vs Rural Residence')
axes[1, 0].set_xlabel('')

# Education
df['education_group'].value_counts().reindex(['None', 'Primary', 'Secondary', 'Higher']).plot(
    kind='bar', ax=axes[1, 1], color='mediumpurple'
)
axes[1, 1].set_title('Education Level Distribution')
axes[1, 1].set_xlabel('')

plt.tight_layout()
plt.savefig('../results/figures/demographic_distributions.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. Create High-Risk Pregnancy Indicator

In [None]:
# Create the target variable
df['high_risk'] = create_high_risk_indicator(df)

print("High-risk pregnancy distribution:")
print(df['high_risk'].value_counts(normalize=True).round(3))

In [None]:
# Examine high-risk rates across groups
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# By wealth
risk_by_wealth = df.groupby('wealth_group')['high_risk'].mean().reindex(
    ['Poorest', 'Poorer', 'Middle', 'Richer', 'Richest']
)
risk_by_wealth.plot(kind='bar', ax=axes[0], color='indianred')
axes[0].set_title('High-Risk Rate by Wealth Quintile')
axes[0].set_ylabel('Proportion High-Risk')
axes[0].set_ylim(0, 1)
axes[0].axhline(y=df['high_risk'].mean(), color='black', linestyle='--', label='Overall')
axes[0].legend()

# By residence
risk_by_residence = df.groupby('residence')['high_risk'].mean()
risk_by_residence.plot(kind='bar', ax=axes[1], color='indianred')
axes[1].set_title('High-Risk Rate by Residence')
axes[1].set_ylabel('Proportion High-Risk')
axes[1].set_ylim(0, 1)
axes[1].axhline(y=df['high_risk'].mean(), color='black', linestyle='--', label='Overall')

# By country
risk_by_country = df.groupby('country')['high_risk'].mean().sort_values(ascending=False)
risk_by_country.plot(kind='bar', ax=axes[2], color='indianred')
axes[2].set_title('High-Risk Rate by Country')
axes[2].set_ylabel('Proportion High-Risk')
axes[2].set_ylim(0, 1)
axes[2].axhline(y=df['high_risk'].mean(), color='black', linestyle='--', label='Overall')

plt.tight_layout()
plt.savefig('../results/figures/risk_by_demographics.png', dpi=150, bbox_inches='tight')
plt.show()

## 4. Key Observations for Fairness Analysis

Before building models, we note:

1. **Base rate differences**: High-risk rates vary across demographic groups
2. **Sample size imbalances**: Some groups have fewer samples
3. **Intersectionality**: Combinations (e.g., poor + rural) may show compounded effects

These baseline differences will inform our fairness evaluation.

In [None]:
# Summary statistics table
summary = df.groupby('wealth_group').agg({
    'high_risk': ['count', 'mean'],
    'age': 'mean',
    'total_children_born': 'mean'
}).round(3)

summary.columns = ['N', 'High-Risk Rate', 'Mean Age', 'Mean Children']
summary = summary.reindex(['Poorest', 'Poorer', 'Middle', 'Richer', 'Richest'])

print("Summary by Wealth Quintile:")
summary

In [None]:
# Save processed data
# df.to_csv('../data/processed/maternal_health_combined.csv', index=False)
print("\nData exploration complete!")
print("Next: 02_feature_engineering.ipynb")