Skip to content

Commit

Permalink
Use int64 to calculate totals.
Browse files Browse the repository at this point in the history
So we can get the right value if a discount calculation overflows its value.

Signed-off-by: David Calavera <david.calavera@gmail.com>
  • Loading branch information
calavera committed Nov 22, 2017
1 parent 6f9eb5c commit 09db216
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 23 deletions.
18 changes: 12 additions & 6 deletions calculator/calculator.go
Expand Up @@ -15,7 +15,7 @@ type Price struct {
Subtotal uint64 Subtotal uint64
Discount uint64 Discount uint64
Taxes uint64 Taxes uint64
Total uint64 Total int64
} }


// ItemPrice is the price of a single line item. // ItemPrice is the price of a single line item.
Expand All @@ -25,7 +25,7 @@ type ItemPrice struct {
Subtotal uint64 Subtotal uint64
Discount uint64 Discount uint64
Taxes uint64 Taxes uint64
Total uint64 Total int64
} }


// Settings represent the site-wide settings for price calculation. // Settings represent the site-wide settings for price calculation.
Expand Down Expand Up @@ -173,7 +173,10 @@ func CalculatePrice(settings *Settings, jwtClaims map[string]interface{}, params
} }


for _, item := range params.Items { for _, item := range params.Items {
lineLogger := priceLogger.WithField("product_type", item.ProductType()).WithField("product_sku", item.ProductSku()) lineLogger := priceLogger.WithFields(logrus.Fields{
"product_type": item.ProductType(),
"product_sku": item.ProductSku(),
})


itemPrice := ItemPrice{Quantity: item.GetQuantity()} itemPrice := ItemPrice{Quantity: item.GetQuantity()}
itemPrice.Subtotal = item.PriceInLowestUnit() itemPrice.Subtotal = item.PriceInLowestUnit()
Expand Down Expand Up @@ -228,7 +231,7 @@ func CalculatePrice(settings *Settings, jwtClaims map[string]interface{}, params
} }
} }


itemPrice.Total = itemPrice.Subtotal - itemPrice.Discount + itemPrice.Taxes itemPrice.Total = int64(itemPrice.Subtotal+itemPrice.Taxes) - int64(itemPrice.Discount)
if itemPrice.Total < 0 { if itemPrice.Total < 0 {
itemPrice.Total = 0 itemPrice.Total = 0
} }
Expand All @@ -246,10 +249,13 @@ func CalculatePrice(settings *Settings, jwtClaims map[string]interface{}, params
price.Subtotal += (itemPrice.Subtotal * itemPrice.Quantity) price.Subtotal += (itemPrice.Subtotal * itemPrice.Quantity)
price.Discount += (itemPrice.Discount * itemPrice.Quantity) price.Discount += (itemPrice.Discount * itemPrice.Quantity)
price.Taxes += (itemPrice.Taxes * itemPrice.Quantity) price.Taxes += (itemPrice.Taxes * itemPrice.Quantity)
price.Total += (itemPrice.Total * itemPrice.Quantity) price.Total += (itemPrice.Total * int64(itemPrice.Quantity))
} }


