## Trading costs 

In [None]:
"""
Trading-cost / marginal-price & elasticity helpers
==================================================

Public API
----------
trading_cost(slice_df, vol_mwh, side)            -> total_¥ , clearing_price
marginal_price(slice_df, extra_vol_mwh, side)    -> incr_¥ , incr_¥/kWh , Δprice
trading_cost_series(container, vol_mwh, side)    -> DataFrame per timeslot
marginal_price_series(container, extra_vol_mwh)  -> DataFrame per timeslot
elasticity(slice_df, side)                       -> dV/dP Series or DataFrame
elasticity_panel(container, side)                -> time × price grid(s)
"""
from __future__ import annotations
import numpy as np
import pandas as pd
from typing import Tuple, Literal, Union

from jp_da_imb.trading_costs.bid_curve_struc import BidCurve, MultiBidCurve
from jp_da_imb.trading_costs.bid_curve_stats import (
    clearing_price,
    clearing_demand,
    _clearing_price_volume,      # internal helper from stats module
    _iter_slices,                # internal helper from stats module
)

# ---------------------------------------------------------------------
# 1)  total cost to buy / sell 'vol_mwh'  (uniform-price auction)
# ---------------------------------------------------------------------
def trading_cost(
    slice_df: pd.DataFrame,
    vol_mwh: float,
    *,
    side: Literal["buy", "sell"] = "buy",
) -> Tuple[float, float]:
    """Return (total ¥, clearing_price ¥/kWh) for the extra tranche."""
    if vol_mwh <= 0:
        raise ValueError("vol_mwh must be positive")

    price = slice_df.index.to_numpy(dtype=float)

    if side == "buy":
        cum = slice_df["supply_cum"].to_numpy(dtype=float)
        cp = np.interp(vol_mwh, cum, price)
    elif side == "sell":
        # Demand curve runs from high-price to low-price
        price_rev = price[::-1]
        cum_rev   = slice_df["demand_cum"].to_numpy(dtype=float)[::-1]
        cp = np.interp(vol_mwh, cum_rev, price_rev)
    else:
        raise ValueError("side must be 'buy' or 'sell'")

    total = vol_mwh * cp * 1_000          # ¥/kWh × MWh
    return float(total), float(cp)


# ---------------------------------------------------------------------
# 2)  marginal price impact of an *extra* volume tranche
# ---------------------------------------------------------------------
def marginal_price(
    slice_df: pd.DataFrame,
    extra_vol_mwh: float = 0.5,
    *,
    side: Literal["buy", "sell"] = "buy",
) -> Tuple[float, float, float]:
    """
    Incremental cost ¥, marginal price ¥/kWh and Δprice ¥/kWh beyond clearing.

    • Adds *extra_vol_mwh* to demand (buy) or supply (sell) cum-curve,
      recomputes clearing, and returns:
        (incremental_cost_¥ , marginal_price_¥/kWh , |Δclearing_price|)
    """
    if extra_vol_mwh <= 0:
        raise ValueError("extra_vol_mwh must be positive")

    # original clearing
    cv0 = clearing_demand(slice_df)
    cp0 = clearing_price(slice_df)

    # shift curve
    adj = slice_df.copy()
    if side == "buy":
        adj["demand_cum"] += extra_vol_mwh
    elif side == "sell":
        adj["supply_cum"] += extra_vol_mwh

    # new clearing
    cv1 = clearing_demand(adj)
    cp1 = clearing_price(adj)

    incr_cost  = (cv1 + 0.0) * cp1 * 1_000 - cv0 * cp0 * 1_000
    incr_price = incr_cost / (extra_vol_mwh * 1_000)

    return float(incr_cost), float(incr_price), abs(cp1 - cp0)


# ---------------------------------------------------------------------
# 3)  vectorised helpers  (work with BidCurve or MultiBidCurve)
# ---------------------------------------------------------------------
def trading_cost_series(
    container: Union[BidCurve, MultiBidCurve],
    vol_mwh: float,
    *,
    side: Literal["buy", "sell"] = "buy",
) -> pd.DataFrame:
    """DataFrame with columns ['total_cost','clearing_price'] for every slot."""
    labels, costs, cps = [], [], []
    for lbl, sl in _iter_slices(container):
        c, p = trading_cost(sl, vol_mwh, side=side)
        labels.append(lbl); costs.append(c); cps.append(p)

    idx_name = "timestamp" if isinstance(container, MultiBidCurve) else "time_code"
    return pd.DataFrame(
        {"total_cost": costs, "clearing_price": cps},
        index=pd.Index(labels, name=idx_name)
    )


