Permalink
Browse files

Use int64 to calculate totals.

So we can get the right value if a discount calculation overflows its value.

Signed-off-by: David Calavera <david.calavera@gmail.com>
  • Loading branch information...
calavera committed Nov 22, 2017
1 parent 6f9eb5c commit 09db216158df64bb6a3e19e9bfb1dbdeb8cf7b5c
Showing with 87 additions and 23 deletions.
  1. +12 −6 calculator/calculator.go
  2. +16 −16 calculator/calculator_test.go
  3. +31 −0 claims/claims_test.go
  4. +24 −0 claims/test/jwt_payload_fixture.json
  5. +4 −1 models/order.go
View
@@ -15,7 +15,7 @@ type Price struct {
Subtotal uint64
Discount uint64
Taxes uint64
Total uint64
Total int64
}
// ItemPrice is the price of a single line item.
@@ -25,7 +25,7 @@ type ItemPrice struct {
Subtotal uint64
Discount uint64
Taxes uint64
Total uint64
Total int64
}
// Settings represent the site-wide settings for price calculation.
@@ -173,7 +173,10 @@ func CalculatePrice(settings *Settings, jwtClaims map[string]interface{}, params
}
for _, item := range params.Items {
lineLogger := priceLogger.WithField("product_type", item.ProductType()).WithField("product_sku", item.ProductSku())
lineLogger := priceLogger.WithFields(logrus.Fields{
"product_type": item.ProductType(),
"product_sku": item.ProductSku(),
})
itemPrice := ItemPrice{Quantity: item.GetQuantity()}
itemPrice.Subtotal = item.PriceInLowestUnit()
@@ -228,7 +231,7 @@ func CalculatePrice(settings *Settings, jwtClaims map[string]interface{}, params
}
}
itemPrice.Total = itemPrice.Subtotal - itemPrice.Discount + itemPrice.Taxes
itemPrice.Total = int64(itemPrice.Subtotal+itemPrice.Taxes) - int64(itemPrice.Discount)
if itemPrice.Total < 0 {
itemPrice.Total = 0
}
@@ -246,10 +249,13 @@ func CalculatePrice(settings *Settings, jwtClaims map[string]interface{}, params
price.Subtotal += (itemPrice.Subtotal * itemPrice.Quantity)
price.Discount += (itemPrice.Discount * itemPrice.Quantity)
price.Taxes += (itemPrice.Taxes * itemPrice.Quantity)
price.Total += (itemPrice.Total * itemPrice.Quantity)
price.Total += (itemPrice.Total * int64(itemPrice.Quantity))
}
price.Total = price.Subtotal - price.Discount + price.Taxes
price.Total = int64(price.Subtotal+price.Taxes) - int64(price.Discount)
if price.Total < 0 {
price.Total = 0
}
priceLogger.WithFields(
logrus.Fields{
"total_price": price.Total,
@@ -79,7 +79,7 @@ func (c *TestCoupon) FixedDiscount(currency string) uint64 {
func TestNoItems(t *testing.T) {
params := PriceParameters{"USA", "USD", nil, nil}
price := CalculatePrice(nil, nil, params, testLogger)
assert.Equal(t, uint64(0), price.Total)
assert.Equal(t, int64(0), price.Total)
}
func TestNoTaxes(t *testing.T) {
@@ -89,7 +89,7 @@ func TestNoTaxes(t *testing.T) {
assert.Equal(t, uint64(100), price.Subtotal)
assert.Equal(t, uint64(0), price.Taxes)
assert.Equal(t, uint64(0), price.Discount)
assert.Equal(t, uint64(100), price.Total)
assert.Equal(t, int64(100), price.Total)
}
func TestFixedVAT(t *testing.T) {
@@ -99,7 +99,7 @@ func TestFixedVAT(t *testing.T) {
assert.Equal(t, uint64(100), price.Subtotal)
assert.Equal(t, uint64(9), price.Taxes)
assert.Equal(t, uint64(0), price.Discount)
assert.Equal(t, uint64(109), price.Total)
assert.Equal(t, int64(109), price.Total)
}
func TestFixedVATWhenPricesIncludeTaxes(t *testing.T) {
@@ -109,7 +109,7 @@ func TestFixedVATWhenPricesIncludeTaxes(t *testing.T) {
assert.Equal(t, uint64(92), price.Subtotal)
assert.Equal(t, uint64(8), price.Taxes)
assert.Equal(t, uint64(0), price.Discount)
assert.Equal(t, uint64(100), price.Total)
assert.Equal(t, int64(100), price.Total)
}
func TestCountryBasedVAT(t *testing.T) {
@@ -127,7 +127,7 @@ func TestCountryBasedVAT(t *testing.T) {
assert.Equal(t, uint64(100), price.Subtotal)
assert.Equal(t, uint64(21), price.Taxes)
assert.Equal(t, uint64(0), price.Discount)
assert.Equal(t, uint64(121), price.Total)
assert.Equal(t, int64(121), price.Total)
}
func TestCouponWithNoTaxes(t *testing.T) {
@@ -138,7 +138,7 @@ func TestCouponWithNoTaxes(t *testing.T) {
assert.Equal(t, uint64(100), price.Subtotal)
assert.Equal(t, uint64(0), price.Taxes)
assert.Equal(t, uint64(10), price.Discount)
assert.Equal(t, uint64(90), price.Total)
assert.Equal(t, int64(90), price.Total)
}
func TestCouponWithVAT(t *testing.T) {
@@ -149,7 +149,7 @@ func TestCouponWithVAT(t *testing.T) {
assert.Equal(t, uint64(100), price.Subtotal)
assert.Equal(t, uint64(9), price.Taxes)
assert.Equal(t, uint64(10), price.Discount)
assert.Equal(t, uint64(99), price.Total)
assert.Equal(t, int64(99), price.Total)
}
func TestCouponWithVATWhenPRiceIncludeTaxes(t *testing.T) {
@@ -161,7 +161,7 @@ func TestCouponWithVATWhenPRiceIncludeTaxes(t *testing.T) {
assert.Equal(t, uint64(92), price.Subtotal)
assert.Equal(t, uint64(8), price.Taxes)
assert.Equal(t, uint64(10), price.Discount)
assert.Equal(t, uint64(90), price.Total)
assert.Equal(t, int64(90), price.Total)
}
func TestCouponWithVATWhenPRiceIncludeTaxesWithQuantity(t *testing.T) {
@@ -173,7 +173,7 @@ func TestCouponWithVATWhenPRiceIncludeTaxesWithQuantity(t *testing.T) {
assert.Equal(t, uint64(184), price.Subtotal)
assert.Equal(t, uint64(16), price.Taxes)
assert.Equal(t, uint64(20), price.Discount)
assert.Equal(t, uint64(180), price.Total)
assert.Equal(t, int64(180), price.Total)
}
func TestPricingItems(t *testing.T) {
@@ -203,7 +203,7 @@ func TestPricingItems(t *testing.T) {
assert.Equal(t, uint64(100), price.Subtotal)
assert.Equal(t, uint64(10), price.Taxes)
assert.Equal(t, uint64(0), price.Discount)
assert.Equal(t, uint64(110), price.Total)
assert.Equal(t, int64(110), price.Total)
}
func TestMemberDiscounts(t *testing.T) {
@@ -217,7 +217,7 @@ func TestMemberDiscounts(t *testing.T) {
assert.Equal(t, uint64(92), price.Subtotal)
assert.Equal(t, uint64(8), price.Taxes)
assert.Equal(t, uint64(0), price.Discount)
assert.Equal(t, uint64(100), price.Total)
assert.Equal(t, int64(100), price.Total)
claims := map[string]interface{}{}
require.NoError(t, json.Unmarshal([]byte(`{"app_metadata": {"plan": "member"}}`), &claims))
@@ -228,7 +228,7 @@ func TestMemberDiscounts(t *testing.T) {
assert.Equal(t, uint64(92), price.Subtotal)
assert.Equal(t, uint64(8), price.Taxes)
assert.Equal(t, uint64(10), price.Discount)
assert.Equal(t, uint64(90), price.Total)
assert.Equal(t, int64(90), price.Total)
}
func TestFixedMemberDiscounts(t *testing.T) {
@@ -246,7 +246,7 @@ func TestFixedMemberDiscounts(t *testing.T) {
assert.Equal(t, uint64(92), price.Subtotal)
assert.Equal(t, uint64(8), price.Taxes)
assert.Equal(t, uint64(0), price.Discount)
assert.Equal(t, uint64(100), price.Total)
assert.Equal(t, int64(100), price.Total)
claims := map[string]interface{}{}
require.NoError(t, json.Unmarshal([]byte(`{"app_metadata": {"plan": "member"}}`), &claims))
@@ -257,7 +257,7 @@ func TestFixedMemberDiscounts(t *testing.T) {
assert.Equal(t, uint64(92), price.Subtotal)
assert.Equal(t, uint64(8), price.Taxes)
assert.Equal(t, uint64(10), price.Discount)
assert.Equal(t, uint64(90), price.Total)
assert.Equal(t, int64(90), price.Total)
}
func TestMixedDiscounts(t *testing.T) {
@@ -269,7 +269,7 @@ func TestMixedDiscounts(t *testing.T) {
assert.NoError(t, err)
item := &TestItem{
sku: "inclusive-design-patterns",
sku: "design-systems-ebook",
itemType: "Book",
quantity: 1,
price: 3490,
@@ -287,5 +287,5 @@ func TestMixedDiscounts(t *testing.T) {
},
}
price = CalculatePrice(&settings, claims, params, testLogger)
assert.Equal(t, 1990, int(price.Total))
assert.Equal(t, int64(0), price.Total)
}
View
@@ -0,0 +1,31 @@
package claims
import (
"encoding/json"
"io/ioutil"
"testing"
"github.com/stretchr/testify/assert"
)
func TestClaims(t *testing.T) {
b, err := ioutil.ReadFile("test/jwt_payload_fixture.json")
assert.NoError(t, err)
var claims map[string]interface{}
err = json.Unmarshal(b, &claims)
required := map[string]string{
"app_metadata.subscription.plan": "smashing",
}
matches := HasClaims(claims, required)
assert.True(t, matches)
required = map[string]string{
"app_metadata.subscription.plan": "member",
}
matches = HasClaims(claims, required)
assert.False(t, matches)
}
@@ -0,0 +1,24 @@
{
"exp": 1511383997,
"sub": "68904e78-96c0-4de1-aa37-fa71b1790756",
"email": "matt@netlify.com",
"app_metadata": {
"customer": {
"id": "c_id"
},
"provider": "email",
"roles": [
"cms",
"admin"
],
"subscription": {
"id": "sub_id",
"plan": "smashing"
}
},
"user_metadata": {
"firstname": "Matt",
"full_name": "Matt Biilmann",
"lastname": "Biilmann"
}
}
View
@@ -159,7 +159,10 @@ func (o *Order) CalculateTotal(settings *calculator.Settings, claims map[string]
o.SubTotal = price.Subtotal
o.Taxes = price.Taxes
o.Discount = price.Discount
o.Total = price.Total
if price.Total > 0 {
o.Total = uint64(price.Total)
}
}
func (o *Order) BeforeDelete(tx *gorm.DB) error {

0 comments on commit 09db216

Please sign in to comment.