Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

54 multiplier models #66

Open
wants to merge 61 commits into
base: 54-multiplier-models
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
bba4cb1
Create multiplier_model.py
IsitaRex Nov 5, 2022
e134386
Update multiplier_model.py
IsitaRex Nov 5, 2022
88fdb5f
Update logistic_regression.ipynb
IsitaRex Nov 5, 2022
d3cb8d0
Create test_multiplier_model.py
IsitaRex Nov 6, 2022
da5de4a
Update Multiplier Model to inherit from multipliers
IsitaRex Nov 6, 2022
4c8ec04
Update multipliers.py
IsitaRex Nov 7, 2022
c26ce46
Update __init__.py
IsitaRex Nov 7, 2022
0d3fa08
Update __init__.py
IsitaRex Nov 7, 2022
eef8eb6
Update lagrangian_model.py
IsitaRex Nov 7, 2022
d1868ce
Update multiplier_model.py
IsitaRex Nov 7, 2022
cc92995
Update multipliers.py
IsitaRex Nov 7, 2022
812f54b
Create test_lagrangian_model.py
IsitaRex Nov 7, 2022
c4431f5
Update test_multiplier_model.py
IsitaRex Nov 7, 2022
1faa097
Update test_lagrangian_model.py
IsitaRex Nov 7, 2022
6ddb61a
MM and Lagrangian Model work and pass tests
daoterog Nov 8, 2022
fd84696
Correction - Add again proxy constraints to model formulation
daoterog Nov 9, 2022
af39b27
Add proxy constraints to toy problem
daoterog Nov 9, 2022
4eb714d
Fix test_lagrangian_model test and tol from convergence test
daoterog Nov 9, 2022
fcbda77
Fixed convergence tolerance
daoterog Nov 9, 2022
90eed1d
Refactor to integrate large toy problem in buil tests
daoterog Nov 9, 2022
883c5c9
Add flip gradients to formulations
daoterog Nov 10, 2022
b3075fe
Add missing argument
daoterog Nov 10, 2022
ad79921
Optimizer pending changes
daoterog Nov 10, 2022
6de574d
Pending todo
daoterog Nov 10, 2022
36abc70
Experiment to compare model form with regular form
daoterog Nov 10, 2022
1fe0fc3
Ignore wandb logs
daoterog Nov 10, 2022
97c9e88
Extend build_test_problem to encompass model form
daoterog Nov 10, 2022
183b955
Rewrite tests for current multiplier model object
daoterog Nov 10, 2022
38f3841
New experiment setting
daoterog Nov 11, 2022
943c46e
Change in third constraint and constraint features
daoterog Nov 12, 2022
c129db4
Black formatting
daoterog Nov 12, 2022
af1811e
Update model_form_vs_lag_form.py
IsitaRex Nov 13, 2022
ff3d9e2
Add mock model fomulation
daoterog Nov 14, 2022
c4093bf
Fix by black and isort
IsitaRex Dec 7, 2022
a7426ae
Merge remote-tracking branch 'upstream/dev' into dev
IsitaRex Dec 7, 2022
cc894bf
Merge branch 'dev' into 54-multiplier-models
IsitaRex Dec 7, 2022
26b327e
Remove MockLagrangianModelFormulation
IsitaRex Dec 7, 2022
9e58b63
Black
IsitaRex Dec 7, 2022
6248b6b
Remove wandb from gitignore
IsitaRex Dec 7, 2022
4a80910
Add comment about accessing gradients of module
IsitaRex Dec 7, 2022
3de5a3e
Comment to change grad function name.
IsitaRex Dec 7, 2022
d2b3a35
Add missing documentation
daoterog Dec 9, 2022
695ed73
Correct documentation
daoterog Dec 9, 2022
cceee8c
Modify flip gradients, add flag to multiplier projection
daoterog Feb 3, 2023
de01f2e
Fix circular import and add typehint to grad func
daoterog Feb 3, 2023
3304242
Removed unnecessary TODO
daoterog Feb 3, 2023
92143ab
Simplify object for test
daoterog Feb 3, 2023
a8bc041
Remove comments on alterntive implementation
daoterog Feb 3, 2023
a643e1a
Make test more simple
daoterog Feb 3, 2023
cc76eb1
Formatting
daoterog Feb 3, 2023
a7b07b4
Change todo to git username
daoterog Feb 10, 2023
1a61dd5
Add CMPModelState, fix state, state_dict, load_state_dict, and flip_d…
daoterog Feb 10, 2023
6ef151f
Fix and add TODOs
daoterog Feb 10, 2023
1bb3bc3
Add RuntimeError to unused methods
daoterog Feb 10, 2023
fdaff32
Add git username to TODO
daoterog Feb 10, 2023
5cd667a
Add git username to TODO
daoterog Feb 10, 2023
fcbff54
Add CMPModelState to convergence test
daoterog Feb 10, 2023
2d47807
Black fromatting
daoterog Feb 10, 2023
96705a5
Pending fix on circular import and correct type hint
daoterog Feb 10, 2023
e8a2b32
Properly implement exhaustive tests with fixtures and parameters
daoterog Feb 11, 2023
4f8b6ed
Updated type hints with CMPModelState
daoterog Feb 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cooper/formulation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .augmented_lagrangian import AugmentedLagrangianFormulation
from .formulation import Formulation, UnconstrainedFormulation
from .lagrangian import LagrangianFormulation, ProxyLagrangianFormulation
from .lagrangian_model import LagrangianModelFormulation
10 changes: 10 additions & 0 deletions cooper/formulation/formulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def state(self):
"""Returns the internal state of formulation (e.g. multipliers)."""
pass

