### DataLoading 

In [1]:
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Literal, Union
import pandas as pd

# ---------------------------------------------------------------------------
# Module-level docs & type aliases
# ---------------------------------------------------------------------------

"""
Generic loaders & converters for regional-timeseries data.

Canonical representation
------------------------
    CanonicalData = dict[str, pd.DataFrame]
        key   – region name
        value – DataFrame indexed by datetime; columns = features

Disk layouts we support
-----------------------
    "multi" : MultiIndex columns (level-0 = region, level-1 = feature)
    "long"  : Normal DataFrame with a *region* column
    "wide"  : Wide DataFrame whose column names are "<region><sep><feature>"
"""

CanonicalData = Dict[str, pd.DataFrame]
Layout        = Literal["multi", "long", "wide"]

# ---------------------------------------------------------------------------
# ─── Converters: DISK ➜ CANONICAL ────────────────────────────────────────────
# ---------------------------------------------------------------------------

def _multi_to_dict(
    df: pd.DataFrame,
    *,
    cols: List[str] | None = None
) -> CanonicalData:
    """
    Convert a Multi-Index column DataFrame to the canonical dict form.

    Parameters
    ----------
    df   : DataFrame with two column levels (region, feature)
    cols : optional list of feature names to *keep* (others are dropped)

    Returns
    -------
    Dict[str, DataFrame]  mapping region → feature-matrix
    """
    out: CanonicalData = {}
    for region in df.columns.get_level_values(0).unique():
        sub = df.xs(region, axis=1, level=0)
        if cols is not None:
            sub = sub[[c for c in cols if c in sub.columns]]
        out[region] = sub.copy()
    return out


def _long_to_dict(
    df: pd.DataFrame,
    *,
    region_col: str = "region",
    cols: List[str] | None = None,
) -> CanonicalData:
    """
    Convert a ‘long’ (tidy) DataFrame to canonical dict form.

    A *long* frame must contain a `region_col` column; all other columns
    are interpreted as features.

    Parameters
    ----------
    df         : tidy DataFrame with duplicate timestamps per region
    region_col : column that identifies the region
    cols       : optional feature filter

    Returns
    -------
    Canonical dict keyed by region
    """
    if region_col not in df.columns:
        raise KeyError(f"Expected column '{region_col}' in long-layout dataframe.")

    features = [c for c in df.columns if c != region_col]
    if cols is not None:
        features = [c for c in features if c in cols]

    out: CanonicalData = {}
    for region, grp in df.groupby(region_col):
        frame = grp.drop(columns=region_col)
        out[region] = frame[features].copy()
    return out


def _wide_to_dict(
    df: pd.DataFrame,
    *,
    sep: str = "-",
    cols: List[str] | None = None,
) -> CanonicalData:
    """
    Convert a wide DataFrame with ‘region-feature’ column names into canon dict.

    Parameters
    ----------
    df   : DataFrame whose columns look like "tokyo-demand"
    sep  : separator between region and feature
    cols : optional feature filter

    Returns
    -------
    Canonical dict keyed by region
    """
    parts = df.columns.to_series().str.split(sep, n=1, expand=True)
    if parts.isna().any().any():
        raise ValueError(
            f"Column names do not all match <region>{sep}<feature> pattern."
        )

    out: CanonicalData = {}
    for region in parts[0].unique():
        mask      = parts[0] == region
        sub_cols  = df.columns[mask]
        features  = parts[1][mask]
        sub_frame = df[sub_cols].copy()
        sub_frame.columns = features
        if cols is not None:
            sub_frame = sub_frame[[c for c in cols if c in sub_frame.columns]]
        out[region] = sub_frame
    return out

# ---------------------------------------------------------------------------
# ─── Converters: CANONICAL ➜ DISK ────────────────────────────────────────────
# ---------------------------------------------------------------------------

def _dict_to_multi(data: CanonicalData) -> pd.DataFrame:
    """
    Combine a canonical dict into a Multi-Index column DataFrame.
    """
    frames = []
    for region, df in data.items():
        tmp = df.copy()
        tmp.columns = pd.MultiIndex.from_product([[region], tmp.columns])
        frames.append(tmp)
    return pd.concat(frames, axis=1).sort_index(axis=1)


def _dict_to_long(
    data: CanonicalData,
    *,
    region_col: str = "region"
) -> pd.DataFrame:
    """
    Combine a canonical dict into a long (tidy) DataFrame.
    """
    frames = []
    for region, df in data.items():
        tmp = df.copy()
        tmp[region_col] = region
        frames.append(tmp)
    combined = pd.concat(frames)
    order = [c for c in combined.columns if c != region_col] + [region_col]
    return combined[order]


def _dict_to_wide(
    data: CanonicalData,
    *,
    sep: str = "-"
) -> pd.DataFrame:
    """
    Combine a canonical dict into wide ‘region-feature’ columns.
    """
    frames = []
    for region, df in data.items():
        tmp = df.copy()
        tmp.columns = [f"{region}{sep}{c}" for c in tmp.columns]
        frames.append(tmp)
    return pd.concat(frames, axis=1)

# ---------------------------------------------------------------------------
# ─── Public helpers: one-file load / convert ────────────────────────────────
# ---------------------------------------------------------------------------

def load_parquet_as_canonical(
    path: str | Path,
    *,
    layout: Layout,
    region_col: str = "region",
    sep: str = "-",
    cols: List[str] | None = None,
) -> CanonicalData:
    """
    Read a parquet file of known *layout* and return canonical dict form.

    Parameters
    ----------
    path       : file path
    layout     : "multi", "long", or "wide" (must match the file)
    region_col : name of the region column for *long* layout
    sep        : region-feature separator for *wide* layout
    cols       : optional feature subset

    Returns
    -------
    CanonicalData
    """
    df = pd.read_parquet(path)
    if layout == "multi":
        return _multi_to_dict(df, cols=cols)
    if layout == "long":
        return _long_to_dict(df, region_col=region_col, cols=cols)
    if layout == "wide":
        return _wide_to_dict(df, sep=sep, cols=cols)
    raise ValueError(f"Unsupported layout '{layout}'.")


def canonical_to_layout(
    data: CanonicalData,
    layout: Layout,
    *,
    region_col: str = "region",
    sep: str = "-",
) -> pd.DataFrame:
    """
    Convert *canonical* dict back to the requested layout format.

    Parameters
    ----------
    data       : canonical dict
    layout     : target layout ("multi", "long", "wide")
    region_col : name of region column for long layout
    sep        : separator for wide layout

    Returns
    -------
    DataFrame in the specified layout
    """
    if layout == "multi":
        return _dict_to_multi(data)
    if layout == "long":
        return _dict_to_long(data, region_col=region_col)
    if layout == "wide":
        return _dict_to_wide(data, sep=sep)
    raise ValueError(f"Unsupported target layout '{layout}'.")

# ---------------------------------------------------------------------------
# ─── Helpers for multi-file ingestion ───────────────────────────────────────
# ---------------------------------------------------------------------------

def _collect_by_layout(
    paths: List[str] | None,
    layout: Layout,
    *,
    cols: List[str] | None,
    region_col: str,
    sep: str,
) -> Dict[str, List[pd.DataFrame]]:
    """
    Internal helper: load every file in `paths` (given its layout)
    and collate DataFrames by region.

    Returns
    -------
    Dict[region, list[DataFrame]]
    """
    bucket: Dict[str, List[pd.DataFrame]] = {}
    for p in paths or []:
        data = load_parquet_as_canonical(
            p, layout=layout, region_col=region_col, sep=sep, cols=cols
        )
        for region, df in data.items():
            bucket.setdefault(region, []).append(df)
    return bucket


