In [1]:
from __future__ import annotations

import json
import gzip
import os
import shutil
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union

import numpy as np
import requests
from scipy import sparse
from tqdm import tqdm

import gurobipy as gp
from gurobipy import GRB

# Read data

In [2]:
# --------------------------------------------------------------------------- #
#  High-level public API                                                      #
# --------------------------------------------------------------------------- #

INSTANCES_URL = "https://axavier.org/UnitCommitment.jl/0.4/instances"
if "__file__" in globals():
    _BASEDIR = Path(__file__).resolve().parent
else:  # interactive session
    _BASEDIR = Path.cwd()

_CACHE = _BASEDIR / "instances"
_CACHE.mkdir(exist_ok=True, parents=True)


def read_benchmark(name: str, *, quiet: bool = False) -> "UnitCommitmentInstance":
    """
    Download (if necessary) a benchmark instance and load it.

    Example
    -------
    >>> inst = read_benchmark("matpower/case3375wp/2017-02-01")
    """
    gz_name = f"{name}.json.gz"
    local_path = _CACHE / gz_name
    url = f"{INSTANCES_URL}/{gz_name}"

    if not local_path.is_file():
        if not quiet:
            print(f"Downloading  {url}")
        _download(url, local_path)

    instance = read(str(local_path))

    return instance


def read(path_or_paths: Union[str, Sequence[str]]) -> "UnitCommitmentInstance":
    """
    Generic loader.  Accepts:
      • single path (JSON or JSON.GZ) ➜ deterministic instance
      • list / tuple of paths           ➜ stochastic instance
    """
    if isinstance(path_or_paths, (list, tuple)):
        scenarios = [_read_scenario(p) for p in path_or_paths if isinstance(p, str)]
        _repair_scenario_names_and_probabilities(scenarios, list(path_or_paths))
    else:
        scenarios = [_read_scenario(path_or_paths)]
        scenarios[0].name = "s1"
        scenarios[0].probability = 1.0

    return UnitCommitmentInstance(time=scenarios[0].time, scenarios=scenarios)


# --------------------------------------------------------------------------- #
#  Internal helpers                                                           #
# --------------------------------------------------------------------------- #


def _download(url: str, dst: Path, chunk: int = 1 << 20) -> None:
    """Stream a file to *dst* with a progress bar."""
    dst.parent.mkdir(parents=True, exist_ok=True)
    with requests.get(url, stream=True, timeout=60) as r:
        r.raise_for_status()
        total = int(r.headers.get("content-length", 0))
        with (
            dst.open("wb") as fh,
            tqdm(total=total, unit="B", unit_scale=True, disable=total == 0) as bar,
        ):
            for chunk_data in r.iter_content(chunk_size=chunk):
                fh.write(chunk_data)
                bar.update(len(chunk_data))


def _read_json(path: str) -> dict:
    """Open JSON or JSON.GZ transparently."""
    if path.endswith(".gz"):
        with gzip.open(path, "rt", encoding="utf-8") as fh:
            return json.load(fh)
    with open(path, "r", encoding="utf-8") as fh:
        return json.load(fh)


def _read_scenario(path: str) -> UnitCommitmentScenario:
    raw = _read_json(path)
    _migrate(raw)
    return _from_json(raw)


def _repair_scenario_names_and_probabilities(
    scenarios: List["UnitCommitmentScenario"], paths: List[str]
) -> None:
    """Normalize names and probabilities so they sum to 1."""
    total = sum(sc.probability for sc in scenarios)
    for sc, p in zip(scenarios, paths):
        if not sc.name:
            sc.name = Path(p).stem.split(".")[0]
        sc.probability /= total


# --------------------------------------------------------------------------- #
#  Datastructures                                                             #
# --------------------------------------------------------------------------- #

Number = Union[int, float]
Series = List[Number]


@dataclass
class CostSegment:
    amount: Series
    cost: Series


@dataclass
class StartupCategory:
    delay_steps: int
    cost: float


@dataclass
class Bus:
    name: str
    index: int
    load: Series
    thermal_units: List["ThermalUnit"] = field(default_factory=list)
    price_sensitive_loads: List["PriceSensitiveLoad"] = field(default_factory=list)
    profiled_units: List["ProfiledUnit"] = field(default_factory=list)
    storage_units: List["StorageUnit"] = field(default_factory=list)


@dataclass
class Reserve:
    name: str
    type: str
    amount: Series
    thermal_units: List["ThermalUnit"]
    shortfall_penalty: float


@dataclass
class ThermalUnit:
    name: str
    bus: Bus
    max_power: Series
    min_power: Series
    must_run: Series
    min_power_cost: Series
    segments: List[CostSegment]
    min_up: int
    min_down: int
    ramp_up: float
    ramp_down: float
    startup_limit: float
    shutdown_limit: float
    initial_status: Optional[int]
    initial_power: Optional[float]
    startup_categories: List[StartupCategory]
    reserves: List[Reserve]
    commitment_status: List[Optional[bool]]


@dataclass
class ProfiledUnit:
    name: str
    bus: Bus
    min_power: Series
    max_power: Series
    cost: Series


@dataclass
class StorageUnit:
    name: str
    bus: Bus
    min_level: Series
    max_level: Series
    simultaneous: Series
    charge_cost: Series
    discharge_cost: Series
    charge_eff: Series
    discharge_eff: Series
    loss_factor: Series
    min_charge: Series
    max_charge: Series
    min_discharge: Series
    max_discharge: Series
    initial_level: float
    last_min: float
    last_max: float


@dataclass
class TransmissionLine:
    name: str
    index: int
    source: Bus
    target: Bus
    susceptance: float
    normal_limit: Series
    emergency_limit: Series
    flow_penalty: Series


@dataclass
class Contingency:
    name: str
    lines: List[TransmissionLine]
    units: List[ThermalUnit]


@dataclass
class PriceSensitiveLoad:
    name: str
    bus: Bus
    demand: Series
    revenue: Series


@dataclass
class UnitCommitmentScenario:
    name: str
    probability: float
    buses_by_name: Dict[str, Bus]
    buses: List[Bus]
    contingencies_by_name: Dict[str, Contingency]
    contingencies: List[Contingency]
    lines_by_name: Dict[str, TransmissionLine]
    lines: List[TransmissionLine]
    power_balance_penalty: Series
    price_sensitive_loads_by_name: Dict[str, PriceSensitiveLoad]
    price_sensitive_loads: List[PriceSensitiveLoad]
    reserves: List[Reserve]
    reserves_by_name: Dict[str, Reserve]
    time: int
    time_step: int
    thermal_units_by_name: Dict[str, ThermalUnit]
    thermal_units: List[ThermalUnit]
    profiled_units_by_name: Dict[str, ProfiledUnit]
    profiled_units: List[ProfiledUnit]
    storage_units_by_name: Dict[str, StorageUnit]
    storage_units: List[StorageUnit]
    isf: sparse.spmatrix
    lodf: sparse.spmatrix
    source: Optional[str] = None  # citation / paper reference


@dataclass
class UnitCommitmentInstance:
    time: int
    scenarios: List[UnitCommitmentScenario]

    # convenient alias
    @property
    def deterministic(self) -> UnitCommitmentScenario:
        if len(self.scenarios) != 1:
            raise ValueError("Instance is stochastic; pick a scenario explicitly")
        return self.scenarios[0]


# --------------------------------------------------------------------------- #
#  JSON ➜ objects                                                             #
# --------------------------------------------------------------------------- #


def _scalar(val, default=None):
    """Replicates Julia's scalar(x; default) helper."""
    return default if val is None else val


