Skip to content

Commit

Permalink
Refactor ConstraintGroup (#85)
Browse files Browse the repository at this point in the history
* Rename ConstraintGroup to Constraint

* Remove statefulness of Constraints

* Remove checkpoint tests

---------

Co-authored-by: juan43ramirez <juan43.ramirez@gmail.com>
  • Loading branch information
merajhashemi and juan43ramirez committed Mar 21, 2024
1 parent 67e9a7b commit b6c51ef
Show file tree
Hide file tree
Showing 24 changed files with 76 additions and 287 deletions.
2 changes: 1 addition & 1 deletion cooper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
warnings.warn("Could not retrieve Cooper version!")

from cooper.cmp import CMPState, ConstrainedMinimizationProblem, LagrangianStore
from cooper.constraints import ConstraintGroup, ConstraintState, ConstraintType
from cooper.constraints import Constraint, ConstraintState, ConstraintType
from cooper.formulations import FormulationType

from . import formulations, multipliers, optim, utils
28 changes: 14 additions & 14 deletions cooper/cmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from cooper.constraints import ConstraintGroup, ConstraintMeasurement, ConstraintState
from cooper.constraints import Constraint, ConstraintMeasurement, ConstraintState

# Formulation, and some other classes below, are inspired by the design of the
# TensorFlow Constrained Optimization (TFCO) library:
Expand Down Expand Up @@ -48,7 +48,7 @@ class CMPState:
def __init__(
self,
loss: Optional[torch.Tensor] = None,
observed_constraints: Sequence[tuple[ConstraintGroup, ConstraintState]] = (),
observed_constraints: Sequence[tuple[Constraint, ConstraintState]] = (),
misc: Optional[dict] = None,
):
self.loss = loss
Expand Down Expand Up @@ -87,13 +87,13 @@ def populate_primal_lagrangian(self) -> LagrangianStore:
current_primal_lagrangian = 0.0 if self.loss is None else torch.clone(self.loss)

current_primal_constraint_measurements = []
for constraint_group, constraint_state in contributing_constraints:
primal_constraint_contrib, primal_measurement = constraint_group.compute_constraint_primal_contribution(
for constraint, constraint_state in contributing_constraints:
primal_lagrangian_contribution, primal_measurement = constraint.compute_constraint_primal_contribution(
constraint_state
)
current_primal_constraint_measurements.append(primal_measurement)
if primal_constraint_contrib is not None:
current_primal_lagrangian = current_primal_lagrangian + primal_constraint_contrib
if primal_lagrangian_contribution is not None:
current_primal_lagrangian = current_primal_lagrangian + primal_lagrangian_contribution

# Modify "private" attributes to accumulate Lagrangian values over successive
# calls to `populate_primal_lagrangian`
Expand Down Expand Up @@ -136,19 +136,19 @@ def populate_dual_lagrangian(self) -> LagrangianStore:
current_dual_lagrangian = 0.0

current_dual_constraint_measurements = []
for constraint_group, constraint_state in contributing_constraints:
dual_lagrangian_contrib, dual_measurement = constraint_group.compute_constraint_dual_contribution(
for constraint, constraint_state in contributing_constraints:
dual_lagrangian_contribution, dual_measurement = constraint.compute_constraint_dual_contribution(
constraint_state
)
current_dual_constraint_measurements.append(dual_measurement)
if dual_lagrangian_contrib is not None:
current_dual_lagrangian = current_dual_lagrangian + dual_lagrangian_contrib
if dual_lagrangian_contribution is not None:
current_dual_lagrangian = current_dual_lagrangian + dual_lagrangian_contribution

# Extracting the violation from the dual_constraint_measurement ensures that it is
# the "strict" violation, if available.
_, strict_constraint_features = constraint_state.extract_constraint_features()
constraint_group.update_strictly_feasible_indices_(
strict_violation=dual_lagrangian_contrib.violation,
constraint.update_strictly_feasible_indices_(
strict_violation=dual_measurement.violation,
strict_constraint_features=strict_constraint_features,
)

Expand Down Expand Up @@ -234,8 +234,8 @@ def backward(self) -> None:

def __repr__(self) -> str:
_string = f"CMPState(\n loss={self.loss},\n observed_constraints=["
for constraint_group, constraint_state in self.observed_constraints:
_string += f"\n\t{constraint_group} -> {constraint_state},"
for constraint, constraint_state in self.observed_constraints:
_string += f"\n\t{constraint} -> {constraint_state},"
_string += f"\n ]\n misc={self.misc}\n)"
return _string

Expand Down
2 changes: 1 addition & 1 deletion cooper/constraints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .constraint_group import ConstraintGroup
from .constraint import Constraint
from .constraint_state import ConstraintMeasurement, ConstraintState, ConstraintType
from .slacks import ConstantSlack, DenseSlack, ExplicitSlack, IndexedSlack, SlackVariable
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from cooper.multipliers import IndexedMultiplier, Multiplier, PenaltyCoefficient


class ConstraintGroup:
"""Constraint Group."""
class Constraint:
"""Constraint."""

# TODO(gallego-posada): Add documentation

Expand Down Expand Up @@ -50,17 +50,17 @@ def sanity_check_multiplier(self, multiplier: Multiplier, constraint_type: Const
if multiplier.constraint_type != constraint_type:
raise ValueError(
f"Constraint type of provided multiplier is {multiplier.constraint_type} \
which is inconsistent with {constraint_type} set for the constraint group."
which is inconsistent with {constraint_type} set for the constraint."
)

def sanity_check_penalty_coefficient(self, penalty_coefficient: PenaltyCoefficient) -> None:
if torch.any(penalty_coefficient.value < 0):
raise ValueError("All entries of the penalty coefficient must be non-negative.")

def update_penalty_coefficient(self, constraint_state: ConstraintState) -> None:
"""Update the penalty coefficient of the constraint group."""
"""Update the penalty coefficient of the constraint."""
if self.penalty_coefficient is None:
raise ValueError("Constraint group does not have a penalty coefficient.")
raise ValueError("Constraint does not have a penalty coefficient.")
else:
self.penalty_coefficient.update_value(
constraint_state=constraint_state,
Expand Down Expand Up @@ -112,32 +112,8 @@ def update_strictly_feasible_indices_(

self.multiplier.strictly_feasible_indices = strictly_feasible_indices

def state_dict(self):
state_dict = {"constraint_type": self.constraint_type, "formulation": self.formulation.state_dict()}
for attr_name, attr in [("multiplier", self.multiplier), ("penalty_coefficient", self.penalty_coefficient)]:
state_dict[attr_name] = attr.state_dict() if attr is not None else None
return state_dict

def load_state_dict(self, state_dict):
self.constraint_type = state_dict["constraint_type"]
self.formulation.load_state_dict(state_dict["formulation"])

if state_dict["multiplier"] is not None and self.multiplier is None:
raise ValueError("Cannot load multiplier state dict since existing multiplier is `None`.")
elif state_dict["multiplier"] is None and self.multiplier is not None:
raise ValueError("Multiplier exists but state dict is `None`.")
elif state_dict["multiplier"] is not None and self.multiplier is not None:
self.multiplier.load_state_dict(state_dict["multiplier"])

if state_dict["penalty_coefficient"] is not None and self.penalty_coefficient is None:
raise ValueError("Cannot load penalty_coefficient state dict since existing penalty_coefficient is `None`.")
elif state_dict["penalty_coefficient"] is None and self.penalty_coefficient is not None:
raise ValueError("Penalty coefficient exists but state dict is `None`.")
elif state_dict["penalty_coefficient"] is not None and self.penalty_coefficient is not None:
self.penalty_coefficient.load_state_dict(state_dict["penalty_coefficient"])

def __repr__(self):
repr = f"ConstraintGroup(constraint_type={self.constraint_type}, formulation={self.formulation}"
repr = f"Constraint(constraint_type={self.constraint_type}, formulation={self.formulation}"
if self.multiplier is not None:
repr += f", multiplier={self.multiplier}"
if self.penalty_coefficient is not None:
Expand Down
6 changes: 3 additions & 3 deletions cooper/constraints/constraint_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ class ConstraintType(Enum):

@dataclass
class ConstraintState:
"""State of a constraint group describing the current constraint violation.
"""State of a constraint describing the current constraint violation.
Args:
violation: Measurement of the constraint violation at some value of the primal
parameters. This is expected to be differentiable with respect to the
primal parameters.
constraint_features: The features of the (differentiable) constraint. This is
used to evaluate the Lagrange multiplier associated with a constraint group.
used to evaluate the Lagrange multiplier associated with a constraint.
For example, an `IndexedMultiplier` expects the indices of the constraints
whose Lagrange multipliers are to be retrieved; while an
`ImplicitMultiplier` expects general tensor-valued features for the
constraints. This field is not used for `DenseMultiplier`//s.
This can be used in conjunction with an `IndexedMultiplier` to indicate the
measurement of the violation for only a subset of the constraints within a
`ConstraintGroup`.
`Constraint`.
strict_violation: Measurement of the constraint violation which may be
non-differentiable with respect to the primal parameters. When provided,
the (necessarily differentiable) `violation` is used to compute the gradient
Expand Down
10 changes: 5 additions & 5 deletions cooper/constraints/slacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class ExplicitSlack(SlackVariable):
"""
An explicit slack holds a :py:class:`~torch.nn.parameter.Parameter` which contains
(explicitly) the value of the slack variable with a
:py:class:`~cooper.constraints.ConstraintGroup` in a
:py:class:`~cooper.constraints.Constraint` in a
:py:class:`~cooper.cmp.ConstrainedMinimizationProblem`.
Args:
Expand Down Expand Up @@ -105,10 +105,10 @@ class DenseSlack(ExplicitSlack):
"""Simplest kind of trainable slack variable.
:py:class:`~cooper.constraints.slacks.DenseSlack`\\s are suitable for low to
mid-scale :py:class:`~cooper.constraints.ConstraintGroup`\\s for which all the
mid-scale :py:class:`~cooper.constraints.Constraint`\\s for which all the
constraints in the group are measured constantly.
For large-scale :py:class:`~cooper.constraints.ConstraintGroup`\\s (for example,
For large-scale :py:class:`~cooper.constraints.Constraint`\\s (for example,
one constraint per training example) you may consider using an
:py:class:`~cooper.constraints.slacks.IndexedSlack`.
"""
Expand All @@ -121,12 +121,12 @@ def forward(self):
class IndexedSlack(ExplicitSlack):
"""Indexed slacks extend the functionality of
:py:class:`~cooper.constraints.slacks.DenseSlack`\\s to cases where the number of
constraints in the :py:class:`~cooper.constraints.ConstraintGroup` is too large.
constraints in the :py:class:`~cooper.constraints.Constraint` is too large.
This situation may arise, for example, when imposing point-wise constraints over all
the training samples in a learning task.
In such cases, it might be computationally prohibitive to measure the value for all
the constraints in the :py:class:`~cooper.constraints.ConstraintGroup` and one may
the constraints in the :py:class:`~cooper.constraints.Constraint` and one may
typically resort to sampling. :py:class:`~cooper.constraints.slacks.IndexedSlack`\\s
enable time-efficient retrieval of the slack variables for the sampled constraints
only, and memory-efficient sparse gradients (on GPU).
Expand Down
6 changes: 3 additions & 3 deletions cooper/formulations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def compute_primal_weighted_violation(
Args:
constraint_factor_value: The value of the multiplier or penalty coefficient for the
constraint group.
constraint.
violation: Tensor of constraint violations.
"""

Expand Down Expand Up @@ -57,7 +57,7 @@ def compute_dual_weighted_violation(
Bertsekas (2016).
Args:
multiplier_value: The value of the multiplier for the constraint group.
multiplier_value: The value of the multiplier for the constraint.
violation: Tensor of constraint violations.
penalty_coefficient_value: Tensor of penalty coefficient values.
"""
Expand Down Expand Up @@ -92,7 +92,7 @@ def compute_quadratic_augmented_contribution(
constraint_type: ConstraintType,
) -> Optional[torch.Tensor]:
r"""
Computes the quadratic penalty for a constraint group.
Computes the quadratic penalty for a constraint.
When the constraint is an inequality constraint, the quadratic penalty is computed
following Eq 17.65 in Numerical Optimization by Nocedal and Wright (2006). Denoting
Expand Down
4 changes: 2 additions & 2 deletions cooper/multipliers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def evaluate_constraint_factor(
module: ConstraintFactor, constraint_features: torch.Tensor, violation: torch.Tensor
) -> torch.Tensor:
"""Evaluate the Lagrange multiplier or penalty coefficient associated with a
constraint group.
constraint.
Args:
module: Multiplier or penalty coefficient module.
Expand All @@ -35,7 +35,7 @@ def evaluate_constraint_factor(
if not value.requires_grad and value.numel() == 1 and violation.numel() > 1:
# Expand the value of the penalty coefficient to match the shape of the violation.
# This enables the use of a single penalty coefficient for all constraints in a
# constraint group.
# constraint.
# We only do this for penalty coefficients an not multipliers because we expect
# a one-to-one mapping between multiplier values and constraints. If multiplier
# sharing is desired, this should be done explicitly by the user.
Expand Down
14 changes: 7 additions & 7 deletions cooper/multipliers/multipliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ class ExplicitMultiplier(Multiplier):
"""
An explicit multiplier holds a :py:class:`~torch.nn.parameter.Parameter` which
contains (explicitly) the value of the Lagrange multipliers associated with a
:py:class:`~cooper.constraints.ConstraintGroup` in a
:py:class:`~cooper.constraints.Constraint` in a
:py:class:`~cooper.cmp.ConstrainedMinimizationProblem`.
.. warning::
When `restart_on_feasible=True`, the entries of the multiplier which correspond
to feasible constraints in the :py:class:`~cooper.constraints.ConstraintGroup`
to feasible constraints in the :py:class:`~cooper.constraints.Constraint`
are reset to a default value (typically zero) by the
:py:meth:`~cooper.multipliers.ExplicitMultiplier.post_step_` method. Note that
we do **not** perform any modification to the dual optimizer associated with
Expand Down Expand Up @@ -163,10 +163,10 @@ class DenseMultiplier(ExplicitMultiplier):
"""Simplest kind of trainable Lagrange multiplier.
:py:class:`~cooper.multipliers.DenseMultiplier`\\s are suitable for low to mid-scale
:py:class:`~cooper.constraints.ConstraintGroup`\\s for which all the constraints
:py:class:`~cooper.constraints.Constraint`\\s for which all the constraints
in the group are measured constantly.
For large-scale :py:class:`~cooper.constraints.ConstraintGroup`\\s (for example,
For large-scale :py:class:`~cooper.constraints.Constraint`\\s (for example,
one constraint per training example) you may consider using an
:py:class:`~cooper.multipliers.IndexedMultiplier`.
"""
Expand All @@ -182,12 +182,12 @@ def __repr__(self):
class IndexedMultiplier(ExplicitMultiplier):
"""Indexed multipliers extend the functionality of
:py:class:`~cooper.multipliers.DenseMultiplier`\\s to cases where the number of
constraints in the :py:class:`~cooper.constraints.ConstraintGroup` is too large.
constraints in the :py:class:`~cooper.constraints.Constraint` is too large.
This situation may arise, for example, when imposing point-wise constraints over all
the training samples in a learning task.
In such cases, it might be computationally prohibitive to measure the value for all
the constraints in the :py:class:`~cooper.constraints.ConstraintGroup` and one may
the constraints in the :py:class:`~cooper.constraints.Constraint` and one may
typically resort to sampling. :py:class:`~cooper.multipliers.IndexedMultiplier`\\s
enable time-efficient retrieval of the multipliers for the sampled constraints only,
and memory-efficient sparse gradients (on GPU).
Expand Down Expand Up @@ -244,7 +244,7 @@ def __repr__(self):
class ImplicitMultiplier(Multiplier):
"""An implicit multiplier is a :py:class:`~torch.nn.Module` that computes the value
of a Lagrange multiplier associated with a
:py:class:`~cooper.constraints.ConstraintGroup` based on "features" for each
:py:class:`~cooper.constraints.Constraint` based on "features" for each
constraint. The multiplier is _implicitly_ represented by the features of its
associated constraint as well as the computation that takes place in the
:py:meth:`~cooper.multipliers.ImplicitMultiplier.forward` method.
Expand Down
8 changes: 4 additions & 4 deletions cooper/optim/constrained_optimizers/alternating_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,17 @@ def step(self):
pass

def update_penalty_coefficients(self, cmp_state: CMPState) -> None:
"""Update the penalty coefficients of the constraint groups. Only the penalty
"""Update the penalty coefficients of the constraints. Only the penalty
coefficients associated with the ``FormulationType.AUGMENTED_LAGRANGIAN`` and
constraints that ``contributes_to_dual_update`` are updated.
"""
for constraint_group, constraint_state in cmp_state.observed_constraints:
if constraint_group.formulation_type == FormulationType.AUGMENTED_LAGRANGIAN:
for constraint, constraint_state in cmp_state.observed_constraints:
if constraint.formulation_type == FormulationType.AUGMENTED_LAGRANGIAN:
# We might reach this point via an AugmetedLagrangianOptimizer acting
# on some constraints that do not use an Augmented Lagrangian formulation,
# so we do _not_ apply penalty coefficient updates to those.
if constraint_state.contributes_to_dual_update:
constraint_group.update_penalty_coefficient(constraint_state=constraint_state)
constraint.update_penalty_coefficient(constraint_state=constraint_state)


class AlternatingPrimalDualOptimizer(BaseAlternatingOptimizer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ConstrainedOptimizer:
dual_optimizers: Optimizer(s) for the dual variables (e.g. the Lagrange
multipliers associated with the constraints). An iterable of
``torch.optim.Optimizer``\\s can be passed to handle the case of several
``~cooper.constraints.ConstraintGroup``\\s. If dealing with an unconstrained
``~cooper.constraints.Constraint``\\s. If dealing with an unconstrained
problem, please use a
:py:class:`~cooper.optim.cooper_optimizer.UnconstrainedOptimizer` instead.
Expand Down
4 changes: 2 additions & 2 deletions cooper/optim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def load_cooper_optimizer_from_state_dict(
):
"""Creates a Cooper optimizer and loads the state_dicts contained in a
:py:class:`~cooper.optim.CooperOptimizerState` onto instantiated primal and dual
optimizers and constraint groups or multipliers.
optimizers and constraints or multipliers.
"""

# Load primal optimizers
Expand Down Expand Up @@ -104,7 +104,7 @@ def load_cooper_optimizer_from_state_dict(
for multiplier, multiplier_state in zip(multipliers, multiplier_states):
multiplier.load_state_dict(multiplier_state)

# Since we have extracted the multiplier information above, we discard the constraint_groups below
# Since we have extracted the multiplier information above, we discard the constraints below
return create_optimizer_from_kwargs(
primal_optimizers=primal_optimizers,
extrapolation=cooper_optimizer_state.extrapolation,
Expand Down
Loading

0 comments on commit b6c51ef

Please sign in to comment.