def load_region_dataset(
    *,
    region: str,
    multi_paths: List[str] | None = None,
    long_paths: List[str] | None = None,
    wide_paths: List[str] | None = None,
    cols: List[str] | None = None,
    region_col: str = "region",
    sep: str = "-",
) -> pd.DataFrame:
    """
    Aggregate **one** region’s data from any number of parquet files.

    Parameters
    ----------
    region      : name of the region to extract
    multi_paths : list of files in *multi* layout
    long_paths  : list of files in *long* layout
    wide_paths  : list of files in *wide* layout
    cols        : optional feature subset
    region_col  : region column (long layout)
    sep         : separator (wide layout)

    Returns
    -------
    DataFrame with union of timestamps & features for that region
    (duplicate columns are de-duplicated, keeping last-read version).
    """
    frames: List[pd.DataFrame] = []
    for layout, paths in (
        ("multi", multi_paths),
        ("long", long_paths),
        ("wide", wide_paths),
    ):
        for p in paths or []:
            data = load_parquet_as_canonical(
                p, layout=layout, region_col=region_col, sep=sep, cols=cols
            )
            if region in data:
                frames.append(data[region])

    if not frames:
        raise KeyError(f"Region '{region}' not found in supplied paths.")

    combined = pd.concat(frames, axis=1)
    combined = combined.loc[:, ~combined.columns.duplicated()]  # keep last duplicate
    combined.sort_index(inplace=True)
    return combined


def load_all_regions_dataset(
    *,
    multi_paths: List[str] | None = None,
    long_paths: List[str] | None = None,
    wide_paths: List[str] | None = None,
    regions: List[str] | None = None,
    cols: List[str] | None = None,
    region_col: str = "region",
    sep: str = "-",
) -> CanonicalData:
    """
    Load **every** region (or a specified subset) from the supplied file lists.

    Parameters
    ----------
    multi_paths : parquet files in *multi* layout
    long_paths  : parquet files in *long* layout
    wide_paths  : parquet files in *wide* layout
    regions     : optional subset of region names to return
    cols        : optional feature subset
    region_col  : region column (long layout)
    sep         : separator (wide layout)

    Returns
    -------
    CanonicalData containing one DataFrame per region.
    """
    # Read & collate per-layout
    bucket: Dict[str, List[pd.DataFrame]] = {}

    for paths, layout in (
        (multi_paths, "multi"),
        (long_paths, "long"),
        (wide_paths, "wide"),
    ):
        part = _collect_by_layout(
            paths,
            layout,
            cols=cols,
            region_col=region_col,
            sep=sep,
        )
        for region, frames in part.items():
            bucket.setdefault(region, []).extend(frames)

    # Filter to requested subset (if any)
    if regions is not None:
        bucket = {r: bucket[r] for r in regions if r in bucket}

    # Merge frames per region
    out: CanonicalData = {}
    for region, frames in bucket.items():
        df = pd.concat(frames, axis=1)
        df = df.loc[:, ~df.columns.duplicated()]
        df.sort_index(inplace=True)
        out[region] = df

    # Handle missing regions gracefully
    if regions is not None:
        missing = set(regions) - set(out)
        if missing:
            raise KeyError(
                f"The following requested regions were not found in any file: "
                f"{', '.join(sorted(missing))}"
            )

    return out
#88888888888888888888888888888888888888888888888888888888888888888888888888888888888888888888888

from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Literal
import pandas as pd

# -------------  Canonical type & layout tags  --------------------------------
CanonicalData = Dict[str, pd.DataFrame]
Layout        = Literal["multi", "long", "wide"]

# ----------------------------------------------------------------------------- 
# ═════════════════════════════ I/O CONVERSION ════════════════════════════════
# (unchanged from earlier — trimmed for brevity; keep the same helper fns:
#     _multi_to_dict, _long_to_dict, _wide_to_dict,
#     _dict_to_multi,  _dict_to_long,  _dict_to_wide,
#     load_parquet_as_canonical, canonical_to_layout)
# ----------------------------------------------------------------------------- 

# ...  <keep earlier converter code here exactly as before> ...


# ----------------------------------------------------------------------------- 
# ══════════════════════════ PRE-MERGE FEATURE PASS ═══════════════════════════
# ----------------------------------------------------------------------------- 

from jp_da_imb.utils.time import construct_time_features   # ← your existing util


def preprocess_region_df(
    df: pd.DataFrame,
    *,
    start_date: str | pd.Timestamp | None = None,
    end_date:   str | pd.Timestamp | None = None,
    freq: str | None = "30T",
    na_removal: bool = True,
    add_time_feats: bool = True,
) -> pd.DataFrame:
    """
    Apply *per-file* hygiene & feature engineering **before** merging with other
    DataFrames for the same region.

    Steps (all optional, controlled by kwargs)
    -----------------------------------------
    1. Clip to [start_date, end_date] (if provided)
    2. Resample to `freq` (mean aggregation)       — skip if `freq=None`
    3. Drop rows with any NA                       — only if `na_removal=True`
    4. Add calendar/time features (and cast them
       to `category`) using your `construct_time_features`
       — only if `add_time_feats=True`
    """
    out = df.copy()

    # 1. Clip date range -------------------------------------------------------
    if start_date is not None:
        out = out.loc[out.index >= pd.to_datetime(start_date)]
    if end_date is not None:
        out = out.loc[out.index <= pd.to_datetime(end_date)]

    # 2. Resample --------------------------------------------------------------
    if freq is not None:
        out = out.resample(freq).mean()

    # 3. Remove NAs ------------------------------------------------------------
    if na_removal:
        out = out.dropna()

    # 4. Calendar / categorical time features ---------------------------------
    if add_time_feats:
        construct_time_features(out)     # mutates in place, adds columns
        time_cols = [
            "weekday", "hour", "month", "quarter",
            "koma", "koma_week", "is_holiday",
            "is_peak", "is_weekend",
        ]
        for c in time_cols:
            if c in out.columns and not pd.api.types.is_categorical_dtype(out[c]):
                out[c] = pd.Categorical(out[c])

    return out


# ----------------------------------------------------------------------------- 
# ═══════════════  MULTI-FILE INGEST, NOW WITH PRE-PASS  ══════════════════════
# ----------------------------------------------------------------------------- 

def _collect_by_layout(
    paths: List[str] | None,
    layout: Layout,
    *,
    cols: List[str] | None,
    region_col: str,
    sep: str,
    preprocess_kwargs: dict,
) -> Dict[str, List[pd.DataFrame]]:
    """
    Load every parquet file in `paths` -> canonical dict,
    run `preprocess_region_df` on each region DataFrame,
    and bucket them by region.
    """
    bucket: Dict[str, List[pd.DataFrame]] = {}
    for p in paths or []:
        data = load_parquet_as_canonical(
            p, layout=layout, region_col=region_col, sep=sep, cols=cols
        )
        for region, df in data.items():
            df_proc = preprocess_region_df(df, **preprocess_kwargs)
            bucket.setdefault(region, []).append(df_proc)
    return bucket