def _timeseries(val, T: int, *, default=None):
    """
    Julia behaviour:
      * if val is missing ➜ default
      * if val is array  ➜ keep
      * if val is scalar ➜ replicate T times
    """
    if val is None:
        return default if default is not None else [None] * T
    return val if isinstance(val, list) else [val] * T


def _parse_version(v):
    """Return (major, minor) tuple; treat malformed strings as (0, 0)."""
    try:
        return tuple(int(x) for x in str(v).split(".")[:2])
    except Exception:
        return (0, 0)


def _migrate(json_: dict) -> None:
    """
    Bring legacy (< 0.4) files up to date:

        * v0.2 → v0.3:  restructure reserves & generator flags
        * v0.3 → v0.4:  ensure every generator has `"Type": "Thermal"`
    """
    params = json_.get("Parameters", {})
    ver_raw = params.get("Version")
    if ver_raw is None:
        raise ValueError(
            "Input file has no Parameters['Version'] entry – please add it "
            '(e.g. {"Parameters": {"Version": "0.3"}}).'
        )

    ver = _parse_version(ver_raw)
    if ver < (0, 3):
        _migrate_to_v03(json_)
    if ver < (0, 4):
        _migrate_to_v04(json_)


def _migrate_to_v03(json_: dict) -> None:
    """Match Julia’s _migrate_to_v03: create r1 spinning reserve, map flags."""
    reserves = json_.get("Reserves")
    if reserves and "Spinning (MW)" in reserves:
        amount = reserves["Spinning (MW)"]
        # Replace the old flat field with the new nested structure
        json_["Reserves"] = {
            "r1": {
                "Type": "spinning",
                "Amount (MW)": amount,
            }
        }
        # Any generator that set the legacy boolean now becomes eligible for r1
        for gen in json_.get("Generators", {}).values():
            if gen.get("Provides spinning reserves?") is True:
                gen["Reserve eligibility"] = ["r1"]


def _migrate_to_v04(json_: dict) -> None:
    """Match Julia’s _migrate_to_v04: default missing types to Thermal."""
    for gen in json_.get("Generators", {}).values():
        gen.setdefault("Type", "Thermal")


