In [1]:
from __future__ import annotations

import dataclasses
from typing import Literal

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from src.drawing_utils import Arrow, Point, Segment
from src.plotting_utils import configure_matplotlib, rm
from scipy.interpolate import interp1d

configure_matplotlib()

In [2]:
DX = 1e-3

SELLER_COST = rm("Seller Cost, $C$")
BUYER_UTILITY = rm("Buyer Utility, $U$")
WELFARE = rm("Welfare, $W$")
OPTIMUM = rm("Optimum")


@dataclasses.dataclass
class Curve:
    points: list[Point]
    stepped: bool = True

    @property
    def xs(self) -> np.ndarray:
        return np.array([p.x for p in self.points])

    @property
    def ys(self) -> np.ndarray:
        return np.array([p.y for p in self.points])

    @property
    def func(self) -> interp1d:
        return interp1d(
            self.xs,
            self.ys,
            kind=("next" if self.stepped else "linear"),
            bounds_error=False,
        )

    @property
    def integral(self) -> interp1d:
        x_vals = np.arange(self.xs.min(), self.xs.max(), DX).round(10)
        if self.stepped:
            y_vals = np.array([0.0, *np.cumsum(self.func(x_vals + DX))[:-1]]) * DX
        else:
            y_vals = np.array([0.0, *np.cumsum(self.func(x_vals)[:-1])]) * DX
        return interp1d(x_vals, y_vals, bounds_error=False)


@dataclasses.dataclass
class SupplyDemand:
    supply: Curve | None = None
    demand: Curve | None = None
    equilibrium: Point | None = None

    @property
    def cost(self) -> interp1d:
        return self.supply.integral

    @property
    def utility(self) -> interp1d:
        return self.demand.integral

    @property
    def welfare(self) -> callable[np.array, np.array]:
        return lambda quantity: self.utility(quantity) - self.cost(quantity)

    @property
    def Q_star(self) -> float:
        quantity_vals = np.arange(self._bound(min), self._bound(max), DX)
        welfare_vals = self.welfare(quantity_vals)
        idxmax = np.nanargmax(welfare_vals)
        return quantity_vals[idxmax]

    def _bound(self, func: callable) -> float:
        args = []
        if self.supply is not None:
            args.append(func(self.supply.xs))
        if self.demand is not None:
            args.append(func(self.demand.xs))
        return func(args)

    def plot(
        self,
        ax: Axes,
        xlim: tuple[float, float] = (0.0, 10.0),
        ylim: tuple[float, float] = (0.0, 10.0),
        xticks: dict[float, str] | None = None,
        yticks: dict[float, str] | None = None,
        xaxis_label: str = "$Q$",
        mode: Mode = "supply_and_demand",
        legend: bool = True,
    ) -> None:
        plotter = SupplyDemandPlotter(ax, xlim, ylim, xticks, yticks, xaxis_label, mode)
        Q_star = self.equilibrium.x if self.equilibrium is not None else None
        if self.supply is not None:
            plotter.add("supply", self.supply, Q_star)
        if self.demand is not None:
            plotter.add("demand", self.demand, Q_star)
        if mode == "supply_and_demand" and self.equilibrium is not None:
            self.equilibrium.drawn(ax)
        if self.supply is not None and self.demand is not None:
            if mode == "cost_and_utility":
                print(f"Q_star = {Q_star:.3f}")
                welfare_vals = self.welfare(plotter.x_vals)
                ax.plot(plotter.x_vals, welfare_vals, label=WELFARE)
                optimum = Point(Q_star, np.nanmax(welfare_vals))
                ax.plot(*optimum.xy, "ko", markersize=3, label=OPTIMUM)
        if legend:
            plotter.legend()


type Mode = Literal["supply_and_demand", "cost_and_utility"]