@abc.abstractmethod
def flip_dual_gradients(self):
"""Flips the sign of the dual gradients."""
pass

@property
@abc.abstractmethod
def is_state_created(self):
Expand Down Expand Up @@ -119,6 +124,11 @@ def load_state_dict(self, state_dict: dict):
"""
pass

def flip_dual_gradients(self):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove repeated class

"""Flips the sign of the dual gradients. This is a no-op for
unconstrained formulations."""
pass

def compute_lagrangian(
self,
closure: Callable[..., CMPState],
Expand Down
8 changes: 8 additions & 0 deletions cooper/formulation/lagrangian.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ def load_state_dict(self, state_dict: Dict[str, Any]):
), "LagrangianFormulation received unknown key: {}".format(key)
setattr(self, key, val)

def flip_dual_gradients(self):
"""
Flip the sign of the gradients of the dual variables.
"""
for multiplier in self.state():
if multiplier is not None:
multiplier.grad.mul_(-1.0)


class LagrangianFormulation(BaseLagrangianFormulation):
"""
Expand Down
273 changes: 273 additions & 0 deletions cooper/formulation/lagrangian_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
from typing import Callable, List, Optional, Tuple, Union, no_type_check

import torch

from cooper.formulation.lagrangian import BaseLagrangianFormulation
from cooper.multipliers import MultiplierModel
from cooper.problem import CMPState


class LagrangianModelFormulation(BaseLagrangianFormulation):
"""
# TODO: document this
Computes the Lagrangian based on the predictions of a `MultiplierModel`.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give a more complete description of the formulation and add the arguments of the init to it.

"""

def __init__(
self,
*args,
ineq_multiplier_model: Optional[MultiplierModel] = None,
eq_multiplier_model: Optional[MultiplierModel] = None,
**kwargs,
):
super().__init__(*args, **kwargs)

self.ineq_multiplier_model = ineq_multiplier_model
self.eq_multiplier_model = eq_multiplier_model

if self.ineq_multiplier_model is None and self.eq_multiplier_model is None:
# TODO: document this
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really necessary to document this?

raise ValueError("At least one multiplier model must be provided.")

if self.ineq_multiplier_model is not None and not isinstance(
self.ineq_multiplier_model, MultiplierModel
):
raise ValueError("The `ineq_multiplier_model` must be a `MultiplierModel`.")

if self.eq_multiplier_model is not None and not isinstance(
self.eq_multiplier_model, MultiplierModel
):
raise ValueError("The `eq_multiplier_model` must be a `MultiplierModel`.")

def create_state(self):
"""This method is not implemented for this formulation. It originally
instantiates the dual variables, but in this formulation this is done since the
instantiation of the object, since it is necessary to provide a `MultiplerModel`
for each of the contraint types."""
pass

# TODO(IsitaRex): create_state_from_metadata missing
def create_state_from_metadata(self):
pass

@property
def is_state_created(self):
"""
Returns ``True`` if any Lagrange multipliers have been initialized.
"""
return (
self.ineq_multiplier_model is not None
or self.eq_multiplier_model is not None
)

@property
def dual_parameters(self) -> List[torch.Tensor]:
"""Returns a list gathering all dual parameters."""
all_dual_params = []