def _from_json(j: dict) -> UnitCommitmentScenario:
    # -- Time grid ---------------------------------------------------------- #
    par = j["Parameters"]
    time_horizon = (
        par.get("Time horizon (min)")
        or par.get("Time (h)")
        or par.get("Time horizon (h)")
    )
    if time_horizon is None:
        raise ValueError("Missing parameter: Time horizon")
    if "Time (h)" in par or "Time horizon (h)" in par:
        time_horizon *= 60  # convert hours → minutes

    time_horizon = int(time_horizon)
    time_step = int(_scalar(par.get("Time step (min)"), default=60))
    if 60 % time_step or time_horizon % time_step:
        raise ValueError("Time step must divide 60 and the horizon")

    time_multiplier = 60 // time_step
    T = time_horizon // time_step

    # ---------------------------------------------------------------------- #
    #  Look-up tables                                                        #
    # ---------------------------------------------------------------------- #
    buses: List[Bus] = []
    lines: List[TransmissionLine] = []
    thermal_units: List[ThermalUnit] = []
    profiled_units: List[ProfiledUnit] = []
    storage_units: List[StorageUnit] = []
    reserves: List[Reserve] = []
    contingencies: List[Contingency] = []
    loads: List[PriceSensitiveLoad] = []

    name_to_bus, name_to_line, name_to_unit, name_to_reserve = ({}, {}, {}, {})

    # ---------------------------------------------------------------------- #
    #  Helper to make sure each list has length T                            #
    # ---------------------------------------------------------------------- #

    def ts(x, *, default=None):
        return _timeseries(x, T, default=default)

    # ---------------------------------------------------------------------- #
    #  Penalties                                                             #
    # ---------------------------------------------------------------------- #
    power_balance_penalty = ts(
        par.get("Power balance penalty ($/MW)"), default=[1000.0] * T
    )

    # ---------------------------------------------------------------------- #
    #  Buses                                                                 #
    # ---------------------------------------------------------------------- #
    for idx, (bname, bdict) in enumerate(j["Buses"].items(), start=1):
        bus = Bus(
            name=bname,
            index=idx,
            load=ts(bdict["Load (MW)"]),
        )
        name_to_bus[bname] = bus
        buses.append(bus)

    # ---------------------------------------------------------------------- #
    #  Reserves                                                              #
    # ---------------------------------------------------------------------- #
    if "Reserves" in j:
        for rname, rdict in j["Reserves"].items():
            r = Reserve(
                name=rname,
                type=rdict["Type"].lower(),
                amount=ts(rdict["Amount (MW)"]),
                thermal_units=[],
                shortfall_penalty=_scalar(
                    rdict.get("Shortfall penalty ($/MW)"), default=10
                ),
            )
            name_to_reserve[rname] = r
            reserves.append(r)

    # ---------------------------------------------------------------------- #
    #  Generators                                                            #
    # ---------------------------------------------------------------------- #
    for uname, udict in j["Generators"].items():
        utype = udict.get("Type")
        if not utype:
            raise ValueError(f"Generator {uname} missing Type")
        bus = name_to_bus[udict["Bus"]]

        if utype.lower() == "thermal":
            # Production cost curve
            curve_mw = udict["Production cost curve (MW)"]
            curve_cost = udict["Production cost curve ($)"]
            K = len(curve_mw)
            curve_mw = np.column_stack([ts(curve_mw[k]) for k in range(K)])
            curve_cost = np.column_stack([ts(curve_cost[k]) for k in range(K)])

            min_power = curve_mw[:, 0].tolist()
            max_power = curve_mw[:, -1].tolist()
            min_power_cost = curve_cost[:, 0].tolist()

            segments = []
            for k in range(1, K):
                amount = (curve_mw[:, k] - curve_mw[:, k - 1]).tolist()
                cost = (
                    (curve_cost[:, k] - curve_cost[:, k - 1])
                    / (np.maximum(amount, 1e-9))
                ).tolist()
                segments.append(CostSegment(amount, cost))

            # Startup categories
            delays = _scalar(udict.get("Startup delays (h)"), default=[1])
            scost = _scalar(udict.get("Startup costs ($)"), default=[0.0])
            startup_categories = [
                StartupCategory(int(delays[k] * time_multiplier), scost[k])
                for k in range(len(delays))
            ]

            # Reserve eligibility
            unit_reserves = [
                name_to_reserve[n] for n in udict.get("Reserve eligibility", [])
            ]
            # Initial conditions
            init_p = udict.get("Initial power (MW)")
            init_s = udict.get("Initial status (h)")
            if init_p is None:
                init_s = None
            elif init_s is None:
                raise ValueError(f"{uname} has power but no status")
            else:
                init_s = int(init_s * time_multiplier)

            commitment_status = _scalar(
                udict.get("Commitment status"), default=[None] * T
            )

            tu = ThermalUnit(
                name=uname,
                bus=bus,
                max_power=max_power,
                min_power=min_power,
                must_run=ts(udict.get("Must run?"), default=[False] * T),
                min_power_cost=min_power_cost,
                segments=segments,
                min_up=int(
                    _scalar(udict.get("Minimum uptime (h)"), 1) * time_multiplier
                ),
                min_down=int(
                    _scalar(udict.get("Minimum downtime (h)"), 1) * time_multiplier
                ),
                ramp_up=_scalar(udict.get("Ramp up limit (MW)"), 1e6),
                ramp_down=_scalar(udict.get("Ramp down limit (MW)"), 1e6),
                startup_limit=_scalar(udict.get("Startup limit (MW)"), 1e6),
                shutdown_limit=_scalar(udict.get("Shutdown limit (MW)"), 1e6),
                initial_status=init_s,
                initial_power=init_p,
                startup_categories=startup_categories,
                reserves=unit_reserves,
                commitment_status=commitment_status,
            )
            bus.thermal_units.append(tu)
            thermal_units.append(tu)
            name_to_unit[uname] = tu
            for r in unit_reserves:
                r.thermal_units.append(tu)

        elif utype.lower() == "profiled":
            pu = ProfiledUnit(
                name=uname,
                bus=bus,
                min_power=ts(_scalar(udict.get("Minimum power (MW)"), 0.0)),
                max_power=ts(udict["Maximum power (MW)"]),
                cost=ts(udict["Cost ($/MW)"]),
            )
            bus.profiled_units.append(pu)
            profiled_units.append(pu)
        else:
            raise ValueError(f"Unit {uname} has invalid type '{utype}'")

    # ---------------------------------------------------------------------- #
    #  Lines                                                                 #
    # ---------------------------------------------------------------------- #
    if "Transmission lines" in j:
        for idx, (lname, ldict) in enumerate(j["Transmission lines"].items(), start=1):
            line = TransmissionLine(
                name=lname,
                index=idx,
                source=name_to_bus[ldict["Source bus"]],
                target=name_to_bus[ldict["Target bus"]],
                susceptance=float(ldict["Susceptance (S)"]),
                normal_limit=ts(ldict.get("Normal flow limit (MW)"), default=[1e8] * T),
                emergency_limit=ts(
                    ldict.get("Emergency flow limit (MW)"), default=[1e8] * T
                ),
                flow_penalty=ts(
                    ldict.get("Flow limit penalty ($/MW)"), default=[5000.0] * T
                ),
            )
            lines.append(line)
            name_to_line[lname] = line

    # ---------------------------------------------------------------------- #
    #  Contingencies                                                         #
    # ---------------------------------------------------------------------- #
    if "Contingencies" in j:
        for cname, cdict in j["Contingencies"].items():
            affected_lines = [name_to_line[l] for l in cdict.get("Affected lines", [])]
            affected_units = [name_to_unit[u] for u in cdict.get("Affected units", [])]
            contingencies.append(
                Contingency(name=cname, lines=affected_lines, units=affected_units)
            )

    # ---------------------------------------------------------------------- #
    #  Price-sensitive loads                                                 #
    # ---------------------------------------------------------------------- #
    if "Price-sensitive loads" in j:
        for lname, ldict in j["Price-sensitive loads"].items():
            load = PriceSensitiveLoad(
                name=lname,
                bus=name_to_bus[ldict["Bus"]],
                demand=ts(ldict["Demand (MW)"]),
                revenue=ts(ldict["Revenue ($/MW)"]),
            )
            loads.append(load)
            load.bus.price_sensitive_loads.append(load)

    # ---------------------------------------------------------------------- #
    #  Storage units                                                         #
    # ---------------------------------------------------------------------- #
    if "Storage units" in j:
        for sname, sdict in j["Storage units"].items():
            bus = name_to_bus[sdict["Bus"]]
            min_level = ts(_scalar(sdict.get("Minimum level (MWh)"), 0.0))
            max_level = ts(sdict["Maximum level (MWh)"])
            su = StorageUnit(
                name=sname,
                bus=bus,
                min_level=min_level,
                max_level=max_level,
                simultaneous=ts(
                    _scalar(
                        sdict.get("Allow simultaneous charging and discharging"), True
                    )
                ),
                charge_cost=ts(sdict["Charge cost ($/MW)"]),
                discharge_cost=ts(sdict["Discharge cost ($/MW)"]),
                charge_eff=ts(_scalar(sdict.get("Charge efficiency"), 1.0)),
                discharge_eff=ts(_scalar(sdict.get("Discharge efficiency"), 1.0)),
                loss_factor=ts(_scalar(sdict.get("Loss factor"), 0.0)),
                min_charge=ts(_scalar(sdict.get("Minimum charge rate (MW)"), 0.0)),
                max_charge=ts(sdict["Maximum charge rate (MW)"]),
                min_discharge=ts(
                    _scalar(sdict.get("Minimum discharge rate (MW)"), 0.0)
                ),
                max_discharge=ts(sdict["Maximum discharge rate (MW)"]),
                initial_level=_scalar(sdict.get("Initial level (MWh)"), 0.0),
                last_min=_scalar(
                    sdict.get("Last period minimum level (MWh)"), min_level[-1]
                ),
                last_max=_scalar(
                    sdict.get("Last period maximum level (MWh)"), max_level[-1]
                ),
            )
            storage_units.append(su)
            bus.storage_units.append(su)

    # ---------------------------------------------------------------------- #
    #  Sparse matrices (zeros – replication of spzeros(Float64, …) )         #
    # ---------------------------------------------------------------------- #
    isf = sparse.csr_matrix((len(lines), len(buses) - 1), dtype=float)
    lodf = sparse.csr_matrix((len(lines), len(lines)), dtype=float)

    scenario = UnitCommitmentScenario(
        name=_scalar(par.get("Scenario name"), ""),
        probability=float(_scalar(par.get("Scenario weight"), 1)),
        buses_by_name={b.name: b for b in buses},
        buses=buses,
        contingencies_by_name={c.name: c for c in contingencies},
        contingencies=contingencies,
        lines_by_name={l.name: l for l in lines},
        lines=lines,
        power_balance_penalty=power_balance_penalty,
        price_sensitive_loads_by_name={pl.name: pl for pl in loads},
        price_sensitive_loads=loads,
        reserves=reserves,
        reserves_by_name=name_to_reserve,
        time=T,
        time_step=time_step,
        thermal_units_by_name={tu.name: tu for tu in thermal_units},
        thermal_units=thermal_units,
        profiled_units_by_name={pu.name: pu for pu in profiled_units},
        profiled_units=profiled_units,
        storage_units_by_name={su.name: su for su in storage_units},
        storage_units=storage_units,
        isf=isf,
        lodf=lodf,
        source=j["Parameters"].get("SOURCE"),
    )

    _repair(scenario)  # replicate Julia's repair! in a minimal way
    return scenario


# --------------------------------------------------------------------------- #
#  Basic "repair!"                                                           #
# --------------------------------------------------------------------------- #


def _repair(scenario: UnitCommitmentScenario) -> None:
    """
    Julia's repair! performs several tasks:
      • fills commitment_status for must-run units
      • clamps initial conditions
      • builds ISF/LODF if missing
    Here we implement minimal sanity checks.
    """
    for tu in scenario.thermal_units:
        # ensure commitment_status consistent with must_run
        for t, mr in enumerate(tu.must_run):
            if mr is True:
                tu.commitment_status[t] = True


In [3]:
# --------------------------------------------------------------------------- #
#  Quick self-test                                                            #
# --------------------------------------------------------------------------- #

if __name__ == "__main__":
    SAMPLE = "matpower/case57/2017-01-01"
    print(f"→ Loading sample instance '{SAMPLE}' …")
    inst = read_benchmark(SAMPLE, quiet=False)
    sc = inst.deterministic
    print(
        f"Loaded scenario '{sc.name}' with "
        f"{len(sc.thermal_units)} thermal units, "
        f"{len(sc.lines)} lines, horizon {sc.time} steps of "
        f"{sc.time_step} minutes."
    )

→ Loading sample instance 'matpower/case57/2017-01-01' …
Loaded scenario 's1' with 7 thermal units, 80 lines, horizon 36 steps of 60 minutes.


# Gurobi SCUC

In [4]:
def _nested():
    return gp.tupledict({})


class _VarKeeper(dict):
    """dict that auto-creates a nested tupledict on first access."""

    def __missing__(self, key):
        val = _nested()
        super().__setitem__(key, val)
        return val


