In [None]:
import os, re, glob, json
from itertools import product
import numpy as np
import pandas as pd
import geopandas as gpd
import mercantile
from shapely.geometry import Point, box, shape
from shapely import wkt
from tqdm import tqdm
import rasterio
from rasterstats import zonal_stats

# CONFIG
PATH_DATA = "/ookla_algeria_fixed/"  
OUTPUT_DIR = "processed_data/"
os.makedirs(OUTPUT_DIR, exist_ok=True)

NET_TYPE = "fixed"
ISO_CODE = "DZA"
YEAR_MIN, YEAR_MAX = 2019, 2025
YEARS = list(range(YEAR_MIN, YEAR_MAX + 1))

# Boundaries
PATH_ADM0 = "geoBoundaries-DZA-ADM0.geojson"
PATH_ADM1 = "geoBoundaries-DZA-ADM1.geojson"
PATH_ADM3 = "geoBoundaries-DZA-ADM3.geojson"

# WorldPop
WORLDPOP_DIR = "worldpop_tifs/"
CHUNK_SIZE = 2500        
GDAL_CACHE = 128         

print("ALGERIA OOKLA DATA PREPROCESSING PIPELINE")

# UTILITIES
def tile_to_quadkey(x: int, y: int, z: int) -> str:
    qk = ""
    for i in range(z, 0, -1):
        digit, mask = 0, 1 << (i - 1)
        if (x & mask) != 0:
            digit += 1
        if (y & mask) != 0:
            digit += 2
        qk += str(digit)
    return qk


def quadkey_to_point(quadkey: str):
    t = mercantile.quadkey_to_tile(quadkey)
    b = mercantile.bounds(t)
    return Point((b.west + b.east) / 2, (b.south + b.north) / 2)


def quadkey_to_polygon(quadkey: str):
    t = mercantile.quadkey_to_tile(quadkey)
    b = mercantile.bounds(t)
    return box(b.west, b.south, b.east, b.north)


def get_country_quadkeys_at_zoom(boundary_gdf: gpd.GeoDataFrame, zoom_level: int) -> list:
    minx, miny, maxx, maxy = boundary_gdf.total_bounds
    tiles = list(mercantile.tiles(minx, miny, maxx, maxy, zoom_level))
    quadkeys, pts = [], []
    for t in tqdm(tiles, desc=f"Processing zoom {zoom_level} tiles"):
        qk = tile_to_quadkey(t.x, t.y, t.z)
        quadkeys.append(qk)
        pts.append(quadkey_to_point(qk))
    gdf = gpd.GeoDataFrame({"quadkey": quadkeys, "geometry": pts}, crs="EPSG:4326")
    gdf_f = gpd.sjoin(gdf, boundary_gdf[["geometry"]], how="inner", predicate="within")
    return gdf_f["quadkey"].astype(str).tolist()


def get_subquadkeys(parent_quadkey: str, target_zoom: int) -> list:
    current_zoom = len(parent_quadkey)
    delta = target_zoom - current_zoom
    if delta < 0:
        raise ValueError("Target zoom must be >= parent zoom")
    if delta == 0:
        return [parent_quadkey]
    out = []
    for suffix in product("0123", repeat=delta):
        out.append(parent_quadkey + "".join(suffix))
    return out


def ensure_tile_polygon_gdf(df: pd.DataFrame) -> gpd.GeoDataFrame:
    if "quadkey" in df.columns:
        df = df.copy()
        df["geometry"] = df["quadkey"].astype(str).apply(quadkey_to_polygon)
    elif {"lat", "lon"}.issubset(df.columns):
        df = df.copy()
        df["geometry"] = [Point(xy) for xy in zip(df["lon"], df["lat"])]
    else:
        raise ValueError("No geometry source (quadkey or lat/lon) found.")
    return gpd.GeoDataFrame(df, crs="EPSG:4326")


def _flatten_columns(df: pd.DataFrame) -> pd.DataFrame:
    new_cols = []
    for col in df.columns:
        if isinstance(col, tuple):
            a, b = col
            new_cols.append(f"{a}_{b}" if b else a)
        else:
            new_cols.append(col)
    df.columns = new_cols
    return df


