Skip to content

Commit

Permalink
address single-variable linked sweeps
Browse files Browse the repository at this point in the history
  • Loading branch information
bqpd committed Aug 6, 2020
1 parent fe9c591 commit b7fa5b3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 22 deletions.
24 changes: 6 additions & 18 deletions gpkit/constraints/prog_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,22 @@

def evaluate_linked(constants, linked):
"Evaluates the values and gradients of linked variables."
kdc = KeyDict({k: adnumber(maybe_flatten(v))
kdc = KeyDict({k: adnumber(maybe_flatten(v), k)
for k, v in constants.items()})
kdc.log_gets = True
kdc_plain = None
array_calulated, logged_array_gets = {}, {}
array_calulated = {}
for v, f in linked.items():
try:
if v.veckey and v.veckey.original_fn:
if v.veckey not in array_calulated:
ofn = v.veckey.original_fn
array_calulated[v.veckey] = np.array(ofn(kdc))
logged_array_gets[v.veckey] = kdc.logged_gets
logged_gets = logged_array_gets[v.veckey]
out = array_calulated[v.veckey][v.idx]
else:
logged_gets = kdc.logged_gets
out = f(kdc)
constants[v] = out.x
v.descr["gradients"] = {}
for key in logged_gets:
if key.shape:
grad = out.gradient(kdc[key])
v.gradients[key] = np.array(grad)
else:
v.gradients[key] = out.d(kdc[key])
v.descr["gradients"] = {adn.tag: grad
for adn, grad in out.d().items()}
except Exception as exception: # pylint: disable=broad-except
from .. import settings
if settings.get("ad_errors_raise", None):
Expand All @@ -48,8 +39,6 @@ def evaluate_linked(constants, linked):
kdc_plain = KeyDict(constants)
constants[v] = f(kdc_plain)
v.descr.pop("gradients", None)
finally:
kdc.logged_gets = set()


def progify(program, return_attr=None):
Expand Down Expand Up @@ -149,8 +138,7 @@ def run_sweep(genfunction, self, solution, skipsweepfailures,
constants.update({var: sweep_vect[i]
for (var, sweep_vect) in sweep_vects.items()})
if linked:
kdc = KeyDict(constants)
constants.update({v: f(kdc) for v, f in linked.items()})
evaluate_linked(constants, linked)
program, solvefn = genfunction(self, constants)
self.program.append(program) # NOTE: SIDE EFFECTS
try:
Expand All @@ -173,7 +161,7 @@ def run_sweep(genfunction, self, solution, skipsweepfailures,
if var in ksweep:
solution["sweepvariables"][var] = val
del solution["constants"][var]
elif var not in linked:
elif (val[0] == val[1:]).all():
solution["constants"][var] = [val[0]]

if verbosity > 0:
Expand Down
8 changes: 4 additions & 4 deletions gpkit/keydict.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def __init__(self, *args, **kwargs):
self.keymap = defaultdict(set)
self._unmapped_keys = set()
self.owned = set()
self.logged_gets = set()
self.update(*args, **kwargs) # pylint: disable=no-member

def parse_and_index(self, key):
Expand Down Expand Up @@ -175,8 +174,6 @@ def __getitem__(self, key):
raise KeyError(key)
got = {}
for k in keys:
if self.log_gets:
self.logged_gets.add(k)
if not idx and k.shape:
self._copyonwrite(k)
val = dict.__getitem__(self, k)
Expand Down Expand Up @@ -207,8 +204,11 @@ def __setitem__(self, key, value):
super().__setitem__(key, np.array(old, "object"))
self.owned.add(key)
self._copyonwrite(key)
if hasattr(value, "__call__"): # a linked function
old = super().__getitem__(key)
super().__setitem__(key, np.array(old, dtype="object"))
super().__getitem__(key)[idx] = value
return # succefully set a single index!
return # successfully set a single index!
if key.shape: # now if we're setting an array...
if getattr(value, "shape", None): # is the value an array?
if value.dtype == INT_DTYPE:
Expand Down

0 comments on commit b7fa5b3

Please sign in to comment.