From 8d699679434e6e13590d757f1a4be9ed391abd73 Mon Sep 17 00:00:00 2001 From: Meraj Hashemizadeh Date: Wed, 13 Mar 2024 11:25:22 -0400 Subject: [PATCH 1/3] Refactor ConstraintStore Rename ConstraintStore to ConstraintMeasurement Remove lagrangian contribution from ConstraintMeasurement --- cooper/cmp.py | 74 +++---- cooper/constraints/__init__.py | 2 +- cooper/constraints/constraint_group.py | 10 +- cooper/constraints/constraint_state.py | 7 +- cooper/formulations/formulations.py | 189 +++++++++--------- .../alternating_optimizer.py | 12 +- 6 files changed, 151 insertions(+), 143 deletions(-) diff --git a/cooper/cmp.py b/cooper/cmp.py index 1c1a4905..09f12a23 100644 --- a/cooper/cmp.py +++ b/cooper/cmp.py @@ -5,7 +5,7 @@ import torch -from cooper.constraints import ConstraintGroup, ConstraintState, ConstraintStore +from cooper.constraints import ConstraintGroup, ConstraintMeasurement, ConstraintState # Formulation, and some other classes below, are inspired by the design of the # TensorFlow Constrained Optimization (TFCO) library: @@ -19,14 +19,14 @@ class LagrangianStore: lagrangian: torch.Tensor dual_lagrangian: Optional[torch.Tensor] = None - primal_constraint_stores: Optional[list[ConstraintStore]] = None - dual_constraint_stores: Optional[list[ConstraintStore]] = None + primal_constraint_measurements: Optional[list[ConstraintMeasurement]] = None + dual_constraint_measurements: Optional[list[ConstraintMeasurement]] = None def multiplier_values_for_primal_constraints(self): - return [_.multiplier_value for _ in self.primal_constraint_stores] + return [_.multiplier_value for _ in self.primal_constraint_measurements] def multiplier_values_for_dual_constraints(self): - return [_.multiplier_value for _ in self.dual_constraint_stores] + return [_.multiplier_value for _ in self.dual_constraint_measurements] class CMPState: @@ -56,9 +56,9 @@ def __init__( self.misc = misc self._primal_lagrangian = None - self._primal_constraint_stores = [] + self._primal_constraint_measurements = [] self._dual_lagrangian = None - self._dual_constraint_stores = [] + self._dual_constraint_measurements = [] def populate_primal_lagrangian(self) -> LagrangianStore: """Computes and accumulates the primal-differentiable Lagrangian based on the @@ -77,8 +77,8 @@ def populate_primal_lagrangian(self) -> LagrangianStore: return LagrangianStore( lagrangian=self._primal_lagrangian, dual_lagrangian=self._dual_lagrangian, - primal_constraint_stores=self._primal_constraint_stores, - dual_constraint_stores=self._dual_constraint_stores, + primal_constraint_measurements=self._primal_constraint_measurements, + dual_constraint_measurements=self._dual_constraint_measurements, ) # Either a loss was provided, or at least one observed constraint contributes to @@ -86,26 +86,29 @@ def populate_primal_lagrangian(self) -> LagrangianStore: previous_primal_lagrangian = 0.0 if self._primal_lagrangian is None else self._primal_lagrangian current_primal_lagrangian = 0.0 if self.loss is None else torch.clone(self.loss) - current_primal_constraint_stores = [] + current_primal_constraint_measurements = [] for constraint_group, constraint_state in contributing_constraints: - primal_constraint_store = constraint_group.compute_constraint_primal_contribution(constraint_state) - current_primal_constraint_stores.append(primal_constraint_store) - if primal_constraint_store is not None: - current_primal_lagrangian = current_primal_lagrangian + primal_constraint_store.lagrangian_contribution + ( + primal_lagrangian_contribution, + primal_constraint_measurement, + ) = constraint_group.compute_constraint_primal_contribution(constraint_state) + current_primal_constraint_measurements.append(primal_constraint_measurement) + if primal_lagrangian_contribution is not None: + current_primal_lagrangian = current_primal_lagrangian + primal_lagrangian_contribution # Modify "private" attributes to accumulate Lagrangian values over successive # calls to `populate_primal_lagrangian` self._primal_lagrangian = previous_primal_lagrangian + current_primal_lagrangian - self._primal_constraint_stores.extend(current_primal_constraint_stores) + self._primal_constraint_measurements.extend(current_primal_constraint_measurements) # We return any existent values for the _dual_lagrangian, and the - # _dual_constraint_stores. The _primal_lagrangian and _primal_constraint_stores + # _dual_constraint_measurements. The _primal_lagrangian and _primal_constraint_measurements # attributes have been modified earlier, so their updated values are returned. lagrangian_store = LagrangianStore( lagrangian=self._primal_lagrangian, dual_lagrangian=self._dual_lagrangian, - primal_constraint_stores=self._primal_constraint_stores, - dual_constraint_stores=self._dual_constraint_stores, + primal_constraint_measurements=self._primal_constraint_measurements, + dual_constraint_measurements=self._dual_constraint_measurements, ) return lagrangian_store @@ -125,42 +128,45 @@ def populate_dual_lagrangian(self) -> LagrangianStore: return LagrangianStore( lagrangian=self._primal_lagrangian, dual_lagrangian=self._dual_lagrangian, - primal_constraint_stores=self._primal_constraint_stores, - dual_constraint_stores=self._dual_constraint_stores, + primal_constraint_measurements=self._primal_constraint_measurements, + dual_constraint_measurements=self._dual_constraint_measurements, ) # At least one observed constraint contributes to the dual Lagrangian. previous_dual_lagrangian = 0.0 if self._dual_lagrangian is None else self._dual_lagrangian current_dual_lagrangian = 0.0 - current_dual_constraint_stores = [] + current_dual_constraint_measurements = [] for constraint_group, constraint_state in contributing_constraints: - dual_constraint_store = constraint_group.compute_constraint_dual_contribution(constraint_state) - current_dual_constraint_stores.append(dual_constraint_store) - if dual_constraint_store is not None: - current_dual_lagrangian = current_dual_lagrangian + dual_constraint_store.lagrangian_contribution - - # Extracting the violation from the dual_constraint_store ensures that it is + ( + dual_lagrangian_contribution, + dual_constraint_measurement, + ) = constraint_group.compute_constraint_dual_contribution(constraint_state) + current_dual_constraint_measurements.append(dual_constraint_measurement) + if dual_lagrangian_contribution is not None: + current_dual_lagrangian = current_dual_lagrangian + dual_lagrangian_contribution + + # Extracting the violation from the dual_constraint_measurement ensures that it is # the "strict" violation, if available. _, strict_constraint_features = constraint_state.extract_constraint_features() constraint_group.update_strictly_feasible_indices_( - strict_violation=dual_constraint_store.violation, + strict_violation=dual_constraint_measurement.violation, strict_constraint_features=strict_constraint_features, ) # Modify "private" attributes to accumulate Lagrangian values over successive # calls to `populate_dual_lagrangian` self._dual_lagrangian = previous_dual_lagrangian + current_dual_lagrangian - self._dual_constraint_stores.extend(current_dual_constraint_stores) + self._dual_constraint_measurements.extend(current_dual_constraint_measurements) # We return any existent values for the _primal_lagrangian, and the - # _primal_constraint_stores. The _dual_lagrangian and _dual_constraint_stores + # _primal_constraint_measurements. The _dual_lagrangian and _dual_constraint_measurements # attributes have been modified earlier, so their updated values are returned. lagrangian_store = LagrangianStore( lagrangian=self._primal_lagrangian, dual_lagrangian=self._dual_lagrangian, - primal_constraint_stores=self._primal_constraint_stores, - dual_constraint_stores=self._dual_constraint_stores, + primal_constraint_measurements=self._primal_constraint_measurements, + dual_constraint_measurements=self._dual_constraint_measurements, ) return lagrangian_store @@ -192,12 +198,12 @@ def populate_lagrangian(self) -> LagrangianStore: def purge_primal_lagrangian(self) -> None: """Purge the accumulated primal Lagrangian contributions.""" self._primal_lagrangian = None - self._primal_constraint_stores = [] + self._primal_constraint_measurements = [] def purge_dual_lagrangian(self) -> None: """Purge the accumulated dual Lagrangian contributions.""" self._dual_lagrangian = None - self._dual_constraint_stores = [] + self._dual_constraint_measurements = [] def purge_lagrangian(self) -> None: """Purge the accumulated Lagrangian contributions.""" diff --git a/cooper/constraints/__init__.py b/cooper/constraints/__init__.py index d6a35315..18d3c759 100644 --- a/cooper/constraints/__init__.py +++ b/cooper/constraints/__init__.py @@ -1,3 +1,3 @@ from .constraint_group import ConstraintGroup -from .constraint_state import ConstraintState, ConstraintStore, ConstraintType +from .constraint_state import ConstraintMeasurement, ConstraintState, ConstraintType from .slacks import ConstantSlack, DenseSlack, ExplicitSlack, IndexedSlack, SlackVariable diff --git a/cooper/constraints/constraint_group.py b/cooper/constraints/constraint_group.py index fa015074..99810d82 100644 --- a/cooper/constraints/constraint_group.py +++ b/cooper/constraints/constraint_group.py @@ -3,7 +3,7 @@ import torch from cooper import multipliers -from cooper.constraints.constraint_state import ConstraintState, ConstraintStore, ConstraintType +from cooper.constraints.constraint_state import ConstraintMeasurement, ConstraintState, ConstraintType from cooper.formulations import FormulationType from cooper.multipliers import IndexedMultiplier, Multiplier, PenaltyCoefficient @@ -78,12 +78,16 @@ def prepare_kwargs_for_lagrangian_contribution(self, constraint_state: Constrain return kwargs - def compute_constraint_primal_contribution(self, constraint_state: ConstraintState) -> ConstraintStore: + def compute_constraint_primal_contribution( + self, constraint_state: ConstraintState + ) -> tuple[Optional[torch.Tensor], Optional[ConstraintMeasurement]]: """Compute the contribution of the current constraint to the primal Lagrangian.""" kwargs = self.prepare_kwargs_for_lagrangian_contribution(constraint_state=constraint_state) return self.formulation.compute_contribution_for_primal_lagrangian(**kwargs) - def compute_constraint_dual_contribution(self, constraint_state: ConstraintState) -> ConstraintStore: + def compute_constraint_dual_contribution( + self, constraint_state: ConstraintState + ) -> tuple[Optional[torch.Tensor], Optional[ConstraintMeasurement]]: """Compute the contribution of the current constraint to the dual Lagrangian.""" kwargs = self.prepare_kwargs_for_lagrangian_contribution(constraint_state=constraint_state) return self.formulation.compute_contribution_for_dual_lagrangian(**kwargs) diff --git a/cooper/constraints/constraint_state.py b/cooper/constraints/constraint_state.py index cd9e4a9d..5d770376 100644 --- a/cooper/constraints/constraint_state.py +++ b/cooper/constraints/constraint_state.py @@ -121,14 +121,13 @@ def compute_strictly_feasible_constraints(self) -> tuple[torch.Tensor, torch.Ten @dataclass -class ConstraintStore: - # TODO: update docstring. Current ConstraintStore is agnostic to dual or primal +class ConstraintMeasurement: + # TODO: update docstring. Current ConstraintMeasurement is agnostic to dual or primal # lagrangian. """Stores the value of the constraint factor (multiplier or penalty coefficient), - the contribution of the constraint to the primal-differentiable Lagrian, and the + the contribution of the constraint to the primal-differentiable Lagrangian, and the contribution of the constraint to the dual-differentiable Lagrangian.""" - lagrangian_contribution: Optional[torch.Tensor] = None violation: Optional[torch.Tensor] = None multiplier_value: Optional[torch.Tensor] = None penalty_coefficient_value: Optional[torch.Tensor] = None diff --git a/cooper/formulations/formulations.py b/cooper/formulations/formulations.py index bdd2e014..8ef4f54d 100644 --- a/cooper/formulations/formulations.py +++ b/cooper/formulations/formulations.py @@ -1,7 +1,10 @@ import abc +from typing import Optional + +import torch import cooper.formulations.utils as formulation_utils -from cooper.constraints.constraint_state import ConstraintState, ConstraintStore, ConstraintType +from cooper.constraints.constraint_state import ConstraintMeasurement, ConstraintState, ConstraintType from cooper.multipliers import Multiplier, PenaltyCoefficient, evaluate_constraint_factor @@ -36,7 +39,7 @@ def __init__(self, constraint_type: ConstraintType): def compute_contribution_for_primal_lagrangian( self, constraint_state: ConstraintState, penalty_coefficient: PenaltyCoefficient - ) -> ConstraintStore: + ) -> ConstraintMeasurement: if not constraint_state.contributes_to_primal_update: return None else: @@ -49,7 +52,7 @@ def compute_contribution_for_primal_lagrangian( weighted_violation = formulation_utils.compute_primal_weighted_violation( constraint_factor_value=penalty_coefficient_value, violation=violation ) - primal_constraint_store = ConstraintStore( + primal_constraint_store = ConstraintMeasurement( violation=violation, penalty_coefficient_value=penalty_coefficient_value, lagrangian_contribution=weighted_violation, @@ -79,7 +82,7 @@ def __init__(self, constraint_type: ConstraintType): def compute_contribution_for_primal_lagrangian( self, constraint_state: ConstraintState, penalty_coefficient: PenaltyCoefficient - ) -> ConstraintStore: + ) -> ConstraintMeasurement: if not constraint_state.contributes_to_primal_update: return None @@ -101,7 +104,7 @@ def compute_contribution_for_primal_lagrangian( constraint_type=self.constraint_type, ) - primal_constraint_store = ConstraintStore( + primal_constraint_store = ConstraintMeasurement( violation=violation, penalty_coefficient_value=penalty_coefficient_value, lagrangian_contribution=weighted_violation, @@ -127,51 +130,47 @@ def __init__(self, constraint_type: ConstraintType): def compute_contribution_for_primal_lagrangian( self, constraint_state: ConstraintState, multiplier: Multiplier - ) -> ConstraintStore: + ) -> tuple[Optional[torch.Tensor], Optional[ConstraintMeasurement]]: if not constraint_state.contributes_to_primal_update: - return None - else: - - violation, strict_violation = constraint_state.extract_violations() - constraint_features, strict_constraint_features = constraint_state.extract_constraint_features() - multiplier_value = evaluate_constraint_factor( - module=multiplier, violation=violation, constraint_features=constraint_features - ) - weighted_violation = formulation_utils.compute_primal_weighted_violation( - constraint_factor_value=multiplier_value, violation=violation - ) - primal_constraint_store = ConstraintStore( - multiplier_value=multiplier_value, - violation=violation, - lagrangian_contribution=weighted_violation, - ) - - return primal_constraint_store + return None, None + + violation, strict_violation = constraint_state.extract_violations() + constraint_features, strict_constraint_features = constraint_state.extract_constraint_features() + multiplier_value = evaluate_constraint_factor( + module=multiplier, violation=violation, constraint_features=constraint_features + ) + lagrangian_contribution = formulation_utils.compute_primal_weighted_violation( + constraint_factor_value=multiplier_value, violation=violation + ) + primal_constraint_store = ConstraintMeasurement( + multiplier_value=multiplier_value, + violation=violation, + ) + + return lagrangian_contribution, primal_constraint_store def compute_contribution_for_dual_lagrangian( self, constraint_state: ConstraintState, multiplier: Multiplier - ) -> tuple[ConstraintStore, ConstraintStore]: + ) -> tuple[Optional[torch.Tensor], Optional[ConstraintMeasurement]]: if not constraint_state.contributes_to_dual_update: - return None - else: - - violation, strict_violation = constraint_state.extract_violations() - constraint_features, strict_constraint_features = constraint_state.extract_constraint_features() - multiplier_value = evaluate_constraint_factor( - module=multiplier, violation=strict_violation, constraint_features=strict_constraint_features - ) - weighted_violation = formulation_utils.compute_dual_weighted_violation( - constraint_factor_value=multiplier_value, violation=strict_violation - ) - dual_constraint_store = ConstraintStore( - multiplier_value=multiplier_value, - violation=strict_violation, - lagrangian_contribution=weighted_violation, - ) - - return dual_constraint_store + return None, None + + violation, strict_violation = constraint_state.extract_violations() + constraint_features, strict_constraint_features = constraint_state.extract_constraint_features() + multiplier_value = evaluate_constraint_factor( + module=multiplier, violation=strict_violation, constraint_features=strict_constraint_features + ) + lagrangian_contribution = formulation_utils.compute_dual_weighted_violation( + constraint_factor_value=multiplier_value, violation=strict_violation + ) + dual_constraint_store = ConstraintMeasurement( + multiplier_value=multiplier_value, + violation=strict_violation, + ) + + return lagrangian_contribution, dual_constraint_store def __repr__(self): return f"LagrangianFormulation(constraint_type={self.constraint_type})" @@ -214,67 +213,63 @@ def __init__( def compute_contribution_for_primal_lagrangian( self, constraint_state: ConstraintState, multiplier: Multiplier, penalty_coefficient: PenaltyCoefficient - ) -> ConstraintStore: + ) -> tuple[Optional[torch.Tensor], Optional[ConstraintMeasurement]]: if not constraint_state.contributes_to_primal_update: - return None - else: - - violation, strict_violation = constraint_state.extract_violations() - constraint_features, strict_constraint_features = constraint_state.extract_constraint_features() - multiplier_value = evaluate_constraint_factor( - module=multiplier, violation=violation, constraint_features=constraint_features - ) - penalty_coefficient_value = evaluate_constraint_factor( - module=penalty_coefficient, violation=violation, constraint_features=constraint_features - ) - - augmented_weighted_violation = formulation_utils.compute_quadratic_augmented_contribution( - multiplier_value=multiplier_value, - penalty_coefficient_value=penalty_coefficient_value, - violation=violation, - constraint_type=self.constraint_type, - ) - - primal_constraint_store = ConstraintStore( - lagrangian_contribution=augmented_weighted_violation, - violation=violation, - multiplier_value=multiplier_value, - penalty_coefficient_value=penalty_coefficient_value, - ) - - return primal_constraint_store + return None, None + + violation, strict_violation = constraint_state.extract_violations() + constraint_features, strict_constraint_features = constraint_state.extract_constraint_features() + multiplier_value = evaluate_constraint_factor( + module=multiplier, violation=violation, constraint_features=constraint_features + ) + penalty_coefficient_value = evaluate_constraint_factor( + module=penalty_coefficient, violation=violation, constraint_features=constraint_features + ) + + lagrangian_contribution = formulation_utils.compute_quadratic_augmented_contribution( + multiplier_value=multiplier_value, + penalty_coefficient_value=penalty_coefficient_value, + violation=violation, + constraint_type=self.constraint_type, + ) + + primal_constraint_store = ConstraintMeasurement( + violation=violation, + multiplier_value=multiplier_value, + penalty_coefficient_value=penalty_coefficient_value, + ) + + return lagrangian_contribution, primal_constraint_store def compute_contribution_for_dual_lagrangian( self, constraint_state: ConstraintState, multiplier: Multiplier, penalty_coefficient: PenaltyCoefficient - ) -> ConstraintStore: + ) -> tuple[Optional[torch.Tensor], Optional[ConstraintMeasurement]]: if not constraint_state.contributes_to_dual_update: - return None - else: - - violation, strict_violation = constraint_state.extract_violations() - constraint_features, strict_constraint_features = constraint_state.extract_constraint_features() - multiplier_value = evaluate_constraint_factor( - module=multiplier, violation=strict_violation, constraint_features=strict_constraint_features - ) - - # TODO: why does evaluate_constraint_factor use violation instead of strict_violation? - penalty_coefficient_value = evaluate_constraint_factor( - module=penalty_coefficient, violation=violation, constraint_features=constraint_features - ) - weighted_violation = formulation_utils.compute_dual_weighted_violation( - constraint_factor_value=multiplier_value, - violation=strict_violation, - penalty_coefficient_value=penalty_coefficient_value, - ) - dual_constraint_store = ConstraintStore( - lagrangian_contribution=weighted_violation, - violation=strict_violation, - multiplier_value=multiplier_value, - ) - - return dual_constraint_store + return None, None + + violation, strict_violation = constraint_state.extract_violations() + constraint_features, strict_constraint_features = constraint_state.extract_constraint_features() + multiplier_value = evaluate_constraint_factor( + module=multiplier, violation=strict_violation, constraint_features=strict_constraint_features + ) + + # TODO: why does evaluate_constraint_factor use violation instead of strict_violation? + penalty_coefficient_value = evaluate_constraint_factor( + module=penalty_coefficient, violation=violation, constraint_features=constraint_features + ) + lagrangian_contribution = formulation_utils.compute_dual_weighted_violation( + constraint_factor_value=multiplier_value, + violation=strict_violation, + penalty_coefficient_value=penalty_coefficient_value, + ) + dual_constraint_store = ConstraintMeasurement( + violation=strict_violation, + multiplier_value=multiplier_value, + ) + + return lagrangian_contribution, dual_constraint_store def __repr__(self): return f"AugmentedLagrangianFormulation(constraint_type={self.constraint_type})" diff --git a/cooper/optim/constrained_optimizers/alternating_optimizer.py b/cooper/optim/constrained_optimizers/alternating_optimizer.py index 1e9670db..9bcafbd9 100644 --- a/cooper/optim/constrained_optimizers/alternating_optimizer.py +++ b/cooper/optim/constrained_optimizers/alternating_optimizer.py @@ -188,8 +188,10 @@ def roll( new_cmp_state.loss = cmp_state.loss assert lagrangian_store_for_dual.lagrangian is None lagrangian_store_for_dual.lagrangian = new_cmp_state.loss + lagrangian_store_for_dual.dual_lagrangian - assert lagrangian_store_for_dual.primal_constraint_stores == [] - lagrangian_store_for_dual.primal_constraint_stores = lagrangian_store_for_primal.primal_constraint_stores + assert lagrangian_store_for_dual.primal_constraint_measurements == [] + lagrangian_store_for_dual.primal_constraint_measurements = ( + lagrangian_store_for_primal.primal_constraint_measurements + ) return new_cmp_state, lagrangian_store_for_dual @@ -248,8 +250,10 @@ def roll(self, compute_cmp_state_fn: Callable[..., CMPState]) -> tuple[CMPState, # Lagrangian estimate. See the docstring for more details. assert lagrangian_store_for_primal.dual_lagrangian is None lagrangian_store_for_primal.dual_lagrangian = lagrangian_store_for_dual.dual_lagrangian - assert lagrangian_store_for_primal.dual_constraint_stores == [] - lagrangian_store_for_primal.dual_constraint_stores = lagrangian_store_for_dual.dual_constraint_stores + assert lagrangian_store_for_primal.dual_constraint_measurements == [] + lagrangian_store_for_primal.dual_constraint_measurements = ( + lagrangian_store_for_dual.dual_constraint_measurements + ) return cmp_state, lagrangian_store_for_primal From ed851f26f90f0cd34371b72899041f8ac27ce127 Mon Sep 17 00:00:00 2001 From: juan43ramirez Date: Wed, 13 Mar 2024 14:27:11 -0400 Subject: [PATCH 2/3] Add TODO in formulations --- cooper/formulations/formulations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cooper/formulations/formulations.py b/cooper/formulations/formulations.py index 8ef4f54d..ffb0b37f 100644 --- a/cooper/formulations/formulations.py +++ b/cooper/formulations/formulations.py @@ -1,3 +1,4 @@ +# TODO(juan43ramirez): File needs to be updated after the switch from ConstraintStore to ConstraintMeasurement import abc from typing import Optional From 0d171fe426c4ca671285a92fd1a2d04f6ca7b16a Mon Sep 17 00:00:00 2001 From: juan43ramirez Date: Wed, 13 Mar 2024 14:27:41 -0400 Subject: [PATCH 3/3] Renaming to improve readability --- cooper/cmp.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/cooper/cmp.py b/cooper/cmp.py index 09f12a23..356fc056 100644 --- a/cooper/cmp.py +++ b/cooper/cmp.py @@ -88,13 +88,12 @@ def populate_primal_lagrangian(self) -> LagrangianStore: current_primal_constraint_measurements = [] for constraint_group, constraint_state in contributing_constraints: - ( - primal_lagrangian_contribution, - primal_constraint_measurement, - ) = constraint_group.compute_constraint_primal_contribution(constraint_state) - current_primal_constraint_measurements.append(primal_constraint_measurement) - if primal_lagrangian_contribution is not None: - current_primal_lagrangian = current_primal_lagrangian + primal_lagrangian_contribution + primal_constraint_contrib, primal_measurement = constraint_group.compute_constraint_primal_contribution( + constraint_state + ) + current_primal_constraint_measurements.append(primal_measurement) + if primal_constraint_contrib is not None: + current_primal_lagrangian = current_primal_lagrangian + primal_constraint_contrib # Modify "private" attributes to accumulate Lagrangian values over successive # calls to `populate_primal_lagrangian` @@ -138,19 +137,18 @@ def populate_dual_lagrangian(self) -> LagrangianStore: current_dual_constraint_measurements = [] for constraint_group, constraint_state in contributing_constraints: - ( - dual_lagrangian_contribution, - dual_constraint_measurement, - ) = constraint_group.compute_constraint_dual_contribution(constraint_state) - current_dual_constraint_measurements.append(dual_constraint_measurement) - if dual_lagrangian_contribution is not None: - current_dual_lagrangian = current_dual_lagrangian + dual_lagrangian_contribution + dual_lagrangian_contrib, dual_measurement = constraint_group.compute_constraint_dual_contribution( + constraint_state + ) + current_dual_constraint_measurements.append(dual_measurement) + if dual_lagrangian_contrib is not None: + current_dual_lagrangian = current_dual_lagrangian + dual_lagrangian_contrib # Extracting the violation from the dual_constraint_measurement ensures that it is # the "strict" violation, if available. _, strict_constraint_features = constraint_state.extract_constraint_features() constraint_group.update_strictly_feasible_indices_( - strict_violation=dual_constraint_measurement.violation, + strict_violation=dual_lagrangian_contrib.violation, strict_constraint_features=strict_constraint_features, )