def validate_ookla_data(df: pd.DataFrame, year: int, quarter: int) -> bool:
    required_cols = ["avg_d_kbps", "avg_u_kbps"]
    missing = [c for c in required_cols if c not in df.columns]
    if missing:
        raise ValueError(f"{year}-Q{quarter}: Missing required columns: {missing}")
    if len(df) == 0:
        print(f"Warning: {year}-Q{quarter} has no rows")
        return False
    neg_down = (df["avg_d_kbps"] < 0).sum()
    neg_up = (df["avg_u_kbps"] < 0).sum()
    if neg_down > 0 or neg_up > 0:
        print(f"Warning: {year}-Q{quarter} has {neg_down} negative download and {neg_up} negative upload values")
    extreme_down = (df["avg_d_kbps"] > 1_000_000).sum()
    extreme_up = (df["avg_u_kbps"] > 1_000_000).sum()
    if extreme_down > 0 or extreme_up > 0:
        print(f"Warning: {year}-Q{quarter} has {extreme_down} extreme download and {extreme_up} extreme upload values (>1 Gbps)")
    if "avg_lat_ms" in df.columns:
        neg_lat = (df["avg_lat_ms"] < 0).sum()
        extreme_lat = (df["avg_lat_ms"] > 1000).sum()
        if neg_lat > 0:
            print(f"Warning: {year}-Q{quarter} has {neg_lat} negative latency values")
        if extreme_lat > 0:
            print(f"Warning: {year}-Q{quarter} has {extreme_lat} extreme latency values (>1000ms)")
    return True


def validate_boundaries(boundary_gdf: gpd.GeoDataFrame, level: str, required_cols: list) -> bool:
    missing = [c for c in required_cols if c not in boundary_gdf.columns]
    if missing:
        raise ValueError(f"{level}: Missing required columns: {missing}")
    invalid = (~boundary_gdf.geometry.is_valid).sum()
    if invalid > 0:
        print(f"Warning: {level} has {invalid} invalid geometries (will be fixed with buffer(0))")
    return True


def read_geojson_no_gdal(path):
    with open(path, "r", encoding="utf-8") as f:
        gj = json.load(f)
    recs = []
    for feat in gj["features"]:
        props = feat.get("properties", {})
        geom = shape(feat["geometry"])
        recs.append({**props, "geometry": geom})
    return gpd.GeoDataFrame(recs, crs="EPSG:4326")


def aggregate_by_admin(gdf_meas: gpd.GeoDataFrame, admin_gdf: gpd.GeoDataFrame,
                       id_col: str, name_col: str, admin_level_tag: str) -> pd.DataFrame:
    joined = gpd.sjoin(gdf_meas, admin_gdf[[id_col, name_col, "geometry"]],
                       how="left", predicate="intersects")
    has_tests = "tests" in joined.columns
    has_devices = "devices" in joined.columns
    has_latency = "avg_lat_ms" in joined.columns

    agg_dict = {
        "avg_d_kbps": ["mean", "median", "count"],
        "avg_u_kbps": ["mean", "median"],
    }
    if has_latency:
        agg_dict["avg_lat_ms"] = ["mean", "median"]
    if has_tests:
        agg_dict["tests"] = "sum"
    if has_devices:
        agg_dict["devices"] = "sum"

    out = joined.groupby([id_col, name_col]).agg(agg_dict).reset_index()
    out = _flatten_columns(out)
    rename_map = {
        "avg_d_kbps_mean": "avg_download_kbps",
        "avg_d_kbps_median": "median_download_kbps",
        "avg_d_kbps_count": "num_tiles",
        "avg_u_kbps_mean": "avg_upload_kbps",
        "avg_u_kbps_median": "median_upload_kbps",
        "tests_sum": "num_tests",
        "devices_sum": "num_devices",
        "avg_lat_ms_mean": "avg_latency_ms",
        "avg_lat_ms_median": "median_latency_ms",
    }
    out = out.rename(columns=rename_map)
    out["admin_level"] = admin_level_tag
    out["admin_code"] = out[id_col].astype(str)
    out["admin_name"] = out[name_col].astype(str)

    if "avg_download_kbps" in out.columns:
        out["avg_download_mbps"] = out["avg_download_kbps"] / 1000.0
    if "avg_upload_kbps" in out.columns:
        out["avg_upload_mbps"] = out["avg_upload_kbps"] / 1000.0
    if "median_download_kbps" in out.columns:
        out["median_download_mbps"] = out["median_download_kbps"] / 1000.0
    if "median_upload_kbps" in out.columns:
        out["median_upload_mbps"] = out["median_upload_kbps"] / 1000.0

    keep = ["admin_level", "admin_code", "admin_name",
            "avg_download_mbps", "avg_upload_mbps",
            "median_download_mbps", "median_upload_mbps", "num_tiles"]
    if "avg_latency_ms" in out.columns:
        keep += ["avg_latency_ms", "median_latency_ms"]
    if "num_tests" in out.columns:
        keep.append("num_tests")
    if "num_devices" in out.columns:
        keep.append("num_devices")
    return out[keep]


