Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions drf_braces/fields/custom.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import absolute_import, print_function, unicode_literals
import inspect
from decimal import Decimal, getcontext

import pytz
import six
Expand Down Expand Up @@ -69,8 +70,9 @@ class RoundedDecimalField(fields.DecimalField):
to two decimal places.
"""

def __init__(self, max_digits=None, decimal_places=2, *args, **kwargs):
max_digits = max_digits or self.MAX_STRING_LENGTH
def __init__(self, max_digits=None, decimal_places=2, rounding=None, *args, **kwargs):
self.rounding = rounding

super(RoundedDecimalField, self).__init__(
max_digits=max_digits,
decimal_places=decimal_places,
Expand All @@ -83,6 +85,24 @@ def to_internal_value(self, data):
def validate_precision(self, data):
return data

def quantize(self, data):
"""
Quantize the decimal value to the configured precision.
"""
if self.decimal_places is None:
return data

context = getcontext().copy()

if self.max_digits is not None:
context.prec = self.max_digits
if self.rounding is not None:
context.rounding = self.rounding
return data.quantize(
Decimal('.1') ** self.decimal_places,
context=context
)


__all__ = [name for name, value in locals().items()
if inspect.isclass(value) and issubclass(value, fields.Field)]
13 changes: 11 additions & 2 deletions drf_braces/tests/fields/test_custom.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import absolute_import, print_function, unicode_literals
import unittest
from collections import OrderedDict
from decimal import Decimal
from decimal import ROUND_DOWN, Decimal

import mock
import pytz
Expand Down Expand Up @@ -62,8 +62,11 @@ def test_to_internal_value(self):
class TestRoundedDecimalField(unittest.TestCase):
def test_init(self):
field = RoundedDecimalField()
self.assertIsNotNone(field.max_digits)
self.assertEqual(field.decimal_places, 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add assert about rounding being set

self.assertIsNone(field.rounding)

new_field = RoundedDecimalField(rounding=ROUND_DOWN)
self.assertEqual(new_field.rounding, ROUND_DOWN)

def test_to_internal_value(self):
field = RoundedDecimalField()
Expand All @@ -83,3 +86,9 @@ def test_to_internal_value(self):
self.assertEqual(field.to_internal_value(Decimal('5.2345')), Decimal('5.23'))
self.assertEqual(field.to_internal_value(Decimal('5.2356')), Decimal('5.24'))
self.assertEqual(field.to_internal_value(Decimal('4.2399')), Decimal('4.24'))

floored_field = RoundedDecimalField(rounding=ROUND_DOWN)
self.assertEqual(floored_field.to_internal_value(5.2345), Decimal('5.23'))
self.assertEqual(floored_field.to_internal_value(5.2356), Decimal('5.23'))
self.assertEqual(floored_field.to_internal_value(Decimal('5.2345')), Decimal('5.23'))
self.assertEqual(floored_field.to_internal_value(Decimal('5.2356')), Decimal('5.23'))