Skip to content

Commit

Permalink
add warning when AD encounters monomial
Browse files Browse the repository at this point in the history
  • Loading branch information
bqpd committed Jun 29, 2021
1 parent 2a41a97 commit 041c504
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
16 changes: 10 additions & 6 deletions gpkit/constraints/prog_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def evaluate_linked(constants, linked):
"Linked function for %s did not return a united value."
"Modifying it to do so (e.g. by using `()` instead of `[]`"
" to access variables) will reduce errors." % v)
if hasattr(out, "__len__"):
out = out.item() # break out of 0-dimensional arrays
out = maybe_flatten(out)
if not hasattr(out, "x"):
constants[v] = out
continue # a new fixed variable, not a calculated one
Expand All @@ -57,14 +56,19 @@ def evaluate_linked(constants, linked):
from .. import settings
if settings.get("ad_errors_raise", None):
raise
print("Warning: skipped auto-differentiation of linked variable"
" %s because %s was raised. Set `gpkit.settings"
"[\"ad_errors_raise\"] = True` to raise such Exceptions"
" directly.\n" % (v, repr(exception)))
if kdc_plain is None:
kdc_plain = KeyDict(constants)
constants[v] = f(kdc_plain)
v.descr.pop("gradients", None)
print("Warning: skipped auto-differentiation of linked variable"
" %s because %s was raised. Set `gpkit.settings"
"[\"ad_errors_raise\"] = True` to raise such Exceptions"
" directly.\n" % (v, repr(exception)))
if ("Automatic differentiation not yet supported for <class "
"'gpkit.nomials.math.Monomial'> objects") in str(exception):
print("This particular warning may have come from using"
" gpkit.units.* in the function for %s; try using"
" gpkit.ureg.* or gpkit.units.*.units instead." % v)


def progify(program, return_attr=None):
Expand Down
4 changes: 2 additions & 2 deletions gpkit/small_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def isnan(element):

def maybe_flatten(value):
"Extract values from 0-d numpy arrays, if necessary"
if hasattr(value, "shape") and not value.shape:
return value.flatten()[0] # 0-d numpy arrays
if hasattr(value, "size") and value.size == 1:
return value.item()
return value


Expand Down

0 comments on commit 041c504

Please sign in to comment.