Skip to content

Commit

Permalink
Reduce memory usage and makes saving constraints optional (#1532)
Browse files Browse the repository at this point in the history
  • Loading branch information
bqpd committed Dec 8, 2020
1 parent 00cde4a commit a436015
Show file tree
Hide file tree
Showing 16 changed files with 132 additions and 119 deletions.
9 changes: 4 additions & 5 deletions gpkit/constraints/costed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,20 @@ def __init__(self, cost, constraints, substitutions=None):
self.cost = maybe_flatten(cost)
if isinstance(self.cost, np.ndarray): # if it's still a vector
raise ValueError("Cost must be scalar, not the vector %s." % cost)
subs = {k: k.value for k in self.cost.varkeys if "value" in k.descr}
subs = {k: k.value for k in self.cost.vks if "value" in k.descr}
if substitutions:
subs.update(substitutions)
ConstraintSet.__init__(self, constraints, subs)
self.varkeys.update(self.cost.varkeys)
ConstraintSet.__init__(self, constraints, subs, bonusvks=self.cost.vks)

def constrained_varkeys(self):
"Return all varkeys in the cost and non-ConstraintSet constraints"
constrained_varkeys = ConstraintSet.constrained_varkeys(self)
constrained_varkeys.update(self.cost.varkeys)
constrained_varkeys.update(self.cost.vks)
return constrained_varkeys

def _rootlines(self, excluded=()):
"String showing cost, to be used when this is the top constraint"
if self.cost.varkeys:
if self.cost.vks:
description = ["", "Cost Function", "-------------",
" %s" % self.cost.str_without(excluded),
"", "Constraints", "-----------"]
Expand Down
20 changes: 10 additions & 10 deletions gpkit/constraints/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self, cost, constraints, substitutions, *, checkbounds=True):
if not isinstance(sub, (Numbers, np.ndarray)):
raise TypeError("substitution {%s: %s} has invalid value type"
" %s." % (key, sub, type(sub)))
cost_hmap = cost.hmap.sub(self.substitutions, cost.varkeys)
cost_hmap = cost.hmap.sub(self.substitutions, cost.vks)
if any(c <= 0 for c in cost_hmap.values()):
raise InvalidPosynomial("a GP's cost must be Posynomial")
hmapgen = ConstraintSet.as_hmapslt1(constraints, self.substitutions)
Expand Down Expand Up @@ -305,33 +305,33 @@ def _compile_result(self, solver_out):
la, self.nu_by_posy = self._generate_nula(solver_out)
cost_senss = sum(nu_i*exp for (nu_i, exp) in zip(self.nu_by_posy[0],
self.cost.hmap))
self.v_ss = cost_senss.copy()
gpv_ss = cost_senss.copy()
for las, nus, c in zip(la[1:], self.nu_by_posy[1:], self.hmaps[1:]):
while getattr(c, "parent", None) is not None:
c = c.parent
v_ss, c_senss = c.sens_from_dual(las, nus, result)
for vk, x in v_ss.items():
self.v_ss[vk] = x + self.v_ss.get(vk, 0)
while getattr(c, "generated_by", None) is not None:
gpv_ss[vk] = x + gpv_ss.get(vk, 0)
while getattr(c, "generated_by", None):
c = c.generated_by
result["sensitivities"]["constraints"][c] = c_senss
# carry linked sensitivities over to their constants
for v in list(v for v in self.v_ss if v.gradients):
dlogcost_dlogv = self.v_ss.pop(v)
for v in list(v for v in gpv_ss if v.gradients):
dlogcost_dlogv = gpv_ss.pop(v)
val = np.array(result["constants"][v])
for c, dv_dc in v.gradients.items():
with warnings.catch_warnings(): # skip pesky divide-by-zeros
warnings.simplefilter("ignore")
dlogv_dlogc = dv_dc * result["constants"][c]/val
before = self.v_ss.get(c, 0)
self.v_ss[c] = before + dlogcost_dlogv*dlogv_dlogc
before = gpv_ss.get(c, 0)
gpv_ss[c] = before + dlogcost_dlogv*dlogv_dlogc
if v in cost_senss:
if c in self.cost.varkeys:
if c in self.cost.vks:
dlogcost_dlogv = cost_senss.pop(v)
before = cost_senss.get(c, 0)
cost_senss[c] = before + dlogcost_dlogv*dlogv_dlogc
result["sensitivities"]["cost"] = cost_senss
result["sensitivities"]["variables"] = KeyDict(self.v_ss)
result["sensitivities"]["variables"] = KeyDict(gpv_ss)
result["sensitivities"]["constants"] = \
result["sensitivities"]["variables"] # NOTE: backwards compat.
result["soltime"] = solver_out["soltime"]
Expand Down
2 changes: 1 addition & 1 deletion gpkit/constraints/relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(self, constraints, *, include_only=None, exclude=None):
if not isinstance(constraints, ConstraintSet):
constraints = ConstraintSet(constraints)
substitutions = KeyDict(constraints.substitutions)
constants, _, linked = parse_subs(constraints.varkeys, substitutions)
constants, _, linked = parse_subs(constraints.vks, substitutions)
if linked:
kdc = KeyDict(constants)
constrained_varkeys = constraints.constrained_varkeys()
Expand Down
56 changes: 31 additions & 25 deletions gpkit/constraints/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,58 +64,60 @@ class ConstraintSet(list, ReprMixin):
"Recursive container for ConstraintSets and Inequalities"
unique_varkeys, idxlookup = frozenset(), {}
_name_collision_varkeys = None
_varkeys = None

