Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for modulo on semirings #149

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/LeftPadTutorial.gr
Expand Up @@ -135,4 +135,4 @@ input = Cons 1 (Cons 2 (Cons 3 Nil))
-- vector length 6

main : Vec 6 Int
main = leftPad input [0] (S (S (S (S (S (S Z))))))
main = leftPad input [0] (S (S (S (S (S (S Z))))))
36 changes: 36 additions & 0 deletions examples/modSR.gr
@@ -0,0 +1,36 @@
-- examples using the 'Mod k' semiring

-- todo: support arbitrary nats for the Mod index

isAssocPlus1 : forall {a b : Type, n m k : Mod 7} . (a [(n + m) + k] -> b) -> (a [n + (m + k)] -> b)
isAssocPlus1 f = f

isAssocPlus2 : forall {a b : Type, n m k : Mod 7} . (a [n + (m + k)] -> b) -> (a [(n + m) + k] -> b)
isAssocPlus2 f = f

isAssocMult1 : forall {a b : Type, n m k : Mod 7} . (a [(n * m) * k] -> b) -> (a [n * (m * k)] -> b)
isAssocMult1 f = f

isAssocMult2 : forall {a b : Type, n m k : Mod 7} . (a [n * (m * k)] -> b) -> (a [(n * m) * k] -> b)
isAssocMult2 f = f

isDistrib1 : forall {a b : Type, n m k : Mod 7} . (a [(n + m) * k] -> b) -> (a [(n * k) + (m * k)] -> b)
isDistrib1 f = f

isDistrib2 : forall {a b : Type, n m k : Mod 7} . (a [(n * k) + (m * k)] -> b) -> (a [(n + m) * k] -> b)
isDistrib2 f = f

isCommutePlus : forall {a b : Type, n m : Mod 7} . (a [n + m] -> b) -> (a [m + n] -> b)
isCommutePlus f = f

isCommuteMult : forall {a b : Type, n m : Mod 7} . (a [n * m] -> b) -> (a [m * n] -> b)
isCommuteMult f = f

zeroPlusIdentity : forall {a b : Type, n : Mod 7} . (a [n + (0 : Mod 7)] -> b) -> (a [n] -> b)
zeroPlusIdentity f = f

oneTimesIdentity : forall {a b : Type, n : Mod 7} . (a [n * (1 : Mod 7)] -> b) -> (a [n] -> b)
oneTimesIdentity f = f

zeroAbsorbs : forall {a b : Type, n : Mod 7} . (a [n * (0 : Mod 7)] -> b) -> (a [0 : Mod 7] -> b)
zeroAbsorbs f = f
8 changes: 4 additions & 4 deletions examples/ooz.gr
@@ -1,13 +1,13 @@
-- examples using a semiring {0, 1} with 1 + 1 = 0

oozZero : forall a b. b -> a [0 : OOZ] -> b
oozZero : forall a b. b -> a [0 : Mod 2] -> b
oozZero f [x] = f

oozOne : forall a b. (a -> b) -> a [1 : OOZ] -> b
oozOne : forall a b. (a -> b) -> a [1 : Mod 2] -> b
oozOne f [x] = f x

oozTwo : forall a b. (f : a -> a -> b) -> a [0 : OOZ] -> b
oozTwo : forall a b. (f : a -> a -> b) -> a [0 : Mod 2] -> b
oozTwo f [x] = f x x

oozThree : forall a b. (f : a -> a -> a -> b) -> a [1 : OOZ] -> b
oozThree : forall a b. (f : a -> a -> a -> b) -> a [1 : Mod 2] -> b
oozThree f [x] = f x x x
23 changes: 18 additions & 5 deletions frontend/src/Language/Granule/Checker/Constraints.hs
Expand Up @@ -148,6 +148,10 @@ freshCVarScoped quant name (isInterval -> Just t) q k =
, SInterval solverVarLb solverVarUb )
))

freshCVarScoped quant name (isMod -> Just n) q k =
quant q name (\solverVar ->
k (solverVar .>= 0 .&& (fromIntegral n) .> (0 :: SInteger), SMod solverVar n))

freshCVarScoped quant name (isProduct -> Just (t1, t2)) q k =

freshCVarScoped quant (name <> ".fst") t1 q
Expand All @@ -169,7 +173,6 @@ freshCVarScoped quant name (TyCon conName) q k =
.|| solverVar .== literal publicRepresentation
.|| solverVar .== literal unusedRepresentation
, SLevel solverVar)
"OOZ" -> k (solverVar .== 0 .|| solverVar .== 1, SOOZ (ite (solverVar .== 0) sFalse sTrue))
k -> solverError $ "I don't know how to make a fresh solver variable of type " <> show conName)

freshCVarScoped quant name t q k | t == extendedNat = do
Expand Down Expand Up @@ -313,6 +316,9 @@ compileCoeffect (CNat n) k _ | k == nat =
compileCoeffect (CNat n) k _ | k == extendedNat =
return (SExtNat . fromInteger . toInteger $ n, sTrue)

compileCoeffect (CNat n) (isMod -> Just i) _ =
pure (SMod (fromInteger . toInteger $ n) i, sTrue)

compileCoeffect (CFloat r) (TyCon k) _ | internalName k == "Q" =
return (SFloat . fromRational $ r, sTrue)

Expand All @@ -339,6 +345,9 @@ compileCoeffect c@(CTimes n m) k vars =
compileCoeffect c@(CMinus n m) k vars =
bindM2And symGradeMinus (compileCoeffect n k vars) (compileCoeffect m k vars)

compileCoeffect c@(CMod n m) k vars =
bindM2And symGradeMod (compileCoeffect n k vars) (compileCoeffect m k vars)

compileCoeffect c@(CExpon n m) k vars = do
(g1, p1) <- compileCoeffect n k vars
(g2, p2) <- compileCoeffect m k vars
Expand All @@ -360,7 +369,6 @@ compileCoeffect (CZero k') k vars =
"Nat" -> return (SNat 0, sTrue)
"Q" -> return (SFloat (fromRational 0), sTrue)
"Set" -> return (SSet (S.fromList []), sTrue)
"OOZ" -> return (SOOZ sFalse, sTrue)
_ -> solverError $ "I don't know how to compile a 0 for " <> pretty k'
(otherK', otherK) | (otherK' == extendedNat || otherK == extendedNat) ->
return (SExtNat 0, sTrue)
Expand All @@ -375,6 +383,9 @@ compileCoeffect (CZero k') k vars =
(compileCoeffect (CZero t) t' vars)
(compileCoeffect (CZero t) t' vars)

(isMod -> Just i, isMod -> Just j) | i == j ->
pure (SMod 0 i, sTrue)

(TyVar _, _) -> return (SUnknown (SynLeaf (Just 0)), sTrue)
_ -> solverError $ "I don't know how to compile a 0 for " <> pretty k'

Expand All @@ -386,7 +397,6 @@ compileCoeffect (COne k') k vars =
"Nat" -> return (SNat 1, sTrue)
"Q" -> return (SFloat (fromRational 1), sTrue)
"Set" -> return (SSet (S.fromList []), sTrue)
"OOZ" -> return (SOOZ sTrue, sTrue)
_ -> solverError $ "I don't know how to compile a 1 for " <> pretty k'

(otherK', otherK) | (otherK' == extendedNat || otherK == extendedNat) ->
Expand All @@ -403,6 +413,9 @@ compileCoeffect (COne k') k vars =
(compileCoeffect (COne t) t' vars)
(compileCoeffect (COne t) t' vars)

(isMod -> Just i, isMod -> Just j) | i == j ->
pure (SMod 1 i, sTrue)

(TyVar _, _) -> return (SUnknown (SynLeaf (Just 1)), sTrue)

_ -> solverError $ "I don't know how to compile a 1 for " <> pretty k'
Expand Down Expand Up @@ -453,7 +466,7 @@ approximatedByOrEqualConstraint (SNat n) (SNat m) = return $ n .== m
approximatedByOrEqualConstraint (SFloat n) (SFloat m) = return $ n .<= m
approximatedByOrEqualConstraint SPoint SPoint = return $ sTrue
approximatedByOrEqualConstraint (SExtNat x) (SExtNat y) = return $ x .== y
approximatedByOrEqualConstraint (SOOZ s) (SOOZ r) = pure $ s .== r
approximatedByOrEqualConstraint (SMod s i) (SMod r j) | i == j = pure $ sMod s (fromIntegral i) .== sMod r (fromIntegral i)
approximatedByOrEqualConstraint (SSet s) (SSet t) =
return $ if s == t then sTrue else sFalse

Expand Down Expand Up @@ -618,4 +631,4 @@ bindM2And' k ma mb = do
return (p .&& q .&& c)

liftM2And :: Monad m => (t1 -> t2 -> b) -> m (t1, SBool) -> m (t2, SBool) -> m (b, SBool)
liftM2And k = bindM2And (\a b -> return (k a b))
liftM2And k = bindM2And (\a b -> return (k a b))
25 changes: 14 additions & 11 deletions frontend/src/Language/Granule/Checker/Constraints/SymbolicGrades.hs
Expand Up @@ -35,10 +35,8 @@ data SGrade =
-- Single point coeffect (not exposed at the moment)
| SPoint
| SProduct { sfst :: SGrade, ssnd :: SGrade }
-- | Coeffect with 1 + 1 = 0. False is 0, True is 1.
-- |
-- | Grade '0' denotes even usage, and grade '1' denotes odd usage.
| SOOZ SBool
-- | SMod k n is a Nat @k mod n@.
| SMod SInteger Int

-- A kind of embedded uninterpreted sort which can accept some equations
-- Used for doing some limited solving over poly coeffect grades
Expand Down Expand Up @@ -98,7 +96,7 @@ sLtTree (SynPlus s s') (SynPlus t t') = liftM2 (.&&) (sLtTree s t) (sLtTree s'
sLtTree (SynTimes s s') (SynTimes t t') = liftM2 (.&&) (sLtTree s t) (sLtTree s' t')
sLtTree (SynMeet s s') (SynMeet t t') = liftM2 (.&&) (sLtTree s t) (sLtTree s' t')
sLtTree (SynJoin s s') (SynJoin t t') = liftM2 (.&&) (sLtTree s t) (sLtTree s' t')
sLtTree (SynMerge sb s s') (SynMerge sb' t t') =
sLtTree (SynMerge sb s s') (SynMerge sb' t t') =
liftM2 (.&&) (return $ sb .== sb') (liftM2 (.&&) (sLtTree s t) (sLtTree s' t'))
sLtTree (SynLeaf Nothing) (SynLeaf Nothing) = return $ sFalse
sLtTree (SynLeaf (Just n)) (SynLeaf (Just n')) = return $ n .< n'
Expand All @@ -115,7 +113,7 @@ match (SInterval s1 s2) (SInterval t1 t2) = match s1 t1 && match t1 t2
match SPoint SPoint = True
match (SProduct s1 s2) (SProduct t1 t2) = match s1 t1 && match s2 t2
match (SUnknown _) (SUnknown _) = True
match (SOOZ _) (SOOZ _) = True
match (SMod _ i) (SMod _ j) = i == j
match _ _ = False

isSProduct :: SGrade -> Bool
Expand Down Expand Up @@ -219,7 +217,7 @@ symGradeEq (SLevel n) (SLevel n') = return $ n .== n'
symGradeEq (SSet n) (SSet n') = solverError "Can't compare symbolic sets yet"
symGradeEq (SExtNat n) (SExtNat n') = return $ n .== n'
symGradeEq SPoint SPoint = return $ sTrue
symGradeEq (SOOZ s) (SOOZ r) = pure $ s .== r
symGradeEq (SMod s i) (SMod r j) | i == j = pure $ sMod s (fromIntegral i) .== sMod r (fromIntegral i)
symGradeEq s t | isSProduct s || isSProduct t =
either solverError id (applyToProducts symGradeEq (.&&) (const sTrue) s t)

Expand Down Expand Up @@ -276,8 +274,7 @@ symGradePlus (SInterval lb1 ub1) (SInterval lb2 ub2) =
symGradePlus SPoint SPoint = return $ SPoint
symGradePlus s t | isSProduct s || isSProduct t =
either solverError id (applyToProducts symGradePlus SProduct id s t)
-- 1 + 1 = 0
symGradePlus (SOOZ s) (SOOZ r) = pure . SOOZ $ ite s (sNot r) r
symGradePlus (SMod s i) (SMod r j) | i == j = pure $ SMod (s + r) i

-- Direct encoding of additive unit
symGradePlus (SUnknown t@(SynLeaf (Just u))) (SUnknown t'@(SynLeaf (Just u'))) =
Expand Down Expand Up @@ -309,7 +306,7 @@ symGradeTimes (SLevel lev1) (SLevel lev2) = return $
(SLevel $ lev1 `smax` lev2)
symGradeTimes (SFloat n1) (SFloat n2) = return $ SFloat $ n1 * n2
symGradeTimes (SExtNat x) (SExtNat y) = return $ SExtNat (x * y)
symGradeTimes (SOOZ s) (SOOZ r) = pure . SOOZ $ s .&& r
symGradeTimes (SMod s i) (SMod r j) | i == j = pure $ SMod (s * r) i

symGradeTimes (SInterval lb1 ub1) (SInterval lb2 ub2) =
liftM2 SInterval (comb symGradeMeet) (comb symGradeJoin)
Expand Down Expand Up @@ -363,6 +360,12 @@ symGradeMinus s t | isSProduct s || isSProduct t =
either solverError id (applyToProducts symGradeMinus SProduct id s t)
symGradeMinus s t = solverError $ cannotDo "minus" s t

-- | Mod operation on symbolic grades
symGradeMod :: SGrade -> SGrade -> Symbolic SGrade
-- TODO: perhaps fail here if dividing by 0? (2020-07-10)
symGradeMod s@(SNat n1) t@(SNat n2) = pure . SNat $ sMod n1 n2
symGradeMod s t = solverError $ cannotDo "mod" s t

cannotDo :: String -> SGrade -> SGrade -> String
cannotDo op (SUnknown s) (SUnknown t) =
"It is unknown whether "
Expand All @@ -375,4 +378,4 @@ cannotDo op s t =
"Cannot perform symbolic operation `"
<> op <> "` on "
<> show s <> " and "
<> show t
<> show t
3 changes: 2 additions & 1 deletion frontend/src/Language/Granule/Checker/Kinds.hs
Expand Up @@ -242,6 +242,7 @@ inferCoeffectTypeInContext s _ (CTimes c c') = fmap fst2 $ mguCoeffectTypesFromC
inferCoeffectTypeInContext s _ (CMeet c c') = fmap fst2 $ mguCoeffectTypesFromCoeffects s c c'
inferCoeffectTypeInContext s _ (CJoin c c') = fmap fst2 $ mguCoeffectTypesFromCoeffects s c c'
inferCoeffectTypeInContext s _ (CExpon c c') = fmap fst2 $ mguCoeffectTypesFromCoeffects s c c'
inferCoeffectTypeInContext s _ (CMod c c') = fmap fst2 $ mguCoeffectTypesFromCoeffects s c c'

-- Coeffect variables should have a type in the cvar->kind context
inferCoeffectTypeInContext s ctxt (CVar cvar) = do
Expand Down Expand Up @@ -322,4 +323,4 @@ isEffectTypeFromKind s kind =
if isEffectKind kind'
then return $ Right effTy
else return $ Left kind
_ -> return $ Left kind
_ -> return $ Left kind
2 changes: 1 addition & 1 deletion frontend/src/Language/Granule/Checker/Primitives.hs
Expand Up @@ -40,7 +40,7 @@ typeConstructors =
, (mkId "Private", (KPromote (TyCon $ mkId "Level"), [], False))
, (mkId "Public", (KPromote (TyCon $ mkId "Level"), [], False))
, (mkId "Unused", (KPromote (TyCon $ mkId "Level"), [], False))
, (mkId "OOZ", (KCoeffect, [], False)) -- 1 + 1 = 0
, (mkId "Mod", (KFun (KPromote (TyCon $ mkId "Nat")) KCoeffect, [], False)) -- modulo semiring
, (mkId "Interval", (KFun KCoeffect KCoeffect, [], False))
, (mkId "Set", (KFun (KVar $ mkId "k") (KFun (kConstr $ mkId "k") KCoeffect), [], False))
-- Channels and protocol types
Expand Down
12 changes: 11 additions & 1 deletion frontend/src/Language/Granule/Checker/Substitution.hs
Expand Up @@ -130,6 +130,11 @@ instance Substitutable Coeffect where
c2' <- substitute subst c2
return $ CProduct c1' c2'

substitute subst (CMod c1 c2) = do
c1' <- substitute subst c1
c2' <- substitute subst c2
return $ CMod c1' c2'

substitute subst (CVar v) =
case lookup v subst of
Just (SubstC c) -> do
Expand Down Expand Up @@ -672,6 +677,11 @@ instance Unifiable Coeffect where
u2 <- unify c2 c2'
u1 <<>> u2

unify (CMod c1 c2) (CMod c1' c2') = do
u1 <- unify c1 c1'
u2 <- unify c2 c2'
u1 <<>> u2

unify (CInfinity k) (CInfinity k') = do
unify k k'

Expand Down Expand Up @@ -747,4 +757,4 @@ updateTyVar s tyVar k = do
| tyVar == kindVar = (name, (k, q)) : rewriteCtxt ctxt
rewriteCtxt ((name, (KVar kindVar, q)) : ctxt)
| tyVar == kindVar = (name, (k, q)) : rewriteCtxt ctxt
rewriteCtxt (x : ctxt) = x : rewriteCtxt ctxt
rewriteCtxt (x : ctxt) = x : rewriteCtxt ctxt
2 changes: 2 additions & 0 deletions frontend/src/Language/Granule/Syntax/Lexer.x
Expand Up @@ -86,6 +86,7 @@ tokens :-
\_ { \p _ -> TokenUnderscore p }
\| { \p s -> TokenPipe p }
\/ { \p s -> TokenForwardSlash p }
"%" { \p s -> TokenPercent p }
"≤" { \p s -> TokenLesserEq p }
"<=" { \p s -> TokenLesserEq p }
"≥" { \p s -> TokenGreaterEq p }
Expand Down Expand Up @@ -170,6 +171,7 @@ data Token
| TokenEmptyHole AlexPosn
| TokenHoleStart AlexPosn
| TokenHoleEnd AlexPosn
| TokenPercent AlexPosn

deriving (Eq, Show, Generic)

Expand Down
4 changes: 3 additions & 1 deletion frontend/src/Language/Granule/Syntax/Parser.y
Expand Up @@ -91,6 +91,7 @@ import Language.Granule.Utils hiding (mkSpan)
'.' { TokenPeriod _ }
'`' { TokenBackTick _ }
'^' { TokenCaret _ }
'%' { TokenPercent _ }
'..' { TokenDotDot _ }
"\\/" { TokenJoin _ }
"/\\" { TokenMeet _ }
Expand Down Expand Up @@ -323,7 +324,7 @@ TyApp :: { Type }
| TyAtom { $1 }

TyJuxt :: { Type }
: TyJuxt '`' TyAtom '`' { TyApp $3 $1 }
: TyJuxt '`' TyAtom '`' { TyApp $3 $1 }
| TyJuxt TyAtom { TyApp $1 $2 }
| TyAtom { $1 }
| TyAtom '+' TyAtom { TyInfix TyOpPlus $1 $3 }
Expand Down Expand Up @@ -369,6 +370,7 @@ Coeffect :: { Coeffect }
| Coeffect '*' Coeffect { CTimes $1 $3 }
| Coeffect '-' Coeffect { CMinus $1 $3 }
| Coeffect '^' Coeffect { CExpon $1 $3 }
| Coeffect '%' Coeffect { CMod $1 $3 }
| Coeffect "/\\" Coeffect { CMeet $1 $3 }
| Coeffect "\\/" Coeffect { CJoin $1 $3 }
| '(' Coeffect ')' { $2 }
Expand Down
2 changes: 2 additions & 0 deletions frontend/src/Language/Granule/Syntax/Pretty.hs
Expand Up @@ -88,6 +88,8 @@ instance Pretty Coeffect where
prettyNested c <> " * " <> prettyNested d
pretty (CMinus c d) =
prettyNested c <> " - " <> prettyNested d
pretty (CMod c1 c2) =
prettyNested c1 <> " % " <> prettyNested c2
pretty (CSet xs) =
"{" <> intercalate "," (map (\(name, t) -> name <> " : " <> prettyNested t) xs) <> "}"
pretty (CSig c t) =
Expand Down