From 03fae8464d761562eec68c5db9fd0addb54cda6d Mon Sep 17 00:00:00 2001 From: Sirui Lu Date: Fri, 17 Feb 2023 21:24:55 -0800 Subject: [PATCH 1/6] :sparkles: SafeDivision and SafeLinearArith --- src/Grisette/Core.hs | 21 +-- src/Grisette/Core/Data/Class/Error.hs | 12 ++ src/Grisette/Core/Data/Class/Integer.hs | 110 +++++++++--- src/Grisette/IR/SymPrim/Data/SymPrim.hs | 49 ++++- test/Grisette/IR/SymPrim/Data/SymPrimTests.hs | 167 ++++++++++++++++-- 5 files changed, 302 insertions(+), 57 deletions(-) diff --git a/src/Grisette/Core.hs b/src/Grisette/Core.hs index 004d5888..cf41a253 100644 --- a/src/Grisette/Core.hs +++ b/src/Grisette/Core.hs @@ -190,9 +190,8 @@ module Grisette.Core someBVExtract', SizedBV (..), sizedBVExtract, - SignedDivMod (..), - UnsignedDivMod (..), - SignedQuotRem (..), + SafeDivision (..), + SafeLinearArith (..), SymIntegerOp, Function (..), @@ -794,6 +793,7 @@ module Grisette.Core TransformError (..), symAssert, symAssume, + symAssertWith, symAssertTransformableError, symThrowTransformableError, @@ -902,15 +902,6 @@ module Grisette.Core -- deriving (Mergeable, SEq) via (Default Error) -- :} -- - -- Then we define how to transform the generic errors to the error type. - -- - -- >>> :{ - -- instance TransformError ArithException Error where - -- transformError _ = Arith - -- instance TransformError AssertionError Error where - -- transformError _ = Assert - -- :} - -- -- Then we can perform the symbolic evaluation. The `divs` function throws -- 'ArithException' when the divisor is 0, which would be transformed to -- @Arith@, and the `symAssert` function would throw 'AssertionError' when @@ -918,14 +909,16 @@ module Grisette.Core -- -- >>> let x = "x" :: SymInteger -- >>> let y = "y" :: SymInteger + -- >>> assert = symAssertWith Assert + -- >>> sdiv = safeDiv Arith -- >>> :{ -- -- equivalent concrete program: -- -- let x = x `div` y -- -- if z > 0 then assert (x >= y) else return () -- res :: ExceptT Error UnionM () -- res = do - -- z <- x `divs` y - -- mrgIf (z >~ 0) (symAssert (x >=~ y)) (return ()) + -- z <- x `sdiv` y + -- mrgIf (z >~ 0) (assert (x >=~ y)) (return ()) -- :} -- -- Then we can ask the solver to find a counter-example that would lead to diff --git a/src/Grisette/Core/Data/Class/Error.hs b/src/Grisette/Core/Data/Class/Error.hs index fabc27d2..ee5fe83c 100644 --- a/src/Grisette/Core/Data/Class/Error.hs +++ b/src/Grisette/Core/Data/Class/Error.hs @@ -15,6 +15,7 @@ module Grisette.Core.Data.Class.Error TransformError (..), -- * Throwing error + symAssertWith, symAssertTransformableError, symThrowTransformableError, ) @@ -104,3 +105,14 @@ symAssertTransformableError :: erm () symAssertTransformableError err cond = mrgIf cond (return ()) (symThrowTransformableError err) {-# INLINE symAssertTransformableError #-} + +symAssertWith :: + ( Mergeable e, + MonadError e erm, + MonadUnion erm + ) => + e -> + SymBool -> + erm () +symAssertWith err cond = mrgIf cond (return ()) (throwError err) +{-# INLINE symAssertWith #-} diff --git a/src/Grisette/Core/Data/Class/Integer.hs b/src/Grisette/Core/Data/Class/Integer.hs index 42c5f89e..93705eec 100644 --- a/src/Grisette/Core/Data/Class/Integer.hs +++ b/src/Grisette/Core/Data/Class/Integer.hs @@ -13,9 +13,8 @@ module Grisette.Core.Data.Class.Integer ( -- * Symbolic integer operations ArithException (..), - SignedDivMod (..), - UnsignedDivMod (..), - SignedQuotRem (..), + SafeDivision (..), + SafeLinearArith (..), SymIntegerOp, ) where @@ -25,48 +24,111 @@ import Control.Monad.Except import Grisette.Core.Control.Monad.Union import Grisette.Core.Data.Class.Bool import Grisette.Core.Data.Class.Error +import Grisette.Core.Data.Class.Mergeable import Grisette.Core.Data.Class.SOrd +import Grisette.Core.Data.Class.SimpleMergeable import Grisette.Core.Data.Class.Solvable -- $setup -- >>> import Grisette.Core -- >>> import Grisette.IR.SymPrim --- | Safe signed 'div' and 'mod' with monadic error handling in multi-path --- execution. These procedures show throw 'DivideByZero' exception when the +-- | Safe division with monadic error handling in multi-path +-- execution. These procedures throw an exception when the -- divisor is zero. The result should be able to handle errors with --- `MonadError`, and the error type should be compatible with 'ArithException' --- (see 'TransformError' for more details). -class SignedDivMod a where +-- `MonadError`. +class (SOrd a, Num a, Mergeable a) => SafeDivision a where -- | Safe signed 'div' with monadic error handling in multi-path execution. -- - -- >>> divs (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM SymInteger + -- >>> safeDiv AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM SymInteger -- ExceptT {If (= b 0) (Left AssertionError) (Right (div a b))} - divs :: (MonadError e uf, MonadUnion uf, TransformError ArithException e) => a -> a -> uf a + safeDiv :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf a + safeDiv e l r = do + (d, _) <- safeDivMod e l r + mrgSingle d -- | Safe signed 'mod' with monadic error handling in multi-path execution. -- - -- >>> mods (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM SymInteger + -- >>> safeMod AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM SymInteger -- ExceptT {If (= b 0) (Left AssertionError) (Right (mod a b))} - mods :: (MonadError e uf, MonadUnion uf, TransformError ArithException e) => a -> a -> uf a + safeMod :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf a + safeMod e l r = do + (_, m) <- safeDivMod e l r + mrgSingle m --- | Safe unsigned 'div' and 'mod' with monadic error handling in multi-path --- execution. These procedures show throw 'DivideByZero' exception when the --- divisor is zero. The result should be able to handle errors with --- `MonadError`, and the error type should be compatible with 'ArithException' --- (see 'TransformError' for more details). -class UnsignedDivMod a where - udivs :: (MonadError e uf, MonadUnion uf, TransformError ArithException e) => a -> a -> uf a - umods :: (MonadError e uf, MonadUnion uf, TransformError ArithException e) => a -> a -> uf a + -- | Safe signed 'div' with monadic error handling in multi-path execution. + -- + -- >>> safeDivMod AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM (SymInteger, SymInteger) + -- ExceptT {If (= b 0) (Left AssertionError) (Right ((div a b),(mod a b)))} + safeDivMod :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf (a, a) + safeDivMod e l r = do + d <- safeDiv e l r + m <- safeMod e l r + mrgSingle (d, m) + + -- | Safe signed 'quot' with monadic error handling in multi-path execution. + safeQuot :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf a + safeQuot e l r = do + (d, m) <- safeDivMod e l r + mrgIf + ((l >=~ 0 &&~ r >~ 0) ||~ (l <=~ 0 &&~ r <~ 0) ||~ m ==~ 0) + (mrgSingle d) + (mrgSingle $ d + 1) + + -- | Safe signed 'rem' with monadic error handling in multi-path execution. + safeRem :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf a + safeRem e l r = do + (d, m) <- safeDivMod e l r + mrgIf + ((l >=~ 0 &&~ r >~ 0) ||~ (l <=~ 0 &&~ r <~ 0) ||~ m ==~ 0) + (mrgSingle m) + (mrgSingle $ m - r) + + -- | Safe signed 'quotRem' with monadic error handling in multi-path execution. + safeQuotRem :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf (a, a) + safeQuotRem e l r = do + (d, m) <- safeDivMod e l r + mrgIf + ((l >=~ 0 &&~ r >~ 0) ||~ (l <=~ 0 &&~ r <~ 0) ||~ m ==~ 0) + (mrgSingle (d, m)) + (mrgSingle $ (d + 1, m - r)) + + {-# MINIMAL (safeDivMod | (safeDiv, safeMod)) #-} + +class SafeLinearArith a where + -- | Safe signed '+' with monadic error handling in multi-path execution. + -- Overflows are treated as errors. + -- + -- >>> safeAdd AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM SymInteger + -- ExceptT {Right (+ a b)} + -- >>> safeAdd AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM (SymIntN 4) + -- ExceptT {If (|| (&& (< 0x0 a) (&& (< 0x0 b) (< (+ a b) 0x0))) (&& (< a 0x0) (&& (< b 0x0) (<= 0x0 (+ a b))))) (Left AssertionError) (Right (+ a b))} + safeAdd :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf a + + -- | Safe signed 'negate' with monadic error handling in multi-path execution. + -- Overflows are treated as errors. + -- + -- >>> safeNeg AssertionError (ssym "a") :: ExceptT AssertionError UnionM SymInteger + -- ExceptT {Right (- a)} + -- >>> safeNeg AssertionError (ssym "a") :: ExceptT AssertionError UnionM (SymIntN 4) + -- ExceptT {If (= a 0x8) (Left AssertionError) (Right (- a))} + safeNeg :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> uf a + + -- | Safe signed '-' with monadic error handling in multi-path execution. + -- Overflows are treated as errors. + -- + -- >>> safeMinus AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM SymInteger + -- ExceptT {Right (+ a (- b))} + -- >>> safeMinus AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM (SymIntN 4) + -- ExceptT {If (|| (&& (<= 0x0 a) (&& (< b 0x0) (< (+ a (- b)) 0x0))) (&& (< a 0x0) (&& (< 0x0 b) (< 0x0 (+ a (- b)))))) (Left AssertionError) (Right (+ a (- b)))} + safeMinus :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf a -- | Safe signed 'quot' and 'rem' with monadic error handling in multi-path -- execution. These procedures show throw 'DivideByZero' exception when the -- divisor is zero. The result should be able to handle errors with -- `MonadError`, and the error type should be compatible with 'ArithException' -- (see 'TransformError' for more details). -class SignedQuotRem a where - quots :: (MonadError e uf, MonadUnion uf, TransformError ArithException e) => a -> a -> uf a - rems :: (MonadError e uf, MonadUnion uf, TransformError ArithException e) => a -> a -> uf a +class SafeQuotRem a -- | Aggregation for the operations on symbolic integer types -class (Num a, SEq a, SOrd a, Solvable Integer a) => SymIntegerOp a +class (Num a, SEq a, SOrd a, Solvable Integer a, SafeDivision a, SafeLinearArith a) => SymIntegerOp a diff --git a/src/Grisette/IR/SymPrim/Data/SymPrim.hs b/src/Grisette/IR/SymPrim/Data/SymPrim.hs index f0441df7..daa47f5f 100644 --- a/src/Grisette/IR/SymPrim/Data/SymPrim.hs +++ b/src/Grisette/IR/SymPrim/Data/SymPrim.hs @@ -140,18 +140,23 @@ newtype SymBool = SymBool {underlyingBoolTerm :: Term Bool} newtype SymInteger = SymInteger {underlyingIntegerTerm :: Term Integer} deriving (Lift, NFData, Generic) -instance SignedDivMod SymInteger where - divs (SymInteger l) rs@(SymInteger r) = +instance SafeDivision SymInteger where + safeDiv e (SymInteger l) rs@(SymInteger r) = mrgIf (rs ==~ con 0) - (throwError $ transformError DivideByZero) + (throwError e) (mrgReturn $ SymInteger $ pevalDivIntegerTerm l r) - mods (SymInteger l) rs@(SymInteger r) = + safeMod e (SymInteger l) rs@(SymInteger r) = mrgIf (rs ==~ con 0) - (throwError $ transformError DivideByZero) + (throwError e) (mrgReturn $ SymInteger $ pevalModIntegerTerm l r) +instance SafeLinearArith SymInteger where + safeAdd e ls rs = mrgReturn $ ls + rs + safeNeg e v = mrgReturn $ -v + safeMinus e ls rs = mrgReturn $ ls - rs + instance SymIntegerOp SymInteger -- | Symbolic signed bit vector type. Indexed with the bit width. @@ -173,6 +178,23 @@ instance SymIntegerOp SymInteger newtype SymIntN (n :: Nat) = SymIntN {underlyingIntNTerm :: Term (IntN n)} deriving (Lift, NFData, Generic) +instance (KnownNat n, 1 <= n) => SafeLinearArith (SymIntN n) where + safeAdd e ls rs = + mrgIf + ((ls >~ 0 &&~ rs >~ 0 &&~ res <~ 0) ||~ (ls <~ 0 &&~ rs <~ 0 &&~ res >=~ 0)) + (throwError e) + (mrgReturn res) + where + res = ls + rs + safeNeg e v = mrgIf (v ==~ con minBound) (throwError e) (mrgReturn $ -v) + safeMinus e ls rs = + mrgIf + ((ls >=~ 0 &&~ rs <~ 0 &&~ res <~ 0) ||~ (ls <~ 0 &&~ rs >~ 0 &&~ res >~ 0)) + (throwError e) + (mrgReturn res) + where + res = ls - rs + -- | Symbolic signed bit vector type. Not indexed, but the bit width is -- fixed at the creation time. -- @@ -243,6 +265,23 @@ binSomeSymIntNR2 op str (SomeSymIntN (l :: SymIntN l)) (SomeSymIntN (r :: SymInt newtype SymWordN (n :: Nat) = SymWordN {underlyingWordNTerm :: Term (WordN n)} deriving (Lift, NFData, Generic) +instance (KnownNat n, 1 <= n) => SafeLinearArith (SymWordN n) where + safeAdd e ls rs = + mrgIf + (ls >~ res ||~ rs >~ res) + (throwError e) + (mrgReturn res) + where + res = ls + rs + safeNeg e v = mrgIf (v /=~ 0) (throwError e) (mrgReturn v) + safeMinus e ls rs = + mrgIf + (rs >~ ls) + (throwError e) + (mrgReturn res) + where + res = ls - rs + -- | Symbolic unsigned bit vector type. Not indexed, but the bit width is -- fixed at the creation time. -- diff --git a/test/Grisette/IR/SymPrim/Data/SymPrimTests.hs b/test/Grisette/IR/SymPrim/Data/SymPrimTests.hs index f20bfa9e..e690fde4 100644 --- a/test/Grisette/IR/SymPrim/Data/SymPrimTests.hs +++ b/test/Grisette/IR/SymPrim/Data/SymPrimTests.hs @@ -164,41 +164,113 @@ symPrimTests = signum (ssym "a" :: SymInteger) @=? SymInteger (pevalSignumNumTerm (ssymTerm "a")) ], testGroup - "SignedDivMod" - [ testProperty "divs on concrete" $ \(i :: Integer, j :: Integer) -> + "SafeDivision" + [ testProperty "safeDiv on concrete" $ \(i :: Integer, j :: Integer) -> ioProperty $ - divs (con i :: SymInteger) (con j) + safeDiv () (con i :: SymInteger) (con j) @=? if j == 0 then merge $ throwError () :: ExceptT () UnionM SymInteger else mrgSingle $ con $ i `div` j, - testCase "divs when divided by zero" $ do - divs (ssym "a" :: SymInteger) (con 0) + testCase "safeDiv when divided by zero" $ do + safeDiv () (ssym "a" :: SymInteger) (con 0) @=? (merge $ throwError () :: ExceptT () UnionM SymInteger), - testCase "divs on symbolic" $ do - divs (ssym "a" :: SymInteger) (ssym "b") + testCase "safeDiv on symbolic" $ do + safeDiv () (ssym "a" :: SymInteger) (ssym "b") @=? ( mrgIf ((ssym "b" :: SymInteger) ==~ con (0 :: Integer) :: SymBool) (throwError ()) (mrgSingle $ SymInteger $ pevalDivIntegerTerm (ssymTerm "a") (ssymTerm "b")) :: ExceptT () UnionM SymInteger ), - testProperty "mods on concrete" $ \(i :: Integer, j :: Integer) -> + testProperty "safeMod on concrete" $ \(i :: Integer, j :: Integer) -> ioProperty $ - mods (con i :: SymInteger) (con j) + safeMod () (con i :: SymInteger) (con j) @=? if j == 0 then merge $ throwError () :: ExceptT () UnionM SymInteger else mrgSingle $ con $ i `mod` j, - testCase "mods when divided by zero" $ do - mods (ssym "a" :: SymInteger) (con 0) + testCase "safeMod when divided by zero" $ do + safeMod () (ssym "a" :: SymInteger) (con 0) @=? (merge $ throwError () :: ExceptT () UnionM SymInteger), - testCase "mods on symbolic" $ do - mods (ssym "a" :: SymInteger) (ssym "b") + testCase "safeMod on symbolic" $ do + safeMod () (ssym "a" :: SymInteger) (ssym "b") @=? ( mrgIf ((ssym "b" :: SymInteger) ==~ con (0 :: Integer) :: SymBool) (throwError ()) (mrgSingle $ SymInteger $ pevalModIntegerTerm (ssymTerm "a") (ssymTerm "b")) :: ExceptT () UnionM SymInteger - ) + ), + testProperty "safeDivMod on concrete" $ \(i :: Integer, j :: Integer) -> + ioProperty $ + safeDivMod () (con i :: SymInteger) (con j) + @=? if j == 0 + then merge $ throwError () :: ExceptT () UnionM (SymInteger, SymInteger) + else mrgSingle $ (con $ i `div` j, con $ i `mod` j), + testCase "safeDivMod when divided by zero" $ do + safeDivMod () (ssym "a" :: SymInteger) (con 0) + @=? (merge $ throwError () :: ExceptT () UnionM (SymInteger, SymInteger)), + testCase "safeDivMod on symbolic" $ do + safeDivMod () (ssym "a" :: SymInteger) (ssym "b") + @=? ( mrgIf + ((ssym "b" :: SymInteger) ==~ con (0 :: Integer) :: SymBool) + (throwError ()) + ( mrgSingle + ( SymInteger $ pevalDivIntegerTerm (ssymTerm "a") (ssymTerm "b"), + SymInteger $ pevalModIntegerTerm (ssymTerm "a") (ssymTerm "b") + ) + ) :: + ExceptT () UnionM (SymInteger, SymInteger) + ), + testProperty "safeQuot on concrete" $ \(i :: Integer, j :: Integer) -> + ioProperty $ + safeQuot () (con i :: SymInteger) (con j) + @=? if j == 0 + then merge $ throwError () :: ExceptT () UnionM SymInteger + else mrgSingle $ con $ i `quot` j, + testCase "safeQuot when divided by zero" $ do + safeQuot () (ssym "a" :: SymInteger) (con 0) + @=? (merge $ throwError () :: ExceptT () UnionM SymInteger), + testProperty "safeRem on concrete" $ \(i :: Integer, j :: Integer) -> + ioProperty $ + safeRem () (con i :: SymInteger) (con j) + @=? if j == 0 + then merge $ throwError () :: ExceptT () UnionM SymInteger + else mrgSingle $ con $ i `rem` j, + testCase "safeRem when divided by zero" $ do + safeRem () (ssym "a" :: SymInteger) (con 0) + @=? (merge $ throwError () :: ExceptT () UnionM SymInteger), + testProperty "safeQuotRem on concrete" $ \(i :: Integer, j :: Integer) -> + ioProperty $ + safeQuotRem () (con i :: SymInteger) (con j) + @=? if j == 0 + then merge $ throwError () :: ExceptT () UnionM (SymInteger, SymInteger) + else mrgSingle $ (con $ i `quot` j, con $ i `rem` j), + testCase "safeQuotRem when divided by zero" $ do + safeQuotRem () (ssym "a" :: SymInteger) (con 0) + @=? (merge $ throwError () :: ExceptT () UnionM (SymInteger, SymInteger)) + ], + testGroup + "SafeLinearArith" + [ testProperty "safeAdd on concrete" $ \(i :: Integer, j :: Integer) -> + ioProperty $ + safeAdd () (con i :: SymInteger) (con j) + @=? (mrgSingle $ con $ i + j :: ExceptT () UnionM SymInteger), + testCase "safeAdd on symbolic" $ do + safeAdd () (ssym "a" :: SymInteger) (ssym "b") + @=? (mrgSingle $ SymInteger $ pevalAddNumTerm (ssymTerm "a") (ssymTerm "b") :: ExceptT () UnionM SymInteger), + testProperty "safeNeg on concrete" $ \(i :: Integer) -> + ioProperty $ + safeNeg () (con i :: SymInteger) + @=? (mrgSingle $ con $ -i :: ExceptT () UnionM SymInteger), + testCase "safeNeg on symbolic" $ do + safeNeg () (ssym "a" :: SymInteger) + @=? (mrgSingle $ SymInteger $ pevalUMinusNumTerm (ssymTerm "a") :: ExceptT () UnionM SymInteger), + testProperty "safeMinus on concrete" $ \(i :: Integer, j :: Integer) -> + ioProperty $ + safeMinus () (con i :: SymInteger) (con j) + @=? (mrgSingle $ con $ i - j :: ExceptT () UnionM SymInteger), + testCase "safeMinus on symbolic" $ do + safeMinus () (ssym "a" :: SymInteger) (ssym "b") + @=? (mrgSingle $ SymInteger $ pevalMinusNumTerm (ssymTerm "a") (ssymTerm "b") :: ExceptT () UnionM SymInteger) ], testGroup "SOrd" @@ -257,6 +329,73 @@ symPrimTests = signum au @=? SymWordN (pevalSignumNumTerm aut) signum as @=? SymIntN (pevalSignumNumTerm ast) ], + testGroup + "SafeLinearArith" + [ testGroup + "IntN" + [ testProperty "safeAdd on concrete" $ \(i :: Int8, j :: Int8) -> + ioProperty $ + let iint = fromIntegral i :: Integer + jint = fromIntegral j + in safeAdd () (toSym i :: SymIntN 8) (toSym j) + @=? ( mrgIf + (iint + jint ==~ fromIntegral (i + j)) + (mrgSingle $ toSym $ i + j :: ExceptT () UnionM (SymIntN 8)) + (throwError ()) + ), + testProperty "safeMinus on concrete" $ \(i :: Int8, j :: Int8) -> + ioProperty $ + let iint = fromIntegral i :: Integer + jint = fromIntegral j + in safeMinus () (toSym i :: SymIntN 8) (toSym j) + @=? ( mrgIf + (iint - jint ==~ fromIntegral (i - j)) + (mrgSingle $ toSym $ i - j :: ExceptT () UnionM (SymIntN 8)) + (throwError ()) + ), + testProperty "safeNeg on concrete" $ \(i :: Int8) -> + ioProperty $ + let iint = fromIntegral i :: Integer + in safeNeg () (toSym i :: SymIntN 8) + @=? ( mrgIf + (-iint ==~ fromIntegral (-i)) + (mrgSingle $ toSym $ -i :: ExceptT () UnionM (SymIntN 8)) + (throwError ()) + ) + ], + testGroup + "WordN" + [ testProperty "safeAdd on concrete" $ \(i :: Word8, j :: Word8) -> + ioProperty $ + let iint = fromIntegral i :: Integer + jint = fromIntegral j + in safeAdd () (toSym i :: SymWordN 8) (toSym j) + @=? ( mrgIf + (iint + jint ==~ fromIntegral (i + j)) + (mrgSingle $ toSym $ i + j :: ExceptT () UnionM (SymWordN 8)) + (throwError ()) + ), + testProperty "safeMinus on concrete" $ \(i :: Word8, j :: Word8) -> + ioProperty $ + let iint = fromIntegral i :: Integer + jint = fromIntegral j + in safeMinus () (toSym i :: SymWordN 8) (toSym j) + @=? ( mrgIf + (iint - jint ==~ fromIntegral (i - j)) + (mrgSingle $ toSym $ i - j :: ExceptT () UnionM (SymWordN 8)) + (throwError ()) + ), + testProperty "safeNeg on concrete" $ \(i :: Word8) -> + ioProperty $ + let iint = fromIntegral i :: Integer + in safeNeg () (toSym i :: SymWordN 8) + @=? ( mrgIf + (-iint ==~ fromIntegral (-i)) + (mrgSingle $ toSym $ -i :: ExceptT () UnionM (SymWordN 8)) + (throwError ()) + ) + ] + ], testGroup "SOrd" [ testProperty "SOrd on concrete" $ \(i :: Integer, j :: Integer) -> ioProperty $ do From 84a22669cea1f81c597ee3a59dcebd897ec31f61 Mon Sep 17 00:00:00 2001 From: Sirui Lu Date: Sat, 18 Feb 2023 11:06:28 -0800 Subject: [PATCH 2/6] :sparkles: SafeDivision for integers --- src/Grisette/Core/Data/BV.hs | 26 ++++++++++ src/Grisette/Core/Data/Class/Integer.hs | 68 ++++++++++++++++++++++--- 2 files changed, 86 insertions(+), 8 deletions(-) diff --git a/src/Grisette/Core/Data/BV.hs b/src/Grisette/Core/Data/BV.hs index ecb90255..845f0e87 100644 --- a/src/Grisette/Core/Data/BV.hs +++ b/src/Grisette/Core/Data/BV.hs @@ -314,6 +314,10 @@ instance (KnownNat n, 1 <= n) => Enum (WordN n) where enumFromThen = boundedEnumFromThen {-# INLINE enumFromThen #-} +instance Enum SomeWordN where + toEnum = error "SomeWordN is not really a Enum type as the bit width is unknown, please consider using WordN instead" + fromEnum = error "SomeWordN is not really a Enum type as the bit width is unknown, please consider using WordN instead" + instance (KnownNat n, 1 <= n) => Real (WordN n) where toRational (WordN n) = n % 1 @@ -330,6 +334,15 @@ instance (KnownNat n, 1 <= n) => Integral (WordN n) where divMod = quotRem toInteger (WordN n) = n +instance Integral SomeWordN where + quot = binSomeWordN' quot "quot" + rem = binSomeWordN' rem "rem" + quotRem = binSomeWordN'' quotRem "quotRem" + div = binSomeWordN' div "div" + mod = binSomeWordN' mod "mod" + divMod = binSomeWordN'' divMod "divMod" + toInteger = unarySomeWordN toInteger "toInteger" + instance (KnownNat n, 1 <= n) => Num (WordN n) where WordN x + WordN y = WordN (x + y) .&. maxBound WordN x * WordN y = WordN (x * y) .&. maxBound @@ -448,6 +461,10 @@ instance (KnownNat n, 1 <= n) => Enum (IntN n) where enumFromThen = boundedEnumFromThen {-# INLINE enumFromThen #-} +instance Enum SomeIntN where + toEnum = error "SomeIntN is not really a Enum type as the bit width is unknown, please consider using IntN instead" + fromEnum = error "SomeIntN is not really a Enum type as the bit width is unknown, please consider using IntN instead" + instance (KnownNat n, 1 <= n) => Real (IntN n) where toRational i = toInteger i % 1 @@ -489,6 +506,15 @@ instance (KnownNat n, 1 <= n) => Integral (IntN n) where in if signum x == -1 then -n else negate (toInteger x) _ -> undefined +instance Integral SomeIntN where + quot = binSomeIntN' quot "quot" + rem = binSomeIntN' rem "rem" + quotRem = binSomeIntN'' quotRem "quotRem" + div = binSomeIntN' div "div" + mod = binSomeIntN' mod "mod" + divMod = binSomeIntN'' divMod "divMod" + toInteger = unarySomeIntN toInteger "toInteger" + instance (KnownNat n, 1 <= n) => Num (IntN n) where IntN x + IntN y = IntN (x + y) .&. minusOneIntN (Proxy :: Proxy n) IntN x * IntN y = IntN (x * y) .&. minusOneIntN (Proxy :: Proxy n) diff --git a/src/Grisette/Core/Data/Class/Integer.hs b/src/Grisette/Core/Data/Class/Integer.hs index 93705eec..bed06af3 100644 --- a/src/Grisette/Core/Data/Class/Integer.hs +++ b/src/Grisette/Core/Data/Class/Integer.hs @@ -1,6 +1,10 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE Trustworthy #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} -- | -- Module : Grisette.Core.Data.Class.Integer @@ -28,6 +32,10 @@ import Grisette.Core.Data.Class.Mergeable import Grisette.Core.Data.Class.SOrd import Grisette.Core.Data.Class.SimpleMergeable import Grisette.Core.Data.Class.Solvable +import Data.Int +import Data.Word +import GHC.TypeNats +import Grisette.Core.Data.BV -- $setup -- >>> import Grisette.Core @@ -91,10 +99,61 @@ class (SOrd a, Num a, Mergeable a) => SafeDivision a where mrgIf ((l >=~ 0 &&~ r >~ 0) ||~ (l <=~ 0 &&~ r <~ 0) ||~ m ==~ 0) (mrgSingle (d, m)) - (mrgSingle $ (d + 1, m - r)) + (mrgSingle (d + 1, m - r)) {-# MINIMAL (safeDivMod | (safeDiv, safeMod)) #-} +#define SAFE_DIVISION_CONCRETE(type) \ +instance SafeDivision type where \ + safeDiv e _ r | r == 0 = merge $ throwError e; \ + safeDiv _ l r = mrgSingle $ l `div` r; \ + safeMod e _ r | r == 0 = merge $ throwError e; \ + safeMod _ l r = mrgSingle $ l `mod` r; \ + safeDivMod e _ r | r == 0 = merge $ throwError e; \ + safeDivMod _ l r = mrgSingle $ l `divMod` r; \ + safeQuot e _ r | r == 0 = merge $ throwError e; \ + safeQuot _ l r = mrgSingle $ l `quot` r; \ + safeRem e _ r | r == 0 = merge $ throwError e; \ + safeRem _ l r = mrgSingle $ l `rem` r; \ + safeQuotRem e _ r | r == 0 = merge $ throwError e; \ + safeQuotRem _ l r = mrgSingle $ l `quotRem` r + +#define SAFE_DIVISION_CONCRETE_BV(type) \ +instance (KnownNat n, 1 <= n) => SafeDivision (type n) where \ + safeDiv e _ r | r == 0 = merge $ throwError e; \ + safeDiv _ l r = mrgSingle $ l `div` r; \ + safeMod e _ r | r == 0 = merge $ throwError e; \ + safeMod _ l r = mrgSingle $ l `mod` r; \ + safeDivMod e _ r | r == 0 = merge $ throwError e; \ + safeDivMod _ l r = mrgSingle $ l `divMod` r; \ + safeQuot e _ r | r == 0 = merge $ throwError e; \ + safeQuot _ l r = mrgSingle $ l `quot` r; \ + safeRem e _ r | r == 0 = merge $ throwError e; \ + safeRem _ l r = mrgSingle $ l `rem` r; \ + safeQuotRem e _ r | r == 0 = merge $ throwError e; \ + safeQuotRem _ l r = mrgSingle $ l `quotRem` r + +#if 1 +SAFE_DIVISION_CONCRETE(Integer) +SAFE_DIVISION_CONCRETE(Int8) +SAFE_DIVISION_CONCRETE(Int16) +SAFE_DIVISION_CONCRETE(Int32) +SAFE_DIVISION_CONCRETE(Int64) +SAFE_DIVISION_CONCRETE(Int) +SAFE_DIVISION_CONCRETE(SomeIntN) +SAFE_DIVISION_CONCRETE(Word8) +SAFE_DIVISION_CONCRETE(Word16) +SAFE_DIVISION_CONCRETE(Word32) +SAFE_DIVISION_CONCRETE(Word64) +SAFE_DIVISION_CONCRETE(Word) +SAFE_DIVISION_CONCRETE(SomeWordN) +#endif + +#if 1 +SAFE_DIVISION_CONCRETE_BV(IntN) +SAFE_DIVISION_CONCRETE_BV(WordN) +#endif + class SafeLinearArith a where -- | Safe signed '+' with monadic error handling in multi-path execution. -- Overflows are treated as errors. @@ -123,12 +182,5 @@ class SafeLinearArith a where -- ExceptT {If (|| (&& (<= 0x0 a) (&& (< b 0x0) (< (+ a (- b)) 0x0))) (&& (< a 0x0) (&& (< 0x0 b) (< 0x0 (+ a (- b)))))) (Left AssertionError) (Right (+ a (- b)))} safeMinus :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf a --- | Safe signed 'quot' and 'rem' with monadic error handling in multi-path --- execution. These procedures show throw 'DivideByZero' exception when the --- divisor is zero. The result should be able to handle errors with --- `MonadError`, and the error type should be compatible with 'ArithException' --- (see 'TransformError' for more details). -class SafeQuotRem a - -- | Aggregation for the operations on symbolic integer types class (Num a, SEq a, SOrd a, Solvable Integer a, SafeDivision a, SafeLinearArith a) => SymIntegerOp a From 48cc7d3fa0a5e646af14516aa1cc648f7e772097 Mon Sep 17 00:00:00 2001 From: Sirui Lu Date: Mon, 20 Feb 2023 00:21:27 -0800 Subject: [PATCH 3/6] :sparkles: refine SafeDivision interface, add tests --- grisette.cabal | 6 +- src/Grisette/Backend/SBV/Data/SMT/Lowering.hs | 92 ++- src/Grisette/Core.hs | 4 +- src/Grisette/Core/Data/BV.hs | 263 ++++---- src/Grisette/Core/Data/Class/Integer.hs | 186 ----- src/Grisette/Core/Data/Class/Mergeable.hs | 17 + src/Grisette/Core/Data/Class/SafeArith.hs | 294 ++++++++ .../Data/Prim/InternedTerm/InternedCtors.hs | 46 +- .../Prim/InternedTerm/InternedCtors.hs-boot | 20 +- .../IR/SymPrim/Data/Prim/InternedTerm/Term.hs | 101 ++- .../Data/Prim/InternedTerm/Term.hs-boot | 20 +- .../Prim/InternedTerm/TermSubstitution.hs | 12 +- .../Data/Prim/InternedTerm/TermUtils.hs | 91 +-- src/Grisette/IR/SymPrim/Data/Prim/Model.hs | 22 +- .../SymPrim/Data/Prim/PartialEval/Integer.hs | 36 - .../SymPrim/Data/Prim/PartialEval/Integral.hs | 88 +++ src/Grisette/IR/SymPrim/Data/SymPrim.hs | 100 ++- src/Grisette/Internal/IR/SymPrim.hs | 8 +- src/Grisette/Lib/Data/List.hs | 2 +- .../Backend/SBV/Data/SMT/LoweringTests.hs | 46 +- .../Backend/SBV/Data/SMT/TermRewritingGen.hs | 28 +- .../SBV/Data/SMT/TermRewritingTests.hs | 312 +++++---- test/Grisette/Core/Data/BVTests.hs | 85 ++- .../IR/SymPrim/Data/Prim/IntegerTests.hs | 48 -- .../IR/SymPrim/Data/Prim/IntegralTests.hs | 170 +++++ test/Grisette/IR/SymPrim/Data/SymPrimTests.hs | 637 +++++++++++++----- test/Main.hs | 4 +- 27 files changed, 1863 insertions(+), 875 deletions(-) delete mode 100644 src/Grisette/Core/Data/Class/Integer.hs create mode 100644 src/Grisette/Core/Data/Class/SafeArith.hs delete mode 100644 src/Grisette/IR/SymPrim/Data/Prim/PartialEval/Integer.hs create mode 100644 src/Grisette/IR/SymPrim/Data/Prim/PartialEval/Integral.hs delete mode 100644 test/Grisette/IR/SymPrim/Data/Prim/IntegerTests.hs create mode 100644 test/Grisette/IR/SymPrim/Data/Prim/IntegralTests.hs diff --git a/grisette.cabal b/grisette.cabal index c0bb08cc..d70d708c 100644 --- a/grisette.cabal +++ b/grisette.cabal @@ -60,9 +60,9 @@ library Grisette.Core.Data.Class.ExtractSymbolics Grisette.Core.Data.Class.Function Grisette.Core.Data.Class.GenSym - Grisette.Core.Data.Class.Integer Grisette.Core.Data.Class.Mergeable Grisette.Core.Data.Class.ModelOps + Grisette.Core.Data.Class.SafeArith Grisette.Core.Data.Class.SimpleMergeable Grisette.Core.Data.Class.Solvable Grisette.Core.Data.Class.Solver @@ -93,7 +93,7 @@ library Grisette.IR.SymPrim.Data.Prim.PartialEval.Bool Grisette.IR.SymPrim.Data.Prim.PartialEval.BV Grisette.IR.SymPrim.Data.Prim.PartialEval.GeneralFun - Grisette.IR.SymPrim.Data.Prim.PartialEval.Integer + Grisette.IR.SymPrim.Data.Prim.PartialEval.Integral Grisette.IR.SymPrim.Data.Prim.PartialEval.Num Grisette.IR.SymPrim.Data.Prim.PartialEval.PartialEval Grisette.IR.SymPrim.Data.Prim.PartialEval.TabularFun @@ -192,7 +192,7 @@ test-suite spec Grisette.IR.SymPrim.Data.Prim.BitsTests Grisette.IR.SymPrim.Data.Prim.BoolTests Grisette.IR.SymPrim.Data.Prim.BVTests - Grisette.IR.SymPrim.Data.Prim.IntegerTests + Grisette.IR.SymPrim.Data.Prim.IntegralTests Grisette.IR.SymPrim.Data.Prim.ModelTests Grisette.IR.SymPrim.Data.Prim.NumTests Grisette.IR.SymPrim.Data.Prim.TabularFunTests diff --git a/src/Grisette/Backend/SBV/Data/SMT/Lowering.hs b/src/Grisette/Backend/SBV/Data/SMT/Lowering.hs index be105120..cebd6216 100644 --- a/src/Grisette/Backend/SBV/Data/SMT/Lowering.hs +++ b/src/Grisette/Backend/SBV/Data/SMT/Lowering.hs @@ -313,14 +313,38 @@ lowerSinglePrimImpl' config t@(GeneralFunApplyTerm _ (f :: Term (b --> a)) (arg addResult @integerBitWidth t g return g _ -> translateBinaryError "generalApply" (R.typeRep @(b --> a)) (R.typeRep @b) (R.typeRep @a) -lowerSinglePrimImpl' config t@(DivIntegerTerm _ arg1 arg2) = +lowerSinglePrimImpl' config t@(DivIntegralTerm _ arg1 arg2) = case (config, R.typeRep @a) of - (ResolvedConfig {}, IntegerType) -> lowerBinaryTerm' config t arg1 arg2 SBV.sDiv + ResolvedSDivisibleType -> lowerBinaryTerm' config t arg1 arg2 SBV.sDiv _ -> translateBinaryError "div" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a) -lowerSinglePrimImpl' config t@(ModIntegerTerm _ arg1 arg2) = +lowerSinglePrimImpl' config t@(ModIntegralTerm _ arg1 arg2) = case (config, R.typeRep @a) of - (ResolvedConfig {}, IntegerType) -> lowerBinaryTerm' config t arg1 arg2 SBV.sMod + ResolvedSDivisibleType -> lowerBinaryTerm' config t arg1 arg2 SBV.sMod _ -> translateBinaryError "mod" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a) +lowerSinglePrimImpl' config t@(QuotIntegralTerm _ arg1 arg2) = + case (config, R.typeRep @a) of + ResolvedSDivisibleType -> lowerBinaryTerm' config t arg1 arg2 SBV.sQuot + _ -> translateBinaryError "quot" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a) +lowerSinglePrimImpl' config t@(RemIntegralTerm _ arg1 arg2) = + case (config, R.typeRep @a) of + ResolvedSDivisibleType -> lowerBinaryTerm' config t arg1 arg2 SBV.sRem + _ -> translateBinaryError "rem" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a) +lowerSinglePrimImpl' config t@(DivBoundedIntegralTerm _ arg1 arg2) = + case (config, R.typeRep @a) of + ResolvedSDivisibleType -> lowerBinaryTerm' config t arg1 arg2 SBV.sDiv + _ -> translateBinaryError "div" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a) +lowerSinglePrimImpl' config t@(ModBoundedIntegralTerm _ arg1 arg2) = + case (config, R.typeRep @a) of + ResolvedSDivisibleType -> lowerBinaryTerm' config t arg1 arg2 SBV.sMod + _ -> translateBinaryError "mod" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a) +lowerSinglePrimImpl' config t@(QuotBoundedIntegralTerm _ arg1 arg2) = + case (config, R.typeRep @a) of + ResolvedSDivisibleType -> lowerBinaryTerm' config t arg1 arg2 SBV.sQuot + _ -> translateBinaryError "quot" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a) +lowerSinglePrimImpl' config t@(RemBoundedIntegralTerm _ arg1 arg2) = + case (config, R.typeRep @a) of + ResolvedSDivisibleType -> lowerBinaryTerm' config t arg1 arg2 SBV.sRem + _ -> translateBinaryError "rem" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a) lowerSinglePrimImpl' _ _ = undefined buildUTFun11 :: @@ -699,14 +723,38 @@ lowerSinglePrimImpl config t@(GeneralFunApplyTerm _ (f :: Term (b --> a)) (arg : let g = l1 l2 return (addBiMapIntermediate (SomeTerm t) (toDyn g) m2, g) _ -> translateBinaryError "generalApply" (R.typeRep @(b --> a)) (R.typeRep @b) (R.typeRep @a) -lowerSinglePrimImpl config t@(DivIntegerTerm _ arg1 arg2) m = +lowerSinglePrimImpl config t@(DivIntegralTerm _ arg1 arg2) m = + case (config, R.typeRep @a) of + ResolvedSDivisibleType -> lowerBinaryTerm config t arg1 arg2 SBV.sDiv m + _ -> translateBinaryError "div" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a) +lowerSinglePrimImpl config t@(ModIntegralTerm _ arg1 arg2) m = + case (config, R.typeRep @a) of + ResolvedSDivisibleType -> lowerBinaryTerm config t arg1 arg2 SBV.sMod m + _ -> translateBinaryError "mod" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a) +lowerSinglePrimImpl config t@(QuotIntegralTerm _ arg1 arg2) m = + case (config, R.typeRep @a) of + ResolvedSDivisibleType -> lowerBinaryTerm config t arg1 arg2 SBV.sQuot m + _ -> translateBinaryError "quot" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a) +lowerSinglePrimImpl config t@(RemIntegralTerm _ arg1 arg2) m = + case (config, R.typeRep @a) of + ResolvedSDivisibleType -> lowerBinaryTerm config t arg1 arg2 SBV.sRem m + _ -> translateBinaryError "rem" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a) +lowerSinglePrimImpl config t@(DivBoundedIntegralTerm _ arg1 arg2) m = case (config, R.typeRep @a) of - (ResolvedConfig {}, IntegerType) -> lowerBinaryTerm config t arg1 arg2 SBV.sDiv m + ResolvedSDivisibleType -> lowerBinaryTerm config t arg1 arg2 SBV.sDiv m _ -> translateBinaryError "div" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a) -lowerSinglePrimImpl config t@(ModIntegerTerm _ arg1 arg2) m = +lowerSinglePrimImpl config t@(ModBoundedIntegralTerm _ arg1 arg2) m = case (config, R.typeRep @a) of - (ResolvedConfig {}, IntegerType) -> lowerBinaryTerm config t arg1 arg2 SBV.sMod m + ResolvedSDivisibleType -> lowerBinaryTerm config t arg1 arg2 SBV.sMod m _ -> translateBinaryError "mod" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a) +lowerSinglePrimImpl config t@(QuotBoundedIntegralTerm _ arg1 arg2) m = + case (config, R.typeRep @a) of + ResolvedSDivisibleType -> lowerBinaryTerm config t arg1 arg2 SBV.sQuot m + _ -> translateBinaryError "quot" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a) +lowerSinglePrimImpl config t@(RemBoundedIntegralTerm _ arg1 arg2) m = + case (config, R.typeRep @a) of + ResolvedSDivisibleType -> lowerBinaryTerm config t arg1 arg2 SBV.sRem m + _ -> translateBinaryError "rem" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a) lowerSinglePrimImpl _ _ _ = error "Should never happen" bvIsNonZeroFromGEq1 :: forall w r. (1 <= w) => ((SBV.BVIsNonZero w) => r) -> r @@ -1117,6 +1165,34 @@ pattern ResolvedNumType :: (GrisetteSMTConfig integerBitWidth, R.TypeRep s) pattern ResolvedNumType <- (resolveNumTypeView -> Just DictNumType) +type SDivisibleTypeConstraint integerBitWidth s s' = + ( SimpleTypeConstraint integerBitWidth s s', + SBV.SDivisible (SBV.SBV s'), + Integral s + ) + +data DictSDivisibleType integerBitWidth s where + DictSDivisibleType :: + forall integerBitWidth s s'. + (SDivisibleTypeConstraint integerBitWidth s s') => + DictSDivisibleType integerBitWidth s + +resolveSDivisibleTypeView :: TypeResolver DictSDivisibleType +resolveSDivisibleTypeView (ResolvedConfig {}, s) = case s of + IntegerType -> Just DictSDivisibleType + SignedBVType _ -> Just DictSDivisibleType + UnsignedBVType _ -> Just DictSDivisibleType + _ -> Nothing +resolveSDivisibleTypeView _ = error "Should never happen, make compiler happy" + +pattern ResolvedSDivisibleType :: + forall integerBitWidth s. + (SupportedPrim s) => + forall s'. + SDivisibleTypeConstraint integerBitWidth s s' => + (GrisetteSMTConfig integerBitWidth, R.TypeRep s) +pattern ResolvedSDivisibleType <- (resolveSDivisibleTypeView -> Just DictSDivisibleType) + type NumOrdTypeConstraint integerBitWidth s s' = ( NumTypeConstraint integerBitWidth s s', SBV.OrdSymbolic (SBV.SBV s'), diff --git a/src/Grisette/Core.hs b/src/Grisette/Core.hs index cf41a253..703302e3 100644 --- a/src/Grisette/Core.hs +++ b/src/Grisette/Core.hs @@ -910,7 +910,7 @@ module Grisette.Core -- >>> let x = "x" :: SymInteger -- >>> let y = "y" :: SymInteger -- >>> assert = symAssertWith Assert - -- >>> sdiv = safeDiv Arith + -- >>> sdiv = safeDiv' (const Arith) -- >>> :{ -- -- equivalent concrete program: -- -- let x = x `div` y @@ -1026,10 +1026,10 @@ import Grisette.Core.Data.Class.Evaluate import Grisette.Core.Data.Class.ExtractSymbolics import Grisette.Core.Data.Class.Function import Grisette.Core.Data.Class.GenSym -import Grisette.Core.Data.Class.Integer import Grisette.Core.Data.Class.Mergeable import Grisette.Core.Data.Class.ModelOps import Grisette.Core.Data.Class.SOrd +import Grisette.Core.Data.Class.SafeArith import Grisette.Core.Data.Class.SimpleMergeable import Grisette.Core.Data.Class.Solvable import Grisette.Core.Data.Class.Solver diff --git a/src/Grisette/Core/Data/BV.hs b/src/Grisette/Core/Data/BV.hs index 845f0e87..2e053884 100644 --- a/src/Grisette/Core/Data/BV.hs +++ b/src/Grisette/Core/Data/BV.hs @@ -2,6 +2,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveLift #-} +{-# LANGUAGE DerivingVia #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} @@ -25,16 +26,28 @@ -- Stability : Experimental -- Portability : GHC only module Grisette.Core.Data.BV - ( IntN (..), + ( BitwidthMismatch (..), + IntN (..), WordN (..), SomeIntN (..), SomeWordN (..), + unarySomeIntN, + unarySomeIntNR1, + binSomeIntN, + binSomeIntNR1, + binSomeIntNR2, + unarySomeWordN, + unarySomeWordNR1, + binSomeWordN, + binSomeWordNR1, + binSomeWordNR2, ) where import Control.DeepSeq import Control.Exception import Data.Bits +import Data.CallStack import Data.Hashable import Data.Proxy import Data.Typeable @@ -47,6 +60,12 @@ import Grisette.Utils.Parameterized import Language.Haskell.TH.Syntax import Numeric +data BitwidthMismatch = BitwidthMismatch + deriving (Show, Eq, Ord, Generic) + +instance Exception BitwidthMismatch where + displayException BitwidthMismatch = "Bit width does not match" + -- | -- Symbolic unsigned bit vectors. newtype WordN (n :: Nat) = WordN {unWordN :: Integer} @@ -57,57 +76,57 @@ newtype WordN (n :: Nat) = WordN {unWordN :: Integer} data SomeWordN where SomeWordN :: (KnownNat n, 1 <= n) => WordN n -> SomeWordN -unarySomeWordN :: (forall n. (KnownNat n, 1 <= n) => WordN n -> r) -> String -> SomeWordN -> r -unarySomeWordN op str (SomeWordN (w :: WordN w)) = op w +unarySomeWordN :: HasCallStack => (forall n. (KnownNat n, 1 <= n) => WordN n -> r) -> SomeWordN -> r +unarySomeWordN op (SomeWordN (w :: WordN w)) = op w {-# INLINE unarySomeWordN #-} -unarySomeWordN' :: (forall n. (KnownNat n, 1 <= n) => WordN n -> WordN n) -> String -> SomeWordN -> SomeWordN -unarySomeWordN' op str (SomeWordN (w :: WordN w)) = SomeWordN $ op w -{-# INLINE unarySomeWordN' #-} +unarySomeWordNR1 :: HasCallStack => (forall n. (KnownNat n, 1 <= n) => WordN n -> WordN n) -> SomeWordN -> SomeWordN +unarySomeWordNR1 op (SomeWordN (w :: WordN w)) = SomeWordN $ op w +{-# INLINE unarySomeWordNR1 #-} -binSomeWordN :: (forall n. (KnownNat n, 1 <= n) => WordN n -> WordN n -> r) -> String -> SomeWordN -> SomeWordN -> r -binSomeWordN op str (SomeWordN (l :: WordN l)) (SomeWordN (r :: WordN r)) = +binSomeWordN :: HasCallStack => (forall n. (KnownNat n, 1 <= n) => WordN n -> WordN n -> r) -> SomeWordN -> SomeWordN -> r +binSomeWordN op (SomeWordN (l :: WordN l)) (SomeWordN (r :: WordN r)) = case sameNat (Proxy @l) (Proxy @r) of Just Refl -> op l r - Nothing -> error $ "Operation " ++ str ++ " on WordN with different bitwidth" + Nothing -> throw BitwidthMismatch {-# INLINE binSomeWordN #-} -binSomeWordN' :: (forall n. (KnownNat n, 1 <= n) => WordN n -> WordN n -> WordN n) -> String -> SomeWordN -> SomeWordN -> SomeWordN -binSomeWordN' op str (SomeWordN (l :: WordN l)) (SomeWordN (r :: WordN r)) = +binSomeWordNR1 :: HasCallStack => (forall n. (KnownNat n, 1 <= n) => WordN n -> WordN n -> WordN n) -> SomeWordN -> SomeWordN -> SomeWordN +binSomeWordNR1 op (SomeWordN (l :: WordN l)) (SomeWordN (r :: WordN r)) = case sameNat (Proxy @l) (Proxy @r) of Just Refl -> SomeWordN $ op l r - Nothing -> error $ "Operation " ++ str ++ " on WordN with different bitwidth" -{-# INLINE binSomeWordN' #-} + Nothing -> throw BitwidthMismatch +{-# INLINE binSomeWordNR1 #-} -binSomeWordN'' :: (forall n. (KnownNat n, 1 <= n) => WordN n -> WordN n -> (WordN n, WordN n)) -> String -> SomeWordN -> SomeWordN -> (SomeWordN, SomeWordN) -binSomeWordN'' op str (SomeWordN (l :: WordN l)) (SomeWordN (r :: WordN r)) = +binSomeWordNR2 :: HasCallStack => (forall n. (KnownNat n, 1 <= n) => WordN n -> WordN n -> (WordN n, WordN n)) -> SomeWordN -> SomeWordN -> (SomeWordN, SomeWordN) +binSomeWordNR2 op (SomeWordN (l :: WordN l)) (SomeWordN (r :: WordN r)) = case sameNat (Proxy @l) (Proxy @r) of Just Refl -> case op l r of (a, b) -> (SomeWordN a, SomeWordN b) - Nothing -> error $ "Operation " ++ str ++ " on WordN with different bitwidth" -{-# INLINE binSomeWordN'' #-} + Nothing -> throw BitwidthMismatch +{-# INLINE binSomeWordNR2 #-} instance Eq SomeWordN where - (==) = binSomeWordN (==) "==" + (==) = binSomeWordN (==) {-# INLINE (==) #-} - (/=) = binSomeWordN (/=) "/=" + (/=) = binSomeWordN (/=) {-# INLINE (/=) #-} instance Ord SomeWordN where - (<=) = binSomeWordN (<=) "<=" + (<=) = binSomeWordN (<=) {-# INLINE (<=) #-} - (<) = binSomeWordN (<) "<" + (<) = binSomeWordN (<) {-# INLINE (<) #-} - (>=) = binSomeWordN (>=) ">=" + (>=) = binSomeWordN (>=) {-# INLINE (>=) #-} - (>) = binSomeWordN (>) ">" + (>) = binSomeWordN (>) {-# INLINE (>) #-} - max = binSomeWordN' max "max" + max = binSomeWordNR1 max {-# INLINE max #-} - min = binSomeWordN' min "min" + min = binSomeWordNR1 min {-# INLINE min #-} - compare = binSomeWordN compare "compare" + compare = binSomeWordN compare {-# INLINE compare #-} instance Lift SomeWordN where @@ -141,57 +160,57 @@ newtype IntN (n :: Nat) = IntN {unIntN :: Integer} data SomeIntN where SomeIntN :: (KnownNat n, 1 <= n) => IntN n -> SomeIntN -unarySomeIntN :: (forall n. (KnownNat n, 1 <= n) => IntN n -> r) -> String -> SomeIntN -> r -unarySomeIntN op str (SomeIntN (w :: IntN w)) = op w +unarySomeIntN :: (forall n. (KnownNat n, 1 <= n) => IntN n -> r) -> SomeIntN -> r +unarySomeIntN op (SomeIntN (w :: IntN w)) = op w {-# INLINE unarySomeIntN #-} -unarySomeIntN' :: (forall n. (KnownNat n, 1 <= n) => IntN n -> IntN n) -> String -> SomeIntN -> SomeIntN -unarySomeIntN' op str (SomeIntN (w :: IntN w)) = SomeIntN $ op w -{-# INLINE unarySomeIntN' #-} +unarySomeIntNR1 :: (forall n. (KnownNat n, 1 <= n) => IntN n -> IntN n) -> SomeIntN -> SomeIntN +unarySomeIntNR1 op (SomeIntN (w :: IntN w)) = SomeIntN $ op w +{-# INLINE unarySomeIntNR1 #-} -binSomeIntN :: (forall n. (KnownNat n, 1 <= n) => IntN n -> IntN n -> r) -> String -> SomeIntN -> SomeIntN -> r -binSomeIntN op str (SomeIntN (l :: IntN l)) (SomeIntN (r :: IntN r)) = +binSomeIntN :: (forall n. (KnownNat n, 1 <= n) => IntN n -> IntN n -> r) -> SomeIntN -> SomeIntN -> r +binSomeIntN op (SomeIntN (l :: IntN l)) (SomeIntN (r :: IntN r)) = case sameNat (Proxy @l) (Proxy @r) of Just Refl -> op l r - Nothing -> error $ "Operation " ++ str ++ " on IntN with different bitwidth" + Nothing -> throw BitwidthMismatch {-# INLINE binSomeIntN #-} -binSomeIntN' :: (forall n. (KnownNat n, 1 <= n) => IntN n -> IntN n -> IntN n) -> String -> SomeIntN -> SomeIntN -> SomeIntN -binSomeIntN' op str (SomeIntN (l :: IntN l)) (SomeIntN (r :: IntN r)) = +binSomeIntNR1 :: (forall n. (KnownNat n, 1 <= n) => IntN n -> IntN n -> IntN n) -> SomeIntN -> SomeIntN -> SomeIntN +binSomeIntNR1 op (SomeIntN (l :: IntN l)) (SomeIntN (r :: IntN r)) = case sameNat (Proxy @l) (Proxy @r) of Just Refl -> SomeIntN $ op l r - Nothing -> error $ "Operation " ++ str ++ " on IntN with different bitwidth" -{-# INLINE binSomeIntN' #-} + Nothing -> throw BitwidthMismatch +{-# INLINE binSomeIntNR1 #-} -binSomeIntN'' :: (forall n. (KnownNat n, 1 <= n) => IntN n -> IntN n -> (IntN n, IntN n)) -> String -> SomeIntN -> SomeIntN -> (SomeIntN, SomeIntN) -binSomeIntN'' op str (SomeIntN (l :: IntN l)) (SomeIntN (r :: IntN r)) = +binSomeIntNR2 :: (forall n. (KnownNat n, 1 <= n) => IntN n -> IntN n -> (IntN n, IntN n)) -> SomeIntN -> SomeIntN -> (SomeIntN, SomeIntN) +binSomeIntNR2 op (SomeIntN (l :: IntN l)) (SomeIntN (r :: IntN r)) = case sameNat (Proxy @l) (Proxy @r) of Just Refl -> case op l r of (a, b) -> (SomeIntN a, SomeIntN b) - Nothing -> error $ "Operation " ++ str ++ " on IntN with different bitwidth" -{-# INLINE binSomeIntN'' #-} + Nothing -> throw BitwidthMismatch +{-# INLINE binSomeIntNR2 #-} instance Eq SomeIntN where - (==) = binSomeIntN (==) "==" + (==) = binSomeIntN (==) {-# INLINE (==) #-} - (/=) = binSomeIntN (/=) "/=" + (/=) = binSomeIntN (/=) {-# INLINE (/=) #-} instance Ord SomeIntN where - (<=) = binSomeIntN (<=) "<=" + (<=) = binSomeIntN (<=) {-# INLINE (<=) #-} - (<) = binSomeIntN (<) "<" + (<) = binSomeIntN (<) {-# INLINE (<) #-} - (>=) = binSomeIntN (>=) ">=" + (>=) = binSomeIntN (>=) {-# INLINE (>=) #-} - (>) = binSomeIntN (>) ">" + (>) = binSomeIntN (>) {-# INLINE (>) #-} - max = binSomeIntN' max "max" + max = binSomeIntNR1 max {-# INLINE max #-} - min = binSomeIntN' min "min" + min = binSomeIntNR1 min {-# INLINE min #-} - compare = binSomeIntN compare "compare" + compare = binSomeIntN compare {-# INLINE compare #-} instance Lift SomeIntN where @@ -263,36 +282,36 @@ instance (KnownNat n, 1 <= n) => Bits (WordN n) where popCount (WordN n) = popCount n instance Bits SomeWordN where - (.&.) = binSomeWordN' (.&.) ".&." - (.|.) = binSomeWordN' (.|.) ".|." - xor = binSomeWordN' xor "xor" - complement = unarySomeWordN' complement "complement" - shift s i = unarySomeWordN' (`shift` i) "shift" s - rotate s i = unarySomeWordN' (`rotate` i) "rotate" s + (.&.) = binSomeWordNR1 (.&.) + (.|.) = binSomeWordNR1 (.|.) + xor = binSomeWordNR1 xor + complement = unarySomeWordNR1 complement + shift s i = unarySomeWordNR1 (`shift` i) s + rotate s i = unarySomeWordNR1 (`rotate` i) s zeroBits = error "zeroBits is not defined for SomeWordN as no bitwidth is known" bit = error "bit is not defined for SomeWordN as no bitwidth is known" - setBit s i = unarySomeWordN' (`setBit` i) "setBit" s - clearBit s i = unarySomeWordN' (`clearBit` i) "clearBit" s - complementBit s i = unarySomeWordN' (`complementBit` i) "complementBit" s - testBit s i = unarySomeWordN (`testBit` i) "testBit" s + setBit s i = unarySomeWordNR1 (`setBit` i) s + clearBit s i = unarySomeWordNR1 (`clearBit` i) s + complementBit s i = unarySomeWordNR1 (`complementBit` i) s + testBit s i = unarySomeWordN (`testBit` i) s bitSizeMaybe (SomeWordN (n :: WordN n)) = Just $ fromIntegral $ natVal n bitSize (SomeWordN (n :: WordN n)) = fromIntegral $ natVal n isSigned _ = False - shiftL s i = unarySomeWordN' (`shiftL` i) "shiftL" s - unsafeShiftL s i = unarySomeWordN' (`unsafeShiftL` i) "unsafeShiftL" s - shiftR s i = unarySomeWordN' (`shiftR` i) "shiftR" s - unsafeShiftR s i = unarySomeWordN' (`unsafeShiftR` i) "unsafeShiftR" s - rotateL s i = unarySomeWordN' (`rotateL` i) "rotateL" s - rotateR s i = unarySomeWordN' (`rotateR` i) "rotateR" s - popCount = unarySomeWordN popCount "popCount" + shiftL s i = unarySomeWordNR1 (`shiftL` i) s + unsafeShiftL s i = unarySomeWordNR1 (`unsafeShiftL` i) s + shiftR s i = unarySomeWordNR1 (`shiftR` i) s + unsafeShiftR s i = unarySomeWordNR1 (`unsafeShiftR` i) s + rotateL s i = unarySomeWordNR1 (`rotateL` i) s + rotateR s i = unarySomeWordNR1 (`rotateR` i) s + popCount = unarySomeWordN popCount instance (KnownNat n, 1 <= n) => FiniteBits (WordN n) where finiteBitSize _ = fromIntegral (natVal (Proxy :: Proxy n)) instance FiniteBits SomeWordN where finiteBitSize (SomeWordN (n :: WordN n)) = fromIntegral $ natVal n - countLeadingZeros = unarySomeWordN countLeadingZeros "countLeadingZeros" - countTrailingZeros = unarySomeWordN countTrailingZeros "countTrailingZeros" + countLeadingZeros = unarySomeWordN countLeadingZeros + countTrailingZeros = unarySomeWordN countTrailingZeros instance (KnownNat n, 1 <= n) => Bounded (WordN n) where maxBound = WordN ((1 `shiftL` fromIntegral (natVal (Proxy :: Proxy n))) - 1) @@ -322,7 +341,7 @@ instance (KnownNat n, 1 <= n) => Real (WordN n) where toRational (WordN n) = n % 1 instance Real SomeWordN where - toRational = unarySomeWordN toRational "toRational" + toRational = unarySomeWordN toRational instance (KnownNat n, 1 <= n) => Integral (WordN n) where quot (WordN x) (WordN y) = WordN (x `quot` y) @@ -335,13 +354,13 @@ instance (KnownNat n, 1 <= n) => Integral (WordN n) where toInteger (WordN n) = n instance Integral SomeWordN where - quot = binSomeWordN' quot "quot" - rem = binSomeWordN' rem "rem" - quotRem = binSomeWordN'' quotRem "quotRem" - div = binSomeWordN' div "div" - mod = binSomeWordN' mod "mod" - divMod = binSomeWordN'' divMod "divMod" - toInteger = unarySomeWordN toInteger "toInteger" + quot = binSomeWordNR1 quot + rem = binSomeWordNR1 rem + quotRem = binSomeWordNR2 quotRem + div = binSomeWordNR1 div + mod = binSomeWordNR1 mod + divMod = binSomeWordNR2 divMod + toInteger = unarySomeWordN toInteger instance (KnownNat n, 1 <= n) => Num (WordN n) where WordN x + WordN y = WordN (x + y) .&. maxBound @@ -360,12 +379,12 @@ instance (KnownNat n, 1 <= n) => Num (WordN n) where | otherwise = -fromInteger (-x) instance Num SomeWordN where - (+) = binSomeWordN' (+) "+" - (-) = binSomeWordN' (-) "-" - (*) = binSomeWordN' (*) "*" - negate = unarySomeWordN' negate "negate" - abs = unarySomeWordN' abs "abs" - signum = unarySomeWordN' signum "signum" + (+) = binSomeWordNR1 (+) + (-) = binSomeWordNR1 (-) + (*) = binSomeWordNR1 (*) + negate = unarySomeWordNR1 negate + abs = unarySomeWordNR1 abs + signum = unarySomeWordNR1 signum fromInteger = error "fromInteger is not defined for SomeWordN as no bitwidth is known" minusOneIntN :: forall proxy n. KnownNat n => proxy n -> IntN n @@ -410,36 +429,36 @@ instance (KnownNat n, 1 <= n) => Bits (IntN n) where popCount (IntN i) = popCount i instance Bits SomeIntN where - (.&.) = binSomeIntN' (.&.) ".&." - (.|.) = binSomeIntN' (.|.) ".|." - xor = binSomeIntN' xor "xor" - complement = unarySomeIntN' complement "complement" - shift s i = unarySomeIntN' (`shift` i) "shift" s - rotate s i = unarySomeIntN' (`rotate` i) "rotate" s + (.&.) = binSomeIntNR1 (.&.) + (.|.) = binSomeIntNR1 (.|.) + xor = binSomeIntNR1 xor + complement = unarySomeIntNR1 complement + shift s i = unarySomeIntNR1 (`shift` i) s + rotate s i = unarySomeIntNR1 (`rotate` i) s zeroBits = error "zeroBits is not defined for SomeIntN as no bitwidth is known" bit = error "bit is not defined for SomeIntN as no bitwidth is known" - setBit s i = unarySomeIntN' (`setBit` i) "setBit" s - clearBit s i = unarySomeIntN' (`clearBit` i) "clearBit" s - complementBit s i = unarySomeIntN' (`complementBit` i) "complementBit" s - testBit s i = unarySomeIntN (`testBit` i) "testBit" s + setBit s i = unarySomeIntNR1 (`setBit` i) s + clearBit s i = unarySomeIntNR1 (`clearBit` i) s + complementBit s i = unarySomeIntNR1 (`complementBit` i) s + testBit s i = unarySomeIntN (`testBit` i) s bitSizeMaybe (SomeIntN (n :: IntN n)) = Just $ fromIntegral $ natVal n bitSize (SomeIntN (n :: IntN n)) = fromIntegral $ natVal n isSigned _ = False - shiftL s i = unarySomeIntN' (`shiftL` i) "shiftL" s - unsafeShiftL s i = unarySomeIntN' (`unsafeShiftL` i) "unsafeShiftL" s - shiftR s i = unarySomeIntN' (`shiftR` i) "shiftR" s - unsafeShiftR s i = unarySomeIntN' (`unsafeShiftR` i) "unsafeShiftR" s - rotateL s i = unarySomeIntN' (`rotateL` i) "rotateL" s - rotateR s i = unarySomeIntN' (`rotateR` i) "rotateR" s - popCount = unarySomeIntN popCount "popCount" + shiftL s i = unarySomeIntNR1 (`shiftL` i) s + unsafeShiftL s i = unarySomeIntNR1 (`unsafeShiftL` i) s + shiftR s i = unarySomeIntNR1 (`shiftR` i) s + unsafeShiftR s i = unarySomeIntNR1 (`unsafeShiftR` i) s + rotateL s i = unarySomeIntNR1 (`rotateL` i) s + rotateR s i = unarySomeIntNR1 (`rotateR` i) s + popCount = unarySomeIntN popCount instance (KnownNat n, 1 <= n) => FiniteBits (IntN n) where finiteBitSize _ = fromIntegral (natVal (Proxy :: Proxy n)) instance FiniteBits SomeIntN where finiteBitSize (SomeIntN (n :: IntN n)) = fromIntegral $ natVal n - countLeadingZeros = unarySomeIntN countLeadingZeros "countLeadingZeros" - countTrailingZeros = unarySomeIntN countTrailingZeros "countTrailingZeros" + countLeadingZeros = unarySomeIntN countLeadingZeros + countTrailingZeros = unarySomeIntN countTrailingZeros instance (KnownNat n, 1 <= n) => Bounded (IntN n) where maxBound = IntN (1 `shiftL` (fromIntegral (natVal (Proxy :: Proxy n)) - 1) - 1) @@ -469,17 +488,14 @@ instance (KnownNat n, 1 <= n) => Real (IntN n) where toRational i = toInteger i % 1 instance Real SomeIntN where - toRational = unarySomeIntN toRational "toRational" + toRational = unarySomeIntN toRational instance (KnownNat n, 1 <= n) => Integral (IntN n) where quot x y = if x == minBound && y == -1 then throw Overflow else fromInteger (toInteger x `quot` toInteger y) - rem x y = - if x == minBound && y == -1 - then throw Overflow - else fromInteger (toInteger x `rem` toInteger y) + rem x y = fromInteger (toInteger x `rem` toInteger y) quotRem x y = if x == minBound && y == -1 then throw Overflow @@ -489,10 +505,7 @@ instance (KnownNat n, 1 <= n) => Integral (IntN n) where if x == minBound && y == -1 then throw Overflow else fromInteger (toInteger x `div` toInteger y) - mod x y = - if x == minBound && y == -1 - then throw Overflow - else fromInteger (toInteger x `mod` toInteger y) + mod x y = fromInteger (toInteger x `mod` toInteger y) divMod x y = if x == minBound && y == -1 then throw Overflow @@ -507,13 +520,13 @@ instance (KnownNat n, 1 <= n) => Integral (IntN n) where _ -> undefined instance Integral SomeIntN where - quot = binSomeIntN' quot "quot" - rem = binSomeIntN' rem "rem" - quotRem = binSomeIntN'' quotRem "quotRem" - div = binSomeIntN' div "div" - mod = binSomeIntN' mod "mod" - divMod = binSomeIntN'' divMod "divMod" - toInteger = unarySomeIntN toInteger "toInteger" + quot = binSomeIntNR1 quot + rem = binSomeIntNR1 rem + quotRem = binSomeIntNR2 quotRem + div = binSomeIntNR1 div + mod = binSomeIntNR1 mod + divMod = binSomeIntNR2 divMod + toInteger = unarySomeIntN toInteger instance (KnownNat n, 1 <= n) => Num (IntN n) where IntN x + IntN y = IntN (x + y) .&. minusOneIntN (Proxy :: Proxy n) @@ -533,12 +546,12 @@ instance (KnownNat n, 1 <= n) => Num (IntN n) where maxn = 1 `shiftL` (n - 1) - 1 instance Num SomeIntN where - (+) = binSomeIntN' (+) "+" - (-) = binSomeIntN' (-) "-" - (*) = binSomeIntN' (*) "*" - negate = unarySomeIntN' negate "negate" - abs = unarySomeIntN' abs "abs" - signum = unarySomeIntN' signum "signum" + (+) = binSomeIntNR1 (+) + (-) = binSomeIntNR1 (-) + (*) = binSomeIntNR1 (*) + negate = unarySomeIntNR1 negate + abs = unarySomeIntNR1 abs + signum = unarySomeIntNR1 signum fromInteger = error "fromInteger is not defined for SomeIntN as no bitwidth is known" instance (KnownNat n, 1 <= n) => Ord (IntN n) where diff --git a/src/Grisette/Core/Data/Class/Integer.hs b/src/Grisette/Core/Data/Class/Integer.hs deleted file mode 100644 index bed06af3..00000000 --- a/src/Grisette/Core/Data/Class/Integer.hs +++ /dev/null @@ -1,186 +0,0 @@ -{-# LANGUAGE CPP #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE Trustworthy #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} - --- | --- Module : Grisette.Core.Data.Class.Integer --- Copyright : (c) Sirui Lu 2021-2023 --- License : BSD-3-Clause (see the LICENSE file) --- --- Maintainer : siruilu@cs.washington.edu --- Stability : Experimental --- Portability : GHC only -module Grisette.Core.Data.Class.Integer - ( -- * Symbolic integer operations - ArithException (..), - SafeDivision (..), - SafeLinearArith (..), - SymIntegerOp, - ) -where - -import Control.Exception -import Control.Monad.Except -import Grisette.Core.Control.Monad.Union -import Grisette.Core.Data.Class.Bool -import Grisette.Core.Data.Class.Error -import Grisette.Core.Data.Class.Mergeable -import Grisette.Core.Data.Class.SOrd -import Grisette.Core.Data.Class.SimpleMergeable -import Grisette.Core.Data.Class.Solvable -import Data.Int -import Data.Word -import GHC.TypeNats -import Grisette.Core.Data.BV - --- $setup --- >>> import Grisette.Core --- >>> import Grisette.IR.SymPrim - --- | Safe division with monadic error handling in multi-path --- execution. These procedures throw an exception when the --- divisor is zero. The result should be able to handle errors with --- `MonadError`. -class (SOrd a, Num a, Mergeable a) => SafeDivision a where - -- | Safe signed 'div' with monadic error handling in multi-path execution. - -- - -- >>> safeDiv AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM SymInteger - -- ExceptT {If (= b 0) (Left AssertionError) (Right (div a b))} - safeDiv :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf a - safeDiv e l r = do - (d, _) <- safeDivMod e l r - mrgSingle d - - -- | Safe signed 'mod' with monadic error handling in multi-path execution. - -- - -- >>> safeMod AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM SymInteger - -- ExceptT {If (= b 0) (Left AssertionError) (Right (mod a b))} - safeMod :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf a - safeMod e l r = do - (_, m) <- safeDivMod e l r - mrgSingle m - - -- | Safe signed 'div' with monadic error handling in multi-path execution. - -- - -- >>> safeDivMod AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM (SymInteger, SymInteger) - -- ExceptT {If (= b 0) (Left AssertionError) (Right ((div a b),(mod a b)))} - safeDivMod :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf (a, a) - safeDivMod e l r = do - d <- safeDiv e l r - m <- safeMod e l r - mrgSingle (d, m) - - -- | Safe signed 'quot' with monadic error handling in multi-path execution. - safeQuot :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf a - safeQuot e l r = do - (d, m) <- safeDivMod e l r - mrgIf - ((l >=~ 0 &&~ r >~ 0) ||~ (l <=~ 0 &&~ r <~ 0) ||~ m ==~ 0) - (mrgSingle d) - (mrgSingle $ d + 1) - - -- | Safe signed 'rem' with monadic error handling in multi-path execution. - safeRem :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf a - safeRem e l r = do - (d, m) <- safeDivMod e l r - mrgIf - ((l >=~ 0 &&~ r >~ 0) ||~ (l <=~ 0 &&~ r <~ 0) ||~ m ==~ 0) - (mrgSingle m) - (mrgSingle $ m - r) - - -- | Safe signed 'quotRem' with monadic error handling in multi-path execution. - safeQuotRem :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf (a, a) - safeQuotRem e l r = do - (d, m) <- safeDivMod e l r - mrgIf - ((l >=~ 0 &&~ r >~ 0) ||~ (l <=~ 0 &&~ r <~ 0) ||~ m ==~ 0) - (mrgSingle (d, m)) - (mrgSingle (d + 1, m - r)) - - {-# MINIMAL (safeDivMod | (safeDiv, safeMod)) #-} - -#define SAFE_DIVISION_CONCRETE(type) \ -instance SafeDivision type where \ - safeDiv e _ r | r == 0 = merge $ throwError e; \ - safeDiv _ l r = mrgSingle $ l `div` r; \ - safeMod e _ r | r == 0 = merge $ throwError e; \ - safeMod _ l r = mrgSingle $ l `mod` r; \ - safeDivMod e _ r | r == 0 = merge $ throwError e; \ - safeDivMod _ l r = mrgSingle $ l `divMod` r; \ - safeQuot e _ r | r == 0 = merge $ throwError e; \ - safeQuot _ l r = mrgSingle $ l `quot` r; \ - safeRem e _ r | r == 0 = merge $ throwError e; \ - safeRem _ l r = mrgSingle $ l `rem` r; \ - safeQuotRem e _ r | r == 0 = merge $ throwError e; \ - safeQuotRem _ l r = mrgSingle $ l `quotRem` r - -#define SAFE_DIVISION_CONCRETE_BV(type) \ -instance (KnownNat n, 1 <= n) => SafeDivision (type n) where \ - safeDiv e _ r | r == 0 = merge $ throwError e; \ - safeDiv _ l r = mrgSingle $ l `div` r; \ - safeMod e _ r | r == 0 = merge $ throwError e; \ - safeMod _ l r = mrgSingle $ l `mod` r; \ - safeDivMod e _ r | r == 0 = merge $ throwError e; \ - safeDivMod _ l r = mrgSingle $ l `divMod` r; \ - safeQuot e _ r | r == 0 = merge $ throwError e; \ - safeQuot _ l r = mrgSingle $ l `quot` r; \ - safeRem e _ r | r == 0 = merge $ throwError e; \ - safeRem _ l r = mrgSingle $ l `rem` r; \ - safeQuotRem e _ r | r == 0 = merge $ throwError e; \ - safeQuotRem _ l r = mrgSingle $ l `quotRem` r - -#if 1 -SAFE_DIVISION_CONCRETE(Integer) -SAFE_DIVISION_CONCRETE(Int8) -SAFE_DIVISION_CONCRETE(Int16) -SAFE_DIVISION_CONCRETE(Int32) -SAFE_DIVISION_CONCRETE(Int64) -SAFE_DIVISION_CONCRETE(Int) -SAFE_DIVISION_CONCRETE(SomeIntN) -SAFE_DIVISION_CONCRETE(Word8) -SAFE_DIVISION_CONCRETE(Word16) -SAFE_DIVISION_CONCRETE(Word32) -SAFE_DIVISION_CONCRETE(Word64) -SAFE_DIVISION_CONCRETE(Word) -SAFE_DIVISION_CONCRETE(SomeWordN) -#endif - -#if 1 -SAFE_DIVISION_CONCRETE_BV(IntN) -SAFE_DIVISION_CONCRETE_BV(WordN) -#endif - -class SafeLinearArith a where - -- | Safe signed '+' with monadic error handling in multi-path execution. - -- Overflows are treated as errors. - -- - -- >>> safeAdd AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM SymInteger - -- ExceptT {Right (+ a b)} - -- >>> safeAdd AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM (SymIntN 4) - -- ExceptT {If (|| (&& (< 0x0 a) (&& (< 0x0 b) (< (+ a b) 0x0))) (&& (< a 0x0) (&& (< b 0x0) (<= 0x0 (+ a b))))) (Left AssertionError) (Right (+ a b))} - safeAdd :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf a - - -- | Safe signed 'negate' with monadic error handling in multi-path execution. - -- Overflows are treated as errors. - -- - -- >>> safeNeg AssertionError (ssym "a") :: ExceptT AssertionError UnionM SymInteger - -- ExceptT {Right (- a)} - -- >>> safeNeg AssertionError (ssym "a") :: ExceptT AssertionError UnionM (SymIntN 4) - -- ExceptT {If (= a 0x8) (Left AssertionError) (Right (- a))} - safeNeg :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> uf a - - -- | Safe signed '-' with monadic error handling in multi-path execution. - -- Overflows are treated as errors. - -- - -- >>> safeMinus AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM SymInteger - -- ExceptT {Right (+ a (- b))} - -- >>> safeMinus AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM (SymIntN 4) - -- ExceptT {If (|| (&& (<= 0x0 a) (&& (< b 0x0) (< (+ a (- b)) 0x0))) (&& (< a 0x0) (&& (< 0x0 b) (< 0x0 (+ a (- b)))))) (Left AssertionError) (Right (+ a (- b)))} - safeMinus :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf a - --- | Aggregation for the operations on symbolic integer types -class (Num a, SEq a, SOrd a, Solvable Integer a, SafeDivision a, SafeLinearArith a) => SymIntegerOp a diff --git a/src/Grisette/Core/Data/Class/Mergeable.hs b/src/Grisette/Core/Data/Class/Mergeable.hs index 43db080b..3e09bd8b 100644 --- a/src/Grisette/Core/Data/Class/Mergeable.hs +++ b/src/Grisette/Core/Data/Class/Mergeable.hs @@ -50,6 +50,7 @@ module Grisette.Core.Data.Class.Mergeable ) where +import Control.Exception import Control.Monad.Cont import Control.Monad.Except import Control.Monad.Identity @@ -992,3 +993,19 @@ MERGEABLE_BV_SOME(SomeSymWordN) MERGEABLE_FUN(=~>) MERGEABLE_FUN(-~>) #endif + +-- Exceptions +instance Mergeable ArithException where + rootStrategy = + SortedStrategy + ( \case + Overflow -> 0 :: Int + Underflow -> 1 :: Int + LossOfPrecision -> 2 :: Int + DivideByZero -> 3 :: Int + Denormal -> 4 :: Int + RatioZeroDenominator -> 5 :: Int + ) + (const $ SimpleStrategy $ \_ l r -> l) + +deriving via (Default BitwidthMismatch) instance (Mergeable BitwidthMismatch) diff --git a/src/Grisette/Core/Data/Class/SafeArith.hs b/src/Grisette/Core/Data/Class/SafeArith.hs new file mode 100644 index 00000000..a2da39b6 --- /dev/null +++ b/src/Grisette/Core/Data/Class/SafeArith.hs @@ -0,0 +1,294 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE Trustworthy #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +{- HLINT ignore "Redundant bracket" -} + +-- | +-- Module : Grisette.Core.Data.Class.SafeArith +-- Copyright : (c) Sirui Lu 2021-2023 +-- License : BSD-3-Clause (see the LICENSE file) +-- +-- Maintainer : siruilu@cs.washington.edu +-- Stability : Experimental +-- Portability : GHC only +module Grisette.Core.Data.Class.SafeArith + ( -- * Symbolic integer operations + ArithException (..), + SafeDivision (..), + SafeLinearArith (..), + SymIntegerOp, + ) +where + +import Control.Exception +import Control.Monad.Except +import Data.Int +import Data.Typeable +import Data.Word +import GHC.TypeNats +import Grisette.Core.Control.Monad.Union +import Grisette.Core.Data.BV +import Grisette.Core.Data.Class.Bool +import Grisette.Core.Data.Class.Error +import Grisette.Core.Data.Class.Mergeable +import Grisette.Core.Data.Class.SOrd +import Grisette.Core.Data.Class.SimpleMergeable +import Grisette.Core.Data.Class.Solvable + +-- $setup +-- >>> import Grisette.Core +-- >>> import Grisette.IR.SymPrim + +-- | Safe division with monadic error handling in multi-path +-- execution. These procedures throw an exception when the +-- divisor is zero. The result should be able to handle errors with +-- `MonadError`. +class (SOrd a, Num a, Mergeable a, Mergeable e) => SafeDivision e a | a -> e where + -- | Safe signed 'div' with monadic error handling in multi-path execution. + -- + -- >>> safeDiv (ssym "a") (ssym "b") :: ExceptT ArithException UnionM SymInteger + -- ExceptT {If (= b 0) (Left divide by zero) (Right (div a b))} + safeDiv :: (MonadError e uf, MonadUnion uf) => a -> a -> uf a + safeDiv l r = do + (d, _) <- safeDivMod l r + mrgSingle d + + -- | Safe signed 'mod' with monadic error handling in multi-path execution. + -- + -- >>> safeMod (ssym "a") (ssym "b") :: ExceptT ArithException UnionM SymInteger + -- ExceptT {If (= b 0) (Left divide by zero) (Right (mod a b))} + safeMod :: (MonadError e uf, MonadUnion uf) => a -> a -> uf a + safeMod l r = do + (_, m) <- safeDivMod l r + mrgSingle m + + -- | Safe signed 'divMod' with monadic error handling in multi-path execution. + -- + -- >>> safeDivMod (ssym "a") (ssym "b") :: ExceptT ArithException UnionM (SymInteger, SymInteger) + -- ExceptT {If (= b 0) (Left divide by zero) (Right ((div a b),(mod a b)))} + safeDivMod :: (MonadError e uf, MonadUnion uf) => a -> a -> uf (a, a) + safeDivMod l r = do + d <- safeDiv l r + m <- safeMod l r + mrgSingle (d, m) + + -- | Safe signed 'quot' with monadic error handling in multi-path execution. + safeQuot :: (MonadError e uf, MonadUnion uf) => a -> a -> uf a + safeQuot l r = do + (d, m) <- safeDivMod l r + mrgIf + ((l >=~ 0 &&~ r >~ 0) ||~ (l <=~ 0 &&~ r <~ 0) ||~ m ==~ 0) + (mrgSingle d) + (mrgSingle $ d + 1) + + -- | Safe signed 'rem' with monadic error handling in multi-path execution. + safeRem :: (MonadError e uf, MonadUnion uf) => a -> a -> uf a + safeRem l r = do + (d, m) <- safeDivMod l r + mrgIf + ((l >=~ 0 &&~ r >~ 0) ||~ (l <=~ 0 &&~ r <~ 0) ||~ m ==~ 0) + (mrgSingle m) + (mrgSingle $ m - r) + + -- | Safe signed 'quotRem' with monadic error handling in multi-path execution. + safeQuotRem :: (MonadError e uf, MonadUnion uf) => a -> a -> uf (a, a) + safeQuotRem l r = do + (d, m) <- safeDivMod l r + mrgIf + ((l >=~ 0 &&~ r >~ 0) ||~ (l <=~ 0 &&~ r <~ 0) ||~ m ==~ 0) + (mrgSingle (d, m)) + (mrgSingle (d + 1, m - r)) + + -- | Safe signed 'div' with monadic error handling in multi-path execution. + -- The error is transformed. + -- + -- >>> safeDiv' (const ()) (ssym "a") (ssym "b") :: ExceptT () UnionM SymInteger + -- ExceptT {If (= b 0) (Left ()) (Right (div a b))} + safeDiv' :: (MonadError e' uf, MonadUnion uf, Mergeable e') => (e -> e') -> a -> a -> uf a + safeDiv' t l r = do + (d, _) <- safeDivMod' t l r + mrgSingle d + + -- | Safe signed 'mod' with monadic error handling in multi-path execution. + -- The error is transformed. + -- + -- >>> safeMod' (const ()) (ssym "a") (ssym "b") :: ExceptT () UnionM SymInteger + -- ExceptT {If (= b 0) (Left ()) (Right (mod a b))} + safeMod' :: (MonadError e' uf, MonadUnion uf, Mergeable e') => (e -> e') -> a -> a -> uf a + safeMod' t l r = do + (_, m) <- safeDivMod' t l r + mrgSingle m + + -- | Safe signed 'divMod' with monadic error handling in multi-path execution. + -- The error is transformed. + -- + -- >>> safeDivMod' (const ()) (ssym "a") (ssym "b") :: ExceptT () UnionM (SymInteger, SymInteger) + -- ExceptT {If (= b 0) (Left ()) (Right ((div a b),(mod a b)))} + safeDivMod' :: (MonadError e' uf, MonadUnion uf, Mergeable e') => (e -> e') -> a -> a -> uf (a, a) + safeDivMod' t l r = do + d <- safeDiv' t l r + m <- safeMod' t l r + mrgSingle (d, m) + + -- | Safe signed 'quot' with monadic error handling in multi-path execution. + -- The error is transformed. + safeQuot' :: (MonadError e' uf, MonadUnion uf, Mergeable e') => (e -> e') -> a -> a -> uf a + safeQuot' t l r = do + (d, m) <- safeDivMod' t l r + mrgIf + ((l >=~ 0 &&~ r >~ 0) ||~ (l <=~ 0 &&~ r <~ 0) ||~ m ==~ 0) + (mrgSingle d) + (mrgSingle $ d + 1) + + -- | Safe signed 'rem' with monadic error handling in multi-path execution. + -- The error is transformed. + safeRem' :: (MonadError e' uf, MonadUnion uf, Mergeable e') => (e -> e') -> a -> a -> uf a + safeRem' t l r = do + (d, m) <- safeDivMod' t l r + mrgIf + ((l >=~ 0 &&~ r >~ 0) ||~ (l <=~ 0 &&~ r <~ 0) ||~ m ==~ 0) + (mrgSingle m) + (mrgSingle $ m - r) + + -- | Safe signed 'quotRem' with monadic error handling in multi-path execution. + -- The error is transformed. + safeQuotRem' :: (MonadError e' uf, MonadUnion uf, Mergeable e') => (e -> e') -> a -> a -> uf (a, a) + safeQuotRem' t l r = do + (d, m) <- safeDivMod' t l r + mrgIf + ((l >=~ 0 &&~ r >~ 0) ||~ (l <=~ 0 &&~ r <~ 0) ||~ m ==~ 0) + (mrgSingle (d, m)) + (mrgSingle (d + 1, m - r)) + + {-# MINIMAL (safeDivMod | (safeDiv, safeMod)), (safeDivMod' | (safeDiv', safeMod')) #-} + +#define SAFE_DIVISION_FUNC(name, op) \ +name _ r | r == 0 = merge $ throwError DivideByZero; \ +name l r = mrgSingle $ l `op` r; \ +name' t _ r | r == 0 = merge $ throwError (t DivideByZero); \ +name' _ l r = mrgSingle $ l `op` r + +#define SAFE_DIVISION_CONCRETE(type) \ +instance SafeDivision ArithException type where \ + SAFE_DIVISION_FUNC(safeDiv, div); \ + SAFE_DIVISION_FUNC(safeMod, mod); \ + SAFE_DIVISION_FUNC(safeDivMod, divMod); \ + SAFE_DIVISION_FUNC(safeQuot, quot); \ + SAFE_DIVISION_FUNC(safeRem, rem); \ + SAFE_DIVISION_FUNC(safeQuotRem, quotRem) + +#define SAFE_DIVISION_CONCRETE_BV(type) \ +instance (KnownNat n, 1 <= n) => SafeDivision ArithException (type n) where \ + SAFE_DIVISION_FUNC(safeDiv, div); \ + SAFE_DIVISION_FUNC(safeMod, mod); \ + SAFE_DIVISION_FUNC(safeDivMod, divMod); \ + SAFE_DIVISION_FUNC(safeQuot, quot); \ + SAFE_DIVISION_FUNC(safeRem, rem); \ + SAFE_DIVISION_FUNC(safeQuotRem, quotRem) + +#if 1 +SAFE_DIVISION_CONCRETE(Integer) +SAFE_DIVISION_CONCRETE(Int8) +SAFE_DIVISION_CONCRETE(Int16) +SAFE_DIVISION_CONCRETE(Int32) +SAFE_DIVISION_CONCRETE(Int64) +SAFE_DIVISION_CONCRETE(Int) +SAFE_DIVISION_CONCRETE(Word8) +SAFE_DIVISION_CONCRETE(Word16) +SAFE_DIVISION_CONCRETE(Word32) +SAFE_DIVISION_CONCRETE(Word64) +SAFE_DIVISION_CONCRETE(Word) +#endif + +#define SAFE_DIVISION_FUNC_SOME(stype, type, name, op) \ + name (stype (l :: type l)) (stype (r :: type r)) = \ + (case sameNat (Proxy @l) (Proxy @r) of \ + Just Refl -> \ + if r == 0 \ + then merge $ throwError $ Right DivideByZero \ + else mrgSingle $ stype $ l `op` r; \ + Nothing -> merge $ throwError $ Left BitwidthMismatch); \ + name' t (stype (l :: type l)) (stype (r :: type r)) = \ + (case sameNat (Proxy @l) (Proxy @r) of \ + Just Refl -> \ + if r == 0 \ + then merge $ throwError $ t (Right DivideByZero) \ + else mrgSingle $ stype $ l `op` r; \ + Nothing -> merge $ throwError $ t (Left BitwidthMismatch)) + +#define SAFE_DIVISION_FUNC_SOME_DIVMOD(stype, type, name, op) \ + name (stype (l :: type l)) (stype (r :: type r)) = \ + (case sameNat (Proxy @l) (Proxy @r) of \ + Just Refl -> \ + if r == 0 \ + then merge $ throwError $ Right DivideByZero \ + else (case l `op` r of (d, m) -> mrgSingle (stype d, stype m)); \ + Nothing -> merge $ throwError $ Left BitwidthMismatch); \ + name' t (stype (l :: type l)) (stype (r :: type r)) = \ + (case sameNat (Proxy @l) (Proxy @r) of \ + Just Refl -> \ + if r == 0 \ + then merge $ throwError $ t (Right DivideByZero) \ + else (case l `op` r of (d, m) -> mrgSingle (stype d, stype m)); \ + Nothing -> merge $ throwError $ t (Left BitwidthMismatch)) + +#if 1 +SAFE_DIVISION_CONCRETE_BV(IntN) +SAFE_DIVISION_CONCRETE_BV(WordN) +instance SafeDivision (Either BitwidthMismatch ArithException) SomeIntN where + SAFE_DIVISION_FUNC_SOME(SomeIntN, IntN, safeDiv, div) + SAFE_DIVISION_FUNC_SOME(SomeIntN, IntN, safeMod, mod) + SAFE_DIVISION_FUNC_SOME_DIVMOD(SomeIntN, IntN, safeDivMod, divMod) + SAFE_DIVISION_FUNC_SOME(SomeIntN, IntN, safeQuot, quot) + SAFE_DIVISION_FUNC_SOME(SomeIntN, IntN, safeRem, rem) + SAFE_DIVISION_FUNC_SOME_DIVMOD(SomeIntN, IntN, safeQuotRem, quotRem) + +instance SafeDivision (Either BitwidthMismatch ArithException) SomeWordN where + SAFE_DIVISION_FUNC_SOME(SomeWordN, WordN, safeDiv, div) + SAFE_DIVISION_FUNC_SOME(SomeWordN, WordN, safeMod, mod) + SAFE_DIVISION_FUNC_SOME_DIVMOD(SomeWordN, WordN, safeDivMod, divMod) + SAFE_DIVISION_FUNC_SOME(SomeWordN, WordN, safeQuot, quot) + SAFE_DIVISION_FUNC_SOME(SomeWordN, WordN, safeRem, rem) + SAFE_DIVISION_FUNC_SOME_DIVMOD(SomeWordN, WordN, safeQuotRem, quotRem) +#endif + +class SafeLinearArith a where + -- | Safe signed '+' with monadic error handling in multi-path execution. + -- Overflows are treated as errors. + -- + -- >>> safeAdd AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM SymInteger + -- ExceptT {Right (+ a b)} + -- >>> safeAdd AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM (SymIntN 4) + -- ExceptT {If (|| (&& (< 0x0 a) (&& (< 0x0 b) (< (+ a b) 0x0))) (&& (< a 0x0) (&& (< b 0x0) (<= 0x0 (+ a b))))) (Left AssertionError) (Right (+ a b))} + safeAdd :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf a + + -- | Safe signed 'negate' with monadic error handling in multi-path execution. + -- Overflows are treated as errors. + -- + -- >>> safeNeg AssertionError (ssym "a") :: ExceptT AssertionError UnionM SymInteger + -- ExceptT {Right (- a)} + -- >>> safeNeg AssertionError (ssym "a") :: ExceptT AssertionError UnionM (SymIntN 4) + -- ExceptT {If (= a 0x8) (Left AssertionError) (Right (- a))} + safeNeg :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> uf a + + -- | Safe signed '-' with monadic error handling in multi-path execution. + -- Overflows are treated as errors. + -- + -- >>> safeMinus AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM SymInteger + -- ExceptT {Right (+ a (- b))} + -- >>> safeMinus AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM (SymIntN 4) + -- ExceptT {If (|| (&& (<= 0x0 a) (&& (< b 0x0) (< (+ a (- b)) 0x0))) (&& (< a 0x0) (&& (< 0x0 b) (< 0x0 (+ a (- b)))))) (Left AssertionError) (Right (+ a (- b)))} + safeMinus :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf a + +-- | Aggregation for the operations on symbolic integer types +class (Num a, SEq a, SOrd a, Solvable Integer a) => SymIntegerOp a diff --git a/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/InternedCtors.hs b/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/InternedCtors.hs index d5728da7..0f95caa6 100644 --- a/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/InternedCtors.hs +++ b/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/InternedCtors.hs @@ -49,8 +49,14 @@ module Grisette.IR.SymPrim.Data.Prim.InternedTerm.InternedCtors bvzeroExtendTerm, tabularFunApplyTerm, generalFunApplyTerm, - divIntegerTerm, - modIntegerTerm, + divIntegralTerm, + modIntegralTerm, + quotIntegralTerm, + remIntegralTerm, + divBoundedIntegralTerm, + modBoundedIntegralTerm, + quotBoundedIntegralTerm, + remBoundedIntegralTerm, ) where @@ -309,10 +315,34 @@ generalFunApplyTerm :: (SupportedPrim a, SupportedPrim b) => Term (a --> b) -> T generalFunApplyTerm f a = internTerm $ UGeneralFunApplyTerm f a {-# INLINE generalFunApplyTerm #-} -divIntegerTerm :: Term Integer -> Term Integer -> Term Integer -divIntegerTerm l r = internTerm $ UDivIntegerTerm l r -{-# INLINE divIntegerTerm #-} +divIntegralTerm :: (SupportedPrim a, Integral a) => Term a -> Term a -> Term a +divIntegralTerm l r = internTerm $ UDivIntegralTerm l r +{-# INLINE divIntegralTerm #-} -modIntegerTerm :: Term Integer -> Term Integer -> Term Integer -modIntegerTerm l r = internTerm $ UModIntegerTerm l r -{-# INLINE modIntegerTerm #-} +modIntegralTerm :: (SupportedPrim a, Integral a) => Term a -> Term a -> Term a +modIntegralTerm l r = internTerm $ UModIntegralTerm l r +{-# INLINE modIntegralTerm #-} + +quotIntegralTerm :: (SupportedPrim a, Integral a) => Term a -> Term a -> Term a +quotIntegralTerm l r = internTerm $ UQuotIntegralTerm l r +{-# INLINE quotIntegralTerm #-} + +remIntegralTerm :: (SupportedPrim a, Integral a) => Term a -> Term a -> Term a +remIntegralTerm l r = internTerm $ URemIntegralTerm l r +{-# INLINE remIntegralTerm #-} + +divBoundedIntegralTerm :: (SupportedPrim a, Bounded a, Integral a) => Term a -> Term a -> Term a +divBoundedIntegralTerm l r = internTerm $ UDivBoundedIntegralTerm l r +{-# INLINE divBoundedIntegralTerm #-} + +modBoundedIntegralTerm :: (SupportedPrim a, Bounded a, Integral a) => Term a -> Term a -> Term a +modBoundedIntegralTerm l r = internTerm $ UModBoundedIntegralTerm l r +{-# INLINE modBoundedIntegralTerm #-} + +quotBoundedIntegralTerm :: (SupportedPrim a, Bounded a, Integral a) => Term a -> Term a -> Term a +quotBoundedIntegralTerm l r = internTerm $ UQuotBoundedIntegralTerm l r +{-# INLINE quotBoundedIntegralTerm #-} + +remBoundedIntegralTerm :: (SupportedPrim a, Bounded a, Integral a) => Term a -> Term a -> Term a +remBoundedIntegralTerm l r = internTerm $ URemBoundedIntegralTerm l r +{-# INLINE remBoundedIntegralTerm #-} diff --git a/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/InternedCtors.hs-boot b/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/InternedCtors.hs-boot index 1adfd9b2..42b2e251 100644 --- a/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/InternedCtors.hs-boot +++ b/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/InternedCtors.hs-boot @@ -38,8 +38,14 @@ module Grisette.IR.SymPrim.Data.Prim.InternedTerm.InternedCtors bvzeroExtendTerm, tabularFunApplyTerm, generalFunApplyTerm, - divIntegerTerm, - modIntegerTerm, + divIntegralTerm, + modIntegralTerm, + quotIntegralTerm, + remIntegralTerm, + divBoundedIntegralTerm, + modBoundedIntegralTerm, + quotBoundedIntegralTerm, + remBoundedIntegralTerm, ) where @@ -178,5 +184,11 @@ bvzeroExtendTerm :: Term (bv r) tabularFunApplyTerm :: (SupportedPrim a, SupportedPrim b) => Term (a =-> b) -> Term a -> Term b generalFunApplyTerm :: (SupportedPrim a, SupportedPrim b) => Term (a --> b) -> Term a -> Term b -divIntegerTerm :: Term Integer -> Term Integer -> Term Integer -modIntegerTerm :: Term Integer -> Term Integer -> Term Integer +divIntegralTerm :: (SupportedPrim a, Integral a) => Term a -> Term a -> Term a +modIntegralTerm :: (SupportedPrim a, Integral a) => Term a -> Term a -> Term a +quotIntegralTerm :: (SupportedPrim a, Integral a) => Term a -> Term a -> Term a +remIntegralTerm :: (SupportedPrim a, Integral a) => Term a -> Term a -> Term a +divBoundedIntegralTerm :: (SupportedPrim a, Bounded a, Integral a) => Term a -> Term a -> Term a +modBoundedIntegralTerm :: (SupportedPrim a, Bounded a, Integral a) => Term a -> Term a -> Term a +quotBoundedIntegralTerm :: (SupportedPrim a, Bounded a, Integral a) => Term a -> Term a -> Term a +remBoundedIntegralTerm :: (SupportedPrim a, Bounded a, Integral a) => Term a -> Term a -> Term a diff --git a/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/Term.hs b/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/Term.hs index 158772f9..f8bc11a4 100644 --- a/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/Term.hs +++ b/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/Term.hs @@ -388,8 +388,14 @@ data Term t where Term (a --> b) -> Term a -> Term b - DivIntegerTerm :: !Id -> Term Integer -> Term Integer -> Term Integer - ModIntegerTerm :: !Id -> Term Integer -> Term Integer -> Term Integer + DivIntegralTerm :: (SupportedPrim t, Integral t) => {-# UNPACK #-} !Id -> !(Term t) -> !(Term t) -> Term t + ModIntegralTerm :: (SupportedPrim t, Integral t) => {-# UNPACK #-} !Id -> !(Term t) -> !(Term t) -> Term t + QuotIntegralTerm :: (SupportedPrim t, Integral t) => {-# UNPACK #-} !Id -> !(Term t) -> !(Term t) -> Term t + RemIntegralTerm :: (SupportedPrim t, Integral t) => {-# UNPACK #-} !Id -> !(Term t) -> !(Term t) -> Term t + DivBoundedIntegralTerm :: (SupportedPrim t, Bounded t, Integral t) => {-# UNPACK #-} !Id -> !(Term t) -> !(Term t) -> Term t + ModBoundedIntegralTerm :: (SupportedPrim t, Bounded t, Integral t) => {-# UNPACK #-} !Id -> !(Term t) -> !(Term t) -> Term t + QuotBoundedIntegralTerm :: (SupportedPrim t, Bounded t, Integral t) => {-# UNPACK #-} !Id -> !(Term t) -> !(Term t) -> Term t + RemBoundedIntegralTerm :: (SupportedPrim t, Bounded t, Integral t) => {-# UNPACK #-} !Id -> !(Term t) -> !(Term t) -> Term t instance NFData (Term a) where rnf i = identity i `seq` () @@ -424,8 +430,14 @@ instance Lift (Term t) where liftTyped (BVExtendTerm _ signed (_ :: TypeRep n) arg) = [||bvextendTerm signed (Proxy @n) arg||] liftTyped (TabularFunApplyTerm _ func arg) = [||tabularFunApplyTerm func arg||] liftTyped (GeneralFunApplyTerm _ func arg) = [||generalFunApplyTerm func arg||] - liftTyped (DivIntegerTerm _ arg1 arg2) = [||divIntegerTerm arg1 arg2||] - liftTyped (ModIntegerTerm _ arg1 arg2) = [||modIntegerTerm arg1 arg2||] + liftTyped (DivIntegralTerm _ arg1 arg2) = [||divIntegralTerm arg1 arg2||] + liftTyped (ModIntegralTerm _ arg1 arg2) = [||modIntegralTerm arg1 arg2||] + liftTyped (QuotIntegralTerm _ arg1 arg2) = [||quotIntegralTerm arg1 arg2||] + liftTyped (RemIntegralTerm _ arg1 arg2) = [||remIntegralTerm arg1 arg2||] + liftTyped (DivBoundedIntegralTerm _ arg1 arg2) = [||divBoundedIntegralTerm arg1 arg2||] + liftTyped (ModBoundedIntegralTerm _ arg1 arg2) = [||modBoundedIntegralTerm arg1 arg2||] + liftTyped (QuotBoundedIntegralTerm _ arg1 arg2) = [||quotBoundedIntegralTerm arg1 arg2||] + liftTyped (RemBoundedIntegralTerm _ arg1 arg2) = [||remBoundedIntegralTerm arg1 arg2||] instance Show (Term ty) where show (ConTerm i v) = "ConTerm{id=" ++ show i ++ ", v=" ++ show v ++ "}" @@ -496,10 +508,22 @@ instance Show (Term ty) where "TabularFunApply{id=" ++ show i ++ ", func=" ++ show func ++ ", arg=" ++ show arg ++ "}" show (GeneralFunApplyTerm i func arg) = "GeneralFunApply{id=" ++ show i ++ ", func=" ++ show func ++ ", arg=" ++ show arg ++ "}" - show (DivIntegerTerm i arg1 arg2) = - "DivInteger{id=" ++ show i ++ ", arg1=" ++ show arg1 ++ ", arg2=" ++ show arg2 ++ "}" - show (ModIntegerTerm i arg1 arg2) = - "ModInteger{id=" ++ show i ++ ", arg1=" ++ show arg1 ++ ", arg2=" ++ show arg2 ++ "}" + show (DivIntegralTerm i arg1 arg2) = + "DivIntegral{id=" ++ show i ++ ", arg1=" ++ show arg1 ++ ", arg2=" ++ show arg2 ++ "}" + show (ModIntegralTerm i arg1 arg2) = + "ModIntegral{id=" ++ show i ++ ", arg1=" ++ show arg1 ++ ", arg2=" ++ show arg2 ++ "}" + show (QuotIntegralTerm i arg1 arg2) = + "QuotIntegral{id=" ++ show i ++ ", arg1=" ++ show arg1 ++ ", arg2=" ++ show arg2 ++ "}" + show (RemIntegralTerm i arg1 arg2) = + "RemIntegral{id=" ++ show i ++ ", arg1=" ++ show arg1 ++ ", arg2=" ++ show arg2 ++ "}" + show (DivBoundedIntegralTerm i arg1 arg2) = + "DivBoundedIntegral{id=" ++ show i ++ ", arg1=" ++ show arg1 ++ ", arg2=" ++ show arg2 ++ "}" + show (ModBoundedIntegralTerm i arg1 arg2) = + "ModBoundedIntegral{id=" ++ show i ++ ", arg1=" ++ show arg1 ++ ", arg2=" ++ show arg2 ++ "}" + show (QuotBoundedIntegralTerm i arg1 arg2) = + "QuotBoundedIntegral{id=" ++ show i ++ ", arg1=" ++ show arg1 ++ ", arg2=" ++ show arg2 ++ "}" + show (RemBoundedIntegralTerm i arg1 arg2) = + "RemBoundedIntegral{id=" ++ show i ++ ", arg1=" ++ show arg1 ++ ", arg2=" ++ show arg2 ++ "}" instance (SupportedPrim t) => Eq (Term t) where (==) = (==) `on` identity @@ -597,8 +621,14 @@ data UTerm t where Term (a --> b) -> Term a -> UTerm b - UDivIntegerTerm :: Term Integer -> Term Integer -> UTerm Integer - UModIntegerTerm :: Term Integer -> Term Integer -> UTerm Integer + UDivIntegralTerm :: (SupportedPrim t, Integral t) => !(Term t) -> !(Term t) -> UTerm t + UModIntegralTerm :: (SupportedPrim t, Integral t) => !(Term t) -> !(Term t) -> UTerm t + UQuotIntegralTerm :: (SupportedPrim t, Integral t) => !(Term t) -> !(Term t) -> UTerm t + URemIntegralTerm :: (SupportedPrim t, Integral t) => !(Term t) -> !(Term t) -> UTerm t + UDivBoundedIntegralTerm :: (SupportedPrim t, Bounded t, Integral t) => !(Term t) -> !(Term t) -> UTerm t + UModBoundedIntegralTerm :: (SupportedPrim t, Bounded t, Integral t) => !(Term t) -> !(Term t) -> UTerm t + UQuotBoundedIntegralTerm :: (SupportedPrim t, Bounded t, Integral t) => !(Term t) -> !(Term t) -> UTerm t + URemBoundedIntegralTerm :: (SupportedPrim t, Bounded t, Integral t) => !(Term t) -> !(Term t) -> UTerm t eqTypedId :: (TypeRep a, Id) -> (TypeRep b, Id) -> Bool eqTypedId (a, i1) (b, i2) = i1 == i2 && eqTypeRepBool a b @@ -669,8 +699,14 @@ instance (SupportedPrim t) => Interned (Term t) where {-# UNPACK #-} !(TypeRep (a --> b), Id) -> {-# UNPACK #-} !(TypeRep a, Id) -> Description (Term b) - DDivIntegerTerm :: {-# UNPACK #-} !Id -> {-# UNPACK #-} !Id -> Description (Term Integer) - DModIntegerTerm :: {-# UNPACK #-} !Id -> {-# UNPACK #-} !Id -> Description (Term Integer) + DDivIntegralTerm :: {-# UNPACK #-} !Id -> {-# UNPACK #-} !Id -> Description (Term a) + DModIntegralTerm :: {-# UNPACK #-} !Id -> {-# UNPACK #-} !Id -> Description (Term a) + DQuotIntegralTerm :: {-# UNPACK #-} !Id -> {-# UNPACK #-} !Id -> Description (Term a) + DRemIntegralTerm :: {-# UNPACK #-} !Id -> {-# UNPACK #-} !Id -> Description (Term a) + DDivBoundedIntegralTerm :: {-# UNPACK #-} !Id -> {-# UNPACK #-} !Id -> Description (Term a) + DModBoundedIntegralTerm :: {-# UNPACK #-} !Id -> {-# UNPACK #-} !Id -> Description (Term a) + DQuotBoundedIntegralTerm :: {-# UNPACK #-} !Id -> {-# UNPACK #-} !Id -> Description (Term a) + DRemBoundedIntegralTerm :: {-# UNPACK #-} !Id -> {-# UNPACK #-} !Id -> Description (Term a) describe (UConTerm v) = DConTerm v describe ((USymTerm name) :: UTerm t) = DSymTerm @t name @@ -712,8 +748,15 @@ instance (SupportedPrim t) => Interned (Term t) where DTabularFunApplyTerm (typeRep :: TypeRep f, identity func) (typeRep :: TypeRep a, identity arg) describe (UGeneralFunApplyTerm (func :: Term f) (arg :: Term a)) = DGeneralFunApplyTerm (typeRep :: TypeRep f, identity func) (typeRep :: TypeRep a, identity arg) - describe (UDivIntegerTerm arg1 arg2) = DDivIntegerTerm (identity arg1) (identity arg2) - describe (UModIntegerTerm arg1 arg2) = DModIntegerTerm (identity arg1) (identity arg2) + describe (UDivIntegralTerm arg1 arg2) = DDivIntegralTerm (identity arg1) (identity arg2) + describe (UModIntegralTerm arg1 arg2) = DModIntegralTerm (identity arg1) (identity arg2) + describe (UQuotIntegralTerm arg1 arg2) = DRemIntegralTerm (identity arg1) (identity arg2) + describe (URemIntegralTerm arg1 arg2) = DQuotIntegralTerm (identity arg1) (identity arg2) + describe (UDivBoundedIntegralTerm arg1 arg2) = DDivBoundedIntegralTerm (identity arg1) (identity arg2) + describe (UModBoundedIntegralTerm arg1 arg2) = DModBoundedIntegralTerm (identity arg1) (identity arg2) + describe (UQuotBoundedIntegralTerm arg1 arg2) = DRemBoundedIntegralTerm (identity arg1) (identity arg2) + describe (URemBoundedIntegralTerm arg1 arg2) = DQuotBoundedIntegralTerm (identity arg1) (identity arg2) + identify i = go where go (UConTerm v) = ConTerm i v @@ -744,8 +787,14 @@ instance (SupportedPrim t) => Interned (Term t) where go (UBVExtendTerm signed n arg) = BVExtendTerm i signed n arg go (UTabularFunApplyTerm func arg) = TabularFunApplyTerm i func arg go (UGeneralFunApplyTerm func arg) = GeneralFunApplyTerm i func arg - go (UDivIntegerTerm arg1 arg2) = DivIntegerTerm i arg1 arg2 - go (UModIntegerTerm arg1 arg2) = ModIntegerTerm i arg1 arg2 + go (UDivIntegralTerm arg1 arg2) = DivIntegralTerm i arg1 arg2 + go (UModIntegralTerm arg1 arg2) = ModIntegralTerm i arg1 arg2 + go (UQuotIntegralTerm arg1 arg2) = QuotIntegralTerm i arg1 arg2 + go (URemIntegralTerm arg1 arg2) = RemIntegralTerm i arg1 arg2 + go (UDivBoundedIntegralTerm arg1 arg2) = DivBoundedIntegralTerm i arg1 arg2 + go (UModBoundedIntegralTerm arg1 arg2) = ModBoundedIntegralTerm i arg1 arg2 + go (UQuotBoundedIntegralTerm arg1 arg2) = QuotBoundedIntegralTerm i arg1 arg2 + go (URemBoundedIntegralTerm arg1 arg2) = RemBoundedIntegralTerm i arg1 arg2 cache = termCache instance (SupportedPrim t) => Eq (Description (Term t)) where @@ -784,8 +833,14 @@ instance (SupportedPrim t) => Eq (Description (Term t)) where && eqTypedId li ri DTabularFunApplyTerm lf li == DTabularFunApplyTerm rf ri = eqTypedId lf rf && eqTypedId li ri DGeneralFunApplyTerm lf li == DGeneralFunApplyTerm rf ri = eqTypedId lf rf && eqTypedId li ri - DDivIntegerTerm li1 li2 == DDivIntegerTerm ri1 ri2 = li1 == ri1 && li2 == ri2 - DModIntegerTerm li1 li2 == DModIntegerTerm ri1 ri2 = li1 == ri1 && li2 == ri2 + DDivIntegralTerm li1 li2 == DDivIntegralTerm ri1 ri2 = li1 == ri1 && li2 == ri2 + DModIntegralTerm li1 li2 == DModIntegralTerm ri1 ri2 = li1 == ri1 && li2 == ri2 + DQuotIntegralTerm li1 li2 == DQuotIntegralTerm ri1 ri2 = li1 == ri1 && li2 == ri2 + DRemIntegralTerm li1 li2 == DRemIntegralTerm ri1 ri2 = li1 == ri1 && li2 == ri2 + DDivBoundedIntegralTerm li1 li2 == DDivBoundedIntegralTerm ri1 ri2 = li1 == ri1 && li2 == ri2 + DModBoundedIntegralTerm li1 li2 == DModBoundedIntegralTerm ri1 ri2 = li1 == ri1 && li2 == ri2 + DQuotBoundedIntegralTerm li1 li2 == DQuotBoundedIntegralTerm ri1 ri2 = li1 == ri1 && li2 == ri2 + DRemBoundedIntegralTerm li1 li2 == DRemBoundedIntegralTerm ri1 ri2 = li1 == ri1 && li2 == ri2 _ == _ = False instance (SupportedPrim t) => Hashable (Description (Term t)) where @@ -837,8 +892,14 @@ instance (SupportedPrim t) => Hashable (Description (Term t)) where `hashWithSalt` id1 hashWithSalt s (DTabularFunApplyTerm id1 id2) = s `hashWithSalt` (26 :: Int) `hashWithSalt` id1 `hashWithSalt` id2 hashWithSalt s (DGeneralFunApplyTerm id1 id2) = s `hashWithSalt` (27 :: Int) `hashWithSalt` id1 `hashWithSalt` id2 - hashWithSalt s (DDivIntegerTerm id1 id2) = s `hashWithSalt` (28 :: Int) `hashWithSalt` id1 `hashWithSalt` id2 - hashWithSalt s (DModIntegerTerm id1 id2) = s `hashWithSalt` (29 :: Int) `hashWithSalt` id1 `hashWithSalt` id2 + hashWithSalt s (DDivIntegralTerm id1 id2) = s `hashWithSalt` (28 :: Int) `hashWithSalt` id1 `hashWithSalt` id2 + hashWithSalt s (DModIntegralTerm id1 id2) = s `hashWithSalt` (29 :: Int) `hashWithSalt` id1 `hashWithSalt` id2 + hashWithSalt s (DQuotIntegralTerm id1 id2) = s `hashWithSalt` (30 :: Int) `hashWithSalt` id1 `hashWithSalt` id2 + hashWithSalt s (DRemIntegralTerm id1 id2) = s `hashWithSalt` (31 :: Int) `hashWithSalt` id1 `hashWithSalt` id2 + hashWithSalt s (DDivBoundedIntegralTerm id1 id2) = s `hashWithSalt` (32 :: Int) `hashWithSalt` id1 `hashWithSalt` id2 + hashWithSalt s (DModBoundedIntegralTerm id1 id2) = s `hashWithSalt` (33 :: Int) `hashWithSalt` id1 `hashWithSalt` id2 + hashWithSalt s (DQuotBoundedIntegralTerm id1 id2) = s `hashWithSalt` (34 :: Int) `hashWithSalt` id1 `hashWithSalt` id2 + hashWithSalt s (DRemBoundedIntegralTerm id1 id2) = s `hashWithSalt` (35 :: Int) `hashWithSalt` id1 `hashWithSalt` id2 -- Basic Bool defaultValueForBool :: Bool diff --git a/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/Term.hs-boot b/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/Term.hs-boot index 392f994f..9ea24401 100644 --- a/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/Term.hs-boot +++ b/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/Term.hs-boot @@ -235,8 +235,14 @@ data Term t where Term (a --> b) -> Term a -> Term b - DivIntegerTerm :: !Id -> Term Integer -> Term Integer -> Term Integer - ModIntegerTerm :: !Id -> Term Integer -> Term Integer -> Term Integer + DivIntegralTerm :: (SupportedPrim t, Integral t) => {-# UNPACK #-} !Id -> !(Term t) -> !(Term t) -> Term t + ModIntegralTerm :: (SupportedPrim t, Integral t) => {-# UNPACK #-} !Id -> !(Term t) -> !(Term t) -> Term t + QuotIntegralTerm :: (SupportedPrim t, Integral t) => {-# UNPACK #-} !Id -> !(Term t) -> !(Term t) -> Term t + RemIntegralTerm :: (SupportedPrim t, Integral t) => {-# UNPACK #-} !Id -> !(Term t) -> !(Term t) -> Term t + DivBoundedIntegralTerm :: (SupportedPrim t, Bounded t, Integral t) => {-# UNPACK #-} !Id -> !(Term t) -> !(Term t) -> Term t + ModBoundedIntegralTerm :: (SupportedPrim t, Bounded t, Integral t) => {-# UNPACK #-} !Id -> !(Term t) -> !(Term t) -> Term t + QuotBoundedIntegralTerm :: (SupportedPrim t, Bounded t, Integral t) => {-# UNPACK #-} !Id -> !(Term t) -> !(Term t) -> Term t + RemBoundedIntegralTerm :: (SupportedPrim t, Bounded t, Integral t) => {-# UNPACK #-} !Id -> !(Term t) -> !(Term t) -> Term t data UTerm t where UConTerm :: (SupportedPrim t) => !t -> UTerm t @@ -328,8 +334,14 @@ data UTerm t where Term (a --> b) -> Term a -> UTerm b - UDivIntegerTerm :: Term Integer -> Term Integer -> UTerm Integer - UModIntegerTerm :: Term Integer -> Term Integer -> UTerm Integer + UDivIntegralTerm :: (SupportedPrim t, Integral t) => !(Term t) -> !(Term t) -> UTerm t + UModIntegralTerm :: (SupportedPrim t, Integral t) => !(Term t) -> !(Term t) -> UTerm t + UQuotIntegralTerm :: (SupportedPrim t, Integral t) => !(Term t) -> !(Term t) -> UTerm t + URemIntegralTerm :: (SupportedPrim t, Integral t) => !(Term t) -> !(Term t) -> UTerm t + UDivBoundedIntegralTerm :: (SupportedPrim t, Bounded t, Integral t) => !(Term t) -> !(Term t) -> UTerm t + UModBoundedIntegralTerm :: (SupportedPrim t, Bounded t, Integral t) => !(Term t) -> !(Term t) -> UTerm t + UQuotBoundedIntegralTerm :: (SupportedPrim t, Bounded t, Integral t) => !(Term t) -> !(Term t) -> UTerm t + URemBoundedIntegralTerm :: (SupportedPrim t, Bounded t, Integral t) => !(Term t) -> !(Term t) -> UTerm t data (-->) a b where GeneralFun :: (SupportedPrim a, SupportedPrim b) => TypedSymbol a -> Term b -> a --> b diff --git a/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/TermSubstitution.hs b/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/TermSubstitution.hs index 9e872aa5..8c0fca46 100644 --- a/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/TermSubstitution.hs +++ b/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/TermSubstitution.hs @@ -22,7 +22,7 @@ import Grisette.IR.SymPrim.Data.Prim.PartialEval.BV import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bits import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bool import Grisette.IR.SymPrim.Data.Prim.PartialEval.GeneralFun -import Grisette.IR.SymPrim.Data.Prim.PartialEval.Integer +import Grisette.IR.SymPrim.Data.Prim.PartialEval.Integral import Grisette.IR.SymPrim.Data.Prim.PartialEval.Num import Grisette.IR.SymPrim.Data.Prim.PartialEval.TabularFun import Type.Reflection @@ -74,5 +74,11 @@ substTerm sym term = gov BVExtendTerm _ n signed op -> SomeTerm $ pevalBVExtendTerm n signed (gov op) TabularFunApplyTerm _ f op -> SomeTerm $ pevalTabularFunApplyTerm (gov f) (gov op) GeneralFunApplyTerm _ f op -> SomeTerm $ pevalGeneralFunApplyTerm (gov f) (gov op) - DivIntegerTerm _ op1 op2 -> SomeTerm $ pevalDivIntegerTerm (gov op1) (gov op2) - ModIntegerTerm _ op1 op2 -> SomeTerm $ pevalModIntegerTerm (gov op1) (gov op2) + DivIntegralTerm _ op1 op2 -> SomeTerm $ pevalDivIntegralTerm (gov op1) (gov op2) + ModIntegralTerm _ op1 op2 -> SomeTerm $ pevalModIntegralTerm (gov op1) (gov op2) + QuotIntegralTerm _ op1 op2 -> SomeTerm $ pevalQuotIntegralTerm (gov op1) (gov op2) + RemIntegralTerm _ op1 op2 -> SomeTerm $ pevalRemIntegralTerm (gov op1) (gov op2) + DivBoundedIntegralTerm _ op1 op2 -> SomeTerm $ pevalDivBoundedIntegralTerm (gov op1) (gov op2) + ModBoundedIntegralTerm _ op1 op2 -> SomeTerm $ pevalModBoundedIntegralTerm (gov op1) (gov op2) + QuotBoundedIntegralTerm _ op1 op2 -> SomeTerm $ pevalQuotBoundedIntegralTerm (gov op1) (gov op2) + RemBoundedIntegralTerm _ op1 op2 -> SomeTerm $ pevalRemBoundedIntegralTerm (gov op1) (gov op2) diff --git a/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/TermUtils.hs b/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/TermUtils.hs index 5af668ef..5d878f85 100644 --- a/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/TermUtils.hs +++ b/src/Grisette/IR/SymPrim/Data/Prim/InternedTerm/TermUtils.hs @@ -37,36 +37,7 @@ import Grisette.IR.SymPrim.Data.TabularFun () import qualified Type.Reflection as R identity :: Term t -> Id -identity (ConTerm i _) = i -identity (SymTerm i _) = i -identity (UnaryTerm i _ _) = i -identity (BinaryTerm i _ _ _) = i -identity (TernaryTerm i _ _ _ _) = i -identity (NotTerm i _) = i -identity (OrTerm i _ _) = i -identity (AndTerm i _ _) = i -identity (EqvTerm i _ _) = i -identity (ITETerm i _ _ _) = i -identity (AddNumTerm i _ _) = i -identity (UMinusNumTerm i _) = i -identity (TimesNumTerm i _ _) = i -identity (AbsNumTerm i _) = i -identity (SignumNumTerm i _) = i -identity (LTNumTerm i _ _) = i -identity (LENumTerm i _ _) = i -identity (AndBitsTerm i _ _) = i -identity (OrBitsTerm i _ _) = i -identity (XorBitsTerm i _ _) = i -identity (ComplementBitsTerm i _) = i -identity (ShiftBitsTerm i _ _) = i -identity (RotateBitsTerm i _ _) = i -identity (BVConcatTerm i _ _) = i -identity (BVSelectTerm i _ _ _) = i -identity (BVExtendTerm i _ _ _) = i -identity (TabularFunApplyTerm i _ _) = i -identity (GeneralFunApplyTerm i _ _) = i -identity (DivIntegerTerm i _ _) = i -identity (ModIntegerTerm i _ _) = i +identity = snd . identityWithTypeRep {-# INLINE identity #-} identityWithTypeRep :: forall t. Term t -> (TypeRep, Id) @@ -98,8 +69,14 @@ identityWithTypeRep (BVSelectTerm i _ _ _) = (typeRep (Proxy @t), i) identityWithTypeRep (BVExtendTerm i _ _ _) = (typeRep (Proxy @t), i) identityWithTypeRep (TabularFunApplyTerm i _ _) = (typeRep (Proxy @t), i) identityWithTypeRep (GeneralFunApplyTerm i _ _) = (typeRep (Proxy @t), i) -identityWithTypeRep (DivIntegerTerm i _ _) = (typeRep (Proxy @t), i) -identityWithTypeRep (ModIntegerTerm i _ _) = (typeRep (Proxy @t), i) +identityWithTypeRep (DivIntegralTerm i _ _) = (typeRep (Proxy @t), i) +identityWithTypeRep (ModIntegralTerm i _ _) = (typeRep (Proxy @t), i) +identityWithTypeRep (QuotIntegralTerm i _ _) = (typeRep (Proxy @t), i) +identityWithTypeRep (RemIntegralTerm i _ _) = (typeRep (Proxy @t), i) +identityWithTypeRep (DivBoundedIntegralTerm i _ _) = (typeRep (Proxy @t), i) +identityWithTypeRep (ModBoundedIntegralTerm i _ _) = (typeRep (Proxy @t), i) +identityWithTypeRep (QuotBoundedIntegralTerm i _ _) = (typeRep (Proxy @t), i) +identityWithTypeRep (RemBoundedIntegralTerm i _ _) = (typeRep (Proxy @t), i) {-# INLINE identityWithTypeRep #-} introSupportedPrimConstraint :: forall t a. Term t -> ((SupportedPrim t) => a) -> a @@ -131,8 +108,14 @@ introSupportedPrimConstraint BVSelectTerm {} x = x introSupportedPrimConstraint BVExtendTerm {} x = x introSupportedPrimConstraint TabularFunApplyTerm {} x = x introSupportedPrimConstraint GeneralFunApplyTerm {} x = x -introSupportedPrimConstraint DivIntegerTerm {} x = x -introSupportedPrimConstraint ModIntegerTerm {} x = x +introSupportedPrimConstraint DivIntegralTerm {} x = x +introSupportedPrimConstraint ModIntegralTerm {} x = x +introSupportedPrimConstraint QuotIntegralTerm {} x = x +introSupportedPrimConstraint RemIntegralTerm {} x = x +introSupportedPrimConstraint DivBoundedIntegralTerm {} x = x +introSupportedPrimConstraint ModBoundedIntegralTerm {} x = x +introSupportedPrimConstraint QuotBoundedIntegralTerm {} x = x +introSupportedPrimConstraint RemBoundedIntegralTerm {} x = x {-# INLINE introSupportedPrimConstraint #-} extractSymbolicsSomeTerm :: SomeTerm -> S.HashSet SomeTypedSymbol @@ -177,8 +160,14 @@ extractSymbolicsSomeTerm t1 = evalState (gocached t1) M.empty go (SomeTerm (BVExtendTerm _ _ _ arg)) = goUnary arg go (SomeTerm (TabularFunApplyTerm _ func arg)) = goBinary func arg go (SomeTerm (GeneralFunApplyTerm _ func arg)) = goBinary func arg - go (SomeTerm (DivIntegerTerm _ arg1 arg2)) = goBinary arg1 arg2 - go (SomeTerm (ModIntegerTerm _ arg1 arg2)) = goBinary arg1 arg2 + go (SomeTerm (DivIntegralTerm _ arg1 arg2)) = goBinary arg1 arg2 + go (SomeTerm (ModIntegralTerm _ arg1 arg2)) = goBinary arg1 arg2 + go (SomeTerm (QuotIntegralTerm _ arg1 arg2)) = goBinary arg1 arg2 + go (SomeTerm (RemIntegralTerm _ arg1 arg2)) = goBinary arg1 arg2 + go (SomeTerm (DivBoundedIntegralTerm _ arg1 arg2)) = goBinary arg1 arg2 + go (SomeTerm (ModBoundedIntegralTerm _ arg1 arg2)) = goBinary arg1 arg2 + go (SomeTerm (QuotBoundedIntegralTerm _ arg1 arg2)) = goBinary arg1 arg2 + go (SomeTerm (RemBoundedIntegralTerm _ arg1 arg2)) = goBinary arg1 arg2 goUnary arg = gocached (SomeTerm arg) goBinary arg1 arg2 = do r1 <- gocached (SomeTerm arg1) @@ -224,8 +213,14 @@ castTerm t@BVSelectTerm {} = cast t castTerm t@BVExtendTerm {} = cast t castTerm t@TabularFunApplyTerm {} = cast t castTerm t@GeneralFunApplyTerm {} = cast t -castTerm t@DivIntegerTerm {} = cast t -castTerm t@ModIntegerTerm {} = cast t +castTerm t@DivIntegralTerm {} = cast t +castTerm t@ModIntegralTerm {} = cast t +castTerm t@QuotIntegralTerm {} = cast t +castTerm t@RemIntegralTerm {} = cast t +castTerm t@DivBoundedIntegralTerm {} = cast t +castTerm t@ModBoundedIntegralTerm {} = cast t +castTerm t@QuotBoundedIntegralTerm {} = cast t +castTerm t@RemBoundedIntegralTerm {} = cast t {-# INLINE castTerm #-} pformat :: forall t. (SupportedPrim t) => Term t -> String @@ -258,8 +253,14 @@ pformat (BVExtendTerm _ signed n arg) = (if signed then "(bvsext " else "(bvzext ") ++ show n ++ " " ++ pformat arg ++ ")" pformat (TabularFunApplyTerm _ func arg) = "(apply " ++ pformat func ++ " " ++ pformat arg ++ ")" pformat (GeneralFunApplyTerm _ func arg) = "(apply " ++ pformat func ++ " " ++ pformat arg ++ ")" -pformat (DivIntegerTerm _ arg1 arg2) = "(div " ++ pformat arg1 ++ " " ++ pformat arg2 ++ ")" -pformat (ModIntegerTerm _ arg1 arg2) = "(mod " ++ pformat arg1 ++ " " ++ pformat arg2 ++ ")" +pformat (DivIntegralTerm _ arg1 arg2) = "(div " ++ pformat arg1 ++ " " ++ pformat arg2 ++ ")" +pformat (ModIntegralTerm _ arg1 arg2) = "(mod " ++ pformat arg1 ++ " " ++ pformat arg2 ++ ")" +pformat (QuotIntegralTerm _ arg1 arg2) = "(quot " ++ pformat arg1 ++ " " ++ pformat arg2 ++ ")" +pformat (RemIntegralTerm _ arg1 arg2) = "(rem " ++ pformat arg1 ++ " " ++ pformat arg2 ++ ")" +pformat (DivBoundedIntegralTerm _ arg1 arg2) = "(div " ++ pformat arg1 ++ " " ++ pformat arg2 ++ ")" +pformat (ModBoundedIntegralTerm _ arg1 arg2) = "(mod " ++ pformat arg1 ++ " " ++ pformat arg2 ++ ")" +pformat (QuotBoundedIntegralTerm _ arg1 arg2) = "(quot " ++ pformat arg1 ++ " " ++ pformat arg2 ++ ")" +pformat (RemBoundedIntegralTerm _ arg1 arg2) = "(rem " ++ pformat arg1 ++ " " ++ pformat arg2 ++ ")" {-# INLINE pformat #-} someTermsSize :: [SomeTerm] -> Int @@ -298,8 +299,14 @@ someTermsSize terms = S.size $ execState (traverse goSome terms) S.empty go t@(BVExtendTerm _ _ _ arg) = goUnary t arg go t@(TabularFunApplyTerm _ func arg) = goBinary t func arg go t@(GeneralFunApplyTerm _ func arg) = goBinary t func arg - go t@(DivIntegerTerm _ arg1 arg2) = goBinary t arg1 arg2 - go t@(ModIntegerTerm _ arg1 arg2) = goBinary t arg1 arg2 + go t@(DivIntegralTerm _ arg1 arg2) = goBinary t arg1 arg2 + go t@(ModIntegralTerm _ arg1 arg2) = goBinary t arg1 arg2 + go t@(QuotIntegralTerm _ arg1 arg2) = goBinary t arg1 arg2 + go t@(RemIntegralTerm _ arg1 arg2) = goBinary t arg1 arg2 + go t@(DivBoundedIntegralTerm _ arg1 arg2) = goBinary t arg1 arg2 + go t@(ModBoundedIntegralTerm _ arg1 arg2) = goBinary t arg1 arg2 + go t@(QuotBoundedIntegralTerm _ arg1 arg2) = goBinary t arg1 arg2 + go t@(RemBoundedIntegralTerm _ arg1 arg2) = goBinary t arg1 arg2 goUnary :: forall a b. (SupportedPrim a) => Term a -> Term b -> State (S.HashSet SomeTerm) () goUnary t arg = do b <- exists t diff --git a/src/Grisette/IR/SymPrim/Data/Prim/Model.hs b/src/Grisette/IR/SymPrim/Data/Prim/Model.hs index 62287a76..0a119973 100644 --- a/src/Grisette/IR/SymPrim/Data/Prim/Model.hs +++ b/src/Grisette/IR/SymPrim/Data/Prim/Model.hs @@ -44,7 +44,7 @@ import Grisette.IR.SymPrim.Data.Prim.PartialEval.BV import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bits import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bool import Grisette.IR.SymPrim.Data.Prim.PartialEval.GeneralFun -import Grisette.IR.SymPrim.Data.Prim.PartialEval.Integer +import Grisette.IR.SymPrim.Data.Prim.PartialEval.Integral import Grisette.IR.SymPrim.Data.Prim.PartialEval.Num import Grisette.IR.SymPrim.Data.Prim.PartialEval.TabularFun import Type.Reflection @@ -342,10 +342,22 @@ evaluateSomeTerm fillDefault m@(Model ma) = gomemo goBinary pevalTabularFunApplyTerm f arg go (SomeTerm (GeneralFunApplyTerm _ f arg)) = goBinary pevalGeneralFunApplyTerm f arg - go (SomeTerm (DivIntegerTerm _ arg1 arg2)) = - goBinary pevalDivIntegerTerm arg1 arg2 - go (SomeTerm (ModIntegerTerm _ arg1 arg2)) = - goBinary pevalModIntegerTerm arg1 arg2 + go (SomeTerm (DivIntegralTerm _ arg1 arg2)) = + goBinary pevalDivIntegralTerm arg1 arg2 + go (SomeTerm (ModIntegralTerm _ arg1 arg2)) = + goBinary pevalModIntegralTerm arg1 arg2 + go (SomeTerm (QuotIntegralTerm _ arg1 arg2)) = + goBinary pevalQuotIntegralTerm arg1 arg2 + go (SomeTerm (RemIntegralTerm _ arg1 arg2)) = + goBinary pevalRemIntegralTerm arg1 arg2 + go (SomeTerm (DivBoundedIntegralTerm _ arg1 arg2)) = + goBinary pevalDivBoundedIntegralTerm arg1 arg2 + go (SomeTerm (ModBoundedIntegralTerm _ arg1 arg2)) = + goBinary pevalModBoundedIntegralTerm arg1 arg2 + go (SomeTerm (QuotBoundedIntegralTerm _ arg1 arg2)) = + goBinary pevalQuotBoundedIntegralTerm arg1 arg2 + go (SomeTerm (RemBoundedIntegralTerm _ arg1 arg2)) = + goBinary pevalRemBoundedIntegralTerm arg1 arg2 goUnary :: (SupportedPrim a, SupportedPrim b) => (Term a -> Term b) -> Term a -> SomeTerm goUnary f a = SomeTerm $ f (gotyped a) goBinary :: diff --git a/src/Grisette/IR/SymPrim/Data/Prim/PartialEval/Integer.hs b/src/Grisette/IR/SymPrim/Data/Prim/PartialEval/Integer.hs deleted file mode 100644 index 61a514ad..00000000 --- a/src/Grisette/IR/SymPrim/Data/Prim/PartialEval/Integer.hs +++ /dev/null @@ -1,36 +0,0 @@ --- | --- Module : Grisette.IR.SymPrim.Data.Prim.PartialEval.Integer --- Copyright : (c) Sirui Lu 2021-2023 --- License : BSD-3-Clause (see the LICENSE file) --- --- Maintainer : siruilu@cs.washington.edu --- Stability : Experimental --- Portability : GHC only -module Grisette.IR.SymPrim.Data.Prim.PartialEval.Integer - ( pevalDivIntegerTerm, - pevalModIntegerTerm, - ) -where - -import Grisette.IR.SymPrim.Data.Prim.InternedTerm.InternedCtors -import Grisette.IR.SymPrim.Data.Prim.InternedTerm.Term -import Grisette.IR.SymPrim.Data.Prim.PartialEval.Unfold - --- div -pevalDivIntegerTerm :: Term Integer -> Term Integer -> Term Integer -pevalDivIntegerTerm = binaryUnfoldOnce doPevalDivIntegerTerm divIntegerTerm - -doPevalDivIntegerTerm :: Term Integer -> Term Integer -> Maybe (Term Integer) -doPevalDivIntegerTerm (ConTerm _ a) (ConTerm _ b) | b /= 0 = Just $ conTerm $ a `div` b -doPevalDivIntegerTerm a (ConTerm _ 1) = Just a -doPevalDivIntegerTerm _ _ = Nothing - --- mod -pevalModIntegerTerm :: Term Integer -> Term Integer -> Term Integer -pevalModIntegerTerm = binaryUnfoldOnce doPevalModIntegerTerm modIntegerTerm - -doPevalModIntegerTerm :: Term Integer -> Term Integer -> Maybe (Term Integer) -doPevalModIntegerTerm (ConTerm _ a) (ConTerm _ b) | b /= 0 = Just $ conTerm $ a `mod` b -doPevalModIntegerTerm _ (ConTerm _ 1) = Just $ conTerm 0 -doPevalModIntegerTerm _ (ConTerm _ (-1)) = Just $ conTerm 0 -doPevalModIntegerTerm _ _ = Nothing diff --git a/src/Grisette/IR/SymPrim/Data/Prim/PartialEval/Integral.hs b/src/Grisette/IR/SymPrim/Data/Prim/PartialEval/Integral.hs new file mode 100644 index 00000000..ea391963 --- /dev/null +++ b/src/Grisette/IR/SymPrim/Data/Prim/PartialEval/Integral.hs @@ -0,0 +1,88 @@ +{-# LANGUAGE ScopedTypeVariables #-} + +-- | +-- Module : Grisette.IR.SymPrim.Data.Prim.PartialEval.Integral +-- Copyright : (c) Sirui Lu 2021-2023 +-- License : BSD-3-Clause (see the LICENSE file) +-- +-- Maintainer : siruilu@cs.washington.edu +-- Stability : Experimental +-- Portability : GHC only +module Grisette.IR.SymPrim.Data.Prim.PartialEval.Integral + ( pevalDivIntegralTerm, + pevalModIntegralTerm, + pevalQuotIntegralTerm, + pevalRemIntegralTerm, + pevalDivBoundedIntegralTerm, + pevalModBoundedIntegralTerm, + pevalQuotBoundedIntegralTerm, + pevalRemBoundedIntegralTerm, + ) +where + +import Grisette.Core.Data.BV +import Grisette.IR.SymPrim.Data.Prim.InternedTerm.InternedCtors +import Grisette.IR.SymPrim.Data.Prim.InternedTerm.Term +import Grisette.IR.SymPrim.Data.Prim.InternedTerm.TermUtils +import Grisette.IR.SymPrim.Data.Prim.PartialEval.Unfold +import Grisette.IR.SymPrim.Data.Prim.Utils + +-- div +pevalDivIntegralTerm :: (SupportedPrim a, Integral a) => Term a -> Term a -> Term a +pevalDivIntegralTerm = binaryUnfoldOnce doPevalDivIntegralTerm divIntegralTerm + +doPevalDivIntegralTerm :: (SupportedPrim a, Integral a) => Term a -> Term a -> Maybe (Term a) +doPevalDivIntegralTerm (ConTerm _ a) (ConTerm _ b) | b /= 0 = Just $ conTerm $ a `div` b +doPevalDivIntegralTerm a (ConTerm _ 1) = Just a +doPevalDivIntegralTerm _ _ = Nothing + +pevalDivBoundedIntegralTerm :: (SupportedPrim a, Bounded a, Integral a) => Term a -> Term a -> Term a +pevalDivBoundedIntegralTerm = binaryUnfoldOnce doPevalDivBoundedIntegralTerm divBoundedIntegralTerm + +doPevalDivBoundedIntegralTerm :: (SupportedPrim a, Bounded a, Integral a) => Term a -> Term a -> Maybe (Term a) +doPevalDivBoundedIntegralTerm (ConTerm _ a) (ConTerm _ b) | b /= 0 && (b /= -1 || a /= minBound) = Just $ conTerm $ a `div` b +doPevalDivBoundedIntegralTerm a (ConTerm _ 1) = Just a +doPevalDivBoundedIntegralTerm _ _ = Nothing + +-- mod +pevalModIntegralTerm :: (SupportedPrim a, Integral a) => Term a -> Term a -> Term a +pevalModIntegralTerm = binaryUnfoldOnce doPevalModIntegralTerm modIntegralTerm + +doPevalModIntegralTerm :: (SupportedPrim a, Integral a) => Term a -> Term a -> Maybe (Term a) +doPevalModIntegralTerm (ConTerm _ a) (ConTerm _ b) | b /= 0 = Just $ conTerm $ a `mod` b +doPevalModIntegralTerm _ (ConTerm _ 1) = Just $ conTerm 0 +doPevalModIntegralTerm _ (ConTerm _ (-1)) = Just $ conTerm 0 +doPevalModIntegralTerm _ _ = Nothing + +pevalModBoundedIntegralTerm :: (SupportedPrim a, Integral a) => Term a -> Term a -> Term a +pevalModBoundedIntegralTerm = pevalModIntegralTerm + +-- quot +pevalQuotIntegralTerm :: (SupportedPrim a, Integral a) => Term a -> Term a -> Term a +pevalQuotIntegralTerm = binaryUnfoldOnce doPevalQuotIntegralTerm quotIntegralTerm + +doPevalQuotIntegralTerm :: (SupportedPrim a, Integral a) => Term a -> Term a -> Maybe (Term a) +doPevalQuotIntegralTerm (ConTerm _ a) (ConTerm _ b) | b /= 0 = Just $ conTerm $ a `quot` b +doPevalQuotIntegralTerm a (ConTerm _ 1) = Just a +doPevalQuotIntegralTerm _ _ = Nothing + +pevalQuotBoundedIntegralTerm :: (SupportedPrim a, Bounded a, Integral a) => Term a -> Term a -> Term a +pevalQuotBoundedIntegralTerm = binaryUnfoldOnce doPevalQuotBoundedIntegralTerm quotBoundedIntegralTerm + +doPevalQuotBoundedIntegralTerm :: (SupportedPrim a, Bounded a, Integral a) => Term a -> Term a -> Maybe (Term a) +doPevalQuotBoundedIntegralTerm (ConTerm _ a) (ConTerm _ b) | b /= 0 && (b /= -1 || a /= minBound) = Just $ conTerm $ a `quot` b +doPevalQuotBoundedIntegralTerm a (ConTerm _ 1) = Just a +doPevalQuotBoundedIntegralTerm _ _ = Nothing + +-- rem +pevalRemIntegralTerm :: (SupportedPrim a, Integral a) => Term a -> Term a -> Term a +pevalRemIntegralTerm = binaryUnfoldOnce doPevalRemIntegralTerm remIntegralTerm + +doPevalRemIntegralTerm :: (SupportedPrim a, Integral a) => Term a -> Term a -> Maybe (Term a) +doPevalRemIntegralTerm (ConTerm _ a) (ConTerm _ b) | b /= 0 = Just $ conTerm $ a `rem` b +doPevalRemIntegralTerm _ (ConTerm _ 1) = Just $ conTerm 0 +doPevalRemIntegralTerm _ (ConTerm _ (-1)) = Just $ conTerm 0 +doPevalRemIntegralTerm _ _ = Nothing + +pevalRemBoundedIntegralTerm :: (SupportedPrim a, Bounded a, Integral a) => Term a -> Term a -> Term a +pevalRemBoundedIntegralTerm = pevalRemIntegralTerm diff --git a/src/Grisette/IR/SymPrim/Data/SymPrim.hs b/src/Grisette/IR/SymPrim/Data/SymPrim.hs index daa47f5f..4fba5d59 100644 --- a/src/Grisette/IR/SymPrim/Data/SymPrim.hs +++ b/src/Grisette/IR/SymPrim/Data/SymPrim.hs @@ -83,10 +83,10 @@ import Grisette.Core.Data.Class.Evaluate import Grisette.Core.Data.Class.ExtractSymbolics import Grisette.Core.Data.Class.Function import Grisette.Core.Data.Class.GenSym -import Grisette.Core.Data.Class.Integer import Grisette.Core.Data.Class.Mergeable import Grisette.Core.Data.Class.ModelOps import Grisette.Core.Data.Class.SOrd +import Grisette.Core.Data.Class.SafeArith import Grisette.Core.Data.Class.SimpleMergeable import Grisette.Core.Data.Class.Solvable import Grisette.Core.Data.Class.Substitute @@ -103,7 +103,7 @@ import Grisette.IR.SymPrim.Data.Prim.PartialEval.BV import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bits import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bool import Grisette.IR.SymPrim.Data.Prim.PartialEval.GeneralFun -import Grisette.IR.SymPrim.Data.Prim.PartialEval.Integer +import Grisette.IR.SymPrim.Data.Prim.PartialEval.Integral import Grisette.IR.SymPrim.Data.Prim.PartialEval.Num import Grisette.IR.SymPrim.Data.Prim.PartialEval.TabularFun import Grisette.IR.SymPrim.Data.TabularFun @@ -140,17 +140,39 @@ newtype SymBool = SymBool {underlyingBoolTerm :: Term Bool} newtype SymInteger = SymInteger {underlyingIntegerTerm :: Term Integer} deriving (Lift, NFData, Generic) -instance SafeDivision SymInteger where - safeDiv e (SymInteger l) rs@(SymInteger r) = - mrgIf - (rs ==~ con 0) - (throwError e) - (mrgReturn $ SymInteger $ pevalDivIntegerTerm l r) - safeMod e (SymInteger l) rs@(SymInteger r) = - mrgIf - (rs ==~ con 0) - (throwError e) - (mrgReturn $ SymInteger $ pevalModIntegerTerm l r) +#define SAFE_DIVISION_FUNC(name, type, op) \ +name (type l) rs@(type r) = \ + mrgIf \ + (rs ==~ con 0) \ + (throwError DivideByZero) \ + (mrgReturn $ type $ op l r); \ +name' t (type l) rs@(type r) = \ + mrgIf \ + (rs ==~ con 0) \ + (throwError (t DivideByZero)) \ + (mrgReturn $ type $ op l r) + +#define SAFE_DIVISION_FUNC2(name, type, op1, op2) \ +name (type l) rs@(type r) = \ + mrgIf \ + (rs ==~ con 0) \ + (throwError DivideByZero) \ + (mrgReturn (type $ op1 l r, type $ op2 l r)); \ +name' t (type l) rs@(type r) = \ + mrgIf \ + (rs ==~ con 0) \ + (throwError (t DivideByZero)) \ + (mrgReturn (type $ op1 l r, type $ op2 l r)) + +#if 1 +instance SafeDivision ArithException SymInteger where + SAFE_DIVISION_FUNC(safeDiv, SymInteger, pevalDivIntegralTerm) + SAFE_DIVISION_FUNC(safeMod, SymInteger, pevalModIntegralTerm) + SAFE_DIVISION_FUNC(safeQuot, SymInteger, pevalQuotIntegralTerm) + SAFE_DIVISION_FUNC(safeRem, SymInteger, pevalRemIntegralTerm) + SAFE_DIVISION_FUNC2(safeDivMod, SymInteger, pevalDivIntegralTerm, pevalModIntegralTerm) + SAFE_DIVISION_FUNC2(safeQuotRem, SymInteger, pevalQuotIntegralTerm, pevalRemIntegralTerm) +#endif instance SafeLinearArith SymInteger where safeAdd e ls rs = mrgReturn $ ls + rs @@ -178,6 +200,48 @@ instance SymIntegerOp SymInteger newtype SymIntN (n :: Nat) = SymIntN {underlyingIntNTerm :: Term (IntN n)} deriving (Lift, NFData, Generic) +#define SAFE_DIVISION_FUNC_BOUNDED_SIGNED(name, type, op) \ +name ls@(type l) rs@(type r) = \ + mrgIf \ + (rs ==~ con 0) \ + (throwError DivideByZero) \ + (mrgIf (rs ==~ con (-1) &&~ ls ==~ con minBound) \ + (throwError Overflow) \ + (mrgReturn $ type $ op l r)); \ +name' t ls@(type l) rs@(type r) = \ + mrgIf \ + (rs ==~ con 0) \ + (throwError (t DivideByZero)) \ + (mrgIf (rs ==~ con (-1) &&~ ls ==~ con minBound) \ + (throwError (t Overflow)) \ + (mrgReturn $ type $ op l r)) + +#define SAFE_DIVISION_FUNC2_BOUNDED_SIGNED(name, type, op1, op2) \ +name ls@(type l) rs@(type r) = \ + mrgIf \ + (rs ==~ con 0) \ + (throwError DivideByZero) \ + (mrgIf (rs ==~ con (-1) &&~ ls ==~ con minBound) \ + (throwError Overflow) \ + (mrgReturn (type $ op1 l r, type $ op2 l r))); \ +name' t ls@(type l) rs@(type r) = \ + mrgIf \ + (rs ==~ con 0) \ + (throwError (t DivideByZero)) \ + (mrgIf (rs ==~ con (-1) &&~ ls ==~ con minBound) \ + (throwError (t Overflow)) \ + (mrgReturn (type $ op1 l r, type $ op2 l r))) + +#if 1 +instance (KnownNat n, 1 <= n) => SafeDivision ArithException (SymIntN n) where + SAFE_DIVISION_FUNC_BOUNDED_SIGNED(safeDiv, SymIntN, pevalDivBoundedIntegralTerm) + SAFE_DIVISION_FUNC(safeMod, SymIntN, pevalModBoundedIntegralTerm) + SAFE_DIVISION_FUNC_BOUNDED_SIGNED(safeQuot, SymIntN, pevalQuotBoundedIntegralTerm) + SAFE_DIVISION_FUNC(safeRem, SymIntN, pevalRemBoundedIntegralTerm) + SAFE_DIVISION_FUNC2_BOUNDED_SIGNED(safeDivMod, SymIntN, pevalDivBoundedIntegralTerm, pevalModBoundedIntegralTerm) + SAFE_DIVISION_FUNC2_BOUNDED_SIGNED(safeQuotRem, SymIntN, pevalQuotBoundedIntegralTerm, pevalRemBoundedIntegralTerm) +#endif + instance (KnownNat n, 1 <= n) => SafeLinearArith (SymIntN n) where safeAdd e ls rs = mrgIf @@ -265,6 +329,16 @@ binSomeSymIntNR2 op str (SomeSymIntN (l :: SymIntN l)) (SomeSymIntN (r :: SymInt newtype SymWordN (n :: Nat) = SymWordN {underlyingWordNTerm :: Term (WordN n)} deriving (Lift, NFData, Generic) +#if 1 +instance (KnownNat n, 1 <= n) => SafeDivision ArithException (SymWordN n) where + SAFE_DIVISION_FUNC(safeDiv, SymWordN, pevalDivIntegralTerm) + SAFE_DIVISION_FUNC(safeMod, SymWordN, pevalModIntegralTerm) + SAFE_DIVISION_FUNC(safeQuot, SymWordN, pevalQuotIntegralTerm) + SAFE_DIVISION_FUNC(safeRem, SymWordN, pevalRemIntegralTerm) + SAFE_DIVISION_FUNC2(safeDivMod, SymWordN, pevalDivIntegralTerm, pevalModIntegralTerm) + SAFE_DIVISION_FUNC2(safeQuotRem, SymWordN, pevalQuotIntegralTerm, pevalRemIntegralTerm) +#endif + instance (KnownNat n, 1 <= n) => SafeLinearArith (SymWordN n) where safeAdd e ls rs = mrgIf diff --git a/src/Grisette/Internal/IR/SymPrim.hs b/src/Grisette/Internal/IR/SymPrim.hs index 0a31d078..1713d1cf 100644 --- a/src/Grisette/Internal/IR/SymPrim.hs +++ b/src/Grisette/Internal/IR/SymPrim.hs @@ -83,8 +83,10 @@ module Grisette.Internal.IR.SymPrim pevalGeNumTerm, pevalTabularFunApplyTerm, pevalGeneralFunApplyTerm, - pevalDivIntegerTerm, - pevalModIntegerTerm, + pevalDivIntegralTerm, + pevalModIntegralTerm, + pevalQuotIntegralTerm, + pevalRemIntegralTerm, ) where @@ -96,7 +98,7 @@ import Grisette.IR.SymPrim.Data.Prim.InternedTerm.TermUtils import Grisette.IR.SymPrim.Data.Prim.Model import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bool import Grisette.IR.SymPrim.Data.Prim.PartialEval.GeneralFun -import Grisette.IR.SymPrim.Data.Prim.PartialEval.Integer +import Grisette.IR.SymPrim.Data.Prim.PartialEval.Integral import Grisette.IR.SymPrim.Data.Prim.PartialEval.Num import Grisette.IR.SymPrim.Data.Prim.PartialEval.PartialEval import Grisette.IR.SymPrim.Data.Prim.PartialEval.TabularFun diff --git a/src/Grisette/Lib/Data/List.hs b/src/Grisette/Lib/Data/List.hs index 256fca0f..fff67a4d 100644 --- a/src/Grisette/Lib/Data/List.hs +++ b/src/Grisette/Lib/Data/List.hs @@ -22,9 +22,9 @@ import Control.Monad.Except import Grisette.Core.Control.Monad.Union import Grisette.Core.Data.Class.Bool import Grisette.Core.Data.Class.Error -import Grisette.Core.Data.Class.Integer import Grisette.Core.Data.Class.Mergeable import Grisette.Core.Data.Class.SOrd +import Grisette.Core.Data.Class.SafeArith import Grisette.Core.Data.Class.SimpleMergeable import Grisette.IR.SymPrim.Data.SymPrim import Grisette.Lib.Control.Monad diff --git a/test/Grisette/Backend/SBV/Data/SMT/LoweringTests.hs b/test/Grisette/Backend/SBV/Data/SMT/LoweringTests.hs index f6a84284..ae18c1b2 100644 --- a/test/Grisette/Backend/SBV/Data/SMT/LoweringTests.hs +++ b/test/Grisette/Backend/SBV/Data/SMT/LoweringTests.hs @@ -360,12 +360,18 @@ loweringTests = leNumTerm "(<=)" (\x y -> x * 2 - x SBV..<= y * 2 - y), - testCase "DivI" $ do - testBinaryOpLowering @Integer @Integer @Integer unboundedConfig divIntegerTerm "div" SBV.sDiv - testBinaryOpLowering @Integer @Integer @Integer boundedConfig divIntegerTerm "div" SBV.sDiv, - testCase "ModI" $ do - testBinaryOpLowering @Integer @Integer @Integer unboundedConfig modIntegerTerm "mod" SBV.sMod - testBinaryOpLowering @Integer @Integer @Integer boundedConfig modIntegerTerm "mod" SBV.sMod + testCase "Div" $ do + testBinaryOpLowering @Integer @Integer @Integer unboundedConfig divIntegralTerm "div" SBV.sDiv + testBinaryOpLowering @Integer @Integer @Integer boundedConfig divIntegralTerm "div" SBV.sDiv, + testCase "Mod" $ do + testBinaryOpLowering @Integer @Integer @Integer unboundedConfig modIntegralTerm "mod" SBV.sMod + testBinaryOpLowering @Integer @Integer @Integer boundedConfig modIntegralTerm "mod" SBV.sMod, + testCase "Quot" $ do + testBinaryOpLowering @Integer @Integer @Integer unboundedConfig quotIntegralTerm "quot" SBV.sQuot + testBinaryOpLowering @Integer @Integer @Integer boundedConfig quotIntegralTerm "quot" SBV.sQuot, + testCase "Rem" $ do + testBinaryOpLowering @Integer @Integer @Integer unboundedConfig remIntegralTerm "rem" SBV.sRem + testBinaryOpLowering @Integer @Integer @Integer boundedConfig remIntegralTerm "rem" SBV.sRem ], testGroup "IntN Lowering" @@ -554,7 +560,19 @@ loweringTests = testUnaryOpLowering @(IntN 5) unboundedConfig (`rotateBitsTerm` (-4)) "rotate" (`rotate` (-4)) testUnaryOpLowering @(IntN 5) unboundedConfig (`rotateBitsTerm` (-4)) "rotate" (`rotate` 1) testUnaryOpLowering @(IntN 5) unboundedConfig (`rotateBitsTerm` (-5)) "rotate" (`rotate` (-5)) - testUnaryOpLowering @(IntN 5) unboundedConfig (`rotateBitsTerm` (-5)) "rotate" id + testUnaryOpLowering @(IntN 5) unboundedConfig (`rotateBitsTerm` (-5)) "rotate" id, + testCase "Div - bounded" $ do + testBinaryOpLowering @(IntN 5) @(IntN 5) @(IntN 5) unboundedConfig divBoundedIntegralTerm "div" SBV.sDiv + testBinaryOpLowering @(IntN 5) @(IntN 5) @(IntN 5) boundedConfig divBoundedIntegralTerm "div" SBV.sDiv, + testCase "Mod - bounded" $ do + testBinaryOpLowering @(IntN 5) @(IntN 5) @(IntN 5) unboundedConfig modBoundedIntegralTerm "mod" SBV.sMod + testBinaryOpLowering @(IntN 5) @(IntN 5) @(IntN 5) boundedConfig modBoundedIntegralTerm "mod" SBV.sMod, + testCase "Quot - bounded" $ do + testBinaryOpLowering @(IntN 5) @(IntN 5) @(IntN 5) unboundedConfig quotBoundedIntegralTerm "quot" SBV.sQuot + testBinaryOpLowering @(IntN 5) @(IntN 5) @(IntN 5) boundedConfig quotBoundedIntegralTerm "quot" SBV.sQuot, + testCase "Rem - bounded" $ do + testBinaryOpLowering @(IntN 5) @(IntN 5) @(IntN 5) unboundedConfig remBoundedIntegralTerm "rem" SBV.sRem + testBinaryOpLowering @(IntN 5) @(IntN 5) @(IntN 5) boundedConfig remBoundedIntegralTerm "rem" SBV.sRem ], testGroup "WordN" @@ -743,6 +761,18 @@ loweringTests = testUnaryOpLowering @(WordN 5) unboundedConfig (`rotateBitsTerm` (-4)) "rotate" (`rotate` (-4)) testUnaryOpLowering @(WordN 5) unboundedConfig (`rotateBitsTerm` (-4)) "rotate" (`rotate` 1) testUnaryOpLowering @(WordN 5) unboundedConfig (`rotateBitsTerm` (-5)) "rotate" (`rotate` (-5)) - testUnaryOpLowering @(WordN 5) unboundedConfig (`rotateBitsTerm` (-5)) "rotate" id + testUnaryOpLowering @(WordN 5) unboundedConfig (`rotateBitsTerm` (-5)) "rotate" id, + testCase "Div" $ do + testBinaryOpLowering @(WordN 5) @(WordN 5) @(WordN 5) unboundedConfig divIntegralTerm "div" SBV.sDiv + testBinaryOpLowering @(WordN 5) @(WordN 5) @(WordN 5) boundedConfig divIntegralTerm "div" SBV.sDiv, + testCase "Mod" $ do + testBinaryOpLowering @(WordN 5) @(WordN 5) @(WordN 5) unboundedConfig modIntegralTerm "mod" SBV.sMod + testBinaryOpLowering @(WordN 5) @(WordN 5) @(WordN 5) boundedConfig modIntegralTerm "mod" SBV.sMod, + testCase "Quot" $ do + testBinaryOpLowering @(WordN 5) @(WordN 5) @(WordN 5) unboundedConfig quotIntegralTerm "quot" SBV.sQuot + testBinaryOpLowering @(WordN 5) @(WordN 5) @(WordN 5) boundedConfig quotIntegralTerm "quot" SBV.sQuot, + testCase "Rem" $ do + testBinaryOpLowering @(WordN 5) @(WordN 5) @(WordN 5) unboundedConfig remIntegralTerm "rem" SBV.sRem + testBinaryOpLowering @(WordN 5) @(WordN 5) @(WordN 5) boundedConfig remIntegralTerm "rem" SBV.sRem ] ] diff --git a/test/Grisette/Backend/SBV/Data/SMT/TermRewritingGen.hs b/test/Grisette/Backend/SBV/Data/SMT/TermRewritingGen.hs index 1a90f847..f6aaf0f6 100644 --- a/test/Grisette/Backend/SBV/Data/SMT/TermRewritingGen.hs +++ b/test/Grisette/Backend/SBV/Data/SMT/TermRewritingGen.hs @@ -24,7 +24,7 @@ import Grisette.IR.SymPrim.Data.Prim.InternedTerm.TermUtils import Grisette.IR.SymPrim.Data.Prim.PartialEval.BV import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bits import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bool -import Grisette.IR.SymPrim.Data.Prim.PartialEval.Integer +import Grisette.IR.SymPrim.Data.Prim.PartialEval.Integral import Grisette.IR.SymPrim.Data.Prim.PartialEval.Num import Test.Tasty.QuickCheck @@ -230,11 +230,29 @@ bvextendSpec :: b bvextendSpec signed p = constructUnarySpec (bvextendTerm signed p) (pevalBVExtendTerm signed p) -divIntegerSpec :: (TermRewritingSpec a Integer) => a -> a -> a -divIntegerSpec = constructBinarySpec divIntegerTerm pevalDivIntegerTerm +divIntegralSpec :: (TermRewritingSpec a b, Integral b) => a -> a -> a +divIntegralSpec = constructBinarySpec divIntegralTerm pevalDivIntegralTerm -modIntegerSpec :: (TermRewritingSpec a Integer) => a -> a -> a -modIntegerSpec = constructBinarySpec modIntegerTerm pevalModIntegerTerm +modIntegralSpec :: (TermRewritingSpec a b, Integral b) => a -> a -> a +modIntegralSpec = constructBinarySpec modIntegralTerm pevalModIntegralTerm + +quotIntegralSpec :: (TermRewritingSpec a b, Integral b) => a -> a -> a +quotIntegralSpec = constructBinarySpec quotIntegralTerm pevalQuotIntegralTerm + +remIntegralSpec :: (TermRewritingSpec a b, Integral b) => a -> a -> a +remIntegralSpec = constructBinarySpec remIntegralTerm pevalRemIntegralTerm + +divBoundedIntegralSpec :: (TermRewritingSpec a b, Bounded b, Integral b) => a -> a -> a +divBoundedIntegralSpec = constructBinarySpec divBoundedIntegralTerm pevalDivBoundedIntegralTerm + +modBoundedIntegralSpec :: (TermRewritingSpec a b, Bounded b, Integral b) => a -> a -> a +modBoundedIntegralSpec = constructBinarySpec modBoundedIntegralTerm pevalModBoundedIntegralTerm + +quotBoundedIntegralSpec :: (TermRewritingSpec a b, Bounded b, Integral b) => a -> a -> a +quotBoundedIntegralSpec = constructBinarySpec quotBoundedIntegralTerm pevalQuotBoundedIntegralTerm + +remBoundedIntegralSpec :: (TermRewritingSpec a b, Bounded b, Integral b) => a -> a -> a +remBoundedIntegralSpec = constructBinarySpec remBoundedIntegralTerm pevalRemBoundedIntegralTerm data BoolOnlySpec = BoolOnlySpec (Term Bool) (Term Bool) diff --git a/test/Grisette/Backend/SBV/Data/SMT/TermRewritingTests.hs b/test/Grisette/Backend/SBV/Data/SMT/TermRewritingTests.hs index 380cc9ac..1c7b9424 100644 --- a/test/Grisette/Backend/SBV/Data/SMT/TermRewritingTests.hs +++ b/test/Grisette/Backend/SBV/Data/SMT/TermRewritingTests.hs @@ -31,160 +31,174 @@ validateSpec config a = do (Right m, _) -> do assertFailure $ "With model" ++ show m ++ "Bad rewriting: " ++ pformat (norewriteVer a) ++ " was rewritten to " ++ pformat (rewriteVer a) +unboundedConfig = precise SBV.z3 + +divisionTest :: + forall a b. + (TermRewritingSpec a b, Show a, Enum b, Num b, SupportedPrim b) => + TestName -> + (a -> a -> a) -> + TestTree +divisionTest name f = + testGroup + name + [ testCase "on concrete" $ do + traverse_ + ( \(x :: b, y :: b) -> do + validateSpec @a unboundedConfig $ f (conSpec x) (conSpec y) + ) + [(i, j) | i <- [-3 .. 3], j <- [-3 .. 3]], + testCase "on single concrete" $ do + traverse_ + ( \x -> do + validateSpec @a unboundedConfig $ f (conSpec x) (symSpec "a") + validateSpec @a unboundedConfig $ f (symSpec "a") (conSpec x) + ) + [-3 .. 3] + ] + termRewritingTests :: TestTree termRewritingTests = - let unboundedConfig = precise SBV.z3 -- {SBV.verbose=True} - in testGroup - "TermRewritingTests" - [ testGroup - "Bool only" - [ testProperty "Bool only random test" $ - mapSize (`min` 10) $ - ioProperty . \(x :: BoolOnlySpec) -> do - validateSpec unboundedConfig x, - testCase "Regression nested ite with (ite a (ite b c d) e) with b is true" $ do - validateSpec @BoolOnlySpec - unboundedConfig - ( iteSpec - (symSpec "a" :: BoolOnlySpec) - ( iteSpec - (orSpec (notSpec (andSpec (symSpec "b1") (symSpec "b2"))) (symSpec "b2") :: BoolOnlySpec) - (symSpec "c") - (symSpec "d") - ) - (symSpec "e") - ), - testCase "Regression for pevalImpliesTerm _ false should be false" $ do - validateSpec @BoolOnlySpec - unboundedConfig + testGroup + "TermRewritingTests" + [ testGroup + "Bool only" + [ testProperty "Bool only random test" $ + mapSize (`min` 10) $ + ioProperty . \(x :: BoolOnlySpec) -> do + validateSpec unboundedConfig x, + testCase "Regression nested ite with (ite a (ite b c d) e) with b is true" $ do + validateSpec @BoolOnlySpec + unboundedConfig + ( iteSpec + (symSpec "a" :: BoolOnlySpec) ( iteSpec - (symSpec "fbool" :: BoolOnlySpec) - ( notSpec - ( orSpec - (orSpec (notSpec (andSpec (symSpec "gbool" :: BoolOnlySpec) (symSpec "fbool" :: BoolOnlySpec))) (symSpec "gbool" :: BoolOnlySpec)) - (orSpec (symSpec "abool" :: BoolOnlySpec) (notSpec (andSpec (symSpec "gbool" :: BoolOnlySpec) (symSpec "bbool" :: BoolOnlySpec)))) - ) - ) - (symSpec "xxx" :: BoolOnlySpec) + (orSpec (notSpec (andSpec (symSpec "b1") (symSpec "b2"))) (symSpec "b2") :: BoolOnlySpec) + (symSpec "c") + (symSpec "d") ) - ], - testGroup - "LIA" - [ testProperty "LIA random test" $ - mapSize (`min` 10) $ - ioProperty . \(x :: LIAWithBoolSpec) -> do - validateSpec unboundedConfig x, - testCase "Regression nested ite with (ite a b (ite c d e)) with c implies a" $ do - validateSpec @LIAWithBoolSpec - unboundedConfig - ( iteSpec - (notSpec (eqvSpec (symSpec "v" :: LIAWithBoolSpec) (conSpec 1 :: LIAWithBoolSpec) :: BoolWithLIASpec)) - (symSpec "b") - ( iteSpec - (eqvSpec (symSpec "v" :: LIAWithBoolSpec) (conSpec 2 :: LIAWithBoolSpec) :: BoolWithLIASpec) - (symSpec "d") - (symSpec "d") + (symSpec "e") + ), + testCase "Regression for pevalImpliesTerm _ false should be false" $ do + validateSpec @BoolOnlySpec + unboundedConfig + ( iteSpec + (symSpec "fbool" :: BoolOnlySpec) + ( notSpec + ( orSpec + (orSpec (notSpec (andSpec (symSpec "gbool" :: BoolOnlySpec) (symSpec "fbool" :: BoolOnlySpec))) (symSpec "gbool" :: BoolOnlySpec)) + (orSpec (symSpec "abool" :: BoolOnlySpec) (notSpec (andSpec (symSpec "gbool" :: BoolOnlySpec) (symSpec "bbool" :: BoolOnlySpec)))) ) ) - ], - testGroup - "Different sized SignedBV" - [ testProperty "Fixed Sized SignedBV random test" $ - mapSize (`min` 10) $ - ioProperty . \(x :: (DifferentSizeBVSpec IntN 4)) -> do - validateSpec unboundedConfig x - ], - testGroup - "Fixed sized SignedBV" - [ testProperty "Fixed Sized SignedBV random test" $ - mapSize (`min` 10) $ - ioProperty . \(x :: (FixedSizedBVWithBoolSpec IntN)) -> do - validateSpec unboundedConfig x - ], - testGroup - "timesNumSpec on integer" - [ testCase "times on both concrete" $ do - traverse_ - (\(x, y) -> validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (conSpec x) (conSpec y)) - [(i, j) | i <- [-3 .. 3], j <- [-3 .. 3]], - testCase "times on single concrete" $ do - traverse_ - ( \x -> do - validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (conSpec x) (symSpec "a") - validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (symSpec "a") (conSpec x) - ) - [-3 .. 3], - testCase "Two times with two concrete combined" $ do - traverse_ - ( \(x, y) -> do - validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (conSpec x) $ timesNumSpec (conSpec y) (symSpec "a") - validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (conSpec x) $ timesNumSpec (symSpec "a") (conSpec y) - validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (timesNumSpec (conSpec x) (symSpec "a")) (conSpec y) - validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (timesNumSpec (symSpec "a") (conSpec x)) (conSpec y) - ) - [(i, j) | i <- [-3 .. 3], j <- [-3 .. 3]], - testCase "Two times with one concrete" $ do - traverse_ - ( \x -> do - validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (conSpec x) $ timesNumSpec (symSpec "b") (symSpec "a") - validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (symSpec "b") $ timesNumSpec (symSpec "a") (conSpec x) - validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (symSpec "b") $ timesNumSpec (conSpec x) (symSpec "a") - validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (timesNumSpec (conSpec x) (symSpec "a")) (symSpec "b") - validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (timesNumSpec (symSpec "a") (conSpec x)) (symSpec "b") - validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (timesNumSpec (symSpec "a") (symSpec "b")) (conSpec x) - ) - [-3 .. 3], - testCase "times and add with two concretes combined" $ do - traverse_ - ( \(x, y) -> do - validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (conSpec x) $ addNumSpec (conSpec y) (symSpec "a") - validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (conSpec x) $ addNumSpec (symSpec "a") (conSpec y) - validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (addNumSpec (conSpec x) (symSpec "a")) (conSpec y) - validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (addNumSpec (symSpec "a") (conSpec x)) (conSpec y) - validateSpec @(GeneralSpec Integer) unboundedConfig $ addNumSpec (conSpec x) $ timesNumSpec (conSpec y) (symSpec "a") - validateSpec @(GeneralSpec Integer) unboundedConfig $ addNumSpec (conSpec x) $ timesNumSpec (symSpec "a") (conSpec y) - validateSpec @(GeneralSpec Integer) unboundedConfig $ addNumSpec (timesNumSpec (conSpec x) (symSpec "a")) (conSpec y) - validateSpec @(GeneralSpec Integer) unboundedConfig $ addNumSpec (timesNumSpec (symSpec "a") (conSpec x)) (conSpec y) - ) - [(i, j) | i <- [-3 .. 3], j <- [-3 .. 3]], - testCase "times concrete with uminusNumSpec symbolic" $ do - traverse_ - ( \x -> do - validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (conSpec x) (uminusNumSpec $ symSpec "a") - validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (uminusNumSpec $ symSpec "a") (conSpec x) - ) - [-3 .. 3] - ], - testGroup - "DivI" - [ testCase "DivI on concrete" $ do - traverse_ - ( \(x, y) -> do - validateSpec @(GeneralSpec Integer) unboundedConfig $ divIntegerSpec (conSpec x) (conSpec y) - ) - [(i, j) | i <- [-3 .. 3], j <- [-3 .. 3]], - testCase "DivI on single concrete" $ do - traverse_ - ( \x -> do - validateSpec @(GeneralSpec Integer) unboundedConfig $ divIntegerSpec (conSpec x) (symSpec "a") - validateSpec @(GeneralSpec Integer) unboundedConfig $ divIntegerSpec (symSpec "a") (conSpec x) - ) - [-3 .. 3] - ], - testGroup - "ModI" - [ testCase "ModI on concrete" $ do - traverse_ - ( \(x, y) -> do - validateSpec @(GeneralSpec Integer) unboundedConfig $ modIntegerSpec (conSpec x) (conSpec y) - ) - [(i, j) | i <- [-3 .. 3], j <- [-3 .. 3]], - testCase "ModI on single concrete" $ do - traverse_ - ( \x -> do - validateSpec @(GeneralSpec Integer) unboundedConfig $ modIntegerSpec (conSpec x) (symSpec "a") - validateSpec @(GeneralSpec Integer) unboundedConfig $ modIntegerSpec (symSpec "a") (conSpec x) + (symSpec "xxx" :: BoolOnlySpec) + ) + ], + testGroup + "LIA" + [ testProperty "LIA random test" $ + mapSize (`min` 10) $ + ioProperty . \(x :: LIAWithBoolSpec) -> do + validateSpec unboundedConfig x, + testCase "Regression nested ite with (ite a b (ite c d e)) with c implies a" $ do + validateSpec @LIAWithBoolSpec + unboundedConfig + ( iteSpec + (notSpec (eqvSpec (symSpec "v" :: LIAWithBoolSpec) (conSpec 1 :: LIAWithBoolSpec) :: BoolWithLIASpec)) + (symSpec "b") + ( iteSpec + (eqvSpec (symSpec "v" :: LIAWithBoolSpec) (conSpec 2 :: LIAWithBoolSpec) :: BoolWithLIASpec) + (symSpec "d") + (symSpec "d") ) - [-3 .. 3] - ] + ) + ], + testGroup + "Different sized SignedBV" + [ testProperty "Fixed Sized SignedBV random test" $ + mapSize (`min` 10) $ + ioProperty . \(x :: (DifferentSizeBVSpec IntN 4)) -> do + validateSpec unboundedConfig x + ], + testGroup + "Fixed sized SignedBV" + [ testProperty "Fixed Sized SignedBV random test" $ + mapSize (`min` 10) $ + ioProperty . \(x :: (FixedSizedBVWithBoolSpec IntN)) -> do + validateSpec unboundedConfig x + ], + testGroup + "timesNumSpec on integer" + [ testCase "times on both concrete" $ do + traverse_ + (\(x, y) -> validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (conSpec x) (conSpec y)) + [(i, j) | i <- [-3 .. 3], j <- [-3 .. 3]], + testCase "times on single concrete" $ do + traverse_ + ( \x -> do + validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (conSpec x) (symSpec "a") + validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (symSpec "a") (conSpec x) + ) + [-3 .. 3], + testCase "Two times with two concrete combined" $ do + traverse_ + ( \(x, y) -> do + validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (conSpec x) $ timesNumSpec (conSpec y) (symSpec "a") + validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (conSpec x) $ timesNumSpec (symSpec "a") (conSpec y) + validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (timesNumSpec (conSpec x) (symSpec "a")) (conSpec y) + validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (timesNumSpec (symSpec "a") (conSpec x)) (conSpec y) + ) + [(i, j) | i <- [-3 .. 3], j <- [-3 .. 3]], + testCase "Two times with one concrete" $ do + traverse_ + ( \x -> do + validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (conSpec x) $ timesNumSpec (symSpec "b") (symSpec "a") + validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (symSpec "b") $ timesNumSpec (symSpec "a") (conSpec x) + validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (symSpec "b") $ timesNumSpec (conSpec x) (symSpec "a") + validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (timesNumSpec (conSpec x) (symSpec "a")) (symSpec "b") + validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (timesNumSpec (symSpec "a") (conSpec x)) (symSpec "b") + validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (timesNumSpec (symSpec "a") (symSpec "b")) (conSpec x) + ) + [-3 .. 3], + testCase "times and add with two concretes combined" $ do + traverse_ + ( \(x, y) -> do + validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (conSpec x) $ addNumSpec (conSpec y) (symSpec "a") + validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (conSpec x) $ addNumSpec (symSpec "a") (conSpec y) + validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (addNumSpec (conSpec x) (symSpec "a")) (conSpec y) + validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (addNumSpec (symSpec "a") (conSpec x)) (conSpec y) + validateSpec @(GeneralSpec Integer) unboundedConfig $ addNumSpec (conSpec x) $ timesNumSpec (conSpec y) (symSpec "a") + validateSpec @(GeneralSpec Integer) unboundedConfig $ addNumSpec (conSpec x) $ timesNumSpec (symSpec "a") (conSpec y) + validateSpec @(GeneralSpec Integer) unboundedConfig $ addNumSpec (timesNumSpec (conSpec x) (symSpec "a")) (conSpec y) + validateSpec @(GeneralSpec Integer) unboundedConfig $ addNumSpec (timesNumSpec (symSpec "a") (conSpec x)) (conSpec y) + ) + [(i, j) | i <- [-3 .. 3], j <- [-3 .. 3]], + testCase "times concrete with uminusNumSpec symbolic" $ do + traverse_ + ( \x -> do + validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (conSpec x) (uminusNumSpec $ symSpec "a") + validateSpec @(GeneralSpec Integer) unboundedConfig $ timesNumSpec (uminusNumSpec $ symSpec "a") (conSpec x) + ) + [-3 .. 3] + ], + testGroup + "divisions on integer" + [ divisionTest @(GeneralSpec Integer) "div" divIntegralSpec, + divisionTest @(GeneralSpec Integer) "mod" modIntegralSpec, + divisionTest @(GeneralSpec Integer) "quot" quotIntegralSpec, + divisionTest @(GeneralSpec Integer) "rem" remIntegralSpec + ], + testGroup + "divisions on signed bv" + [ divisionTest @(GeneralSpec (IntN 4)) "div" divBoundedIntegralSpec, + divisionTest @(GeneralSpec (IntN 4)) "mod" modBoundedIntegralSpec, + divisionTest @(GeneralSpec (IntN 4)) "quot" quotBoundedIntegralSpec, + divisionTest @(GeneralSpec (IntN 4)) "rem" remBoundedIntegralSpec + ], + testGroup + "divisions on unsigned bv" + [ divisionTest @(GeneralSpec (WordN 4)) "div" divIntegralSpec, + divisionTest @(GeneralSpec (WordN 4)) "mod" modIntegralSpec, + divisionTest @(GeneralSpec (WordN 4)) "quot" quotIntegralSpec, + divisionTest @(GeneralSpec (WordN 4)) "rem" remIntegralSpec ] + ] diff --git a/test/Grisette/Core/Data/BVTests.hs b/test/Grisette/Core/Data/BVTests.hs index afadc22d..b3f242ee 100644 --- a/test/Grisette/Core/Data/BVTests.hs +++ b/test/Grisette/Core/Data/BVTests.hs @@ -176,57 +176,69 @@ enumConformTest pref ptyp = else (fromIntegral <$> enumFromThenTo x y z) @=? enumFromThenTo (fromIntegral x :: typ) (fromIntegral y) (fromIntegral z) ] +newtype AEWrapper = AEWrapper ArithException deriving (Eq) + +instance Show AEWrapper where + show (AEWrapper x) = show x + +instance NFData AEWrapper where + rnf (AEWrapper x) = x `seq` () + +sameDiv :: (NFData a, NFData b, Eq b, Show b) => a -> a -> (a -> b) -> (a -> a -> a) -> (b -> b -> b) -> IO () +sameDiv x y a2b fa fb = do + xa <- evaluate (force $ Right $ fa x y) `catch` \(e :: ArithException) -> return $ Left $ AEWrapper e + xb <- evaluate (force $ Right $ fb (a2b x) (a2b y)) `catch` \(e :: ArithException) -> return $ Left $ AEWrapper e + xb @=? a2b <$> xa + +sameDivMod :: (NFData a, NFData b, Eq b, Show b) => a -> a -> (a -> b) -> (a -> a -> (a, a)) -> (b -> b -> (b, b)) -> IO () +sameDivMod x y a2b fa fb = do + xa <- evaluate (force $ Right $ fa x y) `catch` \(e :: ArithException) -> return $ Left $ AEWrapper e + xb <- evaluate (force $ Right $ fb (a2b x) (a2b y)) `catch` \(e :: ArithException) -> return $ Left $ AEWrapper e + xb @=? bimap a2b a2b <$> xa + divLikeTest :: forall a b. - (Arbitrary a, Eq a, Eq b, Num a, Show a, Bounded a, Bits a, Eq b, Show b, Num b, NFData b, Bounded b, Bits b) => + (Arbitrary a, Eq a, Eq b, Num a, Show a, Bounded a, Bits a, Eq b, Show b, Num b, NFData b, Bounded b, Bits b, NFData a) => TestName -> (a -> b) -> (a -> a -> a) -> (b -> b -> b) -> TestTree divLikeTest name a2b fa fb = - if isSigned (0 :: a) - then - testGroup - name - [ testProperty (name ++ " non zero / minBound vs -1") $ \(x :: a) y -> ioProperty $ do - if y == 0 || (x == minBound && y == -1) then return () else a2b (fa x y) @=? fb (a2b x) (a2b y), - testCase (name ++ " zero") $ shouldThrow (name ++ " zero") $ fb 1 0, - testCase (name ++ " minBound vs -1") $ shouldThrow (name ++ " minBound vs -1") $ fb minBound (-1) - ] - else - testGroup - name - [ testProperty (name ++ " non zero") $ \(x :: a) y -> ioProperty $ do - if y == 0 then return () else a2b (fa x y) @=? fb (a2b x) (a2b y), - testCase (name ++ " zero") $ shouldThrow (name ++ " zero") $ fb 1 0 - ] + testGroup + "name" + [ testCase "divided by zero" $ do + sameDiv 1 0 a2b fa fb + sameDiv 0 0 a2b fa fb + sameDiv (-1) 0 a2b fa fb + sameDiv minBound 0 a2b fa fb + sameDiv maxBound 0 a2b fa fb, + testCase "min divided by -1" $ do + sameDiv minBound (-1) a2b fa fb, + testProperty "prop" $ \(x :: a) y -> ioProperty $ sameDiv x y a2b fa fb + ] divModLikeTest :: forall a b. - (Arbitrary a, Eq a, Eq b, Num a, Show a, Bounded a, Bits a, Eq b, Show b, Num b, NFData b, Bounded b, Bits b) => + (Arbitrary a, Eq a, Eq b, Num a, NFData a, Show a, Bounded a, Bits a, Eq b, Show b, Num b, NFData b, Bounded b, Bits b) => TestName -> (a -> b) -> (a -> a -> (a, a)) -> (b -> b -> (b, b)) -> TestTree divModLikeTest name a2b fa fb = - if isSigned (0 :: a) - then - testGroup - name - [ testProperty (name ++ " non zero / minBound vs -1") $ \(x :: a) y -> ioProperty $ do - if y == 0 || (x == minBound && y == -1) then return () else bimap a2b a2b (fa x y) @=? fb (a2b x) (a2b y), - testCase (name ++ " zero") $ shouldThrow (name ++ " zero") $ fb 1 0, - testCase (name ++ " minBound vs -1") $ shouldThrow (name ++ " minBound vs -1") $ fb minBound (-1) - ] - else - testGroup - name - [ testProperty (name ++ " non zero") $ \(x :: a) y -> ioProperty $ do - if y == 0 then return () else bimap a2b a2b (fa x y) @=? fb (a2b x) (a2b y), - testCase (name ++ " zero") $ shouldThrow (name ++ " zero") $ fb 1 0 - ] + testGroup + "name" + [ testCase "divided by zero" $ do + sameDivMod 1 0 a2b fa fb + sameDivMod 0 0 a2b fa fb + sameDivMod (-1) 0 a2b fa fb + sameDivMod minBound 0 a2b fa fb + sameDivMod maxBound 0 a2b fa fb, + testCase "min divided by -1" $ do + sameDivMod minBound (-1) a2b fa fb, + testProperty "prop" $ \(x :: a) y -> ioProperty $ sameDivMod x y a2b fa fb + ] realConformTest :: forall proxy ref typ. @@ -257,7 +269,8 @@ integralConformTest :: Bits typ, Bounded ref, Bounded typ, - NFData typ + NFData typ, + NFData ref ) => Proxy ref -> Proxy typ -> @@ -411,9 +424,7 @@ bvTests = [ testCase "division of min bound and minus one for signed bit vector should throw" $ do shouldThrow "divMod" $ divMod (minBound :: IntN 8) (-1 :: IntN 8) shouldThrow "div" $ div (minBound :: IntN 8) (-1 :: IntN 8) - shouldThrow "mod" $ mod (minBound :: IntN 8) (-1 :: IntN 8) shouldThrow "quotRem" $ quotRem (minBound :: IntN 8) (-1 :: IntN 8) shouldThrow "quot" $ quot (minBound :: IntN 8) (-1 :: IntN 8) - shouldThrow "rem" $ rem (minBound :: IntN 8) (-1 :: IntN 8) ] ] diff --git a/test/Grisette/IR/SymPrim/Data/Prim/IntegerTests.hs b/test/Grisette/IR/SymPrim/Data/Prim/IntegerTests.hs deleted file mode 100644 index 0fc40fb7..00000000 --- a/test/Grisette/IR/SymPrim/Data/Prim/IntegerTests.hs +++ /dev/null @@ -1,48 +0,0 @@ -{-# LANGUAGE ScopedTypeVariables #-} - -module Grisette.IR.SymPrim.Data.Prim.IntegerTests where - -import Grisette.IR.SymPrim.Data.Prim.InternedTerm.InternedCtors -import Grisette.IR.SymPrim.Data.Prim.InternedTerm.Term -import Grisette.IR.SymPrim.Data.Prim.PartialEval.Integer -import Test.Tasty -import Test.Tasty.HUnit -import Test.Tasty.QuickCheck - -integerTests :: TestTree -integerTests = - testGroup - "IntegerTests" - [ testGroup - "DivI" - [ testProperty "On concrete" $ - ioProperty . \(i :: Integer, j :: Integer) -> do - if j /= 0 - then pevalDivIntegerTerm (conTerm i) (conTerm j) @=? conTerm (i `div` j) - else - pevalDivIntegerTerm (conTerm i) (conTerm j) - @=? divIntegerTerm (conTerm i) (conTerm j), - testCase "divide by 1" $ do - pevalDivIntegerTerm (ssymTerm "a" :: Term Integer) (conTerm 1) @=? ssymTerm "a", - testCase "On symbolic" $ do - pevalDivIntegerTerm (ssymTerm "a" :: Term Integer) (ssymTerm "b") - @=? divIntegerTerm (ssymTerm "a" :: Term Integer) (ssymTerm "b" :: Term Integer) - ], - testGroup - "ModI" - [ testProperty "On concrete" $ - ioProperty . \(i :: Integer, j :: Integer) -> do - if j /= 0 - then pevalModIntegerTerm (conTerm i) (conTerm j) @=? conTerm (i `mod` j) - else - pevalModIntegerTerm (conTerm i) (conTerm j) - @=? modIntegerTerm (conTerm i) (conTerm j), - testCase "mod by 1" $ do - pevalModIntegerTerm (ssymTerm "a" :: Term Integer) (conTerm 1) @=? conTerm 0, - testCase "mod by -1" $ do - pevalModIntegerTerm (ssymTerm "a" :: Term Integer) (conTerm $ -1) @=? conTerm 0, - testCase "On symbolic" $ do - pevalModIntegerTerm (ssymTerm "a" :: Term Integer) (ssymTerm "b") - @=? modIntegerTerm (ssymTerm "a" :: Term Integer) (ssymTerm "b" :: Term Integer) - ] - ] diff --git a/test/Grisette/IR/SymPrim/Data/Prim/IntegralTests.hs b/test/Grisette/IR/SymPrim/Data/Prim/IntegralTests.hs new file mode 100644 index 00000000..c0f72af0 --- /dev/null +++ b/test/Grisette/IR/SymPrim/Data/Prim/IntegralTests.hs @@ -0,0 +1,170 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} + +module Grisette.IR.SymPrim.Data.Prim.IntegralTests where + +import Control.DeepSeq +import Control.Exception +import Data.Proxy +import Grisette.Core.Data.BV +import Grisette.IR.SymPrim.Data.Prim.InternedTerm.InternedCtors +import Grisette.IR.SymPrim.Data.Prim.InternedTerm.Term +import Grisette.IR.SymPrim.Data.Prim.PartialEval.Integral +import Test.Tasty +import Test.Tasty.HUnit +import Test.Tasty.QuickCheck + +newtype AEWrapper = AEWrapper ArithException deriving (Eq) + +instance Show AEWrapper where + show (AEWrapper x) = show x + +instance NFData AEWrapper where + rnf (AEWrapper x) = x `seq` () + +sameDivPeval :: + forall t. + (Num t, Eq t, SupportedPrim t, Integral t) => + t -> + t -> + (Term t -> Term t -> Term t) -> + (t -> t -> t) -> + (Term t -> Term t -> Term t) -> + IO () +sameDivPeval i j pf cf consf = do + cx <- evaluate (force $ Right $ cf i j) `catch` \(e :: ArithException) -> return $ Left AEWrapper + case cx of + Left f -> pf (conTerm i) (conTerm j) @=? consf (conTerm i) (conTerm j) + Right t -> pf (conTerm i) (conTerm j) @=? conTerm t + +divisionPevalBoundedTests :: + forall p t0 t. + (Num t, Eq t, Arbitrary t0, Show t0, Bounded t, SupportedPrim t, Integral t) => + p t -> + TestName -> + (t0 -> t) -> + (Term t -> Term t -> Term t) -> + (t -> t -> t) -> + (Term t -> Term t -> Term t) -> + TestTree +divisionPevalBoundedTests _ name transform pf cf consf = + testGroup + name + [ testCase "On concrete min divide by -1" $ + sameDivPeval minBound (-1) pf cf consf + ] + +divisionPevalTests :: + forall p t0 t. + (Num t, Eq t, Arbitrary t0, Show t0, SupportedPrim t, Integral t) => + p t -> + TestName -> + (t0 -> t) -> + (Term t -> Term t -> Term t) -> + (t -> t -> t) -> + (Term t -> Term t -> Term t) -> + TestTree +divisionPevalTests p name transform pf cf consf = + testGroup + name + [ testProperty "On concrete prop" $ + ioProperty . \(i0 :: t0, j0 :: t0) -> do + let i = transform i0 + let j = transform j0 + sameDivPeval i j pf cf consf, + testProperty "On concrete divide by 0" $ + ioProperty . \(i0 :: t0) -> do + let i = transform i0 + sameDivPeval i 0 pf cf consf, + testCase "divide by 1" $ do + pf (ssymTerm "a" :: Term t) (conTerm 1) @=? ssymTerm "a", + testCase "On symbolic" $ do + pf (ssymTerm "a" :: Term t) (ssymTerm "b") + @=? consf (ssymTerm "a" :: Term t) (ssymTerm "b" :: Term t) + ] + +divisionPevalBoundedTestGroup :: + TestName -> + (forall t. (SupportedPrim t, Bounded t, Integral t) => Term t -> Term t -> Term t) -> + (forall t. (Bounded t, Integral t) => t -> t -> t) -> + (forall t. (SupportedPrim t, Bounded t, Integral t) => Term t -> Term t -> Term t) -> + TestTree +divisionPevalBoundedTestGroup name pf cf consf = + testGroup + name + [ divisionPevalTests (Proxy @(IntN 4)) "IntN" IntN pf cf consf, + divisionPevalBoundedTests (Proxy @(IntN 4)) "IntN Bounded" IntN pf cf consf + ] + +divisionPevalUnboundedTestGroup :: + TestName -> + (forall t. (SupportedPrim t, Integral t) => Term t -> Term t -> Term t) -> + (forall t. (Integral t) => t -> t -> t) -> + (forall t. (SupportedPrim t, Integral t) => Term t -> Term t -> Term t) -> + TestTree +divisionPevalUnboundedTestGroup name pf cf consf = + testGroup + name + [ divisionPevalTests (Proxy @Integer) "Integer" id pf cf consf, + divisionPevalTests (Proxy @(WordN 4)) "WordN" WordN pf cf consf, + divisionPevalBoundedTests (Proxy @(WordN 4)) "WordN Bounded" WordN pf cf consf + ] + +moduloPevalTests :: + forall p t0 t. + (Num t, Eq t, Arbitrary t0, Show t0, SupportedPrim t, Integral t) => + p t -> + TestName -> + (t0 -> t) -> + (Term t -> Term t -> Term t) -> + (t -> t -> t) -> + (Term t -> Term t -> Term t) -> + TestTree +moduloPevalTests p name transform pf cf consf = + testGroup + name + [ testProperty "On concrete" $ + ioProperty . \(i0 :: t0, j0 :: t0) -> do + let i = transform i0 + let j = transform j0 + sameDivPeval i j pf cf consf, + testProperty "On concrete divide by 0" $ + ioProperty . \(i0 :: t0) -> do + let i = transform i0 + sameDivPeval i 0 pf cf consf, + testCase "mod by 1" $ do + pf (ssymTerm "a" :: Term t) (conTerm 1) @=? conTerm 0, + testCase "mod by -1" $ do + pf (ssymTerm "a" :: Term t) (conTerm $ -1) @=? conTerm 0, + testCase "On symbolic" $ do + pf (ssymTerm "a" :: Term t) (ssymTerm "b") + @=? consf (ssymTerm "a" :: Term t) (ssymTerm "b" :: Term t) + ] + +moduloPevalTestGroup :: + TestName -> + (forall t. (SupportedPrim t, Integral t) => Term t -> Term t -> Term t) -> + (forall t. (Integral t) => t -> t -> t) -> + (forall t. (SupportedPrim t, Integral t) => Term t -> Term t -> Term t) -> + TestTree +moduloPevalTestGroup name pf cf consf = + testGroup + name + [ moduloPevalTests (Proxy @Integer) "Integer" id pf cf consf, + moduloPevalTests (Proxy @(IntN 4)) "IntN" IntN pf cf consf, + moduloPevalTests (Proxy @(WordN 4)) "WordN" WordN pf cf consf + ] + +integralTests :: TestTree +integralTests = + testGroup + "IntegralTests" + [ divisionPevalUnboundedTestGroup "Div unbounded" pevalDivIntegralTerm div divIntegralTerm, + divisionPevalUnboundedTestGroup "Quot unbounded" pevalQuotIntegralTerm quot quotIntegralTerm, + divisionPevalBoundedTestGroup "Div bounded" pevalDivBoundedIntegralTerm div divBoundedIntegralTerm, + divisionPevalBoundedTestGroup "Quot bounded" pevalQuotBoundedIntegralTerm quot quotBoundedIntegralTerm, + moduloPevalTestGroup "Mod" pevalModIntegralTerm mod modIntegralTerm, + moduloPevalTestGroup "Rem" pevalRemIntegralTerm rem remIntegralTerm + ] diff --git a/test/Grisette/IR/SymPrim/Data/SymPrimTests.hs b/test/Grisette/IR/SymPrim/Data/SymPrimTests.hs index e690fde4..54883796 100644 --- a/test/Grisette/IR/SymPrim/Data/SymPrimTests.hs +++ b/test/Grisette/IR/SymPrim/Data/SymPrimTests.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE NegativeLiterals #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -8,6 +9,8 @@ module Grisette.IR.SymPrim.Data.SymPrimTests where +import Control.DeepSeq +import Control.Exception import Control.Monad.Except import Data.Bits import qualified Data.HashMap.Strict as M @@ -23,10 +26,10 @@ import Grisette.Core.Data.Class.Evaluate import Grisette.Core.Data.Class.ExtractSymbolics import Grisette.Core.Data.Class.Function import Grisette.Core.Data.Class.GenSym -import Grisette.Core.Data.Class.Integer import Grisette.Core.Data.Class.Mergeable import Grisette.Core.Data.Class.ModelOps import Grisette.Core.Data.Class.SOrd +import Grisette.Core.Data.Class.SafeArith import Grisette.Core.Data.Class.SimpleMergeable import Grisette.Core.Data.Class.Solvable import Grisette.Core.Data.Class.ToCon @@ -38,7 +41,7 @@ import Grisette.IR.SymPrim.Data.Prim.ModelValue import Grisette.IR.SymPrim.Data.Prim.PartialEval.BV import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bits import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bool -import Grisette.IR.SymPrim.Data.Prim.PartialEval.Integer +import Grisette.IR.SymPrim.Data.Prim.PartialEval.Integral import Grisette.IR.SymPrim.Data.Prim.PartialEval.Num import Grisette.IR.SymPrim.Data.Prim.PartialEval.TabularFun import Grisette.IR.SymPrim.Data.SymPrim @@ -48,6 +51,396 @@ import Test.Tasty.HUnit import Test.Tasty.QuickCheck hiding ((.&.)) import Type.Reflection hiding (Con) +newtype AEWrapper = AEWrapper ArithException deriving (Eq) + +instance Show AEWrapper where + show (AEWrapper x) = show x + +instance NFData AEWrapper where + rnf (AEWrapper x) = x `seq` () + +sameSafeDiv :: + forall c s. + ( Show s, + Eq s, + Eq c, + Num c, + Mergeable s, + NFData c, + Solvable c s + ) => + c -> + c -> + (s -> s -> ExceptT ArithException UnionM s) -> + (c -> c -> c) -> + Assertion +sameSafeDiv i j f cf = do + xc <- evaluate (force $ Right $ cf i j) `catch` \(e :: ArithException) -> return $ Left $ AEWrapper e + case xc of + Left (AEWrapper e) -> f (con i :: s) (con j) @=? merge (throwError e) + Right c -> f (con i :: s) (con j) @=? mrgSingle (con c) + +sameSafeDiv' :: + forall c s. + ( Show s, + Eq s, + Eq c, + Num c, + Mergeable s, + NFData c, + Solvable c s + ) => + c -> + c -> + ((ArithException -> ()) -> s -> s -> ExceptT () UnionM s) -> + (c -> c -> c) -> + Assertion +sameSafeDiv' i j f cf = do + xc <- evaluate (force $ Right $ cf i j) `catch` \(e :: ArithException) -> return $ Left () + case xc of + Left () -> f (const ()) (con i :: s) (con j) @=? merge (throwError ()) + Right c -> f (const ()) (con i :: s) (con j) @=? mrgSingle (con c) + +sameSafeDivMod :: + forall c s. + ( Show s, + Eq s, + Eq c, + Num c, + Mergeable s, + NFData c, + Solvable c s + ) => + c -> + c -> + (s -> s -> ExceptT ArithException UnionM (s, s)) -> + (c -> c -> (c, c)) -> + Assertion +sameSafeDivMod i j f cf = do + xc <- evaluate (force $ Right $ cf i j) `catch` \(e :: ArithException) -> return $ Left $ AEWrapper e + case xc of + Left (AEWrapper e) -> f (con i :: s) (con j) @=? merge (throwError e) + Right (c1, c2) -> f (con i :: s) (con j) @=? mrgSingle (con c1, con c2) + +sameSafeDivMod' :: + forall c s. + ( Show s, + Eq s, + Eq c, + Num c, + Mergeable s, + NFData c, + Solvable c s + ) => + c -> + c -> + ((ArithException -> ()) -> s -> s -> ExceptT () UnionM (s, s)) -> + (c -> c -> (c, c)) -> + Assertion +sameSafeDivMod' i j f cf = do + xc <- evaluate (force $ Right $ cf i j) `catch` \(e :: ArithException) -> return $ Left () + case xc of + Left () -> f (const ()) (con i :: s) (con j) @=? merge (throwError ()) + Right (c1, c2) -> f (const ()) (con i :: s) (con j) @=? mrgSingle (con c1, con c2) + +safeDivisionBoundedOnlyTests :: + forall c s. + (LinkedRep c s, Bounded c, Solvable c s, Eq s, Num c, Show s, Mergeable s, SEq s) => + (s -> s -> ExceptT ArithException UnionM s) -> + ((ArithException -> ()) -> s -> s -> ExceptT () UnionM s) -> + (c -> c -> c) -> + (Term c -> Term c -> Term c) -> + [TestTree] +safeDivisionBoundedOnlyTests f f' cf pf = + [ testCase "on concrete min divided by minus one" $ do + sameSafeDiv minBound (-1) f cf + sameSafeDiv' minBound (-1) f' cf, + testCase "on symbolic" $ do + f (ssym "a" :: s) (ssym "b") + @=? ( mrgIf + ((ssym "b" :: s) ==~ con (0 :: c) :: SymBool) + (throwError DivideByZero) + ( mrgIf + ((ssym "b" :: s) ==~ con (-1) &&~ (ssym "a" :: s) ==~ con (minBound :: c) :: SymBool) + (throwError Overflow) + (mrgSingle $ wrapTerm $ pf (ssymTerm "a") (ssymTerm "b")) + ) :: + ExceptT ArithException UnionM s + ) + f' (const ()) (ssym "a" :: s) (ssym "b") + @=? ( mrgIf + ((ssym "b" :: s) ==~ con (0 :: c) :: SymBool) + (throwError ()) + ( mrgIf + ((ssym "b" :: s) ==~ con (-1) &&~ (ssym "a" :: s) ==~ con (minBound :: c) :: SymBool) + (throwError ()) + (mrgSingle $ wrapTerm $ pf (ssymTerm "a") (ssymTerm "b")) + ) :: + ExceptT () UnionM s + ) + ] + +safeDivisionUnboundedOnlyTests :: + forall c s. + (LinkedRep c s, Solvable c s, Eq s, Num c, Show s, Mergeable s, SEq s) => + (s -> s -> ExceptT ArithException UnionM s) -> + ((ArithException -> ()) -> s -> s -> ExceptT () UnionM s) -> + (c -> c -> c) -> + (Term c -> Term c -> Term c) -> + [TestTree] +safeDivisionUnboundedOnlyTests f f' cf pf = + [ testCase "on symbolic" $ do + f (ssym "a" :: s) (ssym "b") + @=? ( mrgIf + ((ssym "b" :: s) ==~ con (0 :: c) :: SymBool) + (throwError DivideByZero) + (mrgSingle $ wrapTerm $ pf (ssymTerm "a") (ssymTerm "b")) :: + ExceptT ArithException UnionM s + ) + f' (const ()) (ssym "a" :: s) (ssym "b") + @=? ( mrgIf + ((ssym "b" :: s) ==~ con (0 :: c) :: SymBool) + (throwError ()) + (mrgSingle $ wrapTerm $ pf (ssymTerm "a") (ssymTerm "b")) :: + ExceptT () UnionM s + ) + ] + +safeDivisionGeneralTests :: + forall c c0 s. + (LinkedRep c s, Arbitrary c0, Show c0, Solvable c s, Eq s, Num c, Show s, Mergeable s, SEq s) => + (c0 -> c) -> + (s -> s -> ExceptT ArithException UnionM s) -> + ((ArithException -> ()) -> s -> s -> ExceptT () UnionM s) -> + (c -> c -> c) -> + (Term c -> Term c -> Term c) -> + [TestTree] +safeDivisionGeneralTests transform f f' cf pf = + [ testProperty "on concrete prop" $ \(i0 :: c0, j0 :: c0) -> + ioProperty $ do + let i = transform i0 + let j = transform j0 + sameSafeDiv i j f cf + sameSafeDiv' i j f' cf, + testProperty "on concrete divided by zero" $ \(i0 :: c0) -> + ioProperty $ do + let i = transform i0 + sameSafeDiv i 0 f cf + sameSafeDiv' i 0 f' cf, + testCase "when divided by zero" $ do + f (ssym "a" :: s) (con 0) + @=? (merge $ throwError DivideByZero :: ExceptT ArithException UnionM s) + f' (const ()) (ssym "a" :: s) (con 0) + @=? (merge $ throwError () :: ExceptT () UnionM s) + ] + +safeDivisionBoundedTests :: + forall c c0 s. + (LinkedRep c s, Arbitrary c0, Show c0, Bounded c, Solvable c s, Eq s, Num c, Show s, Mergeable s, SEq s) => + TestName -> + (c0 -> c) -> + (s -> s -> ExceptT ArithException UnionM s) -> + ((ArithException -> ()) -> s -> s -> ExceptT () UnionM s) -> + (c -> c -> c) -> + (Term c -> Term c -> Term c) -> + TestTree +safeDivisionBoundedTests name transform f f' cf pf = + testGroup name $ + safeDivisionGeneralTests transform f f' cf pf + ++ safeDivisionBoundedOnlyTests f f' cf pf + +safeDivisionUnboundedTests :: + forall c c0 s. + (LinkedRep c s, Arbitrary c0, Show c0, Solvable c s, Eq s, Num c, Show s, Mergeable s, SEq s) => + TestName -> + (c0 -> c) -> + (s -> s -> ExceptT ArithException UnionM s) -> + ((ArithException -> ()) -> s -> s -> ExceptT () UnionM s) -> + (c -> c -> c) -> + (Term c -> Term c -> Term c) -> + TestTree +safeDivisionUnboundedTests name transform f f' cf pf = + testGroup name $ + safeDivisionGeneralTests transform f f' cf pf + ++ safeDivisionUnboundedOnlyTests f f' cf pf + +safeDivModBoundedOnlyTests :: + forall c s. + (LinkedRep c s, Bounded c, Solvable c s, Eq s, Num c, Show s, Mergeable s, SEq s) => + ( s -> + s -> + ExceptT ArithException UnionM (s, s) + ) -> + ( (ArithException -> ()) -> + s -> + s -> + ExceptT () UnionM (s, s) + ) -> + (c -> c -> (c, c)) -> + (Term c -> Term c -> Term c) -> + (Term c -> Term c -> Term c) -> + [TestTree] +safeDivModBoundedOnlyTests f f' cf pf1 pf2 = + [ testCase "on concrete min divided by minus one" $ do + sameSafeDivMod minBound (-1) f cf + sameSafeDivMod' minBound (-1) f' cf, + testCase "on symbolic" $ do + f (ssym "a" :: s) (ssym "b") + @=? ( mrgIf + ((ssym "b" :: s) ==~ con (0 :: c) :: SymBool) + (throwError DivideByZero) + ( mrgIf + ((ssym "b" :: s) ==~ con (-1) &&~ (ssym "a" :: s) ==~ con (minBound :: c) :: SymBool) + (throwError Overflow) + ( mrgSingle + ( wrapTerm $ pf1 (ssymTerm "a") (ssymTerm "b"), + wrapTerm $ pf2 (ssymTerm "a") (ssymTerm "b") + ) + ) + ) :: + ExceptT ArithException UnionM (s, s) + ) + f' (const ()) (ssym "a" :: s) (ssym "b") + @=? ( mrgIf + ((ssym "b" :: s) ==~ con (0 :: c) :: SymBool) + (throwError ()) + ( mrgIf + ((ssym "b" :: s) ==~ con (-1) &&~ (ssym "a" :: s) ==~ con (minBound :: c) :: SymBool) + (throwError ()) + ( mrgSingle + ( wrapTerm $ pf1 (ssymTerm "a") (ssymTerm "b"), + wrapTerm $ pf2 (ssymTerm "a") (ssymTerm "b") + ) + ) + ) :: + ExceptT () UnionM (s, s) + ) + ] + +safeDivModUnboundedOnlyTests :: + forall c s. + (LinkedRep c s, Solvable c s, Eq s, Num c, Show s, Mergeable s, SEq s) => + ( s -> + s -> + ExceptT ArithException UnionM (s, s) + ) -> + ( (ArithException -> ()) -> + s -> + s -> + ExceptT () UnionM (s, s) + ) -> + (c -> c -> (c, c)) -> + (Term c -> Term c -> Term c) -> + (Term c -> Term c -> Term c) -> + [TestTree] +safeDivModUnboundedOnlyTests f f' cf pf1 pf2 = + [ testCase "on symbolic" $ do + f (ssym "a" :: s) (ssym "b") + @=? ( mrgIf + ((ssym "b" :: s) ==~ con (0 :: c) :: SymBool) + (throwError DivideByZero) + ( mrgSingle + ( wrapTerm $ pf1 (ssymTerm "a") (ssymTerm "b"), + wrapTerm $ pf2 (ssymTerm "a") (ssymTerm "b") + ) + ) :: + ExceptT ArithException UnionM (s, s) + ) + f' (const ()) (ssym "a" :: s) (ssym "b") + @=? ( mrgIf + ((ssym "b" :: s) ==~ con (0 :: c) :: SymBool) + (throwError ()) + ( mrgSingle + ( wrapTerm $ pf1 (ssymTerm "a") (ssymTerm "b"), + wrapTerm $ pf2 (ssymTerm "a") (ssymTerm "b") + ) + ) :: + ExceptT () UnionM (s, s) + ) + ] + +safeDivModGeneralTests :: + forall c c0 s. + (LinkedRep c s, Arbitrary c0, Show c0, Solvable c s, Eq s, Num c, Show s, Mergeable s, SEq s) => + (c0 -> c) -> + ( s -> + s -> + ExceptT ArithException UnionM (s, s) + ) -> + ( (ArithException -> ()) -> + s -> + s -> + ExceptT () UnionM (s, s) + ) -> + (c -> c -> (c, c)) -> + (Term c -> Term c -> Term c) -> + (Term c -> Term c -> Term c) -> + [TestTree] +safeDivModGeneralTests transform f f' cf pf1 pf2 = + [ testProperty "on concrete" $ \(i0 :: c0, j0 :: c0) -> + ioProperty $ do + let i = transform i0 + let j = transform j0 + sameSafeDivMod i j f cf + sameSafeDivMod' i j f' cf, + testProperty "on concrete divided by zero" $ \(i0 :: c0) -> + ioProperty $ do + let i = transform i0 + sameSafeDivMod i 0 f cf + sameSafeDivMod' i 0 f' cf, + testCase "when divided by zero" $ do + f (ssym "a" :: s) (con 0) + @=? (merge $ throwError DivideByZero :: ExceptT ArithException UnionM (s, s)) + f' (const ()) (ssym "a" :: s) (con 0) + @=? (merge $ throwError () :: ExceptT () UnionM (s, s)) + ] + +safeDivModBoundedTests :: + forall c c0 s. + (LinkedRep c s, Arbitrary c0, Show c0, Bounded c, Solvable c s, Eq s, Num c, Show s, Mergeable s, SEq s) => + TestName -> + (c0 -> c) -> + ( s -> + s -> + ExceptT ArithException UnionM (s, s) + ) -> + ( (ArithException -> ()) -> + s -> + s -> + ExceptT () UnionM (s, s) + ) -> + (c -> c -> (c, c)) -> + (Term c -> Term c -> Term c) -> + (Term c -> Term c -> Term c) -> + TestTree +safeDivModBoundedTests name transform f f' cf pf1 pf2 = + testGroup name $ + safeDivModGeneralTests transform f f' cf pf1 pf2 + ++ safeDivModBoundedOnlyTests f f' cf pf1 pf2 + +safeDivModUnboundedTests :: + forall c c0 s. + (LinkedRep c s, Arbitrary c0, Show c0, Solvable c s, Eq s, Num c, Show s, Mergeable s, SEq s) => + TestName -> + (c0 -> c) -> + ( s -> + s -> + ExceptT ArithException UnionM (s, s) + ) -> + ( (ArithException -> ()) -> + s -> + s -> + ExceptT () UnionM (s, s) + ) -> + (c -> c -> (c, c)) -> + (Term c -> Term c -> Term c) -> + (Term c -> Term c -> Term c) -> + TestTree +safeDivModUnboundedTests name transform f f' cf pf1 pf2 = + testGroup name $ + safeDivModGeneralTests transform f f' cf pf1 pf2 + ++ safeDivModUnboundedOnlyTests f f' cf pf1 pf2 + symPrimTests :: TestTree symPrimTests = testGroup @@ -56,12 +449,9 @@ symPrimTests = "General SymPrim" [ testGroup "Solvable" - [ testCase "con" $ do - (con 1 :: SymInteger) @=? SymInteger (conTerm 1), - testCase "ssym" $ do - (ssym "a" :: SymInteger) @=? SymInteger (ssymTerm "a"), - testCase "isym" $ do - (isym "a" 1 :: SymInteger) @=? SymInteger (isymTerm "a" 1), + [ testCase "con" $ (con 1 :: SymInteger) @=? SymInteger (conTerm 1), + testCase "ssym" $ (ssym "a" :: SymInteger) @=? SymInteger (ssymTerm "a"), + testCase "isym" $ (isym "a" 1 :: SymInteger) @=? SymInteger (isymTerm "a" 1), testCase "conView" $ do conView (con 1 :: SymInteger) @=? Just 1 conView (ssym "a" :: SymInteger) @=? Nothing @@ -82,24 +472,19 @@ symPrimTests = let SimpleStrategy s = rootStrategy :: MergingStrategy SymInteger s (ssym "a") (ssym "b") (ssym "c") @=? ites (ssym "a" :: SymBool) (ssym "b" :: SymInteger) (ssym "c"), - testCase "SimpleMergeable" $ do + testCase "SimpleMergeable" $ mrgIte (ssym "a" :: SymBool) (ssym "b") (ssym "c") @=? ites (ssym "a" :: SymBool) (ssym "b" :: SymInteger) (ssym "c"), - testCase "IsString" $ do - ("a" :: SymBool) @=? SymBool (ssymTerm "a"), + testCase "IsString" $ ("a" :: SymBool) @=? SymBool (ssymTerm "a"), testGroup "ToSym" - [ testCase "From self" $ do - toSym (ssym "a" :: SymBool) @=? (ssym "a" :: SymBool), - testCase "From concrete" $ do - toSym True @=? (con True :: SymBool) + [ testCase "From self" $ toSym (ssym "a" :: SymBool) @=? (ssym "a" :: SymBool), + testCase "From concrete" $ toSym True @=? (con True :: SymBool) ], testGroup "ToCon" - [ testCase "To self" $ do - toCon (ssym "a" :: SymBool) @=? (Nothing :: Maybe Bool), - testCase "To concrete" $ do - toCon True @=? Just True + [ testCase "To self" $ toCon (ssym "a" :: SymBool) @=? (Nothing :: Maybe Bool), + testCase "To concrete" $ toCon True @=? Just True ], testCase "EvaluateSym" $ do let m1 = emptyModel :: Model @@ -108,7 +493,7 @@ symPrimTests = evaluateSym False m3 (ites ("c" :: SymBool) "a" ("a" + "a" :: SymInteger)) @=? ites ("c" :: SymBool) 1 2 evaluateSym True m3 (ites ("c" :: SymBool) "a" ("a" + "a" :: SymInteger)) @=? 2, - testCase "ExtractSymbolics" $ do + testCase "ExtractSymbolics" $ extractSymbolics (ites ("c" :: SymBool) ("a" :: SymInteger) ("b" :: SymInteger)) @=? SymbolSet ( S.fromList @@ -132,121 +517,33 @@ symPrimTests = "SymBool" [ testGroup "LogicalOp" - [ testCase "||~" $ do - ssym "a" ||~ ssym "b" @=? SymBool (pevalOrTerm (ssymTerm "a") (ssymTerm "b")), - testCase "&&~" $ do - ssym "a" &&~ ssym "b" @=? SymBool (pevalAndTerm (ssymTerm "a") (ssymTerm "b")), - testCase "nots" $ do - nots (ssym "a") @=? SymBool (pevalNotTerm (ssymTerm "a")), - testCase "xors" $ do - xors (ssym "a") (ssym "b") @=? SymBool (pevalXorTerm (ssymTerm "a") (ssymTerm "b")), - testCase "implies" $ do - implies (ssym "a") (ssym "b") @=? SymBool (pevalImplyTerm (ssymTerm "a") (ssymTerm "b")) + [ testCase "||~" $ ssym "a" ||~ ssym "b" @=? SymBool (pevalOrTerm (ssymTerm "a") (ssymTerm "b")), + testCase "&&~" $ ssym "a" &&~ ssym "b" @=? SymBool (pevalAndTerm (ssymTerm "a") (ssymTerm "b")), + testCase "nots" $ nots (ssym "a") @=? SymBool (pevalNotTerm (ssymTerm "a")), + testCase "xors" $ xors (ssym "a") (ssym "b") @=? SymBool (pevalXorTerm (ssymTerm "a") (ssymTerm "b")), + testCase "implies" $ implies (ssym "a") (ssym "b") @=? SymBool (pevalImplyTerm (ssymTerm "a") (ssymTerm "b")) ] ], testGroup "SymInteger" [ testGroup "Num" - [ testCase "fromInteger" $ do - (1 :: SymInteger) @=? SymInteger (conTerm 1), - testCase "(+)" $ do - (ssym "a" :: SymInteger) + ssym "b" @=? SymInteger (pevalAddNumTerm (ssymTerm "a") (ssymTerm "b")), - testCase "(-)" $ do - (ssym "a" :: SymInteger) - ssym "b" @=? SymInteger (pevalMinusNumTerm (ssymTerm "a") (ssymTerm "b")), - testCase "(*)" $ do - (ssym "a" :: SymInteger) * ssym "b" @=? SymInteger (pevalTimesNumTerm (ssymTerm "a") (ssymTerm "b")), - testCase "negate" $ do - negate (ssym "a" :: SymInteger) @=? SymInteger (pevalUMinusNumTerm (ssymTerm "a")), - testCase "abs" $ do - abs (ssym "a" :: SymInteger) @=? SymInteger (pevalAbsNumTerm (ssymTerm "a")), - testCase "signum" $ do - signum (ssym "a" :: SymInteger) @=? SymInteger (pevalSignumNumTerm (ssymTerm "a")) + [ testCase "fromInteger" $ (1 :: SymInteger) @=? SymInteger (conTerm 1), + testCase "(+)" $ (ssym "a" :: SymInteger) + ssym "b" @=? SymInteger (pevalAddNumTerm (ssymTerm "a") (ssymTerm "b")), + testCase "(-)" $ (ssym "a" :: SymInteger) - ssym "b" @=? SymInteger (pevalMinusNumTerm (ssymTerm "a") (ssymTerm "b")), + testCase "(*)" $ (ssym "a" :: SymInteger) * ssym "b" @=? SymInteger (pevalTimesNumTerm (ssymTerm "a") (ssymTerm "b")), + testCase "negate" $ negate (ssym "a" :: SymInteger) @=? SymInteger (pevalUMinusNumTerm (ssymTerm "a")), + testCase "abs" $ abs (ssym "a" :: SymInteger) @=? SymInteger (pevalAbsNumTerm (ssymTerm "a")), + testCase "signum" $ signum (ssym "a" :: SymInteger) @=? SymInteger (pevalSignumNumTerm (ssymTerm "a")) ], testGroup "SafeDivision" - [ testProperty "safeDiv on concrete" $ \(i :: Integer, j :: Integer) -> - ioProperty $ - safeDiv () (con i :: SymInteger) (con j) - @=? if j == 0 - then merge $ throwError () :: ExceptT () UnionM SymInteger - else mrgSingle $ con $ i `div` j, - testCase "safeDiv when divided by zero" $ do - safeDiv () (ssym "a" :: SymInteger) (con 0) - @=? (merge $ throwError () :: ExceptT () UnionM SymInteger), - testCase "safeDiv on symbolic" $ do - safeDiv () (ssym "a" :: SymInteger) (ssym "b") - @=? ( mrgIf - ((ssym "b" :: SymInteger) ==~ con (0 :: Integer) :: SymBool) - (throwError ()) - (mrgSingle $ SymInteger $ pevalDivIntegerTerm (ssymTerm "a") (ssymTerm "b")) :: - ExceptT () UnionM SymInteger - ), - testProperty "safeMod on concrete" $ \(i :: Integer, j :: Integer) -> - ioProperty $ - safeMod () (con i :: SymInteger) (con j) - @=? if j == 0 - then merge $ throwError () :: ExceptT () UnionM SymInteger - else mrgSingle $ con $ i `mod` j, - testCase "safeMod when divided by zero" $ do - safeMod () (ssym "a" :: SymInteger) (con 0) - @=? (merge $ throwError () :: ExceptT () UnionM SymInteger), - testCase "safeMod on symbolic" $ do - safeMod () (ssym "a" :: SymInteger) (ssym "b") - @=? ( mrgIf - ((ssym "b" :: SymInteger) ==~ con (0 :: Integer) :: SymBool) - (throwError ()) - (mrgSingle $ SymInteger $ pevalModIntegerTerm (ssymTerm "a") (ssymTerm "b")) :: - ExceptT () UnionM SymInteger - ), - testProperty "safeDivMod on concrete" $ \(i :: Integer, j :: Integer) -> - ioProperty $ - safeDivMod () (con i :: SymInteger) (con j) - @=? if j == 0 - then merge $ throwError () :: ExceptT () UnionM (SymInteger, SymInteger) - else mrgSingle $ (con $ i `div` j, con $ i `mod` j), - testCase "safeDivMod when divided by zero" $ do - safeDivMod () (ssym "a" :: SymInteger) (con 0) - @=? (merge $ throwError () :: ExceptT () UnionM (SymInteger, SymInteger)), - testCase "safeDivMod on symbolic" $ do - safeDivMod () (ssym "a" :: SymInteger) (ssym "b") - @=? ( mrgIf - ((ssym "b" :: SymInteger) ==~ con (0 :: Integer) :: SymBool) - (throwError ()) - ( mrgSingle - ( SymInteger $ pevalDivIntegerTerm (ssymTerm "a") (ssymTerm "b"), - SymInteger $ pevalModIntegerTerm (ssymTerm "a") (ssymTerm "b") - ) - ) :: - ExceptT () UnionM (SymInteger, SymInteger) - ), - testProperty "safeQuot on concrete" $ \(i :: Integer, j :: Integer) -> - ioProperty $ - safeQuot () (con i :: SymInteger) (con j) - @=? if j == 0 - then merge $ throwError () :: ExceptT () UnionM SymInteger - else mrgSingle $ con $ i `quot` j, - testCase "safeQuot when divided by zero" $ do - safeQuot () (ssym "a" :: SymInteger) (con 0) - @=? (merge $ throwError () :: ExceptT () UnionM SymInteger), - testProperty "safeRem on concrete" $ \(i :: Integer, j :: Integer) -> - ioProperty $ - safeRem () (con i :: SymInteger) (con j) - @=? if j == 0 - then merge $ throwError () :: ExceptT () UnionM SymInteger - else mrgSingle $ con $ i `rem` j, - testCase "safeRem when divided by zero" $ do - safeRem () (ssym "a" :: SymInteger) (con 0) - @=? (merge $ throwError () :: ExceptT () UnionM SymInteger), - testProperty "safeQuotRem on concrete" $ \(i :: Integer, j :: Integer) -> - ioProperty $ - safeQuotRem () (con i :: SymInteger) (con j) - @=? if j == 0 - then merge $ throwError () :: ExceptT () UnionM (SymInteger, SymInteger) - else mrgSingle $ (con $ i `quot` j, con $ i `rem` j), - testCase "safeQuotRem when divided by zero" $ do - safeQuotRem () (ssym "a" :: SymInteger) (con 0) - @=? (merge $ throwError () :: ExceptT () UnionM (SymInteger, SymInteger)) + [ safeDivisionUnboundedTests @Integer "safeDiv" id safeDiv safeDiv' div pevalDivIntegralTerm, + safeDivisionUnboundedTests @Integer "safeMod" id safeMod safeMod' mod pevalModIntegralTerm, + safeDivModUnboundedTests @Integer "safeDivMod" id safeDivMod safeDivMod' divMod pevalDivIntegralTerm pevalModIntegralTerm, + safeDivisionUnboundedTests @Integer "safeQuot" id safeQuot safeQuot' quot pevalQuotIntegralTerm, + safeDivisionUnboundedTests @Integer "safeRem" id safeRem safeRem' rem pevalRemIntegralTerm, + safeDivModUnboundedTests @Integer "safeQuotRem" id safeQuotRem safeQuotRem' quotRem pevalQuotIntegralTerm pevalRemIntegralTerm ], testGroup "SafeLinearArith" @@ -254,21 +551,21 @@ symPrimTests = ioProperty $ safeAdd () (con i :: SymInteger) (con j) @=? (mrgSingle $ con $ i + j :: ExceptT () UnionM SymInteger), - testCase "safeAdd on symbolic" $ do + testCase "safeAdd on symbolic" $ safeAdd () (ssym "a" :: SymInteger) (ssym "b") @=? (mrgSingle $ SymInteger $ pevalAddNumTerm (ssymTerm "a") (ssymTerm "b") :: ExceptT () UnionM SymInteger), testProperty "safeNeg on concrete" $ \(i :: Integer) -> ioProperty $ safeNeg () (con i :: SymInteger) @=? (mrgSingle $ con $ -i :: ExceptT () UnionM SymInteger), - testCase "safeNeg on symbolic" $ do + testCase "safeNeg on symbolic" $ safeNeg () (ssym "a" :: SymInteger) @=? (mrgSingle $ SymInteger $ pevalUMinusNumTerm (ssymTerm "a") :: ExceptT () UnionM SymInteger), testProperty "safeMinus on concrete" $ \(i :: Integer, j :: Integer) -> ioProperty $ safeMinus () (con i :: SymInteger) (con j) @=? (mrgSingle $ con $ i - j :: ExceptT () UnionM SymInteger), - testCase "safeMinus on symbolic" $ do + testCase "safeMinus on symbolic" $ safeMinus () (ssym "a" :: SymInteger) (ssym "b") @=? (mrgSingle $ SymInteger $ pevalMinusNumTerm (ssymTerm "a") (ssymTerm "b") :: ExceptT () UnionM SymInteger) ], @@ -329,6 +626,27 @@ symPrimTests = signum au @=? SymWordN (pevalSignumNumTerm aut) signum as @=? SymIntN (pevalSignumNumTerm ast) ], + testGroup + "SafeDivision" + [ testGroup + "WordN" + [ safeDivisionUnboundedTests @(WordN 4) "safeDiv" WordN safeDiv safeDiv' div pevalDivIntegralTerm, + safeDivisionUnboundedTests @(WordN 4) "safeMod" WordN safeMod safeMod' mod pevalModIntegralTerm, + safeDivModUnboundedTests @(WordN 4) "safeDivMod" WordN safeDivMod safeDivMod' divMod pevalDivIntegralTerm pevalModIntegralTerm, + safeDivisionUnboundedTests @(WordN 4) "safeQuot" WordN safeQuot safeQuot' quot pevalQuotIntegralTerm, + safeDivisionUnboundedTests @(WordN 4) "safeRem" WordN safeRem safeRem' rem pevalRemIntegralTerm, + safeDivModUnboundedTests @(WordN 4) "safeQuotRem" WordN safeQuotRem safeQuotRem' divMod pevalQuotIntegralTerm pevalRemIntegralTerm + ], + testGroup + "IntN" + [ safeDivisionBoundedTests @(IntN 4) "safeDiv" IntN safeDiv safeDiv' div pevalDivBoundedIntegralTerm, + safeDivisionUnboundedTests @(IntN 4) "safeMod" IntN safeMod safeMod' mod pevalModBoundedIntegralTerm, + safeDivModBoundedTests @(IntN 4) "safeDivMod" IntN safeDivMod safeDivMod' divMod pevalDivBoundedIntegralTerm pevalModBoundedIntegralTerm, + safeDivisionBoundedTests @(IntN 4) "safeQuot" IntN safeQuot safeQuot' quot pevalQuotBoundedIntegralTerm, + safeDivisionUnboundedTests @(IntN 4) "safeRem" IntN safeRem safeRem' rem pevalRemBoundedIntegralTerm, + safeDivModBoundedTests @(IntN 4) "safeQuotRem" IntN safeQuotRem safeQuotRem' quotRem pevalQuotBoundedIntegralTerm pevalRemBoundedIntegralTerm + ] + ], testGroup "SafeLinearArith" [ testGroup @@ -338,30 +656,27 @@ symPrimTests = let iint = fromIntegral i :: Integer jint = fromIntegral j in safeAdd () (toSym i :: SymIntN 8) (toSym j) - @=? ( mrgIf - (iint + jint ==~ fromIntegral (i + j)) - (mrgSingle $ toSym $ i + j :: ExceptT () UnionM (SymIntN 8)) - (throwError ()) - ), + @=? mrgIf + (iint + jint ==~ fromIntegral (i + j)) + (mrgSingle $ toSym $ i + j :: ExceptT () UnionM (SymIntN 8)) + (throwError ()), testProperty "safeMinus on concrete" $ \(i :: Int8, j :: Int8) -> ioProperty $ let iint = fromIntegral i :: Integer jint = fromIntegral j in safeMinus () (toSym i :: SymIntN 8) (toSym j) - @=? ( mrgIf - (iint - jint ==~ fromIntegral (i - j)) - (mrgSingle $ toSym $ i - j :: ExceptT () UnionM (SymIntN 8)) - (throwError ()) - ), + @=? mrgIf + (iint - jint ==~ fromIntegral (i - j)) + (mrgSingle $ toSym $ i - j :: ExceptT () UnionM (SymIntN 8)) + (throwError ()), testProperty "safeNeg on concrete" $ \(i :: Int8) -> ioProperty $ let iint = fromIntegral i :: Integer in safeNeg () (toSym i :: SymIntN 8) - @=? ( mrgIf - (-iint ==~ fromIntegral (-i)) - (mrgSingle $ toSym $ -i :: ExceptT () UnionM (SymIntN 8)) - (throwError ()) - ) + @=? mrgIf + (-iint ==~ fromIntegral (-i)) + (mrgSingle $ toSym $ -i :: ExceptT () UnionM (SymIntN 8)) + (throwError ()) ], testGroup "WordN" @@ -370,30 +685,27 @@ symPrimTests = let iint = fromIntegral i :: Integer jint = fromIntegral j in safeAdd () (toSym i :: SymWordN 8) (toSym j) - @=? ( mrgIf - (iint + jint ==~ fromIntegral (i + j)) - (mrgSingle $ toSym $ i + j :: ExceptT () UnionM (SymWordN 8)) - (throwError ()) - ), + @=? mrgIf + (iint + jint ==~ fromIntegral (i + j)) + (mrgSingle $ toSym $ i + j :: ExceptT () UnionM (SymWordN 8)) + (throwError ()), testProperty "safeMinus on concrete" $ \(i :: Word8, j :: Word8) -> ioProperty $ let iint = fromIntegral i :: Integer jint = fromIntegral j in safeMinus () (toSym i :: SymWordN 8) (toSym j) - @=? ( mrgIf - (iint - jint ==~ fromIntegral (i - j)) - (mrgSingle $ toSym $ i - j :: ExceptT () UnionM (SymWordN 8)) - (throwError ()) - ), + @=? mrgIf + (iint - jint ==~ fromIntegral (i - j)) + (mrgSingle $ toSym $ i - j :: ExceptT () UnionM (SymWordN 8)) + (throwError ()), testProperty "safeNeg on concrete" $ \(i :: Word8) -> ioProperty $ let iint = fromIntegral i :: Integer in safeNeg () (toSym i :: SymWordN 8) - @=? ( mrgIf - (-iint ==~ fromIntegral (-i)) - (mrgSingle $ toSym $ -i :: ExceptT () UnionM (SymWordN 8)) - (throwError ()) - ) + @=? mrgIf + (-iint ==~ fromIntegral (-i)) + (mrgSingle $ toSym $ -i :: ExceptT () UnionM (SymWordN 8)) + (throwError ()) ] ], testGroup @@ -476,7 +788,7 @@ symPrimTests = ], testGroup "sizedBVConcat" - [ testCase "sizedBVConcat" $ do + [ testCase "sizedBVConcat" $ sizedBVConcat (ssym "a" :: SymWordN 4) (ssym "b" :: SymWordN 3) @@ -533,7 +845,7 @@ symPrimTests = ], testGroup "TabularFun" - [ testCase "apply" $ do + [ testCase "apply" $ (ssym "a" :: SymInteger =~> SymInteger) # ssym "b" @=? SymInteger (pevalTabularFunApplyTerm (ssymTerm "a" :: Term (Integer =-> Integer)) (ssymTerm "b")) @@ -561,8 +873,7 @@ symPrimTests = symSize (ssym "a" + ssym "a" :: SymInteger) @=? 2 symSize (-(ssym "a") :: SymInteger) @=? 2 symSize (ites (ssym "a" :: SymBool) (ssym "b") (ssym "c") :: SymInteger) @=? 4, - testCase "symsSize" $ do - symsSize [ssym "a" :: SymInteger, ssym "a" + ssym "a"] @=? 2 + testCase "symsSize" $ symsSize [ssym "a" :: SymInteger, ssym "a" + ssym "a"] @=? 2 ], let asymbol :: TypedSymbol Integer = "a" bsymbol :: TypedSymbol Bool = "b" diff --git a/test/Main.hs b/test/Main.hs index 0b909cf1..19c9cdf0 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -30,7 +30,7 @@ import qualified Grisette.Core.Data.BVTests import qualified Grisette.IR.SymPrim.Data.Prim.BVTests import Grisette.IR.SymPrim.Data.Prim.BitsTests import Grisette.IR.SymPrim.Data.Prim.BoolTests -import Grisette.IR.SymPrim.Data.Prim.IntegerTests +import Grisette.IR.SymPrim.Data.Prim.IntegralTests import Grisette.IR.SymPrim.Data.Prim.ModelTests import Grisette.IR.SymPrim.Data.Prim.NumTests import qualified Grisette.IR.SymPrim.Data.Prim.TabularFunTests @@ -129,7 +129,7 @@ irTests = [ bitsTests, boolTests, Grisette.IR.SymPrim.Data.Prim.BVTests.bvTests, - integerTests, + integralTests, modelTests, numTests, Grisette.IR.SymPrim.Data.Prim.TabularFunTests.tabularFunTests From 548740a971b1c58bdd319ca299e1bdb0353f985e Mon Sep 17 00:00:00 2001 From: Sirui Lu Date: Mon, 20 Feb 2023 01:54:39 -0800 Subject: [PATCH 4/6] :bug: make clang happy --- src/Grisette/Core/Data/Class/SafeArith.hs | 14 ++++++++++---- src/Grisette/IR/SymPrim/Data/SymPrim.hs | 12 ++++++++---- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/Grisette/Core/Data/Class/SafeArith.hs b/src/Grisette/Core/Data/Class/SafeArith.hs index a2da39b6..cb462847 100644 --- a/src/Grisette/Core/Data/Class/SafeArith.hs +++ b/src/Grisette/Core/Data/Class/SafeArith.hs @@ -172,11 +172,17 @@ class (SOrd a, Num a, Mergeable a, Mergeable e) => SafeDivision e a | a -> e whe {-# MINIMAL (safeDivMod | (safeDiv, safeMod)), (safeDivMod' | (safeDiv', safeMod')) #-} +#define QUOTE() ' +#define QID(a) a +#define QRIGHT(a) QID(a)' + +#define QRIGHTT(a) QID(a)' t' + #define SAFE_DIVISION_FUNC(name, op) \ name _ r | r == 0 = merge $ throwError DivideByZero; \ name l r = mrgSingle $ l `op` r; \ -name' t _ r | r == 0 = merge $ throwError (t DivideByZero); \ -name' _ l r = mrgSingle $ l `op` r +QRIGHTT(name) _ r | r == 0 = let t1 = t' in merge $ throwError (t' DivideByZero); \ +QRIGHTT(name) l r = mrgSingle $ l `op` r #define SAFE_DIVISION_CONCRETE(type) \ instance SafeDivision ArithException type where \ @@ -218,7 +224,7 @@ SAFE_DIVISION_CONCRETE(Word) then merge $ throwError $ Right DivideByZero \ else mrgSingle $ stype $ l `op` r; \ Nothing -> merge $ throwError $ Left BitwidthMismatch); \ - name' t (stype (l :: type l)) (stype (r :: type r)) = \ + QRIGHT(name) t (stype (l :: type l)) (stype (r :: type r)) = \ (case sameNat (Proxy @l) (Proxy @r) of \ Just Refl -> \ if r == 0 \ @@ -234,7 +240,7 @@ SAFE_DIVISION_CONCRETE(Word) then merge $ throwError $ Right DivideByZero \ else (case l `op` r of (d, m) -> mrgSingle (stype d, stype m)); \ Nothing -> merge $ throwError $ Left BitwidthMismatch); \ - name' t (stype (l :: type l)) (stype (r :: type r)) = \ + QRIGHT(name) t (stype (l :: type l)) (stype (r :: type r)) = \ (case sameNat (Proxy @l) (Proxy @r) of \ Just Refl -> \ if r == 0 \ diff --git a/src/Grisette/IR/SymPrim/Data/SymPrim.hs b/src/Grisette/IR/SymPrim/Data/SymPrim.hs index 4fba5d59..46005854 100644 --- a/src/Grisette/IR/SymPrim/Data/SymPrim.hs +++ b/src/Grisette/IR/SymPrim/Data/SymPrim.hs @@ -140,13 +140,17 @@ newtype SymBool = SymBool {underlyingBoolTerm :: Term Bool} newtype SymInteger = SymInteger {underlyingIntegerTerm :: Term Integer} deriving (Lift, NFData, Generic) +#define QUOTE() ' +#define QID(a) a +#define QRIGHT(a) QID(a)' + #define SAFE_DIVISION_FUNC(name, type, op) \ name (type l) rs@(type r) = \ mrgIf \ (rs ==~ con 0) \ (throwError DivideByZero) \ (mrgReturn $ type $ op l r); \ -name' t (type l) rs@(type r) = \ +QRIGHT(name) t (type l) rs@(type r) = \ mrgIf \ (rs ==~ con 0) \ (throwError (t DivideByZero)) \ @@ -158,7 +162,7 @@ name (type l) rs@(type r) = \ (rs ==~ con 0) \ (throwError DivideByZero) \ (mrgReturn (type $ op1 l r, type $ op2 l r)); \ -name' t (type l) rs@(type r) = \ +QRIGHT(name) t (type l) rs@(type r) = \ mrgIf \ (rs ==~ con 0) \ (throwError (t DivideByZero)) \ @@ -208,7 +212,7 @@ name ls@(type l) rs@(type r) = \ (mrgIf (rs ==~ con (-1) &&~ ls ==~ con minBound) \ (throwError Overflow) \ (mrgReturn $ type $ op l r)); \ -name' t ls@(type l) rs@(type r) = \ +QRIGHT(name) t ls@(type l) rs@(type r) = \ mrgIf \ (rs ==~ con 0) \ (throwError (t DivideByZero)) \ @@ -224,7 +228,7 @@ name ls@(type l) rs@(type r) = \ (mrgIf (rs ==~ con (-1) &&~ ls ==~ con minBound) \ (throwError Overflow) \ (mrgReturn (type $ op1 l r, type $ op2 l r))); \ -name' t ls@(type l) rs@(type r) = \ +QRIGHT(name) t ls@(type l) rs@(type r) = \ mrgIf \ (rs ==~ con 0) \ (throwError (t DivideByZero)) \ From 5a1cae849dbfe2c3174ba4ca79bf8a324be64d7d Mon Sep 17 00:00:00 2001 From: Sirui Lu Date: Sun, 26 Feb 2023 22:52:11 -0800 Subject: [PATCH 5/6] :sparkles: instances for SafeLinearArith --- src/Grisette/Core/Data/Class/SafeArith.hs | 166 ++++++++++++++++-- src/Grisette/IR/SymPrim/Data/SymPrim.hs | 89 +++++++--- test/Grisette/IR/SymPrim/Data/SymPrimTests.hs | 108 ++++++++---- 3 files changed, 288 insertions(+), 75 deletions(-) diff --git a/src/Grisette/Core/Data/Class/SafeArith.hs b/src/Grisette/Core/Data/Class/SafeArith.hs index cb462847..a9f4bf6e 100644 --- a/src/Grisette/Core/Data/Class/SafeArith.hs +++ b/src/Grisette/Core/Data/Class/SafeArith.hs @@ -268,33 +268,161 @@ instance SafeDivision (Either BitwidthMismatch ArithException) SomeWordN where SAFE_DIVISION_FUNC_SOME_DIVMOD(SomeWordN, WordN, safeQuotRem, quotRem) #endif -class SafeLinearArith a where - -- | Safe signed '+' with monadic error handling in multi-path execution. - -- Overflows are treated as errors. +-- | Safe division with monadic error handling in multi-path +-- execution. These procedures throw an exception when overflow or underflow happens. +-- The result should be able to handle errors with `MonadError`. +class (SOrd a, Num a, Mergeable a, Mergeable e) => SafeLinearArith e a | a -> e where + -- | Safe '+' with monadic error handling in multi-path execution. + -- Overflows or underflows are treated as errors. -- - -- >>> safeAdd AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM SymInteger + -- >>> safeAdd (ssym "a") (ssym "b") :: ExceptT ArithException UnionM SymInteger -- ExceptT {Right (+ a b)} - -- >>> safeAdd AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM (SymIntN 4) - -- ExceptT {If (|| (&& (< 0x0 a) (&& (< 0x0 b) (< (+ a b) 0x0))) (&& (< a 0x0) (&& (< b 0x0) (<= 0x0 (+ a b))))) (Left AssertionError) (Right (+ a b))} - safeAdd :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf a + -- >>> safeAdd (ssym "a") (ssym "b") :: ExceptT ArithException UnionM (SymIntN 4) + -- ExceptT {If (ite (< 0x0 a) (&& (< 0x0 b) (< (+ a b) 0x0)) (&& (< a 0x0) (&& (< b 0x0) (<= 0x0 (+ a b))))) (If (< 0x0 a) (Left arithmetic overflow) (Left arithmetic underflow)) (Right (+ a b))} + safeAdd :: (MonadError e uf, MonadUnion uf) => a -> a -> uf a - -- | Safe signed 'negate' with monadic error handling in multi-path execution. - -- Overflows are treated as errors. + -- | Safe 'negate' with monadic error handling in multi-path execution. + -- Overflows or underflows are treated as errors. -- - -- >>> safeNeg AssertionError (ssym "a") :: ExceptT AssertionError UnionM SymInteger + -- >>> safeNeg (ssym "a") :: ExceptT ArithException UnionM SymInteger -- ExceptT {Right (- a)} - -- >>> safeNeg AssertionError (ssym "a") :: ExceptT AssertionError UnionM (SymIntN 4) - -- ExceptT {If (= a 0x8) (Left AssertionError) (Right (- a))} - safeNeg :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> uf a + -- >>> safeNeg (ssym "a") :: ExceptT ArithException UnionM (SymIntN 4) + -- ExceptT {If (= a 0x8) (Left arithmetic overflow) (Right (- a))} + safeNeg :: (MonadError e uf, MonadUnion uf) => a -> uf a - -- | Safe signed '-' with monadic error handling in multi-path execution. - -- Overflows are treated as errors. + -- | Safe '-' with monadic error handling in multi-path execution. + -- Overflows or underflows are treated as errors. -- - -- >>> safeMinus AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM SymInteger + -- >>> safeMinus (ssym "a") (ssym "b") :: ExceptT ArithException UnionM SymInteger -- ExceptT {Right (+ a (- b))} - -- >>> safeMinus AssertionError (ssym "a") (ssym "b") :: ExceptT AssertionError UnionM (SymIntN 4) - -- ExceptT {If (|| (&& (<= 0x0 a) (&& (< b 0x0) (< (+ a (- b)) 0x0))) (&& (< a 0x0) (&& (< 0x0 b) (< 0x0 (+ a (- b)))))) (Left AssertionError) (Right (+ a (- b)))} - safeMinus :: (MonadError e uf, MonadUnion uf, Mergeable e) => e -> a -> a -> uf a + -- >>> safeMinus (ssym "a") (ssym "b") :: ExceptT ArithException UnionM (SymIntN 4) + -- ExceptT {If (ite (<= 0x0 a) (&& (< b 0x0) (< (+ a (- b)) 0x0)) (&& (< a 0x0) (&& (< 0x0 b) (< 0x0 (+ a (- b)))))) (If (<= 0x0 a) (Left arithmetic overflow) (Left arithmetic underflow)) (Right (+ a (- b)))} + safeMinus :: (MonadError e uf, MonadUnion uf) => a -> a -> uf a + + -- | Safe '+' with monadic error handling in multi-path execution. + -- Overflows or underflows are treated as errors. + -- The error is transformed. + safeAdd' :: (MonadError e' uf, MonadUnion uf, Mergeable e') => (e -> e') -> a -> a -> uf a + + -- | Safe 'negate' with monadic error handling in multi-path execution. + -- Overflows or underflows are treated as errors. + -- The error is transformed. + safeNeg' :: (MonadError e' uf, MonadUnion uf, Mergeable e') => (e -> e') -> a -> uf a + + -- | Safe '-' with monadic error handling in multi-path execution. + -- Overflows or underflows are treated as errors. + -- The error is transformed. + safeMinus' :: (MonadError e' uf, MonadUnion uf, Mergeable e') => (e -> e') -> a -> a -> uf a + +instance SafeLinearArith ArithException Integer where + safeAdd l r = mrgSingle (l + r) + safeNeg l = mrgSingle (-l) + safeMinus l r = mrgSingle (l - r) + safeAdd' _ l r = mrgSingle (l + r) + safeNeg' _ l = mrgSingle (-l) + safeMinus' _ l r = mrgSingle (l - r) + +#define SAFE_LINARITH_SIGNED_CONCRETE_BODY \ + safeAdd l r = let res = l + r in \ + mrgIf (con $ l > 0 && r > 0 && res < 0) \ + (throwError Overflow) \ + (mrgIf (con $ l < 0 && r < 0 && res >= 0) \ + (throwError Underflow) \ + (return res));\ + safeAdd' t' l r = let res = l + r in \ + mrgIf (con $ l > 0 && r > 0 && res < 0) \ + (throwError (t' Overflow)) \ + (mrgIf (con $ l < 0 && r < 0 && res >= 0) \ + (throwError (t' Underflow)) \ + (return res)); \ + safeMinus l r = let res = l - r in \ + mrgIf (con $ l >= 0 && r < 0 && res < 0) \ + (throwError Overflow) \ + (mrgIf (con $ l < 0 && r > 0 && res > 0) \ + (throwError Underflow) \ + (return res));\ + safeMinus' t' l r = let res = l - r in \ + mrgIf (con $ l >= 0 && r < 0 && res < 0) \ + (throwError (t' Overflow)) \ + (mrgIf (con $ l < 0 && r > 0 && res > 0) \ + (throwError (t' Underflow)) \ + (return res)); \ + safeNeg v = mrgIf (con $ v == minBound) (throwError Overflow) (return $ -v);\ + safeNeg' t' v = mrgIf (con $ v == minBound) (throwError (t' Overflow)) (return $ -v) + +#define SAFE_LINARITH_SIGNED_CONCRETE(type) \ +instance SafeLinearArith ArithException type where \ + SAFE_LINARITH_SIGNED_CONCRETE_BODY + +#define SAFE_LINARITH_SIGNED_BV_CONCRETE(type) \ +instance (KnownNat n, 1 <= n) => SafeLinearArith ArithException (type n) where \ + SAFE_LINARITH_SIGNED_CONCRETE_BODY + +#define SAFE_LINARITH_UNSIGNED_CONCRETE_BODY \ + safeAdd l r = let res = l + r in \ + mrgIf (con $ l > res || r > res) \ + (throwError Overflow) \ + (return res);\ + safeAdd' t' l r = let res = l + r in \ + mrgIf (con $ l > res || r > res) \ + (throwError (t' Overflow)) \ + (return res); \ + safeMinus l r = \ + mrgIf (con $ r > l) \ + (throwError Underflow) \ + (return $ l - r);\ + safeMinus' t' l r = \ + mrgIf (con $ r > l) \ + (throwError $ t' Underflow) \ + (return $ l - r);\ + safeNeg v = mrgIf (con $ v /= 0) (throwError Underflow) (return $ -v);\ + safeNeg' t' v = mrgIf (con $ v /= 0) (throwError (t' Underflow)) (return $ -v) + +#define SAFE_LINARITH_UNSIGNED_CONCRETE(type) \ +instance SafeLinearArith ArithException type where \ + SAFE_LINARITH_UNSIGNED_CONCRETE_BODY + +#define SAFE_LINARITH_UNSIGNED_BV_CONCRETE(type) \ +instance (KnownNat n, 1 <= n) => SafeLinearArith ArithException (type n) where \ + SAFE_LINARITH_UNSIGNED_CONCRETE_BODY + +#define SAFE_LINARITH_SOME_CONCRETE(type, ctype) \ +instance SafeLinearArith (Either BitwidthMismatch ArithException) type where \ + safeAdd (type (l :: ctype l)) (type (r :: ctype r)) = merge (\ + case sameNat (Proxy @l) (Proxy @r) of \ + Just Refl -> type <$> safeAdd' Right l r; \ + _ -> throwError $ Left BitwidthMismatch); \ + safeAdd' t (type (l :: ctype l)) (type (r :: ctype r)) = merge (\ + case sameNat (Proxy @l) (Proxy @r) of \ + Just Refl -> type <$> safeAdd' (t . Right) l r; \ + _ -> throwError $ t $ Left BitwidthMismatch); \ + safeMinus (type (l :: ctype l)) (type (r :: ctype r)) = merge (\ + case sameNat (Proxy @l) (Proxy @r) of \ + Just Refl -> type <$> safeMinus' Right l r; \ + _ -> throwError $ Left BitwidthMismatch); \ + safeMinus' t (type (l :: ctype l)) (type (r :: ctype r)) = merge (\ + case sameNat (Proxy @l) (Proxy @r) of \ + Just Refl -> type <$> safeMinus' (t . Right) l r; \ + _ -> throwError $ t $ Left BitwidthMismatch); \ + safeNeg (type l) = merge $ type <$> safeNeg' Right l; \ + safeNeg' t (type l) = merge $ type <$> safeNeg' (t . Right) l + +#if 1 +SAFE_LINARITH_SIGNED_CONCRETE(Int8) +SAFE_LINARITH_SIGNED_CONCRETE(Int16) +SAFE_LINARITH_SIGNED_CONCRETE(Int32) +SAFE_LINARITH_SIGNED_CONCRETE(Int64) +SAFE_LINARITH_SIGNED_CONCRETE(Int) +SAFE_LINARITH_SIGNED_BV_CONCRETE(IntN) +SAFE_LINARITH_SOME_CONCRETE(SomeIntN, IntN) +SAFE_LINARITH_UNSIGNED_CONCRETE(Word8) +SAFE_LINARITH_UNSIGNED_CONCRETE(Word16) +SAFE_LINARITH_UNSIGNED_CONCRETE(Word32) +SAFE_LINARITH_UNSIGNED_CONCRETE(Word64) +SAFE_LINARITH_UNSIGNED_CONCRETE(Word) +SAFE_LINARITH_UNSIGNED_BV_CONCRETE(WordN) +SAFE_LINARITH_SOME_CONCRETE(SomeWordN, WordN) +#endif -- | Aggregation for the operations on symbolic integer types class (Num a, SEq a, SOrd a, Solvable Integer a) => SymIntegerOp a diff --git a/src/Grisette/IR/SymPrim/Data/SymPrim.hs b/src/Grisette/IR/SymPrim/Data/SymPrim.hs index 46005854..770b0c86 100644 --- a/src/Grisette/IR/SymPrim/Data/SymPrim.hs +++ b/src/Grisette/IR/SymPrim/Data/SymPrim.hs @@ -178,10 +178,13 @@ instance SafeDivision ArithException SymInteger where SAFE_DIVISION_FUNC2(safeQuotRem, SymInteger, pevalQuotIntegralTerm, pevalRemIntegralTerm) #endif -instance SafeLinearArith SymInteger where - safeAdd e ls rs = mrgReturn $ ls + rs - safeNeg e v = mrgReturn $ -v - safeMinus e ls rs = mrgReturn $ ls - rs +instance SafeLinearArith ArithException SymInteger where + safeAdd ls rs = mrgReturn $ ls + rs + safeAdd' _ ls rs = mrgReturn $ ls + rs + safeNeg v = mrgReturn $ -v + safeNeg' _ v = mrgReturn $ -v + safeMinus ls rs = mrgReturn $ ls - rs + safeMinus' e ls rs = mrgReturn $ ls - rs instance SymIntegerOp SymInteger @@ -246,20 +249,51 @@ instance (KnownNat n, 1 <= n) => SafeDivision ArithException (SymIntN n) where SAFE_DIVISION_FUNC2_BOUNDED_SIGNED(safeQuotRem, SymIntN, pevalQuotBoundedIntegralTerm, pevalRemBoundedIntegralTerm) #endif -instance (KnownNat n, 1 <= n) => SafeLinearArith (SymIntN n) where - safeAdd e ls rs = +instance (KnownNat n, 1 <= n) => SafeLinearArith ArithException (SymIntN n) where + safeAdd ls rs = mrgIf - ((ls >~ 0 &&~ rs >~ 0 &&~ res <~ 0) ||~ (ls <~ 0 &&~ rs <~ 0 &&~ res >=~ 0)) - (throwError e) - (mrgReturn res) + (ls >~ 0) + (mrgIf (rs >~ 0 &&~ res <~ 0) (throwError Overflow) (return res)) + ( mrgIf + (ls <~ 0 &&~ rs <~ 0 &&~ res >=~ 0) + (throwError Underflow) + (mrgReturn res) + ) where res = ls + rs - safeNeg e v = mrgIf (v ==~ con minBound) (throwError e) (mrgReturn $ -v) - safeMinus e ls rs = + safeAdd' f ls rs = mrgIf - ((ls >=~ 0 &&~ rs <~ 0 &&~ res <~ 0) ||~ (ls <~ 0 &&~ rs >~ 0 &&~ res >~ 0)) - (throwError e) - (mrgReturn res) + (ls >~ 0) + (mrgIf (rs >~ 0 &&~ res <~ 0) (throwError $ f Overflow) (return res)) + ( mrgIf + (ls <~ 0 &&~ rs <~ 0 &&~ res >=~ 0) + (throwError $ f Underflow) + (mrgReturn res) + ) + where + res = ls + rs + safeNeg v = mrgIf (v ==~ con minBound) (throwError Overflow) (mrgReturn $ -v) + safeNeg' f v = mrgIf (v ==~ con minBound) (throwError $ f Overflow) (mrgReturn $ -v) + safeMinus ls rs = + mrgIf + (ls >=~ 0) + (mrgIf (rs <~ 0 &&~ res <~ 0) (throwError Overflow) (return res)) + ( mrgIf + (ls <~ 0 &&~ rs >~ 0 &&~ res >~ 0) + (throwError Underflow) + (mrgReturn res) + ) + where + res = ls - rs + safeMinus' f ls rs = + mrgIf + (ls >=~ 0) + (mrgIf (rs <~ 0 &&~ res <~ 0) (throwError $ f Overflow) (return res)) + ( mrgIf + (ls <~ 0 &&~ rs >~ 0 &&~ res >~ 0) + (throwError $ f Underflow) + (mrgReturn res) + ) where res = ls - rs @@ -343,19 +377,34 @@ instance (KnownNat n, 1 <= n) => SafeDivision ArithException (SymWordN n) where SAFE_DIVISION_FUNC2(safeQuotRem, SymWordN, pevalQuotIntegralTerm, pevalRemIntegralTerm) #endif -instance (KnownNat n, 1 <= n) => SafeLinearArith (SymWordN n) where - safeAdd e ls rs = +instance (KnownNat n, 1 <= n) => SafeLinearArith ArithException (SymWordN n) where + safeAdd ls rs = + mrgIf + (ls >~ res ||~ rs >~ res) + (throwError Overflow) + (mrgReturn res) + where + res = ls + rs + safeAdd' f ls rs = mrgIf (ls >~ res ||~ rs >~ res) - (throwError e) + (throwError $ f Overflow) (mrgReturn res) where res = ls + rs - safeNeg e v = mrgIf (v /=~ 0) (throwError e) (mrgReturn v) - safeMinus e ls rs = + safeNeg v = mrgIf (v /=~ 0) (throwError Underflow) (mrgReturn v) + safeNeg' f v = mrgIf (v /=~ 0) (throwError $ f Underflow) (mrgReturn v) + safeMinus ls rs = + mrgIf + (rs >~ ls) + (throwError Underflow) + (mrgReturn res) + where + res = ls - rs + safeMinus' f ls rs = mrgIf (rs >~ ls) - (throwError e) + (throwError $ f Underflow) (mrgReturn res) where res = ls - rs diff --git a/test/Grisette/IR/SymPrim/Data/SymPrimTests.hs b/test/Grisette/IR/SymPrim/Data/SymPrimTests.hs index 54883796..e6363ad8 100644 --- a/test/Grisette/IR/SymPrim/Data/SymPrimTests.hs +++ b/test/Grisette/IR/SymPrim/Data/SymPrimTests.hs @@ -548,25 +548,37 @@ symPrimTests = testGroup "SafeLinearArith" [ testProperty "safeAdd on concrete" $ \(i :: Integer, j :: Integer) -> - ioProperty $ - safeAdd () (con i :: SymInteger) (con j) + ioProperty $ do + safeAdd (con i :: SymInteger) (con j) + @=? (mrgSingle $ con $ i + j :: ExceptT ArithException UnionM SymInteger) + safeAdd' (const ()) (con i :: SymInteger) (con j) @=? (mrgSingle $ con $ i + j :: ExceptT () UnionM SymInteger), - testCase "safeAdd on symbolic" $ - safeAdd () (ssym "a" :: SymInteger) (ssym "b") + testCase "safeAdd on symbolic" $ do + safeAdd (ssym "a" :: SymInteger) (ssym "b") + @=? (mrgSingle $ SymInteger $ pevalAddNumTerm (ssymTerm "a") (ssymTerm "b") :: ExceptT ArithException UnionM SymInteger) + safeAdd' (const ()) (ssym "a" :: SymInteger) (ssym "b") @=? (mrgSingle $ SymInteger $ pevalAddNumTerm (ssymTerm "a") (ssymTerm "b") :: ExceptT () UnionM SymInteger), testProperty "safeNeg on concrete" $ \(i :: Integer) -> - ioProperty $ - safeNeg () (con i :: SymInteger) + ioProperty $ do + safeNeg (con i :: SymInteger) + @=? (mrgSingle $ con $ -i :: ExceptT ArithException UnionM SymInteger) + safeNeg' (const ()) (con i :: SymInteger) @=? (mrgSingle $ con $ -i :: ExceptT () UnionM SymInteger), - testCase "safeNeg on symbolic" $ - safeNeg () (ssym "a" :: SymInteger) + testCase "safeNeg on symbolic" $ do + safeNeg (ssym "a" :: SymInteger) + @=? (mrgSingle $ SymInteger $ pevalUMinusNumTerm (ssymTerm "a") :: ExceptT ArithException UnionM SymInteger) + safeNeg' (const ()) (ssym "a" :: SymInteger) @=? (mrgSingle $ SymInteger $ pevalUMinusNumTerm (ssymTerm "a") :: ExceptT () UnionM SymInteger), testProperty "safeMinus on concrete" $ \(i :: Integer, j :: Integer) -> - ioProperty $ - safeMinus () (con i :: SymInteger) (con j) + ioProperty $ do + safeMinus (con i :: SymInteger) (con j) + @=? (mrgSingle $ con $ i - j :: ExceptT ArithException UnionM SymInteger) + safeMinus' (const ()) (con i :: SymInteger) (con j) @=? (mrgSingle $ con $ i - j :: ExceptT () UnionM SymInteger), - testCase "safeMinus on symbolic" $ - safeMinus () (ssym "a" :: SymInteger) (ssym "b") + testCase "safeMinus on symbolic" $ do + safeMinus (ssym "a" :: SymInteger) (ssym "b") + @=? (mrgSingle $ SymInteger $ pevalMinusNumTerm (ssymTerm "a") (ssymTerm "b") :: ExceptT ArithException UnionM SymInteger) + safeMinus' (const ()) (ssym "a" :: SymInteger) (ssym "b") @=? (mrgSingle $ SymInteger $ pevalMinusNumTerm (ssymTerm "a") (ssymTerm "b") :: ExceptT () UnionM SymInteger) ], testGroup @@ -655,28 +667,40 @@ symPrimTests = ioProperty $ let iint = fromIntegral i :: Integer jint = fromIntegral j - in safeAdd () (toSym i :: SymIntN 8) (toSym j) + in safeAdd (toSym i :: SymIntN 8) (toSym j) @=? mrgIf - (iint + jint ==~ fromIntegral (i + j)) - (mrgSingle $ toSym $ i + j :: ExceptT () UnionM (SymIntN 8)) - (throwError ()), + (iint + jint <~ fromIntegral (i + j)) + (throwError Underflow) + ( mrgIf + (iint + jint >~ fromIntegral (i + j)) + (throwError Overflow) + (mrgSingle $ toSym $ i + j :: ExceptT ArithException UnionM (SymIntN 8)) + ), testProperty "safeMinus on concrete" $ \(i :: Int8, j :: Int8) -> ioProperty $ let iint = fromIntegral i :: Integer jint = fromIntegral j - in safeMinus () (toSym i :: SymIntN 8) (toSym j) + in safeMinus (toSym i :: SymIntN 8) (toSym j) @=? mrgIf - (iint - jint ==~ fromIntegral (i - j)) - (mrgSingle $ toSym $ i - j :: ExceptT () UnionM (SymIntN 8)) - (throwError ()), + (iint - jint <~ fromIntegral (i - j)) + (throwError Underflow) + ( mrgIf + (iint - jint >~ fromIntegral (i - j)) + (throwError Overflow) + (mrgSingle $ toSym $ i - j :: ExceptT ArithException UnionM (SymIntN 8)) + ), testProperty "safeNeg on concrete" $ \(i :: Int8) -> ioProperty $ let iint = fromIntegral i :: Integer - in safeNeg () (toSym i :: SymIntN 8) + in safeNeg (toSym i :: SymIntN 8) @=? mrgIf - (-iint ==~ fromIntegral (-i)) - (mrgSingle $ toSym $ -i :: ExceptT () UnionM (SymIntN 8)) - (throwError ()) + (-iint <~ fromIntegral (-i)) + (throwError Underflow) + ( mrgIf + (-iint >~ fromIntegral (-i)) + (throwError Overflow) + (mrgSingle $ toSym $ -i :: ExceptT ArithException UnionM (SymIntN 8)) + ) ], testGroup "WordN" @@ -684,28 +708,40 @@ symPrimTests = ioProperty $ let iint = fromIntegral i :: Integer jint = fromIntegral j - in safeAdd () (toSym i :: SymWordN 8) (toSym j) + in safeAdd (toSym i :: SymWordN 8) (toSym j) @=? mrgIf - (iint + jint ==~ fromIntegral (i + j)) - (mrgSingle $ toSym $ i + j :: ExceptT () UnionM (SymWordN 8)) - (throwError ()), + (iint + jint <~ fromIntegral (i + j)) + (throwError Underflow) + ( mrgIf + (iint + jint >~ fromIntegral (i + j)) + (throwError Overflow) + (mrgSingle $ toSym $ i + j :: ExceptT ArithException UnionM (SymWordN 8)) + ), testProperty "safeMinus on concrete" $ \(i :: Word8, j :: Word8) -> ioProperty $ let iint = fromIntegral i :: Integer jint = fromIntegral j - in safeMinus () (toSym i :: SymWordN 8) (toSym j) + in safeMinus (toSym i :: SymWordN 8) (toSym j) @=? mrgIf - (iint - jint ==~ fromIntegral (i - j)) - (mrgSingle $ toSym $ i - j :: ExceptT () UnionM (SymWordN 8)) - (throwError ()), + (iint - jint <~ fromIntegral (i - j)) + (throwError Underflow) + ( mrgIf + (iint - jint >~ fromIntegral (i - j)) + (throwError Overflow) + (mrgSingle $ toSym $ i - j :: ExceptT ArithException UnionM (SymWordN 8)) + ), testProperty "safeNeg on concrete" $ \(i :: Word8) -> ioProperty $ let iint = fromIntegral i :: Integer - in safeNeg () (toSym i :: SymWordN 8) + in safeNeg (toSym i :: SymWordN 8) @=? mrgIf - (-iint ==~ fromIntegral (-i)) - (mrgSingle $ toSym $ -i :: ExceptT () UnionM (SymWordN 8)) - (throwError ()) + (-iint <~ fromIntegral (-i)) + (throwError Underflow) + ( mrgIf + (-iint >~ fromIntegral (-i)) + (throwError Overflow) + (mrgSingle $ toSym $ -i :: ExceptT ArithException UnionM (SymWordN 8)) + ) ] ], testGroup From a1bffac80007eae95a457821967724322a7140ad Mon Sep 17 00:00:00 2001 From: Sirui Lu Date: Sun, 26 Feb 2023 23:04:03 -0800 Subject: [PATCH 6/6] :bug: fix macOS build issues --- src/Grisette/Core/Data/Class/SafeArith.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Grisette/Core/Data/Class/SafeArith.hs b/src/Grisette/Core/Data/Class/SafeArith.hs index a9f4bf6e..928cc496 100644 --- a/src/Grisette/Core/Data/Class/SafeArith.hs +++ b/src/Grisette/Core/Data/Class/SafeArith.hs @@ -395,7 +395,7 @@ instance SafeLinearArith (Either BitwidthMismatch ArithException) type where \ safeAdd' t (type (l :: ctype l)) (type (r :: ctype r)) = merge (\ case sameNat (Proxy @l) (Proxy @r) of \ Just Refl -> type <$> safeAdd' (t . Right) l r; \ - _ -> throwError $ t $ Left BitwidthMismatch); \ + _ -> let t' = t; t''' = t in throwError $ t' $ Left BitwidthMismatch); \ safeMinus (type (l :: ctype l)) (type (r :: ctype r)) = merge (\ case sameNat (Proxy @l) (Proxy @r) of \ Just Refl -> type <$> safeMinus' Right l r; \ @@ -403,7 +403,7 @@ instance SafeLinearArith (Either BitwidthMismatch ArithException) type where \ safeMinus' t (type (l :: ctype l)) (type (r :: ctype r)) = merge (\ case sameNat (Proxy @l) (Proxy @r) of \ Just Refl -> type <$> safeMinus' (t . Right) l r; \ - _ -> throwError $ t $ Left BitwidthMismatch); \ + _ -> let t' = t; t''' = t in throwError $ t' $ Left BitwidthMismatch); \ safeNeg (type l) = merge $ type <$> safeNeg' Right l; \ safeNeg' t (type l) = merge $ type <$> safeNeg' (t . Right) l