for mult in [self.ineq_multiplier_model, self.eq_multiplier_model]:
if mult is not None:
all_dual_params.extend(list(mult.parameters()))

return all_dual_params

def state(
self,
eq_features: Optional[torch.Tensor] = None,
ineq_features: Optional[torch.Tensor] = None,
) -> Tuple[Union[None, torch.Tensor]]:

"""
Collects all dual variables and returns a tuple containing their
:py:class:`torch.Tensor` values. Note that the *values* are the output of a
`MultiplierModel`, thus, it is necessary to pass the model inputs to obtain the
predictions.
"""

assert (
eq_features is not None or ineq_features is not None
), "At least one of `eq_features` or `ineq_features` must be provided."

if ineq_features is None:
ineq_state = None
else:
ineq_state = self.ineq_multiplier_model.forward(ineq_features)

if eq_features is None:
eq_state = None
else:
eq_state = self.eq_multiplier_model.forward(eq_features)

return ineq_state, eq_state

def flip_dual_gradients(self) -> None:
"""
Flips the sign of the gradients for the dual variables. This is useful
when using the dual formulation in conjunction with the alternating
update scheme.
"""
# FIXME(IsitaRex): We are accessing grad from the multiplier_model,
# but not from the multiplier_model.parameters(). Check if this is
# correct.
for constraint_type in ["eq", "ineq"]:
mult_name = constraint_type + "_multiplier_model"
multiplier_model = getattr(self, mult_name)
if multiplier_model is not None:
for param_grad in multiplier_model.grad:
if param_grad is not None:
param_grad.mul_(-1.0)

@no_type_check
def compute_lagrangian(
self,
closure: Callable[..., CMPState] = None,
*closure_args,
pre_computed_state: Optional[CMPState] = None,
write_state: bool = True,
**closure_kwargs,
) -> torch.Tensor:
""" """

assert (
closure is not None or pre_computed_state is not None
), "At least one of closure or pre_computed_state must be provided"

if pre_computed_state is not None:
cmp_state = pre_computed_state
else:
cmp_state = closure(*closure_args, **closure_kwargs)

if write_state:
self.write_cmp_state(cmp_state)

# Extract values from ProblemState object
loss = cmp_state.loss

# Purge previously accumulated constraint violations
self.update_accumulated_violation(update=None)

# Compute contribution of the sampled constraint violations, weighted by the
# current multiplier values predicted by the multuplier model.
ineq_viol = self.weighted_violation(cmp_state, "ineq")
eq_viol = self.weighted_violation(cmp_state, "eq")

# Lagrangian = loss + \sum_i multiplier_i * defect_i
lagrangian = loss + ineq_viol + eq_viol

return lagrangian

def weighted_violation(
self, cmp_state: CMPState, constraint_type: str
) -> torch.Tensor:
"""
Computes the dot product between the current multipliers and the
constraint violations of type ``constraint_type``. The multiplier correspond to
the output of a `MultiplierModel` provided when the formulation was initialized.
The model is trained on `constraint_features`, which are provided to the
CMPState in the misc, and is optimized with respect to the lagrangian. If proxy-
constraints are provided in the :py:class:`.CMPState`, the non-proxy (usually
non-differentiable) constraints are used for computing the dot product,
while the "proxy-constraint" dot products are accumulated under
``self.accumulated_violation_dot_prod``.

Args:
cmp_state: current ``CMPState``
constraint_type: type of constrained to be used, e.g. "eq" or "ineq".
"""

defect = getattr(cmp_state, constraint_type + "_defect")
has_defect = defect is not None

proxy_defect = getattr(cmp_state, "proxy_" + constraint_type + "_defect")
has_proxy_defect = proxy_defect is not None

if not has_proxy_defect:
# If not given proxy constraints, then the regular defects are
# used for computing gradients and evaluating the multipliers
proxy_defect = defect

if not has_defect:
# We should always have at least the "regular" defects, if not, then
# the problem instance does not have `constraint_type` constraints
violation = torch.tensor([0.0], device=cmp_state.loss.device)
else:
mult_model = getattr(self, constraint_type + "_multiplier_model")

# Get multipliers by performing a prediction over the features of the
# sampled constraints
constraint_features = cmp_state.misc[
constraint_type + "_constraint_features"
]

multipliers = mult_model.forward(constraint_features)

