Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): Update about modulo #192

Merged
merged 3 commits into from
Aug 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Jikka.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ library
Jikka.Core.Language.FreeVars
Jikka.Core.Language.LambdaPatterns
Jikka.Core.Language.Lint
Jikka.Core.Language.ModuloExpr
Jikka.Core.Language.QuasiRules
Jikka.Core.Language.RewriteRules
Jikka.Core.Language.Runtime
Expand Down
215 changes: 88 additions & 127 deletions src/Jikka/Core/Convert/PropagateMod.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ module Jikka.Core.Convert.PropagateMod
)
where

import Control.Monad.Trans.Maybe
import Data.List
import Data.Maybe
import Jikka.Common.Alpha
Expand All @@ -23,175 +24,135 @@ import Jikka.Core.Language.Beta
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.Lint
import Jikka.Core.Language.ModuloExpr
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.TypeCheck
import Jikka.Core.Language.Util

-- | `Mod` is a newtype to avoid mistakes that swapping left and right of mod-op.
newtype Mod = Mod Expr
isModulo' :: Expr -> Expr -> Bool
isModulo' e m = e `isModulo` Modulo m

isModulo' :: Expr -> Mod -> Bool
isModulo' e (Mod m) = case e of
FloorMod' _ m' -> m' == m
ModNegate' _ m' -> m' == m
ModPlus' _ _ m' -> m' == m
ModMinus' _ _ m' -> m' == m
ModMult' _ _ m' -> m' == m
ModInv' _ m' -> m' == m
ModPow' _ _ m' -> m' == m
VecFloorMod' _ _ m' -> m' == m
MatFloorMod' _ _ _ m' -> m' == m
ModMatAp' _ _ _ _ m' -> m' == m
ModMatAdd' _ _ _ _ m' -> m' == m
ModMatMul' _ _ _ _ _ m' -> m' == m
ModMatPow' _ _ _ m' -> m' == m
ModSum' _ m' -> m' == m
ModProduct' _ m' -> m' == m
LitInt' n -> case m of
LitInt' m -> 0 <= n && n < m
_ -> False
Proj' ts _ e | isVectorTy' ts -> e `isModulo'` Mod m
Proj' ts _ e | isMatrixTy' ts -> e `isModulo'` Mod m
Map' _ _ f _ -> f `isModulo'` Mod m
Lam _ _ body -> body `isModulo'` Mod m
e@(App _ _) -> case curryApp e of
(e@(Lam _ _ _), _) -> e `isModulo'` Mod m
(Tuple' ts, es) | isVectorTy' ts -> all (`isModulo'` Mod m) es
(Tuple' ts, es) | isMatrixTy' ts -> all (`isModulo'` Mod m) es
_ -> False
_ -> False
putFloorMod :: MonadAlpha m => Modulo -> Expr -> m (Maybe Expr)
putFloorMod (Modulo m) =
runMaybeT . \case
Negate' e -> return $ ModNegate' e m
Plus' e1 e2 -> return $ ModPlus' e1 e2 m
Minus' e1 e2 -> return $ ModMinus' e1 e2 m
Mult' e1 e2 -> return $ ModMult' e1 e2 m
JustDiv' e1 e2 -> return $ ModMult' e1 (ModInv' e2 m) m
Pow' e1 e2 -> return $ ModPow' e1 e2 m
MatAp' h w e1 e2 -> return $ ModMatAp' h w e1 e2 m
MatAdd' h w e1 e2 -> return $ ModMatAdd' h w e1 e2 m
MatMul' h n w e1 e2 -> return $ ModMatMul' h n w e1 e2 m
MatPow' n e1 e2 -> return $ ModMatPow' n e1 e2 m
Sum' e -> return $ ModSum' e m
Product' e -> return $ ModProduct' e m
LitInt' n -> case m of
LitInt' m -> return $ LitInt' (n `mod` m)
_ -> MaybeT $ return Nothing
Proj' ts i e | isVectorTy' ts -> return $ Proj' ts i (VecFloorMod' (genericLength ts) e m)
Proj' ts i e
| isMatrixTy' ts ->
let (h, w) = fromJust (sizeOfMatrixTy (TupleTy ts))
in return $ Proj' ts i (MatFloorMod' (toInteger h) (toInteger w) e m)
Map' t1 t2 f xs -> do
f <- MaybeT $ putFloorMod (Modulo m) f
return $ Map' t1 t2 f xs
Foldl' t1 t2 f init xs -> do
f <- MaybeT $ putFloorMod (Modulo m) f
return $ Foldl' t1 t2 f init xs
Lam x t body -> do
-- TODO: rename only if required
y <- lift $ genVarName x
body <- lift $ substitute x (Var y) body
body <- MaybeT $ putFloorMod (Modulo m) body
return $ Lam y t body
e@(App _ _) -> case curryApp e of
(f@(Lam _ _ _), args) -> do
f <- MaybeT $ putFloorMod (Modulo m) f
return $ uncurryApp f args
(Tuple' ts, es) | isVectorTy' ts -> do
es' <- lift $ mapM (putFloorMod (Modulo m)) es
if all isNothing es'
then MaybeT $ return Nothing
else return $ uncurryApp (Tuple' ts) (zipWith fromMaybe es es')
(Tuple' ts, es) | isMatrixTy (TupleTy ts) -> do
es' <- lift $ mapM (putFloorMod (Modulo m)) es
if all isNothing es'
then MaybeT $ return Nothing
else return $ uncurryApp (Tuple' ts) (zipWith fromMaybe es es')
_ -> MaybeT $ return Nothing
_ -> MaybeT $ return Nothing

isModulo :: Expr -> Expr -> Bool
isModulo e m = e `isModulo'` Mod m

putFloorMod :: MonadAlpha m => Mod -> Expr -> m (Maybe Expr)
putFloorMod (Mod m) =
let return' = return . Just
in \case
Negate' e -> return' $ ModNegate' e m
Plus' e1 e2 -> return' $ ModPlus' e1 e2 m
Minus' e1 e2 -> return' $ ModMinus' e1 e2 m
Mult' e1 e2 -> return' $ ModMult' e1 e2 m
JustDiv' e1 e2 -> return' $ ModMult' e1 (ModInv' e2 m) m
Pow' e1 e2 -> return' $ ModPow' e1 e2 m
MatAp' h w e1 e2 -> return' $ ModMatAp' h w e1 e2 m
MatAdd' h w e1 e2 -> return' $ ModMatAdd' h w e1 e2 m
MatMul' h n w e1 e2 -> return' $ ModMatMul' h n w e1 e2 m
MatPow' n e1 e2 -> return' $ ModMatPow' n e1 e2 m
Sum' e -> return' $ ModSum' e m
Product' e -> return' $ ModProduct' e m
LitInt' n -> case m of
LitInt' m -> return' $ LitInt' (n `mod` m)
_ -> return Nothing
Proj' ts i e | isVectorTy' ts -> return' $ Proj' ts i (VecFloorMod' (genericLength ts) e m)
Proj' ts i e
| isMatrixTy' ts ->
let (h, w) = fromJust (sizeOfMatrixTy (TupleTy ts))
in return' $ Proj' ts i (MatFloorMod' (toInteger h) (toInteger w) e m)
Map' t1 t2 f xs -> do
f <- putFloorMod (Mod m) f
case f of
Nothing -> return Nothing
Just f -> return' $ Map' t1 t2 f xs
Foldl' t1 t2 f init xs -> do
f <- putFloorMod (Mod m) f
case f of
Nothing -> return Nothing
Just f -> return' $ Foldl' t1 t2 f init xs
Lam x t body -> do
-- TODO: rename only if required
y <- genVarName x
body <- substitute x (Var y) body
body <- putFloorMod (Mod m) body
case body of
Nothing -> return Nothing
Just body -> return' $ Lam y t body
e@(App _ _) -> case curryApp e of
(f@(Lam _ _ _), args) -> do
f <- putFloorMod (Mod m) f
case f of
Nothing -> return Nothing
Just f -> return' $ uncurryApp f args
(Tuple' ts, es) | isVectorTy' ts -> do
es' <- mapM (putFloorMod (Mod m)) es
if all isNothing es'
then return Nothing
else return' $ uncurryApp (Tuple' ts) (zipWith fromMaybe es es')
(Tuple' ts, es) | isMatrixTy (TupleTy ts) -> do
es' <- mapM (putFloorMod (Mod m)) es
if all isNothing es'
then return Nothing
else return' $ uncurryApp (Tuple' ts) (zipWith fromMaybe es es')
_ -> return Nothing
_ -> return Nothing

putFloorModGeneric :: MonadAlpha m => (Expr -> Mod -> m Expr) -> Mod -> Expr -> m Expr
putFloorModGeneric :: MonadAlpha m => (Expr -> Modulo -> m Expr) -> Modulo -> Expr -> m Expr
putFloorModGeneric fallback m e =
if e `isModulo'` m
if e `isModulo` m
then return e
else do
e' <- putFloorMod m e
case e' of
Just e' -> return e'
Nothing -> fallback e m

putFloorModInt :: MonadAlpha m => Mod -> Expr -> m Expr
putFloorModInt = putFloorModGeneric (\e (Mod m) -> return $ FloorMod' e m)

putMapFloorMod :: MonadAlpha m => Mod -> Expr -> m Expr
putMapFloorMod :: MonadAlpha m => Modulo -> Expr -> m Expr
putMapFloorMod = putFloorModGeneric fallback
where
fallback e (Mod m) = do
fallback e (Modulo m) = do
x <- genVarName'
return $ Map' IntTy IntTy (Lam x IntTy (FloorMod' (Var x) m)) e

putVecFloorMod :: (MonadError Error m, MonadAlpha m) => [(VarName, Type)] -> Mod -> Expr -> m Expr
putVecFloorMod :: (MonadError Error m, MonadAlpha m) => [(VarName, Type)] -> Modulo -> Expr -> m Expr
putVecFloorMod env = putFloorModGeneric fallback
where
fallback e (Mod m) = do
fallback e (Modulo m) = do
t <- typecheckExpr env e
case t of
TupleTy ts -> return $ VecFloorMod' (genericLength ts) e m
_ -> throwInternalError $ "not a vector: " ++ formatType t

putMatFloorMod :: (MonadError Error m, MonadAlpha m) => [(VarName, Type)] -> Mod -> Expr -> m Expr
putMatFloorMod :: (MonadError Error m, MonadAlpha m) => [(VarName, Type)] -> Modulo -> Expr -> m Expr
putMatFloorMod env = putFloorModGeneric fallback
where
fallback e (Mod m) = do
fallback e (Modulo m) = do
t <- typecheckExpr env e
case t of
TupleTy ts@(TupleTy ts' : _) -> return $ MatFloorMod' (genericLength ts) (genericLength ts') e m
_ -> throwInternalError $ "not a matrix: " ++ formatType t

rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
rule =
let go1 m f (t1, e1) = Just <$> (f <$> t1 (Mod m) e1 <*> pure m)
go2 m f (t1, e1) (t2, e2) = Just <$> (f <$> t1 (Mod m) e1 <*> t2 (Mod m) e2 <*> pure m)
let go0 :: Expr -> Maybe Expr
go0 e = do
e' <- formatBottomModuloExpr <$> parseModuloExpr e
guard $ e' /= e
return e'
go1 m f (t1, e1) = Just <$> (f <$> t1 (Modulo m) e1 <*> pure m)
go2 m f (t1, e1) (t2, e2) = Just <$> (f <$> t1 (Modulo m) e1 <*> t2 (Modulo m) e2 <*> pure m)
in makeRewriteRule "Jikka.Core.Convert.PropagateMod" $ \env -> \case
ModNegate' e m | not (e `isModulo` m) -> go1 m ModNegate' (putFloorModInt, e)
ModPlus' e1 e2 m | not (e1 `isModulo` m) || not (e2 `isModulo` m) -> go2 m ModPlus' (putFloorModInt, e1) (putFloorModInt, e2)
ModMinus' e1 e2 m | not (e1 `isModulo` m) || not (e2 `isModulo` m) -> go2 m ModMinus' (putFloorModInt, e1) (putFloorModInt, e2)
ModMult' e1 e2 m | not (e1 `isModulo` m) || not (e2 `isModulo` m) -> go2 m ModMult' (putFloorModInt, e1) (putFloorModInt, e2)
ModInv' e m | not (e `isModulo` m) -> go1 m ModInv' (putFloorModInt, e)
ModPow' e1 e2 m | not (e1 `isModulo` m) -> go2 m ModPow' (putFloorModInt, e1) (\_ e -> return e, e2)
ModMatAp' h w e1 e2 m | not (e1 `isModulo` m) || not (e2 `isModulo` m) -> go2 m (ModMatAp' h w) (putMatFloorMod env, e1) (putVecFloorMod env, e2)
ModMatAdd' h w e1 e2 m | not (e1 `isModulo` m) || not (e2 `isModulo` m) -> go2 m (ModMatAdd' h w) (putMatFloorMod env, e1) (putMatFloorMod env, e2)
ModMatMul' h n w e1 e2 m | not (e1 `isModulo` m) || not (e2 `isModulo` m) -> go2 m (ModMatMul' h n w) (putMatFloorMod env, e1) (putMatFloorMod env, e2)
ModMatPow' n e1 e2 m | not (e1 `isModulo` m) -> go2 m (ModMatPow' n) (putMatFloorMod env, e1) (\_ e -> return e, e2)
ModSum' e m | not (e `isModulo` m) -> go1 m ModSum' (putMapFloorMod, e)
ModProduct' e m | not (e `isModulo` m) -> go1 m ModProduct' (putMapFloorMod, e)
e@(ModNegate' _ _) -> return $ go0 e
e@(ModPlus' _ _ _) -> return $ go0 e
e@(ModMinus' _ _ _) -> return $ go0 e
e@(ModMult' _ _ _) -> return $ go0 e
e@(ModInv' _ _) -> return $ go0 e
e@(ModPow' _ _ _) -> return $ go0 e
ModMatAp' h w e1 e2 m | not (e1 `isModulo'` m) || not (e2 `isModulo'` m) -> go2 m (ModMatAp' h w) (putMatFloorMod env, e1) (putVecFloorMod env, e2)
ModMatAdd' h w e1 e2 m | not (e1 `isModulo'` m) || not (e2 `isModulo'` m) -> go2 m (ModMatAdd' h w) (putMatFloorMod env, e1) (putMatFloorMod env, e2)
ModMatMul' h n w e1 e2 m | not (e1 `isModulo'` m) || not (e2 `isModulo'` m) -> go2 m (ModMatMul' h n w) (putMatFloorMod env, e1) (putMatFloorMod env, e2)
ModMatPow' n e1 e2 m | not (e1 `isModulo'` m) -> go2 m (ModMatPow' n) (putMatFloorMod env, e1) (\_ e -> return e, e2)
ModSum' e m | not (e `isModulo'` m) -> go1 m ModSum' (putMapFloorMod, e)
ModProduct' e m | not (e `isModulo'` m) -> go1 m ModProduct' (putMapFloorMod, e)
FloorMod' e m ->
if e `isModulo` m
if e `isModulo'` m
then return $ Just e
else putFloorMod (Mod m) e
else putFloorMod (Modulo m) e
VecFloorMod' _ e m ->
if e `isModulo` m
if e `isModulo'` m
then return $ Just e
else putFloorMod (Mod m) e
else putFloorMod (Modulo m) e
MatFloorMod' _ _ e m ->
if e `isModulo` m
if e `isModulo'` m
then return $ Just e
else putFloorMod (Mod m) e
else putFloorMod (Modulo m) e
_ -> return Nothing

runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
Expand Down
40 changes: 32 additions & 8 deletions src/Jikka/Core/Convert/SpecializeFoldl.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module : Jikka.Core.Convert.SpecializeFoldl
Expand All @@ -22,19 +23,46 @@ where

import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Core.Language.ArithmeticExpr
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars
import Jikka.Core.Language.Lint
import Jikka.Core.Language.ModuloExpr
import Jikka.Core.Language.RewriteRules

convertToSum :: Expr -> Maybe Expr
convertToSum = \case
Foldl' t1 IntTy (Lam2 x2 _ x1 _ body) init xs -> do
(a, b) <- makeAffineFunctionFromArithmeticExpr x2 (parseArithmeticExpr body)
guard $ isOneArithmeticExpr a
return $ Plus' init (Sum' (Map' t1 IntTy (Lam x1 t1 (formatArithmeticExpr b)) xs))
_ -> Nothing

convertToModSum :: Expr -> Maybe Expr
convertToModSum = \case
Foldl' t1 IntTy (Lam2 x2 _ x1 _ body) init xs -> do
body <- parseModuloExpr body
(a, b) <- makeAffineFunctionFromArithmeticExpr x2 (arithmeticExprFromModuloExpr body)
guard $ isOneArithmeticExpr a

-- `if` is required for cases like `foldl (fun y x -> y % 2) 3 xs`, which is the same to `if xs == nil then 3 else 1`.
let wrap :: Expr -> Expr
wrap =
if init `isModulo` Modulo (moduloOfModuloExpr body)
then id
else If' IntTy (Equal' (ListTy t1) xs (Nil' t1)) init

return . wrap $
ModPlus' init (ModSum' (Map' t1 IntTy (Lam x1 t1 (formatArithmeticExpr b)) xs) (moduloOfModuloExpr body)) (moduloOfModuloExpr body)
_ -> Nothing

rule :: MonadAlpha m => RewriteRule m
rule = simpleRewriteRule "Jikka.Core.Convert.SpecializeFoldl" $ \case
(convertToSum -> Just e) -> return e
(convertToModSum -> Just e) -> return e
-- TODO: Replace these operators with the better implementation like sum.
Foldl' t1 t2 (Lam2 x2 _ x1 _ body) init xs -> case body of
-- Sum
Plus' (Var x2') e | x2' == x2 && x2 `isUnusedVar` e -> Just $ Sum' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs))
Plus' e (Var x2') | x2' == x2 && x2 `isUnusedVar` e -> Just $ Sum' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs))
Minus' (Var x2') e | x2' == x2 && x2 `isUnusedVar` e -> Just $ Minus' init (Sum' (Map' t1 t2 (Lam x1 t1 e) xs))
-- Product
Mult' (Var x2') e | x2' == x2 && x2 `isUnusedVar` e -> Just $ Product' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs))
Mult' e (Var x2') | x2' == x2 && x2 `isUnusedVar` e -> Just $ Product' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs))
Expand All @@ -60,10 +88,6 @@ rule = simpleRewriteRule "Jikka.Core.Convert.SpecializeFoldl" $ \case
_ -> Nothing
-- The outer floor-mod is required because foldl for empty lists returns values without modulo.
FloorMod' (Foldl' t1 t2 (Lam2 x2 _ x1 _ body) init xs) m -> case body of
-- ModSum
ModPlus' (Var x2') e m' | x2' == x2 && x2 `isUnusedVar` e && m' == m -> Just $ ModSum' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) m
ModPlus' e (Var x2') m' | x2' == x2 && x2 `isUnusedVar` e && m' == m -> Just $ ModSum' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) m
ModMinus' (Var x2') e m' | x2' == x2 && x2 `isUnusedVar` e && m' == m -> Just $ ModMinus' init (ModSum' (Map' t1 t2 (Lam x1 t1 e) xs) m) m
-- ModProduct
ModMult' (Var x2') e m' | x2' == x2 && x2 `isUnusedVar` e && m' == m -> Just $ ModProduct' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) m
ModMult' e (Var x2') m' | x2' == x2 && x2 `isUnusedVar` e && m' == m -> Just $ ModProduct' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) m
Expand Down
Loading