@dataclasses.dataclass
class SupplyDemandPlotter:
    ax: Axes
    xlim: tuple[float, float] = (0.0, 10.0)
    ylim: tuple[float, float] = (0.0, 10.0)
    xticks: dict[float, str] | None = None
    yticks: dict[float, str] | None = None
    xaxis_label: str = "$Q$"
    mode: Mode = "supply_and_demand"

    x_vals: np.ndarray = dataclasses.field(init=False)

    def __post_init__(self) -> None:
        self.x_vals = np.linspace(*self.xlim, 501)

        # if self.mode == "supply_and_demand":
        #     self.ax.set_aspect("equal")
        self.ax.spines[:].set_visible(False)
        self.ax.set_xlim(self.xlim[0] - 0.3, self.xlim[1] * 1.1)
        self.ax.set_ylim(self.ylim[0] - 0.3, self.ylim[1] * 1.1)
        if self.xticks is not None:
            self.ax.set_xticks(list(self.xticks.keys()))
        if self.yticks is not None:
            self.ax.set_yticks(list(self.yticks.keys()))
        if self.xticks is not None:
            self.ax.set_xticklabels(list(self.xticks.values()))
        if self.yticks is not None:
            self.ax.set_yticklabels(list(self.yticks.values()))
        Arrow.horizontal(x1=self.xlim[0], x2=(self.xlim[1] * 1.1)).drawn(
            self.ax
        ).end.labeled(self.ax, self.xaxis_label, ha="left", va="center")
        yaxis_label = {
            "supply_and_demand": "$P$",
            "cost_and_utility": r"$\mathdollar$",
        }[self.mode]
        Arrow.vertical(y1=self.ylim[0], y2=(self.ylim[1] * 1.1)).drawn(
            self.ax
        ).end.labeled(self.ax, yaxis_label, va="bottom")

    def add(
        self,
        supply_or_demand: Literal["supply", "demand"],
        curve: Curve,
        Q_star: float | None = None,
    ) -> None:
        color = {"supply": "#1565C0", "demand": "#EF6C00"}[supply_or_demand]
        if self.mode == "supply_and_demand":
            fmt = "o-" if curve.stepped else "-"
            self.ax.plot(
                curve.xs,
                curve.ys,
                fmt,
                color=color,
                drawstyle=("steps" if curve.stepped else "default"),
                label=rm(f"{supply_or_demand.capitalize()} Curve"),
            )

        cost_or_utility_label = {"supply": SELLER_COST, "demand": BUYER_UTILITY}[
            supply_or_demand
        ]
        if self.mode == "supply_and_demand":
            if Q_star is not None:
                self.ax.fill_between(
                    self.x_vals,
                    curve.func(self.x_vals),
                    0,
                    where=(self.x_vals <= Q_star),
                    step=("pre" if curve.stepped else None),
                    alpha=0.2,
                    color=color,
                    hatch={"supply": r"\\", "demand": "//"}[supply_or_demand],
                    edgecolor=color,
                    label=cost_or_utility_label,
                )
        if self.mode == "cost_and_utility":
            self.ax.plot(
                self.x_vals, curve.integral(self.x_vals), label=cost_or_utility_label
            )

    def legend(self) -> None:
        self.ax.legend(loc="center left", bbox_to_anchor=(0.95, 0.5))

In [None]:
fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(6.4, 4.2), layout="tight")

x_vals = np.linspace(0, 10)
equilibrium = Point(4.5, 5.5)
supply_demand = SupplyDemand(
    supply=Curve([Point(x, 1 + 1 * x) for x in x_vals], stepped=False),
    demand=Curve([Point(x, 10 - x) for x in x_vals], stepped=False),
    equilibrium=equilibrium,
)
Segment(equilibrium, Point(5.5, 5.5)).drawn(ax1).end.labeled(
    ax1, "$(Q^*, P^*)$", ha="left", va="center"
)
supply_demand.plot(ax1)

supply_demand.plot(ax2, ylim=(0, 60), mode="cost_and_utility")

fig.savefig("img/fig_2_1.png", dpi=200)