def build_model(
    instance: UnitCommitmentInstance,
    *,
    name: str = "SCUC-Python",
    mip_gap: float | None = 0.001,
    threads: int | None = None,
    verbose: bool = True,
    warm_start: dict = None,  # Dictionary for warm start values
    pruned_lines: list = None,  # List of line names to drop constraints for
) -> gp.Model:
    """
    Translate *instance* into a Gurobi MIP with warm start and optional line pruning.

    Arguments
    ---------
    mip_gap       – Relative MIP gap (None ⇒ Gurobi default)
    threads       – Thread count (None ⇒ default)
    verbose       – If False, suppress Gurobi output
    warm_start    – Dict mapping variable names to initial values (e.g., {"is_on": {("g1", 1): 1}})
    pruned_lines  – List of line names whose flow constraints should be dropped
    """
    T = instance.time
    m = gp.Model(name)
    m._instance = instance
    if not verbose:
        m.setParam("OutputFlag", 0)
    if mip_gap is not None:
        m.setParam("MIPGap", mip_gap)
    if threads is not None:
        m.setParam("Threads", threads)

    # Variable storage
    v = {
        n: _VarKeeper()
        for n in (
            "is_on",
            "switch_on",
            "switch_off",
            "prod_above",
            "startup",
            "reserve",
            "reserve_shortfall",
            "upflexiramp",
            "dwflexiramp",
            "upflexiramp_shortfall",
            "dwflexiramp_shortfall",
            "expr_net_injection",
            "net_injection",
            "curtail",
            "overflow",
            "cont_overflow",
            "prod_profiled",
            "storage_level",
            "charge_rate",
            "discharge_rate",
            "is_charging",
            "is_discharging",
            "bal_pos",
            "bal_neg",
            "flow",
        )
    }
    m._vars = v
    con = {}
    obj = gp.LinExpr()

    # Define variables (unchanged from original except for warm start application)
    for g in instance.scenarios[0].thermal_units:
        for t in range(1, T + 1):
            k = (g.name, t)
            v["is_on"][k] = m.addVar(vtype=GRB.BINARY, name=f"is_on[{g.name},{t}]")
            v["switch_on"][k] = m.addVar(vtype=GRB.BINARY, name=f"su[{g.name},{t}]")
            v["switch_off"][k] = m.addVar(vtype=GRB.BINARY, name=f"sd[{g.name},{t}]")
            v["prod_above"][k] = m.addVar(lb=0, name=f"g+[{g.name},{t}]")
        for t in range(1, T + 1):
            for s, cat in enumerate(g.startup_categories, start=1):
                v["startup"][(g.name, t, s)] = m.addVar(
                    vtype=GRB.BINARY, name=f"startup[{g.name},{t},{s}]"
                )
        _add_unit_commitment_eqs(m, g, v, con, obj.add)

    # Apply warm start if provided
    if warm_start:
        for var_name, values in warm_start.items():
            for key, val in values.items():
                if key in v[var_name]:
                    v[var_name][key].VarHintVal = val
                    v[var_name][key].VarHintPri = 1  # Priority for warm start

    # Scenario-specific constraints
    for sc in instance.scenarios:
        pname = sc.name
        for line in sc.lines:
            for t in range(1, T + 1):
                key = (pname, line.name, t)
                v["overflow"][key] = m.addVar(
                    lb=0, name=f"ovfl[{line.name},{pname},{t}]"
                )
                obj.add(line.flow_penalty[t - 1] * sc.probability * v["overflow"][key])
        for bus in sc.buses:
            for t in range(1, T + 1):
                v["expr_net_injection"][(pname, bus.name, t)] = gp.LinExpr(
                    -bus.load[t - 1]
                )
                v["curtail"][(pname, bus.name, t)] = m.addVar(
                    lb=0, ub=bus.load[t - 1], name=f"curt[{bus.name},{pname},{t}]"
                )
                obj.add(
                    sc.power_balance_penalty[t - 1]
                    * sc.probability
                    * v["curtail"][(pname, bus.name, t)]
                )
                v["expr_net_injection"][(pname, bus.name, t)].addTerms(
                    1.0, v["curtail"][(pname, bus.name, t)]
                )
        for g in sc.thermal_units:
            _add_unit_dispatch_eqs(m, g, sc, v, con, obj.add)
        for pu in sc.profiled_units:
            _add_profiled_unit(m, pu, sc, v, obj.add)
        for ps in sc.price_sensitive_loads:
            _add_price_sensitive_load(m, ps, sc, v, obj.add)
        for su in sc.storage_units:
            _add_storage(m, su, sc, v, con, obj.add)
        _add_transmission_eqs(
            m, sc, v, con, pruned_lines=pruned_lines
        )  # Pass pruned_lines
        _add_contingency_eqs(m, sc, v, con, obj.add)
        _add_system_wide_eqs(m, sc, v, con, obj.add)

    m.setObjective(obj, GRB.MINIMIZE)
    m.update()
    return m


def _add_reserve_vars_and_eqs(m, g, sc, v, add):
    """
    Add reserve variables and constraints, including ramp-up capacity for spinning reserves.
    Follows UnitCommitment.jl/src/formulation/reserves.jl (Carrion-Arroyo formulation).
    """
    pname = sc.name
    T = m._instance.time
    reserve = v["reserve"]
    reserve_sf = v["reserve_shortfall"]
    upfr, dwfr = v["upflexiramp"], v["dwflexiramp"]
    upfr_sf, dwfr_sf = v["upflexiramp_shortfall"], v["dwflexiramp_shortfall"]
    is_on = v["is_on"]
    prod_above = v["prod_above"]

    for r in g.reserves:
        if r.type == "spinning":
            for t in range(1, T + 1):
                key = (pname, r.name, g.name, t)
                reserve[key] = m.addVar(
                    lb=0, name=f"Rspin[{g.name},{r.name},{pname},{t}]"
                )
                reserve_sf.setdefault((pname, r.name, t), m.addVar(lb=0))

                # Ramp-up capacity constraint: reserve + production above min <= max power
                m.addConstr(
                    prod_above[(g.name, t)] + reserve[key]
                    <= (g.max_power[t - 1] - g.min_power[t - 1]) * is_on[(g.name, t)],
                    name=f"RspinCap[{g.name},{r.name},{pname},{t}]",
                )

                # Reserve limited by ramp-up rate
                m.addConstr(
                    reserve[key] <= g.ramp_up * is_on[(g.name, t)],
                    name=f"RspinRamp[{g.name},{r.name},{pname},{t}]",
                )

        elif r.type == "flexiramp":
            for t in range(1, T + 1):
                upfr[(pname, r.name, g.name, t)] = m.addVar(lb=0)
                dwfr[(pname, r.name, g.name, t)] = m.addVar(lb=0)
                upfr_sf.setdefault((pname, r.name, t), m.addVar(lb=0))
                dwfr_sf.setdefault((pname, r.name, t), m.addVar(lb=0))


# --------------------------------------------------------------------------- #
# -------------  E Q U A T I O N   B U I L D E R S  ------------------------- #
# --------------------------------------------------------------------------- #