def __init__(self, constraints, substitutions=None): # pylint: disable=too-many-branches,too-many-statements
def __init__(self, constraints, substitutions=None, *, bonusvks=None): # pylint: disable=too-many-branches,too-many-statements
if isinstance(constraints, dict):
keys, constraints = sort_constraints_dict(constraints)
self.idxlookup = {k: i for i, k in enumerate(keys)}
elif isinstance(constraints, ConstraintSet):
constraints = [constraints] # put it one level down
list.__init__(self, constraints)
self.varkeys = KeySet(self.unique_varkeys)
self.vks = set(self.unique_varkeys)
self.substitutions = KeyDict({k: k.value for k in self.unique_varkeys
if "value" in k.descr})
self.substitutions.varkeys = self.varkeys
self.substitutions.cset = self
self.bounded, self.meq_bounded = set(), defaultdict(set)
for i, constraint in enumerate(self):
if hasattr(constraint, "varkeys"):
if hasattr(constraint, "vks"):
self._update(constraint)
elif not (hasattr(constraint, "as_hmapslt1")
or hasattr(constraint, "as_gpconstr")):
try:
for subconstraint in flatiter(constraint, "varkeys"):
for subconstraint in flatiter(constraint, "vks"):
self._update(subconstraint)
except Exception as e:
raise badelement(self, i, constraint) from e
elif isinstance(constraint, ConstraintSet):
raise badelement(self, i, constraint,
" It had not yet been initialized!")
if bonusvks:
self.vks.update(bonusvks)
if substitutions:
self.substitutions.update(substitutions)
for subkey in self.substitutions:
if subkey.shape and not subkey.idx: # vector sub found
for key in self.varkeys:
if key.veckey:
self.varkeys.keymap[key.veckey].add(key)
break # vectorkeys need to be mapped only once
for subkey in self.substitutions:
for key in self.varkeys[subkey]:
self.bounded.add((key, "upper"))
self.bounded.add((key, "lower"))
if key.value is not None and not key.constant:
del key.descr["value"]
if key.veckey and key.veckey.value is not None:
del key.veckey.descr["value"]
self._varkeys = None
for key in self.vks:
if key not in self.substitutions:
if key.veckey is None or key.veckey not in self.substitutions:
continue
if np.isnan(self.substitutions[key.veckey][key.idx]):
continue
self.bounded.add((key, "upper"))
self.bounded.add((key, "lower"))
if key.value is not None and not key.constant:
del key.descr["value"]
if key.veckey and key.veckey.value is not None:
del key.veckey.descr["value"]
add_meq_bounds(self.bounded, self.meq_bounded)

def _update(self, constraint):
"Update parameters with a given constraint"
self.varkeys.update(constraint.varkeys)
self.vks.update(constraint.vks)
if hasattr(constraint, "substitutions"):
self.substitutions.update(constraint.substitutions)
else:
self.substitutions.update({k: k.value \
for k in constraint.varkeys if "value" in k.descr})
for k in constraint.vks if "value" in k.descr})
self.bounded.update(constraint.bounded)
for bound, solutionset in constraint.meq_bounded.items():
self.meq_bounded[bound].update(solutionset)
Expand Down Expand Up @@ -150,12 +152,16 @@ def variables_byname(self, key):
return sorted([Variable(k) for k in self.varkeys[key]],
key=_sort_by_name_and_idx)

@property
def varkeys(self):
"The NomialData's varkeys, created when necessary for a substitution."
if self._varkeys is None:
self._varkeys = KeySet(self.vks)
return self._varkeys

def constrained_varkeys(self):
"Return all varkeys in non-ConstraintSet constraints"
constrained_varkeys = set()
for constraint in self.flat(yield_if_hasattr="varkeys"):
constrained_varkeys.update(constraint.varkeys)
return constrained_varkeys
return self.vks - self.unique_varkeys

flat = flatiter

