Skip to content

Commit

Permalink
Merge pull request #192 from kmyk/update-specialize-foldl
Browse files Browse the repository at this point in the history
feat(core): Update about modulo
  • Loading branch information
kmyk committed Aug 23, 2021
2 parents a67b95c + 6ed6507 commit ae91e9e
Show file tree
Hide file tree
Showing 6 changed files with 589 additions and 190 deletions.
1 change: 1 addition & 0 deletions Jikka.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ library
Jikka.Core.Language.FreeVars
Jikka.Core.Language.LambdaPatterns
Jikka.Core.Language.Lint
Jikka.Core.Language.ModuloExpr
Jikka.Core.Language.QuasiRules
Jikka.Core.Language.RewriteRules
Jikka.Core.Language.Runtime
Expand Down
215 changes: 88 additions & 127 deletions src/Jikka/Core/Convert/PropagateMod.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ module Jikka.Core.Convert.PropagateMod
)
where

import Control.Monad.Trans.Maybe
import Data.List
import Data.Maybe
import Jikka.Common.Alpha
Expand All @@ -23,175 +24,135 @@ import Jikka.Core.Language.Beta
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.Lint
import Jikka.Core.Language.ModuloExpr
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.TypeCheck
import Jikka.Core.Language.Util

-- | `Mod` is a newtype to avoid mistakes that swapping left and right of mod-op.
newtype Mod = Mod Expr
isModulo' :: Expr -> Expr -> Bool
isModulo' e m = e `isModulo` Modulo m

