In [None]:
from __future__ import annotations

from pathlib import Path
import importlib
import numpy as np
import pandas as pd


# --- helpers -------------------------------------------------------------

def parse_number_series(s: pd.Series) -> pd.Series:
    """
    Rough equivalent to readr::parse_number(as.character(x)).
    Keeps digits, decimal point, and minus sign; converts to float.
    """
    return (
        s.astype(str)
         .str.replace(r"[^\d\.\-]", "", regex=True)
         .replace({"": np.nan, "nan": np.nan, "None": np.nan})
         .pipe(pd.to_numeric, errors="coerce")
    )


def last_non_na(series: pd.Series):
    non_na = series.dropna()
    return non_na.iloc[-1] if len(non_na) else np.nan


# --- pipeline step loaders (replace with your real modules) --------------

def load_plan_data(year: int) -> pd.DataFrame:
    # Equivalent of: source("R_code/1_plan-data.R"); fp <- final.plans
    mod = importlib.import_module("py_code.plan_data")  # you create this
    return mod.build(year)  # returns DataFrame


def load_service_area(year: int) -> pd.DataFrame:
    mod = importlib.import_module("py_code.service_area")
    return mod.build(year)


def load_plan_characteristics(year: int) -> pd.DataFrame:
    # year-specific file in R: 3_plan-characteristics-{y}.R
    mod = importlib.import_module(f"py_code.plan_characteristics_{year}")
    df = mod.build(year)
    # mutate(state = str_to_lower(state))
    df = df.copy()
    if "state" in df.columns:
        df["state"] = df["state"].astype("string").str.lower()
    return df


def load_penetration(year: int) -> pd.DataFrame:
    mod = importlib.import_module("py_code.penetration")
    return mod.build(year)


def load_star_ratings(year: int) -> pd.DataFrame:
    mod = importlib.import_module(f"py_code.star_ratings_{year}")
    return mod.build(year)


def load_risk_rebates(year: int) -> pd.DataFrame:
    mod = importlib.import_module(f"py_code.risk_rebates_{year}")
    return mod.build(year)


def load_benchmarks(year: int) -> pd.DataFrame:
    mod = importlib.import_module(f"py_code.benchmarks_{year}")
    return mod.build(year)


def load_ffs_costs(year: int) -> pd.DataFrame:
    mod = importlib.import_module(f"py_code.ffs_costs_{year}")
    return mod.build(year)


# --- main builder --------------------------------------------------------

def build_year_ma(year: int) -> pd.DataFrame:
    print("building plan data")
    fp = load_plan_data(year)

    print("building service area data")
    fsa = load_service_area(year)

    print("building plan characteristics data")
    fls = load_plan_characteristics(year)

    print("building penetration data")
    fpen = load_penetration(year)

    print("building star ratings data")
    fsr = load_star_ratings(year)

    print("building risk rebates data")
    frr = load_risk_rebates(year)

    print("building benchmarks data")
    fbm = load_benchmarks(year)

    print("building ffs costs data")
    ffs = load_ffs_costs(year)

    print(f"finalizing ma data for year {year}")

    # inner_join(fsa %>% select(contractid, fips, year), by=c("contractid","fips","year"))
    fsa_key = fsa[["contractid", "fips", "year"]].copy()
    final_ma = fp.merge(fsa_key, on=["contractid", "fips", "year"], how="inner")

    # filter(!state %in% c("VI","PR","MP","GU","AS",""), snp=="No", (planid < 800 | planid >= 900), !is.na(planid), !is.na(fips))
    final_ma = final_ma[
        (~final_ma["state"].isin(["VI", "PR", "MP", "GU", "AS", ""])) &
        (final_ma["snp"] == "No") &
        ((final_ma["planid"] < 800) | (final_ma["planid"] >= 900)) &
        (final_ma["planid"].notna()) &
        (final_ma["fips"].notna())
    ].copy()

    # left_join(fsr, by=c("contractid","year"))
    final_ma = final_ma.merge(fsr, on=["contractid", "year"], how="left")

    # left_join(fpen %>% ungroup() %>% rename(state_long=state, county_long=county) %>% mutate(state_long=str_to_lower(state_long)),
    #          by=c("fips","year"))
    fpen2 = fpen.copy()
    fpen2 = fpen2.rename(columns={"state": "state_long", "county": "county_long"})
    if "state_long" in fpen2.columns:
        fpen2["state_long"] = fpen2["state_long"].astype("string").str.lower()
    final_ma = final_ma.merge(fpen2, on=["fips", "year"], how="left")

    # final.state <- final.ma %>% group_by(state) %>% summarize(state_name = last(state_long[!is.na(state_long)]))
    final_state = (
        final_ma.groupby("state", as_index=False)
                .agg(state_name=("state_long", last_non_na))
    )

    # left_join(final.state, by="state")
    final_ma = final_ma.merge(final_state, on="state", how="left")

    # left_join(fls, by=c("contractid","planid","state_name"="state","county","year"))
    fls2 = fls.rename(columns={"state": "state_name"}) if "state" in fls.columns else fls
    final_ma = final_ma.merge(
        fls2,
        on=["contractid", "planid", "state_name", "county", "year"],
        how="left"
    )

    # left_join(frr %>% select(-contract_name, -plan_type), by=c("contractid","planid","year"))
    frr2 = frr.drop(columns=[c for c in ["contract_name", "plan_type"] if c in frr.columns])
    final_ma = final_ma.merge(frr2, on=["contractid", "planid", "year"], how="left")

    # left_join(fbm %>% mutate(ssa=as.numeric(ssa)), by=c("ssa","year"))
    fbm2 = fbm.copy()
    fbm2["ssa"] = pd.to_numeric(fbm2["ssa"], errors="coerce")
    final_ma["ssa"] = pd.to_numeric(final_ma["ssa"], errors="coerce")
    final_ma = final_ma.merge(fbm2, on=["ssa", "year"], how="left")

    # Star_Rating logic
    def compute_star_rating(row):
        partd = row.get("partd")
        partc_score = row.get("partc_score")
        partcd_score = row.get("partcd_score")
        if partd == "No":
            return partc_score
        if partd == "Yes" and pd.isna(partcd_score):
            return partc_score
        if partd == "Yes" and not pd.isna(partcd_score):
            return partcd_score
        return np.nan

    final_ma["Star_Rating"] = final_ma.apply(compute_star_rating, axis=1)

    # ma_rate case_when
    sr = final_ma["Star_Rating"]
    y = final_ma["year"]

    final_ma["ma_rate"] = np.select(
        [
            y < 2012,
            (y >= 2012) & (y < 2015) & (sr == 5),
            (y >= 2012) & (y < 2015) & (sr == 4.5),
            (y >= 2012) & (y < 2015) & (sr == 4),
            (y >= 2012) & (y < 2015) & (sr == 3.5),
            (y >= 2012) & (y < 2015) & (sr == 3),
            (y >= 2012) & (y < 2015) & (sr < 3),
            (y >= 2012) & (y < 2015) & (sr.isna()),
            (y >= 2015) & (sr >= 4),
            (y >= 2015) & (sr < 4),
            (y >= 2015) & (sr.isna()),
        ],
        [
            final_ma.get("risk_ab"),
            final_ma.get("risk_star5"),
            final_ma.get("risk_star45"),
            final_ma.get("risk_star4"),
            final_ma.get("risk_star35"),
            final_ma.get("risk_star3"),
            final_ma.get("risk_star25"),
            final_ma.get("risk_star35"),
            final_ma.get("risk_bonus5"),
            final_ma.get("risk_bonus0"),
            final_ma.get("risk_bonus35"),
        ],
        default=np.nan
    )

    # basic_premium case_when
    rebate_partc = final_ma.get("rebate_partc")
    partd = final_ma.get("partd")
    premium = final_ma.get("premium")
    premium_partc = final_ma.get("premium_partc")

    final_ma["basic_premium"] = np.select(
        [
            rebate_partc.gt(0),
            (partd == "No") & premium.notna() & premium_partc.isna(),
        ],
        [
            0,
            premium,
        ],
        default=premium_partc
    )

    # bid case_when
    payment_partc = final_ma.get("payment_partc")
    riskscore_partc = final_ma.get("riskscore_partc")
    basic_premium = final_ma["basic_premium"]

    final_ma["bid"] = np.select(
        [
            rebate_partc.eq(0) & basic_premium.gt(0),
            rebate_partc.gt(0) | basic_premium.eq(0),
        ],
        [
            (payment_partc + basic_premium) / riskscore_partc,
            payment_partc / riskscore_partc,
        ],
        default=np.nan
    )

    # left_join(ffs %>% select(-state), by=c("ssa","year"))
    ffs2 = ffs.drop(columns=[c for c in ["state"] if c in ffs.columns]).copy()
    ffs2["ssa"] = pd.to_numeric(ffs2["ssa"], errors="coerce")
    final_ma = final_ma.merge(ffs2, on=["ssa", "year"], how="left")

    # avg_ffscost case_when
    parta_enroll = final_ma.get("parta_enroll")
    partb_enroll = final_ma.get("partb_enroll")
    parta_reimb = final_ma.get("parta_reimb")
    partb_reimb = final_ma.get("partb_reimb")

    final_ma["avg_ffscost"] = np.select(
        [
            (parta_enroll == 0) & (partb_enroll == 0),
            (parta_enroll == 0) & (partb_enroll > 0),
            (parta_enroll > 0) & (partb_enroll == 0),
            (parta_enroll > 0) & (partb_enroll > 0),
        ],
        [
            0,
            partb_reimb / partb_enroll,
            parta_reimb / parta_enroll,
            (parta_reimb / parta_enroll) + (partb_reimb / partb_enroll),
        ],
        default=np.nan
    )

    # write_tsv(final.ma, ...)
    out_path = Path(f"data/output/ma_data_{year}.txt")
    out_path.parent.mkdir(parents=True, exist_ok=True)
    final_ma.to_csv(out_path, sep="\t", index=False)

    return final_ma


# --- run all years -------------------------------------------------------

years = range(2008, 2024)
frames = []
for y in years:
    print(f"Building MA data for year: {y}")
    frames.append(build_year_ma(y))

final_ma_full = pd.concat(frames, ignore_index=True)

Path("data/output").mkdir(parents=True, exist_ok=True)
final_ma_full.to_csv("data/output/ma_data_full.txt", sep="\t", index=False)