In [11]:
from __future__ import annotations

import dataclasses
from copy import deepcopy
from typing import Literal, Self

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from src.drawing_utils import Arrow, Point, Segment
from src.plotting_utils import configure_matplotlib, rm
from scipy.interpolate import interp1d

configure_matplotlib()

In [12]:
DX = 1e-3
DECIMALS = 10

SELLER_COST = "Seller Cost, $C$"
BUYER_UTILITY = "Buyer Utility, $U$"
WELFARE = "Welfare, $W$"
OPTIMUM = "Optimum"


@dataclasses.dataclass
class _Curve:
    points: list[Point]
    name: str
    integral_name: str
    color: str
    stepped: bool = True

    @property
    def xs(self) -> np.ndarray:
        return np.array([p.x for p in self.points])

    @property
    def ys(self) -> np.ndarray:
        return np.array([p.y for p in self.points])

    def to_series(self) -> pd.Series:
        return pd.Series(self.ys, index=self.xs)

    @property
    def func(self) -> interp1d:
        return interp1d(
            self.xs,
            self.ys,
            kind=("next" if self.stepped else "linear"),
            bounds_error=False,
        )

    @property
    def upsampled(self) -> Self:
        copy = deepcopy(self)
        quantity_vals = np.arange(self.xs.min(), self.xs.max() + DX, DX).round(DECIMALS)
        copy.points = [
            Point(x, y)
            for (x, y) in zip(quantity_vals, copy.func(quantity_vals), strict=True)
        ]
        return copy

    @property
    def integral(self) -> interp1d:
        upsampled = self.upsampled.to_series()
        if self.stepped:
            upsampled = upsampled.shift(-1)
        y_vals = np.array([0.0, *upsampled.iloc[:-1].cumsum().to_numpy()]) * DX
        return interp1d(upsampled.index, y_vals, bounds_error=False)


@dataclasses.dataclass
class SupplyCurve(_Curve):
    name: str = "Supply Curve"
    integral_name: str = rm(SELLER_COST)
    color: str = "#1565C0"


@dataclasses.dataclass
class DemandCurve(_Curve):
    name: str = "Demand Curve"
    integral_name: str = rm(BUYER_UTILITY)
    color: str = "#EF6C00"