def infer_year(path):
    m = re.search(r"(20\d{2})", os.path.basename(path))
    return int(m.group(1)) if m else None


# STEP 1: LOAD BOUNDARIES
for pth in [PATH_ADM0, PATH_ADM1, PATH_ADM3]:
    if not os.path.exists(pth):
        raise FileNotFoundError(pth)

boundary_national = read_geojson_no_gdal(PATH_ADM0).to_crs("EPSG:4326")
print("Loaded national boundary (ADM0)")

boundary_adm1 = read_geojson_no_gdal(PATH_ADM1)[["shapeISO", "shapeName", "geometry"]].to_crs("EPSG:4326")
boundary_adm1["shapeName"] = boundary_adm1["shapeName"].astype(str)
validate_boundaries(boundary_adm1, "ADM1", ["shapeISO", "shapeName", "geometry"])
print(f"Loaded ADM1: {len(boundary_adm1)} wilayas")

boundary_adm3 = read_geojson_no_gdal(PATH_ADM3)[["shapeISO", "shapeName", "geometry"]].to_crs("EPSG:4326")
boundary_adm3["shapeName"] = boundary_adm3["shapeName"].astype(str)
validate_boundaries(boundary_adm3, "ADM3", ["shapeISO", "shapeName", "geometry"])
print(f"Loaded ADM3: {len(boundary_adm3)} communes")

for gdf_fix in [boundary_national, boundary_adm1, boundary_adm3]:
    if isinstance(gdf_fix, gpd.GeoDataFrame):
        gdf_fix.geometry = gdf_fix.buffer(0)

ALGERIA_BOUNDS = boundary_national.total_bounds

# STEP 2: QUADKEY FILTERS AND Z12 GRID
print("STEP 2: Generating quadkey filter (z10 to z16) and z12 grid")

parent_qk_z10 = get_country_quadkeys_at_zoom(boundary_national, zoom_level=10)

print("Expanding to zoom 16 quadkeys (Ookla native resolution)")
all_qk_z16 = []
for pqk in tqdm(parent_qk_z10, desc="Expanding quadkeys"):
    all_qk_z16.extend(get_subquadkeys(pqk, 16))
print(f"Generated {len(all_qk_z16):,} z16 quadkeys for Algeria")

pd.DataFrame({"quadkey": all_qk_z16, "country": ISO_CODE}).to_csv(
    os.path.join(OUTPUT_DIR, "algeria_quadkeys_z16.csv"), index=False, encoding="utf-8"
)
QK16_SET = set(all_qk_z16)

print("Creating z12 grid (polygons) for mapping")
parent_qk_z12 = get_country_quadkeys_at_zoom(boundary_national, zoom_level=12)
gdf_grid_z12 = gpd.GeoDataFrame(
    {"quadkey": parent_qk_z12, "geometry": [quadkey_to_polygon(qk) for qk in parent_qk_z12]},
    crs="EPSG:4326",
)
gdf_grid_z12["quadkey"] = gdf_grid_z12["quadkey"].astype(str)
print(f"Created {len(gdf_grid_z12)} z12 tiles")

_gz12 = gdf_grid_z12.copy()
_gz12["quadkey"] = _gz12["quadkey"].astype(str)
_gz12["wkt"] = _gz12.geometry.apply(lambda g: g.wkt)
_gz12.drop(columns="geometry").to_csv(
    os.path.join(OUTPUT_DIR, "algeria_grid_z12_wkt.csv"), index=False, encoding="utf-8"
)
print(f"Saved grid WKT: {os.path.join(OUTPUT_DIR, 'algeria_grid_z12_wkt.csv')}")

print("Building tile to admin (ADM1 and ADM3) lookup")
gdf_grid_z12_centroids = gdf_grid_z12.copy()
gdf_grid_z12_centroids = gdf_grid_z12_centroids.to_crs(3857)
gdf_grid_z12_centroids["geometry"] = gdf_grid_z12_centroids.geometry.centroid
gdf_grid_z12_centroids = gdf_grid_z12_centroids.to_crs(4326)

adm1_ren = boundary_adm1.rename(columns={"shapeISO": "adm1_code", "shapeName": "adm1_name"})
adm3_ren = boundary_adm3.rename(columns={"shapeISO": "adm3_code", "shapeName": "adm3_name"})

