Skip to content

Commit

Permalink
fix minimize
Browse files Browse the repository at this point in the history
  • Loading branch information
papajohn committed Mar 23, 2016
1 parent 925ba9e commit 4f4fc75
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 13 deletions.
7 changes: 5 additions & 2 deletions datascience/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def where(self, column_or_label, value_or_predicate=None):
Red | Round | 4 | 1.3
Red | Round | 7 | 1.75
Green | Round | 2 | 1
>>> marbles.where(marbles.column("Shape") == "Round") # equivalent to the previous example
>>> marbles.where(marbles.column("Shape") == "Round") # equivalent to previous example
Color | Shape | Amount | Price
Red | Round | 4 | 1.3
Red | Round | 7 | 1.75
Expand All @@ -658,8 +658,11 @@ def where(self, column_or_label, value_or_predicate=None):
Color | Shape | Amount | Price
Blue | Rectangular | 12 | 2
Red | Round | 7 | 1.75
You can also use predicates to simplify single-column comparisons.
>>> from datascience.predicates import are
>>> marbles.where("Price", are.above(1.5))
>>> marbles.where("Price", are.above(1.5)) # equivalent to previous example
Color | Shape | Amount | Price
Blue | Rectangular | 12 | 2
Red | Round | 7 | 1.75
Expand Down
21 changes: 14 additions & 7 deletions datascience/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def table_apply(table, func, subset=None):
tab = Table.from_df(df)
return tab

def minimize(f, start=None, smooth=False, log=None, **vargs):
def minimize(f, start=None, smooth=False, log=None, array=False, **vargs):
"""Minimize a function f of one or more arguments.
Args:
Expand All @@ -137,20 +137,27 @@ def minimize(f, start=None, smooth=False, log=None, **vargs):
(b) an array of minimizing arguments of a multi-argument function
"""
if start is None:
assert not array, "Please pass starting values explicitly when array=True"
arg_count = f.__code__.co_argcount
assert arg_count > 0, "Please pass starting values explicitly"
assert arg_count > 0, "Please pass starting values explicitly for variadic functions"
start = [0] * arg_count
if not hasattr(start, '__len__'):
start = [start]

@functools.wraps(f)
def wrapper(args):
return f(*args)
if array:
objective = f
else:
@functools.wraps(f)
def objective(args):
return f(*args)

if not smooth and 'method' not in vargs:
vargs['method'] = 'Powell'
result = optimize.minimize(wrapper, start, **vargs)
result = optimize.minimize(objective, start, **vargs)
if log is not None:
log(result)
return result.x
if len(start) == 1:
return result.x.item(0)
else:
return result.x

24 changes: 20 additions & 4 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,24 @@ def test_table_apply():
assert all(newtab['a'] == tab['a'])
assert all(newtab['b'] == tab['b'] + 1)

def _round_eq(a, b):
if hasattr(a, '__len__'):
return all(a == np.round(b))
else:
return (a == np.round(b)) == True

def test_minimize():
assert (2 == ds.minimize(lambda x: (x-2)**2)) == True
assert [2, 1] == list(ds.minimize(lambda x, y: (x-2)**2 + (y-1)**2))
assert (2 == ds.minimize(lambda x: (x-2)**2, 1)) == True
assert [2, 1] == list(ds.minimize(lambda x, y: (x-2)**2 + (y-1)**2, [1, 1]))
assert _round_eq(2, ds.minimize(lambda x: (x-2)**2))
assert _round_eq([2, 1], list(ds.minimize(lambda x, y: (x-2)**2 + (y-1)**2)))
assert _round_eq(2, ds.minimize(lambda x: (x-2)**2, 1))
assert _round_eq([2, 1], list(ds.minimize(lambda x, y: (x-2)**2 + (y-1)**2, [1, 1])))

def test_minimize_smooth():
assert _round_eq(2, ds.minimize(lambda x: (x-2)**2, smooth=True))
assert _round_eq([2, 1], list(ds.minimize(lambda x, y: (x-2)**2 + (y-1)**2, smooth=True)))
assert _round_eq(2, ds.minimize(lambda x: (x-2)**2, 1, smooth=True))
assert _round_eq([2, 1], list(ds.minimize(lambda x, y: (x-2)**2 + (y-1)**2, [1, 1], smooth=True)))

def test_minimize_array():
assert _round_eq(2, ds.minimize(lambda x: (x[0]-2)**2, [0], array=True))
assert _round_eq([2, 1], list(ds.minimize(lambda x: (x[0]-2)**2 + (x[1]-1)**2, [0, 0], array=True)))

0 comments on commit 4f4fc75

Please sign in to comment.