In [1]:
#!/usr/bin/env python3
# dhs_marital_status_eastern.py

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import math
import json

# ------------------------------------------------------------------
# 1. CONFIGURATION
# ------------------------------------------------------------------
DATA_DIR = "data"

# Categories Mapping
STATUS_MAP = {
    0: "Never in union",
    1: "Married",
    2: "Living with partner",
    3: "Widowed",
    4: "Divorced",
    5: "Separated"
}

# Plot Order (Bottom to Top)
PLOT_ORDER = [
    "Never in union", 
    "Married", 
    "Living with partner", 
    "Widowed", 
    "Divorced", 
    "Separated"
]

COLORS = [
    "#0070C0", # Blue
    "#C0504D", # Red
    "#00B050", # Green
    "#7030A0", # Purple
    "#4BACC6", # Light Blue
    "#F79646"  # Orange
]

DATASETS = [
    {
        "label": "Women",
        "file":  "RWIR81FL.DTA",
        "max_age": 49,
        "vars": {"wt": "v005", "reg": "v024", "dist": "sdistrict", "status": "v501", "age": "v012"}
    },
    {
        "label": "Men",
        "file":  "RWMR81FL.DTA",
        "max_age": 59,
        "vars": {"wt": "mv005", "reg": "mv024", "dist": "smdistrict", "status": "mv501", "age": "mv012"}
    }
]

# Eastern Province District Mapping
DIST_MAP = {
    51: 'Rwamagana', 
    52: 'Nyagatare', 
    53: 'Gatsibo',
    54: 'Kayonza',
    55: 'Kirehe',
    56: 'Ngoma',
    57: 'Bugesera'
}

# ------------------------------------------------------------------
# 2. CALCULATION
# ------------------------------------------------------------------
def standard_round(n):
    return int(math.floor(n + 0.5))

def get_distribution(df, status_col, wt_col):
    if df.empty or df[wt_col].sum() == 0:
        return pd.Series({k: 0 for k in PLOT_ORDER})
    
    counts = df.groupby(status_col)[wt_col].sum()
    total_wt = counts.sum()
    pcts = (counts / total_wt) * 100
    
    pcts.index = pcts.index.map(STATUS_MAP)
    pcts = pcts.reindex(PLOT_ORDER, fill_value=0)
    return pcts.apply(standard_round)

# ------------------------------------------------------------------
# 3. ANALYSIS & PLOTTING
# ------------------------------------------------------------------
def analyze_dataset(config):
    label = config['label']
    file_path = os.path.join(DATA_DIR, config['file'])
    v = config['vars']
    max_age = config['max_age']

    print(f"\n--- Processing: {label} (15-{max_age}) ---")
    
    if not os.path.exists(file_path):
        print(f"❌ Error: {file_path} not found.")
        return

    try:
        df = pd.read_stata(file_path, convert_categoricals=False)
        df.columns = df.columns.str.lower()
    except Exception as e:
        print(f"❌ Error loading file: {e}")
        return

    # --- AGE FILTER ---
    df = df[(df[v['age']] >= 15) & (df[v['age']] <= max_age)]
    df["w"] = df[v['wt']] / 1000000.0

    # Aggregation
    results = {}

    # A. Districts (Eastern Province = Region 5)
    df_east = df[df[v['reg']] == 5]
    if v['dist'] in df_east.columns:
        for dist_code, dist_name in DIST_MAP.items():
            subset = df_east[df_east[v['dist']] == dist_code]
            results[dist_name] = get_distribution(subset, v['status'], "w")
    
    # B. Province & National
    results["Eastern Province"] = get_distribution(df_east, v['status'], "w")
    results["Rwanda (National)"] = get_distribution(df, v['status'], "w")

    final_df = pd.DataFrame(results).T
    final_df = final_df[PLOT_ORDER]
    
    # --- JSON OUTPUT ---
    json_name = f"Marital_Status_{label}_Eastern.json"
    output_dict = {
        "indicator": f"Percentage distribution of {label} (15-{max_age}) by current marital status",
        "unit": "Percentage (%)",
        "data": final_df.to_dict(orient='index')
    }
    with open(json_name, "w") as f:
        json.dump(output_dict, f, indent=4)
    print(f"✅ JSON saved: {json_name}")

    # --- PLOTTING ---
    ax = final_df.plot(kind="bar", stacked=True, color=COLORS, figsize=(14, 8), 
                       width=0.7, edgecolor="white", linewidth=0.5)

    plt.title(f"Percentage distribution of {label} (15-{max_age}) by marital status\n(Eastern Province & National)", 
              fontsize=16, fontweight="bold", pad=20)
    
    # Rotation for 7 districts
    plt.xticks(rotation=45, ha='right', fontsize=11)
    plt.grid(axis="y", ls="--", alpha=0.3)
    
    ax.yaxis.set_visible(False)
    for s in ["top", "right", "left"]: 
        ax.spines[s].set_visible(False)

    plt.legend(ncol=3, loc="upper center", bbox_to_anchor=(0.5, -0.15), 
               frameon=False, fontsize=11)

    # Labels (Only >= 3% for readability in stacked bars)
    for c in ax.containers:
        labels = [f"{int(h)}" if h >= 3 else "" for h in c.datavalues]
        ax.bar_label(c, labels=labels, label_type='center', 
                     fontsize=10, fontweight="bold", color="white")

    plt.tight_layout()
    png_name = f"Marital_Status_{label}_Eastern.png"
    plt.savefig(png_name, dpi=300)
    plt.close()
    print(f"✅ Graph saved: {png_name}")

if __name__ == "__main__":
    for ds in DATASETS:
        analyze_dataset(ds)
    print("\n✅ Analysis Complete.")


--- Processing: Women (15-49) ---


  df["w"] = df[v['wt']] / 1000000.0


✅ JSON saved: Marital_Status_Women_Eastern.json
✅ Graph saved: Marital_Status_Women_Eastern.png

--- Processing: Men (15-59) ---


  df["w"] = df[v['wt']] / 1000000.0


✅ JSON saved: Marital_Status_Men_Eastern.json
✅ Graph saved: Marital_Status_Men_Eastern.png

✅ Analysis Complete.
