diff --git a/drf_braces/fields/custom.py b/drf_braces/fields/custom.py index 12f4b32..af6acfc 100644 --- a/drf_braces/fields/custom.py +++ b/drf_braces/fields/custom.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, print_function, unicode_literals import inspect +from decimal import Decimal, getcontext import pytz import six @@ -69,8 +70,9 @@ class RoundedDecimalField(fields.DecimalField): to two decimal places. """ - def __init__(self, max_digits=None, decimal_places=2, *args, **kwargs): - max_digits = max_digits or self.MAX_STRING_LENGTH + def __init__(self, max_digits=None, decimal_places=2, rounding=None, *args, **kwargs): + self.rounding = rounding + super(RoundedDecimalField, self).__init__( max_digits=max_digits, decimal_places=decimal_places, @@ -83,6 +85,24 @@ def to_internal_value(self, data): def validate_precision(self, data): return data + def quantize(self, data): + """ + Quantize the decimal value to the configured precision. + """ + if self.decimal_places is None: + return data + + context = getcontext().copy() + + if self.max_digits is not None: + context.prec = self.max_digits + if self.rounding is not None: + context.rounding = self.rounding + return data.quantize( + Decimal('.1') ** self.decimal_places, + context=context + ) + __all__ = [name for name, value in locals().items() if inspect.isclass(value) and issubclass(value, fields.Field)] diff --git a/drf_braces/tests/fields/test_custom.py b/drf_braces/tests/fields/test_custom.py index 355010c..e9f9980 100644 --- a/drf_braces/tests/fields/test_custom.py +++ b/drf_braces/tests/fields/test_custom.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, print_function, unicode_literals import unittest from collections import OrderedDict -from decimal import Decimal +from decimal import ROUND_DOWN, Decimal import mock import pytz @@ -62,8 +62,11 @@ def test_to_internal_value(self): class TestRoundedDecimalField(unittest.TestCase): def test_init(self): field = RoundedDecimalField() - self.assertIsNotNone(field.max_digits) self.assertEqual(field.decimal_places, 2) + self.assertIsNone(field.rounding) + + new_field = RoundedDecimalField(rounding=ROUND_DOWN) + self.assertEqual(new_field.rounding, ROUND_DOWN) def test_to_internal_value(self): field = RoundedDecimalField() @@ -83,3 +86,9 @@ def test_to_internal_value(self): self.assertEqual(field.to_internal_value(Decimal('5.2345')), Decimal('5.23')) self.assertEqual(field.to_internal_value(Decimal('5.2356')), Decimal('5.24')) self.assertEqual(field.to_internal_value(Decimal('4.2399')), Decimal('4.24')) + + floored_field = RoundedDecimalField(rounding=ROUND_DOWN) + self.assertEqual(floored_field.to_internal_value(5.2345), Decimal('5.23')) + self.assertEqual(floored_field.to_internal_value(5.2356), Decimal('5.23')) + self.assertEqual(floored_field.to_internal_value(Decimal('5.2345')), Decimal('5.23')) + self.assertEqual(floored_field.to_internal_value(Decimal('5.2356')), Decimal('5.23'))