@dataclasses.dataclass
class SupplyDemand:
    supply_curves: list[SupplyCurve]
    demand_curves: list[DemandCurve]
    equilibrium_price: float | None = None

    @property
    def curves(self) -> list[SupplyCurve | DemandCurve]:
        return self.supply_curves + self.demand_curves

    @property
    def upsampled(self) -> SupplyDemand:
        copy = deepcopy(self)
        for i, curve in enumerate(self.supply_curves):
            copy.supply_curves[i] = curve.upsampled
        for i, curve in enumerate(self.demand_curves):
            copy.demand_curves[i] = curve.upsampled
        return copy

    def composite_curve[C: SupplyCurve | DemandCurve](
        self,
        curve_type: type[C],
        mask: str | None = None,
        individual_quantities: bool = False,
    ) -> C:
        curves: list[SupplyCurve | DemandCurve] = {
            SupplyCurve: self.supply_curves,
            DemandCurve: self.demand_curves,
        }[curve_type]
        curve_dfs = []
        zero_quantity_prices = []
        for curve in curves:
            curve_df = pd.DataFrame(
                [
                    {"name": curve.name, "individual_quantity": p.x, "price": p.y}
                    for p in curve.points
                ]
            )
            curve_df["delta_quantity"] = (
                curve_df["individual_quantity"].diff().round(DECIMALS)
            )
            curve_dfs.append(curve_df.iloc[1:])
            zero_quantity_prices.append(curve_df.iloc[0]["price"])
        composite_df = pd.concat(curve_dfs)
        composite_df = composite_df.sort_values(
            by="price", ascending=(curve_type is SupplyCurve), kind="stable"
        )
        composite_df["total_quantity"] = (
            composite_df["delta_quantity"].cumsum().round(DECIMALS)
        )
        if mask is not None:
            composite_df.loc[
                composite_df["name"] != mask, ["individual_quantity", "price"]
            ] = 0.0
        [stepped] = {c.stepped for c in curves}
        points = [
            Point(
                round(row["total_quantity"], DECIMALS),
                row["individual_quantity"] if individual_quantities else row["price"],
            )
            for _, row in composite_df.iterrows()
        ]
        zero_quantity_price = {SupplyCurve: min, DemandCurve: max}[curve_type](
            zero_quantity_prices
        )
        points.insert(
            0,
            Point(0.0, (0.0 if individual_quantities else zero_quantity_price)),
        )
        return curve_type(points, stepped=stepped)

    def cost(self, mask: str | None = None) -> interp1d:
        return self.composite_curve(SupplyCurve, mask).integral

    def utility(self, mask: str | None = None) -> interp1d:
        return self.composite_curve(DemandCurve, mask).integral

    def welfare(self, mask: str | None = None) -> callable[np.array, np.array]:
        return lambda quantity: self.utility(mask)(quantity) - self.cost(mask)(quantity)

    def equilibrium_quantity(self, mask: str | None = None) -> float:
        quantity_vals = np.arange(0.0, self.max_total_quantity + DX, DX).round(DECIMALS)
        welfare_vals = self.welfare()(quantity_vals)
        idxmax = np.nanargmax(welfare_vals)
        equilibrium_cumulative_quantity = quantity_vals[idxmax]
        if mask is None:
            return equilibrium_cumulative_quantity
        else:
            [curve_type] = {type(c) for c in self.curves if c.name == mask}
            composite_curve = self.upsampled.composite_curve(
                curve_type, mask, individual_quantities=True
            )
            equilibrium_individual_quantity = max(
                [
                    p.y
                    for p in composite_curve.points
                    if p.x <= equilibrium_cumulative_quantity
                ]
            )
            return equilibrium_individual_quantity

    @property
    def max_total_quantity(self) -> float:
        return sum([c.xs.max() for c in self.supply_curves])

    @property
    def equilibrium(self) -> Point | None:
        return (
            Point(self.equilibrium_quantity(), self.equilibrium_price)
            if self.equilibrium_price is not None
            else None
        )

    def plot(
        self,
        ax: Axes,
        xlim: tuple[float, float] = (0.0, 10.0),
        ylim: tuple[float, float] = (0.0, 10.0),
        xticks: dict[float, str] | None = None,
        yticks: dict[float, str] | None = None,
        xaxis_label: str = "$Q$",
        mode: Mode = "supply_and_demand",
        legend: bool = True,
    ) -> None:
        plotter = SupplyDemandPlotter(ax, xlim, ylim, xticks, yticks, xaxis_label, mode)
        equilibrium_quantity = self.equilibrium_quantity()
        for curve in [
            self.composite_curve(SupplyCurve),
            self.composite_curve(DemandCurve),
        ]:
            plotter.add(curve, equilibrium_quantity)
        if mode == "supply_and_demand" and self.equilibrium is not None:
            self.equilibrium.drawn(ax)
        if self.supply_curves and self.demand_curves:
            if mode == "cost_and_utility":
                print(f"Q_opt = {equilibrium_quantity:.3f}")
                welfare_vals = self.welfare()(plotter.x_vals)
                ax.plot(plotter.x_vals, welfare_vals, label=rm(WELFARE))
                optimum = Point(equilibrium_quantity, np.nanmax(welfare_vals))
                ax.plot(*optimum.xy, "ko", markersize=3, label=rm(OPTIMUM))
        if legend:
            plotter.legend()


type Mode = Literal["supply_and_demand", "cost_and_utility"]


@dataclasses.dataclass
class SupplyDemandPlotter:
    ax: Axes
    xlim: tuple[float, float] = (0.0, 10.0)
    ylim: tuple[float, float] = (0.0, 10.0)
    xticks: dict[float, str] | None = None
    yticks: dict[float, str] | None = None
    xaxis_label: str = "$Q$"
    mode: Mode = "supply_and_demand"

    x_vals: np.ndarray = dataclasses.field(init=False)

    def __post_init__(self) -> None:
        self.x_vals = np.linspace(*self.xlim, 501)

        # if self.mode == "supply_and_demand":
        #     self.ax.set_aspect("equal")
        self.ax.spines[:].set_visible(False)
        self.ax.set_xlim(self.xlim[0] - 0.3, self.xlim[1] * 1.1)
        self.ax.set_ylim(self.ylim[0] - 0.3, self.ylim[1] * 1.1)
        if self.xticks is not None:
            self.ax.set_xticks(list(self.xticks.keys()))
        if self.yticks is not None:
            self.ax.set_yticks(list(self.yticks.keys()))
        if self.xticks is not None:
            self.ax.set_xticklabels(list(self.xticks.values()))
        if self.yticks is not None:
            self.ax.set_yticklabels(list(self.yticks.values()))
        Arrow.horizontal(x1=self.xlim[0], x2=(self.xlim[1] * 1.1)).drawn(
            self.ax
        ).end.labeled(self.ax, self.xaxis_label, ha="left", va="center")
        yaxis_label = {
            "supply_and_demand": "$P$",
            "cost_and_utility": r"$\mathdollar$",
        }[self.mode]
        Arrow.vertical(y1=self.ylim[0], y2=(self.ylim[1] * 1.1)).drawn(
            self.ax
        ).end.labeled(self.ax, yaxis_label, va="bottom")

    def add(self, curve: _Curve, equilibrium_quantity: float | None = None) -> None:
        if self.mode == "supply_and_demand":
            fmt = "o-" if curve.stepped else "-"
            self.ax.plot(
                curve.xs,
                curve.ys,
                fmt,
                color=curve.color,
                drawstyle=("steps" if curve.stepped else "default"),
                label=rm(curve.name),
            )

        if self.mode == "supply_and_demand":
            if equilibrium_quantity is not None:
                self.ax.fill_between(
                    self.x_vals,
                    curve.func(self.x_vals),
                    0,
                    where=(self.x_vals <= equilibrium_quantity),
                    step=("pre" if curve.stepped else None),
                    alpha=0.2,
                    color=curve.color,
                    hatch={SupplyCurve: r"\\", DemandCurve: "//"}[type(curve)],
                    edgecolor=curve.color,
                    label=curve.integral_name,
                )
        if self.mode == "cost_and_utility":
            self.ax.plot(
                self.x_vals, curve.integral(self.x_vals), label=curve.integral_name
            )

    def legend(self) -> None:
        self.ax.legend(loc="center left", bbox_to_anchor=(0.95, 0.5))