lk = gpd.sjoin(
    gdf_grid_z12_centroids[["quadkey", "geometry"]],
    adm1_ren[["adm1_code", "adm1_name", "geometry"]],
    how="left", predicate="within"
).drop(columns="index_right")
lk = gpd.sjoin(
    lk,
    adm3_ren[["adm3_code", "adm3_name", "geometry"]],
    how="left", predicate="within"
).drop(columns=["index_right", "geometry"])

lk["quadkey_z12"] = lk["quadkey"].astype(str)
lookup = lk[["quadkey_z12", "adm1_code", "adm1_name", "adm3_code", "adm3_name"]]
lookup.to_csv(os.path.join(OUTPUT_DIR, "tile_admin_lookup_z12.csv"), index=False, encoding="utf-8")
print(f"Saved: {os.path.join(OUTPUT_DIR, 'tile_admin_lookup_z12.csv')}")

# STEP 3: DISCOVER PARQUET FILES
available = []
if not os.path.isdir(PATH_DATA):
    raise FileNotFoundError(f"Data directory not found: {PATH_DATA}")

for year_folder in sorted(d for d in os.listdir(PATH_DATA) if os.path.isdir(os.path.join(PATH_DATA, d))):
    if not year_folder.startswith("year="):
        continue
    m_year = re.search(r'year\s*=\s*(\d+)', year_folder)
    year = int(m_year.group(1)) if m_year else None
    if year is None or not (YEAR_MIN <= year <= YEAR_MAX):
        continue
    year_path = os.path.join(PATH_DATA, year_folder)
    for quarter_folder in sorted(d for d in os.listdir(year_path) if os.path.isdir(os.path.join(year_path, d))):
        if not quarter_folder.startswith("quarter="):
            continue
        m_quarter = re.search(r'quarter\s*=\s*(\d+)', quarter_folder)
        quarter = int(m_quarter.group(1)) if m_quarter else None
        if quarter is None:
            continue
        quarter_path = os.path.join(year_path, quarter_folder)
        files = sorted(glob.glob(os.path.join(quarter_path, "*.parquet")))
        for fp in files:
            available.append((year, quarter, fp))

print(f"Found {len(available)} parquet file(s) within {YEAR_MIN}-{YEAR_MAX}")

# STEP 4: PROCESS PARQUETS
all_national = []
all_subnat = []
all_grid_long = []
all_grid_long_z16 = []

from collections import defaultdict
by_period = defaultdict(list)
for year, quarter, fp in available:
    by_period[(year, quarter)].append(fp)