def load_region_dataset(
    *,
    region: str,
    multi_paths: List[str] | None = None,
    long_paths:  List[str] | None = None,
    wide_paths:  List[str] | None = None,
    cols:       List[str] | None = None,
    region_col: str = "region",
    sep: str = "-",
    # --- pre-processing options ----------------------------------------------
    start_date: str | pd.Timestamp | None = None,
    end_date:   str | pd.Timestamp | None = None,
    freq: str | None = "30T",
    na_removal: bool = True,
    add_time_feats: bool = True,
) -> pd.DataFrame:
    """
    Build a **single-region** DataFrame by reading any combination of
    *multi*, *long*, and *wide* parquet files, applying the pre-processing
    steps to each file **before** they are merged together.

    Other parameters are identical to the earlier version, plus the
    pre-processing kwargs shown above.
    """
    preprocess_kwargs = dict(
        start_date=start_date,
        end_date=end_date,
        freq=freq,
        na_removal=na_removal,
        add_time_feats=add_time_feats,
    )

    frames: List[pd.DataFrame] = []
    for layout, paths in (
        ("multi", multi_paths),
        ("long",  long_paths),
        ("wide",  wide_paths),
    ):
        for p in paths or []:
            data = load_parquet_as_canonical(
                p, layout=layout, region_col=region_col, sep=sep, cols=cols
            )
            if region in data:
                frames.append(preprocess_region_df(data[region], **preprocess_kwargs))

    if not frames:
        raise KeyError(f"Region '{region}' not found in supplied paths.")

    combined = pd.concat(frames, axis=1)
    combined = combined.loc[:, ~combined.columns.duplicated()]  # keep last dup
    combined.sort_index(inplace=True)
    return combined


def load_all_regions_dataset(
    *,
    multi_paths: List[str] | None = None,
    long_paths:  List[str] | None = None,
    wide_paths:  List[str] | None = None,
    regions:     List[str] | None = None,
    cols:       List[str] | None = None,
    region_col: str = "region",
    sep: str = "-",
    # --- pre-processing options ----------------------------------------------
    start_date: str | pd.Timestamp | None = None,
    end_date:   str | pd.Timestamp | None = None,
    freq: str | None = "30T",
    na_removal: bool = True,
    add_time_feats: bool = True,
) -> CanonicalData:
    """
    Load **all** regions (or a given subset) from the provided file lists,
    applying the pre-merge feature-creation steps to every individual file.
    """
    preprocess_kwargs = dict(
        start_date=start_date,
        end_date=end_date,
        freq=freq,
        na_removal=na_removal,
        add_time_feats=add_time_feats,
    )

    bucket: Dict[str, List[pd.DataFrame]] = {}

    for paths, layout in (
        (multi_paths, "multi"),
        (long_paths,  "long"),
        (wide_paths,  "wide"),
    ):
        part = _collect_by_layout(
            paths,
            layout,
            cols=cols,
            region_col=region_col,
            sep=sep,
            preprocess_kwargs=preprocess_kwargs,
        )
        for region, frames in part.items():
            bucket.setdefault(region, []).extend(frames)

    # subset filter
    if regions is not None:
        bucket = {r: bucket[r] for r in regions if r in bucket}

    # merge frames per region
    out: CanonicalData = {}
    for region, frames in bucket.items():
        df = pd.concat(frames, axis=1)
        df = df.loc[:, ~df.columns.duplicated()]
        df.sort_index(inplace=True)
        out[region] = df

    # raise on missing explicit requests
    if regions is not None:
        missing = set(regions) - set(out)
        if missing:
            raise KeyError(
                f"The following requested regions were not found in any file: "
                f\"{', '.join(sorted(missing))}\"
            )

    return out


In [None]:
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Literal
import pandas as pd
from jp_da_imb.utils.time import construct_time_features   # ← your util

# -----------------------------------------------------------
# Canonical type & layout tags
# -----------------------------------------------------------
CanonicalData = Dict[str, pd.DataFrame]
Layout        = Literal["multi", "long", "wide"]

# -----------------------------------------------------------
#  DISK ➜ CANONICAL converters (unchanged from earlier)
#  _multi_to_dict, _long_to_dict, _wide_to_dict
#  CANONICAL ➜ DISK converters
#  _dict_to_multi, _dict_to_long, _dict_to_wide
#  load_parquet_as_canonical, canonical_to_layout
#  (omitted here for brevity but keep exactly as before)
# -----------------------------------------------------------


# -----------------------------------------------------------
#  Per-file pre-merge feature pass (unchanged)
# -----------------------------------------------------------
def preprocess_region_df(
    df: pd.DataFrame,
    *,
    start_date: str | pd.Timestamp | None = None,
    end_date:   str | pd.Timestamp | None = None,
    freq: str | None = "30T",
    na_removal: bool = True,
    add_time_feats: bool = True,
) -> pd.DataFrame:
    out = df.copy()

    if start_date is not None:
        out = out.loc[out.index >= pd.to_datetime(start_date)]
    if end_date is not None:
        out = out.loc[out.index <= pd.to_datetime(end_date)]

    if freq is not None:
        out = out.resample(freq).mean()

    if na_removal:
        out = out.dropna()

    if add_time_feats:
        construct_time_features(out)
        for c in [
            "weekday", "hour", "month", "quarter",
            "koma", "koma_week", "is_holiday",
            "is_peak", "is_weekend",
        ]:
            if c in out.columns and not pd.api.types.is_categorical_dtype(out[c]):
                out[c] = pd.Categorical(out[c])

    return out


# -----------------------------------------------------------
#  SINGLE-REGION LOADER  (long files can differ in region_col)
# -----------------------------------------------------------
def load_region_dataset(
    *,
    region: str,
    multi_paths: List[str] | None = None,
    long_paths:  List[str] | None = None,
    wide_paths:  List[str] | None = None,
    long_region_cols: List[str] | None = None,
    cols:       List[str] | None = None,
    sep: str = "-",
    # --- pre-processing options
    start_date: str | pd.Timestamp | None = None,
    end_date:   str | pd.Timestamp | None = None,
    freq: str | None = "30T",
    na_removal: bool = True,
    add_time_feats: bool = True,
) -> pd.DataFrame:
    """
    Aggregate *one* region’s data from any mix of multi/long/wide parquet files.

    Parameters
    ----------
    region            : which region to return
    long_region_cols  : list parallel to `long_paths` giving the column
                        name that holds the region label *for each* file.
                        Omit to default every long file to \"region\".
    All other args are as before.
    """
    preprocess_kwargs = dict(
        start_date=start_date,
        end_date=end_date,
        freq=freq,
        na_removal=na_removal,
        add_time_feats=add_time_feats,
    )

    frames: List[pd.DataFrame] = []

    # ---- multi layout (single region column level 0) ------------------------
    for p in multi_paths or []:
        data = load_parquet_as_canonical(p, layout="multi", cols=cols)
        if region in data:
            frames.append(preprocess_region_df(data[region], **preprocess_kwargs))

    # ---- wide layout --------------------------------------------------------
    for p in wide_paths or []:
        data = load_parquet_as_canonical(p, layout="wide", sep=sep, cols=cols)
        if region in data:
            frames.append(preprocess_region_df(data[region], **preprocess_kwargs))

    # ---- long layout  (per-file region_col!) -------------------------------
    if long_paths:
        # fill default list if none supplied
        if long_region_cols is None:
            long_region_cols = ["region"] * len(long_paths)
        if len(long_region_cols) != len(long_paths):
            raise ValueError(
                "long_region_cols must be the same length as long_paths"
            )

        for p, reg_col in zip(long_paths, long_region_cols):
            data = load_parquet_as_canonical(
                p, layout="long", region_col=reg_col, cols=cols
            )
            if region in data:
                frames.append(preprocess_region_df(data[region], **preprocess_kwargs))

    if not frames:
        raise KeyError(f"Region '{region}' not found in supplied paths.")

    combined = pd.concat(frames, axis=1)
    combined = combined.loc[:, ~combined.columns.duplicated()]
    combined.sort_index(inplace=True)
    return combined