In [None]:
fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(6.4, 4.2), layout="tight")

x_vals = np.linspace(0, 10)
supply_demand = SupplyDemand(
    [SupplyCurve([Point(x, 1 + 1 * x) for x in x_vals], stepped=False)],
    [DemandCurve([Point(x, 10 - x) for x in x_vals], stepped=False)],
    equilibrium_price=5.5,
)
supply_demand.composite_curve(SupplyCurve)
Segment(supply_demand.equilibrium, Point(5.5, 5.5)).drawn(ax1).end.labeled(
    ax1, "$(Q^*, P^*)$", ha="left", va="center"
)
supply_demand.plot(ax1)

supply_demand.plot(ax2, ylim=(0, 60), mode="cost_and_utility")

fig.savefig("img/fig_2_1.png", dpi=200)

In [None]:
fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(6.4, 4.2), layout="tight")

supply_demand = SupplyDemand(
    [SupplyCurve([Point(0, 0), Point(6, 2), Point(9, 7)])],
    [DemandCurve([Point(0, 8), Point(4, 8), Point(8, 5)])],
    equilibrium_price=5,
)
supply_demand.plot(ax1)
Segment(supply_demand.equilibrium, (6.5, 4)).drawn(ax1).end.labeled(
    ax1, "$(Q^*, P^*)$", ha="left", va="top"
)

supply_demand.plot(ax2, ylim=(0, 50), mode="cost_and_utility")

fig.savefig("img/fig_2_2.png", dpi=200)

In [15]:
supply_demand = SupplyDemand(
    supply_curves=[
        SupplyCurve([Point(0, 0), Point(6, 2), Point(9, 7)], name="G1"),
        SupplyCurve([Point(0, 0), Point(7, 4), Point(10, 10)], name="G2"),
    ],
    demand_curves=[
        DemandCurve([Point(0, 8), Point(4, 8), Point(8, 5)], name="L1"),
        DemandCurve([Point(0, 9), Point(3, 9), Point(9, 3)], name="L2"),
    ],
    equilibrium_price=4,
)


def two_generators_two_loads(style: Literal["ticks", "area"]) -> Figure:
    fig, axs = plt.subplots(nrows=2, ncols=2, layout="tight")

    for ax, curve in zip(np.ndarray.flatten(axs), supply_demand.curves, strict=True):
        equilibrium_quantity = (
            supply_demand.equilibrium_quantity(curve.name) if style == "area" else None
        )
        plotter = SupplyDemandPlotter(
            ax,
            xticks=(
                {equilibrium_quantity: rf"$Q_\mathrm{{{curve.name}}}^*$"}
                if style == "area"
                else None
            ),
            yticks=({} if style == "area" else None),
        )
        plotter.add(curve, equilibrium_quantity)

    return fig, axs

In [None]:
fig, _ = two_generators_two_loads(style="ticks")
fig.savefig("img/fig_2_3.png", dpi=200)

In [None]:
fig, ((ax1, ax2), (ax3, ax4)) = two_generators_two_loads(style="area")

Point(3, 1).labeled(ax1, r"$C_\mathrm{G1}(Q_\mathrm{G1}^*)$")
Point(2.5, 2).labeled(ax2, r"$C_\mathrm{G2}(Q_\mathrm{G2}^*)$")
Point(4, 3.25).labeled(ax3, r"$U_\mathrm{L2}(Q_\mathrm{L2}^*)$")
Point(1.5, 4.5).labeled(ax4, r"$U_\mathrm{L2}(Q_\mathrm{L2}^*)$")