for (year, quarter), file_paths in sorted(by_period.items()):
    print(f"Processing {year}-Q{quarter} ({len(file_paths)} file(s))")
    try:
        dfs = [pd.read_parquet(fp, engine="fastparquet") for fp in file_paths]
        df = pd.concat(dfs, ignore_index=True)
        print(f"Loaded {len(df):,} rows; columns: {list(df.columns)}")

        if not validate_ookla_data(df, year, quarter):
            continue

        if "network" in df.columns:
            before = len(df)
            df = df[df["network"].astype(str).str.lower() == NET_TYPE.lower()].copy()
            print(f"Filtered by NET_TYPE='{NET_TYPE}': {before:,} to {len(df):,}")
            if len(df) == 0:
                print(f"No {NET_TYPE} data in this period; skipping")
                continue
        else:
            print("Network column not found; proceeding (assume fixed).")

        if "quadkey" in df.columns and len(QK16_SET) > 0:
            df["quadkey"] = df["quadkey"].astype(str)
            df_dza = df[df["quadkey"].isin(QK16_SET)].copy()
            print(f"Quadkey filter kept {len(df_dza):,} rows")
        elif {"lat", "lon"}.issubset(df.columns):
            xmin, ymin, xmax, ymax = ALGERIA_BOUNDS
            df_dza = df[(df["lat"] >= ymin) & (df["lat"] <= ymax) & (df["lon"] >= xmin) & (df["lon"] <= xmax)].copy()
            print(f"Bounding box filter kept {len(df_dza):,} rows")
        else:
            print("No quadkey or lat/lon columns; skipping.")
            continue

        if len(df_dza) == 0:
            print("No Algeria records; skipping.")
            continue

        # National (unweighted)
        nat = {
            "year": year, "quarter": quarter, "date": f"{year}-Q{quarter}",
            "avg_download_mbps": (df_dza["avg_d_kbps"].mean() / 1000.0),
            "avg_upload_mbps": (df_dza["avg_u_kbps"].mean() / 1000.0),
            "median_download_mbps": (df_dza["avg_d_kbps"].median() / 1000.0),
            "median_upload_mbps": (df_dza["avg_u_kbps"].median() / 1000.0),
            "avg_latency_ms": (df_dza["avg_lat_ms"].mean() if "avg_lat_ms" in df_dza.columns else None),
            "median_latency_ms": (df_dza["avg_lat_ms"].median() if "avg_lat_ms" in df_dza.columns else None),
            "num_tiles": len(df_dza),
            "num_tests": int(df_dza["tests"].sum()) if "tests" in df_dza.columns else None,
            "num_devices": int(df_dza["devices"].sum()) if "devices" in df_dza.columns else None,
        }
        if "tests" in df_dza.columns:
            w = df_dza["tests"].clip(lower=0)
            if w.sum() > 0:
                nat["wavg_download_mbps"] = (df_dza["avg_d_kbps"] / 1000.0 * w).sum() / w.sum()
                nat["wavg_upload_mbps"] = (df_dza["avg_u_kbps"] / 1000.0 * w).sum() / w.sum()
                if "avg_lat_ms" in df_dza.columns:
                    nat["wavg_latency_ms"] = (df_dza["avg_lat_ms"] * w).sum() / w.sum()
        all_national.append(nat)
        print(f"National mean: {nat['avg_download_mbps']:.2f} down / {nat['avg_upload_mbps']:.2f} up Mbps")

        # Subnational (unweighted)
        gdf_poly = ensure_tile_polygon_gdf(df_dza)
        adm1_stats = aggregate_by_admin(gdf_poly, boundary_adm1, "shapeISO", "shapeName", "ADM1")
        adm1_stats["year"] = year
        adm1_stats["quarter"] = quarter
        adm1_stats["date"] = f"{year}-Q{quarter}"
        all_subnat.append(adm1_stats)
        print(f"ADM1 aggregated: {len(adm1_stats)} rows")

        adm3_stats = aggregate_by_admin(gdf_poly, boundary_adm3, "shapeISO", "shapeName", "ADM3")
        adm3_stats["year"] = year
        adm3_stats["quarter"] = quarter
        adm3_stats["date"] = f"{year}-Q{quarter}"
        all_subnat.append(adm3_stats)
        print(f"ADM3 aggregated: {len(adm3_stats)} rows")

        # Grid z12 long
        if "quadkey" in df_dza.columns:
            df_dza["quadkey_z12"] = df_dza["quadkey"].astype(str).str[:12]
            agg_cols = {"avg_d_kbps": "mean", "avg_u_kbps": "mean"}
            if "avg_lat_ms" in df_dza.columns:
                agg_cols["avg_lat_ms"] = "mean"
            if "tests" in df_dza.columns:
                agg_cols["tests"] = "sum"

            grid_stats = df_dza.groupby("quadkey_z12").agg(agg_cols).reset_index()
            grid_stats["year"] = year
            grid_stats["quarter"] = quarter
            grid_stats["date"] = f"{year}-Q{quarter}"
            grid_stats["avg_download_mbps"] = grid_stats["avg_d_kbps"] / 1000.0
            grid_stats["avg_upload_mbps"] = grid_stats["avg_u_kbps"] / 1000.0

            keep_cols = ["quadkey_z12", "year", "quarter", "date", "avg_download_mbps", "avg_upload_mbps"]
            if "avg_lat_ms" in grid_stats.columns:
                keep_cols.append("avg_lat_ms")
            if "tests" in df_dza.columns:
                keep_cols.append("tests")

            all_grid_long.append(grid_stats[keep_cols])
            print(f"Grid z12 long rows added: {len(grid_stats)}")

        # Grid z16 long
        if "quadkey" in df_dza.columns:
            agg_cols16 = {"avg_d_kbps": "mean", "avg_u_kbps": "mean"}
            if "avg_lat_ms" in df_dza.columns:
                agg_cols16["avg_lat_ms"] = "mean"
            if "tests" in df_dza.columns:
                agg_cols16["tests"] = "sum"

            grid_stats_z16 = (
                df_dza.groupby(df_dza["quadkey"].astype(str))
                .agg(agg_cols16)
                .reset_index()
                .rename(columns={"quadkey": "quadkey_z16"})
            )
            grid_stats_z16["year"] = year
            grid_stats_z16["quarter"] = quarter
            grid_stats_z16["date"] = f"{year}-Q{quarter}"
            grid_stats_z16["avg_download_mbps"] = grid_stats_z16["avg_d_kbps"] / 1000.0
            grid_stats_z16["avg_upload_mbps"] = grid_stats_z16["avg_u_kbps"] / 1000.0

            keep_z16 = ["quadkey_z16", "year", "quarter", "date", "avg_download_mbps", "avg_upload_mbps"]
            if "avg_lat_ms" in grid_stats_z16.columns:
                keep_z16.append("avg_lat_ms")
            if "tests" in df_dza.columns:
                keep_z16.append("tests")

            all_grid_long_z16.append(grid_stats_z16[keep_z16])
            print(f"Grid z16 long rows added: {len(grid_stats_z16)}")

    except Exception as e:
        print(f"Error processing {year}-Q{quarter}: {e}")
        import traceback
        traceback.print_exc()
        continue

