Skip to content

Commit

Permalink
fix: Missing the right Money.decimal_places and `Money.decimal_plac…
Browse files Browse the repository at this point in the history
…es_display` values after some arithmetic operations.

Ref: #595
  • Loading branch information
Stranger6667 committed Jan 10, 2021
1 parent 9ac8d8b commit 69e89f9
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 35 deletions.
57 changes: 47 additions & 10 deletions djmoney/money.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,44 +39,54 @@ def decimal_places_display(self, value):
""" Set number of digits being displayed - `None` resets to `DECIMAL_PLACES_DISPLAY` setting """
self._decimal_places_display = value

def _fix_decimal_places(self, *args):
""" Make sure to user 'biggest' number of decimal places of all given money instances """
candidates = (getattr(candidate, "decimal_places", 0) for candidate in args)
return max([self.decimal_places, *candidates])
def _copy_attributes(self, source, target):
"""Copy attributes to the new `Money` instance.
This class stores extra bits of information about string formatting that the parent class doesn't have.
The problem is that the parent class creates new instances of `Money` without in some of its methods and
it does so without knowing about `django-money`-level attributes.
For this reason, when this class uses some methods of the parent class that have this behavior, the resulting
instances lose those attribute values.
When it comes to what number of decimal places to choose, we take the maximum number.
"""
for attribute_name in ("decimal_places", "decimal_places_display"):
value = max([getattr(candidate, attribute_name, 0) for candidate in (self, source)])
setattr(target, attribute_name, value)

def __add__(self, other):
if isinstance(other, F):
return other.__radd__(self)
other = maybe_convert(other, self.currency)
result = super().__add__(other)
result.decimal_places = self._fix_decimal_places(other)
self._copy_attributes(other, result)
return result

def __sub__(self, other):
if isinstance(other, F):
return other.__rsub__(self)
other = maybe_convert(other, self.currency)
result = super().__sub__(other)
result.decimal_places = self._fix_decimal_places(other)
self._copy_attributes(other, result)
return result

def __mul__(self, other):
if isinstance(other, F):
return other.__rmul__(self)
result = super().__mul__(other)
result.decimal_places = self._fix_decimal_places(other)
self._copy_attributes(other, result)
return result

def __truediv__(self, other):
if isinstance(other, F):
return other.__rtruediv__(self)
result = super().__truediv__(other)
if isinstance(result, self.__class__):
result.decimal_places = self._fix_decimal_places(other)
self._copy_attributes(other, result)
return result

def __rtruediv__(self, other):
# Backported from py-moneyd, non released bug-fix
# Backported from py-moneyed, non released bug-fix
# https://github.com/py-moneyed/py-moneyed/blob/c518745dd9d7902781409daec1a05699799474dd/moneyed/classes.py#L217-L218
raise TypeError("Cannot divide non-Money by a Money instance.")

Expand All @@ -100,7 +110,34 @@ def __html__(self):

def __round__(self, n=None):
amount = round(self.amount, n)
return self.__class__(amount, self.currency)
new = self.__class__(amount, self.currency)
self._copy_attributes(self, new)
return new

def round(self, ndigits=0):
new = super().round(ndigits)
self._copy_attributes(self, new)
return new

def __pos__(self):
new = super().__pos__()
self._copy_attributes(self, new)
return new

def __neg__(self):
new = super().__neg__()
self._copy_attributes(self, new)
return new

def __abs__(self):
new = super().__abs__()
self._copy_attributes(self, new)
return new

def __rmod__(self, other):
new = super().__rmod__(other)
self._copy_attributes(self, new)
return new

# DefaultMoney sets those synonym functions
# we overwrite the 'targets' so the wrong synonyms are called
Expand Down
2 changes: 2 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Changelog

- Pin ``pymoneyed<1.0`` as it changed the ``repr`` output of the ``Money`` class.
- Subtracting ``Money`` from ``moneyed.Money``. Regression, introduced in ``1.2``. `#593`_
- Missing the right ``Money.decimal_places`` and ``Money.decimal_places_display`` values after some arithmetic operations. `#595`_

`1.2.2`_ - 2020-12-29
---------------------
Expand Down Expand Up @@ -695,6 +696,7 @@ wrapping with ``money_manager``.
.. _0.3: https://github.com/django-money/django-money/compare/0.2...0.3
.. _0.2: https://github.com/django-money/django-money/compare/0.2...a6d90348085332a393abb40b86b5dd9505489b04

.. _#595: https://github.com/django-money/django-money/issues/595
.. _#593: https://github.com/django-money/django-money/issues/593
.. _#586: https://github.com/django-money/django-money/issues/586
.. _#585: https://github.com/django-money/django-money/pull/585
Expand Down
50 changes: 25 additions & 25 deletions tests/test_money.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,31 +80,31 @@ def test_add_decimal_places_zero():
assert result.decimal_places == 3


def test_mul_decimal_places():
""" Test __mul__ and __rmul__ """
two = Money("1.0000", "USD", decimal_places=4)

result = 2 * two
assert result.decimal_places == 4

result = two * 2
assert result.decimal_places == 4


def test_fix_decimal_places():
one = Money(1, "USD", decimal_places=7)
assert one._fix_decimal_places(Money(2, "USD", decimal_places=3)) == 7
assert one._fix_decimal_places(Money(2, "USD", decimal_places=30)) == 30


def test_fix_decimal_places_none():
one = Money(1, "USD", decimal_places=7)
assert one._fix_decimal_places(None) == 7


def test_fix_decimal_places_multiple():
one = Money(1, "USD", decimal_places=7)
assert one._fix_decimal_places(None, Money(3, "USD", decimal_places=8)) == 8
@pytest.mark.parametrize("decimal_places", (1, 4))
@pytest.mark.parametrize(
"operation",
(
lambda a, d: a * 2,
lambda a, d: 2 * a,
lambda a, d: a / 5,
lambda a, d: a - Money("2", "USD", decimal_places=d, decimal_places_display=d),
lambda a, d: Money("2", "USD", decimal_places=d, decimal_places_display=d) - a,
lambda a, d: a + Money("2", "USD", decimal_places=d, decimal_places_display=d),
lambda a, d: Money("2", "USD", decimal_places=d, decimal_places_display=d) + a,
lambda a, d: -a,
lambda a, d: +a,
lambda a, d: abs(a),
lambda a, d: 5 % a,
lambda a, d: round(a),
lambda a, d: a.round(),
),
)
def test_keep_decimal_places(operation, decimal_places):
# Arithmetic operations should keep the `decimal_places` value
amount = Money("1.0000", "USD", decimal_places=decimal_places, decimal_places_display=decimal_places)
new = operation(amount, decimal_places)
assert new.decimal_places == decimal_places
assert new.decimal_places_display == decimal_places


def test_decimal_places_display_overwrite():
Expand Down

0 comments on commit 69e89f9

Please sign in to comment.