From 09db216158df64bb6a3e19e9bfb1dbdeb8cf7b5c Mon Sep 17 00:00:00 2001 From: David Calavera Date: Wed, 22 Nov 2017 14:04:44 -0800 Subject: [PATCH] Use int64 to calculate totals. So we can get the right value if a discount calculation overflows its value. Signed-off-by: David Calavera --- calculator/calculator.go | 18 ++++++++++------ calculator/calculator_test.go | 32 ++++++++++++++-------------- claims/claims_test.go | 31 +++++++++++++++++++++++++++ claims/test/jwt_payload_fixture.json | 24 +++++++++++++++++++++ models/order.go | 5 ++++- 5 files changed, 87 insertions(+), 23 deletions(-) create mode 100644 claims/claims_test.go create mode 100644 claims/test/jwt_payload_fixture.json diff --git a/calculator/calculator.go b/calculator/calculator.go index 95db526..a9fb309 100644 --- a/calculator/calculator.go +++ b/calculator/calculator.go @@ -15,7 +15,7 @@ type Price struct { Subtotal uint64 Discount uint64 Taxes uint64 - Total uint64 + Total int64 } // ItemPrice is the price of a single line item. @@ -25,7 +25,7 @@ type ItemPrice struct { Subtotal uint64 Discount uint64 Taxes uint64 - Total uint64 + Total int64 } // Settings represent the site-wide settings for price calculation. @@ -173,7 +173,10 @@ func CalculatePrice(settings *Settings, jwtClaims map[string]interface{}, params } for _, item := range params.Items { - lineLogger := priceLogger.WithField("product_type", item.ProductType()).WithField("product_sku", item.ProductSku()) + lineLogger := priceLogger.WithFields(logrus.Fields{ + "product_type": item.ProductType(), + "product_sku": item.ProductSku(), + }) itemPrice := ItemPrice{Quantity: item.GetQuantity()} itemPrice.Subtotal = item.PriceInLowestUnit() @@ -228,7 +231,7 @@ func CalculatePrice(settings *Settings, jwtClaims map[string]interface{}, params } } - itemPrice.Total = itemPrice.Subtotal - itemPrice.Discount + itemPrice.Taxes + itemPrice.Total = int64(itemPrice.Subtotal+itemPrice.Taxes) - int64(itemPrice.Discount) if itemPrice.Total < 0 { itemPrice.Total = 0 } @@ -246,10 +249,13 @@ func CalculatePrice(settings *Settings, jwtClaims map[string]interface{}, params price.Subtotal += (itemPrice.Subtotal * itemPrice.Quantity) price.Discount += (itemPrice.Discount * itemPrice.Quantity) price.Taxes += (itemPrice.Taxes * itemPrice.Quantity) - price.Total += (itemPrice.Total * itemPrice.Quantity) + price.Total += (itemPrice.Total * int64(itemPrice.Quantity)) } - price.Total = price.Subtotal - price.Discount + price.Taxes + price.Total = int64(price.Subtotal+price.Taxes) - int64(price.Discount) + if price.Total < 0 { + price.Total = 0 + } priceLogger.WithFields( logrus.Fields{ "total_price": price.Total, diff --git a/calculator/calculator_test.go b/calculator/calculator_test.go index cd4f5ef..2bc94e4 100644 --- a/calculator/calculator_test.go +++ b/calculator/calculator_test.go @@ -79,7 +79,7 @@ func (c *TestCoupon) FixedDiscount(currency string) uint64 { func TestNoItems(t *testing.T) { params := PriceParameters{"USA", "USD", nil, nil} price := CalculatePrice(nil, nil, params, testLogger) - assert.Equal(t, uint64(0), price.Total) + assert.Equal(t, int64(0), price.Total) } func TestNoTaxes(t *testing.T) { @@ -89,7 +89,7 @@ func TestNoTaxes(t *testing.T) { assert.Equal(t, uint64(100), price.Subtotal) assert.Equal(t, uint64(0), price.Taxes) assert.Equal(t, uint64(0), price.Discount) - assert.Equal(t, uint64(100), price.Total) + assert.Equal(t, int64(100), price.Total) } func TestFixedVAT(t *testing.T) { @@ -99,7 +99,7 @@ func TestFixedVAT(t *testing.T) { assert.Equal(t, uint64(100), price.Subtotal) assert.Equal(t, uint64(9), price.Taxes) assert.Equal(t, uint64(0), price.Discount) - assert.Equal(t, uint64(109), price.Total) + assert.Equal(t, int64(109), price.Total) } func TestFixedVATWhenPricesIncludeTaxes(t *testing.T) { @@ -109,7 +109,7 @@ func TestFixedVATWhenPricesIncludeTaxes(t *testing.T) { assert.Equal(t, uint64(92), price.Subtotal) assert.Equal(t, uint64(8), price.Taxes) assert.Equal(t, uint64(0), price.Discount) - assert.Equal(t, uint64(100), price.Total) + assert.Equal(t, int64(100), price.Total) } func TestCountryBasedVAT(t *testing.T) { @@ -127,7 +127,7 @@ func TestCountryBasedVAT(t *testing.T) { assert.Equal(t, uint64(100), price.Subtotal) assert.Equal(t, uint64(21), price.Taxes) assert.Equal(t, uint64(0), price.Discount) - assert.Equal(t, uint64(121), price.Total) + assert.Equal(t, int64(121), price.Total) } func TestCouponWithNoTaxes(t *testing.T) { @@ -138,7 +138,7 @@ func TestCouponWithNoTaxes(t *testing.T) { assert.Equal(t, uint64(100), price.Subtotal) assert.Equal(t, uint64(0), price.Taxes) assert.Equal(t, uint64(10), price.Discount) - assert.Equal(t, uint64(90), price.Total) + assert.Equal(t, int64(90), price.Total) } func TestCouponWithVAT(t *testing.T) { @@ -149,7 +149,7 @@ func TestCouponWithVAT(t *testing.T) { assert.Equal(t, uint64(100), price.Subtotal) assert.Equal(t, uint64(9), price.Taxes) assert.Equal(t, uint64(10), price.Discount) - assert.Equal(t, uint64(99), price.Total) + assert.Equal(t, int64(99), price.Total) } func TestCouponWithVATWhenPRiceIncludeTaxes(t *testing.T) { @@ -161,7 +161,7 @@ func TestCouponWithVATWhenPRiceIncludeTaxes(t *testing.T) { assert.Equal(t, uint64(92), price.Subtotal) assert.Equal(t, uint64(8), price.Taxes) assert.Equal(t, uint64(10), price.Discount) - assert.Equal(t, uint64(90), price.Total) + assert.Equal(t, int64(90), price.Total) } func TestCouponWithVATWhenPRiceIncludeTaxesWithQuantity(t *testing.T) { @@ -173,7 +173,7 @@ func TestCouponWithVATWhenPRiceIncludeTaxesWithQuantity(t *testing.T) { assert.Equal(t, uint64(184), price.Subtotal) assert.Equal(t, uint64(16), price.Taxes) assert.Equal(t, uint64(20), price.Discount) - assert.Equal(t, uint64(180), price.Total) + assert.Equal(t, int64(180), price.Total) } func TestPricingItems(t *testing.T) { @@ -203,7 +203,7 @@ func TestPricingItems(t *testing.T) { assert.Equal(t, uint64(100), price.Subtotal) assert.Equal(t, uint64(10), price.Taxes) assert.Equal(t, uint64(0), price.Discount) - assert.Equal(t, uint64(110), price.Total) + assert.Equal(t, int64(110), price.Total) } func TestMemberDiscounts(t *testing.T) { @@ -217,7 +217,7 @@ func TestMemberDiscounts(t *testing.T) { assert.Equal(t, uint64(92), price.Subtotal) assert.Equal(t, uint64(8), price.Taxes) assert.Equal(t, uint64(0), price.Discount) - assert.Equal(t, uint64(100), price.Total) + assert.Equal(t, int64(100), price.Total) claims := map[string]interface{}{} require.NoError(t, json.Unmarshal([]byte(`{"app_metadata": {"plan": "member"}}`), &claims)) @@ -228,7 +228,7 @@ func TestMemberDiscounts(t *testing.T) { assert.Equal(t, uint64(92), price.Subtotal) assert.Equal(t, uint64(8), price.Taxes) assert.Equal(t, uint64(10), price.Discount) - assert.Equal(t, uint64(90), price.Total) + assert.Equal(t, int64(90), price.Total) } func TestFixedMemberDiscounts(t *testing.T) { @@ -246,7 +246,7 @@ func TestFixedMemberDiscounts(t *testing.T) { assert.Equal(t, uint64(92), price.Subtotal) assert.Equal(t, uint64(8), price.Taxes) assert.Equal(t, uint64(0), price.Discount) - assert.Equal(t, uint64(100), price.Total) + assert.Equal(t, int64(100), price.Total) claims := map[string]interface{}{} require.NoError(t, json.Unmarshal([]byte(`{"app_metadata": {"plan": "member"}}`), &claims)) @@ -257,7 +257,7 @@ func TestFixedMemberDiscounts(t *testing.T) { assert.Equal(t, uint64(92), price.Subtotal) assert.Equal(t, uint64(8), price.Taxes) assert.Equal(t, uint64(10), price.Discount) - assert.Equal(t, uint64(90), price.Total) + assert.Equal(t, int64(90), price.Total) } func TestMixedDiscounts(t *testing.T) { @@ -269,7 +269,7 @@ func TestMixedDiscounts(t *testing.T) { assert.NoError(t, err) item := &TestItem{ - sku: "inclusive-design-patterns", + sku: "design-systems-ebook", itemType: "Book", quantity: 1, price: 3490, @@ -287,5 +287,5 @@ func TestMixedDiscounts(t *testing.T) { }, } price = CalculatePrice(&settings, claims, params, testLogger) - assert.Equal(t, 1990, int(price.Total)) + assert.Equal(t, int64(0), price.Total) } diff --git a/claims/claims_test.go b/claims/claims_test.go new file mode 100644 index 0000000..d65e45f --- /dev/null +++ b/claims/claims_test.go @@ -0,0 +1,31 @@ +package claims + +import ( + "encoding/json" + "io/ioutil" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestClaims(t *testing.T) { + b, err := ioutil.ReadFile("test/jwt_payload_fixture.json") + assert.NoError(t, err) + + var claims map[string]interface{} + err = json.Unmarshal(b, &claims) + + required := map[string]string{ + "app_metadata.subscription.plan": "smashing", + } + + matches := HasClaims(claims, required) + assert.True(t, matches) + + required = map[string]string{ + "app_metadata.subscription.plan": "member", + } + + matches = HasClaims(claims, required) + assert.False(t, matches) +} diff --git a/claims/test/jwt_payload_fixture.json b/claims/test/jwt_payload_fixture.json new file mode 100644 index 0000000..dbdc9e4 --- /dev/null +++ b/claims/test/jwt_payload_fixture.json @@ -0,0 +1,24 @@ +{ + "exp": 1511383997, + "sub": "68904e78-96c0-4de1-aa37-fa71b1790756", + "email": "matt@netlify.com", + "app_metadata": { + "customer": { + "id": "c_id" + }, + "provider": "email", + "roles": [ + "cms", + "admin" + ], + "subscription": { + "id": "sub_id", + "plan": "smashing" + } + }, + "user_metadata": { + "firstname": "Matt", + "full_name": "Matt Biilmann", + "lastname": "Biilmann" + } +} diff --git a/models/order.go b/models/order.go index e4c7919..238f08d 100644 --- a/models/order.go +++ b/models/order.go @@ -159,7 +159,10 @@ func (o *Order) CalculateTotal(settings *calculator.Settings, claims map[string] o.SubTotal = price.Subtotal o.Taxes = price.Taxes o.Discount = price.Discount - o.Total = price.Total + + if price.Total > 0 { + o.Total = uint64(price.Total) + } } func (o *Order) BeforeDelete(tx *gorm.DB) error {