# -----------------------------------------------------------
#  ALL-REGIONS LOADER  (long paths paired with region_cols)
# -----------------------------------------------------------
def load_all_regions_dataset(
    *,
    multi_paths: List[str] | None = None,
    long_paths:  List[str] | None = None,
    wide_paths:  List[str] | None = None,
    long_region_cols: List[str] | None = None,
    regions:     List[str] | None = None,
    cols:        List[str] | None = None,
    sep: str = "-",
    # --- pre-processing options
    start_date: str | pd.Timestamp | None = None,
    end_date:   str | pd.Timestamp | None = None,
    freq: str | None = "30T",
    na_removal: bool = True,
    add_time_feats: bool = True,
) -> CanonicalData:
    """
    Load every region (or the ones in `regions`) applying the pre-merge
    feature pass. Long-format files can each declare their own region column
    via `long_region_cols`.
    """
    preprocess_kwargs = dict(
        start_date=start_date,
        end_date=end_date,
        freq=freq,
        na_removal=na_removal,
        add_time_feats=add_time_feats,
    )

    bucket: Dict[str, List[pd.DataFrame]] = {}

    # -------- helper to merge into bucket -----------------
    def _collect(data_dict: CanonicalData):
        for r, df in data_dict.items():
            bucket.setdefault(r, []).append(
                preprocess_region_df(df, **preprocess_kwargs)
            )

    # ---- multi files -------------------------------------
    for p in multi_paths or []:
        _collect(load_parquet_as_canonical(p, layout="multi", cols=cols))

    # ---- wide files --------------------------------------
    for p in wide_paths or []:
        _collect(load_parquet_as_canonical(p, layout="wide", sep=sep, cols=cols))

    # ---- long files (per-file region_col) ----------------
    if long_paths:
        if long_region_cols is None:
            long_region_cols = ["region"] * len(long_paths)
        if len(long_region_cols) != len(long_paths):
            raise ValueError(
                "long_region_cols must be the same length as long_paths"
            )
        for p, reg_col in zip(long_paths, long_region_cols):
            _collect(load_parquet_as_canonical(
                p, layout="long", region_col=reg_col, cols=cols
            ))

    # ---- subset / merge ----------------------------------
    if regions is not None:
        bucket = {r: bucket[r] for r in regions if r in bucket}

    out: CanonicalData = {}
    for r, frames in bucket.items():
        df = pd.concat(frames, axis=1)
        df = df.loc[:, ~df.columns.duplicated()]
        df.sort_index(inplace=True)
        out[r] = df

    if regions is not None:
        missing = set(regions) - set(out)
        if missing:
            raise KeyError(
                "The following requested regions were not found in any file: "
                + ", ".join(sorted(missing))
            )

    return out


### Target Creation 

In [None]:
import pandas as pd
from typing import Dict

