Skip to content

Commit

Permalink
cover int64/bigInt with tests, address precision handling
Browse files Browse the repository at this point in the history
In testing amounts at and near max int64, the hard-coded precision at
16 digits caused the amounts to be silently zeroed out.  Predicting
the correct precision isn't exact, and should be extended, but attempts
to over-estimate the required precision to perform the calculation.
  • Loading branch information
Kunde21 committed Feb 12, 2021
1 parent 812367a commit a39e001
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 14 deletions.
46 changes: 32 additions & 14 deletions amount.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"database/sql/driver"
"encoding/json"
"fmt"
"math"
"math/big"
"strings"

Expand Down Expand Up @@ -79,21 +80,21 @@ func NewAmount(n, currencyCode string) (Amount, error) {
}

// NewAmountFromBigInt creates a new Amount from a big integer and a currency code.
func NewAmountFromBigInt(amt *big.Int, currencyCode string) (Amount, error) {
if amt == nil {
return Amount{}, InvalidNumberError{"NewAmountFromBigInt", fmt.Sprint(amt)}
func NewAmountFromBigInt(n *big.Int, currencyCode string) (Amount, error) {
if n == nil {
return Amount{}, InvalidNumberError{"NewAmountFromBigInt", fmt.Sprint(n)}
}
if currencyCode == "" || !IsValid(currencyCode) {
return Amount{}, InvalidCurrencyCodeError{"NewAmountFromBigInt", currencyCode}
}
d, _ := GetDigits(currencyCode)

return Amount{apd.NewWithBigInt(amt, -int32(d)), currencyCode}, nil
return Amount{apd.NewWithBigInt(n, -int32(d)), currencyCode}, nil
}

// NewAmount creates a new Amount from an int64 and a currency code.
func NewAmountFromInt64(amt int64, currencyCode string) (Amount, error) {
return NewAmountFromBigInt(big.NewInt(amt), currencyCode)
// NewAmountFromInt64 creates a new Amount from an int64 and a currency code.
func NewAmountFromInt64(n int64, currencyCode string) (Amount, error) {
return NewAmountFromBigInt(big.NewInt(n), currencyCode)
}

// Number returns the number as a numeric string.
Expand Down Expand Up @@ -133,7 +134,9 @@ func (a Amount) BigInt() *big.Int {
// Int64 returns the integer value of a in minor units.
// Returns an error if value can't be expressed as a 64-bit integer.
func (a Amount) Int64() (int64, error) {
return a.Round().number.Int64()
n := *a.Round().number
n.Exponent = 0
return n.Int64()
}

// Convert converts a to a different currency.
Expand All @@ -145,7 +148,7 @@ func (a Amount) Convert(currencyCode, rate string) (Amount, error) {
if err != nil {
return Amount{}, InvalidNumberError{"Amount.Convert", rate}
}
ctx := apd.BaseContext.WithPrecision(16)
ctx := contextPrecision(a.number.NumDigits(), result.NumDigits())
ctx.Mul(result, a.number, result)

return Amount{result, currencyCode}, nil
Expand All @@ -157,7 +160,7 @@ func (a Amount) Add(b Amount) (Amount, error) {
return Amount{}, MismatchError{"Amount.Add", a, b}
}
result := apd.New(0, 0)
ctx := apd.BaseContext.WithPrecision(16)
ctx := contextPrecision(a.number.NumDigits(), b.number.NumDigits())
ctx.Add(result, a.number, b.number)

return Amount{result, a.currencyCode}, nil
Expand All @@ -169,7 +172,7 @@ func (a Amount) Sub(b Amount) (Amount, error) {
return Amount{}, MismatchError{"Amount.Sub", a, b}
}
result := apd.New(0, 0)
ctx := apd.BaseContext.WithPrecision(16)
ctx := contextPrecision(a.number.NumDigits(), result.NumDigits())
ctx.Sub(result, a.number, b.number)

return Amount{result, a.currencyCode}, nil
Expand All @@ -181,7 +184,7 @@ func (a Amount) Mul(n string) (Amount, error) {
if err != nil {
return Amount{}, InvalidNumberError{"Amount.Mul", n}
}
ctx := apd.BaseContext.WithPrecision(16)
ctx := contextPrecision(a.number.NumDigits(), result.NumDigits())
ctx.Mul(result, a.number, result)

return Amount{result, a.currencyCode}, err
Expand All @@ -193,7 +196,7 @@ func (a Amount) Div(n string) (Amount, error) {
if err != nil || result.IsZero() {
return Amount{}, InvalidNumberError{"Amount.Div", n}
}
ctx := apd.BaseContext.WithPrecision(16)
ctx := contextPrecision(a.number.NumDigits(), result.NumDigits())
ctx.Quo(result, a.number, result)

return Amount{result, a.currencyCode}, err
Expand All @@ -216,13 +219,28 @@ func (a Amount) RoundTo(digits uint8, mode RoundingMode) Amount {
RoundDown: apd.RoundDown,
}
result := apd.New(0, 0)
ctx := apd.BaseContext.WithPrecision(16)
ctx := contextPrecision(a.number.NumDigits())
ctx.Rounding = extModes[mode]
ctx.Quantize(result, a.number, -int32(digits))

return Amount{result, a.currencyCode}
}

func contextPrecision(digits ...int64) *apd.Context {
dg := big.NewInt(0)
for _, d := range digits {
dg = dg.Add(dg, big.NewInt(d))
}
switch {
case !dg.IsInt64(), dg.Int64() > math.MaxUint32:
return apd.BaseContext.WithPrecision(math.MaxUint32)
case dg.Int64() < 16:
return apd.BaseContext.WithPrecision(16)
default:
return apd.BaseContext.WithPrecision(uint32(dg.Int64()))
}
}

// Cmp compares a and b and returns:
//
// -1 if a < b
Expand Down
116 changes: 116 additions & 0 deletions amount_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ package currency_test

import (
"encoding/json"
"fmt"
"math"
"math/big"
"strconv"
"testing"

"github.com/bojanz/currency"
Expand Down Expand Up @@ -58,6 +62,68 @@ func TestNewAmount(t *testing.T) {
}
}

func TestAmount_Int(t *testing.T) {
const cur = "EUR"
tests := []struct {
number string
want string
}{
{"1099", "10.99"},
{"1234567890123456789034", "12345678901234567890.34"},
{strconv.Itoa(math.MaxInt64), "92233720368547758.07"},
}

for _, tt := range tests {
bi, ok := big.NewInt(0).SetString(tt.number, 10)
if !ok {
t.Fatalf("invalid big int %v", tt.number)
}
a, err := currency.NewAmountFromBigInt(bi, cur)
if err != nil {
t.Errorf("unexpected error %v", err)
}
if a.Number() != tt.want {
t.Errorf("got %v want %v", a.Number(), tt.want)
}
wantStr := tt.want + " " + cur
if a.String() != wantStr {
t.Errorf("got %v want %v", a.Number(), wantStr)
}

if a.BigInt().Cmp(bi) != 0 {
t.Errorf("got %v want %v", a.BigInt(), bi)
}

// only test valid int64 values
if !bi.IsInt64() {
continue
}
ai, err := currency.NewAmountFromInt64(bi.Int64(), cur)
if err != nil {
t.Errorf("unexpected error %v", err)
}
if ai.Number() != tt.want {
t.Errorf("got %v want %v", ai.Number(), tt.want)
}
if ai.String() != wantStr {
t.Errorf("got %v want %v", ai.Number(), wantStr)
}

ival, err := ai.Int64()
if err != nil {
t.Fatalf("unexpected Int64 error %v", err)
}

if ival != bi.Int64() {
t.Errorf("got %v want %v", ival, bi.Int64())
}

if !ai.Equal(a) {
t.Errorf("%v (int64) != %v (bigInt)", ai, a)
}
}
}

func TestAmount_ToMinorUnits(t *testing.T) {
tests := []struct {
number string
Expand Down Expand Up @@ -344,6 +410,12 @@ func TestAmount_RoundTo(t *testing.T) {
{"12.345", 0, currency.RoundHalfDown, "12"},
{"12.345", 0, currency.RoundUp, "13"},
{"12.345", 0, currency.RoundDown, "12"},

// large amounts (> max int64).
{"12345678901234567890.0345", 3, currency.RoundHalfUp, "12345678901234567890.035"},
{"12345678901234567890.0345", 3, currency.RoundHalfDown, "12345678901234567890.034"},
{"12345678901234567890.0345", 3, currency.RoundUp, "12345678901234567890.035"},
{"12345678901234567890.0345", 3, currency.RoundDown, "12345678901234567890.034"},
}

for _, tt := range tests {
Expand Down Expand Up @@ -655,3 +727,47 @@ func TestAmount_Scan(t *testing.T) {
})
}
}

func TestAmount_BigInt(t *testing.T) {
tests := []struct {
n string
cur string
wantError error
}{
{"", "USD", currency.InvalidNumberError{"NewAmountFromBigInt", fmt.Sprint(nil)}},
{"100", "UST", currency.InvalidCurrencyCodeError{"NewAmountFromBigInt", "UST"}},
{"100", "JPY", nil},
}
for _, tt := range tests {
t.Run(tt.n+tt.cur, func(t *testing.T) {
bi, ok := big.NewInt(0).SetString(tt.n, 10)
if !ok {
if tt.n != "" {
t.Fatal("number not parsed", tt.n)
}
bi = nil
}
amt, err := currency.NewAmountFromBigInt(bi, tt.cur)
if err != tt.wantError {
t.Errorf("error: got %v, want %v", err, tt.wantError)
}
if tt.n == "" {
empty := currency.Amount{}
if amt != empty {
t.Errorf("error: got %v, want %v", err, tt.wantError)
}
return
}
if tt.wantError != nil {
return
}
exAmt, err := currency.NewAmount(tt.n, tt.cur)
if err != tt.wantError {
t.Errorf("error: got %v, want %v", err, tt.wantError)
}
if !amt.Equal(exAmt) {
t.Errorf("amt: got %v, want %v", amt, exAmt)
}
})
}
}

0 comments on commit a39e001

Please sign in to comment.