Skip to content

Commit

Permalink
Switch UnitInterval to use safe-decimal
Browse files Browse the repository at this point in the history
  • Loading branch information
lehins committed Jun 10, 2021
1 parent 2a5f490 commit 23708e1
Show file tree
Hide file tree
Showing 24 changed files with 162 additions and 177 deletions.
7 changes: 3 additions & 4 deletions alonzo/impl/src/Cardano/Ledger/Alonzo/PParams.hs
Expand Up @@ -54,7 +54,6 @@ import Cardano.Ledger.BaseTypes
( Nonce (NeutralNonce),
StrictMaybe (..),
UnitInterval,
interval0,
fromSMaybe,
)
import Cardano.Ledger.Coin (Coin (..))
Expand Down Expand Up @@ -290,9 +289,9 @@ emptyPParams =
_eMax = EpochNo 0,
_nOpt = 100,
_a0 = 0,
_rho = interval0,
_tau = interval0,
_d = interval0,
_rho = minBound,
_tau = minBound,
_d = minBound,
_extraEntropy = NeutralNonce,
_protocolVersion = ProtVer 0 0,
_minPoolCost = mempty,
Expand Down
2 changes: 2 additions & 0 deletions cardano-ledger-core/cardano-ledger-core.cabal
Expand Up @@ -65,6 +65,7 @@ library
cardano-prelude,
cardano-slotting,
containers,
data-default-class,
deepseq,
groups,
iproute,
Expand All @@ -73,6 +74,7 @@ library
nothunks,
partial-order,
quiet,
safe-decimal,
scientific,
shelley-spec-non-integral,
small-steps,
Expand Down
103 changes: 63 additions & 40 deletions cardano-ledger-core/src/Cardano/Ledger/BaseTypes.hs
@@ -1,3 +1,4 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
Expand All @@ -22,17 +23,14 @@ module Cardano.Ledger.BaseTypes
Nonce (..),
Seed (..),
UnitInterval,
unitScale,
fromScientificUnitInterval,
fpPrecision,
interval0,
intervalValue,
unitIntervalToRational,
unitIntervalFromRational,
invalidKey,
mkNonceFromOutputVRF,
mkNonceFromNumber,
mkUnitInterval,
truncateUnitInterval,
StrictMaybe (..),
strictMaybeToMaybe,
maybeToStrictMaybe,
Expand Down Expand Up @@ -66,29 +64,34 @@ import Cardano.Binary
import Cardano.Crypto.Hash
import Cardano.Crypto.Util (SignableRepresentation (..))
import qualified Cardano.Crypto.VRF as VRF
import Cardano.Ledger.Serialization (decodeRecordSum, ratioFromCBOR, ratioToCBOR)
import Cardano.Ledger.Serialization (decodeRecordSum)
import Cardano.Prelude (NFData, cborError)
import Cardano.Slotting.EpochInfo
import Cardano.Slotting.Time (SystemStart)
import Control.Exception (throw)
import Control.Monad (when)
import Control.Monad.Trans.Reader (ReaderT)
import Data.Aeson (FromJSON (..), ToJSON (..))
import qualified Data.Binary.Put as B
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import Data.Coders (invalidKey)
import Data.Coerce
import Data.Default.Class (Default (def))
import qualified Data.Fixed as FP (Fixed, HasResolution, resolution)
import Data.Functor.Identity
import Data.Maybe.Strict
import Data.Ratio (Ratio, denominator, numerator, (%))
import Data.Scientific (Scientific)
import Data.Proxy
import Data.Scientific (Scientific, base10Exponent, coefficient, normalize)
import Data.Text (Text)
import qualified Data.Text as Text
import Data.Text.Encoding (encodeUtf8)
import Data.Word (Word16, Word64, Word8)
import GHC.Exception.Type (Exception)
import GHC.Generics (Generic)
import GHC.TypeLits
import NoThunks.Class (NoThunks (..))
import Numeric.Decimal
import Numeric.Natural (Natural)
import Shelley.Spec.NonIntegral (ln')

Expand All @@ -104,56 +107,76 @@ type FixedPoint = Digits34
fpPrecision :: FixedPoint
fpPrecision = (10 :: FixedPoint) ^ (34 :: Integer)

-- | Maximum precision possible for unit interval when backed by a 64bit
type UnitScale = 19

unitScale :: Int
unitScale = fromInteger (natVal (Proxy :: Proxy UnitScale))

newtype RawUnit
= RawUnit Word64
deriving newtype (Show, Eq, Ord, Num, Real, Enum, Integral, NFData)

instance Bounded RawUnit where
minBound = 0
maxBound = RawUnit (10 ^ unitScale)

-- | Type to represent a value in the unit interval [0; 1]
newtype UnitInterval = UnsafeUnitInterval (Ratio Word64)
deriving (Show, Ord, Eq, Generic)
deriving newtype (NoThunks, NFData)
newtype UnitInterval = UnitInterval
{ unitDecimal :: Decimal RoundHalfEven UnitScale RawUnit
}
deriving (Generic)
deriving newtype (Show, Ord, Eq, Bounded, NFData)

instance Default UnitInterval where
def = minBound

instance NoThunks UnitInterval where
noThunks ctx = noThunks ctx . coerce @_ @Word64
wNoThunks ctx = wNoThunks ctx . coerce @_ @Word64

instance ToCBOR UnitInterval where
toCBOR (UnsafeUnitInterval u) = ratioToCBOR u
toCBOR = toCBOR . coerce @_ @Word64

instance FromCBOR UnitInterval where
fromCBOR = do
r <- ratioFromCBOR
case mkUnitInterval r of
Nothing -> cborError $ DecoderErrorCustom "UnitInterval" (Text.pack $ show r)
Just u -> pure u
ru <- RawUnit <$> fromCBOR
if ru < minBound || ru > maxBound
then cborError $ DecoderErrorCustom "UnitInterval" (Text.pack $ show ru)
else pure $ UnitInterval $ Decimal ru

instance ToJSON UnitInterval where
toJSON ui = toJSON (fromRational (unitIntervalToRational ui) :: Scientific)
toJSON = toJSON . toScientificDecimal . unitDecimal

instance FromJSON UnitInterval where
parseJSON v = do
d <- parseJSON v
either fail pure $ fromScientificUnitInterval d

-- | safe-decimal-0.2.1.0 has a fixed version of `fromScientificDecimalBounded` function
-- that makes both of the `when` checks and `normalize` call redundant.
fromScientificUnitInterval :: Scientific -> Either String UnitInterval
fromScientificUnitInterval d =
maybe (Left "The value must be between 0 and 1 (inclusive)") Right $ mkUnitInterval (realToFrac d)
fromScientificUnitInterval (normalize -> num) = do
when (coeff < 0) $ Left "Negative values aren't allowed - protect against underflow"
when (coeff > toInteger (maxBound :: Word64) || exp10 < 0 || exp10 > unitScale) $
Left "Precision is too large - protection against overflow"
either (Left . show) (Right . UnitInterval) . fromScientificDecimalBounded $ num
where
coeff = coefficient num
exp10 = negate (base10Exponent num)

unitIntervalToRational :: UnitInterval -> Rational
unitIntervalToRational (UnsafeUnitInterval x) =
(fromIntegral $ numerator x) % (fromIntegral $ denominator x)
unitIntervalToRational = toRationalDecimal . unitDecimal

unitIntervalFromRational :: Rational -> UnitInterval
unitIntervalFromRational = truncateUnitInterval . fromRational

-- | Return a `UnitInterval` type if `r` is in [0; 1].
mkUnitInterval :: Ratio Word64 -> Maybe UnitInterval
mkUnitInterval r = if r <= 1 && r >= 0 then Just $ UnsafeUnitInterval r else Nothing

-- | Convert a rational to a `UnitInterval` by ignoring its integer part.
truncateUnitInterval :: Ratio Word64 -> UnitInterval
truncateUnitInterval (abs -> r) = case (numerator r, denominator r) of
(n, d) | n > d -> UnsafeUnitInterval $ (n `mod` d) % d
_ -> UnsafeUnitInterval r

-- | Get rational value of `UnitInterval` type
intervalValue :: UnitInterval -> Ratio Word64
intervalValue (UnsafeUnitInterval v) = v

interval0 :: UnitInterval
interval0 = UnsafeUnitInterval 0
-- | Returns `Nothing` when supplied value is not in the [0, 1] range. When rational
-- cannot be represented as decimal exactly it will be rounded.
--
-- ===__Example__
--
-- >>> unitIntervalFromRational $ 2 % 3
-- Just 0.6666666666666666667
unitIntervalFromRational :: Rational -> Maybe UnitInterval
unitIntervalFromRational r = UnitInterval <$> fromRationalDecimalBoundedWithRounding r

-- | Evolving nonce type.
data Nonce
Expand Down Expand Up @@ -300,7 +323,7 @@ mkActiveSlotCoeff v =
ActiveSlotCoeff
{ unActiveSlotVal = v,
unActiveSlotLog =
if (intervalValue v) == 1
if v == maxBound
then -- If the active slot coefficient is equal to one,
-- then nearly every stake pool can produce a block every slot.
-- In this degenerate case, where ln (1-f) is not defined,
Expand Down
2 changes: 1 addition & 1 deletion cardano-ledger-test/cardano-ledger-test.cabal
Expand Up @@ -35,7 +35,7 @@ library
import: base, project-config
hs-source-dirs: src
exposed-modules:
Test.Cardano.Ledger.BaseTypesSpec
Test.Cardano.Ledger.BaseTypes
Test.Cardano.Ledger.Examples.TwoPhaseValidation
Test.Cardano.Ledger.Generic.Proof
Test.Cardano.Ledger.Generic.Indexed
Expand Down
@@ -1,47 +1,21 @@
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Test.Cardano.Ledger.BaseTypesSpec where
module Test.Cardano.Ledger.BaseTypes where

import Cardano.Ledger.BaseTypes
import Data.Aeson
import Data.Either
import Data.GenValidity
import Data.GenValidity (genValid)
import Data.GenValidity.Scientific ()
import Data.Maybe
import Data.Ratio
import Data.Scientific
import Data.Word
import Test.QuickCheck
import Test.Shelley.Spec.Ledger.Serialisation.EraIndepGenerators ()
import Test.Tasty
import Test.Tasty.HUnit
import Test.Tasty.QuickCheck

instance Arbitrary UnitInterval where
arbitrary = do
w2 <- genValid
if w2 == 0
then arbitrary
else do
w1 <- genValid
pure $
fromMaybe (error "Impossible: Abitrary UnitInterval") $
mkUnitInterval $
if w1 > w2
then w2 % w1
else w1 % w2

partialUnitIntervalFromRational :: Rational -> UnitInterval
partialUnitIntervalFromRational r
| toInteger n /= numerator r || toInteger d /= denominator r =
error $ "Overflow detected: " ++ show r
| otherwise =
fromMaybe (error "Unexpected negative interval") $ mkUnitInterval (n % d)
where
n = fromIntegral (numerator r)
d = fromIntegral (denominator r)

baseTypesSpec :: TestTree
baseTypesSpec = do
baseTypesTests :: TestTree
baseTypesTests = do
testGroup
"UnitInterval"
[ testGroup
Expand All @@ -55,19 +29,21 @@ baseTypesSpec = do
expectLeft (fromScientificUnitInterval 1.01),
testCase "Check negative" $
expectLeft (fromScientificUnitInterval (-1e-3)),
testProperty "Rational roundtrip (mkUnitInterval . intervalValue)" $ \ui ->
Just ui === mkUnitInterval (intervalValue ui),
testProperty
"Rational roundtrip (unitIntervalFromRational . intervalValue)"
$ \ui ->
Just ui === unitIntervalFromRational (unitIntervalToRational ui),
testProperty "Scientific valid roundtrip" $ \ui ->
case fromRationalRepetendLimited 20 (unitIntervalToRational ui) of
Left (s, r) ->
ui === partialUnitIntervalFromRational (toRational s + r)
Just ui === unitIntervalFromRational (toRational s + r)
Right (s, Nothing) ->
classify
True
"no-repeat digits"
(Right ui === fromScientificUnitInterval s)
Right (s, Just r) ->
ui === partialUnitIntervalFromRational (toRationalRepetend s r),
Just ui === unitIntervalFromRational (toRationalRepetend s r),
localOption (mkTimeout 500000) $
testProperty
"Scientific roundtrip (fromRational . unitIntervalToRational . fromScientific)"
Expand Down
6 changes: 2 additions & 4 deletions cardano-ledger-test/test/Tests.hs
Expand Up @@ -8,9 +8,7 @@

module Main where

import Test.Cardano.Ledger.BaseTypesSpec
( baseTypesSpec,
)
import Test.Cardano.Ledger.BaseTypes ( baseTypesTests )
import Test.Cardano.Ledger.Examples.TwoPhaseValidation
( alonzoBBODYexamples,
alonzoUTXOWexamples,
Expand All @@ -31,7 +29,7 @@ mainTests :: TestTree
mainTests =
testGroup
"cardano-core"
[ baseTypesSpec,
[ baseTypesTests,
testGroup
"STS Tests"
[ alonzoUTXOWexamples,
Expand Down
Expand Up @@ -90,7 +90,6 @@ import Cardano.Ledger.BaseTypes
StrictMaybe (..),
activeSlotLog,
activeSlotVal,
intervalValue,
mkNonceFromNumber,
mkNonceFromOutputVRF,
strictMaybeToMaybe,
Expand Down Expand Up @@ -808,7 +807,7 @@ checkLeaderValue ::
ActiveSlotCoeff ->
Bool
checkLeaderValue certVRF σ f =
if (intervalValue $ activeSlotVal f) == 1
if activeSlotVal f == maxBound
then -- If the active slot coefficient is equal to one,
-- then nearly every stake pool can produce a block every slot.
-- In this degenerate case, where ln (1-f) is not defined,
Expand Down

0 comments on commit 23708e1

Please sign in to comment.