fig.savefig("img/fig_2_4.png", dpi=200)

In [18]:
def add_horizontal_brace(
    ax: Axes, x1: float, x2: float, y: float, label: str, opening: Literal["up", "down"]
) -> None:
    center_x = (x2 + x1) / 2
    sign = {"up": -1, "down": 1}[opening]
    ax.annotate(
        label,
        xy=(center_x, y + 0.5 * sign),
        xytext=(center_x, y + 1.5 * sign),
        ha="center",
        va={"up": "top", "down": "bottom"}[opening],
        arrowprops=dict(arrowstyle=f"-[, widthB={(x2 - x1) / 2 * 1.1 - 0.2}"),
    )

In [None]:
fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(6.4, 4.2), layout="tight")
xlim = (0, 20)

supply_demand.plot(ax1, xlim)
add_horizontal_brace(ax1, x1=0, x2=6, y=2, label=r"$Q_\mathrm{G1}^*$", opening="down")
add_horizontal_brace(ax1, x1=6, x2=11, y=4, label=r"$Q_\mathrm{G2}^*$", opening="up")
add_horizontal_brace(ax1, x1=0, x2=3, y=9, label=r"$Q_\mathrm{L1}^*$", opening="down")
add_horizontal_brace(ax1, x1=3, x2=11, y=8, label=r"$Q_\mathrm{L2}^*$", opening="down")
Segment(supply_demand.equilibrium, Point(14, 5)).drawn(ax1).end.labeled(
    ax1, "$(Q^*, P^*)$", ha="left"
)

supply_demand.plot(ax2, xlim, ylim=(0, 100), mode="cost_and_utility")

fig.savefig("img/fig_2_5.png", dpi=200)

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, figsize=(6.4, 6.4), layout="tight")
xlim = (0, 20)

loss_factor = 0.75


def plot_cost_and_utility(ax: Axes, key: Literal["G", "L"]) -> None:
    quantity_vals = np.arange(xlim[0], xlim[1] + DX, DX)
    cost_vals = {
        "G": supply_demand.cost()(quantity_vals),
        "L": supply_demand.cost()(quantity_vals / loss_factor),
    }[key]
    utility_vals = {
        "G": supply_demand.utility()(quantity_vals * loss_factor),
        "L": supply_demand.utility()(quantity_vals),
    }[key]
    welfare_vals = utility_vals - cost_vals
    idxmax = np.nanargmax(welfare_vals)
    equilibrium_quantity = quantity_vals[idxmax]
    print(
        f"Q_opt = {equilibrium_quantity:.3f}, "
        f"C(Q_opt) = {cost_vals[idxmax]:.3f}, "
        f"U(Q_opt) = {utility_vals[idxmax]:.3f}, "
        f"W(Q_opt) = {welfare_vals[idxmax]:.3f}"
    )

    label = rf"Q_\mathrm{{{key}}}"
    plotter = SupplyDemandPlotter(
        ax,
        xlim,
        ylim=(0, 100),
        xticks={equilibrium_quantity: rf"${label}^*$"},
        yticks={},
        xaxis_label=rf"${label}$",
    )
    ax.plot(quantity_vals, cost_vals, label=rm(SELLER_COST))
    ax.plot(quantity_vals, utility_vals, label=rm(BUYER_UTILITY))
    ax.plot(quantity_vals, welfare_vals, label=rm(WELFARE))
    optimum = Point(equilibrium_quantity, np.nanmax(welfare_vals))
    ax.plot(*optimum.xy, "ko", markersize=3, label=rm(OPTIMUM))
    plotter.legend()

    return equilibrium_quantity


generator_equilibrium_quantity = plot_cost_and_utility(ax2, "G")
load_equilibrium_quantity = plot_cost_and_utility(ax3, "L")

plotter = SupplyDemandPlotter(ax1, xlim, xticks={}, yticks={})
plotter.add(
    supply_demand.composite_curve(SupplyCurve),
    equilibrium_quantity=generator_equilibrium_quantity,
)
plotter.add(
    supply_demand.composite_curve(DemandCurve),
    equilibrium_quantity=load_equilibrium_quantity,
)
plotter.legend()
Segment(Point(generator_equilibrium_quantity, 4).drawn(ax1), Point(10.5, 7)).drawn(
    ax1
).end.labeled(ax1, r"$(Q_\mathrm{G}^*, P_\mathrm{G}^*)$", va="bottom")
Segment(Point(load_equilibrium_quantity, 8).drawn(ax1), Point(8, 10)).drawn(
    ax1
).end.labeled(ax1, r"$(Q_\mathrm{L}^*, P_\mathrm{L}^*)$", va="bottom")

fig.savefig("img/fig_2_6.png", dpi=200)