def _add_unit_commitment_eqs(
    m: gp.Model,
    g: ThermalUnit,
    v: Dict[str, TD],
    con: Dict[str, TD],
    add,
):
    """Binary/unit-commitment constraints & startup costs."""
    T = m._instance.time
    is_on, su, sd = v["is_on"], v["switch_on"], v["switch_off"]
    gmin, gmax = g.min_power, g.max_power

    # Logical consistency: on_t − on_{t-1} = su_t − sd_t
    for t in range(1, T + 1):
        lhs = gp.LinExpr(is_on[(g.name, t)])
        rhs = gp.LinExpr()
        if t == 1:
            rhs.addConstant(int(g.initial_status > 0))
        else:
            rhs.add(is_on[(g.name, t - 1)], 1.0)
        rhs.add(su[(g.name, t)], 1.0)
        rhs.add(sd[(g.name, t)], -1.0)
        m.addConstr(lhs - rhs == 0, name=f"logic[{g.name},{t}]")

    # Piece-wise startup cost contribution to objective
    for t in range(1, T + 1):
        for s, cat in enumerate(g.startup_categories, start=1):
            add(cat.cost * v["startup"][(g.name, t, s)])

    # Minimum up-/down-time
    for t in range(1, T + 1):
        m.addConstr(
            gp.quicksum(su[(g.name, τ)] for τ in range(max(1, t - g.min_up + 1), t + 1))
            <= is_on[(g.name, t)],
            name=f"minUP[{g.name},{t}]",
        )
        m.addConstr(
            gp.quicksum(
                sd[(g.name, τ)] for τ in range(max(1, t - g.min_down + 1), t + 1)
            )
            <= 1 - is_on[(g.name, t)],
            name=f"minDN[{g.name},{t}]",
        )

    # Production above minimum must be zero when off
    for t in range(1, T + 1):
        key = (g.name, t)
        m.addConstr(
            v["prod_above"][key] <= (gmax[t - 1] - gmin[t - 1]) * is_on[key],
            name=f"gmax[{g.name},{t}]",
        )


def _add_unit_dispatch_eqs(m, g, sc, v, con, add):
    """
    Continuous dispatch, ramping, piece-wise costs, reserves.
    This replaces the buggy block that tried to call .VarName / .add on Vars.
    """
    T = m._instance.time
    name = g.name
    pname = sc.name
    gmin = g.min_power
    gmax = g.max_power
    is_on = v["is_on"]
    g_abv = v["prod_above"]
    netexp = v["expr_net_injection"]

    # ---------- 1) Piece-wise cost segments & definition of g_abv ----------
    for t in range(1, T + 1):
        seg_vars = []
        for s, seg in enumerate(g.segments, start=1):
            amount = seg.amount[t - 1]
            slope = seg.cost[t - 1]
            z = m.addVar(lb=0, ub=amount, name=f"seg[{name},{t},{s}]")
            seg_vars.append(z)
            # cost
            add(sc.probability * slope * z)
            # availability only if unit is on
            m.addConstr(z <= amount * is_on[(name, t)], name=f"segON[{name},{t},{s}]")
        # tie segments to g_abv
        m.addConstr(g_abv[(name, t)] == gp.quicksum(seg_vars), name=f"gDef[{name},{t}]")

        # net injection: min-power + above-min
        netexp[(pname, g.bus.name, t)].add(g_abv[(name, t)])
        netexp[(pname, g.bus.name, t)].add(is_on[(name, t)], gmin[t - 1])

    # ---------- 2) Ramping limits -----------------------------------------
    for t in range(1, T + 1):
        prev = g.initial_power - gmin[0] if t == 1 else g_abv[(name, t - 1)]
        m.addConstr(g_abv[(name, t)] - prev <= g.ramp_up, name=f"RU[{name},{t}]")
        m.addConstr(prev - g_abv[(name, t)] <= g.ramp_down, name=f"RD[{name},{t}]")

    # ---------- 3) Reserves & flexiramp (unchanged) ------------------------
    _add_reserve_vars_and_eqs(m, g, sc, v, add)


def _add_reserve_vars_and_eqs(m, g, sc, v, add):
    pname = sc.name
    T = m._instance.time
    reserve = v["reserve"]
    reserve_sf = v["reserve_shortfall"]
    upfr, dwfr = v["upflexiramp"], v["dwflexiramp"]
    upfr_sf, dwfr_sf = v["upflexiramp_shortfall"], v["dwflexiramp_shortfall"]

    for r in g.reserves:
        if r.type == "spinning":
            for t in range(1, T + 1):
                reserve[(pname, r.name, g.name, t)] = m.addVar(
                    lb=0, name=f"Rspin[{g.name},{r.name},{pname},{t}]"
                )
                reserve_sf.setdefault((pname, r.name, t), m.addVar(lb=0))
        elif r.type == "flexiramp":
            for t in range(1, T + 1):
                upfr[(pname, r.name, g.name, t)] = m.addVar(lb=0)
                dwfr[(pname, r.name, g.name, t)] = m.addVar(lb=0)
                upfr_sf.setdefault((pname, r.name, t), m.addVar(lb=0))
                dwfr_sf.setdefault((pname, r.name, t), m.addVar(lb=0))

    # Ramp-up capacity must cover spinning reserve + prod increase
    # (simple Carrion–Arroyo type) – omitted here for brevity.


def _add_profiled_unit(m, pu: ProfiledUnit, sc, v, add):
    pname = sc.name
    T = m._instance.time
    for t in range(1, T + 1):
        key = (pname, pu.name, t)
        v["prod_profiled"][key] = m.addVar(
            lb=pu.min_power[t - 1], ub=pu.max_power[t - 1], name=f"PU[{pu.name},{t}]"
        )
        add(sc.probability * pu.cost[t - 1] * v["prod_profiled"][key])
        v["expr_net_injection"][(pname, pu.bus.name, t)].add(v["prod_profiled"][key])


def _add_price_sensitive_load(m, ps: PriceSensitiveLoad, sc, v, add):
    pname = sc.name
    T = m._instance.time
    for t in range(1, T + 1):
        k = (pname, ps.name, t)
        var = m.addVar(lb=0, ub=ps.demand[t - 1], name=f"PSL[{ps.name},{t}]")
        add(-sc.probability * ps.revenue[t - 1] * var)
        v["expr_net_injection"][(pname, ps.bus.name, t)].add(var, -1.0)


def _add_transmission_eqs(m, sc, v, con, pruned_lines=None):
    """ISF/PTDF linear flow constraints with overflow slack and optional pruning."""
    pname = sc.name
    T = m._instance.time
    isf = sc.isf
    if isf.size == 0:  # Single-bus case
        return
    for l_idx, line in enumerate(sc.lines):
        if pruned_lines and line.name in pruned_lines:  # Skip pruned lines
            continue
        for t in range(1, T + 1):
            flow_expr = gp.LinExpr()
            for b_idx, bus in enumerate(sc.buses[1:], start=0):  # Skip slack
                flow_expr.add(
                    v["net_injection"][(pname, bus.name, t)], isf[l_idx, b_idx]
                )
            flow_key = (pname, line.name, t)
            v["flow"][flow_key] = m.addVar(
                lb=-GRB.INFINITY, name=f"flow[{line.name},{pname},{t}]"
            )
            m.addConstr(
                v["flow"][flow_key] == flow_expr,
                name=f"FlowDef[{line.name},{pname},{t}]",
            )
            m.addConstr(
                v["flow"][flow_key] - v["overflow"][(pname, line.name, t)]
                <= line.normal_limit[t - 1],
                name=f"PF+[{line.name},{pname},{t}]",
            )
            m.addConstr(
                -v["flow"][flow_key] - v["overflow"][(pname, line.name, t)]
                <= line.normal_limit[t - 1],
                name=f"PF-[{line.name},{pname},{t}]",
            )