# National trends
if not all_national:
    raise ValueError("No national data processed; aborting save.")
df_national = pd.DataFrame(all_national).sort_values(["year", "quarter"]).reset_index(drop=True)
df_national.to_csv(os.path.join(OUTPUT_DIR, "algeria_national_trends_fixed.csv"), index=False, encoding="utf-8")
print(f"National trends -> {os.path.join(OUTPUT_DIR, 'algeria_national_trends_fixed.csv')}")

# Subnational trends
if all_subnat:
    df_sub = pd.concat(all_subnat, ignore_index=True)
    df_sub = df_sub.sort_values(["admin_level", "admin_name", "year", "quarter"]).reset_index(drop=True)
    df_sub.to_csv(os.path.join(OUTPUT_DIR, "algeria_subnational_trends_fixed.csv"), index=False, encoding="utf-8")
    print(f"Subnational trends -> {os.path.join(OUTPUT_DIR, 'algeria_subnational_trends_fixed.csv')}")
else:
    df_sub = pd.DataFrame()
    print("No subnational aggregates to save.")

# Grid long (z12)
if all_grid_long:
    df_grid_long = pd.concat(all_grid_long, ignore_index=True)
    df_grid_long.to_csv(os.path.join(OUTPUT_DIR, "algeria_grid_data_long_z12_fixed.csv"),
                        index=False, encoding="utf-8")
    print(f"Grid z12 long -> {os.path.join(OUTPUT_DIR, 'algeria_grid_data_long_z12_fixed.csv')}")

    print("Building z12 time-series WKT CSV")
    dl_wide = df_grid_long.pivot(index="quadkey_z12", columns="date", values="avg_download_mbps")
    ul_wide = df_grid_long.pivot(index="quadkey_z12", columns="date", values="avg_upload_mbps")
    dl_wide.columns = [f"download_{c}" for c in dl_wide.columns]
    ul_wide.columns = [f"upload_{c}" for c in ul_wide.columns]

    parts = [dl_wide, ul_wide]
    if "avg_lat_ms" in df_grid_long.columns:
        lat_wide = df_grid_long.pivot(index="quadkey_z12", columns="date", values="avg_lat_ms")
        lat_wide.columns = [f"latency_{c}" for c in lat_wide.columns]
        parts.append(lat_wide)

    grid_wide = pd.concat(parts, axis=1)

    if grid_wide.index.name != "quadkey_z12":
        grid_wide.index.name = "quadkey_z12"
    grid_wide = grid_wide.reset_index()
    grid_wide["quadkey_z12"] = grid_wide["quadkey_z12"].astype(str).str.strip()

    grid_wkt = pd.read_csv(os.path.join(OUTPUT_DIR, "algeria_grid_z12_wkt.csv"),
                           dtype={"quadkey": str}).rename(columns={"quadkey": "quadkey_z12"})
    grid_wkt["quadkey_z12"] = grid_wkt["quadkey_z12"].astype(str).str.strip()

    grid_merged = grid_wide.merge(grid_wkt["quadkey_z12"].to_frame().join(grid_wkt["wkt"]), on="quadkey_z12", how="left")
    out_wkt = os.path.join(OUTPUT_DIR, "algeria_grid_timeseries_fixed_wkt.csv")
    grid_merged.to_csv(out_wkt, index=False, encoding="utf-8")
    print(f"z12 time-series WKT CSV -> {out_wkt}")
else:
    print("No grid-long z12 data; skipped z12 wide build.")

if all_grid_long_z16:
    df_grid_long_z16 = pd.concat(all_grid_long_z16, ignore_index=True)
    df_grid_long_z16.to_csv(os.path.join(OUTPUT_DIR, "algeria_grid_data_long_z16_fixed.csv"),
                            index=False, encoding="utf-8")
    print(f"Grid z16 long -> {os.path.join(OUTPUT_DIR, 'algeria_grid_data_long_z16_fixed.csv')}")

print("Building population-weighted outputs")

GRID_WKT_CSV = os.path.join(OUTPUT_DIR, "algeria_grid_z12_wkt.csv")
GRID_LONG_Z12 = os.path.join(OUTPUT_DIR, "algeria_grid_data_long_z12_fixed.csv")