# The violations are computed via inner product between the multipliers
# and the defects, they should have the same shape. If given proxy-defects
# then their shape has to be checked as well.
assert defect.shape == proxy_defect.shape == multipliers.shape

# Store the multiplier values
setattr(self, constraint_type + "_multipliers", multipliers)

# We compute (primal) gradients of this object with the sampled
# constraints
violation = torch.sum(multipliers.detach() * proxy_defect)

# This is the violation of the "actual/hard" constraint. We use this
# to update the multipliers.
# The gradients for the dual variables are computed via a backward
# on `accumulated_violation_dot_prod`. This enables easy
# extensibility to multiplier classes beyond DenseMultiplier.

# TODO (JGP): Verify that call to backward is general enough for
# Lagrange Multiplier models
violation_for_update = torch.sum(multipliers * defect.detach())
self.update_accumulated_violation(update=violation_for_update)

return violation

@no_type_check
def backward(
self,
lagrangian: torch.Tensor,
ignore_primal: bool = False,
ignore_dual: bool = False,
):
"""
Performs the actual backward computation which populates the gradients
for the primal and dual variables.

Args:
lagrangian: Value of the computed Lagrangian based on which the
gradients for the primal and dual variables are populated.
ignore_primal: If ``True``, only the gradients with respect to the
dual variables are populated (these correspond to the constraint
violations). This feature is mainly used in conjunction with
``alternating`` updates, which require updating the multipliers
based on the constraints violation *after* having updated the
primal parameters. Defaults to False.
ignore_dual: If ``True``, the gradients with respect to the dual
variables are not populated.
"""

if ignore_primal:
# Only compute gradients wrt Lagrange multipliers
# No need to call backward on Lagrangian as the dual variables have
# been detached when computing the `weighted_violation`s
pass
else:
# Compute gradients wrt _primal_ parameters only.
# The gradient for the dual variables is computed based on the
# non-proxy violations below.
lagrangian.backward()

# Fill in the gradients for the dual variables based on the violation of
# the non-proxy constraints
if not ignore_dual:
dual_vars = self.dual_parameters
self.accumulated_violation_dot_prod.backward(inputs=dual_vars)
# FIXME(IsitaRex): Alternative implementation had the code below
# It's unclear why this difference exists. Must be checked before
# merging into main
# new_loss = lagrangian - self.accumulated_violation_dot_prod
# new_loss.backward(inputs=dual_vars)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This alternative implementation should be disregarded, that implementation corresponds to confusion made while looking at the authors' code, it was a test. The current implementation is the one that works and should be kept.

1 change: 1 addition & 0 deletions cooper/multipliers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .multiplier_model import MultiplierModel
from .multipliers import BaseMultiplier, DenseMultiplier
58 changes: 58 additions & 0 deletions cooper/multipliers/multiplier_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import abc
from typing import Iterator

import torch

from cooper.multipliers import BaseMultiplier


class MultiplierModel(BaseMultiplier, metaclass=abc.ABCMeta):
"""
A multiplier model. Holds a :py:class:`~torch.nn.Module`, which predicts
the value of the Lagrange multipliers associated with the equality or
inequality constraints of a
:py:class:`~cooper.problem.ConstrainedMinimizationProblem`.

Args:
model: A :py:class:`~torch.nn.Module` which predicts the values of the
Lagrange multipliers.
is_positive: Whether to enforce non-negativity on the values of the
multiplier.
"""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these arguments valid? I think that they should be removed and we should make explicit that class is meant to be inherited by the model class, just as it is done with torch.nn.Module


def __init__(self):
super().__init__()

@abc.abstractmethod
def forward(self, constraint_features: torch.Tensor):
"""
Returns the *actual* value of the multipliers by
passing the "features" of the constraint to predict the corresponding
multiplier.
"""
pass

@property
def shape(self):
"""
Returns the shape of the explicit multipliers. In the case of implicit
multipliers, this should return the *actual* predicted multipliers.
"""
pass

@property
# FIXME(IsitaRex): Rename this.
def grad(self):
"""Yields the current gradients stored in each fo the model parameters."""
for parameter in self.parameters():
yield parameter.grad

def project_(self):
raise RuntimeError("""project_ method does not exist for MultiplierModel.""")

def restart_if_feasible_(self):
raise RuntimeError(
"""restart_if_feasible_ method does not exist for MultiplierModel."""
)

# TODO: Add __str__ and similar methods to MultiplierModel if possible
Loading