Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ConstraintStore #84

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 37 additions & 33 deletions cooper/cmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from cooper.constraints import ConstraintGroup, ConstraintState, ConstraintStore
from cooper.constraints import ConstraintGroup, ConstraintMeasurement, ConstraintState

# Formulation, and some other classes below, are inspired by the design of the
# TensorFlow Constrained Optimization (TFCO) library:
Expand All @@ -19,14 +19,14 @@ class LagrangianStore:

lagrangian: torch.Tensor
dual_lagrangian: Optional[torch.Tensor] = None
primal_constraint_stores: Optional[list[ConstraintStore]] = None
dual_constraint_stores: Optional[list[ConstraintStore]] = None
primal_constraint_measurements: Optional[list[ConstraintMeasurement]] = None
dual_constraint_measurements: Optional[list[ConstraintMeasurement]] = None

def multiplier_values_for_primal_constraints(self):
return [_.multiplier_value for _ in self.primal_constraint_stores]
return [_.multiplier_value for _ in self.primal_constraint_measurements]

def multiplier_values_for_dual_constraints(self):
return [_.multiplier_value for _ in self.dual_constraint_stores]
return [_.multiplier_value for _ in self.dual_constraint_measurements]


class CMPState:
Expand Down Expand Up @@ -56,9 +56,9 @@ def __init__(
self.misc = misc

self._primal_lagrangian = None
self._primal_constraint_stores = []
self._primal_constraint_measurements = []
self._dual_lagrangian = None
self._dual_constraint_stores = []
self._dual_constraint_measurements = []

def populate_primal_lagrangian(self) -> LagrangianStore:
"""Computes and accumulates the primal-differentiable Lagrangian based on the
Expand All @@ -77,35 +77,37 @@ def populate_primal_lagrangian(self) -> LagrangianStore:
return LagrangianStore(
lagrangian=self._primal_lagrangian,
dual_lagrangian=self._dual_lagrangian,
primal_constraint_stores=self._primal_constraint_stores,
dual_constraint_stores=self._dual_constraint_stores,
primal_constraint_measurements=self._primal_constraint_measurements,
dual_constraint_measurements=self._dual_constraint_measurements,
)

# Either a loss was provided, or at least one observed constraint contributes to
# the primal Lagrangian.
previous_primal_lagrangian = 0.0 if self._primal_lagrangian is None else self._primal_lagrangian
current_primal_lagrangian = 0.0 if self.loss is None else torch.clone(self.loss)

current_primal_constraint_stores = []
current_primal_constraint_measurements = []
for constraint_group, constraint_state in contributing_constraints:
primal_constraint_store = constraint_group.compute_constraint_primal_contribution(constraint_state)
current_primal_constraint_stores.append(primal_constraint_store)
if primal_constraint_store is not None:
current_primal_lagrangian = current_primal_lagrangian + primal_constraint_store.lagrangian_contribution
primal_constraint_contrib, primal_measurement = constraint_group.compute_constraint_primal_contribution(
constraint_state
)
current_primal_constraint_measurements.append(primal_measurement)
if primal_constraint_contrib is not None:
current_primal_lagrangian = current_primal_lagrangian + primal_constraint_contrib

# Modify "private" attributes to accumulate Lagrangian values over successive
# calls to `populate_primal_lagrangian`
self._primal_lagrangian = previous_primal_lagrangian + current_primal_lagrangian
self._primal_constraint_stores.extend(current_primal_constraint_stores)
self._primal_constraint_measurements.extend(current_primal_constraint_measurements)

# We return any existent values for the _dual_lagrangian, and the
# _dual_constraint_stores. The _primal_lagrangian and _primal_constraint_stores
# _dual_constraint_measurements. The _primal_lagrangian and _primal_constraint_measurements
# attributes have been modified earlier, so their updated values are returned.
lagrangian_store = LagrangianStore(
lagrangian=self._primal_lagrangian,
dual_lagrangian=self._dual_lagrangian,
primal_constraint_stores=self._primal_constraint_stores,
dual_constraint_stores=self._dual_constraint_stores,
primal_constraint_measurements=self._primal_constraint_measurements,
dual_constraint_measurements=self._dual_constraint_measurements,
)

return lagrangian_store
Expand All @@ -125,42 +127,44 @@ def populate_dual_lagrangian(self) -> LagrangianStore:
return LagrangianStore(
lagrangian=self._primal_lagrangian,
dual_lagrangian=self._dual_lagrangian,
primal_constraint_stores=self._primal_constraint_stores,
dual_constraint_stores=self._dual_constraint_stores,
primal_constraint_measurements=self._primal_constraint_measurements,
dual_constraint_measurements=self._dual_constraint_measurements,
)

# At least one observed constraint contributes to the dual Lagrangian.
previous_dual_lagrangian = 0.0 if self._dual_lagrangian is None else self._dual_lagrangian
current_dual_lagrangian = 0.0

current_dual_constraint_stores = []
current_dual_constraint_measurements = []
for constraint_group, constraint_state in contributing_constraints:
dual_constraint_store = constraint_group.compute_constraint_dual_contribution(constraint_state)
current_dual_constraint_stores.append(dual_constraint_store)
if dual_constraint_store is not None:
current_dual_lagrangian = current_dual_lagrangian + dual_constraint_store.lagrangian_contribution
dual_lagrangian_contrib, dual_measurement = constraint_group.compute_constraint_dual_contribution(
constraint_state
)
current_dual_constraint_measurements.append(dual_measurement)
if dual_lagrangian_contrib is not None:
current_dual_lagrangian = current_dual_lagrangian + dual_lagrangian_contrib

# Extracting the violation from the dual_constraint_store ensures that it is
# Extracting the violation from the dual_constraint_measurement ensures that it is
# the "strict" violation, if available.
_, strict_constraint_features = constraint_state.extract_constraint_features()
constraint_group.update_strictly_feasible_indices_(
strict_violation=dual_constraint_store.violation,
strict_violation=dual_lagrangian_contrib.violation,
strict_constraint_features=strict_constraint_features,
)

# Modify "private" attributes to accumulate Lagrangian values over successive
# calls to `populate_dual_lagrangian`
self._dual_lagrangian = previous_dual_lagrangian + current_dual_lagrangian
self._dual_constraint_stores.extend(current_dual_constraint_stores)
self._dual_constraint_measurements.extend(current_dual_constraint_measurements)

# We return any existent values for the _primal_lagrangian, and the
# _primal_constraint_stores. The _dual_lagrangian and _dual_constraint_stores
# _primal_constraint_measurements. The _dual_lagrangian and _dual_constraint_measurements
# attributes have been modified earlier, so their updated values are returned.
lagrangian_store = LagrangianStore(
lagrangian=self._primal_lagrangian,
dual_lagrangian=self._dual_lagrangian,
primal_constraint_stores=self._primal_constraint_stores,
dual_constraint_stores=self._dual_constraint_stores,
primal_constraint_measurements=self._primal_constraint_measurements,
dual_constraint_measurements=self._dual_constraint_measurements,
)

return lagrangian_store
Expand Down Expand Up @@ -192,12 +196,12 @@ def populate_lagrangian(self) -> LagrangianStore:
def purge_primal_lagrangian(self) -> None:
"""Purge the accumulated primal Lagrangian contributions."""
self._primal_lagrangian = None
self._primal_constraint_stores = []
self._primal_constraint_measurements = []

def purge_dual_lagrangian(self) -> None:
"""Purge the accumulated dual Lagrangian contributions."""
self._dual_lagrangian = None
self._dual_constraint_stores = []
self._dual_constraint_measurements = []

def purge_lagrangian(self) -> None:
"""Purge the accumulated Lagrangian contributions."""
Expand Down
2 changes: 1 addition & 1 deletion cooper/constraints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .constraint_group import ConstraintGroup
from .constraint_state import ConstraintState, ConstraintStore, ConstraintType
from .constraint_state import ConstraintMeasurement, ConstraintState, ConstraintType
from .slacks import ConstantSlack, DenseSlack, ExplicitSlack, IndexedSlack, SlackVariable
10 changes: 7 additions & 3 deletions cooper/constraints/constraint_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from cooper import multipliers
from cooper.constraints.constraint_state import ConstraintState, ConstraintStore, ConstraintType
from cooper.constraints.constraint_state import ConstraintMeasurement, ConstraintState, ConstraintType
from cooper.formulations import FormulationType
from cooper.multipliers import IndexedMultiplier, Multiplier, PenaltyCoefficient

Expand Down Expand Up @@ -78,12 +78,16 @@ def prepare_kwargs_for_lagrangian_contribution(self, constraint_state: Constrain

return kwargs

def compute_constraint_primal_contribution(self, constraint_state: ConstraintState) -> ConstraintStore:
def compute_constraint_primal_contribution(
self, constraint_state: ConstraintState
) -> tuple[Optional[torch.Tensor], Optional[ConstraintMeasurement]]:
"""Compute the contribution of the current constraint to the primal Lagrangian."""
kwargs = self.prepare_kwargs_for_lagrangian_contribution(constraint_state=constraint_state)
return self.formulation.compute_contribution_for_primal_lagrangian(**kwargs)

def compute_constraint_dual_contribution(self, constraint_state: ConstraintState) -> ConstraintStore:
def compute_constraint_dual_contribution(
self, constraint_state: ConstraintState
) -> tuple[Optional[torch.Tensor], Optional[ConstraintMeasurement]]:
"""Compute the contribution of the current constraint to the dual Lagrangian."""
kwargs = self.prepare_kwargs_for_lagrangian_contribution(constraint_state=constraint_state)
return self.formulation.compute_contribution_for_dual_lagrangian(**kwargs)
Expand Down
7 changes: 3 additions & 4 deletions cooper/constraints/constraint_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,13 @@ def compute_strictly_feasible_constraints(self) -> tuple[torch.Tensor, torch.Ten


@dataclass
class ConstraintStore:
# TODO: update docstring. Current ConstraintStore is agnostic to dual or primal
class ConstraintMeasurement:
# TODO: update docstring. Current ConstraintMeasurement is agnostic to dual or primal
# lagrangian.
"""Stores the value of the constraint factor (multiplier or penalty coefficient),
the contribution of the constraint to the primal-differentiable Lagrian, and the
the contribution of the constraint to the primal-differentiable Lagrangian, and the
contribution of the constraint to the dual-differentiable Lagrangian."""

lagrangian_contribution: Optional[torch.Tensor] = None
violation: Optional[torch.Tensor] = None
multiplier_value: Optional[torch.Tensor] = None
penalty_coefficient_value: Optional[torch.Tensor] = None
Loading
Loading