def marginal_price_series(
    container: Union[BidCurve, MultiBidCurve],
    extra_vol_mwh: float = 0.5,
    *,
    side: Literal["buy", "sell"] = "buy",
) -> pd.DataFrame:
    """DataFrame with ['incremental_cost','marginal_price'] per timeslot."""
    labels, inc_c, inc_p = [], [], []
    for lbl, sl in _iter_slices(container):
        c, p, _ = marginal_price(sl, extra_vol_mwh=extra_vol_mwh, side=side)
        labels.append(lbl); inc_c.append(c); inc_p.append(p)

    idx_name = "timestamp" if isinstance(container, MultiBidCurve) else "time_code"
    return pd.DataFrame(
        {"incremental_cost": inc_c, "marginal_price": inc_p},
        index=pd.Index(labels, name=idx_name)
    )



In [None]:
"""
Plot helpers for supply–demand curves and trading-cost metrics.
"""
from __future__ import annotations
import matplotlib.pyplot as plt
import pandas as pd
from typing import Literal, Union

from jp_da_imb.trading_costs.bid_curve_struc import BidCurve, MultiBidCurve
from jp_da_imb.trading_costs.trading_cost_estimation import (
    trading_cost_series, marginal_price_series
)
from jp_da_imb.trading_costs.bid_curve_stats import timeslot_summary

# ────────────────────────────────────────────────────────────────────
def _get_slice(container, ts):
    if isinstance(container, MultiBidCurve):
        return container[ts]
    if isinstance(container, BidCurve):
        if isinstance(ts, int):
            return container.slice_time(ts)
        raise TypeError("BidCurve expects time_code 1–48 (int)")
    raise TypeError("container must be BidCurve or MultiBidCurve")


# -------------------------------------------------------------------
# 1) supply & demand curve for a single timestamp
# -------------------------------------------------------------------
def plot_supply_demand(
    container: Union[BidCurve, MultiBidCurve],
    ts: Union[str, pd.Timestamp, int],
    *,
    ylim: float | None = None,
    xlow: float | None = None,
    xhigh: float | None = None,
):
    curve = _get_slice(container, ts)
    px  = curve.index.to_numpy(dtype=float)
    sup = curve["supply_cum"].to_numpy(dtype=float)
    dem = curve["demand_cum"].to_numpy(dtype=float)

    fig, ax = plt.subplots()
    ax.step(px, sup, where="post", label="Supply (cum)", linewidth=1.2)
    ax.step(px, dem, where="post", label="Demand (cum)", linewidth=1.2)
    ax.set_xlabel("Price [¥/kWh]")
    ax.set_ylabel("Cumulative volume [MWh]")
    ax.set_title(f"JEPX supply & demand curves — {ts}")
    if ylim:  ax.set_ylim(bottom=0, top=ylim)
    if xlow or xhigh: ax.set_xlim(xlow, xhigh)
    ax.legend()
    ax.grid(True, which="both", alpha=0.3)
    return fig


# -------------------------------------------------------------------
# 2) trading-cost series  (uniform clearing price)
# -------------------------------------------------------------------
def plot_trading_cost_series(
    container: Union[BidCurve, MultiBidCurve],
    vol_mwh: float,
    *,
    side: Literal["buy", "sell"] = "buy",
    metric: Literal["total_cost", "clearing_price"] = "total_cost",
):
    df = trading_cost_series(container, vol_mwh, side=side)
    if metric not in df.columns:
        raise ValueError(f"metric must be one of {list(df.columns)}")

    fig, ax = plt.subplots()
    ax.plot(df.index, df[metric], linewidth=1.2)
    ax.set_xlabel("Time" if isinstance(container, MultiBidCurve) else "Time code (1-48)")
    ylabel = "Total cost [¥]" if metric == "total_cost" else "Price [¥/kWh]"
    ax.set_ylabel(ylabel)
    ax.set_title(f"{metric.replace('_',' ').title()} — {side.upper()} {vol_mwh} MWh each slot")
    ax.grid(True, alpha=0.3)
    return fig


