Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce SafeDivision and SafeLinearArith #56

Merged
merged 7 commits into from Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions grisette.cabal
Expand Up @@ -60,9 +60,9 @@ library
Grisette.Core.Data.Class.ExtractSymbolics
Grisette.Core.Data.Class.Function
Grisette.Core.Data.Class.GenSym
Grisette.Core.Data.Class.Integer
Grisette.Core.Data.Class.Mergeable
Grisette.Core.Data.Class.ModelOps
Grisette.Core.Data.Class.SafeArith
Grisette.Core.Data.Class.SimpleMergeable
Grisette.Core.Data.Class.Solvable
Grisette.Core.Data.Class.Solver
Expand Down Expand Up @@ -93,7 +93,7 @@ library
Grisette.IR.SymPrim.Data.Prim.PartialEval.Bool
Grisette.IR.SymPrim.Data.Prim.PartialEval.BV
Grisette.IR.SymPrim.Data.Prim.PartialEval.GeneralFun
Grisette.IR.SymPrim.Data.Prim.PartialEval.Integer
Grisette.IR.SymPrim.Data.Prim.PartialEval.Integral
Grisette.IR.SymPrim.Data.Prim.PartialEval.Num
Grisette.IR.SymPrim.Data.Prim.PartialEval.PartialEval
Grisette.IR.SymPrim.Data.Prim.PartialEval.TabularFun
Expand Down Expand Up @@ -192,7 +192,7 @@ test-suite spec
Grisette.IR.SymPrim.Data.Prim.BitsTests
Grisette.IR.SymPrim.Data.Prim.BoolTests
Grisette.IR.SymPrim.Data.Prim.BVTests
Grisette.IR.SymPrim.Data.Prim.IntegerTests
Grisette.IR.SymPrim.Data.Prim.IntegralTests
Grisette.IR.SymPrim.Data.Prim.ModelTests
Grisette.IR.SymPrim.Data.Prim.NumTests
Grisette.IR.SymPrim.Data.Prim.TabularFunTests
Expand Down
92 changes: 84 additions & 8 deletions src/Grisette/Backend/SBV/Data/SMT/Lowering.hs
Expand Up @@ -339,14 +339,38 @@ lowerSinglePrimImpl' config t@(GeneralFunApplyTerm _ (f :: Term (b --> a)) (arg
addResult @integerBitWidth t g
return g
_ -> translateBinaryError "generalApply" (R.typeRep @(b --> a)) (R.typeRep @b) (R.typeRep @a)
lowerSinglePrimImpl' config t@(DivIntegerTerm _ arg1 arg2) =
lowerSinglePrimImpl' config t@(DivIntegralTerm _ arg1 arg2) =
case (config, R.typeRep @a) of
(ResolvedConfig {}, IntegerType) -> lowerBinaryTerm' config t arg1 arg2 SBV.sDiv
ResolvedSDivisibleType -> lowerBinaryTerm' config t arg1 arg2 SBV.sDiv
_ -> translateBinaryError "div" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a)
lowerSinglePrimImpl' config t@(ModIntegerTerm _ arg1 arg2) =
lowerSinglePrimImpl' config t@(ModIntegralTerm _ arg1 arg2) =
case (config, R.typeRep @a) of
(ResolvedConfig {}, IntegerType) -> lowerBinaryTerm' config t arg1 arg2 SBV.sMod
ResolvedSDivisibleType -> lowerBinaryTerm' config t arg1 arg2 SBV.sMod
_ -> translateBinaryError "mod" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a)
lowerSinglePrimImpl' config t@(QuotIntegralTerm _ arg1 arg2) =
case (config, R.typeRep @a) of
ResolvedSDivisibleType -> lowerBinaryTerm' config t arg1 arg2 SBV.sQuot
_ -> translateBinaryError "quot" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a)
lowerSinglePrimImpl' config t@(RemIntegralTerm _ arg1 arg2) =
case (config, R.typeRep @a) of
ResolvedSDivisibleType -> lowerBinaryTerm' config t arg1 arg2 SBV.sRem
_ -> translateBinaryError "rem" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a)
lowerSinglePrimImpl' config t@(DivBoundedIntegralTerm _ arg1 arg2) =
case (config, R.typeRep @a) of
ResolvedSDivisibleType -> lowerBinaryTerm' config t arg1 arg2 SBV.sDiv
_ -> translateBinaryError "div" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a)
lowerSinglePrimImpl' config t@(ModBoundedIntegralTerm _ arg1 arg2) =
case (config, R.typeRep @a) of
ResolvedSDivisibleType -> lowerBinaryTerm' config t arg1 arg2 SBV.sMod
_ -> translateBinaryError "mod" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a)
lowerSinglePrimImpl' config t@(QuotBoundedIntegralTerm _ arg1 arg2) =
case (config, R.typeRep @a) of
ResolvedSDivisibleType -> lowerBinaryTerm' config t arg1 arg2 SBV.sQuot
_ -> translateBinaryError "quot" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a)
lowerSinglePrimImpl' config t@(RemBoundedIntegralTerm _ arg1 arg2) =
case (config, R.typeRep @a) of
ResolvedSDivisibleType -> lowerBinaryTerm' config t arg1 arg2 SBV.sRem
_ -> translateBinaryError "rem" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a)
lowerSinglePrimImpl' _ _ = undefined

buildUTFun11 ::
Expand Down Expand Up @@ -717,14 +741,38 @@ lowerSinglePrimImpl config t@(GeneralFunApplyTerm _ (f :: Term (b --> a)) (arg :
let g = l1 l2
return (addBiMapIntermediate (SomeTerm t) (toDyn g) m2, g)
_ -> translateBinaryError "generalApply" (R.typeRep @(b --> a)) (R.typeRep @b) (R.typeRep @a)
lowerSinglePrimImpl config t@(DivIntegerTerm _ arg1 arg2) m =
lowerSinglePrimImpl config t@(DivIntegralTerm _ arg1 arg2) m =
case (config, R.typeRep @a) of
ResolvedSDivisibleType -> lowerBinaryTerm config t arg1 arg2 SBV.sDiv m
_ -> translateBinaryError "div" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a)
lowerSinglePrimImpl config t@(ModIntegralTerm _ arg1 arg2) m =
case (config, R.typeRep @a) of
ResolvedSDivisibleType -> lowerBinaryTerm config t arg1 arg2 SBV.sMod m
_ -> translateBinaryError "mod" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a)
lowerSinglePrimImpl config t@(QuotIntegralTerm _ arg1 arg2) m =
case (config, R.typeRep @a) of
ResolvedSDivisibleType -> lowerBinaryTerm config t arg1 arg2 SBV.sQuot m
_ -> translateBinaryError "quot" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a)
lowerSinglePrimImpl config t@(RemIntegralTerm _ arg1 arg2) m =
case (config, R.typeRep @a) of
ResolvedSDivisibleType -> lowerBinaryTerm config t arg1 arg2 SBV.sRem m
_ -> translateBinaryError "rem" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a)
lowerSinglePrimImpl config t@(DivBoundedIntegralTerm _ arg1 arg2) m =
case (config, R.typeRep @a) of
(ResolvedConfig {}, IntegerType) -> lowerBinaryTerm config t arg1 arg2 SBV.sDiv m
ResolvedSDivisibleType -> lowerBinaryTerm config t arg1 arg2 SBV.sDiv m
_ -> translateBinaryError "div" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a)
lowerSinglePrimImpl config t@(ModIntegerTerm _ arg1 arg2) m =
lowerSinglePrimImpl config t@(ModBoundedIntegralTerm _ arg1 arg2) m =
case (config, R.typeRep @a) of
(ResolvedConfig {}, IntegerType) -> lowerBinaryTerm config t arg1 arg2 SBV.sMod m
ResolvedSDivisibleType -> lowerBinaryTerm config t arg1 arg2 SBV.sMod m
_ -> translateBinaryError "mod" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a)
lowerSinglePrimImpl config t@(QuotBoundedIntegralTerm _ arg1 arg2) m =
case (config, R.typeRep @a) of
ResolvedSDivisibleType -> lowerBinaryTerm config t arg1 arg2 SBV.sQuot m
_ -> translateBinaryError "quot" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a)
lowerSinglePrimImpl config t@(RemBoundedIntegralTerm _ arg1 arg2) m =
case (config, R.typeRep @a) of
ResolvedSDivisibleType -> lowerBinaryTerm config t arg1 arg2 SBV.sRem m
_ -> translateBinaryError "rem" (R.typeRep @a) (R.typeRep @a) (R.typeRep @a)
lowerSinglePrimImpl _ _ _ = error "Should never happen"

bvIsNonZeroFromGEq1 :: forall w r. (1 <= w) => ((SBV.BVIsNonZero w) => r) -> r
Expand Down Expand Up @@ -1171,6 +1219,34 @@ pattern ResolvedNumType ::
(GrisetteSMTConfig integerBitWidth, R.TypeRep s)
pattern ResolvedNumType <- (resolveNumTypeView -> Just DictNumType)

type SDivisibleTypeConstraint integerBitWidth s s' =
( SimpleTypeConstraint integerBitWidth s s',
SBV.SDivisible (SBV.SBV s'),
Integral s
)

data DictSDivisibleType integerBitWidth s where
DictSDivisibleType ::
forall integerBitWidth s s'.
(SDivisibleTypeConstraint integerBitWidth s s') =>
DictSDivisibleType integerBitWidth s

resolveSDivisibleTypeView :: TypeResolver DictSDivisibleType
resolveSDivisibleTypeView (ResolvedConfig {}, s) = case s of
IntegerType -> Just DictSDivisibleType
SignedBVType _ -> Just DictSDivisibleType
UnsignedBVType _ -> Just DictSDivisibleType
_ -> Nothing
resolveSDivisibleTypeView _ = error "Should never happen, make compiler happy"

pattern ResolvedSDivisibleType ::
forall integerBitWidth s.
(SupportedPrim s) =>
forall s'.
SDivisibleTypeConstraint integerBitWidth s s' =>
(GrisetteSMTConfig integerBitWidth, R.TypeRep s)
pattern ResolvedSDivisibleType <- (resolveSDivisibleTypeView -> Just DictSDivisibleType)

type NumOrdTypeConstraint integerBitWidth s s' =
( NumTypeConstraint integerBitWidth s s',
SBV.OrdSymbolic (SBV.SBV s'),
Expand Down
23 changes: 8 additions & 15 deletions src/Grisette/Core.hs
Expand Up @@ -190,9 +190,8 @@ module Grisette.Core
someBVExtract',
SizedBV (..),
sizedBVExtract,
SignedDivMod (..),
UnsignedDivMod (..),
SignedQuotRem (..),
SafeDivision (..),
SafeLinearArith (..),
SymIntegerOp,
Function (..),

Expand Down Expand Up @@ -794,6 +793,7 @@ module Grisette.Core
TransformError (..),
symAssert,
symAssume,
symAssertWith,
symAssertTransformableError,
symThrowTransformableError,

Expand Down Expand Up @@ -902,30 +902,23 @@ module Grisette.Core
-- deriving (Mergeable, SEq) via (Default Error)
-- :}
--
-- Then we define how to transform the generic errors to the error type.
--
-- >>> :{
-- instance TransformError ArithException Error where
-- transformError _ = Arith
-- instance TransformError AssertionError Error where
-- transformError _ = Assert
-- :}
--
-- Then we can perform the symbolic evaluation. The `divs` function throws
-- 'ArithException' when the divisor is 0, which would be transformed to
-- @Arith@, and the `symAssert` function would throw 'AssertionError' when
-- the condition is false, which would be transformed to @Assert@.
--
-- >>> let x = "x" :: SymInteger
-- >>> let y = "y" :: SymInteger
-- >>> assert = symAssertWith Assert
-- >>> sdiv = safeDiv' (const Arith)
-- >>> :{
-- -- equivalent concrete program:
-- -- let x = x `div` y
-- -- if z > 0 then assert (x >= y) else return ()
-- res :: ExceptT Error UnionM ()
-- res = do
-- z <- x `divs` y
-- mrgIf (z >~ 0) (symAssert (x >=~ y)) (return ())
-- z <- x `sdiv` y
-- mrgIf (z >~ 0) (assert (x >=~ y)) (return ())
-- :}
--
-- Then we can ask the solver to find a counter-example that would lead to
Expand Down Expand Up @@ -1033,10 +1026,10 @@ import Grisette.Core.Data.Class.Evaluate
import Grisette.Core.Data.Class.ExtractSymbolics
import Grisette.Core.Data.Class.Function
import Grisette.Core.Data.Class.GenSym
import Grisette.Core.Data.Class.Integer
import Grisette.Core.Data.Class.Mergeable
import Grisette.Core.Data.Class.ModelOps
import Grisette.Core.Data.Class.SOrd
import Grisette.Core.Data.Class.SafeArith
import Grisette.Core.Data.Class.SimpleMergeable
import Grisette.Core.Data.Class.Solvable
import Grisette.Core.Data.Class.Solver
Expand Down