Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ConstraintGroup #85

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cooper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
warnings.warn("Could not retrieve Cooper version!")

from cooper.cmp import CMPState, ConstrainedMinimizationProblem, LagrangianStore
from cooper.constraints import ConstraintGroup, ConstraintState, ConstraintType
from cooper.constraints import Constraint, ConstraintState, ConstraintType
from cooper.formulations import FormulationType

from . import formulations, multipliers, optim, utils
28 changes: 14 additions & 14 deletions cooper/cmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from cooper.constraints import ConstraintGroup, ConstraintMeasurement, ConstraintState
from cooper.constraints import Constraint, ConstraintMeasurement, ConstraintState

# Formulation, and some other classes below, are inspired by the design of the
# TensorFlow Constrained Optimization (TFCO) library:
Expand Down Expand Up @@ -48,7 +48,7 @@ class CMPState:
def __init__(
self,
loss: Optional[torch.Tensor] = None,
observed_constraints: Sequence[tuple[ConstraintGroup, ConstraintState]] = (),
observed_constraints: Sequence[tuple[Constraint, ConstraintState]] = (),
misc: Optional[dict] = None,
):
self.loss = loss
Expand Down Expand Up @@ -87,13 +87,13 @@ def populate_primal_lagrangian(self) -> LagrangianStore:
current_primal_lagrangian = 0.0 if self.loss is None else torch.clone(self.loss)