# ------------------------------------------------------------------
# Helper 1 ── single DataFrame
# ------------------------------------------------------------------
def add_target_column(
    df: pd.DataFrame,
    minuend: str,
    subtrahend: str,
    *,
    target_name: str = "target",
    trim_to_valid: bool = False,
    inplace: bool = False,
) -> pd.DataFrame:
    """
    Add `target_name = df[minuend] - df[subtrahend]` to **one** DataFrame.

    Parameters
    ----------
    minuend, subtrahend : column names to subtract
    target_name         : name of new column (default \"target\")
    trim_to_valid       : if True, cut the DataFrame to the slice
                          [first_non_NA, last_non_NA] of the new column
    inplace             : True → mutate the original frame, False → return a copy

    Returns
    -------
    DataFrame  (original or copy, according to `inplace`)
    """
    if not inplace:
        df = df.copy()

    if minuend not in df.columns or subtrahend not in df.columns:
        raise KeyError(f\"Missing '{minuend}' or '{subtrahend}' in columns.\")

    df[target_name] = df[minuend] - df[subtrahend]

    if trim_to_valid:
        first = df[target_name].first_valid_index()
        last  = df[target_name].last_valid_index()
        if first is None or last is None:
            raise ValueError(f\"`{target_name}` has no non-NA values to trim to.\")
        df = df.loc[first:last]

    return df


# ------------------------------------------------------------------
# Helper 2 ── canonical dict (ℹ️ deep-copies unless inplace=True)
# ------------------------------------------------------------------
CanonicalData = Dict[str, pd.DataFrame]

def add_target_to_canonical(
    data: CanonicalData,
    minuend: str,
    subtrahend: str,
    *,
    target_name: str = "target",
    trim_to_valid: bool = False,
    inplace: bool = False,
) -> CanonicalData:
    """
    Add the `target_name` column to **each** DataFrame in a canonical dict.

    Parameters
    ----------
    data              : dict[str, DataFrame]
    minuend, subtrahend : columns to subtract
    target_name       : new column name
    trim_to_valid     : if True, slice every DataFrame to the first/last
                        non-NA of the new column
    inplace           : True → mutate the dict & frames in place

    Returns
    -------
    CanonicalData  (original or a deep copy, per `inplace`)
    """
    out = data if inplace else {r: df.copy() for r, df in data.items()}

    for region, df in out.items():
        if minuend not in df.columns or subtrahend not in df.columns:
            raise KeyError(
                f\"Region '{region}' is missing '{minuend}' or '{subtrahend}'.\"
            )

        df[target_name] = df[minuend] - df[subtrahend]

        if trim_to_valid:
            first = df[target_name].first_valid_index()
            last  = df[target_name].last_valid_index()
            if first is None or last is None:
                raise ValueError(
                    f\"Region '{region}': `{target_name}` has no non-NA values to trim to.\"
                )
            out[region] = df.loc[first:last]

    return out
How to use them
python
Copy
Edit
# Single DataFrame example
tokyo_df = add_target_column(
    tokyo_df,
    minuend="pri_imb_down_¥_kwh_jst_min30_a",
    subtrahend="pri_spot_jepx_¥_kwh_jst_min30_a",
    target_name="ImbalanceMinusSpot",
    trim_to_valid=True,     # ← crops to first/last non-NA
    inplace=True,
)

# Canonical dict example
data = add_target_to_canonical(
    data,
    minuend="demand",
    subtrahend="supply",
    target_name="net_demand",
    trim_to_valid=True,     # each region trimmed independently
    inplace=False,          # get a new dict back
)
Both functions now guarantee that your returned frames start at the first timestamp
where target is defined and end at the last one—handy for modelling pipelines
that can’t handle leading/trailing NaNs.

### calculating the number of subgraphs 

In [None]:
import networkx as nx
from functools import lru_cache

def connected_subsets_containing(G, start, available):
    """
    Yield every connected subset S⊆available that includes 'start'.
    """
    stack = [({start}, set(G.neighbors(start)) & available)]
    while stack:
        current, frontier = stack.pop()
        yield frozenset(current)          # every prefix is connected
        for v in list(frontier):          # expand by one frontier node
            new_current  = current | {v}
            new_frontier = (frontier | (set(G.neighbors(v)) & available)) - new_current
            stack.append((new_current, new_frontier))

def all_connected_partitions(G):
    nodes = frozenset(G.nodes)

    @lru_cache(None)
    def helper(remaining):
        if not remaining:
            return {()}                   # empty partition
        start = min(remaining)            # canonical choice
        result = set()
        for S in connected_subsets_containing(G, start, remaining):
            rest = remaining - S
            for tail in helper(rest):
                result.add(tuple(sorted((S,)+tail, key=sorted)))
        return result

    return helper(nodes)

# -----------------  example  -----------------
G = nx.Graph()
G.add_nodes_from(range(9))
G.add_edges_from([(0,1),(1,2),(2,3),(3,4),    # a 5-node path
                  (5,6),(6,7),(7,5),          # a triangle
                  (4,5)])                     # a bridge 4-5
# node 8 is isolated

parts = all_connected_partitions(G)
print("number of possible connectivity states:", len(parts))
print("some examples:")
for p in sorted(list(parts)[:10]):            # show first 10
    print("  ", [sorted(block) for block in p])

### Combining regions 

In [None]:
from typing import Dict, Sequence
import pandas as pd

CanonicalData = Dict[str, pd.DataFrame]

def combine_regions(
    data: CanonicalData,
    regions: Sequence[str],
    *,
    new_region: str | None = None,          # ← now optional
    add: Sequence[str] | None = None,
    average: Sequence[str] | None = None,
    drop: Sequence[str] | None = None,
    keep_first: Sequence[str] | None = None,
    remove_originals: bool = False,
    inplace: bool = False,
) -> CanonicalData:
    """
    Collapse multiple regional DataFrames into a single aggregate.

    Parameters
    ----------
    data, regions            : see earlier version
    new_region               : dict key for the combined frame; if omitted,
                               defaults to \"region1_region2_...\" (order-preserving)
    add, average, drop,
    keep_first, remove_originals, inplace
                              : same semantics as before

    Returns
    -------
    CanonicalData
    """
    if not regions:
        raise ValueError("`regions` must contain at least one region name.")

    missing = [r for r in regions if r not in data]
    if missing:
        raise KeyError("Regions not found: " + ", ".join(missing))

    # auto-generate name if not supplied
    if new_region is None:
        new_region = "_".join(regions)

    add      = set(add or [])
    average  = set(average or [])
    drop_set = set(drop or [])
    keep1    = set(keep_first or [])

    # guard against overlaps
    overlap = (add & average) | (add & keep1) | (average & keep1)
    overlap |= (add | average | keep1) & drop_set
    if overlap:
        raise ValueError(
            "Columns appear in multiple operation lists: "
            + ", ".join(sorted(overlap))
        )

    out: CanonicalData = data if inplace else {r: df.copy() for r, df in data.items()}

    # build wide MultiIndex frame: level-0 = region
    wide = pd.concat([out[r] for r in regions], axis=1, keys=regions)
    features = wide.columns.get_level_values(1).unique()

    combined_cols = {}
    for col in features:
        if col in drop_set:
            continue
        if col in add:
            combined_cols[col] = wide.xs(col, level=1, axis=1).sum(axis=1)
        elif col in average:
            combined_cols[col] = wide.xs(col, level=1, axis=1).mean(axis=1)
        elif col in keep1:
            combined_cols[col] = out[regions[0]][col]
        else:
            raise ValueError(
                f"No rule provided for column '{col}'. "
                "Put it in add/average/keep_first/drop."
            )

    combined_df = pd.concat(combined_cols, axis=1).sort_index()

    out[new_region] = combined_df

    if remove_originals:
        for r in regions:
            out.pop(r, None)

    return out

### Weighted Sum 

In [None]:
from typing import Dict, Sequence, Iterable
import pandas as pd
import numpy as np

CanonicalData = Dict[str, pd.DataFrame]

def combine_regions(
    data: CanonicalData,
    regions: Sequence[str],
    *,
    new_region: str | None = None,
    add: Sequence[str] | None = None,
    average: Sequence[str] | None = None,
    drop: Sequence[str] | None = None,
    keep_first: Sequence[str] | None = None,
    weights: Iterable[float] | None = None,          # ← NEW
    remove_originals: bool = False,
    inplace: bool = False,
) -> CanonicalData:
    """
    Collapse several regions into a single aggregate DataFrame.

    Parameters
    ----------
    data               : canonical dict (region → DataFrame)
    regions            : list / tuple of region keys to combine
    new_region         : key for combined frame; defaults to \"r1_r2_...\" if None
    add                : columns to *sum* across regions
    average            : columns to (weighted) *average* across regions
    drop               : columns to discard
    keep_first         : columns copied from the first region in `regions`
    weights            : iterable of weights (same length as `regions`) used
                         for columns in `average`. If None → simple mean
    remove_originals   : delete source regions after combining
    inplace            : modify `data` in place

    Returns
    -------
    CanonicalData (mutated or copied, depending on `inplace`)
    """
    if not regions:
        raise ValueError("`regions` must contain at least one region name.")

    # Validate weights --------------------------------------------------------
    if weights is not None:
        weights = list(weights)
        if len(weights) != len(regions):
            raise ValueError("`weights` length must match `regions` length.")
        w = np.array(weights, dtype=float)
        if (w < 0).any():
            raise ValueError("`weights` must be non-negative.")
        if w.sum() == 0:
            raise ValueError("Sum of `weights` must be > 0.")
        w = w / w.sum()                          # normalise once
    else:
        w = None

    # Missing-region check
    missing = [r for r in regions if r not in data]
    if missing:
        raise KeyError("Regions not found: " + ", ".join(missing))

    # Auto name if needed
    if new_region is None:
        new_region = "_".join(regions)

    # Sets for quick lookup
    add      = set(add or [])
    average  = set(average or [])
    drop_set = set(drop or [])
    keep1    = set(keep_first or [])

    # Overlap guard
    overlap = (add & average) | (add & keep1) | (average & keep1)
    overlap |= (add | average | keep1) & drop_set
    if overlap:
        raise ValueError(
            "Columns appear in multiple operation lists: "
            + ", ".join(sorted(overlap))
        )

    # Work on copy unless requested in place
    out: CanonicalData = data if inplace else {r: df.copy() for r, df in data.items()}

    # Concatenate wide frame  (lvl-0 = region, lvl-1 = feature)
    wide = pd.concat([out[r] for r in regions], axis=1, keys=regions)
    features = wide.columns.get_level_values(1).unique()

    combined_cols = {}
    for col in features:
        if col in drop_set:
            continue

        if col in add:
            combined_cols[col] = wide.xs(col, level=1, axis=1).sum(axis=1)

        elif col in average:
            block = wide.xs(col, level=1, axis=1)
            if w is None:
                combined_cols[col] = block.mean(axis=1)
            else:
                combined_cols[col] = (block.values * w).sum(axis=1)

        elif col in keep1:
            combined_cols[col] = out[regions[0]][col]

        else:
            raise ValueError(
                f"No rule specified for column '{col}'. "
                "Put it in add, average, keep_first, or drop."
            )

    combined_df = pd.concat(combined_cols, axis=1).sort_index()

    out[new_region] = combined_df

    if remove_originals:
        for r in regions:
            out.pop(r, None)

    return out

### Dummy Columns 

In [None]:
from typing import Dict
import pandas as pd

CanonicalData = Dict[str, pd.DataFrame]

def add_wide_area_coupling(
    data: CanonicalData,
    *,
    category_col: str = "wide_area_category",
    prefix: str = "is_same_wide_area_",
    inplace: bool = False,
) -> CanonicalData:
    """
    For every DataFrame in *canonical* `data`, append indicator columns that
    flag whether its `category_col` value equals the corresponding value of
    each other region **row-by-row**.

    New columns are named  f\"{prefix}{other_region}\".
    The self column (e.g. \"is_same_wide_area_tokyo\" in Tokyo’s frame) is
    constant 1, giving you the requested “base region” flag.

    Parameters
    ----------
    data         : dict[region, DataFrame]
    category_col : column containing the wide-area ID in *each* frame
    prefix       : column-name prefix for the indicators
    inplace      : if True mutate `data`; otherwise return a deep copy

    Returns
    -------
    CanonicalData  (same object or a deep copy, according to `inplace`)
    """
    # ----------------------------------------------------------------- checks
    missing = [r for r, df in data.items() if category_col not in df.columns]
    if missing:
        raise KeyError(
            f"Column '{category_col}' not found in: " + ", ".join(missing)
        )

    out = data if inplace else {r: df.copy() for r, df in data.items()}

    # ----------------------------------------------------------------- gather
    # Build a wide frame: columns = region, index = union of all timestamps
    cat_df = pd.concat(
        {r: df[category_col] for r, df in out.items()}, axis=1
    )  # outer join → NaNs where region missing

    # ----------------------------------------------------------------- expand indicators
    for region in out:
        # Equality test row-wise (NaNs never match)
        same = cat_df.eq(cat_df[region], axis=0).astype(int)

        # Rename columns, keep *all* regions incl. self
        same = same.rename(columns=lambda r: f"{prefix}{r}")

        # Attach to the region’s DataFrame (align on index)
        out[region] = pd.concat([out[region], same], axis=1)

    return out


### EMWA 

In [None]:
import pandas as pd
import numpy as np
from typing import Dict

CanonicalData = Dict[str, pd.DataFrame]

# ---------------------------------------------------------------------------
# 1)  Single-DataFrame EWMA scaler
# ---------------------------------------------------------------------------
def scale_df_ewm(
    df: pd.DataFrame,
    *,
    halflife: int = 1_000,
    burnin_steps: int = 1_000,
    remove_na: bool = True,
) -> pd.DataFrame:
    """
    EWMA-standardise the numeric columns of *one* DataFrame.

    Steps
    -----
    1. Identify numeric vs. non-numeric cols.
    2. Compute EWMA mean & std for numeric cols.
    3. Mask the first `burnin_steps` rows, then bfill.
    4. Scale: (value - mean) / std.
    5. Re-attach categorical cols, cast everything to float.
    6. Optionally drop any remaining NaNs.

    Parameters
    ----------
    halflife     : EWMA halflife in rows
    burnin_steps : number of initial rows to mask (so they aren’t scaled)
    remove_na    : if True, drop rows containing NaNs after scaling

    Returns
    -------
    Scaled DataFrame (all float dtype)
    """
    df = df.copy()

    # Split cols
    num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
    cat_cols = [c for c in df.columns if c not in num_cols]

    if num_cols:
        # EWMA mean / std
        means = (
            df[num_cols]
            .ewm(halflife=halflife, adjust=False)
            .mean()
        )
        stds = (
            df[num_cols]
            .ewm(halflife=halflife, adjust=False)
            .std()
            + 1e-8
        )

        # Burn-in mask
        means.iloc[:burnin_steps, :] = np.nan
        stds.iloc[:burnin_steps, :] = np.nan

        # Backfill masked rows
        means = means.bfill()
        stds = stds.bfill()

        scaled_num = (df[num_cols] - means) / stds
    else:
        scaled_num = pd.DataFrame(index=df.index)

    # Reassemble, preserve original column order
    scaled = pd.concat([scaled_num, df[cat_cols]], axis=1)[df.columns]

    scaled = scaled.astype(float)

    if remove_na:
        scaled = scaled.dropna()

    return scaled


# ---------------------------------------------------------------------------
# 2)  Canonical-dict EWMA scaler
# ---------------------------------------------------------------------------
def scale_canonical_ewm(
    data: CanonicalData,
    *,
    halflife: int = 1_000,
    burnin_steps: int = 1_000,
    remove_na: bool = True,
    inplace: bool = False,
) -> CanonicalData:
    """
    Apply `scale_df_ewm` to **every** DataFrame in a canonical dict.

    Parameters
    ----------
    halflife, burnin_steps, remove_na : forwarded to `scale_df_ewm`
    inplace                           : if False, deep-copies dict & frames

    Returns
    -------
    CanonicalData  (mutated or copied, per `inplace`)
    """
    out = data if inplace else {r: df.copy() for r, df in data.items()}

    for region, df in out.items():
        out[region] = scale_df_ewm(
            df,
            halflife=halflife,
            burnin_steps=burnin_steps,
            remove_na=remove_na,
        )

    return out


### Expanding Mean and std deviation 

In [None]:
import pandas as pd
import numpy as np
from typing import Dict

CanonicalData = Dict[str, pd.DataFrame]

# ---------------------------------------------------------------------------
# 1)  Single-DataFrame EXPANDING scaler
# ---------------------------------------------------------------------------
def scale_df_expanding(
    df: pd.DataFrame,
    *,
    burnin_steps: int = 1_000,
    remove_na: bool = True,
) -> pd.DataFrame:
    """
    Standardise numeric columns using their *expanding* mean/std.

    Parameters
    ----------
    burnin_steps : rows to mask (then back-fill) before scaling
    remove_na    : drop remaining NaNs after scaling

    Returns
    -------
    DataFrame  (all columns float)
    """
    df = df.copy()

    num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
    cat_cols = [c for c in df.columns if c not in num_cols]

    if num_cols:
        means = (
            df[num_cols]
            .expanding(min_periods=1)
            .mean()
        )
        stds = (
            df[num_cols]
            .expanding(min_periods=2)        # std needs ≥2 points
            .std()
            + 1e-8
        )

        # mask early window
        means.iloc[:burnin_steps, :] = np.nan
        stds.iloc[:burnin_steps, :] = np.nan
        means = means.bfill()
        stds = stds.bfill()

        scaled_num = (df[num_cols] - means) / stds
    else:
        scaled_num = pd.DataFrame(index=df.index)

    scaled = pd.concat([scaled_num, df[cat_cols]], axis=1)[df.columns]
    scaled = scaled.astype(float)

    if remove_na:
        scaled = scaled.dropna()

    return scaled


# ---------------------------------------------------------------------------
# 2)  Canonical-dict EXPANDING scaler
# ---------------------------------------------------------------------------
def scale_canonical_expanding(
    data: CanonicalData,
    *,
    burnin_steps: int = 1_000,
    remove_na: bool = True,
    inplace: bool = False,
) -> CanonicalData:
    """
    Apply `scale_df_expanding` to every DataFrame in a canonical dict.

    Parameters
    ----------
    burnin_steps, remove_na : forwarded to single-frame helper
    inplace                 : mutate original dict & frames

    Returns
    -------
    CanonicalData
    """
    out = data if inplace else {r: df.copy() for r, df in data.items()}

    for region, df in out.items():
        out[region] = scale_df_expanding(
            df,
            burnin_steps=burnin_steps,
            remove_na=remove_na,
        )

    return out

In [None]:
import pandas as pd
import numpy as np
from typing import Dict, Sequence

# ---------------------------------------------------------------------------
# Common alias
# ---------------------------------------------------------------------------
CanonicalData = Dict[str, pd.DataFrame]

# ---------------------------------------------------------------------------
# Utility used by both scalers
# ---------------------------------------------------------------------------
def _prep_cols(df: pd.DataFrame, target: Sequence[str] | str | None):
    """
    Split columns into:
        numeric_not_target, categorical, target_cols
    Returns each list plus the original column order.
    """
    if target is None:
        target = []
    elif isinstance(target, str):
        target = [target]

    missing = [c for c in target if c not in df.columns]
    if missing:
        raise KeyError("Target column(s) not found: " + ", ".join(missing))

    target_set = set(target)

    numeric_cols  = [
        c for c in df.columns
        if pd.api.types.is_numeric_dtype(df[c]) and c not in target_set
    ]
    cat_cols      = [c for c in df.columns if c not in numeric_cols and c not in target_set]
    target_cols   = list(target)

    return numeric_cols, cat_cols, target_cols, list(df.columns)

# ---------------------------------------------------------------------------
# 1)  SINGLE-DATAFRAME  ▸  EWMA
# ---------------------------------------------------------------------------
def scale_df_ewm(
    df: pd.DataFrame,
    *,
    halflife: int = 1_000,
    burnin_steps: int = 1_000,
    remove_na: bool = True,
    target: Sequence[str] | str | None = None,
) -> pd.DataFrame:
    """
    EWMA-standardise numeric columns, keep `target` column(s) unchanged.

    All columns are cast to float on return.
    """
    df = df.copy()
    num_cols, cat_cols, tar_cols, orig_order = _prep_cols(df, target)

    # ---------- scale numeric columns
    if num_cols:
        means = df[num_cols].ewm(halflife=halflife, adjust=False).mean()
        stds  = df[num_cols].ewm(halflife=halflife, adjust=False).std() + 1e-8

        means.iloc[:burnin_steps, :] = np.nan
        stds.iloc[:burnin_steps, :]  = np.nan
        means, stds = means.bfill(), stds.bfill()

        scaled_num = (df[num_cols] - means) / stds
    else:
        scaled_num = pd.DataFrame(index=df.index)

    # ---------- reassemble & cast
    parts = [scaled_num, df[cat_cols], df[tar_cols]]
    out   = pd.concat(parts, axis=1)[orig_order].astype(float)

    if remove_na:
        out = out.dropna()

    return out

# ---------------------------------------------------------------------------
# 2)  CANONICAL DICT  ▸  EWMA
# ---------------------------------------------------------------------------
def scale_canonical_ewm(
    data: CanonicalData,
    *,
    halflife: int = 1_000,
    burnin_steps: int = 1_000,
    remove_na: bool = True,
    target: Sequence[str] | str | None = None,
    inplace: bool = False,
) -> CanonicalData:
    """
    Apply `scale_df_ewm` to every DataFrame in a canonical dict.
    """
    out = data if inplace else {r: df.copy() for r, df in data.items()}

    for r in out:
        out[r] = scale_df_ewm(
            out[r],
            halflife=halflife,
            burnin_steps=burnin_steps,
            remove_na=remove_na,
            target=target,
        )
    return out

# ---------------------------------------------------------------------------
# 3)  SINGLE-DATAFRAME  ▸  EXPANDING
# ---------------------------------------------------------------------------
def scale_df_expanding(
    df: pd.DataFrame,
    *,
    burnin_steps: int = 1_000,
    remove_na: bool = True,
    target: Sequence[str] | str | None = None,
) -> pd.DataFrame:
    """
    Expanding-window standardisation (cumulative mean/std) with `target`
    column(s) left untouched.
    """
    df = df.copy()
    num_cols, cat_cols, tar_cols, orig_order = _prep_cols(df, target)

    if num_cols:
        means = df[num_cols].expanding(min_periods=1).mean()
        stds  = df[num_cols].expanding(min_periods=2).std() + 1e-8

        means.iloc[:burnin_steps, :] = np.nan
        stds.iloc[:burnin_steps, :]  = np.nan
        means, stds = means.bfill(), stds.bfill()

        scaled_num = (df[num_cols] - means) / stds
    else:
        scaled_num = pd.DataFrame(index=df.index)

    out = pd.concat([scaled_num, df[cat_cols], df[tar_cols]], axis=1)[orig_order].astype(float)

    if remove_na:
        out = out.dropna()

    return out

# ---------------------------------------------------------------------------
# 4)  CANONICAL DICT  ▸  EXPANDING
# ---------------------------------------------------------------------------
def scale_canonical_expanding(
    data: CanonicalData,
    *,
    burnin_steps: int = 1_000,
    remove_na: bool = True,
    target: Sequence[str] | str | None = None,
    inplace: bool = False,
) -> CanonicalData:
    """
    Apply `scale_df_expanding` to every DataFrame in a canonical dict.
    """
    out = data if inplace else {r: df.copy() for r, df in data.items()}

    for r in out:
        out[r] = scale_df_expanding(
            out[r],
            burnin_steps=burnin_steps,
            remove_na=remove_na,
            target=target,
        )
    return out


### Plotting function 

In [None]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from statsmodels.nonparametric.smoothers_lowess import lowess


def animated_lowess_trend(
    df,
    x_col: str,
    y_col: str,
    period: str = "W",          # "D", "W", "M", "A" (year) … any Pandas offset alias
    lowess_frac: float = 0.3,   # smoothing span for LOWESS (0-1)
    nbins_hist: int = 40,       # #bins for the marginal histogram
    show_scatter: bool = True,  # plot raw points?
    scatter_opacity: float = 0.2,
    hist_height: float = 0.22,  # fraction of fig height reserved for histogram
):
    """
    Animated LOWESS trend of ``y_col`` vs ``x_col`` over successive time periods.

    * df must have a DatetimeIndex
    * x_col and y_col must be numeric
    """

    # ------------------------------------------------------------------ #
    # Prep: label rows by time period
    # ------------------------------------------------------------------ #
    df2 = df.copy()
    df2["period"] = df2.index.to_series().dt.to_period(period).astype(str)
    periods = sorted(df2["period"].unique())

    # Axis ranges (so they stay fixed across frames)
    x_range = (df2[x_col].min(), df2[x_col].max())
    y_range = (df2[y_col].min(), df2[y_col].max())

    # ------------------------------------------------------------------ #
    # Build frames
    # ------------------------------------------------------------------ #
    frames = []
    for p in periods:
        frame_df = df2[df2["period"] == p]

        # LOWESS smoothing – returns sorted (x, ŷ)
        smooth_xy = lowess(
            frame_df[y_col].values,
            frame_df[x_col].values,
            frac=lowess_frac,
            return_sorted=True,
        )
        x_smooth, y_smooth = smooth_xy[:, 0], smooth_xy[:, 1]

        # -- trend line
        trend_trace = go.Scatter(
            x=x_smooth,
            y=y_smooth,
            mode="lines",
            line=dict(width=2),
            name="LOWESS trend",
            showlegend=(p == periods[0]),
        )

        # -- optional raw points (low opacity)
        scatter_trace = go.Scatter(
            x=frame_df[x_col],
            y=frame_df[y_col],
            mode="markers",
            marker=dict(size=4, opacity=scatter_opacity),
            name="points",
            showlegend=False,
            visible=show_scatter,
        )

        # -- slim histogram of x
        hist_trace = go.Histogram(
            x=frame_df[x_col],
            nbinsx=nbins_hist,
            marker=dict(opacity=0.6),
            showlegend=False,
        )

        frames.append(go.Frame(data=[trend_trace, scatter_trace, hist_trace], name=p))

    # ------------------------------------------------------------------ #
    # Initial traces (first period)
    # ------------------------------------------------------------------ #
    init_trend, init_scatter, init_hist = frames[0].data

    # ------------------------------------------------------------------ #
    # Slider
    # ------------------------------------------------------------------ #
    slider_steps = [
        dict(
            method="animate",
            args=[[p], {"frame": {"duration": 450, "redraw": False}, "mode": "immediate"}],
            label=p,
        )
        for p in periods
    ]

    # ------------------------------------------------------------------ #
    # Figure & subplots (2 rows: main + histogram)
    # ------------------------------------------------------------------ #
    fig = make_subplots(
        rows=2,
        cols=1,
        shared_xaxes=True,
        vertical_spacing=0.03,
        row_heights=[1 - hist_height, hist_height],
    )

    # Add initial traces
    fig.add_trace(init_trend, row=1, col=1)
    fig.add_trace(init_scatter, row=1, col=1)
    fig.add_trace(init_hist, row=2, col=1)

    # Axes labels / ranges
    fig.update_yaxes(title_text=y_col, range=y_range, row=1, col=1)
    fig.update_yaxes(showticklabels=False, row=2, col=1)
    fig.update_xaxes(title_text=x_col, range=x_range, row=2, col=1)

    # Layout with slider & play/pause buttons
    fig.update_layout(
        height=520,
        title=f"{y_col} vs {x_col} — LOWESS Trend, animated by {period}",
        sliders=[
            dict(
                active=0,
                pad={"t": 55},
                steps=slider_steps,
                currentvalue={"prefix": "Period: "},
            )
        ],
        updatemenus=[
            {
                "type": "buttons",
                "buttons": [
                    {
                        "label": "Play",
                        "method": "animate",
                        "args": [
                            None,
                            {"frame": {"duration": 450, "redraw": False}, "fromcurrent": True},
                        ],
                    },
                    {
                        "label": "Pause",
                        "method": "animate",
                        "args": [[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate"}],
                    },
                ],
                "direction": "left",
                "pad": {"r": 10, "t": 10},
                "showactive": False,
                "x": 0.1,
                "y": 1.17,
                "xanchor": "right",
            }
        ],
        bargap=0.07,
        frames=frames,
    )

    fig.show()


In [None]:
# df = your DataFrame; assume its index is a DatetimeIndex
animated_lowess_trend(
    df,
    x_col="feature_x",
    y_col="target_y",
    period="M",          # switch to "D", "W", "A", etc. as you like
    lowess_frac=0.25,    # narrower span → wigglier curve
    show_scatter=True,   # set False if you only want the smooth line
)

### Plotting function v2

In [None]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from statsmodels.nonparametric.smoothers_lowess import lowess


def lowess_slider_plot(
    df: pd.DataFrame,
    x_col: str,
    y_col: str,
    period: str = "M",           # "D", "W", "M", "A", …
    lowess_frac: float = 0.25,   # 0 → no smoothing, 1 → very smooth
    show_scatter: bool = True,
    scatter_opacity: float = 0.25,
    nbins_hist: int = 40,
    hist_height: float = 0.20,   # fraction of total height
    marker_colour: str = "#636EFA",  # Plotly’s default blue
    line_colour: str = "#EF553B",    # Plotly’s default red-orange
):
    """
    Animated LOWESS + histogram.

    * `df` must have a DatetimeIndex.
    * `x_col` / `y_col` numeric.
    """

    # --- add period label -------------------------------------------------
    df_ = df.copy()
    df_["period"] = df_.index.to_period(period).astype(str)
    periods = sorted(df_["period"].unique())

    # --- global axis ranges (stop auto-zoom jumpiness) -------------------
    x_min, x_max = df_[x_col].min(), df_[x_col].max()
    y_min, y_max = df_[y_col].min(), df_[y_col].max()

    # --- build histogram edges once, re-use each frame -------------------
    hist_counts, hist_edges = np.histogram(df_[x_col], bins=nbins_hist)
    hist_centres = (hist_edges[:-1] + hist_edges[1:]) / 2
    bin_width = hist_edges[1] - hist_edges[0]

    # --- animation frames -------------------------------------------------
    frames = []
    for p in periods:
        dfi = df_[df_["period"] == p]

        # LOWESS
        sm_x, sm_y = lowess(
            dfi[y_col].values,
            dfi[x_col].values,
            frac=lowess_frac,
            return_sorted=True,
        ).T

        trend = go.Scatter(
            x=sm_x, y=sm_y,
            mode="lines",
            line=dict(width=2, color=line_colour),
            name="LOWESS trend",
            showlegend=(p == periods[0]),
        )

        points = go.Scatter(
            x=dfi[x_col], y=dfi[y_col],
            mode="markers",
            marker=dict(size=5, color=marker_colour, opacity=scatter_opacity),
            name="points",
            showlegend=False,
            visible=show_scatter,
        )

        # The histogram stays identical for all frames → reuse same data
        # & let Plotly just redraw it (avoids disappearing bars bug)
        frames.append(go.Frame(data=[trend, points], name=p))

    # --- figure layout & subplots ----------------------------------------
    fig = make_subplots(
        rows=2, cols=1,
        shared_xaxes=True,
        row_heights=[1 - hist_height, hist_height],
        vertical_spacing=0.03,
    )

    # add initial traces (frame 0)
    init_trend, init_points = frames[0].data
    fig.add_trace(init_trend,  row=1, col=1)
    fig.add_trace(init_points, row=1, col=1)

    # one static bar trace for histogram
    fig.add_trace(
        go.Bar(
            x=hist_centres,
            y=hist_counts,
            width=bin_width * 0.9,
            marker=dict(color=marker_colour, opacity=0.65),
            name=f"{x_col} distribution",
            showlegend=False,
        ),
        row=2, col=1,
    )

    # axis styling
    fig.update_xaxes(title=x_col, range=[x_min, x_max], row=2, col=1)
    fig.update_yaxes(title=y_col, range=[y_min, y_max], zeroline=True,
                     zerolinecolor="black", zerolinewidth=1, row=1, col=1)
    fig.update_yaxes(showticklabels=False, row=2, col=1)

    # background, font, grid
    fig.update_layout(
        template="simple_white",
        paper_bgcolor="white",
        plot_bgcolor="white",
        font=dict(color="black"),
        height=550,
        title={
            "text": f"{y_col} vs {x_col}  |  LOWESS trend  |  period = {period}",
            "x": 0.5, "xanchor": "center",
        },
        bargap=0.05,
    )

    # slider
    slider_steps = [
        dict(method="animate",
             args=[[p], {"frame": {"duration": 400, "redraw": False},
                         "mode": "immediate"}],
             label=p)
        for p in periods
    ]
    fig.update_layout(
        sliders=[dict(
            active=0,
            currentvalue={"prefix": "Period: "},
            pad={"t": 50},
            steps=slider_steps,
        )],
        updatemenus=[{
            "type": "buttons",
            "buttons": [
                {"label": "Play", "method": "animate",
                 "args": [None, {"frame": {"duration": 400, "redraw": False},
                                 "fromcurrent": True}]},
                {"label": "Pause", "method": "animate",
                 "args": [[None], {"frame": {"duration": 0, "redraw": False},
                                   "mode": "immediate"}]},
            ],
            "showactive": False,
            "x": 0.1, "y": 1.18, "xanchor": "right",
            "direction": "left",
            "pad": {"r": 10, "t": 10},
        }],
    )

    # attach frames & show
    fig.frames = frames
    fig.show()