# WorldPop finder
tifs = sorted(glob.glob(os.path.join(WORLDPOP_DIR, "*.tif")))
year_to_tif = {infer_year(fp): fp for fp in tifs if infer_year(fp) in YEARS}

# Load tiles (WKT -> geometry)
tiles_df = pd.read_csv(GRID_WKT_CSV, dtype={"quadkey": str})
tiles_df = tiles_df.rename(columns={"quadkey": "quadkey_z12"})
tiles_df["geometry"] = tiles_df["wkt"].apply(wkt.loads)
g_tiles_master = gpd.GeoDataFrame(tiles_df[["quadkey_z12", "geometry"]].copy(), geometry="geometry", crs="EPSG:4326")
del tiles_df

# Load Ookla z12 long and tidy
df_long = pd.read_csv(GRID_LONG_Z12, dtype={"quadkey_z12": str})
if "year" not in df_long.columns and "date" in df_long.columns:
    df_long["year"] = df_long["date"].astype(str).str[:4].astype(int)
df_long = df_long[df_long["year"].between(YEAR_MIN, YEAR_MAX)].copy()
df_long["avg_download_mbps"] = df_long["avg_download_mbps"].astype(float)
df_long["avg_upload_mbps"] = df_long["avg_upload_mbps"].astype(float)

# ADM1 polygons
adm1_w84 = boundary_adm1.rename(columns={"shapeISO": "adm1_code", "shapeName": "adm1_name"})

# Prepare outputs
OUT_PW_NAT = os.path.join(OUTPUT_DIR, "algeria_pop_weighted_trends_2019_2025.csv")
OUT_PW_ADM1 = os.path.join(OUTPUT_DIR, "algeria_pop_weighted_adm1_z12_2019_2025.csv")
pd.DataFrame(columns=["year", "pw_download_mbps", "pw_upload_mbps"]).to_csv(OUT_PW_NAT, index=False, encoding="utf-8")
pd.DataFrame(columns=["year", "adm1_code", "adm1_name", "pw_download_mbps", "pw_upload_mbps"]).to_csv(
    OUT_PW_ADM1, index=False, encoding="utf-8"
)