# -------------------------------------------------------------------
# 3) any metric from timeslot_summary()
# -------------------------------------------------------------------
def plot_metric_series(
    container: Union[BidCurve, MultiBidCurve],
    metric: Literal[
        "clearing_price", "clearing_volume",
        "vwap", "price_min", "price_max", "imbalance_integral"
    ] = "clearing_price",
    *,
    vwap_side: str = "supply",
):
    summary = timeslot_summary(container, vwap_side=vwap_side)
    if metric not in summary.columns:
        raise ValueError(f"metric must be one of {list(summary.columns)}")

    fig, ax = plt.subplots()
    ax.plot(summary.index, summary[metric], linewidth=1.2)
    ax.set_xlabel("Time" if isinstance(container, MultiBidCurve) else "Time code (1-48)")
    ax.set_ylabel(metric.replace("_", " ").title())
    ax.set_title(f"{metric.replace('_', ' ').title()} per timeslot")
    ax.grid(True, alpha=0.3)
    return fig


In [None]:
# ------------------------------------------------------------
# helper – iterate over every 30-min slice in a container
# ------------------------------------------------------------
def _iter_slices(container: Union[BidCurve, MultiBidCurve]):
    """
    Yield (label, slice_df) pairs where *label* is:
      • time_code 1-48            for BidCurve
      • pd.Timestamp              for MultiBidCurve
    """
    if isinstance(container, BidCurve):
        for tc in range(1, 49):
            yield tc, container.slice_time(tc)

    elif isinstance(container, MultiBidCurve):
        for ts in container:
            yield ts, container[ts]      # indexing sugar we added

    else:
        raise TypeError("Expected BidCurve or MultiBidCurve")


# ------------------------------------------------------------
# clearing-series  (price & volume for **all** slices)
# ------------------------------------------------------------
def clearing_series(
    container: Union[BidCurve, MultiBidCurve],
    *,
    return_dataframe: bool = True
) -> Union[pd.DataFrame, Tuple[pd.Series, pd.Series]]:
    """
    Vectorised clearing stats for *all* slices in the container.

    Returns
    -------
    • DataFrame (default) with columns ['clearing_price', 'clearing_volume']
      and index = time_code 1-48 | timestamps
    • Or (price_series, volume_series) if `return_dataframe=False`
    """
    labels, prices, vols = [], [], []

    for lbl, curve_slice in _iter_slices(container):
        cp, cv = clearing_price_volume(curve_slice)
        labels.append(lbl)
        prices.append(cp)
        vols.append(cv)

    price_s = pd.Series(prices, index=labels, name="clearing_price")
    vol_s   = pd.Series(vols,   index=labels, name="clearing_volume")

    if return_dataframe:
        return pd.concat([price_s, vol_s], axis=1)

    return price_s, vol_s


# ------------------------------------------------------------
# residual volume  (supply or demand beyond a price threshold)
# ------------------------------------------------------------
def residual_volume(
    slice_df: pd.DataFrame,
    price_threshold: float,
    side: str = "supply"
) -> float:

    price = slice_df.index.to_numpy(dtype=float)

    if side == "supply":
        sup   = slice_df["supply_cum"].to_numpy(dtype=float)
        total = sup[-1]
        filled = np.interp(price_threshold, price, sup)
        return float(total - filled)

    if side == "demand":
        dem    = slice_df["demand_cum"].to_numpy(dtype=float)
        filled = np.interp(price_threshold, price, dem)
        return float(filled)

    raise ValueError("side must be 'supply' or 'demand'")


# ------------------------------------------------------------
# imbalance  (supply minus demand; area if integrated=True)
# ------------------------------------------------------------
def imbalance(
    slice_df: pd.DataFrame,
    *,
    integrated: bool = False
):
    """
    Supply-minus-demand difference for each price bin.

    Returns
    -------
    • pd.Series if integrated=False
    • scalar     if integrated=True
    """
    diff = slice_df["supply_cum"] - slice_df["demand_cum"]

    if integrated:
        price = slice_df.index.to_numpy(dtype=float)
        return float(np.trapz(diff, price))

    return diff