def _add_contingency_eqs(m, sc, v, con, add):
    """Transmission constraints for contingencies using LODF matrices."""
    pname = sc.name
    T = m._instance.time
    isf = sc.isf
    lodf = sc.lodf
    if isf.size == 0 or len(sc.contingencies) == 0:
        return

    for c in sc.contingencies:
        outaged_line_indices = [l.index - 1 for l in c.lines]
        if not outaged_line_indices:
            continue
        for l_idx, line in enumerate(sc.lines):
            if line in c.lines:
                continue
            for t in range(1, T + 1):
                flow_expr = gp.LinExpr()
                for b_idx, bus in enumerate(sc.buses[1:], start=0):
                    flow_expr.add(
                        v["net_injection"][(pname, bus.name, t)], isf[l_idx, b_idx]
                    )
                for o_idx in outaged_line_indices:
                    if o_idx != l_idx:
                        for b_idx, bus in enumerate(sc.buses[1:], start=0):
                            flow_expr.add(
                                v["net_injection"][(pname, bus.name, t)],
                                isf[o_idx, b_idx] * lodf[l_idx, o_idx],
                            )
                key = (pname, c.name, line.name, t)
                v["cont_overflow"][key] = m.addVar(
                    lb=0, name=f"cont_ovfl[{c.name},{line.name},{pname},{t}]"
                )
                add(line.flow_penalty[t - 1] * sc.probability * v["cont_overflow"][key])
                m.addConstr(
                    flow_expr - v["cont_overflow"][key] <= line.emergency_limit[t - 1],
                    name=f"ContPF+[{c.name},{line.name},{pname},{t}]",
                )
                m.addConstr(
                    -flow_expr - v["cont_overflow"][key] <= line.emergency_limit[t - 1],
                    name=f"ContPF-[{c.name},{line.name},{pname},{t}]",
                )


def _reserve_requirement(
    m,
    r: Reserve,
    res_vars: TD,
    sf_vars: TD,
    sc,
    T,
    add,
    *,
    paired: TD | None = None,
    paired_sf: TD | None = None,
):
    pname = sc.name
    for t in range(1, T + 1):
        lhs = gp.quicksum(res_vars[(pname, r.name, g.name, t)] for g in r.thermal_units)
        if paired is not None:
            lhs += gp.quicksum(
                paired[(pname, r.name, g.name, t)] for g in r.thermal_units
            )
        sf = sf_vars[(pname, r.name, t)]
        lhs += sf
        m.addConstr(lhs >= r.amount[t - 1], name=f"ResReq[{r.name},{t}]")
        add(sc.probability * r.shortfall_penalty * sf)


def _add_system_wide_eqs(m, sc, v, con, add):
    """
    • Creates explicit net-injection variables NI[b,t], ensuring they can be negative.
    • Links NI to the expression of physical flows (generation, load, etc.).
    • Adds a SINGLE system-wide power balance constraint for each time period,
      allowing power to flow between buses.
    • Penalises system-wide surplus/shortfall to ensure balance.
    • Builds reserve-requirement constraints (unchanged from before).
    """
    pname = sc.name
    T = m._instance.time
    HIGH_PENALTY = 1e7  # A large penalty for system imbalance

    # --- 1. Define explicit net injection variables for each bus ---
    # This part is crucial because other functions (like transmission) use v["net_injection"].
    # We must define it as a free variable (positive or negative).
    for bus in sc.buses:
        for t in range(1, T + 1):
            key = (pname, bus.name, t)
            # Define net_injection as a variable that can be positive or negative
            v["net_injection"][key] = m.addVar(
                lb=-GRB.INFINITY, name=f"NI[{bus.name},{t}]"
            )
            # Link this variable to the detailed expression of flows at the bus
            m.addConstr(
                v["net_injection"][key] == v["expr_net_injection"][key],
                name=f"LinkNI[{bus.name},{t}]",
            )

    # --- 2. Add a single system-wide power balance for each time period ---
    for t in range(1, T + 1):
        # Create system-wide slack variables for this time step
        p_surplus = m.addVar(lb=0, name=f"P_surplus[{pname},{t}]")
        p_shortfall = m.addVar(lb=0, name=f"P_shortfall[{pname},{t}]")

        # Penalize any system-wide imbalance in the objective function
        add(sc.probability * HIGH_PENALTY * (p_surplus + p_shortfall))

        # Add the single system balance constraint: Sum of all net injections must be zero
        # (allowing for a small, heavily penalized slack).
        m.addConstr(
            gp.quicksum(v["net_injection"][(pname, bus.name, t)] for bus in sc.buses)
            == p_surplus - p_shortfall,
            name=f"SystemBalance[{pname},{t}]",
        )

    # --- 3. Add reserve requirements (this part is unchanged) ---
    for r in sc.reserves:
        if r.type == "spinning":
            _reserve_requirement(m, r, v["reserve"], v["reserve_shortfall"], sc, T, add)
        elif r.type == "flexiramp":
            _reserve_requirement(
                m,
                r,
                v["upflexiramp"],
                v["upflexiramp_shortfall"],
                sc,
                T,
                add,
                paired=v["dwflexiramp"],
                paired_sf=v["dwflexiramp_shortfall"],
            )


In [None]:
# --------------------------------------------------------------------------- #
#  Demo / self-test                                                           #
# --------------------------------------------------------------------------- #

if __name__ == "__main__":
    SAMPLE = "matpower/case14/2017-06-24"
    print(f"→ Loading sample instance '{SAMPLE}' …")
    inst = read_benchmark(SAMPLE, quiet=False)
    sc = inst.deterministic
    print(
        f"Loaded scenario '{sc.name}' with "
        f"{len(sc.thermal_units)} thermal units, "
        f"{len(sc.lines)} lines, horizon {sc.time} steps of "
        f"{sc.time_step} minutes."
    )

    # Define warm start: e.g., unit "g1" is on with 50 MW production above min
    T = inst.time
    warm_start = {
        "is_on": {("g1", t): 1 for t in range(1, T + 1)},
        "prod_above": {("g1", t): 50 for t in range(1, T + 1)}
    }
    
    # Load pruned lines from JSON
    with open("warm_start\constraints_line_case14_2017-06-24 (1).json", "r") as f:
        data = json.load(f)
        pruned_lines = [line for line, val in data["line_pruning"].items() if val == 1]

    model = build_model(inst, warm_start={}, pruned_lines=pruned_lines, mip_gap=0.1, verbose=True)
    model.optimize()
    print(f"★ Optimal cost: {model.ObjVal:,.2f}")

→ Loading sample instance 'matpower/case14/2017-06-24' …
Loaded scenario 's1' with 5 thermal units, 20 lines, horizon 36 steps of 60 minutes.
Set parameter MIPGap to value 0.1


Gurobi Optimizer version 12.0.1 build v12.0.1rc0 (win64 - Windows 11.0 (22631.2))

CPU model: AMD Ryzen 7 7840HS with Radeon 780M Graphics, instruction set [SSE2|AVX|AVX2|AVX512]
Thread count: 8 physical cores, 16 logical processors, using up to 16 threads

Non-default parameters:
MIPGap  0.1

Optimize a model with 2556 rows, 4140 columns and 7209 nonzeros
Model fingerprint: 0x7d420925
Variable types: 3060 continuous, 1080 integer (1080 binary)
Coefficient statistics:
  Matrix range     [1e+00, 3e+02]
  Objective range  [1e+01, 1e+07]
  Bounds range     [1e+00, 1e+02]
  RHS range        [1e+00, 5e+02]
Found heuristic solution: objective 5.089946e+08
Presolve removed 998 rows and 2886 columns
Presolve time: 0.03s
Presolved: 1558 rows, 1254 columns, 4876 nonzeros
Variable types: 1008 continuous, 246 integer (246 binary)



  with open("warm_start\constraints_line_case14_2017-06-24 (1).json", "r") as f:


Root relaxation: objective 3.286001e+05, 395 iterations, 0.00 seconds (0.00 work units)

    Nodes    |    Current Node    |     Objective Bounds      |     Work
 Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time

*    0     0               0    328600.08720 328600.087  0.00%     -    0s