In [None]:
fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(6.4, 4.2), layout="tight")

equilibrium = Point(6, 5)
supply_demand = SupplyDemand(
    supply=Curve([Point(0, 0), Point(6, 2), Point(9, 7)]),
    demand=Curve([Point(0, 8), Point(4, 8), Point(8, 5)]),
    equilibrium=equilibrium,
)
supply_demand.plot(ax1)
Segment(equilibrium, (6.5, 4)).drawn(ax1).end.labeled(
    ax1, "$(Q^*, P^*)$", ha="left", va="top"
)

supply_demand.plot(ax2, ylim=(0, 50), mode="cost_and_utility")

fig.savefig("img/fig_2_2.png", dpi=200)

In [5]:
def two_generators_two_loads(style: Literal["ticks", "area"]) -> Figure:
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, layout="tight")

    def plot(
        ax: Axes,
        Q_star: float,
        label: str,
        supply: Curve | None = None,
        demand: Curve | None = None,
    ) -> None:
        supply_demand = SupplyDemand(
            supply,
            demand,
            equilibrium=(Point(Q_star, np.nan) if style == "area" else None),
        )
        supply_demand.plot(
            ax,
            xticks=(
                {Q_star: rf"$Q_\mathrm{{{label}}}^*$"} if style == "area" else None
            ),
            yticks=({} if style == "area" else None),
            legend=False,
        )
        ax.set_title(rm(label))

    plot(
        ax1, supply=Curve([Point(0, 0), Point(6, 2), Point(9, 7)]), Q_star=6, label="G1"
    )
    if style == "area":
        Point(3, 1).labeled(ax1, r"$C_\mathrm{G1}(Q_\mathrm{G1}^*)$")
    plot(
        ax2,
        supply=Curve([Point(0, 0), Point(7, 4), Point(10, 10)]),
        Q_star=5,
        label="G2",
    )
    if style == "area":
        Point(2.5, 2).labeled(ax2, r"$C_\mathrm{G2}(Q_\mathrm{G2}^*)$")
    plot(
        ax3, demand=Curve([Point(0, 8), Point(4, 8), Point(8, 5)]), Q_star=8, label="L1"
    )
    if style == "area":
        Point(4, 3.25).labeled(ax3, r"$U_\mathrm{L2}(Q_\mathrm{L2}^*)$")
    plot(
        ax4, demand=Curve([Point(0, 9), Point(3, 9), Point(9, 3)]), Q_star=3, label="L2"
    )
    if style == "area":
        Point(1.5, 4.5).labeled(ax4, r"$U_\mathrm{L2}(Q_\mathrm{L2}^*)$")

    return fig

In [None]:
fig = two_generators_two_loads(style="ticks")
fig.savefig("img/fig_2_3.png", dpi=200)

In [None]:
fig = two_generators_two_loads(style="area")
fig.savefig("img/fig_2_4.png", dpi=200)

In [8]:
def add_horizontal_brace(
    ax: Axes, x1: float, x2: float, y: float, label: str, opening: Literal["up", "down"]
) -> None:
    center_x = (x2 + x1) / 2
    sign = {"up": -1, "down": 1}[opening]
    ax.annotate(
        label,
        xy=(center_x, y + 0.5 * sign),
        xytext=(center_x, y + 1.5 * sign),
        ha="center",
        va={"up": "top", "down": "bottom"}[opening],
        arrowprops=dict(arrowstyle=f"-[, widthB={(x2 - x1) / 2 * 1.15 - 0.2}"),
    )

In [None]:
fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(6.4, 4.2), layout="tight")
xlim = (0, 20)