# ==================================================================
# BidCurve   – one day (48 rows)
# ==================================================================
@dataclass
class BidCurve:
    region: str
    date:   pd.Timestamp
    bins:   np.ndarray
    supply: pd.DataFrame          # 48 × N_bins
    demand: pd.DataFrame
    df_raw: pd.DataFrame
    _long_cache: pd.DataFrame | None = field(default=None, init=False, repr=False)

    def slice_time(self, time_code: int) -> pd.DataFrame:
        """Return (price_bin × 2) DataFrame for a 30-min slot."""
        if not (1 <= time_code <= 48):
            raise ValueError("time_code must be 1-48")
        row = time_code - 1
        return pd.DataFrame(
            {
                "supply_cum":  self.supply.iloc[row].values,
                "demand_cum":  self.demand.iloc[row].values,
            },
            index=self.bins,
        )

    def to_long(self) -> pd.DataFrame:
        if self._long_cache is not None:
            return self._long_cache

        df_sup = (
            self.supply.assign(side="supply")
            .stack().rename("cum_vol")
            .reset_index(names=["time_code", "price"])
        )
        df_dem = (
            self.demand.assign(side="demand")
            .stack().rename("cum_vol")
            .reset_index(names=["time_code", "price"])
        )
        long = pd.concat([df_sup, df_dem])
        long["date"]    = self.date
        long["region"]  = self.region
        self._long_cache = long[["date", "region", "time_code", "side", "price", "cum_vol"]]
        return self._long_cache

    def __repr__(self):
        return f"<BidCurve {self.region} {self.date.date()} (48 × {len(self.bins)} bins)>"


# ==================================================================
# MultiBidCurve   – many days, 30-min timestamps
# ==================================================================
@dataclass
class MultiBidCurve:
    region: str
    bins:   np.ndarray
    supply: pd.DataFrame          # index = timestamp, columns = bins
    demand: pd.DataFrame
    df_raw: pd.DataFrame
    _long_cache: pd.DataFrame | None = field(default=None, init=False, repr=False)

    @property
    def dates(self) -> pd.DatetimeIndex:
        return self.supply.index.normalize().unique()

    def slice_time(self, ts: Union[str, pd.Timestamp]) -> pd.DataFrame:
        ts = pd.to_datetime(ts)
        row_sup = self.supply.loc[ts]
        row_dem = self.demand.loc[ts]
        return pd.DataFrame(
            {"supply_cum": row_sup.values, "demand_cum": row_dem.values},
            index=self.bins,
        )

    def slice_day(self, date: Union[str, pd.Timestamp]) -> BidCurve:
        d = pd.to_datetime(date).normalize()
        mask = self.supply.index.normalize() == d
        if mask.sum() != 48:
            raise ValueError(f"Expected 48 rows for {d.date()}, got {mask.sum()}")
        sup_day = self.supply.loc[mask].reset_index(drop=True)
        dem_day = self.demand.loc[mask].reset_index(drop=True)
        return BidCurve(
            region=self.region,
            date=d,
            bins=self.bins,
            supply=sup_day,
            demand=dem_day,
            df_raw=self.df_raw.loc[mask].reset_index(drop=True),
        )


# ------------------------------------------------------------
# timeslot_summary  (per-slot diagnostics table)
# ------------------------------------------------------------
def timeslot_summary(
    container: Union[BidCurve, MultiBidCurve],
    vwap_side: str = "supply"
) -> pd.DataFrame:
    """
    Columns:
      clearing_price, clearing_volume, vwap,
      price_min, price_max, imbalance_integral
    """
    if vwap_side not in {"supply", "demand"}:
        raise ValueError("vwap_side must be 'supply' or 'demand'")

    labels, cp, cv, vwap, pmin, pmax, imb = [], [], [], [], [], [], []

    for lbl, sl in _iter_slices(container):
        _cp, _cv = clearing_price_volume(sl)
        cp.append(_cp); cv.append(_cv)

        price   = sl.index.to_numpy(dtype=float)
        cum_vol = sl[f"{vwap_side}_cum"].to_numpy(dtype=float)
        inc_vol = np.diff(np.concatenate(([0.0], cum_vol)))
        tot_vol = cum_vol[-1]
        vwap.append((price * inc_vol).sum() / tot_vol if tot_vol > 0 else np.nan)

        nz = np.flatnonzero(inc_vol)
        pmin.append(price[nz[0]]  if nz.size else np.nan)
        pmax.append(price[nz[-1]] if nz.size else np.nan)

        imb.append(imbalance(sl, integrated=True))
        labels.append(lbl)

    idx_name = "timestamp" if isinstance(container, MultiBidCurve) else "time_code"
    summary = pd.DataFrame(
        {
            "clearing_price":      cp,
            "clearing_volume":     cv,
            "vwap":                vwap,
            "price_min":           pmin,
            "price_max":           pmax,
            "imbalance_integral":  imb,
        },
        index=pd.Index(labels, name=idx_name),
    )
    return summary