Explored 1 nodes (395 simplex iterations) in 0.06 seconds (0.02 work units)
Thread count was 16 (of 16 available processors)

Solution count 2: 328600 5.08995e+08 

Optimal solution found (tolerance 1.00e-01)
Best objective 3.286000872003e+05, best bound 3.286000872003e+05, gap 0.0000%
★ Optimal cost: 328,600.09


#  Solution extractor

In [None]:
from collections import OrderedDict
from math import isfinite
import gurobipy as gp
from gurobipy import GRB


def _round(x, digits=5):
    """Round a number, Gurobi variable, or linear expression to specified digits if finite."""
    if isinstance(x, gp.Var):
        x = x.X if hasattr(x, "X") else 0.0  # Handle unset variables
    elif isinstance(x, gp.LinExpr):
        x = x.getValue() if x.size() > 0 else 0.0  # Evaluate LinExpr to scalar
    return round(float(x), digits) if isfinite(x) else float(x)


def extract_solution(m: gp.Model) -> OrderedDict:
    """
    Extract a solved Gurobi model's solution as a nested OrderedDict matching
    UnitCommitment.jl's solution() structure.

    Args:
        m: Solved Gurobi model with attached m._instance and m._vars.

    Returns:
        OrderedDict with solution data, single scenario dict if deterministic.

    Raises:
        RuntimeError: If model is not solved or lacks required attributes.
    """
    # Validate model state and attributes
    if m.Status not in (GRB.OPTIMAL, GRB.SUBOPTIMAL, GRB.INTERRUPTED):
        raise RuntimeError(f"Model has invalid status: {m.Status}")
    if not hasattr(m, "_instance") or not hasattr(m, "_vars"):
        raise RuntimeError("Model missing required attributes: m._instance or m._vars")

    inst = m._instance
    T = inst.time
    v = m._vars

    def series(var_dict, keys, scenario=None, prefix=""):
        """Create OrderedDict of time series for given keys."""
        od = OrderedDict()
        sname = scenario.name if scenario else None
        for obj in keys:
            key_prefix = (
                (sname, prefix, obj.name)
                if sname and prefix
                else (sname, obj.name)
                if sname
                else obj.name
            )
            od[obj.name] = [
                _round(var_dict.get((key_prefix, t), 0.0)) for t in range(1, T + 1)
            ]
        return od

    def thermal_outputs(sc, sname):
        """Generate thermal unit-related outputs."""
        od = OrderedDict()
        if not sc.thermal_units:
            return od

        # Production and costs
        od["Thermal production (MW)"] = OrderedDict(
            (
                g.name,
                [
                    _round(
                        v["is_on"][(g.name, t)].X * g.min_power[t - 1]
                        + v["prod_above"][(g.name, t)].X
                    )
                    for t in range(1, T + 1)
                ],
            )
            for g in sc.thermal_units
        )
        od["Thermal production cost ($)"] = OrderedDict(
            (
                g.name,
                [
                    _round(
                        v["is_on"][(g.name, t)].X * g.min_power_cost[t - 1]
                        + sum(
                            m.getVarByName(f"seg[{g.name},{t},{k}]").X
                            * g.segments[k - 1].cost[t - 1]
                            for k in range(1, len(g.segments) + 1)
                        )
                    )
                    for t in range(1, T + 1)
                ],
            )
            for g in sc.thermal_units
        )
        od["Startup cost ($)"] = OrderedDict(
            (
                g.name,
                [
                    _round(
                        sum(
                            g.startup_categories[s - 1].cost
                            * v["startup"][(g.name, t, s)].X
                            for s in range(1, len(g.startup_categories) + 1)
                        )
                    )
                    for t in range(1, T + 1)
                ],
            )
            for g in sc.thermal_units
        )

        # Commitment status
        for key in ["is_on", "switch_on", "switch_off"]:
            od[key.capitalize()] = series(v[key], sc.thermal_units)

        # Net injection and curtailment
        od["Net injection (MW)"] = series(v["net_injection"], sc.buses, scenario=sc)
        od["Load curtail (MW)"] = series(v["curtail"], sc.buses, scenario=sc)

        return od

    def transmission_outputs(sc, sname):
        """Generate transmission-related outputs."""
        od = OrderedDict()
        if not sc.lines:
            return od

        # od["Line flow (MW)"] = series(v["flow"], sc.lines, scenario=sc)
        od["Line overflow (MW)"] = series(v["overflow"], sc.lines, scenario=sc)
        if sc.contingencies:
            od["Contingency overflow (MW)"] = OrderedDict(
                (
                    c.name,
                    series(v["cont_overflow"], sc.lines, scenario=sc, prefix=c.name),
                )
                for c in sc.contingencies
            )
        return od

    def other_unit_outputs(sc, sname):
        """Generate outputs for profiled units, price-sensitive loads, and storage."""
        od = OrderedDict()

        # Profiled units
        if sc.profiled_units:
            od["Profiled production (MW)"] = series(
                v["prod_profiled"], sc.profiled_units, scenario=sc
            )
            od["Profiled production cost ($)"] = OrderedDict(
                (
                    pu.name,
                    [
                        _round(
                            v["prod_profiled"][(sname, pu.name, t)].X * pu.cost[t - 1]
                        )
                        for t in range(1, T + 1)
                    ],
                )
                for pu in sc.profiled_units
            )

        # Price-sensitive loads
        if sc.price_sensitive_loads:
            od["Price-sensitive loads (MW)"] = series(
                v["expr_net_injection"], sc.price_sensitive_loads, scenario=sc
            )

        # Storage units
        if sc.storage_units:
            for key in ["storage_level", "charge_rate", "discharge_rate"]:
                od[key.capitalize().replace("_", " ") + " (MW)"] = series(
                    v[key], sc.storage_units, scenario=sc
                )
            for key in ["is_charging", "is_discharging"]:
                od[key.capitalize().replace("_", " ")] = series(
                    v[key], sc.storage_units, scenario=sc
                )
            od["Storage charging cost ($)"] = OrderedDict(
                (
                    su.name,
                    [
                        _round(
                            v["charge_rate"][(sname, su.name, t)].X
                            * su.charge_cost[t - 1]
                        )
                        for t in range(1, T + 1)
                    ],
                )
                for su in sc.storage_units
            )
            od["Storage discharging cost ($)"] = OrderedDict(
                (
                    su.name,
                    [
                        _round(
                            v["discharge_rate"][(sname, su.name, t)].X
                            * su.discharge_cost[t - 1]
                        )
                        for t in range(1, T + 1)
                    ],
                )
                for su in sc.storage_units
            )

        return od

    def reserve_outputs(sc, sname):
        """Generate reserve-related outputs."""
        od = OrderedDict()
        if not sc.reserves:
            return od

        reserve_types = {
            "spinning": ["reserve", "reserve_shortfall"],
            "flexiramp": [
                "upflexiramp",
                "upflexiramp_shortfall",
                "dwflexiramp",
                "dwflexiramp_shortfall",
            ],
        }
        for r_type, var_names in reserve_types.items():
            for var_name in var_names:
                key = var_name.replace("_", " ").capitalize() + " (MW)"
                od[key] = OrderedDict(
                    (
                        r.name,
                        OrderedDict(
                            (
                                g.name,
                                [
                                    _round(v[var_name][(sname, r.name, g.name, t)].X)
                                    for t in range(1, T + 1)
                                ],
                            )
                            for g in r.thermal_units
                        ),
                    )
                    if "shortfall" not in var_name
                    else (
                        r.name,
                        [
                            _round(v[var_name][(sname, r.name, t)].X)
                            for t in range(1, T + 1)
                        ],
                    )
                    for r in sc.reserves
                    if r.type == r_type
                )

        return od

    # Build solution dictionary
    sol = OrderedDict()
    for sc in inst.scenarios:
        sname = sc.name
        sdict = OrderedDict()
        sol[sname] = sdict
        sdict.update(thermal_outputs(sc, sname))
        # sdict.update(transmission_outputs(sc, sname))
        sdict.update(other_unit_outputs(sc, sname))
        # sdict.update(reserve_outputs(sc, sname))

    return next(iter(sol.values())) if len(sol) == 1 else sol

