In [None]:
# ==========================================
#   Diet Patterns vs Cardiovascular Mortality
# ==========================================

from pathlib import Path
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(style="whitegrid")

# ------------------------------------------------
# 1. Auto-detect project root
# ------------------------------------------------
def find_project_root(marker="README.md"):
    cur = Path().resolve()
    while cur != cur.parent:
        if (cur / marker).exists():
            return cur
        cur = cur.parent
    raise RuntimeError("Project root not found.")

ROOT = find_project_root()
DB_PATH = ROOT / "data" / "db" / "global_health_nutrition.db"

print("PROJECT ROOT:", ROOT)
print("DATABASE PATH:", DB_PATH)
print("DB Exists:", DB_PATH.exists())

In [None]:
# ------------------------------------------------
# 2. Load diet + health data using SQL
# ------------------------------------------------
query = """
SELECT
    h.country,
    h.year,
    h.life_expectancy,
    h.infant_mortality_rate,
    h.under5_mortality_rate,
    h.uhc_coverage,
    d.animal_kcal,
    d.plant_kcal,
    d.fat_kcal,
    d.carb_kcal,
    d.total_fruit_consumption,
    g.pct_death_cardiovascular
FROM v_health_access h
LEFT JOIN v_diet_profiles d
    ON h.country = d.country AND h.year = d.year
LEFT JOIN v_core_clean g
    ON h.country = g.country AND h.year = g.year AND g.gender = 'Both sexes'
"""

conn = sqlite3.connect(DB_PATH)
df = pd.read_sql_query(query, conn)
conn.close()

print("Rows loaded:", len(df))
display(df.head())

In [None]:
# ------------------------------------------------
# 3. Basic dataset overview
# ------------------------------------------------
print("Columns:", df.columns.tolist())
print("Years:", df["year"].min(), "→", df["year"].max())
print("Countries:", df["country"].nunique())

display(df.describe().T.head(20))

In [None]:
# ------------------------------------------------
# 4. Latest year slice
# ------------------------------------------------
latest_year = df["year"].max()
df_latest = df[df["year"] == latest_year].copy()

print(f"\nLatest year dataset: {latest_year}, rows: {len(df_latest)}")
display(df_latest.head())

In [None]:
# ------------------------------------------------
# 5. Scatterplots — Diet vs CVD Mortality
# ------------------------------------------------
def safe_scatter(x, y, title):
    if {x, y}.issubset(df_latest.columns):
        plt.figure(figsize=(8, 6))
        sns.scatterplot(data=df_latest, x=x, y=y)
        plt.xlabel(x.replace("_", " ").title())
        plt.ylabel(y.replace("_", " ").title())
        plt.title(f"{title} — {latest_year}")
        plt.tight_layout()
        plt.show()
    else:
        print(f"Skipping: {x} vs {y} (missing column)")

safe_scatter("animal_kcal", "pct_death_cardiovascular", "Animal kcal vs CVD Death Share")
safe_scatter("plant_kcal", "pct_death_cardiovascular", "Plant kcal vs CVD Death Share")
safe_scatter("fat_kcal", "pct_death_cardiovascular", "Fat kcal vs CVD Death Share")
safe_scatter("carb_kcal", "pct_death_cardiovascular", "Carb kcal vs CVD Death Share")

In [None]:
# ------------------------------------------------
# 6. Diet composition distribution
# ------------------------------------------------
diet_cols = ["animal_kcal", "plant_kcal", "fat_kcal", "carb_kcal"]
diet_cols = [c for c in diet_cols if c in df_latest.columns]

if diet_cols:
    plt.figure(figsize=(8, 6))
    sns.boxplot(data=df_latest[diet_cols])
    plt.title(f"Diet Composition — {latest_year}")
    plt.ylabel("kcal per person per day")
    plt.xticks(rotation=20)
    plt.tight_layout()
    plt.show()
else:
    print("No diet columns available for boxplot.")

In [None]:
# ------------------------------------------------
# 7. Correlation matrix
# ------------------------------------------------
corr_cols = [
    "life_expectancy",
    "infant_mortality_rate",
    "under5_mortality_rate",
    "animal_kcal",
    "plant_kcal",
    "fat_kcal",
    "carb_kcal",
    "total_fruit_consumption",
    "pct_death_cardiovascular",
]

corr_cols = [c for c in corr_cols if c in df_latest.columns]
corr = df_latest[corr_cols].corr()

plt.figure(figsize=(10, 8))
sns.heatmap(corr, annot=True, fmt=".2f", cmap="coolwarm", center=0)
plt.title(f"Correlation Matrix — {latest_year}")
plt.tight_layout()
plt.show()

In [None]:
# ------------------------------------------------
# 8. Create a simple diet profile feature
# ------------------------------------------------
def diet_profile(row):
    if row.get("animal_kcal", 0) > row.get("plant_kcal", 0) and row.get("fat_kcal", 0) > row.get("carb_kcal", 0):
        return "Animal & Fat Heavy"
    if row.get("animal_kcal", 0) > row.get("plant_kcal", 0):
        return "Animal Protein Heavy"
    if row.get("plant_kcal", 0) > row.get("animal_kcal", 0):
        return "Plant Protein Heavy"
    return "Mixed"

df_latest["diet_profile"] = df_latest.apply(diet_profile, axis=1)
display(df_latest["diet_profile"].value_counts())

# Summary stats by diet profile
profile_stats = (
    df_latest.groupby("diet_profile")
    .agg(
        life_expectancy=("life_expectancy", "mean"),
        pct_death_cardiovascular=("pct_death_cardiovascular", "mean"),
        n=("country", "count")
    )
    .reset_index()
)

display(profile_stats)

plt.figure(figsize=(8, 5))
sns.barplot(data=profile_stats, x="diet_profile", y="life_expectancy")
plt.title(f"Life Expectancy by Diet Profile — {latest_year}")
plt.xticks(rotation=20)
plt.tight_layout()
plt.show()

plt.figure(figsize=(8, 5))
sns.barplot(data=profile_stats, x="diet_profile", y="pct_death_cardiovascular")
plt.title(f"CVD Death Share by Diet Profile — {latest_year}")
plt.xticks(rotation=20)
plt.tight_layout()
plt.show()

In [None]:
# ------------------------------------------------
# 9. Country Trend Example
# ------------------------------------------------
example_country = "Germany"  # choose any

if example_country in df["country"].unique():
    df_country = df[df["country"] == example_country].sort_values("year")

    print(f"\nDiet trends for: {example_country}")
    display(df_country.head())

    diet_cols_ts = [c for c in ["animal_kcal", "plant_kcal", "fat_kcal", "carb_kcal"] if c in df_country.columns]

    if diet_cols_ts:
        plt.figure(figsize=(10, 6))
        for col in diet_cols_ts:
            plt.plot(df_country["year"], df_country[col], label=col)
        plt.xlabel("Year")
        plt.ylabel("Calories per person per day")
        plt.title(f"Diet Composition Over Time — {example_country}")
        plt.legend()
        plt.tight_layout()
        plt.show()
else:
    print(f"Country '{example_country}' not found in dataset.")