price.Total = price.Subtotal - price.Discount + price.Taxes price.Total = int64(price.Subtotal+price.Taxes) - int64(price.Discount)
if price.Total < 0 {
price.Total = 0
}
priceLogger.WithFields( priceLogger.WithFields(
logrus.Fields{ logrus.Fields{
"total_price": price.Total, "total_price": price.Total,
Expand Down
32 changes: 16 additions & 16 deletions calculator/calculator_test.go
Expand Up @@ -79,7 +79,7 @@ func (c *TestCoupon) FixedDiscount(currency string) uint64 {
func TestNoItems(t *testing.T) { func TestNoItems(t *testing.T) {
params := PriceParameters{"USA", "USD", nil, nil} params := PriceParameters{"USA", "USD", nil, nil}
price := CalculatePrice(nil, nil, params, testLogger) price := CalculatePrice(nil, nil, params, testLogger)
assert.Equal(t, uint64(0), price.Total) assert.Equal(t, int64(0), price.Total)
} }


func TestNoTaxes(t *testing.T) { func TestNoTaxes(t *testing.T) {
Expand All @@ -89,7 +89,7 @@ func TestNoTaxes(t *testing.T) {
assert.Equal(t, uint64(100), price.Subtotal) assert.Equal(t, uint64(100), price.Subtotal)
assert.Equal(t, uint64(0), price.Taxes) assert.Equal(t, uint64(0), price.Taxes)
assert.Equal(t, uint64(0), price.Discount) assert.Equal(t, uint64(0), price.Discount)
assert.Equal(t, uint64(100), price.Total) assert.Equal(t, int64(100), price.Total)
} }


func TestFixedVAT(t *testing.T) { func TestFixedVAT(t *testing.T) {
Expand All @@ -99,7 +99,7 @@ func TestFixedVAT(t *testing.T) {
assert.Equal(t, uint64(100), price.Subtotal) assert.Equal(t, uint64(100), price.Subtotal)
assert.Equal(t, uint64(9), price.Taxes) assert.Equal(t, uint64(9), price.Taxes)
assert.Equal(t, uint64(0), price.Discount) assert.Equal(t, uint64(0), price.Discount)
assert.Equal(t, uint64(109), price.Total) assert.Equal(t, int64(109), price.Total)
} }


func TestFixedVATWhenPricesIncludeTaxes(t *testing.T) { func TestFixedVATWhenPricesIncludeTaxes(t *testing.T) {
Expand All @@ -109,7 +109,7 @@ func TestFixedVATWhenPricesIncludeTaxes(t *testing.T) {
assert.Equal(t, uint64(92), price.Subtotal) assert.Equal(t, uint64(92), price.Subtotal)
assert.Equal(t, uint64(8), price.Taxes) assert.Equal(t, uint64(8), price.Taxes)
assert.Equal(t, uint64(0), price.Discount) assert.Equal(t, uint64(0), price.Discount)
assert.Equal(t, uint64(100), price.Total) assert.Equal(t, int64(100), price.Total)
} }


func TestCountryBasedVAT(t *testing.T) { func TestCountryBasedVAT(t *testing.T) {
Expand All @@ -127,7 +127,7 @@ func TestCountryBasedVAT(t *testing.T) {
assert.Equal(t, uint64(100), price.Subtotal) assert.Equal(t, uint64(100), price.Subtotal)
assert.Equal(t, uint64(21), price.Taxes) assert.Equal(t, uint64(21), price.Taxes)
assert.Equal(t, uint64(0), price.Discount) assert.Equal(t, uint64(0), price.Discount)
assert.Equal(t, uint64(121), price.Total) assert.Equal(t, int64(121), price.Total)
} }


func TestCouponWithNoTaxes(t *testing.T) { func TestCouponWithNoTaxes(t *testing.T) {
Expand All @@ -138,7 +138,7 @@ func TestCouponWithNoTaxes(t *testing.T) {
assert.Equal(t, uint64(100), price.Subtotal) assert.Equal(t, uint64(100), price.Subtotal)
assert.Equal(t, uint64(0), price.Taxes) assert.Equal(t, uint64(0), price.Taxes)
assert.Equal(t, uint64(10), price.Discount) assert.Equal(t, uint64(10), price.Discount)
assert.Equal(t, uint64(90), price.Total) assert.Equal(t, int64(90), price.Total)
} }


func TestCouponWithVAT(t *testing.T) { func TestCouponWithVAT(t *testing.T) {
Expand All @@ -149,7 +149,7 @@ func TestCouponWithVAT(t *testing.T) {
assert.Equal(t, uint64(100), price.Subtotal) assert.Equal(t, uint64(100), price.Subtotal)
assert.Equal(t, uint64(9), price.Taxes) assert.Equal(t, uint64(9), price.Taxes)
assert.Equal(t, uint64(10), price.Discount) assert.Equal(t, uint64(10), price.Discount)
assert.Equal(t, uint64(99), price.Total) assert.Equal(t, int64(99), price.Total)
} }


func TestCouponWithVATWhenPRiceIncludeTaxes(t *testing.T) { func TestCouponWithVATWhenPRiceIncludeTaxes(t *testing.T) {
Expand All @@ -161,7 +161,7 @@ func TestCouponWithVATWhenPRiceIncludeTaxes(t *testing.T) {
assert.Equal(t, uint64(92), price.Subtotal) assert.Equal(t, uint64(92), price.Subtotal)
assert.Equal(t, uint64(8), price.Taxes) assert.Equal(t, uint64(8), price.Taxes)
assert.Equal(t, uint64(10), price.Discount) assert.Equal(t, uint64(10), price.Discount)
assert.Equal(t, uint64(90), price.Total) assert.Equal(t, int64(90), price.Total)
} }


func TestCouponWithVATWhenPRiceIncludeTaxesWithQuantity(t *testing.T) { func TestCouponWithVATWhenPRiceIncludeTaxesWithQuantity(t *testing.T) {
Expand All @@ -173,7 +173,7 @@ func TestCouponWithVATWhenPRiceIncludeTaxesWithQuantity(t *testing.T) {
assert.Equal(t, uint64(184), price.Subtotal) assert.Equal(t, uint64(184), price.Subtotal)
assert.Equal(t, uint64(16), price.Taxes) assert.Equal(t, uint64(16), price.Taxes)
assert.Equal(t, uint64(20), price.Discount) assert.Equal(t, uint64(20), price.Discount)
assert.Equal(t, uint64(180), price.Total) assert.Equal(t, int64(180), price.Total)
} }


func TestPricingItems(t *testing.T) { func TestPricingItems(t *testing.T) {
Expand Down Expand Up @@ -203,7 +203,7 @@ func TestPricingItems(t *testing.T) {
assert.Equal(t, uint64(100), price.Subtotal) assert.Equal(t, uint64(100), price.Subtotal)
assert.Equal(t, uint64(10), price.Taxes) assert.Equal(t, uint64(10), price.Taxes)
assert.Equal(t, uint64(0), price.Discount) assert.Equal(t, uint64(0), price.Discount)
assert.Equal(t, uint64(110), price.Total) assert.Equal(t, int64(110), price.Total)
} }


func TestMemberDiscounts(t *testing.T) { func TestMemberDiscounts(t *testing.T) {
Expand All @@ -217,7 +217,7 @@ func TestMemberDiscounts(t *testing.T) {
assert.Equal(t, uint64(92), price.Subtotal) assert.Equal(t, uint64(92), price.Subtotal)
assert.Equal(t, uint64(8), price.Taxes) assert.Equal(t, uint64(8), price.Taxes)
assert.Equal(t, uint64(0), price.Discount) assert.Equal(t, uint64(0), price.Discount)
assert.Equal(t, uint64(100), price.Total) assert.Equal(t, int64(100), price.Total)


claims := map[string]interface{}{} claims := map[string]interface{}{}
require.NoError(t, json.Unmarshal([]byte(`{"app_metadata": {"plan": "member"}}`), &claims)) require.NoError(t, json.Unmarshal([]byte(`{"app_metadata": {"plan": "member"}}`), &claims))
Expand All @@ -228,7 +228,7 @@ func TestMemberDiscounts(t *testing.T) {
assert.Equal(t, uint64(92), price.Subtotal) assert.Equal(t, uint64(92), price.Subtotal)
assert.Equal(t, uint64(8), price.Taxes) assert.Equal(t, uint64(8), price.Taxes)
assert.Equal(t, uint64(10), price.Discount) assert.Equal(t, uint64(10), price.Discount)
assert.Equal(t, uint64(90), price.Total) assert.Equal(t, int64(90), price.Total)
} }


func TestFixedMemberDiscounts(t *testing.T) { func TestFixedMemberDiscounts(t *testing.T) {
Expand All @@ -246,7 +246,7 @@ func TestFixedMemberDiscounts(t *testing.T) {
assert.Equal(t, uint64(92), price.Subtotal) assert.Equal(t, uint64(92), price.Subtotal)
assert.Equal(t, uint64(8), price.Taxes) assert.Equal(t, uint64(8), price.Taxes)
assert.Equal(t, uint64(0), price.Discount) assert.Equal(t, uint64(0), price.Discount)
assert.Equal(t, uint64(100), price.Total) assert.Equal(t, int64(100), price.Total)


claims := map[string]interface{}{} claims := map[string]interface{}{}
require.NoError(t, json.Unmarshal([]byte(`{"app_metadata": {"plan": "member"}}`), &claims)) require.NoError(t, json.Unmarshal([]byte(`{"app_metadata": {"plan": "member"}}`), &claims))
Expand All @@ -257,7 +257,7 @@ func TestFixedMemberDiscounts(t *testing.T) {
assert.Equal(t, uint64(92), price.Subtotal) assert.Equal(t, uint64(92), price.Subtotal)
assert.Equal(t, uint64(8), price.Taxes) assert.Equal(t, uint64(8), price.Taxes)
assert.Equal(t, uint64(10), price.Discount) assert.Equal(t, uint64(10), price.Discount)
assert.Equal(t, uint64(90), price.Total) assert.Equal(t, int64(90), price.Total)
} }


func TestMixedDiscounts(t *testing.T) { func TestMixedDiscounts(t *testing.T) {
Expand All @@ -269,7 +269,7 @@ func TestMixedDiscounts(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)


item := &TestItem{ item := &TestItem{
sku: "inclusive-design-patterns", sku: "design-systems-ebook",
itemType: "Book", itemType: "Book",
quantity: 1, quantity: 1,
price: 3490, price: 3490,
Expand All @@ -287,5 +287,5 @@ func TestMixedDiscounts(t *testing.T) {
}, },
} }
price = CalculatePrice(&settings, claims, params, testLogger) price = CalculatePrice(&settings, claims, params, testLogger)
assert.Equal(t, 1990, int(price.Total)) assert.Equal(t, int64(0), price.Total)
} }
31 changes: 31 additions & 0 deletions claims/claims_test.go
@@ -0,0 +1,31 @@
package claims

import (
"encoding/json"
"io/ioutil"
"testing"

"github.com/stretchr/testify/assert"
)

func TestClaims(t *testing.T) {
b, err := ioutil.ReadFile("test/jwt_payload_fixture.json")
assert.NoError(t, err)

var claims map[string]interface{}
err = json.Unmarshal(b, &claims)

required := map[string]string{
"app_metadata.subscription.plan": "smashing",
}

matches := HasClaims(claims, required)
assert.True(t, matches)

required = map[string]string{
"app_metadata.subscription.plan": "member",
}

matches = HasClaims(claims, required)
assert.False(t, matches)
}
24 changes: 24 additions & 0 deletions claims/test/jwt_payload_fixture.json
@@ -0,0 +1,24 @@
{
"exp": 1511383997,
"sub": "68904e78-96c0-4de1-aa37-fa71b1790756",
"email": "matt@netlify.com",
"app_metadata": {
"customer": {
"id": "c_id"
},
"provider": "email",
"roles": [
"cms",
"admin"
],
"subscription": {
"id": "sub_id",
"plan": "smashing"
}
},
"user_metadata": {
"firstname": "Matt",
"full_name": "Matt Biilmann",
"lastname": "Biilmann"
}
}
5 changes: 4 additions & 1 deletion models/order.go
Expand Up @@ -159,7 +159,10 @@ func (o *Order) CalculateTotal(settings *calculator.Settings, claims map[string]
o.SubTotal = price.Subtotal o.SubTotal = price.Subtotal
o.Taxes = price.Taxes o.Taxes = price.Taxes
o.Discount = price.Discount o.Discount = price.Discount
o.Total = price.Total
if price.Total > 0 {
o.Total = uint64(price.Total)
}
} }


func (o *Order) BeforeDelete(tx *gorm.DB) error { func (o *Order) BeforeDelete(tx *gorm.DB) error {
Expand Down

0 comments on commit 09db216

Please sign in to comment.