Skip to content

Commit

Permalink
Add initial unit test for Problem class
Browse files Browse the repository at this point in the history
Signed-off-by: Niklas Koep <niklas.koep@gmail.com>
  • Loading branch information
nkoep committed Feb 16, 2016
1 parent e052763 commit 4bce044
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 20 deletions.
38 changes: 18 additions & 20 deletions pymanopt/tools/autodiff/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,19 @@


def compile(problem, need_grad, need_hess):
# Conditionally load autodiff backend if needed
if ((need_grad and problem.grad is None and problem.egrad is None) or
(need_hess and problem.hess is None and problem.ehess is None)):
if isinstance(problem.cost, T.TensorVariable):
if not isinstance(problem.arg, T.TensorVariable):
raise ValueError(
"Theano backend requires an argument with respect to "
"which compilation of the cost function is to be carried "
"out")
backend = _theano
elif callable(problem.cost):
backend = _autograd
else:
raise ValueError("Cannot identify autodiff backend from cost "
"variable.")
# Conditionally load autodiff backend if needed.
if isinstance(problem.cost, T.TensorVariable):
if not isinstance(problem.arg, T.TensorVariable):
raise ValueError(
"Theano backend requires an argument with respect to "
"which compilation of the cost function is to be carried "
"out")
backend = _theano
elif callable(problem.cost):
backend = _autograd
else:
raise ValueError("Cannot identify autodiff backend from cost "
"variable.")

if problem.verbosity >= 1:
print("Compiling cost function...")
Expand All @@ -28,10 +26,10 @@ def compile(problem, need_grad, need_hess):
if problem.verbosity >= 1:
print("Computing gradient of cost function...")
problem.egrad = backend.gradient(problem.cost, problem.arg)
# Assume if Hessian is needed gradient is as well
if need_hess and problem.ehess is None and problem.hess is None:
if problem.verbosity >= 1:
print("Computing Hessian of cost function...")
problem.ehess = backend.hessian(problem.cost, problem.arg)

if need_hess and problem.ehess is None and problem.hess is None:
if problem.verbosity >= 1:
print("Computing Hessian of cost function...")
problem.ehess = backend.hessian(problem.cost, problem.arg)

problem.cost = compiled_cost_function
28 changes: 28 additions & 0 deletions tests/test_problem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import unittest

import numpy as np
import numpy.linalg as la
import numpy.random as rnd
import numpy.testing as np_testing

import theano.tensor as T

import warnings

from pymanopt import Problem
from pymanopt.manifolds import Sphere


class TestProblem(unittest.TestCase):
def setUp(self):
self.X = X = T.vector()
self.cost = T.exp(T.sum(X**2))

n = self.n = 15

self.man = Sphere(n)

def test_compile(self):
problem = Problem(man=self.man, cost=self.cost)
with self.assertRaises(ValueError):
problem.prepare()

0 comments on commit 4bce044

Please sign in to comment.