Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

54 multiplier models #66

Open
wants to merge 61 commits into
base: 54-multiplier-models
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
bba4cb1
Create multiplier_model.py
IsitaRex Nov 5, 2022
e134386
Update multiplier_model.py
IsitaRex Nov 5, 2022
88fdb5f
Update logistic_regression.ipynb
IsitaRex Nov 5, 2022
d3cb8d0
Create test_multiplier_model.py
IsitaRex Nov 6, 2022
da5de4a
Update Multiplier Model to inherit from multipliers
IsitaRex Nov 6, 2022
4c8ec04
Update multipliers.py
IsitaRex Nov 7, 2022
c26ce46
Update __init__.py
IsitaRex Nov 7, 2022
0d3fa08
Update __init__.py
IsitaRex Nov 7, 2022
eef8eb6
Update lagrangian_model.py
IsitaRex Nov 7, 2022
d1868ce
Update multiplier_model.py
IsitaRex Nov 7, 2022
cc92995
Update multipliers.py
IsitaRex Nov 7, 2022
812f54b
Create test_lagrangian_model.py
IsitaRex Nov 7, 2022
c4431f5
Update test_multiplier_model.py
IsitaRex Nov 7, 2022
1faa097
Update test_lagrangian_model.py
IsitaRex Nov 7, 2022
6ddb61a
MM and Lagrangian Model work and pass tests
daoterog Nov 8, 2022
fd84696
Correction - Add again proxy constraints to model formulation
daoterog Nov 9, 2022
af39b27
Add proxy constraints to toy problem
daoterog Nov 9, 2022
4eb714d
Fix test_lagrangian_model test and tol from convergence test
daoterog Nov 9, 2022
fcbda77
Fixed convergence tolerance
daoterog Nov 9, 2022
90eed1d
Refactor to integrate large toy problem in buil tests
daoterog Nov 9, 2022
883c5c9
Add flip gradients to formulations
daoterog Nov 10, 2022
b3075fe
Add missing argument
daoterog Nov 10, 2022
ad79921
Optimizer pending changes
daoterog Nov 10, 2022
6de574d
Pending todo
daoterog Nov 10, 2022
36abc70
Experiment to compare model form with regular form
daoterog Nov 10, 2022
1fe0fc3
Ignore wandb logs
daoterog Nov 10, 2022
97c9e88
Extend build_test_problem to encompass model form
daoterog Nov 10, 2022
183b955
Rewrite tests for current multiplier model object
daoterog Nov 10, 2022
38f3841
New experiment setting
daoterog Nov 11, 2022
943c46e
Change in third constraint and constraint features
daoterog Nov 12, 2022
c129db4
Black formatting
daoterog Nov 12, 2022
af1811e
Update model_form_vs_lag_form.py
IsitaRex Nov 13, 2022
ff3d9e2
Add mock model fomulation
daoterog Nov 14, 2022
c4093bf
Fix by black and isort
IsitaRex Dec 7, 2022
a7426ae
Merge remote-tracking branch 'upstream/dev' into dev
IsitaRex Dec 7, 2022
cc894bf
Merge branch 'dev' into 54-multiplier-models
IsitaRex Dec 7, 2022
26b327e
Remove MockLagrangianModelFormulation
IsitaRex Dec 7, 2022
9e58b63
Black
IsitaRex Dec 7, 2022
6248b6b
Remove wandb from gitignore
IsitaRex Dec 7, 2022
4a80910
Add comment about accessing gradients of module
IsitaRex Dec 7, 2022
3de5a3e
Comment to change grad function name.
IsitaRex Dec 7, 2022
d2b3a35
Add missing documentation
daoterog Dec 9, 2022
695ed73
Correct documentation
daoterog Dec 9, 2022
cceee8c
Modify flip gradients, add flag to multiplier projection
daoterog Feb 3, 2023
de01f2e
Fix circular import and add typehint to grad func
daoterog Feb 3, 2023
3304242
Removed unnecessary TODO
daoterog Feb 3, 2023
92143ab
Simplify object for test
daoterog Feb 3, 2023
a8bc041
Remove comments on alterntive implementation
daoterog Feb 3, 2023
a643e1a
Make test more simple
daoterog Feb 3, 2023
cc76eb1
Formatting
daoterog Feb 3, 2023
a7b07b4
Change todo to git username
daoterog Feb 10, 2023
1a61dd5
Add CMPModelState, fix state, state_dict, load_state_dict, and flip_d…
daoterog Feb 10, 2023
6ef151f
Fix and add TODOs
daoterog Feb 10, 2023
1bb3bc3
Add RuntimeError to unused methods
daoterog Feb 10, 2023
fdaff32
Add git username to TODO
daoterog Feb 10, 2023
5cd667a
Add git username to TODO
daoterog Feb 10, 2023
fcbff54
Add CMPModelState to convergence test
daoterog Feb 10, 2023
2d47807
Black fromatting
daoterog Feb 10, 2023
96705a5
Pending fix on circular import and correct type hint
daoterog Feb 10, 2023
e8a2b32
Properly implement exhaustive tests with fixtures and parameters
daoterog Feb 11, 2023
4f8b6ed
Updated type hints with CMPModelState
daoterog Feb 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cooper/formulation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .augmented_lagrangian import AugmentedLagrangianFormulation
from .formulation import Formulation, UnconstrainedFormulation
from .lagrangian import LagrangianFormulation, ProxyLagrangianFormulation
from .lagrangian_model import LagrangianModelFormulation
2 changes: 1 addition & 1 deletion cooper/formulation/augmented_lagrangian.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def weighted_violation(
# to update the value of the multipliers by lazily filling the
# multiplier gradients in `backward`

# TODO (JGP): Verify that call to backward is general enough for
# TODO (gallego-posada): Verify that call to backward is general enough for
# Lagrange Multiplier models
violation_for_update = torch.sum(multipliers * defect.detach())
self.update_accumulated_violation(update=violation_for_update)
Expand Down
18 changes: 16 additions & 2 deletions cooper/formulation/formulation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import abc
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, Union

import torch

from cooper.problem import CMPState, ConstrainedMinimizationProblem

# from .lagrangian_model import CMPModelState


# Formulation, and some other classes below, are heavily inspired by the design
# of the TensorFlow Constrained Optimization (TFCO) library :
# https://github.com/google-research/tensorflow_constrained_optimization
Expand Down Expand Up @@ -34,6 +37,11 @@ def state(self):
"""Returns the internal state of formulation (e.g. multipliers)."""
pass

@abc.abstractmethod
def flip_dual_gradients(self):
"""Flips the sign of the dual gradients."""
pass

@property
@abc.abstractmethod
def is_state_created(self):
Expand All @@ -59,7 +67,8 @@ def backward(self, *args, **kwargs):
formulation."""
pass

def write_cmp_state(self, cmp_state: CMPState):
# TODO(daoterog): fix circular import type hint can be correct
def write_cmp_state(self, cmp_state: CMPState): # Union[CMPState, CMPModelState]):
"""Provided that the formulation is linked to a
`ConstrainedMinimizationProblem`, writes a CMPState to the CMP."""

Expand Down Expand Up @@ -119,6 +128,11 @@ def load_state_dict(self, state_dict: dict):
"""
pass

def flip_dual_gradients(self):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove repeated class

"""Flips the sign of the dual gradients. This is a no-op for
unconstrained formulations."""
pass

def compute_lagrangian(
self,
closure: Callable[..., CMPState],
Expand Down
16 changes: 12 additions & 4 deletions cooper/formulation/lagrangian.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
# Store user-provided initializations for dual variables
self.ineq_init = ineq_init
self.eq_init = eq_init

# TODO(gallego-posada): comment what is the meaning of this object
self.accumulated_violation_dot_prod: torch.Tensor = None

@property
Expand Down Expand Up @@ -193,11 +193,11 @@ def state_dict(self) -> Dict[str, Any]:
"""
Generates the state dictionary for a Lagrangian formulation.
"""

# TODO(gallego-posada): fix in next PR
state_dict = {
"ineq_multipliers": self.ineq_multipliers,
"eq_multipliers": self.eq_multipliers,
"accumulated_violation_dot_prod": self.accumulated_violation_dot_prod,
# "accumulated_violation_dot_prod": self.accumulated_violation_dot_prod,
}
return state_dict

Expand All @@ -220,6 +220,14 @@ def load_state_dict(self, state_dict: Dict[str, Any]):
), "LagrangianFormulation received unknown key: {}".format(key)
setattr(self, key, val)

def flip_dual_gradients(self):
"""
Flip the sign of the gradients of the dual variables.
"""
for multiplier in self.state():
if multiplier is not None:
multiplier.grad.mul_(-1.0)


class LagrangianFormulation(BaseLagrangianFormulation):
"""
Expand Down Expand Up @@ -353,7 +361,7 @@ def weighted_violation(
# on `accumulated_violation_dot_prod`. This enables easy
# extensibility to multiplier classes beyond DenseMultiplier.

# TODO (JGP): Verify that call to backward is general enough for
# TODO (gallego-posada): Verify that call to backward is general enough for
# Lagrange Multiplier models
violation_for_update = torch.sum(multipliers * defect.detach())
self.update_accumulated_violation(update=violation_for_update)
Expand Down
Loading