Expand Down
2 changes: 1 addition & 1 deletion gpkit/constraints/sgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self, cost, model, substitutions, *,
constraint = (Posynomial(hmaplt1) <= self.slack)
constraint.generated_by = cs
self.approxconstraints.append(constraint)
self.sgpvks.update(constraint.varkeys)
self.sgpvks.update(constraint.vks)
if not self.sgpconstraints:
raise UnnecessarySGP("""Model valid as a Geometric Program.
Expand Down
24 changes: 9 additions & 15 deletions gpkit/keydict.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class KeyMap:
collapse_arrays = False
keymap = []
log_gets = False
varkeys = None
cset = None

def __init__(self, *args, **kwargs):
"Passes through to super().__init__ via the `update()` method"
Expand All @@ -62,12 +62,12 @@ def parse_and_index(self, key):
return key.veckey, key.idx
return key, None
except AttributeError:
if not self.varkeys:
if self.cset is None:
return key, self.update_keymap()
# looks like we're in a substitutions dictionary
if key not in self.varkeys: # pylint:disable=unsupported-membership-test
if key not in self.cset.varkeys:
raise KeyError(key)
newkey, *otherkeys = self.varkeys[key] # pylint:disable=unsubscriptable-object
newkey, *otherkeys = self.cset.varkeys[key]
if otherkeys:
if all(k.veckey == newkey.veckey for k in otherkeys):
return newkey.veckey, None
Expand Down Expand Up @@ -186,7 +186,7 @@ def __getitem__(self, key):

def __setitem__(self, key, value):
"Overloads __setitem__ and []= to work with all keys"
# pylint: disable=too-many-boolean-expressions,too-many-branches
# pylint: disable=too-many-boolean-expressions,too-many-branches,too-many-statements
try:
key, idx = self.parse_and_index(key)
except KeyError as e: # may be indexed VectorVariable
Expand Down Expand Up @@ -285,16 +285,10 @@ class KeySet(KeyMap, set):

def update(self, keys):
"Iterates through the dictionary created by args and kwargs"
if isinstance(keys, KeySet):
set.update(self, keys)
for key, value in keys.keymap.items():
self.keymap[key].update(value)
self._unmapped_keys.update(keys._unmapped_keys) # pylint: disable=protected-access
else: # set-like interface
for key in keys:
self.keymap[key].add(key)
self._unmapped_keys.update(keys)
super().update(keys)
for key in keys:
self.keymap[key].add(key)
self._unmapped_keys.update(keys)
super().update(keys)

def __getitem__(self, key):
"Gets the keys corresponding to a particular key."
Expand Down
17 changes: 10 additions & 7 deletions gpkit/nomials/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ class NomialData(ReprMixin):

def __init__(self, hmap):
self.hmap = hmap
self.vks = set()
for exp in self.hmap:
self.vks.update(exp)
self.units = self.hmap.units
self.any_nonpositive_cs = any(c <= 0 for c in self.hmap.values())

Expand Down Expand Up @@ -48,11 +45,17 @@ def __hash__(self):
return hash(self.hmap)

@property
def vks(self):
"Set of a NomialData's varkeys, created as necessary."
vks = set()
for exp in self.hmap:
vks.update(exp)
return vks

@property # TODO: remove this
def varkeys(self):
"The NomialData's varkeys, created when necessary for a substitution."
if self._varkeys is None:
self._varkeys = KeySet(self.vks)
return self._varkeys
"KeySet of a NomialData's varkeys, created as necessary."
return KeySet(self.vks)

def __eq__(self, other):
"Equality test"
Expand Down
6 changes: 2 additions & 4 deletions gpkit/nomials/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def sub(self, substitutions, varkeys, parsedsubs=False):
for vk in varlocs:
exps, cval = varlocs[vk], fixed[vk]
if hasattr(cval, "hmap"):
if any(cval.hmap.keys()):
if cval.hmap is None or any(cval.hmap.keys()):
raise ValueError("Monomial substitutions are not"
" supported.")
cval, = cval.hmap.to(vk.units or DIMLESS_QUANTITY).values()
Expand All @@ -152,18 +152,16 @@ def mmap(self, orig):
mapping the indices corresponding to the old exps to their
fraction of the post-substitution coefficient
"""
m_from_ms = defaultdict(dict)
pmap = [{} for _ in self]
origexps = list(orig.keys())
selfexps = list(self.keys())
for orig_exp, self_exp in self.expmap.items():
if self_exp not in self: # can occur in tautological constraints
continue # after substitution
fraction = self.csmap.get(orig_exp, orig[orig_exp])/self[self_exp]
m_from_ms[self_exp][orig_exp] = fraction
orig_idx = origexps.index(orig_exp)
pmap[selfexps.index(self_exp)][orig_idx] = fraction
return pmap, m_from_ms
return pmap


# pylint: disable=invalid-name
Expand Down
Loading

0 comments on commit a436015

Please sign in to comment.