isModulo' :: Expr -> Mod -> Bool
isModulo' e (Mod m) = case e of
FloorMod' _ m' -> m' == m
ModNegate' _ m' -> m' == m
ModPlus' _ _ m' -> m' == m
ModMinus' _ _ m' -> m' == m
ModMult' _ _ m' -> m' == m
ModInv' _ m' -> m' == m
ModPow' _ _ m' -> m' == m
VecFloorMod' _ _ m' -> m' == m
MatFloorMod' _ _ _ m' -> m' == m
ModMatAp' _ _ _ _ m' -> m' == m
ModMatAdd' _ _ _ _ m' -> m' == m
ModMatMul' _ _ _ _ _ m' -> m' == m
ModMatPow' _ _ _ m' -> m' == m
ModSum' _ m' -> m' == m
ModProduct' _ m' -> m' == m
LitInt' n -> case m of
LitInt' m -> 0 <= n && n < m
_ -> False
Proj' ts _ e | isVectorTy' ts -> e `isModulo'` Mod m
Proj' ts _ e | isMatrixTy' ts -> e `isModulo'` Mod m
Map' _ _ f _ -> f `isModulo'` Mod m
Lam _ _ body -> body `isModulo'` Mod m
e@(App _ _) -> case curryApp e of
(e@(Lam _ _ _), _) -> e `isModulo'` Mod m
(Tuple' ts, es) | isVectorTy' ts -> all (`isModulo'` Mod m) es
(Tuple' ts, es) | isMatrixTy' ts -> all (`isModulo'` Mod m) es
_ -> False
_ -> False
putFloorMod :: MonadAlpha m => Modulo -> Expr -> m (Maybe Expr)
putFloorMod (Modulo m) =
runMaybeT . \case
Negate' e -> return $ ModNegate' e m
Plus' e1 e2 -> return $ ModPlus' e1 e2 m
Minus' e1 e2 -> return $ ModMinus' e1 e2 m
Mult' e1 e2 -> return $ ModMult' e1 e2 m
JustDiv' e1 e2 -> return $ ModMult' e1 (ModInv' e2 m) m
Pow' e1 e2 -> return $ ModPow' e1 e2 m
MatAp' h w e1 e2 -> return $ ModMatAp' h w e1 e2 m
MatAdd' h w e1 e2 -> return $ ModMatAdd' h w e1 e2 m
MatMul' h n w e1 e2 -> return $ ModMatMul' h n w e1 e2 m
MatPow' n e1 e2 -> return $ ModMatPow' n e1 e2 m
Sum' e -> return $ ModSum' e m
Product' e -> return $ ModProduct' e m
LitInt' n -> case m of
LitInt' m -> return $ LitInt' (n `mod` m)
_ -> MaybeT $ return Nothing
Proj' ts i e | isVectorTy' ts -> return $ Proj' ts i (VecFloorMod' (genericLength ts) e m)
Proj' ts i e
| isMatrixTy' ts ->
let (h, w) = fromJust (sizeOfMatrixTy (TupleTy ts))
in return $ Proj' ts i (MatFloorMod' (toInteger h) (toInteger w) e m)
Map' t1 t2 f xs -> do
f <- MaybeT $ putFloorMod (Modulo m) f
return $ Map' t1 t2 f xs
Foldl' t1 t2 f init xs -> do
f <- MaybeT $ putFloorMod (Modulo m) f
return $ Foldl' t1 t2 f init xs
Lam x t body -> do
-- TODO: rename only if required
y <- lift $ genVarName x
body <- lift $ substitute x (Var y) body
body <- MaybeT $ putFloorMod (Modulo m) body
return $ Lam y t body
e@(App _ _) -> case curryApp e of
(f@(Lam _ _ _), args) -> do
f <- MaybeT $ putFloorMod (Modulo m) f
return $ uncurryApp f args
(Tuple' ts, es) | isVectorTy' ts -> do
es' <- lift $ mapM (putFloorMod (Modulo m)) es
if all isNothing es'
then MaybeT $ return Nothing
else return $ uncurryApp (Tuple' ts) (zipWith fromMaybe es es')
(Tuple' ts, es) | isMatrixTy (TupleTy ts) -> do
es' <- lift $ mapM (putFloorMod (Modulo m)) es
if all isNothing es'
then MaybeT $ return Nothing
else return $ uncurryApp (Tuple' ts) (zipWith fromMaybe es es')
_ -> MaybeT $ return Nothing
_ -> MaybeT $ return Nothing

isModulo :: Expr -> Expr -> Bool
isModulo e m = e `isModulo'` Mod m

putFloorMod :: MonadAlpha m => Mod -> Expr -> m (Maybe Expr)
putFloorMod (Mod m) =
let return' = return . Just
in \case
Negate' e -> return' $ ModNegate' e m
Plus' e1 e2 -> return' $ ModPlus' e1 e2 m
Minus' e1 e2 -> return' $ ModMinus' e1 e2 m
Mult' e1 e2 -> return' $ ModMult' e1 e2 m
JustDiv' e1 e2 -> return' $ ModMult' e1 (ModInv' e2 m) m
Pow' e1 e2 -> return' $ ModPow' e1 e2 m
MatAp' h w e1 e2 -> return' $ ModMatAp' h w e1 e2 m
MatAdd' h w e1 e2 -> return' $ ModMatAdd' h w e1 e2 m
MatMul' h n w e1 e2 -> return' $ ModMatMul' h n w e1 e2 m
MatPow' n e1 e2 -> return' $ ModMatPow' n e1 e2 m
Sum' e -> return' $ ModSum' e m
Product' e -> return' $ ModProduct' e m
LitInt' n -> case m of
LitInt' m -> return' $ LitInt' (n `mod` m)
_ -> return Nothing
Proj' ts i e | isVectorTy' ts -> return' $ Proj' ts i (VecFloorMod' (genericLength ts) e m)
Proj' ts i e
| isMatrixTy' ts ->
let (h, w) = fromJust (sizeOfMatrixTy (TupleTy ts))
in return' $ Proj' ts i (MatFloorMod' (toInteger h) (toInteger w) e m)
Map' t1 t2 f xs -> do
f <- putFloorMod (Mod m) f
case f of
Nothing -> return Nothing
Just f -> return' $ Map' t1 t2 f xs
Foldl' t1 t2 f init xs -> do
f <- putFloorMod (Mod m) f
case f of
Nothing -> return Nothing
Just f -> return' $ Foldl' t1 t2 f init xs
Lam x t body -> do
-- TODO: rename only if required
y <- genVarName x
body <- substitute x (Var y) body
body <- putFloorMod (Mod m) body
case body of
Nothing -> return Nothing
Just body -> return' $ Lam y t body
e@(App _ _) -> case curryApp e of
(f@(Lam _ _ _), args) -> do
f <- putFloorMod (Mod m) f
case f of
Nothing -> return Nothing
Just f -> return' $ uncurryApp f args
(Tuple' ts, es) | isVectorTy' ts -> do
es' <- mapM (putFloorMod (Mod m)) es
if all isNothing es'
then return Nothing
else return' $ uncurryApp (Tuple' ts) (zipWith fromMaybe es es')
(Tuple' ts, es) | isMatrixTy (TupleTy ts) -> do
es' <- mapM (putFloorMod (Mod m)) es
if all isNothing es'
then return Nothing
else return' $ uncurryApp (Tuple' ts) (zipWith fromMaybe es es')
_ -> return Nothing
_ -> return Nothing

putFloorModGeneric :: MonadAlpha m => (Expr -> Mod -> m Expr) -> Mod -> Expr -> m Expr
putFloorModGeneric :: MonadAlpha m => (Expr -> Modulo -> m Expr) -> Modulo -> Expr -> m Expr
putFloorModGeneric fallback m e =
if e `isModulo'` m
if e `isModulo` m
then return e
else do
e' <- putFloorMod m e
case e' of
Just e' -> return e'
Nothing -> fallback e m

putFloorModInt :: MonadAlpha m => Mod -> Expr -> m Expr
putFloorModInt = putFloorModGeneric (\e (Mod m) -> return $ FloorMod' e m)

putMapFloorMod :: MonadAlpha m => Mod -> Expr -> m Expr
putMapFloorMod :: MonadAlpha m => Modulo -> Expr -> m Expr
putMapFloorMod = putFloorModGeneric fallback
where
fallback e (Mod m) = do
fallback e (Modulo m) = do
x <- genVarName'
return $ Map' IntTy IntTy (Lam x IntTy (FloorMod' (Var x) m)) e

putVecFloorMod :: (MonadError Error m, MonadAlpha m) => [(VarName, Type)] -> Mod -> Expr -> m Expr
putVecFloorMod :: (MonadError Error m, MonadAlpha m) => [(VarName, Type)] -> Modulo -> Expr -> m Expr
putVecFloorMod env = putFloorModGeneric fallback
where
fallback e (Mod m) = do
fallback e (Modulo m) = do
t <- typecheckExpr env e
case t of
TupleTy ts -> return $ VecFloorMod' (genericLength ts) e m
_ -> throwInternalError $ "not a vector: " ++ formatType t

putMatFloorMod :: (MonadError Error m, MonadAlpha m) => [(VarName, Type)] -> Mod -> Expr -> m Expr
putMatFloorMod :: (MonadError Error m, MonadAlpha m) => [(VarName, Type)] -> Modulo -> Expr -> m Expr
putMatFloorMod env = putFloorModGeneric fallback
where
fallback e (Mod m) = do
fallback e (Modulo m) = do
t <- typecheckExpr env e
case t of
TupleTy ts@(TupleTy ts' : _) -> return $ MatFloorMod' (genericLength ts) (genericLength ts') e m
_ -> throwInternalError $ "not a matrix: " ++ formatType t

rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
rule =
let go1 m f (t1, e1) = Just <$> (f <$> t1 (Mod m) e1 <*> pure m)
go2 m f (t1, e1) (t2, e2) = Just <$> (f <$> t1 (Mod m) e1 <*> t2 (Mod m) e2 <*> pure m)
let go0 :: Expr -> Maybe Expr
go0 e = do
e' <- formatBottomModuloExpr <$> parseModuloExpr e
guard $ e' /= e
return e'
go1 m f (t1, e1) = Just <$> (f <$> t1 (Modulo m) e1 <*> pure m)
go2 m f (t1, e1) (t2, e2) = Just <$> (f <$> t1 (Modulo m) e1 <*> t2 (Modulo m) e2 <*> pure m)
in makeRewriteRule "Jikka.Core.Convert.PropagateMod" $ \env -> \case
ModNegate' e m | not (e `isModulo` m) -> go1 m ModNegate' (putFloorModInt, e)
ModPlus' e1 e2 m | not (e1 `isModulo` m) || not (e2 `isModulo` m) -> go2 m ModPlus' (putFloorModInt, e1) (putFloorModInt, e2)
ModMinus' e1 e2 m | not (e1 `isModulo` m) || not (e2 `isModulo` m) -> go2 m ModMinus' (putFloorModInt, e1) (putFloorModInt, e2)
ModMult' e1 e2 m | not (e1 `isModulo` m) || not (e2 `isModulo` m) -> go2 m ModMult' (putFloorModInt, e1) (putFloorModInt, e2)
ModInv' e m | not (e `isModulo` m) -> go1 m ModInv' (putFloorModInt, e)
ModPow' e1 e2 m | not (e1 `isModulo` m) -> go2 m ModPow' (putFloorModInt, e1) (\_ e -> return e, e2)
ModMatAp' h w e1 e2 m | not (e1 `isModulo` m) || not (e2 `isModulo` m) -> go2 m (ModMatAp' h w) (putMatFloorMod env, e1) (putVecFloorMod env, e2)
ModMatAdd' h w e1 e2 m | not (e1 `isModulo` m) || not (e2 `isModulo` m) -> go2 m (ModMatAdd' h w) (putMatFloorMod env, e1) (putMatFloorMod env, e2)
ModMatMul' h n w e1 e2 m | not (e1 `isModulo` m) || not (e2 `isModulo` m) -> go2 m (ModMatMul' h n w) (putMatFloorMod env, e1) (putMatFloorMod env, e2)
ModMatPow' n e1 e2 m | not (e1 `isModulo` m) -> go2 m (ModMatPow' n) (putMatFloorMod env, e1) (\_ e -> return e, e2)
ModSum' e m | not (e `isModulo` m) -> go1 m ModSum' (putMapFloorMod, e)
ModProduct' e m | not (e `isModulo` m) -> go1 m ModProduct' (putMapFloorMod, e)
e@(ModNegate' _ _) -> return $ go0 e
e@(ModPlus' _ _ _) -> return $ go0 e
e@(ModMinus' _ _ _) -> return $ go0 e
e@(ModMult' _ _ _) -> return $ go0 e
e@(ModInv' _ _) -> return $ go0 e
e@(ModPow' _ _ _) -> return $ go0 e
ModMatAp' h w e1 e2 m | not (e1 `isModulo'` m) || not (e2 `isModulo'` m) -> go2 m (ModMatAp' h w) (putMatFloorMod env, e1) (putVecFloorMod env, e2)
ModMatAdd' h w e1 e2 m | not (e1 `isModulo'` m) || not (e2 `isModulo'` m) -> go2 m (ModMatAdd' h w) (putMatFloorMod env, e1) (putMatFloorMod env, e2)
ModMatMul' h n w e1 e2 m | not (e1 `isModulo'` m) || not (e2 `isModulo'` m) -> go2 m (ModMatMul' h n w) (putMatFloorMod env, e1) (putMatFloorMod env, e2)
ModMatPow' n e1 e2 m | not (e1 `isModulo'` m) -> go2 m (ModMatPow' n) (putMatFloorMod env, e1) (\_ e -> return e, e2)
ModSum' e m | not (e `isModulo'` m) -> go1 m ModSum' (putMapFloorMod, e)
ModProduct' e m | not (e `isModulo'` m) -> go1 m ModProduct' (putMapFloorMod, e)
FloorMod' e m ->
if e `isModulo` m
if e `isModulo'` m
then return $ Just e
else putFloorMod (Mod m) e
else putFloorMod (Modulo m) e
VecFloorMod' _ e m ->
if e `isModulo` m
if e `isModulo'` m
then return $ Just e
else putFloorMod (Mod m) e
else putFloorMod (Modulo m) e
MatFloorMod' _ _ e m ->
if e `isModulo` m
if e `isModulo'` m
then return $ Just e
else putFloorMod (Mod m) e
else putFloorMod (Modulo m) e
_ -> return Nothing

runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
Expand Down
40 changes: 32 additions & 8 deletions src/Jikka/Core/Convert/SpecializeFoldl.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module : Jikka.Core.Convert.SpecializeFoldl
Expand All @@ -22,19 +23,46 @@ where

import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Core.Language.ArithmeticExpr
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars
import Jikka.Core.Language.Lint
import Jikka.Core.Language.ModuloExpr
import Jikka.Core.Language.RewriteRules

convertToSum :: Expr -> Maybe Expr
convertToSum = \case
Foldl' t1 IntTy (Lam2 x2 _ x1 _ body) init xs -> do
(a, b) <- makeAffineFunctionFromArithmeticExpr x2 (parseArithmeticExpr body)
guard $ isOneArithmeticExpr a
return $ Plus' init (Sum' (Map' t1 IntTy (Lam x1 t1 (formatArithmeticExpr b)) xs))
_ -> Nothing

convertToModSum :: Expr -> Maybe Expr
convertToModSum = \case
Foldl' t1 IntTy (Lam2 x2 _ x1 _ body) init xs -> do
body <- parseModuloExpr body
(a, b) <- makeAffineFunctionFromArithmeticExpr x2 (arithmeticExprFromModuloExpr body)
guard $ isOneArithmeticExpr a

-- `if` is required for cases like `foldl (fun y x -> y % 2) 3 xs`, which is the same to `if xs == nil then 3 else 1`.
let wrap :: Expr -> Expr
wrap =
if init `isModulo` Modulo (moduloOfModuloExpr body)
then id
else If' IntTy (Equal' (ListTy t1) xs (Nil' t1)) init

return . wrap $
ModPlus' init (ModSum' (Map' t1 IntTy (Lam x1 t1 (formatArithmeticExpr b)) xs) (moduloOfModuloExpr body)) (moduloOfModuloExpr body)
_ -> Nothing

rule :: MonadAlpha m => RewriteRule m
rule = simpleRewriteRule "Jikka.Core.Convert.SpecializeFoldl" $ \case
(convertToSum -> Just e) -> return e
(convertToModSum -> Just e) -> return e
-- TODO: Replace these operators with the better implementation like sum.
Foldl' t1 t2 (Lam2 x2 _ x1 _ body) init xs -> case body of
-- Sum
Plus' (Var x2') e | x2' == x2 && x2 `isUnusedVar` e -> Just $ Sum' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs))
Plus' e (Var x2') | x2' == x2 && x2 `isUnusedVar` e -> Just $ Sum' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs))
Minus' (Var x2') e | x2' == x2 && x2 `isUnusedVar` e -> Just $ Minus' init (Sum' (Map' t1 t2 (Lam x1 t1 e) xs))
-- Product
Mult' (Var x2') e | x2' == x2 && x2 `isUnusedVar` e -> Just $ Product' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs))
Mult' e (Var x2') | x2' == x2 && x2 `isUnusedVar` e -> Just $ Product' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs))
Expand All @@ -60,10 +88,6 @@ rule = simpleRewriteRule "Jikka.Core.Convert.SpecializeFoldl" $ \case
_ -> Nothing
-- The outer floor-mod is required because foldl for empty lists returns values without modulo.
FloorMod' (Foldl' t1 t2 (Lam2 x2 _ x1 _ body) init xs) m -> case body of
-- ModSum
ModPlus' (Var x2') e m' | x2' == x2 && x2 `isUnusedVar` e && m' == m -> Just $ ModSum' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) m
ModPlus' e (Var x2') m' | x2' == x2 && x2 `isUnusedVar` e && m' == m -> Just $ ModSum' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) m
ModMinus' (Var x2') e m' | x2' == x2 && x2 `isUnusedVar` e && m' == m -> Just $ ModMinus' init (ModSum' (Map' t1 t2 (Lam x1 t1 e) xs) m) m
-- ModProduct
ModMult' (Var x2') e m' | x2' == x2 && x2 `isUnusedVar` e && m' == m -> Just $ ModProduct' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) m
ModMult' e (Var x2') m' | x2' == x2 && x2 `isUnusedVar` e && m' == m -> Just $ ModProduct' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) m
Expand Down
Loading

0 comments on commit ae91e9e

Please sign in to comment.