current_primal_constraint_measurements = []
for constraint_group, constraint_state in contributing_constraints:
primal_constraint_contrib, primal_measurement = constraint_group.compute_constraint_primal_contribution(
for constraint, constraint_state in contributing_constraints:
primal_lagrangian_contribution, primal_measurement = constraint.compute_constraint_primal_contribution(
constraint_state
)
current_primal_constraint_measurements.append(primal_measurement)
if primal_constraint_contrib is not None:
current_primal_lagrangian = current_primal_lagrangian + primal_constraint_contrib
if primal_lagrangian_contribution is not None:
current_primal_lagrangian = current_primal_lagrangian + primal_lagrangian_contribution

# Modify "private" attributes to accumulate Lagrangian values over successive
# calls to `populate_primal_lagrangian`
Expand Down Expand Up @@ -136,19 +136,19 @@ def populate_dual_lagrangian(self) -> LagrangianStore:
current_dual_lagrangian = 0.0

current_dual_constraint_measurements = []
for constraint_group, constraint_state in contributing_constraints:
dual_lagrangian_contrib, dual_measurement = constraint_group.compute_constraint_dual_contribution(
for constraint, constraint_state in contributing_constraints:
dual_lagrangian_contribution, dual_measurement = constraint.compute_constraint_dual_contribution(
constraint_state
)
current_dual_constraint_measurements.append(dual_measurement)
if dual_lagrangian_contrib is not None:
current_dual_lagrangian = current_dual_lagrangian + dual_lagrangian_contrib
if dual_lagrangian_contribution is not None:
current_dual_lagrangian = current_dual_lagrangian + dual_lagrangian_contribution

# Extracting the violation from the dual_constraint_measurement ensures that it is
# the "strict" violation, if available.
_, strict_constraint_features = constraint_state.extract_constraint_features()
constraint_group.update_strictly_feasible_indices_(
strict_violation=dual_lagrangian_contrib.violation,
constraint.update_strictly_feasible_indices_(
strict_violation=dual_measurement.violation,
strict_constraint_features=strict_constraint_features,
)

Expand Down Expand Up @@ -234,8 +234,8 @@ def backward(self) -> None:

def __repr__(self) -> str:
_string = f"CMPState(\n loss={self.loss},\n observed_constraints=["
for constraint_group, constraint_state in self.observed_constraints:
_string += f"\n\t{constraint_group} -> {constraint_state},"
for constraint, constraint_state in self.observed_constraints:
_string += f"\n\t{constraint} -> {constraint_state},"
_string += f"\n ]\n misc={self.misc}\n)"
return _string

Expand Down
2 changes: 1 addition & 1 deletion cooper/constraints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .constraint_group import ConstraintGroup
from .constraint import Constraint
from .constraint_state import ConstraintMeasurement, ConstraintState, ConstraintType
from .slacks import ConstantSlack, DenseSlack, ExplicitSlack, IndexedSlack, SlackVariable
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from cooper.multipliers import IndexedMultiplier, Multiplier, PenaltyCoefficient


class ConstraintGroup:
"""Constraint Group."""
class Constraint:
"""Constraint."""

# TODO(gallego-posada): Add documentation

Expand Down Expand Up @@ -50,17 +50,17 @@ def sanity_check_multiplier(self, multiplier: Multiplier, constraint_type: Const
if multiplier.constraint_type != constraint_type:
raise ValueError(
f"Constraint type of provided multiplier is {multiplier.constraint_type} \
which is inconsistent with {constraint_type} set for the constraint group."
which is inconsistent with {constraint_type} set for the constraint."
)

def sanity_check_penalty_coefficient(self, penalty_coefficient: PenaltyCoefficient) -> None:
if torch.any(penalty_coefficient.value < 0):
raise ValueError("All entries of the penalty coefficient must be non-negative.")

def update_penalty_coefficient(self, constraint_state: ConstraintState) -> None:
"""Update the penalty coefficient of the constraint group."""
"""Update the penalty coefficient of the constraint."""
if self.penalty_coefficient is None:
raise ValueError("Constraint group does not have a penalty coefficient.")
raise ValueError("Constraint does not have a penalty coefficient.")
else:
self.penalty_coefficient.update_value(
constraint_state=constraint_state,
Expand Down Expand Up @@ -112,32 +112,8 @@ def update_strictly_feasible_indices_(

self.multiplier.strictly_feasible_indices = strictly_feasible_indices

def state_dict(self):
state_dict = {"constraint_type": self.constraint_type, "formulation": self.formulation.state_dict()}
for attr_name, attr in [("multiplier", self.multiplier), ("penalty_coefficient", self.penalty_coefficient)]:
state_dict[attr_name] = attr.state_dict() if attr is not None else None
return state_dict

def load_state_dict(self, state_dict):
self.constraint_type = state_dict["constraint_type"]
self.formulation.load_state_dict(state_dict["formulation"])

if state_dict["multiplier"] is not None and self.multiplier is None:
raise ValueError("Cannot load multiplier state dict since existing multiplier is `None`.")
elif state_dict["multiplier"] is None and self.multiplier is not None:
raise ValueError("Multiplier exists but state dict is `None`.")
elif state_dict["multiplier"] is not None and self.multiplier is not None:
self.multiplier.load_state_dict(state_dict["multiplier"])

if state_dict["penalty_coefficient"] is not None and self.penalty_coefficient is None:
raise ValueError("Cannot load penalty_coefficient state dict since existing penalty_coefficient is `None`.")
elif state_dict["penalty_coefficient"] is None and self.penalty_coefficient is not None:
raise ValueError("Penalty coefficient exists but state dict is `None`.")
elif state_dict["penalty_coefficient"] is not None and self.penalty_coefficient is not None:
self.penalty_coefficient.load_state_dict(state_dict["penalty_coefficient"])

def __repr__(self):
repr = f"ConstraintGroup(constraint_type={self.constraint_type}, formulation={self.formulation}"
repr = f"Constraint(constraint_type={self.constraint_type}, formulation={self.formulation}"
if self.multiplier is not None:
repr += f", multiplier={self.multiplier}"
if self.penalty_coefficient is not None:
Expand Down
6 changes: 3 additions & 3 deletions cooper/constraints/constraint_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ class ConstraintType(Enum):

@dataclass
class ConstraintState:
"""State of a constraint group describing the current constraint violation.
"""State of a constraint describing the current constraint violation.

Args:
violation: Measurement of the constraint violation at some value of the primal
parameters. This is expected to be differentiable with respect to the
primal parameters.
constraint_features: The features of the (differentiable) constraint. This is
used to evaluate the Lagrange multiplier associated with a constraint group.
used to evaluate the Lagrange multiplier associated with a constraint.
For example, an `IndexedMultiplier` expects the indices of the constraints
whose Lagrange multipliers are to be retrieved; while an
`ImplicitMultiplier` expects general tensor-valued features for the
constraints. This field is not used for `DenseMultiplier`//s.
This can be used in conjunction with an `IndexedMultiplier` to indicate the
measurement of the violation for only a subset of the constraints within a
`ConstraintGroup`.
`Constraint`.
strict_violation: Measurement of the constraint violation which may be
non-differentiable with respect to the primal parameters. When provided,
the (necessarily differentiable) `violation` is used to compute the gradient
Expand Down
10 changes: 5 additions & 5 deletions cooper/constraints/slacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class ExplicitSlack(SlackVariable):
"""
An explicit slack holds a :py:class:`~torch.nn.parameter.Parameter` which contains
(explicitly) the value of the slack variable with a
:py:class:`~cooper.constraints.ConstraintGroup` in a
:py:class:`~cooper.constraints.Constraint` in a
:py:class:`~cooper.cmp.ConstrainedMinimizationProblem`.

Args:
Expand Down Expand Up @@ -105,10 +105,10 @@ class DenseSlack(ExplicitSlack):
"""Simplest kind of trainable slack variable.

:py:class:`~cooper.constraints.slacks.DenseSlack`\\s are suitable for low to
mid-scale :py:class:`~cooper.constraints.ConstraintGroup`\\s for which all the
mid-scale :py:class:`~cooper.constraints.Constraint`\\s for which all the
constraints in the group are measured constantly.

For large-scale :py:class:`~cooper.constraints.ConstraintGroup`\\s (for example,
For large-scale :py:class:`~cooper.constraints.Constraint`\\s (for example,
one constraint per training example) you may consider using an
:py:class:`~cooper.constraints.slacks.IndexedSlack`.
"""
Expand All @@ -121,12 +121,12 @@ def forward(self):
class IndexedSlack(ExplicitSlack):
"""Indexed slacks extend the functionality of
:py:class:`~cooper.constraints.slacks.DenseSlack`\\s to cases where the number of
constraints in the :py:class:`~cooper.constraints.ConstraintGroup` is too large.
constraints in the :py:class:`~cooper.constraints.Constraint` is too large.
This situation may arise, for example, when imposing point-wise constraints over all
the training samples in a learning task.

In such cases, it might be computationally prohibitive to measure the value for all
the constraints in the :py:class:`~cooper.constraints.ConstraintGroup` and one may
the constraints in the :py:class:`~cooper.constraints.Constraint` and one may
typically resort to sampling. :py:class:`~cooper.constraints.slacks.IndexedSlack`\\s
enable time-efficient retrieval of the slack variables for the sampled constraints
only, and memory-efficient sparse gradients (on GPU).
Expand Down
6 changes: 3 additions & 3 deletions cooper/formulations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def compute_primal_weighted_violation(

Args:
constraint_factor_value: The value of the multiplier or penalty coefficient for the
constraint group.
constraint.
violation: Tensor of constraint violations.
"""

Expand Down Expand Up @@ -57,7 +57,7 @@ def compute_dual_weighted_violation(
Bertsekas (2016).

Args:
multiplier_value: The value of the multiplier for the constraint group.
multiplier_value: The value of the multiplier for the constraint.
violation: Tensor of constraint violations.
penalty_coefficient_value: Tensor of penalty coefficient values.
"""
Expand Down Expand Up @@ -92,7 +92,7 @@ def compute_quadratic_augmented_contribution(
constraint_type: ConstraintType,
) -> Optional[torch.Tensor]:
r"""
Computes the quadratic penalty for a constraint group.
Computes the quadratic penalty for a constraint.

When the constraint is an inequality constraint, the quadratic penalty is computed
following Eq 17.65 in Numerical Optimization by Nocedal and Wright (2006). Denoting
Expand Down
4 changes: 2 additions & 2 deletions cooper/multipliers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def evaluate_constraint_factor(
module: ConstraintFactor, constraint_features: torch.Tensor, violation: torch.Tensor
) -> torch.Tensor:
"""Evaluate the Lagrange multiplier or penalty coefficient associated with a
constraint group.
constraint.

Args:
module: Multiplier or penalty coefficient module.
Expand All @@ -35,7 +35,7 @@ def evaluate_constraint_factor(
if not value.requires_grad and value.numel() == 1 and violation.numel() > 1:
# Expand the value of the penalty coefficient to match the shape of the violation.
# This enables the use of a single penalty coefficient for all constraints in a
# constraint group.
# constraint.
# We only do this for penalty coefficients an not multipliers because we expect
# a one-to-one mapping between multiplier values and constraints. If multiplier
# sharing is desired, this should be done explicitly by the user.
Expand Down
14 changes: 7 additions & 7 deletions cooper/multipliers/multipliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ class ExplicitMultiplier(Multiplier):
"""
An explicit multiplier holds a :py:class:`~torch.nn.parameter.Parameter` which
contains (explicitly) the value of the Lagrange multipliers associated with a
:py:class:`~cooper.constraints.ConstraintGroup` in a
:py:class:`~cooper.constraints.Constraint` in a
:py:class:`~cooper.cmp.ConstrainedMinimizationProblem`.

.. warning::
When `restart_on_feasible=True`, the entries of the multiplier which correspond
to feasible constraints in the :py:class:`~cooper.constraints.ConstraintGroup`
to feasible constraints in the :py:class:`~cooper.constraints.Constraint`
are reset to a default value (typically zero) by the
:py:meth:`~cooper.multipliers.ExplicitMultiplier.post_step_` method. Note that
we do **not** perform any modification to the dual optimizer associated with
Expand Down Expand Up @@ -163,10 +163,10 @@ class DenseMultiplier(ExplicitMultiplier):
"""Simplest kind of trainable Lagrange multiplier.

:py:class:`~cooper.multipliers.DenseMultiplier`\\s are suitable for low to mid-scale
:py:class:`~cooper.constraints.ConstraintGroup`\\s for which all the constraints
:py:class:`~cooper.constraints.Constraint`\\s for which all the constraints
in the group are measured constantly.

For large-scale :py:class:`~cooper.constraints.ConstraintGroup`\\s (for example,
For large-scale :py:class:`~cooper.constraints.Constraint`\\s (for example,
one constraint per training example) you may consider using an
:py:class:`~cooper.multipliers.IndexedMultiplier`.
"""
Expand All @@ -182,12 +182,12 @@ def __repr__(self):
class IndexedMultiplier(ExplicitMultiplier):
"""Indexed multipliers extend the functionality of
:py:class:`~cooper.multipliers.DenseMultiplier`\\s to cases where the number of
constraints in the :py:class:`~cooper.constraints.ConstraintGroup` is too large.
constraints in the :py:class:`~cooper.constraints.Constraint` is too large.
This situation may arise, for example, when imposing point-wise constraints over all
the training samples in a learning task.

In such cases, it might be computationally prohibitive to measure the value for all
the constraints in the :py:class:`~cooper.constraints.ConstraintGroup` and one may
the constraints in the :py:class:`~cooper.constraints.Constraint` and one may
typically resort to sampling. :py:class:`~cooper.multipliers.IndexedMultiplier`\\s
enable time-efficient retrieval of the multipliers for the sampled constraints only,
and memory-efficient sparse gradients (on GPU).
Expand Down Expand Up @@ -244,7 +244,7 @@ def __repr__(self):
class ImplicitMultiplier(Multiplier):
"""An implicit multiplier is a :py:class:`~torch.nn.Module` that computes the value
of a Lagrange multiplier associated with a
:py:class:`~cooper.constraints.ConstraintGroup` based on "features" for each
:py:class:`~cooper.constraints.Constraint` based on "features" for each
constraint. The multiplier is _implicitly_ represented by the features of its
associated constraint as well as the computation that takes place in the
:py:meth:`~cooper.multipliers.ImplicitMultiplier.forward` method.
Expand Down
8 changes: 4 additions & 4 deletions cooper/optim/constrained_optimizers/alternating_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,17 @@ def step(self):
pass

def update_penalty_coefficients(self, cmp_state: CMPState) -> None:
"""Update the penalty coefficients of the constraint groups. Only the penalty
"""Update the penalty coefficients of the constraints. Only the penalty
coefficients associated with the ``FormulationType.AUGMENTED_LAGRANGIAN`` and
constraints that ``contributes_to_dual_update`` are updated.
"""
for constraint_group, constraint_state in cmp_state.observed_constraints:
if constraint_group.formulation_type == FormulationType.AUGMENTED_LAGRANGIAN:
for constraint, constraint_state in cmp_state.observed_constraints:
if constraint.formulation_type == FormulationType.AUGMENTED_LAGRANGIAN:
# We might reach this point via an AugmetedLagrangianOptimizer acting
# on some constraints that do not use an Augmented Lagrangian formulation,
# so we do _not_ apply penalty coefficient updates to those.
if constraint_state.contributes_to_dual_update:
constraint_group.update_penalty_coefficient(constraint_state=constraint_state)
constraint.update_penalty_coefficient(constraint_state=constraint_state)


class AlternatingPrimalDualOptimizer(BaseAlternatingOptimizer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ConstrainedOptimizer:
dual_optimizers: Optimizer(s) for the dual variables (e.g. the Lagrange
multipliers associated with the constraints). An iterable of
``torch.optim.Optimizer``\\s can be passed to handle the case of several
``~cooper.constraints.ConstraintGroup``\\s. If dealing with an unconstrained
``~cooper.constraints.Constraint``\\s. If dealing with an unconstrained
problem, please use a
:py:class:`~cooper.optim.cooper_optimizer.UnconstrainedOptimizer` instead.

Expand Down
4 changes: 2 additions & 2 deletions cooper/optim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def load_cooper_optimizer_from_state_dict(
):
"""Creates a Cooper optimizer and loads the state_dicts contained in a
:py:class:`~cooper.optim.CooperOptimizerState` onto instantiated primal and dual
optimizers and constraint groups or multipliers.
optimizers and constraints or multipliers.
"""

# Load primal optimizers
Expand Down Expand Up @@ -104,7 +104,7 @@ def load_cooper_optimizer_from_state_dict(
for multiplier, multiplier_state in zip(multipliers, multiplier_states):
multiplier.load_state_dict(multiplier_state)

# Since we have extracted the multiplier information above, we discard the constraint_groups below
# Since we have extracted the multiplier information above, we discard the constraints below
return create_optimizer_from_kwargs(
primal_optimizers=primal_optimizers,
extrapolation=cooper_optimizer_state.extrapolation,
Expand Down
Loading