for y in YEARS:
    tif = year_to_tif.get(y)
    if tif is None:
        print(f"Missing WorldPop {y}; skipping")
        continue

    sub = df_long[df_long["year"] == y]
    if sub.empty:
        print(f"No Ookla rows {y}; skipping")
        continue

    tile_speed = (
        sub.groupby("quadkey_z12", as_index=False)
        .agg(download=("avg_download_mbps", "mean"),
             upload=("avg_upload_mbps", "mean"))
    )
    tile_speed["quadkey_z12"] = tile_speed["quadkey_z12"].astype(str).str.strip()

    g_tiles = g_tiles_master.merge(tile_speed[["quadkey_z12"]].drop_duplicates(),
                                   on="quadkey_z12", how="inner")
    if g_tiles.empty:
        print(f"Tiles subset empty {y}; skipping")
        continue

    with rasterio.Env(GDAL_CACHEMAX=GDAL_CACHE):
        with rasterio.open(tif) as src:
            r_crs = src.crs
            r_bounds = box(*src.bounds)

            tiles_proj = g_tiles.to_crs(r_crs)
            adm1_proj = adm1_w84.to_crs(r_crs)

            # Clip tiles to raster extent
            try:
                sidx = tiles_proj.sindex
                cand = list(sidx.intersection(r_bounds.bounds))
                tiles_proj = tiles_proj.iloc[cand]
            except Exception:
                pass
            tiles_proj = tiles_proj[tiles_proj.intersects(r_bounds)].reset_index(drop=True)
            if tiles_proj.empty:
                print(f"No overlapping tiles {y}; skipping")
                continue

            cent = tiles_proj.to_crs(3857).copy()
            cent["geometry"] = cent.geometry.centroid
            cent = cent.to_crs(r_crs)

            m1 = gpd.sjoin(
                cent[["quadkey_z12", "geometry"]],
                adm1_proj[["adm1_code", "adm1_name", "geometry"]],
                how="left", predicate="within"
            ).drop(columns="index_right").drop(columns="geometry")

            unmatched = m1[m1["adm1_code"].isna()][["quadkey_z12"]]
            if not unmatched.empty:
                cand_pairs = gpd.sjoin(
                    tiles_proj.merge(unmatched, on="quadkey_z12", how="inner")[["quadkey_z12", "geometry"]],
                    adm1_proj[["adm1_code", "adm1_name", "geometry"]],
                    how="inner", predicate="intersects"
                ).drop(columns="index_right")
                if not cand_pairs.empty:
                    adm_geom_map = dict(zip(adm1_proj["adm1_code"], adm1_proj["geometry"]))
                    inter_areas = []
                    for _, row in cand_pairs.iterrows():
                        a = row["geometry"].intersection(adm_geom_map[row["adm1_code"]])
                        inter_areas.append(a.area if (a is not None and not a.is_empty) else 0.0)
                    cand_pairs["area"] = np.array(inter_areas, dtype=float)
                    m2 = (
                        cand_pairs.sort_values(["quadkey_z12", "area"], ascending=[True, False])
                        .drop_duplicates(subset=["quadkey_z12"])
                    )[["quadkey_z12", "adm1_code", "adm1_name"]]
                else:
                    m2 = pd.DataFrame(columns=["quadkey_z12", "adm1_code", "adm1_name"])
            else:
                m2 = pd.DataFrame(columns=["quadkey_z12", "adm1_code", "adm1_name"])

            m1_ok = m1.dropna(subset=["adm1_code"])
            adm_map = pd.concat([m1_ok, m2], ignore_index=True).drop_duplicates("quadkey_z12")

            tiles_adm = tiles_proj.merge(adm_map, on="quadkey_z12", how="inner")[
                ["quadkey_z12", "adm1_code", "adm1_name", "geometry"]
            ]
            if tiles_adm.empty:
                print(f"Tile to ADM1 map empty {y}; skipping")
                continue

            # zonal_stats in chunks
            n = len(tiles_adm)
            pop_list = []
            for start in range(0, n, CHUNK_SIZE):
                end = min(start + CHUNK_SIZE, n)
                sub_tiles = tiles_adm.iloc[start:end]
                zs = zonal_stats(
                    list(sub_tiles["geometry"]),
                    tif, stats=["sum"],
                    nodata=(src.nodata if src.nodata is not None else None),
                    all_touched=True
                )
                pop = np.fromiter((d["sum"] if (d["sum"] is not None) else 0.0 for d in zs), dtype=np.float64)
                pop_list.append(pd.DataFrame({
                    "quadkey_z12": sub_tiles["quadkey_z12"].values,
                    "adm1_code": sub_tiles["adm1_code"].values,
                    "adm1_name": sub_tiles["adm1_name"].values,
                    "pop": pop
                }))
            pop_df = pd.concat(pop_list, ignore_index=True)
            del pop_list

    merged = pop_df.merge(tile_speed, on="quadkey_z12", how="left").fillna({"pop": 0.0})
    merged["dl_w"] = merged["download"] * merged["pop"]
    merged["ul_w"] = merged["upload"] * merged["pop"]

    # ADM1 weighted
    adm1_year = (
        merged.groupby(["adm1_code", "adm1_name"], as_index=False)
        .agg(num_dl=("dl_w", "sum"),
             num_ul=("ul_w", "sum"),
             den=("pop", "sum"))
    )
    adm1_year["pw_download_mbps"] = np.where(adm1_year["den"] > 0, adm1_year["num_dl"] / adm1_year["den"], np.nan)
    adm1_year["pw_upload_mbps"] = np.where(adm1_year["den"] > 0, adm1_year["num_ul"] / adm1_year["den"], np.nan)
    adm1_year["year"] = y
    adm1_year[["year", "adm1_code", "adm1_name", "pw_download_mbps", "pw_upload_mbps"]] \
        .to_csv(OUT_PW_ADM1, mode="a", header=False, index=False, encoding="utf-8")

    # National weighted
    den = adm1_year["den"].sum()
    nat_pw_dl = adm1_year["num_dl"].sum() / den if den > 0 else np.nan
    nat_pw_ul = adm1_year["num_ul"].sum() / den if den > 0 else np.nan
    pd.DataFrame([{"year": y, "pw_download_mbps": nat_pw_dl, "pw_upload_mbps": nat_pw_ul}]) \
        .to_csv(OUT_PW_NAT, mode="a", header=False, index=False, encoding="utf-8")

    import gc
    gc.collect()

print("PREPROCESSING COMPLETE")
print(f"Unweighted national: {os.path.join(OUTPUT_DIR, 'algeria_national_trends_fixed.csv')}")
print(f"Unweighted subnational: {os.path.join(OUTPUT_DIR, 'algeria_subnational_trends_fixed.csv')}")
print(f"Pop-weighted national: {OUT_PW_NAT}")
print(f"Pop-weighted ADM1: {OUT_PW_ADM1}")
