-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding gradient descent optimisation algorithm for single-variable fu…
…nctions
- Loading branch information
Showing
7 changed files
with
137 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,25 @@ | ||
# Files | ||
|
||
ands/algorithms/greedy/huffman.py | ||
ands/ds/Graph.py | ||
ands/ds/DirectedGraph.py | ||
ands/ds/UndirectedGraph.py | ||
|
||
ands/algorithms/dp/previous_larger_element.py | ||
|
||
ands/algorithms/greedy/huffman.py | ||
/ands/ds/MinPriorityQueue.py | ||
|
||
ands/algorithms/graphs/ | ||
ands/algorithms/unclassified/ | ||
/ands/algorithms/math/combinatorics/ | ||
/ands/algorithms/math/ | ||
ands/algorithms/dp/previous_larger_element.py | ||
|
||
*.eggx | ||
*.py[cod] | ||
*$py.class | ||
.DS_Store | ||
|
||
notes/ | ||
.idea/ | ||
_ignore/ | ||
venv/ | ||
*.egg-info/ | ||
__pycache__/ | ||
|
||
notes/ | ||
.idea/ | ||
_ignore/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
|
||
|
||
""" | ||
# Meta-info | ||
Author: Nelson Brochado | ||
Created: 14/10/2017 | ||
Updated: 26/10/2017 | ||
# Description | ||
An implementation of the gradient descent method for finding local minima of | ||
single-variable functions. | ||
# References | ||
- https://en.wikipedia.org/wiki/Gradient_descent | ||
""" | ||
|
||
__all__ = ["gradient_descent"] | ||
|
||
|
||
def gradient_descent(x0: float, | ||
df: callable, | ||
step_size: float = 0.01, | ||
max_iter: int = 50, | ||
tol: float = 1e-6): | ||
"""Finds a local minimum of a function whose derivative is df starting from | ||
an initial guess x0 using a step size = step_size.""" | ||
|
||
# From calculation, it is expected that the local minimum occurs at x=9/4 | ||
if not callable(df): | ||
raise TypeError("df must be a callable object.") | ||
|
||
x = x0 | ||
|
||
for _ in range(max_iter): | ||
x_next = x + -step_size * df(x) # Gradient descent step. | ||
|
||
if abs(x_next - x) < tol * abs(x_next): | ||
x = x_next | ||
break | ||
|
||
x = x_next | ||
|
||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
|
||
|
||
""" | ||
# Meta-info | ||
Author: Nelson Brochado | ||
Created: 26/10/2017 | ||
Updated: 26/10/2017 | ||
# Description | ||
Unittests for the functions inside ands.algorithms.numerical.gradient_descent.py. | ||
""" | ||
|
||
import unittest | ||
|
||
from ands.algorithms.numerical.gradient_descent import * | ||
|
||
''' | ||
def f(x: float) -> float: | ||
return x ** 4 - 3 * x ** 3 + 2 | ||
''' | ||
|
||
|
||
def df(x: float) -> float: | ||
"""Derivative of f.""" | ||
return 4 * x ** 3 - 9 * x ** 2 | ||
|
||
|
||
class TestGradientDescent(unittest.TestCase): | ||
def test_type_error_when_df_not_callable(self): | ||
self.assertRaises(TypeError, gradient_descent, 0.3, 5) | ||
|
||
def test_find_local_min_of_f(self): | ||
self.assertAlmostEqual(gradient_descent(6, df), 2.24674, 5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters