Skip to content

Commit

Permalink
Merge pull request #133 from kmyk/core-parser
Browse files Browse the repository at this point in the history
Add a quasi quote for rewrite rules
  • Loading branch information
kmyk committed Aug 4, 2021
2 parents 7cb4506 + f16415e commit 324934b
Show file tree
Hide file tree
Showing 33 changed files with 2,169 additions and 1,003 deletions.
14 changes: 10 additions & 4 deletions Jikka.cabal
Expand Up @@ -97,11 +97,16 @@ library
Jikka.Core.Language.FreeVars
Jikka.Core.Language.LambdaPatterns
Jikka.Core.Language.Lint
Jikka.Core.Language.QuasiRules
Jikka.Core.Language.RewriteRules
Jikka.Core.Language.Runtime
Jikka.Core.Language.TypeCheck
Jikka.Core.Language.Util
Jikka.Core.Language.Value
Jikka.Core.Parse
Jikka.Core.Parse.Alex
Jikka.Core.Parse.Happy
Jikka.Core.Parse.Token
Jikka.CPlusPlus.Convert
Jikka.CPlusPlus.Convert.AddMain
Jikka.CPlusPlus.Convert.BundleRuntime
Expand Down Expand Up @@ -162,7 +167,7 @@ library
, deepseq >=1.4.4 && <1.5
, directory >=1.3.3 && <1.4
, mtl >=2.2.2 && <2.3
, template-haskell >=2.14.0 && <2.17
, template-haskell >=2.16.0 && <2.17
, text >=1.2.3 && <1.3
, transformers >=0.5.6 && <0.6
, vector >=0.12.3 && <0.13
Expand All @@ -186,7 +191,7 @@ executable jikka
, deepseq >=1.4.4 && <1.5
, directory >=1.3.3 && <1.4
, mtl >=2.2.2 && <2.3
, template-haskell >=2.14.0 && <2.17
, template-haskell >=2.16.0 && <2.17
, text >=1.2.3 && <1.3
, transformers >=0.5.6 && <0.6
, vector >=0.12.3 && <0.13
Expand All @@ -210,7 +215,7 @@ test-suite jikka-doctest
, directory >=1.3.3 && <1.4
, doctest
, mtl >=2.2.2 && <2.3
, template-haskell >=2.14.0 && <2.17
, template-haskell >=2.16.0 && <2.17
, text >=1.2.3 && <1.3
, transformers >=0.5.6 && <0.6
, vector >=0.12.3 && <0.13
Expand Down Expand Up @@ -244,6 +249,7 @@ test-suite jikka-test
Jikka.Core.FormatSpec
Jikka.Core.Language.ArithmeticalExprSpec
Jikka.Core.Language.BetaSpec
Jikka.Core.ParseSpec
Jikka.CPlusPlus.Convert.FromCoreSpec
Jikka.CPlusPlus.FormatSpec
Jikka.Python.Convert.ToRestrictedPythonSpec
Expand Down Expand Up @@ -282,7 +288,7 @@ test-suite jikka-test
, hspec
, mtl >=2.2.2 && <2.3
, ormolu
, template-haskell >=2.14.0 && <2.17
, template-haskell >=2.16.0 && <2.17
, text >=1.2.3 && <1.3
, transformers >=0.5.6 && <0.6
, vector >=0.12.3 && <0.13
Expand Down
4 changes: 1 addition & 3 deletions doctests.hs
Expand Up @@ -7,9 +7,7 @@ import Test.DocTest
main :: IO ()
main = doctest
[ "src/Jikka/Common/"
, "src/Jikka/Core/"
, "src/Jikka/CPlusPlus/"
, "src/Jikka/Python/Convert/"
, "src/Jikka/Python/Language/"
, "src/Jikka/RestrictedPython/"
, "src/Jikka/RestrictedPython/Language/"
]
2 changes: 1 addition & 1 deletion package.yaml
Expand Up @@ -30,7 +30,7 @@ dependencies:
- deepseq >= 1.4.4 && < 1.5
- directory >= 1.3.3 && < 1.4
- mtl >= 2.2.2 && < 2.3
- template-haskell >= 2.14.0 && < 2.17
- template-haskell >= 2.16.0 && < 2.17
- text >= 1.2.3 && < 1.3
- transformers >= 0.5.6 && < 0.6
- vector >= 0.12.3 && < 0.13
Expand Down
336 changes: 179 additions & 157 deletions src/Jikka/CPlusPlus/Convert/FromCore.hs

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions src/Jikka/Common/Alpha.hs
Expand Up @@ -22,6 +22,8 @@ import Control.Monad.Reader
import Control.Monad.Signatures
import Control.Monad.State.Strict
import Control.Monad.Writer.Strict
import Data.Unique
import Language.Haskell.TH (Q)

class Monad m => MonadAlpha m where
nextCounter :: m Int
Expand Down Expand Up @@ -84,3 +86,9 @@ evalAlpha f i = runIdentity (evalAlphaT f i)

resetAlphaT :: Monad m => Int -> AlphaT m ()
resetAlphaT i = AlphaT $ \_ -> return ((), i)

instance MonadAlpha IO where
nextCounter = hashUnique <$> newUnique

instance MonadAlpha Q where
nextCounter = liftIO nextCounter
2 changes: 1 addition & 1 deletion src/Jikka/Core/Convert/ANormal.hs
Expand Up @@ -60,7 +60,7 @@ runExpr env = \case
f <- runExpr env f
args <- mapM (runExpr env) args
case (f, args) of
(Lit (LitBuiltin (If _)), [e1, e2, e3]) -> do
(Lit (LitBuiltin If _), [e1, e2, e3]) -> do
(_, ctx, e1) <- destruct env e1
return $ ctx (App3 f e1 e2 e3)
_ -> runApp env f args
Expand Down
2 changes: 1 addition & 1 deletion src/Jikka/Core/Convert/CumulativeSum.hs
Expand Up @@ -49,7 +49,7 @@ rule = RewriteRule $ \_ -> \case
then At' IntTy (Var b) n
else Minus' (At' IntTy (Var b) (Plus' n (formatArithmeticalExpr shift))) (At' IntTy (Var b) (formatArithmeticalExpr shift))
return . Just $
Let b (ListTy IntTy) (Scanl' IntTy IntTy (Lit (LitBuiltin Plus)) Lit0 a) e
Let b (ListTy IntTy) (Scanl' IntTy IntTy (Builtin Plus) Lit0 a) e
_ -> return Nothing
Max1' t (Cons' _ a0 (Map' _ _ (Lam x _ (At' _ a (Var x'))) (Range1' n))) | x' == x && x `isUnusedVar` a -> do
Just <$> cumulativeMax (Max2' t) t (Just a0) a n
Expand Down
13 changes: 7 additions & 6 deletions src/Jikka/Core/Convert/MakeScanl.hs
Expand Up @@ -31,6 +31,7 @@ module Jikka.Core.Convert.MakeScanl
where

import Control.Monad.Trans.Maybe
import Data.List
import qualified Data.Map as M
import Jikka.Common.Alpha
import Jikka.Common.Error
Expand All @@ -56,7 +57,7 @@ reduceScanlBuild = simpleRewriteRule $ \case
_ -> Nothing

-- | `getRecurrenceFormulaStep1` removes `At` in @body@.
getRecurrenceFormulaStep1 :: MonadAlpha m => Int -> Type -> VarName -> VarName -> Expr -> m (Maybe Expr)
getRecurrenceFormulaStep1 :: MonadAlpha m => Integer -> Type -> VarName -> VarName -> Expr -> m (Maybe Expr)
getRecurrenceFormulaStep1 shift t a i body = do
x <- genVarName a
let proj k =
Expand All @@ -78,13 +79,13 @@ getRecurrenceFormulaStep1 shift t a i body = do
Nothing -> Nothing

-- | `getRecurrenceFormulaStep` replaces `At` in @body@ with `Proj`.
getRecurrenceFormulaStep :: MonadAlpha m => Int -> Int -> Type -> VarName -> VarName -> Expr -> m (Maybe Expr)
getRecurrenceFormulaStep :: MonadAlpha m => Integer -> Integer -> Type -> VarName -> VarName -> Expr -> m (Maybe Expr)
getRecurrenceFormulaStep shift size t a i body = do
x <- genVarName a
let ts = replicate size t
let ts = replicate (fromInteger size) t
let proj k =
if 0 <= toInteger shift + k && toInteger shift + k < toInteger size
then Just $ Proj' ts (shift + fromInteger k) (Var x)
then Just $ Proj' ts (shift + k) (Var x)
else Nothing
let go :: Expr -> Maybe Expr
go = \case
Expand Down Expand Up @@ -129,9 +130,9 @@ reduceFoldlSetAtRecurrence = RewriteRule $ \_ -> \case
_ -> do
let ts = replicate (length base) t2
let base' = uncurryApp (Tuple' ts) base
step <- MaybeT $ getRecurrenceFormulaStep (- length base + fromInteger k) (length base) t2 a i step
step <- MaybeT $ getRecurrenceFormulaStep (- genericLength base + k) (genericLength base) t2 a i step
x <- lift (genVarName a)
return $ foldr (Cons' t2) (Map' (TupleTy ts) t2 (Lam x (TupleTy ts) (Proj' ts (length base - 1) (Var x))) (Scanl' IntTy (TupleTy ts) step base' (Range1' n))) (init base)
return $ foldr (Cons' t2) (Map' (TupleTy ts) t2 (Lam x (TupleTy ts) (Proj' ts (genericLength base - 1) (Var x))) (Scanl' IntTy (TupleTy ts) step base' (Range1' n))) (init base)
_ -> return Nothing

-- | `checkAccumulationFormulaStep` checks that all `At` in @body@ about @a@ are @At a i@.
Expand Down
17 changes: 9 additions & 8 deletions src/Jikka/Core/Convert/MatrixExponentiation.hs
Expand Up @@ -16,6 +16,7 @@ where

import Control.Monad.Trans
import Control.Monad.Trans.Maybe
import Data.List
import qualified Data.Vector as V
import Jikka.Common.Alpha
import Jikka.Common.Error
Expand Down Expand Up @@ -52,13 +53,13 @@ fromAffineMatrix a b =
bottom = uncurryApp (Tuple' (replicate (w + 1) IntTy)) (replicate w (LitInt' 0) ++ [LitInt' 1])
in uncurryApp (Tuple' (replicate (h + 1) (TupleTy (replicate (w + 1) IntTy)))) (V.toList (V.zipWith go (unMatrix a) b) ++ [bottom])

toMatrix :: MonadAlpha m => [(VarName, Type)] -> VarName -> Int -> Expr -> m (Maybe (Matrix ArithmeticalExpr, Maybe (V.Vector ArithmeticalExpr)))
toMatrix :: MonadAlpha m => [(VarName, Type)] -> VarName -> Integer -> Expr -> m (Maybe (Matrix ArithmeticalExpr, Maybe (V.Vector ArithmeticalExpr)))
toMatrix env x n step =
case curryApp step of
(Tuple' _, es) -> runMaybeT $ do
xs <- V.fromList <$> replicateM n (lift (genVarName x))
xs <- V.fromList <$> replicateM (fromInteger n) (lift (genVarName x))
let unpackTuple _ e = case e of
Proj' _ i (Var x') | x' == x -> Var (xs V.! i)
Proj' _ i (Var x') | x' == x -> Var (xs V.! fromInteger i)
_ -> e
rows <- MaybeT . return . forM es $ \e -> do
let e' = mapExpr unpackTuple env e
Expand All @@ -69,14 +70,14 @@ toMatrix env x n step =
return (a, b)
_ -> return Nothing

addOneToVector :: Int -> VarName -> Expr
addOneToVector :: Integer -> VarName -> Expr
addOneToVector n x =
let ts = replicate n IntTy
let ts = replicate (fromInteger n) IntTy
in uncurryApp (Tuple' (IntTy : ts)) (map (\i -> Proj' ts i (Var x)) [0 .. n - 1] ++ [LitInt' 1])

removeOneFromVector :: Int -> VarName -> Expr
removeOneFromVector :: Integer -> VarName -> Expr
removeOneFromVector n x =
let ts = replicate n IntTy
let ts = replicate (fromInteger n) IntTy
in uncurryApp (Tuple' ts) (map (\i -> Proj' (IntTy : ts) i (Var x)) [0 .. n - 1])

rule :: MonadAlpha m => RewriteRule m
Expand All @@ -93,7 +94,7 @@ rule = RewriteRule $ \env -> \case
b' = Mult' (FloorDiv' (Minus' (Pow' a k) (LitInt' 1)) (Minus' a (LitInt' 1))) b -- This division has no remainder.
in Just $ Plus' (Mult' a' base) b'
Iterate' (TupleTy ts) k (Lam x _ step) base | isVectorTy' ts -> do
let n = length ts
let n = genericLength ts
let go n step base = MatAp' n n (MatPow' n step k) base
step <- toMatrix env x n step
case step of
Expand Down
9 changes: 5 additions & 4 deletions src/Jikka/Core/Convert/PropagateMod.hs
Expand Up @@ -14,6 +14,7 @@ module Jikka.Core.Convert.PropagateMod
)
where

import Data.List
import Data.Maybe
import Jikka.Common.Alpha
import Jikka.Common.Error
Expand Down Expand Up @@ -81,11 +82,11 @@ putFloorMod (Mod m) =
LitInt' n -> case m of
LitInt' m -> return' $ LitInt' (n `mod` m)
_ -> return Nothing
Proj' ts i e | isVectorTy' ts -> return' $ Proj' ts i (VecFloorMod' (length ts) e m)
Proj' ts i e | isVectorTy' ts -> return' $ Proj' ts i (VecFloorMod' (genericLength ts) e m)
Proj' ts i e
| isMatrixTy' ts ->
let (h, w) = fromJust (sizeOfMatrixTy (TupleTy ts))
in return' $ Proj' ts i (MatFloorMod' h w e m)
in return' $ Proj' ts i (MatFloorMod' (toInteger h) (toInteger w) e m)
Map' t1 t2 f xs -> do
f <- putFloorMod (Mod m) f
case f of
Expand Down Expand Up @@ -144,7 +145,7 @@ putVecFloorMod env = putFloorModGeneric fallback
fallback e (Mod m) = do
t <- typecheckExpr env e
case t of
TupleTy ts -> return $ VecFloorMod' (length ts) e m
TupleTy ts -> return $ VecFloorMod' (genericLength ts) e m
_ -> throwInternalError $ "not a vector: " ++ formatType t

putMatFloorMod :: (MonadError Error m, MonadAlpha m) => [(VarName, Type)] -> Mod -> Expr -> m Expr
Expand All @@ -153,7 +154,7 @@ putMatFloorMod env = putFloorModGeneric fallback
fallback e (Mod m) = do
t <- typecheckExpr env e
case t of
TupleTy ts@(TupleTy ts' : _) -> return $ MatFloorMod' (length ts) (length ts') e m
TupleTy ts@(TupleTy ts' : _) -> return $ MatFloorMod' (genericLength ts) (genericLength ts') e m
_ -> throwInternalError $ "not a matrix: " ++ formatType t

rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
Expand Down
29 changes: 15 additions & 14 deletions src/Jikka/Core/Convert/SegmentTree.hs
Expand Up @@ -69,26 +69,26 @@ pattern CumulativeSumFlip t f e es <-
x2 = findUnusedVarName (VarName "x") f
in Scanl' t t (Lam2 x1 t x2 t (App (App f (Var x2)) (Var x1))) e es

builtinToSemigroup :: Builtin -> Maybe Semigroup'
builtinToSemigroup = \case
Plus -> Just SemigroupIntPlus
Min2 IntTy -> Just SemigroupIntMin
Max2 IntTy -> Just SemigroupIntMax
builtinToSemigroup :: Builtin -> [Type] -> Maybe Semigroup'
builtinToSemigroup builtin ts = case (builtin, ts) of
(Plus, []) -> Just SemigroupIntPlus
(Min2, [IntTy]) -> Just SemigroupIntMin
(Max2, [IntTy]) -> Just SemigroupIntMax
_ -> Nothing

semigroupToBuiltin :: Semigroup' -> Builtin
semigroupToBuiltin :: Semigroup' -> (Builtin, [Type])
semigroupToBuiltin = \case
SemigroupIntPlus -> Plus
SemigroupIntMin -> Min2 IntTy
SemigroupIntMax -> Max2 IntTy
SemigroupIntPlus -> (Plus, [])
SemigroupIntMin -> (Min2, [IntTy])
SemigroupIntMax -> (Max2, [IntTy])

unCumulativeSum :: Expr -> Expr -> Maybe (Semigroup', Expr)
unCumulativeSum a = \case
CumulativeSum _ (Lit (LitBuiltin op)) b a' | a' == a -> case builtinToSemigroup op of
CumulativeSum _ (Lit (LitBuiltin op ts)) b a' | a' == a -> case builtinToSemigroup op ts of
Just semigrp -> Just (semigrp, b)
Nothing -> Nothing
-- Semigroups must be commutative to use CumulativeSumFlip.
CumulativeSumFlip _ (Lit (LitBuiltin op)) b a' | a' == a -> case builtinToSemigroup op of
CumulativeSumFlip _ (Lit (LitBuiltin op ts)) b a' | a' == a -> case builtinToSemigroup op ts of
Just semigrp -> Just (semigrp, b)
Nothing -> Nothing
_ -> Nothing
Expand All @@ -103,7 +103,7 @@ replaceWithSegtrees a segtrees = go M.empty
go env = \case
At' _ (check env -> Just (e, b, semigrp)) i ->
let e' = SegmentTreeGetRange' semigrp e (LitInt' 0) i
in AppBuiltin2 (semigroupToBuiltin semigrp) b e'
in App2 (Lit (uncurry LitBuiltin (semigroupToBuiltin semigrp))) b e'
Var x -> Var x
Lit lit -> Lit lit
App e1 e2 -> App (go env e1) (go env e2)
Expand All @@ -113,10 +113,11 @@ replaceWithSegtrees a segtrees = go M.empty
in case check env e1' of
Just (e1', b, semigrp) -> go (M.insert x (e1', b, semigrp) env) e2
Nothing -> Let x t (go env e1) (go env e2)
check :: M.Map VarName (Expr, Expr, Semigroup') -> Expr -> Maybe (Expr, Expr, Semigroup')
check env = \case
Var x -> M.lookup x env
CumulativeSum _ (Lit (LitBuiltin op)) b (Var a') | a' == a -> case lookup op (map (first semigroupToBuiltin) segtrees) of
Just e -> Just (e, b, fromJust (builtinToSemigroup op))
CumulativeSum _ (Lit (LitBuiltin op ts)) b (Var a') | a' == a -> case lookup (op, ts) (map (first semigroupToBuiltin) segtrees) of
Just e -> Just (e, b, fromJust (builtinToSemigroup op ts))
Nothing -> Nothing
_ -> Nothing

Expand Down

0 comments on commit 324934b

Please sign in to comment.