Skip to content

Commit

Permalink
Logistic
Browse files Browse the repository at this point in the history
  • Loading branch information
linbrian committed Dec 4, 2018
1 parent 89163d9 commit 8eec983
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
22 changes: 22 additions & 0 deletions autodiff/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,28 @@ def tanh(sclr):
"""
return sinh(sclr) / cosh(sclr);


@vectorize
def logistic(sclr):
"""
This function takes in an int, float, or Scalar object and applies the logistic function to the value. If the argument is an int or float, then the function returns a float. If the argument is a Scalar object, the function returns a new Scalar object with the updated value and derivative.
INPUTS
=======
sclr: An int, float, or Scalar object on which the the logistic function will be applied.
RETURNS
========
float, Scalar
A float is returned if the input is an int/float. A new Scalar object, resulting from applying the logistic function to 'sclr', is returned if the input is a Scalar object .
NOTES
=====
POST:
- 'sclr' is not changed by the function
- returns a float or Scalar object, resulting from applying the logistic function to 'sclr'.
"""

return 1 / (1 + exp(-sclr))


@vectorize
def log(sclr, base):
"""
Expand Down
26 changes: 26 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,32 @@ def test_exp():
assert(np.isclose(z.getDeriv()['x'], 5 * np.exp(-15)))
assert(np.isclose(z.getDeriv()['y'], -3 * np.exp(-15)))

def test_logistic():
x = ad.Scalar('x', 8)
y = ad.logistic(x)
assert(np.isclose(y.getValue(), 1 / (1 + np.exp(-8))))
assert(np.isclose(y.getDeriv()['x'], (1 / (1 + np.exp(-8))) * (1 - 1 / (1 + np.exp(-8)))))

x = ad.Scalar('x', 0)
y = ad.logistic(x)
assert(np.isclose(y.getValue(), 0.5))
assert(np.isclose(y.getDeriv()['x'], (0.25)))

x = ad.Scalar('x', 5)
y = ad.Scalar('y', 2)
z = ad.logistic(x * y)
assert(np.isclose(z.getValue(), 1 / (1 + np.exp(-10))))
assert(np.isclose(z.getDeriv()['x'], 2 * (1 / (1 + np.exp(-10))) * (1 - 1 / (1 + np.exp(-10)))))
assert(np.isclose(z.getDeriv()['y'], 5 * (1 / (1 + np.exp(-10))) * (1 - 1 / (1 + np.exp(-10)))))

x = ad.Scalar('x', 8)
y = ad.logistic(2 * x)
assert(np.isclose(y.getValue(), 1 / (1 + np.exp(-16))))
assert(np.isclose(y.getDeriv()['x'], 2 * (1 / (1 + np.exp(-16))) * (1 - 1 / (1 + np.exp(-16)))))

assert(np.isclose(ad.logistic(0), 0.5))



def test_power():
assert(ad.power(5, 3) == 125.0)
Expand Down

0 comments on commit 8eec983

Please sign in to comment.