In [None]:
# ---------------------------------------------------------------------
# alt_loader.py  ·  extra helpers + row-wise loader
# ---------------------------------------------------------------------
from dataclasses import dataclass, field
from pathlib import Path
from typing import Mapping, Iterable, Dict, Union, Literal

from jp_da_imb.trading_costs.bid_curve_struc import *
import numpy as np
import pandas as pd
import math


# add functionality to load different types of data -------------------
def _read_any(path, **kw) -> pd.DataFrame:
    """Try Parquet first (you can extend to CSV/XLSX if needed)."""
    return pd.read_parquet(path, **kw)


def _make_price_grid(
    df: pd.DataFrame,
    price_step: float,
    pad: float = 0.0,
    price_col: str = "price"
) -> np.ndarray:
    """
    Common ascending grid so that every slice shares the same bins.
    """
    pmin, pmax = df[price_col].min() - pad, df[price_col].max() + pad
    # round to exact step to avoid FP jitter
    pmin = math.floor(pmin / price_step) * price_step
    pmax = math.ceil( pmax / price_step) * price_step
    n    = int((pmax - pmin) / price_step) + 1
    return np.round(np.linspace(pmin, pmax, n), decimals=6)   # 6-dp safety


# ---------------------------------------------------------------------
def _build_slice(
    rows: pd.DataFrame,
    price_bins: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
    """
    Build supply_cum & demand_cum arrays (len = len(price_bins))
    *rows* contains all buy/sell entries for **one timestamp**.
    """
    sup = np.zeros(len(price_bins), dtype=float)   # supply cum
    dem = np.zeros(len(price_bins), dtype=float)   # demand cum

    # split rows --------------------------------------------------------
    sell_rows = rows[rows["order"] == "sell"].sort_values("price")
    buy_rows  = rows[rows["order"] == "buy"].sort_values("price", ascending=False)

    # ---- fill supply (monotone ↑ with price) -------------------------
    for _, r in sell_rows.iterrows():
        idx = np.searchsorted(price_bins, r["price"], side="left")
        sup[idx:] = r["volume"]

    # ---- fill demand (monotone ↓ with price) -------------------------
    for _, r in buy_rows.iterrows():
        idx = np.searchsorted(price_bins, r["price"], side="left")
        dem[: idx + 1] = r["volume"]

    return sup, dem


# ---------------------------------------------------------------------
def load_curves_rows(
    path: str | Path,
    *,
    area: str,
    price_step: float  = 0.01,
    tz: str           = "Asia/Tokyo",
    vol_col: str      = "volume",
    price_col: str    = "price",
    order_col: str    = "order",
    area_col: str     = "area",
    dt_col: str       = "dt",
    **read_kw,
) -> MultiBidCurve:
    """
    Read the *row-wise* file and return a **MultiBidCurve** for *area*.

    * Every unique timestamp becomes one 30-minute curve slice.
    * One common price grid (min-max in data, `price_step` increments).
    """
    df = _read_any(path, **read_kw).copy()

    req = {vol_col, price_col, order_col, area_col, dt_col}
    if not req.issubset(df.columns):
        raise ValueError(f"File missing columns: {req - set(df.columns)}")

    # filter region ----------------------------------------------------
    df = df[df[area_col] == area]
    if df.empty:
        raise ValueError(f"No rows found for area '{area}'")

    # tidy types -------------------------------------------------------
    df[vol_col]   = df[vol_col].astype(float)
    df[price_col] = df[price_col].astype(float)
    df[order_col] = df[order_col].str.lower().str.strip()
    df[dt_col]    = pd.to_datetime(df[dt_col]).dt.tz_convert(tz)

    # common price grid ------------------------------------------------
    bins = _make_price_grid(df[[price_col]], price_step, price_col=price_col)

    # pre-allocate containers -----------------------------------------
    ts_index = sorted(df[dt_col].unique())
    sup_mat  = np.zeros((len(ts_index), len(bins)), dtype=float)
    dem_mat  = np.zeros_like(sup_mat)

    # build each slice -------------------------------------------------
    for i, ts in enumerate(ts_index):
        rows = df.loc[df[dt_col] == ts, [vol_col, price_col, order_col]]
        sup, dem = _build_slice(
            rows.rename(columns={vol_col: "volume",
                                 price_col: "price",
                                 order_col: "order"}),
            bins,
        )
        sup_mat[i, :] = sup
        dem_mat[i, :] = dem

    # wrap into DataFrames with proper index/columns -------------------
    sup_df = pd.DataFrame(sup_mat, index=ts_index, columns=bins)
    dem_df = pd.DataFrame(dem_mat, index=ts_index, columns=bins)

    return MultiBidCurve(
        region = area,
        bins   = bins,
        supply = sup_df,
        demand = dem_df,
        df_raw = df,           # keep original long rows
    )


In [None]:
    # ---------------------------------------------------------------
    def to_long(self) -> pd.DataFrame:
        long = pd.concat(
            [_stack(self.supply, side="supply"),
             _stack(self.demand, side="demand")]
        )
        long["region"] = self.region
        long["date"]   = long["timestamp"].dt.normalize()
        long["time_code"] = (
            long["timestamp"].dt.hour * 60 + long["timestamp"].dt.minute
        ) // 30 + 1
        self._long_cache = long[
            ["date", "region", "time_code", "timestamp", "side", "price", "cum_vol"]
        ]
        return self._long_cache

    # ---------------------------------------------------------------
    def __repr__(self):
        n_days = len(self.dates)
        return f"<MultiBidCurve {self.region} [{n_days} days, {len(self.bins)} bins]>"

    # so you can iterate the new object -----------------------------
    def __getitem__(self, ts: str | pd.Timestamp):
        return self.slice_time(ts)

    def __iter__(self):
        return iter(self.supply.index)


| Module               | Function / Method                                            | What it does (one-liner)                                               |
| -------------------- | ------------------------------------------------------------ | ---------------------------------------------------------------------- |
| **`jepx_loader.py`** | **`load_curve(path, *, region, date)`**                      | Read a *single-day* 48-row file → `BidCurve`.                          |
|                      | **`load_curves(path, *, region)`**                           | Read a *multi-day* file → `MultiBidCurve` (index = 30-min timestamps). |
|                      | **`BidCurve.slice_time(tc)`**                                | Return supply & demand arrays for time-code 1-48.                      |
|                      | **`BidCurve.to_long()`**                                     | Long “tidy” DataFrame (date, time\_code, side, price, cum\_vol).       |
|                      | **`MultiBidCurve.__getitem__(ts)`**                          | Quick accessor: `panel['2025-07-08 12:00']` → curve slice.             |
|                      | **`MultiBidCurve.slice_day(date)`**                          | Pop one day back out as a `BidCurve`.                                  |
|                      | **`MultiBidCurve.to_long()`**                                | Long DF with timestamp granularity.                                    |
|                      | **`MultiBidCurve.iter_timeslices()`**                        | Generator over every 30-min slice.                                     |
| **`jepx_stats.py`**  | **`clearing_price(slice_df)`**                               | Intersection price (¥/kWh) for one slice.                              |
|                      | **`clearing_demand / clearing_supply(slice_df)`**            | Cleared MWh on either side.                                            |
|                      | **`residual_volume(slice_df, price, side)`**                 | Remaining MWh above a price threshold.                                 |
|                      | **`imbalance(slice_df, integrated=False)`**                  | Supply-minus-demand vector or its integral.                            |
|                      | **`trading_cost(slice_df, vol, side)`**                      | `total ¥`, `avg ¥/kWh` to buy/sell *vol* in one slice.                 |
|                      | **`elasticity(slice_df, side)`**                             | dVolume/dPrice curve (slope) per price bin.                            |
|                      | **`clearing_series(container)`**                             | Time-series of clearing price & volume for **all** slots.              |
|                      | **`trading_cost_series(container, vol, side)`**              | Series/DataFrame of cost or avg price per slot.                        |
|                      | **`elasticity_panel(container, side)`**                      | Time × price grid(s) of elasticity.                                    |
|                      | **`timeslot_summary(container)`**                            | Table per slot: clearing P/V, VWAP, min/max price, imbalance.          |
| **`jepx_plots.py`**  | **`plot_supply_demand(container, ts)`**                      | Step plot of supply & demand curves for one timestamp.                 |
|                      | **`plot_trading_cost_series(container, vol, side, metric)`** | Line chart of total cost *or* avg price over time.                     |
|                      | **`plot_metric_series(container, metric)`**                  | Generic plot for any column from `timeslot_summary()`.                 |


### NEW NEW NEW 

In [None]:
# jepx_stats.py  ───────────  point-price cost
def trading_cost_at_price(slice_df, price, *, side="buy"):
    """Cost (buy) or revenue (sell) if everything trades *at* that price."""
    price = float(price)
    if side == "buy":
        vol = np.interp(price, slice_df.index, slice_df["supply_cum"])
    elif side == "sell":
        vol = np.interp(price, slice_df.index, slice_df["demand_cum"])
    else:
        raise ValueError("side must be 'buy' or 'sell'")
    total = price * vol * 1_000             # ¥/kWh × MWh → ¥
    return vol, total, price                # avg_price == price


def trading_cost_at_price_series(container, price, *, side="buy",
                                 metric="total_cost", return_dataframe=True):
    labels, vols, totals = [], [], []
    for lbl, sl in _iter_slices(container):
        v, t, _ = trading_cost_at_price(sl, price, side=side)
        labels.append(lbl); vols.append(v); totals.append(t)
    vol_s   = pd.Series(vols,   index=labels, name="volume")
    tot_s   = pd.Series(totals, index=labels, name="total_cost")
    if not return_dataframe:
        return {"volume": vol_s, "total_cost": tot_s}[metric]
    return pd.concat([vol_s, tot_s], axis=1)


In [None]:
##### Mariginal price correct form 
# jepx_stats.py  ───────────────────────────────────────────────────────
import numpy as np
import pandas as pd
from typing import Literal, Tuple

# ────────────────────────────────────────────────────────────────────
# 1) One-slice marginal impact
# ────────────────────────────────────────────────────────────────────
def marginal_price_impact(
    slice_df: pd.DataFrame,
    qty_mwh: float,
    *,
    side: Literal["buy", "sell"] = "buy",
) -> Tuple[float, float, float]:
    """
    Add **qty_mwh** to demand ('buy') or supply ('sell') and
    return (old_cp, new_cp, extra_cost_approx).

    extra_cost_approx = (new_cp - old_cp) × qty_mwh × 1 000   [¥]
    """
    if qty_mwh <= 0:
        raise ValueError("qty_mwh must be > 0")

    # ---------- original clearing -----------------------------------
    old_cp, _ = _clearing_price_volume(slice_df)

    # ---------- shift curve -----------------------------------------
    adj = slice_df.copy()
    if side == "buy":
        adj["demand_cum"] += qty_mwh
    elif side == "sell":
        adj["supply_cum"] += qty_mwh
    else:
        raise ValueError("side must be 'buy' or 'sell'")

    new_cp, _ = _clearing_price_volume(adj)

    delta_price = new_cp - old_cp                     # ¥/kWh
    extra_cost  = delta_price * qty_mwh * 1_000       # ¥

    return old_cp, new_cp, extra_cost


# ────────────────────────────────────────────────────────────────────
# 2) Vectorised over all timeslots
# ────────────────────────────────────────────────────────────────────
def marginal_price_impact_series(
    container,
    qty_mwh: float,
    *,
    side: Literal["buy", "sell"] = "buy",
    metric: Literal["delta_price", "extra_cost"] = "delta_price",
    return_dataframe: bool = True,
):
    labels, dP, d¥ = [], [], []

    for lbl, sl in _iter_slices(container):
        cp0, cp1, extra = marginal_price_impact(sl, qty_mwh, side=side)
        labels.append(lbl)
        dP.append(cp1 - cp0)
        d¥.append(extra)

    idx_name = "timestamp" if hasattr(container, "slice_day") else "time_code"
    dP_s = pd.Series(dP, index=labels, name="delta_price")   # ¥/kWh
    d¥_s = pd.Series(d¥, index=labels, name="extra_cost")    # ¥

    if not return_dataframe:
        return {"delta_price": dP_s, "extra_cost": d¥_s}[metric]

    return pd.concat([dP_s, d¥_s], axis=1).rename_axis(idx_name)


# ────────────────────────────────────────────────────────────────────
# 3) Quick plot
# ────────────────────────────────────────────────────────────────────
def plot_marginal_price_impact(
    container,
    qty_mwh: float,
    *,
    side: Literal["buy", "sell"] = "buy",
    metric: Literal["delta_price", "extra_cost"] = "delta_price",
):
    data = marginal_price_impact_series(
        container, qty_mwh, side=side, return_dataframe=False, metric=metric
    )
    fig, ax = plt.subplots()
    ax.plot(data.index, data.values, linewidth=1.2)
    ax.set_xlabel("Time" if hasattr(container, "slice_day") else "Time code (1-48)")
    if metric == "delta_price":
        ax.set_ylabel("ΔPrice [¥/kWh]")
        title_y = "Price impact"
    else:
        ax.set_ylabel("Extra cost [¥]")
        title_y = "Extra ¥ out-of-pocket"
    ax.set_title(f"{title_y} | +{qty_mwh} MWh {side.upper()} per slot")
    ax.grid(True, alpha=0.3)
    return fig

### Combining regions 

In [None]:
from __future__ import annotations
from pathlib import Path
from typing import List, Dict, Tuple
import numpy as np
import pandas as pd

from jp_da_imb.trading_costs.alt_loader import load_curves_rows
from jp_da_imb.trading_costs.bid_curve_struc import MultiBidCurve


def _sum_rows(df_list):
    """fast element-wise sum of DataFrames that share index/columns"""
    return sum(df_list)


def _aggregate_slice(
    panels: List[MultiBidCurve],
    ts,
    price_bins: np.ndarray,
    join_char: str = "-"
) -> List[Tuple[str, pd.DataFrame]]:
    """
    For one timestamp, merge panels that share the same group-ID.
    Returns list of (name, slice_df).
    """
    by_gid: Dict[int, List[MultiBidCurve]] = {}
    for p in panels:
        gid = p.groups.get(ts, pd.NA)
        if pd.isna(gid):
            continue
        by_gid.setdefault(int(gid), []).append(p)

    out = []
    for gid, plist in by_gid.items():
        sup_vec = _sum_rows([p.supply.loc[ts] for p in plist])
        dem_vec = _sum_rows([p.demand.loc[ts] for p in plist])

        name = join_char.join(sorted(p.region for p in plist))
        slice_df = pd.DataFrame(
            {
                "supply_cum": sup_vec.values,
                "demand_cum": dem_vec.values,
            },
            index=price_bins,
        )
        out.append((name, slice_df))

    return out


# ---------------------------------------------------------------------
def group_slices_from_parquet(
    parquet_path: str | Path,
    areas: List[str],
    *,
    price_step: float = 0.01,
    tz: str = "Asia/Tokyo",
) -> Dict[pd.Timestamp, List[Tuple[str, pd.DataFrame]]]:
    """
    Read one Parquet file, build per-area MultiBidCurve objects,
    then aggregate **per timestamp** by group-ID.

    Returns
    -------
    dict
        key   = pd.Timestamp (30-min slot)
        value = list of (combined_name, curve_slice_df)
    """
    # -------------------------------------------------- load regions
    panels: List[MultiBidCurve] = [
        load_curves_rows(
            parquet_path,
            area=area,
            price_step=price_step,
            tz=tz,
        )
        for area in areas
    ]

    # sanity: ensure identical bins across regions
    bins0 = panels[0].bins
    for p in panels[1:]:
        if not np.array_equal(bins0, p.bins):
            raise ValueError("Price grids differ across regions; resample first.")

    # union of all timestamps
    all_ts = sorted(set().union(*(p.supply.index for p in panels)))

    time_dict: Dict[pd.Timestamp, List[Tuple[str, pd.DataFrame]]] = {}

    for ts in all_ts:
        combined = _aggregate_slice(panels, ts, bins0)
        if combined:                     # skip slots with no data anywhere
            time_dict[ts] = combined

    return time_dict