equilibrium = Point(11, 4)
supply_demand = SupplyDemand(
    supply=Curve([Point(0, 0), Point(6, 2), Point(13, 4), Point(15, 7), Point(19, 10)]),
    demand=Curve([Point(0, 9), Point(3, 9), Point(7, 8), Point(11, 5), Point(17, 3)]),
    equilibrium=equilibrium,
)
supply_demand.plot(ax1, xlim)
add_horizontal_brace(ax1, x1=0, x2=6, y=2, label=r"$Q_\mathrm{G1}^*$", opening="up")
add_horizontal_brace(ax1, x1=6, x2=11, y=4, label=r"$Q_\mathrm{G2}^*$", opening="up")
add_horizontal_brace(ax1, x1=0, x2=3, y=9, label=r"$Q_\mathrm{L1}^*$", opening="down")
add_horizontal_brace(ax1, x1=3, x2=11, y=8, label=r"$Q_\mathrm{L2}^*$", opening="down")
Segment(equilibrium, Point(14, 5)).drawn(ax1).end.labeled(
    ax1, "$(Q^*, P^*)$", ha="left"
)

supply_demand.plot(ax2, xlim, ylim=(0, 100), mode="cost_and_utility")

fig.savefig("img/fig_2_5.png", dpi=200)

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, figsize=(6.4, 6.4), layout="tight")

loss_factor = 0.75

xlim = (0, 20)
supply_demand = SupplyDemand(
    supply=Curve([Point(0, 0), Point(6, 2), Point(13, 4), Point(15, 7), Point(19, 10)]),
    demand=Curve([Point(0, 9), Point(3, 9), Point(7, 8), Point(11, 5), Point(17, 3)]),
)


def plot_cost_and_utility(ax: Axes, key: Literal["G", "L"]) -> None:
    quantity_vals = np.arange(*xlim, DX)
    cost_vals = {
        "G": supply_demand.cost(quantity_vals),
        "L": supply_demand.cost(quantity_vals / loss_factor),
    }[key]
    utility_vals = {
        "G": supply_demand.utility(quantity_vals * loss_factor),
        "L": supply_demand.utility(quantity_vals),
    }[key]
    welfare_vals = utility_vals - cost_vals
    idxmax = np.nanargmax(welfare_vals)
    Q_star = quantity_vals[idxmax]
    print(
        f"Q_star = {Q_star:.3f}, "
        f"C(Q_star) = {cost_vals[idxmax]:.3f}, "
        f"U(Q_star) = {utility_vals[idxmax]:.3f}, "
        f"W(Q_star) = {welfare_vals[idxmax]:.3f}"
    )

    label = rf"Q_\mathrm{{{key}}}"
    plotter = SupplyDemandPlotter(
        ax,
        xlim,
        ylim=(0, 100),
        xticks={Q_star: rf"${label}^*$"},
        yticks={},
        xaxis_label=rf"${label}$",
    )
    ax.plot(quantity_vals, cost_vals, label=SELLER_COST)
    ax.plot(quantity_vals, utility_vals, label=BUYER_UTILITY)
    ax.plot(quantity_vals, welfare_vals, label=WELFARE)
    optimum = Point(Q_star, np.nanmax(welfare_vals))
    ax.plot(*optimum.xy, "ko", markersize=3, label=OPTIMUM)
    plotter.legend()

    return Q_star


Q_G_star = plot_cost_and_utility(ax2, "G")
Q_L_star = plot_cost_and_utility(ax3, "L")

plotter = SupplyDemandPlotter(ax1, xlim, xticks={}, yticks={})
plotter.add("supply", supply_demand.supply, Q_star=Q_G_star)
plotter.add("demand", supply_demand.demand, Q_star=Q_L_star)
plotter.legend()
Segment(Point(Q_G_star, 4).drawn(ax1), Point(10.5, 7)).drawn(ax1).end.labeled(
    ax1, r"$(Q_\mathrm{G}^*, P_\mathrm{G}^*)$", va="bottom"
)
Segment(Point(Q_L_star, 8).drawn(ax1), Point(8, 10)).drawn(ax1).end.labeled(
    ax1, r"$(Q_\mathrm{L}^*, P_\mathrm{L}^*)$", va="bottom"
)

fig.savefig("img/fig_2_6.png", dpi=200)