In [None]:
from pathlib import Path
from collections import OrderedDict

# extract solution
sol = extract_solution(model)
print(list(sol.keys()))  # top-level keys
print(sol["Thermal production (MW)"]["g1"][:5])  # first 5 periods of g1


# flatten nested OrderedDict → dict of numpy arrays
def _flatten_sol_dict(d, parent_key="", sep="/"):
    items = {}
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, (dict, OrderedDict)):
            items.update(_flatten_sol_dict(v, new_key, sep=sep))
        else:
            items[new_key] = np.array(v)
    return items


flat_sol = _flatten_sol_dict(sol)

in_path = Path(SAMPLE)
out_dir = Path("output_gurobi") / in_path.parent
out_dir.mkdir(parents=True, exist_ok=True)

# name the file "<basename>_solution.npz"
out_file = out_dir / f"{in_path.name}_solution.npz"
np.savez(out_file, **flat_sol)

print(f"\nSolution saved to {out_file}")

In [None]:
import matplotlib.pyplot as plt


def plot_solution_comparison(
    solution_files: List[Union[str, Path]],
    instance: "UnitCommitmentInstance",
    output_dir: Union[str, Path] = "plots",
    metrics: List[str] = None,
    max_units: int = 5,
    max_buses: int = 5,
    max_lines: int = 5,
):
    """
    Plot key metrics from multiple SCUC solution files for comparison.

    Args:
        solution_files: List of paths to .npz solution files.
        instance: UnitCommitmentInstance object for metadata (e.g., unit names, time steps).
        output_dir: Directory to save plots (default: "plots").
        metrics: List of metrics to plot (default: ["Thermal production (MW)", "Net injection (MW)",
                 "Line flow (MW)", "Line overflow (MW)", "Spinning reserve shortfall (MW)"]).
        max_units: Max number of thermal units to plot (default: 5).
        max_buses: Max number of buses to plot (default: 5).
        max_lines: Max number of lines to plot (default: 5).

    Saves plots to output_dir as PNG files.
    """
    # Default metrics if none provided
    if metrics is None:
        metrics = [
            "Thermal production (MW)",
            "Net injection (MW)",
            "Line flow (MW)",
            "Line overflow (MW)",
            "Spinning reserve shortfall (MW)",
            "Up-flexiramp shortfall (MW)",
            "Down-flexiramp shortfall (MW)",
        ]

    # Ensure output directory exists
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Load solutions
    solutions = []
    labels = []
    for i, file in enumerate(solution_files):
        file = Path(file)
        solutions.append(np.load(file, allow_pickle=True))
        labels.append(file.stem.replace("_solution", f"Run {i + 1}"))

    # Get scenario and time horizon
    sc = instance.deterministic
    T = instance.time
    time_steps = np.arange(1, T + 1)

    # Colors for different runs
    colors = plt.cm.tab10(np.linspace(0, 1, len(solutions)))

    # 1. Plot Thermal Production
    if "Thermal production (MW)" in metrics:
        plt.figure(figsize=(10, 6))
        thermal_units = sc.thermal_units[:max_units]
        for i, sol in enumerate(solutions):
            for unit in thermal_units:
                key = f"Thermal production (MW)/{unit.name}"
                if key in sol:
                    plt.plot(
                        time_steps,
                        sol[key],
                        label=f"{labels[i]}: {unit.name}",
                        color=colors[i],
                        linestyle="-",
                    )
        plt.xlabel("Time Step")
        plt.ylabel("Power (MW)")
        plt.title("Thermal Production Comparison")
        plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(output_dir / "thermal_production.png", bbox_inches="tight")
        plt.close()

    # 2. Plot Net Injection
    if "Net injection (MW)" in metrics:
        plt.figure(figsize=(10, 6))
        buses = sc.buses[:max_buses]
        for i, sol in enumerate(solutions):
            for bus in buses:
                key = f"Net injection (MW)/{bus.name}"
                if key in sol:
                    plt.plot(
                        time_steps,
                        sol[key],
                        label=f"{labels[i]}: {bus.name}",
                        color=colors[i],
                        linestyle="-",
                    )
        plt.xlabel("Time Step")
        plt.ylabel("Net Injection (MW)")
        plt.title("Net Injection Comparison")
        plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(output_dir / "net_injection.png", bbox_inches="tight")
        plt.close()

    # 3. Plot Line Flows
    if "Line flow (MW)" in metrics:
        plt.figure(figsize=(10, 6))
        lines = sc.lines[:max_lines]
        for i, sol in enumerate(solutions):
            for line in lines:
                key = f"Line flow (MW)/{line.name}"
                if key in sol:
                    plt.plot(
                        time_steps,
                        sol[key],
                        label=f"{labels[i]}: {line.name}",
                        color=colors[i],
                        linestyle="-",
                    )
        plt.xlabel("Time Step")
        plt.ylabel("Flow (MW)")
        plt.title("Line Flow Comparison")
        plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(output_dir / "line_flow.png", bbox_inches="tight")
        plt.close()

    # 4. Plot Line Overflows
    if "Line overflow (MW)" in metrics:
        plt.figure(figsize=(10, 6))
        lines = sc.lines[:max_lines]
        for i, sol in enumerate(solutions):
            for line in lines:
                key = f"Line overflow (MW)/{line.name}"
                if key in sol:
                    plt.plot(
                        time_steps,
                        sol[key],
                        label=f"{labels[i]}: {line.name}",
                        color=colors[i],
                        linestyle="-",
                    )
        plt.xlabel("Time Step")
        plt.ylabel("Overflow (MW)")
        plt.title("Line Overflow Comparison")
        plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(output_dir / "line_overflow.png", bbox_inches="tight")
        plt.close()

    # 5. Plot Reserve Shortfalls
    for reserve_type in [
        "Spinning reserve shortfall (MW)",
        "Up-flexiramp shortfall (MW)",
        "Down-flexiramp shortfall (MW)",
    ]:
        if reserve_type in metrics:
            plt.figure(figsize=(10, 6))
            reserves = [r for r in sc.reserves if r.type in reserve_type.lower()]
            for i, sol in enumerate(solutions):
                for r in reserves:
                    key = f"{reserve_type}/{r.name}"
                    if key in sol:
                        plt.plot(
                            time_steps,
                            sol[key],
                            label=f"{labels[i]}: {r.name}",
                            color=colors[i],
                            linestyle="-",
                        )
            plt.xlabel("Time Step")
            plt.ylabel("Shortfall (MW)")
            plt.title(f"{reserve_type} Comparison")
            plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
            plt.grid(True)
            plt.tight_layout()
            plt.savefig(
                output_dir / f"{reserve_type.lower().replace(' ', '_')}.png",
                bbox_inches="tight",
            )
            plt.close()

    print(f"Plots saved to {output_dir}")


In [None]:
SAMPLE = "matpower/case14/2017-06-24"
inst = read_benchmark(SAMPLE, quiet=True)
solution_files = [
    f"output_gurobi/matpower/case14/2017-06-24_solution.npz",
    # Add more solution files for comparison, e.g., different MIP gaps
]
plot_solution_comparison(solution_files, inst, output_dir=SAMPLE)
plot_solution_comparison(solution_files, inst, output_dir=SAMPLE)