Skip to content

Commit

Permalink
Adding start of a high-precision DG1 helper.
Browse files Browse the repository at this point in the history
  • Loading branch information
dhermes committed Feb 27, 2016
1 parent ae475af commit e48a55d
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 1 deletion.
61 changes: 61 additions & 0 deletions assignment1/dg1_high_prec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Helpers to use :mod:`assignment1.dg1` with high-precision numbers.
High-precision is achieved by using :mod:`mpmath`.
"""

import mpmath
import numpy as np


_VECTORIZED_EXP = np.vectorize(mpmath.exp)


class HighPrecProvider(object):
"""High-precision replacement for :class:`assignment1.dg1.MathProvider`.
Implements interfaces that are essentially identical (at least up to the
usage in :mod:`dg1 <assignment1.dg1>`) as those provided by NumPy.
All matrices returned are :class:`numpy.ndarray` with :class:`mpmath.mpf`
as the data type and all matrix inputs are assumed to be of the same form.
"""

@staticmethod
def exp_func(value):
"""Vectorized exponential function."""
return _VECTORIZED_EXP(value)

@staticmethod
def linspace(start, stop, num=50):
"""Linearly spaced points.
Points are computed with :func:`mpmath.linspace` but the
output (a ``list``) is converted back to a :class:`numpy.ndarray`.
"""
return np.array(mpmath.linspace(start, stop, num))

@staticmethod
def num_type(value):
"""The high-precision numerical type: :class:`mpmath.mpf`."""
return mpmath.mpf(value)

@staticmethod
def mat_inv(mat):
"""Matrix inversion, using :mod:`mpmath`."""
inv_mpmath = mpmath.matrix(mat.tolist())**(-1)
return np.array(inv_mpmath.tolist())

@staticmethod
def solve(left_mat, right_mat):
"""Solve ``Ax = b`` for ``x``.
``A`` is given by ``left_mat`` and ``b`` by ``right_mat``.
"""
raise NotImplementedError

@staticmethod
def zeros(shape, **kwargs):
"""Produce a matrix of zeros of a given shape."""
result = np.empty(shape, dtype=mpmath.mpf, **kwargs)
result.fill(mpmath.mpf('0.0'))
return result
109 changes: 109 additions & 0 deletions assignment1/test_dg1_high_prec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import unittest


class TestHighPrecProvider(unittest.TestCase):

@staticmethod
def _get_target_class():
from assignment1.dg1_high_prec import HighPrecProvider
return HighPrecProvider

def test_exp_func(self):
import mpmath
import numpy as np

exp_func = self._get_target_class().exp_func
with mpmath.mp.workprec(100):
scalar_val = mpmath.log('2.0')
result = exp_func(scalar_val)
self.assertEqual(result, mpmath.mpf('2.0'))
mat_val = np.array([
[mpmath.log('2.0'), mpmath.log('3.0'), mpmath.log('4.0')],
[mpmath.log('5.0'), mpmath.log('6.0'), mpmath.log('7.0')],
])
result = exp_func(mat_val)
expected_result = np.array([
[mpmath.mpf('2.0'), mpmath.mpf('3.0'), mpmath.mpf('4.0')],
[mpmath.mpf('5.0'), mpmath.mpf('6.0'), mpmath.mpf('7.0')],
])
self.assertTrue(np.all(result == expected_result))

def test_linspace(self):
import mpmath
import numpy as np

linspace = self._get_target_class().linspace

result1 = linspace(0, 1, 5)
self.assertTrue(np.all(result1 == [0, 0.25, 0.5, 0.75, 1.0]))

with mpmath.mp.workprec(100):
result2 = linspace(0, 1, 12)
result3 = linspace(mpmath.mpf('0'), mpmath.mpf('1'), 12)
self.assertTrue(np.all(result2 == result3))
expected_result = np.array([
mpmath.mpf('0/11'), mpmath.mpf('1/11'), mpmath.mpf('2/11'),
mpmath.mpf('3/11'), mpmath.mpf('4/11'), mpmath.mpf('5/11'),
mpmath.mpf('6/11'), mpmath.mpf('7/11'), mpmath.mpf('8/11'),
mpmath.mpf('9/11'), mpmath.mpf('10/11'), mpmath.mpf('11/11'),
])
self.assertTrue(np.all(result2 == expected_result))

def test_num_type(self):
import mpmath

num_type = self._get_target_class().num_type
self.assertIsInstance(num_type(0), mpmath.mpf)
self.assertIsInstance(num_type(1.0), mpmath.mpf)
self.assertIsInstance(num_type('2.1'), mpmath.mpf)

def test_mat_inv(self):
import mpmath
import numpy as np

mat_inv = self._get_target_class().mat_inv
sq_mat = np.array([
[mpmath.mpf('1'), mpmath.mpf('2')],
[mpmath.mpf('3'), mpmath.mpf('4')],
])
inv_val = mat_inv(sq_mat)
# Check the type of the output.
self.assertIsInstance(inv_val, np.ndarray)
self.assertEqual(inv_val.shape, (2, 2))
all_types = set([type(val) for val in inv_val.flatten()])
self.assertEqual(all_types, set([mpmath.mpf]))

# Check the actual result.
expected_result = np.array([
[mpmath.mpf('-2.0'), mpmath.mpf('1.0')],
[mpmath.mpf('1.5'), mpmath.mpf('-0.5')],
])
delta = np.abs(inv_val - expected_result)
self.assertLess(np.max(delta), 1e-10)

def test_solve(self):
solve = self._get_target_class().solve
with self.assertRaises(NotImplementedError):
solve(None, None)

def test_zeros(self):
import mpmath
import numpy as np

zeros = self._get_target_class().zeros
mat1 = zeros(3)
self.assertIsInstance(mat1, np.ndarray)
self.assertEqual(mat1.shape, (3,))
self.assertEqual(mat1.dtype, object)
self.assertTrue(np.all(mat1 == mpmath.mpf('0.0')))
all_types = set([type(val) for val in mat1.flatten()])
self.assertEqual(all_types, set([mpmath.mpf]))

mat2 = zeros((3, 7), order='F')
self.assertIsInstance(mat2, np.ndarray)
self.assertEqual(mat2.shape, (3, 7))
self.assertEqual(mat2.dtype, object)
self.assertTrue(np.all(mat2 == mpmath.mpf('0.0')))
all_types = set([type(val) for val in mat2.flatten()])
self.assertEqual(all_types, set([mpmath.mpf]))
self.assertTrue(mat2.flags.f_contiguous)
7 changes: 7 additions & 0 deletions docs/assignment1.dg1_high_prec.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
assignment1.dg1_high_prec module
================================

.. automodule:: assignment1.dg1_high_prec
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/assignment1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Submodules
.. toctree::

assignment1.dg1
assignment1.dg1_high_prec
assignment1.dg1_symbolic
assignment1.plotting

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


REQUIREMENTS = (
'mpmath',
'numpy',
'six >= 1.6.1',
'sympy',
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ passenv = TRAVIS*
basepython =
python2.7
commands =
sphinx-apidoc --separate --force -o docs/ . setup.py assignment1/test_dg1.py assignment1/test_dg1_symbolic.py assignment1/test_plotting.py
sphinx-apidoc --separate --force -o docs/ . setup.py assignment1/test_dg1.py assignment1/test_dg1_high_prec.py assignment1/test_dg1_symbolic.py assignment1/test_plotting.py
python -c "import os; os.remove('docs/modules.rst')"
python -c "import shutil; shutil.rmtree('docs/_build', ignore_errors=True)"
sphinx-build -W -b html -d docs/_build/doctrees docs docs/_build/html
Expand Down

0 comments on commit e48a55d

Please sign in to comment.