Skip to content

Commit

Permalink
Adds display support
Browse files Browse the repository at this point in the history
Adds a display option on the Optimizer class (defaults to True)
that during optimization (when the run() method is called), prints
some diagnostic information to the standard output
  • Loading branch information
Niru Maheswaranathan committed May 2, 2016
1 parent da02ad1 commit c5fb89a
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ htmlcov/
cover/
docs/_build
build/
dist/
2 changes: 1 addition & 1 deletion descent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
from .utils import *
from .main import *

__version__ = '0.1.4'
__version__ = '0.1.5'
29 changes: 23 additions & 6 deletions descent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,31 @@
from collections import namedtuple, defaultdict
from .utils import wrap, restruct, destruct
import numpy as np
import tableprint as tp
try:
from time import perf_counter
except ImportError:
from time import time as perf_counter


class Optimizer(object):

def __init__(self, theta_init):
def __init__(self, theta_init, display=True):
self.iteration = 0
self.theta = theta_init
self.runtimes = list()
self.store = defaultdict(list)
self.display = display

def __next__(self):
raise NotImplementedError

def run(self, maxiter=None):
def optional_print(self, message):
print(message, flush=True) if self.display else None

def run(self, maxiter=None):
maxiter = np.inf if maxiter is None else (maxiter + self.iteration)

try:
self.optional_print(tp.header(['Iteration', 'Objective', 'Runtime']))
for k in count(start=self.iteration):

self.iteration = k
Expand All @@ -41,21 +44,35 @@ def run(self, maxiter=None):
# TODO: run callbacks
self.store['objective'].append(self.objective(destruct(self.theta)))

# Update display
self.optional_print(tp.row([self.iteration,
self.store['objective'][-1],
tp.humantime(self.runtimes[-1])]))

# TODO: check for convergence
if k >= maxiter:
break

except KeyboardInterrupt:
pass

# cleanup
self.optional_print(tp.hr(3))
self.optional_print(u'\u279b Final objective: {}'.format(self.store['objective'][-1]))
self.optional_print(u'\u279b Total runtime: {}'.format(tp.humantime(sum(self.runtimes))))
self.optional_print(u'\u279b Per iteration runtime: {} +/- {}'.format(
tp.humantime(np.mean(self.runtimes)),
tp.humantime(np.std(self.runtimes)),
))

def restruct(self, x):
return restruct(x, self.theta)


@implements_iterator
class GradientDescent(Optimizer):

def __init__(self, theta_init, f_df, algorithm, options, proxop=None, rho=None):
def __init__(self, theta_init, f_df, algorithm, options=None, proxop=None, rho=None):
options = {} if options is None else options

super().__init__(theta_init)
self.objective, self.gradient = wrap(f_df, theta_init)
Expand Down

0 comments on commit c5fb89a

Please sign in to comment.