<a href="https://colab.research.google.com/github/eshaanraj/cdc_2025/blob/main/cdc_data_cleaning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
Clean the dataset outputted from the previous script
"""

import pandas as pd
from pathlib import Path


INPUT_FORECAST = "sector_forecasts.csv"
INPUT_PLAIN    = "plain_cdc_dataset.xlsx"
APPEND_TO_PLAIN = True

OUT_WIDE_FORECASTS = "plain_format_forecasts.xlsx"
OUT_WIDE_EXTENDED  = "plain_format_extended.xlsx"


def load_forecasts(path: str) -> pd.DataFrame:
    p = Path(path)
    if p.suffix.lower() == ".csv":
        df = pd.read_csv(p)
    else:
        df = pd.read_excel(p)
    # keep only needed cols
    df = df[["Industry", "Year", "Forecast"]].copy()
    df["Year"] = pd.to_numeric(df["Year"], errors="coerce").astype("Int64")
    df = df.dropna(subset=["Industry", "Year", "Forecast"])
    return df

def forecasts_to_wide(df_fc: pd.DataFrame) -> pd.DataFrame:
    wide = (
        df_fc
        .pivot_table(index="Industry", columns="Year", values="Forecast", aggfunc="mean")
        .reset_index()
    )
    year_cols = sorted([c for c in wide.columns if isinstance(c, (int, float, pd.Int64Dtype)) or str(c).isdigit()],
                       key=lambda x: int(str(x)))
    cols = ["Industry"] + year_cols
    wide = wide[cols]
    wide.columns = ["Industry"] + [int(str(c)) for c in wide.columns[1:]]
    return wide

def append_to_plain(wide_fc: pd.DataFrame, plain_path: str) -> pd.DataFrame:
    plain = pd.read_excel(plain_path, sheet_name=0)

    obj_cols = [c for c in plain.columns if plain[c].dtype == "O"]
    industry_col = obj_cols[0] if obj_cols else plain.columns[0]
    plain = plain.rename(columns={industry_col: "Industry"})

    plain_years_str = [str(c) for c in plain.columns if str(c).isdigit()]
    plain_years = [int(y) for y in plain_years_str]

    for y_str in plain_years_str:
        plain[y_str] = pd.to_numeric(plain[y_str], errors="coerce")

    fc_years = [c for c in wide_fc.columns if c != "Industry"]
    new_years = [y for y in fc_years if y not in plain_years]

    merged = plain.merge(wide_fc[["Industry"] + new_years], on="Industry", how="outer")

    out_years = sorted([c for c in merged.columns if c != "Industry" and str(c).isdigit()], key=int)
    merged = merged[["Industry"] + out_years]
    return merged

def main():
    df_fc = load_forecasts(INPUT_FORECAST)
    wide_fc = forecasts_to_wide(df_fc)
    with pd.ExcelWriter(OUT_WIDE_FORECASTS, engine="xlsxwriter") as w:
        wide_fc.to_excel(w, index=False, sheet_name="ForecastsWide")

    if APPEND_TO_PLAIN:
        extended = append_to_plain(wide_fc, INPUT_PLAIN)
        with pd.ExcelWriter(OUT_WIDE_EXTENDED, engine="xlsxwriter") as w:
            extended.to_excel(w, index=False, sheet_name="Extended")

    print("Saved:", OUT_WIDE_FORECASTS)
    if APPEND_TO_PLAIN:
        print("Saved:", OUT_WIDE_EXTENDED)

if __name__ == "__main__":
    main()