diff --git a/Jikka.cabal b/Jikka.cabal index f510b278..bd7e2ee9 100644 --- a/Jikka.cabal +++ b/Jikka.cabal @@ -97,11 +97,16 @@ library Jikka.Core.Language.FreeVars Jikka.Core.Language.LambdaPatterns Jikka.Core.Language.Lint + Jikka.Core.Language.QuasiRules Jikka.Core.Language.RewriteRules Jikka.Core.Language.Runtime Jikka.Core.Language.TypeCheck Jikka.Core.Language.Util Jikka.Core.Language.Value + Jikka.Core.Parse + Jikka.Core.Parse.Alex + Jikka.Core.Parse.Happy + Jikka.Core.Parse.Token Jikka.CPlusPlus.Convert Jikka.CPlusPlus.Convert.AddMain Jikka.CPlusPlus.Convert.BundleRuntime @@ -162,7 +167,7 @@ library , deepseq >=1.4.4 && <1.5 , directory >=1.3.3 && <1.4 , mtl >=2.2.2 && <2.3 - , template-haskell >=2.14.0 && <2.17 + , template-haskell >=2.16.0 && <2.17 , text >=1.2.3 && <1.3 , transformers >=0.5.6 && <0.6 , vector >=0.12.3 && <0.13 @@ -186,7 +191,7 @@ executable jikka , deepseq >=1.4.4 && <1.5 , directory >=1.3.3 && <1.4 , mtl >=2.2.2 && <2.3 - , template-haskell >=2.14.0 && <2.17 + , template-haskell >=2.16.0 && <2.17 , text >=1.2.3 && <1.3 , transformers >=0.5.6 && <0.6 , vector >=0.12.3 && <0.13 @@ -210,7 +215,7 @@ test-suite jikka-doctest , directory >=1.3.3 && <1.4 , doctest , mtl >=2.2.2 && <2.3 - , template-haskell >=2.14.0 && <2.17 + , template-haskell >=2.16.0 && <2.17 , text >=1.2.3 && <1.3 , transformers >=0.5.6 && <0.6 , vector >=0.12.3 && <0.13 @@ -244,6 +249,7 @@ test-suite jikka-test Jikka.Core.FormatSpec Jikka.Core.Language.ArithmeticalExprSpec Jikka.Core.Language.BetaSpec + Jikka.Core.ParseSpec Jikka.CPlusPlus.Convert.FromCoreSpec Jikka.CPlusPlus.FormatSpec Jikka.Python.Convert.ToRestrictedPythonSpec @@ -282,7 +288,7 @@ test-suite jikka-test , hspec , mtl >=2.2.2 && <2.3 , ormolu - , template-haskell >=2.14.0 && <2.17 + , template-haskell >=2.16.0 && <2.17 , text >=1.2.3 && <1.3 , transformers >=0.5.6 && <0.6 , vector >=0.12.3 && <0.13 diff --git a/doctests.hs b/doctests.hs index 1479e6bc..1cb9663d 100644 --- a/doctests.hs +++ b/doctests.hs @@ -7,9 +7,7 @@ import Test.DocTest main :: IO () main = doctest [ "src/Jikka/Common/" - , "src/Jikka/Core/" - , "src/Jikka/CPlusPlus/" , "src/Jikka/Python/Convert/" , "src/Jikka/Python/Language/" - , "src/Jikka/RestrictedPython/" + , "src/Jikka/RestrictedPython/Language/" ] diff --git a/package.yaml b/package.yaml index 86860ae7..f97bfac3 100644 --- a/package.yaml +++ b/package.yaml @@ -30,7 +30,7 @@ dependencies: - deepseq >= 1.4.4 && < 1.5 - directory >= 1.3.3 && < 1.4 - mtl >= 2.2.2 && < 2.3 - - template-haskell >= 2.14.0 && < 2.17 + - template-haskell >= 2.16.0 && < 2.17 - text >= 1.2.3 && < 1.3 - transformers >= 0.5.6 && < 0.6 - vector >= 0.12.3 && < 0.13 diff --git a/src/Jikka/CPlusPlus/Convert/FromCore.hs b/src/Jikka/CPlusPlus/Convert/FromCore.hs index 93d6d652..ff92c19d 100644 --- a/src/Jikka/CPlusPlus/Convert/FromCore.hs +++ b/src/Jikka/CPlusPlus/Convert/FromCore.hs @@ -1,7 +1,6 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE TupleSections #-} -- | -- Module : Jikka.CPlusPlus.Convert.FromCore @@ -72,8 +71,8 @@ runSemigroup = \case runLiteral :: (MonadAlpha m, MonadError Error m) => Env -> X.Literal -> m Y.Expr runLiteral env = \case - X.LitBuiltin builtin -> do - (stmts, e) <- runAppBuiltin env builtin [] + X.LitBuiltin builtin ts -> do + (stmts, e) <- runAppBuiltin env builtin ts [] case stmts of [] -> return e _ -> throwInternalError "now builtin values don't use statements" @@ -88,83 +87,119 @@ runLiteral env = \case t <- runType t return $ Y.Call (Y.Function "jikka::error" [t]) [Y.Lit (Y.LitString err)] -arityOfBuiltin :: X.Builtin -> Int -arityOfBuiltin = \case - X.Min2 _ -> 2 - X.Max2 _ -> 2 - X.Foldl _ _ -> 3 - X.Iterate _ -> 3 - X.At _ -> 2 - X.Min1 _ -> 1 - X.Max1 _ -> 1 - X.Proj _ _ -> 1 - builtin -> length (fst (X.uncurryFunTy (X.builtinToType builtin))) +arityOfBuiltin :: MonadError Error m => X.Builtin -> [X.Type] -> m Int +arityOfBuiltin builtin ts = case builtin of + X.Min2 -> return 2 + X.Max2 -> return 2 + X.Foldl -> return 3 + X.Iterate -> return 3 + X.At -> return 2 + X.Min1 -> return 1 + X.Max1 -> return 1 + X.Proj _ -> return 1 + builtin -> length . fst . X.uncurryFunTy <$> X.builtinToType builtin ts -runAppBuiltin :: (MonadAlpha m, MonadError Error m) => Env -> X.Builtin -> [X.Expr] -> m ([Y.Statement], Y.Expr) -runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinIsolated f) $ do - let go0 f = case args of - [] -> return ([], f) - _ -> throwInternalError $ "expected 0 arguments, got " ++ show (length args) - let go1'' :: (MonadAlpha m, MonadError Error m) => (X.Expr -> m ([Y.Statement], Y.Expr)) -> m ([Y.Statement], Y.Expr) - go1'' f = case args of +runAppBuiltin :: (MonadAlpha m, MonadError Error m) => Env -> X.Builtin -> [X.Type] -> [X.Expr] -> m ([Y.Statement], Y.Expr) +runAppBuiltin env f ts args = wrapError' ("converting builtin " ++ X.formatBuiltinIsolated f ts) $ do + let go0T f = case ts of + [] -> f + _ -> throwInternalError $ "expected 0 type arguments, got " ++ show (length ts) + let go1T' f = case ts of + [t1] -> f t1 + _ -> throwInternalError $ "expected 1 type argument, got " ++ show (length ts) + let go1T f = go1T' $ f <=< runType + let go2T' f = case ts of + [t1, t2] -> f t1 t2 + _ -> throwInternalError $ "expected 2 type arguments, got " ++ show (length ts) + let go0E f = case args of + [] -> f + _ -> throwInternalError $ "expected 0 type arguments, got " ++ show (length args) + let go1E' f = case args of [e1] -> f e1 - _ -> throwInternalError $ "expected 1 argument, got " ++ show (length args) - let go1' :: (MonadAlpha m, MonadError Error m) => (Y.Expr -> m ([Y.Statement], Y.Expr)) -> m ([Y.Statement], Y.Expr) - go1' f = go1'' $ \e1 -> do + _ -> throwInternalError $ "expected 1 type argument, got " ++ show (length args) + let go1E f = go1E' $ \e1 -> do (stmts1, e1) <- runExpr env e1 (stmts, e) <- f e1 return (stmts1 ++ stmts, e) - let go1 f = go1' (return . ([],) . f) - let go2'' :: (MonadAlpha m, MonadError Error m) => (X.Expr -> X.Expr -> m ([Y.Statement], Y.Expr)) -> m ([Y.Statement], Y.Expr) - go2'' f = case args of + let go2E' f = case args of [e1, e2] -> f e1 e2 - _ -> throwInternalError $ "expected 2 arguments, got " ++ show (length args) - let go2' :: (MonadAlpha m, MonadError Error m) => (Y.Expr -> Y.Expr -> m ([Y.Statement], Y.Expr)) -> m ([Y.Statement], Y.Expr) - go2' f = go2'' $ \e1 e2 -> do + _ -> throwInternalError $ "expected 2 type arguments, got " ++ show (length args) + let go2E f = go2E' $ \e1 e2 -> do (stmts1, e1) <- runExpr env e1 (stmts2, e2) <- runExpr env e2 (stmts, e) <- f e1 e2 return (stmts1 ++ stmts2 ++ stmts, e) - let go2 f = go2' (((return . ([],)) .) . f) - let go3'' :: (MonadAlpha m, MonadError Error m) => (X.Expr -> X.Expr -> X.Expr -> m ([Y.Statement], Y.Expr)) -> m ([Y.Statement], Y.Expr) - go3'' f = case args of + let go3E' f = case args of [e1, e2, e3] -> f e1 e2 e3 - _ -> throwInternalError $ "expected 3 arguments, got " ++ show (length args) - let go3' :: (MonadAlpha m, MonadError Error m) => (Y.Expr -> Y.Expr -> Y.Expr -> m ([Y.Statement], Y.Expr)) -> m ([Y.Statement], Y.Expr) - go3' f = go3'' $ \e1 e2 e3 -> do + _ -> throwInternalError $ "expected 2 type arguments, got " ++ show (length args) + let go3E f = go3E' $ \e1 e2 e3 -> do (stmts1, e1) <- runExpr env e1 (stmts2, e2) <- runExpr env e2 (stmts3, e3) <- runExpr env e3 (stmts, e) <- f e1 e2 e3 return (stmts1 ++ stmts2 ++ stmts3 ++ stmts, e) - let go3 f = go3' ((((return . ([],)) .) .) . f) - let goN' :: (MonadAlpha m, MonadError Error m) => ([Y.Expr] -> m Y.Expr) -> m ([Y.Statement], Y.Expr) - goN' f = do + let goP f = return ([], f) + let go00 f = go0T $ go0E $ goP f + let go01' :: (MonadAlpha m, MonadError Error m) => (Y.Expr -> m ([Y.Statement], Y.Expr)) -> m ([Y.Statement], Y.Expr) + go01' f = go0T $ go1E f + let go01 :: (MonadAlpha m, MonadError Error m) => (Y.Expr -> Y.Expr) -> m ([Y.Statement], Y.Expr) + go01 f = go0T $ go1E $ \e1 -> goP $ f e1 + let go11' :: (MonadAlpha m, MonadError Error m) => (Y.Type -> Y.Expr -> m ([Y.Statement], Y.Expr)) -> m ([Y.Statement], Y.Expr) + go11' f = go1T $ \t1 -> go1E $ \e1 -> f t1 e1 + let go11 :: (MonadAlpha m, MonadError Error m) => (Y.Type -> Y.Expr -> Y.Expr) -> m ([Y.Statement], Y.Expr) + go11 f = go1T $ \t1 -> go1E $ \e1 -> goP $ f t1 e1 + let go02' :: (MonadAlpha m, MonadError Error m) => (Y.Expr -> Y.Expr -> m ([Y.Statement], Y.Expr)) -> m ([Y.Statement], Y.Expr) + go02' f = go0T $ go2E f + let go02 :: (MonadAlpha m, MonadError Error m) => (Y.Expr -> Y.Expr -> Y.Expr) -> m ([Y.Statement], Y.Expr) + go02 f = go0T $ go2E $ \e1 e2 -> goP $ f e1 e2 + let go12'' :: (MonadAlpha m, MonadError Error m) => (X.Type -> X.Expr -> X.Expr -> m ([Y.Statement], Y.Expr)) -> m ([Y.Statement], Y.Expr) + go12'' f = go1T' $ \t1 -> go2E' $ \e1 e2 -> f t1 e1 e2 + let go12' :: (MonadAlpha m, MonadError Error m) => (Y.Type -> Y.Expr -> Y.Expr -> m ([Y.Statement], Y.Expr)) -> m ([Y.Statement], Y.Expr) + go12' f = go1T $ \t1 -> go2E $ \e1 e2 -> f t1 e1 e2 + let go12 :: (MonadAlpha m, MonadError Error m) => (Y.Type -> Y.Expr -> Y.Expr -> Y.Expr) -> m ([Y.Statement], Y.Expr) + go12 f = go1T $ \t1 -> go2E $ \e1 e2 -> goP $ f t1 e1 e2 + let go22'' :: (MonadAlpha m, MonadError Error m) => (X.Type -> X.Type -> X.Expr -> X.Expr -> m ([Y.Statement], Y.Expr)) -> m ([Y.Statement], Y.Expr) + go22'' f = go2T' $ \t1 t2 -> go2E' $ \e1 e2 -> f t1 t2 e1 e2 + let go03' :: (MonadAlpha m, MonadError Error m) => (Y.Expr -> Y.Expr -> Y.Expr -> m ([Y.Statement], Y.Expr)) -> m ([Y.Statement], Y.Expr) + go03' f = go0T $ go3E f + let go03 f = go0T $ go3E $ \e1 e2 e3 -> goP $ f e1 e2 e3 + let go13'' :: (MonadAlpha m, MonadError Error m) => (X.Type -> X.Expr -> X.Expr -> X.Expr -> m ([Y.Statement], Y.Expr)) -> m ([Y.Statement], Y.Expr) + go13'' f = go1T' $ \t1 -> go3E' $ \e1 e2 e3 -> f t1 e1 e2 e3 + let go13' :: (MonadAlpha m, MonadError Error m) => (Y.Type -> Y.Expr -> Y.Expr -> Y.Expr -> m ([Y.Statement], Y.Expr)) -> m ([Y.Statement], Y.Expr) + go13' f = go1T $ \t1 -> go3E $ \e1 e2 e3 -> f t1 e1 e2 e3 + let go23'' :: (MonadAlpha m, MonadError Error m) => (X.Type -> X.Type -> X.Expr -> X.Expr -> X.Expr -> m ([Y.Statement], Y.Expr)) -> m ([Y.Statement], Y.Expr) + go23'' f = go2T' $ \t1 t2 -> go3E' $ \e1 e2 e3 -> f t1 t2 e1 e2 e3 + let goN1 :: (MonadAlpha m, MonadError Error m) => ([Y.Type] -> Y.Expr -> Y.Expr) -> m ([Y.Statement], Y.Expr) + goN1 f = case args of + [e1] -> do + ts <- mapM runType ts + (stmts1, e1) <- runExpr env e1 + return (stmts1, f ts e1) + _ -> throwInternalError $ "expected 1 argument, got " ++ show (length args) + let goNN :: (MonadAlpha m, MonadError Error m) => ([Y.Type] -> [Y.Expr] -> Y.Expr) -> m ([Y.Statement], Y.Expr) + goNN f = do + ts <- mapM runType ts args <- mapM (runExpr env) args - e <- f (map snd args) + let e = f ts (map snd args) return (concatMap fst args, e) case f of -- arithmetical functions - X.Negate -> go1 $ \e -> Y.UnOp Y.Negate e - X.Plus -> go2 $ \e1 e2 -> Y.BinOp Y.Add e1 e2 - X.Minus -> go2 $ \e1 e2 -> Y.BinOp Y.Sub e1 e2 - X.Mult -> go2 $ \e1 e2 -> Y.BinOp Y.Mul e1 e2 - X.FloorDiv -> go2 $ \e1 e2 -> Y.Call (Y.Function "jikka::floordiv" []) [e1, e2] - X.FloorMod -> go2 $ \e1 e2 -> Y.Call (Y.Function "jikka::floormod" []) [e1, e2] - X.CeilDiv -> go2 $ \e1 e2 -> Y.Call (Y.Function "jikka::ceildiv" []) [e1, e2] - X.CeilMod -> go2 $ \e1 e2 -> Y.Call (Y.Function "jikka::ceilmod" []) [e1, e2] - X.Pow -> go2 $ \e1 e2 -> Y.Call (Y.Function "jikka::notmod::pow" []) [e1, e2] + X.Negate -> go01 $ \e -> Y.UnOp Y.Negate e + X.Plus -> go02 $ \e1 e2 -> Y.BinOp Y.Add e1 e2 + X.Minus -> go02 $ \e1 e2 -> Y.BinOp Y.Sub e1 e2 + X.Mult -> go02 $ \e1 e2 -> Y.BinOp Y.Mul e1 e2 + X.FloorDiv -> go02 $ \e1 e2 -> Y.Call (Y.Function "jikka::floordiv" []) [e1, e2] + X.FloorMod -> go02 $ \e1 e2 -> Y.Call (Y.Function "jikka::floormod" []) [e1, e2] + X.CeilDiv -> go02 $ \e1 e2 -> Y.Call (Y.Function "jikka::ceildiv" []) [e1, e2] + X.CeilMod -> go02 $ \e1 e2 -> Y.Call (Y.Function "jikka::ceilmod" []) [e1, e2] + X.Pow -> go02 $ \e1 e2 -> Y.Call (Y.Function "jikka::notmod::pow" []) [e1, e2] -- advanced arithmetical functions - X.Abs -> go1 $ \e -> Y.Call (Y.Function "std::abs" []) [e] - X.Gcd -> go2 $ \e1 e2 -> Y.Call (Y.Function "std::gcd" []) [e1, e2] - X.Lcm -> go2 $ \e1 e2 -> Y.Call (Y.Function "std::lcm" []) [e1, e2] - X.Min2 t -> go2' $ \e1 e2 -> do - t <- runType t - return ([], Y.Call (Y.Function "std::min" [t]) [e1, e2]) - X.Max2 t -> go2' $ \e1 e2 -> do - t <- runType t - return ([], Y.Call (Y.Function "std::max" [t]) [e1, e2]) - X.Iterate t -> go3'' $ \n f x -> do + X.Abs -> go01 $ \e -> Y.Call (Y.Function "std::abs" []) [e] + X.Gcd -> go02 $ \e1 e2 -> Y.Call (Y.Function "std::gcd" []) [e1, e2] + X.Lcm -> go02 $ \e1 e2 -> Y.Call (Y.Function "std::lcm" []) [e1, e2] + X.Min2 -> go12 $ \t e1 e2 -> Y.Call (Y.Function "std::min" [t]) [e1, e2] + X.Max2 -> go12 $ \t e1 e2 -> Y.Call (Y.Function "std::max" [t]) [e1, e2] + X.Iterate -> go13'' $ \t n f x -> do t <- runType t (stmtsN, n) <- runExpr env n (stmtsX, x) <- runExpr env x @@ -183,11 +218,11 @@ runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinI Y.Var y ) -- logical functions - X.Not -> go1 $ \e -> Y.UnOp Y.Not e - X.And -> go2 $ \e1 e2 -> Y.BinOp Y.And e1 e2 - X.Or -> go2 $ \e1 e2 -> Y.BinOp Y.Or e1 e2 - X.Implies -> go2 $ \e1 e2 -> Y.BinOp Y.Or (Y.UnOp Y.Not e1) e2 - X.If t -> go3'' $ \e1 e2 e3 -> do + X.Not -> go01 $ \e -> Y.UnOp Y.Not e + X.And -> go02 $ \e1 e2 -> Y.BinOp Y.And e1 e2 + X.Or -> go02 $ \e1 e2 -> Y.BinOp Y.Or e1 e2 + X.Implies -> go02 $ \e1 e2 -> Y.BinOp Y.Or (Y.UnOp Y.Not e1) e2 + X.If -> go13'' $ \t e1 e2 e3 -> do (stmts1, e1') <- runExpr env e1 (stmts2, e2') <- runExpr env e2 (stmts3, e3') <- runExpr env e3 @@ -201,35 +236,34 @@ runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinI let assign = Y.Assign . Y.AssignExpr Y.SimpleAssign (Y.LeftVar phi) return ([Y.Declare t phi Y.DeclareDefault] ++ stmts1 ++ [Y.If e1' (stmts2 ++ [assign e2']) (Just (stmts3 ++ [assign e3']))], Y.Var phi) -- bitwise functions - X.BitNot -> go1 $ \e -> Y.UnOp Y.BitNot e - X.BitAnd -> go2 $ \e1 e2 -> Y.BinOp Y.BitAnd e1 e2 - X.BitOr -> go2 $ \e1 e2 -> Y.BinOp Y.BitOr e1 e2 - X.BitXor -> go2 $ \e1 e2 -> Y.BinOp Y.BitXor e1 e2 - X.BitLeftShift -> go2 $ \e1 e2 -> Y.BinOp Y.BitLeftShift e1 e2 - X.BitRightShift -> go2 $ \e1 e2 -> Y.BinOp Y.BitRightShift e1 e2 + X.BitNot -> go01 $ \e -> Y.UnOp Y.BitNot e + X.BitAnd -> go02 $ \e1 e2 -> Y.BinOp Y.BitAnd e1 e2 + X.BitOr -> go02 $ \e1 e2 -> Y.BinOp Y.BitOr e1 e2 + X.BitXor -> go02 $ \e1 e2 -> Y.BinOp Y.BitXor e1 e2 + X.BitLeftShift -> go02 $ \e1 e2 -> Y.BinOp Y.BitLeftShift e1 e2 + X.BitRightShift -> go02 $ \e1 e2 -> Y.BinOp Y.BitRightShift e1 e2 -- matrix functions - X.MatAp h w -> go2 $ \f x -> Y.Call (Y.Function "jikka::mat::ap" [Y.TyIntValue (fromIntegral h), Y.TyIntValue (fromIntegral w)]) [f, x] - X.MatZero n -> go0 $ Y.Call (Y.Function "jikka::mat::zero" [Y.TyIntValue (fromIntegral n)]) [] - X.MatOne n -> go0 $ Y.Call (Y.Function "jikka::mat::one" [Y.TyIntValue (fromIntegral n)]) [] - X.MatAdd h w -> go2 $ \f g -> Y.Call (Y.Function "jikka::mat::add" [Y.TyIntValue (fromIntegral h), Y.TyIntValue (fromIntegral w)]) [f, g] - X.MatMul h n w -> go2 $ \f g -> Y.Call (Y.Function "jikka::mat::mul" [Y.TyIntValue (fromIntegral h), Y.TyIntValue (fromIntegral n), Y.TyIntValue (fromIntegral w)]) [f, g] - X.MatPow n -> go2 $ \f k -> Y.Call (Y.Function "jikka::mat::pow" [Y.TyIntValue (fromIntegral n)]) [f, k] - X.VecFloorMod n -> go2 $ \x m -> Y.Call (Y.Function "jikka::modmat::floormod" [Y.TyIntValue (fromIntegral n)]) [x, m] - X.MatFloorMod h w -> go2 $ \f m -> Y.Call (Y.Function "jikka::modmat::floormod" [Y.TyIntValue (fromIntegral h), Y.TyIntValue (fromIntegral w)]) [f, m] + X.MatAp h w -> go02 $ \f x -> Y.Call (Y.Function "jikka::mat::ap" [Y.TyIntValue (fromIntegral h), Y.TyIntValue (fromIntegral w)]) [f, x] + X.MatZero n -> go00 $ Y.Call (Y.Function "jikka::mat::zero" [Y.TyIntValue (fromIntegral n)]) [] + X.MatOne n -> go00 $ Y.Call (Y.Function "jikka::mat::one" [Y.TyIntValue (fromIntegral n)]) [] + X.MatAdd h w -> go02 $ \f g -> Y.Call (Y.Function "jikka::mat::add" [Y.TyIntValue (fromIntegral h), Y.TyIntValue (fromIntegral w)]) [f, g] + X.MatMul h n w -> go02 $ \f g -> Y.Call (Y.Function "jikka::mat::mul" [Y.TyIntValue (fromIntegral h), Y.TyIntValue (fromIntegral n), Y.TyIntValue (fromIntegral w)]) [f, g] + X.MatPow n -> go02 $ \f k -> Y.Call (Y.Function "jikka::mat::pow" [Y.TyIntValue (fromIntegral n)]) [f, k] + X.VecFloorMod n -> go02 $ \x m -> Y.Call (Y.Function "jikka::modmat::floormod" [Y.TyIntValue (fromIntegral n)]) [x, m] + X.MatFloorMod h w -> go02 $ \f m -> Y.Call (Y.Function "jikka::modmat::floormod" [Y.TyIntValue (fromIntegral h), Y.TyIntValue (fromIntegral w)]) [f, m] -- modular functions - X.ModNegate -> go2 $ \e1 e2 -> Y.Call (Y.Function "jikka::mod::negate" []) [e1, e2] - X.ModPlus -> go3 $ \e1 e2 e3 -> Y.Call (Y.Function "jikka::mod::plus" []) [e1, e2, e3] - X.ModMinus -> go3 $ \e1 e2 e3 -> Y.Call (Y.Function "jikka::mod::minus" []) [e1, e2, e3] - X.ModMult -> go3 $ \e1 e2 e3 -> Y.Call (Y.Function "jikka::mod::mult" []) [e1, e2, e3] - X.ModInv -> go2 $ \e1 e2 -> Y.Call (Y.Function "jikka::mod::inv" []) [e1, e2] - X.ModPow -> go3 $ \e1 e2 e3 -> Y.Call (Y.Function "jikka::mod::pow" []) [e1, e2, e3] - X.ModMatAp h w -> go3 $ \f x m -> Y.Call (Y.Function "jikka::modmat::ap" [Y.TyIntValue (fromIntegral h), Y.TyIntValue (fromIntegral w)]) [f, x, m] - X.ModMatAdd h w -> go3 $ \f g m -> Y.Call (Y.Function "jikka::modmat::add" [Y.TyIntValue (fromIntegral h), Y.TyIntValue (fromIntegral w)]) [f, g, m] - X.ModMatMul h n w -> go3 $ \f g m -> Y.Call (Y.Function "jikka::modmat::mul" [Y.TyIntValue (fromIntegral h), Y.TyIntValue (fromIntegral n), Y.TyIntValue (fromIntegral w)]) [f, g, m] - X.ModMatPow n -> go3 $ \f k m -> Y.Call (Y.Function "jikka::modmat::pow" [Y.TyIntValue (fromIntegral n)]) [f, k, m] + X.ModNegate -> go02 $ \e1 e2 -> Y.Call (Y.Function "jikka::mod::negate" []) [e1, e2] + X.ModPlus -> go03 $ \e1 e2 e3 -> Y.Call (Y.Function "jikka::mod::plus" []) [e1, e2, e3] + X.ModMinus -> go03 $ \e1 e2 e3 -> Y.Call (Y.Function "jikka::mod::minus" []) [e1, e2, e3] + X.ModMult -> go03 $ \e1 e2 e3 -> Y.Call (Y.Function "jikka::mod::mult" []) [e1, e2, e3] + X.ModInv -> go02 $ \e1 e2 -> Y.Call (Y.Function "jikka::mod::inv" []) [e1, e2] + X.ModPow -> go03 $ \e1 e2 e3 -> Y.Call (Y.Function "jikka::mod::pow" []) [e1, e2, e3] + X.ModMatAp h w -> go03 $ \f x m -> Y.Call (Y.Function "jikka::modmat::ap" [Y.TyIntValue (fromIntegral h), Y.TyIntValue (fromIntegral w)]) [f, x, m] + X.ModMatAdd h w -> go03 $ \f g m -> Y.Call (Y.Function "jikka::modmat::add" [Y.TyIntValue (fromIntegral h), Y.TyIntValue (fromIntegral w)]) [f, g, m] + X.ModMatMul h n w -> go03 $ \f g m -> Y.Call (Y.Function "jikka::modmat::mul" [Y.TyIntValue (fromIntegral h), Y.TyIntValue (fromIntegral n), Y.TyIntValue (fromIntegral w)]) [f, g, m] + X.ModMatPow n -> go03 $ \f k m -> Y.Call (Y.Function "jikka::modmat::pow" [Y.TyIntValue (fromIntegral n)]) [f, k, m] -- list functions - X.Cons t -> go2' $ \x xs -> do - t <- runType t + X.Cons -> go12' $ \t x xs -> do ys <- Y.newFreshName Y.LocalNameKind return ( [ Y.Declare (Y.TyVector t) ys Y.DeclareDefault, @@ -238,8 +272,7 @@ runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinI ], Y.Var ys ) - X.Snoc t -> go2' $ \xs x -> do - t <- runType t + X.Snoc -> go12' $ \t xs x -> do ys <- Y.newFreshName Y.LocalNameKind return ( [ Y.Declare (Y.TyVector t) ys (Y.DeclareCopy xs), @@ -247,7 +280,7 @@ runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinI ], Y.Var ys ) - X.Foldl t1 t2 -> go3'' $ \f init xs -> do + X.Foldl -> go23'' $ \t1 t2 f init xs -> do (stmtsInit, init) <- runExpr env init (stmtsXs, xs) <- runExpr env xs t1 <- runType t1 @@ -267,7 +300,7 @@ runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinI ], Y.Var y ) - X.Scanl _ t2 -> go3'' $ \f init xs -> do + X.Scanl -> go23'' $ \_ t2 f init xs -> do (stmtsInit, init) <- runExpr env init (stmtsXs, xs) <- runExpr env xs t2 <- runType t2 @@ -287,7 +320,7 @@ runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinI ], Y.Var ys ) - X.Build t -> go3'' $ \f xs n -> do + X.Build -> go13'' $ \t f xs n -> do (stmtsInit, xs) <- runExpr env xs (stmtsXs, n) <- runExpr env n t <- runType t @@ -306,8 +339,8 @@ runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinI ], Y.Var ys ) - X.Len _ -> go1 $ \e -> Y.cast Y.TyInt64 (Y.size e) - X.Map _ t2 -> go2'' $ \f xs -> do + X.Len -> go11 $ \_ e -> Y.cast Y.TyInt64 (Y.size e) + X.Map -> go22'' $ \_ t2 f xs -> do ys <- Y.newFreshName Y.LocalNameKind t2 <- runType t2 stmts <- case (f, xs) of @@ -331,7 +364,7 @@ runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinI (body ++ [Y.assignAt ys (Y.Var i) f]) ] return (stmts, Y.Var ys) - X.Filter t -> go2'' $ \f xs -> do + X.Filter -> go12'' $ \t f xs -> do (stmtsXs, xs) <- runExpr env xs t <- runType t ys <- Y.newFreshName Y.LocalNameKind @@ -355,9 +388,8 @@ runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinI ], Y.Var ys ) - X.At _ -> go2 $ \e1 e2 -> Y.at e1 e2 - X.SetAt t -> go3' $ \xs i x -> do - t <- runType t + X.At -> go12 $ \_ e1 e2 -> Y.at e1 e2 + X.SetAt -> go13' $ \t xs i x -> do ys <- Y.newFreshName Y.LocalNameKind return ( [ Y.Declare (Y.TyVector t) ys (Y.DeclareCopy xs), @@ -365,21 +397,21 @@ runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinI ], Y.Var ys ) - X.Elem _ -> go2' $ \xs x -> do + X.Elem -> go12' $ \_ xs x -> do y <- Y.newFreshName Y.LocalNameKind return ( [ Y.Declare Y.TyBool y (Y.DeclareCopy (Y.BinOp Y.NotEqual (Y.callFunction "std::find" [] [Y.begin xs, Y.end xs, x]) (Y.end xs))) ], Y.Var y ) - X.Sum -> go1' $ \xs -> do + X.Sum -> go01' $ \xs -> do y <- Y.newFreshName Y.LocalNameKind return ( [ Y.Declare Y.TyInt64 y (Y.DeclareCopy (Y.callFunction "std::accumulate" [] [Y.begin xs, Y.end xs, Y.litInt64 0])) ], Y.Var y ) - X.ModSum -> go2' $ \xs m -> do + X.ModSum -> go02' $ \xs m -> do y <- Y.newFreshName Y.LocalNameKind x <- Y.newFreshName Y.LocalNameKind return @@ -392,7 +424,7 @@ runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinI ], Y.callFunction "jikka::floormod" [] [Y.Var y, m] ) - X.Product -> go1' $ \xs -> do + X.Product -> go01' $ \xs -> do y <- Y.newFreshName Y.LocalNameKind x <- Y.newFreshName Y.LocalNameKind return @@ -405,7 +437,7 @@ runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinI ], Y.Var y ) - X.ModProduct -> go2' $ \xs m -> do + X.ModProduct -> go02' $ \xs m -> do y <- Y.newFreshName Y.LocalNameKind x <- Y.newFreshName Y.LocalNameKind return @@ -418,54 +450,49 @@ runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinI ], Y.Var y ) - X.Min1 t -> go1' $ \xs -> do - t <- runType t + X.Min1 -> go11' $ \t xs -> do y <- Y.newFreshName Y.LocalNameKind return ( [ Y.Declare t y (Y.DeclareCopy (Y.UnOp Y.Deref (Y.callFunction "std::min_element" [] [Y.begin xs, Y.end xs]))) ], Y.Var y ) - X.Max1 t -> go1' $ \xs -> do - t <- runType t + X.Max1 -> go11' $ \t xs -> do y <- Y.newFreshName Y.LocalNameKind return ( [ Y.Declare t y (Y.DeclareCopy (Y.UnOp Y.Deref (Y.callFunction "std::max_element" [] [Y.begin xs, Y.end xs]))) ], Y.Var y ) - X.ArgMin t -> go1' $ \xs -> do - t <- runType t + X.ArgMin -> go11' $ \t xs -> do y <- Y.newFreshName Y.LocalNameKind return ( [ Y.Declare t y (Y.DeclareCopy (Y.BinOp Y.Sub (Y.callFunction "std::min_element" [] [Y.begin xs, Y.end xs]) (Y.begin xs))) ], Y.Var y ) - X.ArgMax t -> go1' $ \xs -> do - t <- runType t + X.ArgMax -> go11' $ \t xs -> do y <- Y.newFreshName Y.LocalNameKind return ( [ Y.Declare t y (Y.DeclareCopy (Y.BinOp Y.Sub (Y.callFunction "std::max_element" [] [Y.begin xs, Y.end xs]) (Y.begin xs))) ], Y.Var y ) - X.All -> go1' $ \xs -> do + X.All -> go01' $ \xs -> do y <- Y.newFreshName Y.LocalNameKind return ( [ Y.Declare Y.TyBool y (Y.DeclareCopy (Y.BinOp Y.Equal (Y.callFunction "std::find" [] [Y.begin xs, Y.end xs, Y.Lit (Y.LitBool True)]) (Y.end xs))) ], Y.Var y ) - X.Any -> go1' $ \xs -> do + X.Any -> go01' $ \xs -> do y <- Y.newFreshName Y.LocalNameKind return ( [ Y.Declare Y.TyBool y (Y.DeclareCopy (Y.BinOp Y.NotEqual (Y.callFunction "std::find" [] [Y.begin xs, Y.end xs, Y.Lit (Y.LitBool False)]) (Y.end xs))) ], Y.Var y ) - X.Sorted t -> go1' $ \xs -> do - t <- runType t + X.Sorted -> go11' $ \t xs -> do ys <- Y.newFreshName Y.LocalNameKind return ( [ Y.Declare (Y.TyVector t) ys (Y.DeclareCopy xs), @@ -473,8 +500,7 @@ runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinI ], Y.Var ys ) - X.Reversed t -> go1' $ \xs -> do - t <- runType t + X.Reversed -> go11' $ \t xs -> do ys <- Y.newFreshName Y.LocalNameKind return ( [ Y.Declare (Y.TyVector t) ys (Y.DeclareCopy xs), @@ -482,8 +508,8 @@ runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinI ], Y.Var ys ) - X.Range1 -> go1 $ \n -> Y.Call Y.Range [n] - X.Range2 -> go2' $ \from to -> do + X.Range1 -> go01 $ \n -> Y.Call Y.Range [n] + X.Range2 -> go02' $ \from to -> do ys <- Y.newFreshName Y.LocalNameKind return ( [ Y.Declare (Y.TyVector Y.TyInt64) ys (Y.DeclareCopy (Y.vecCtor Y.TyInt64 [Y.BinOp Y.Sub to from])), @@ -491,7 +517,7 @@ runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinI ], Y.Var ys ) - X.Range3 -> go3' $ \from to step -> do + X.Range3 -> go03' $ \from to step -> do ys <- Y.newFreshName Y.LocalNameKind i <- Y.newFreshName Y.LoopCounterNameKind return @@ -508,37 +534,33 @@ runAppBuiltin env f args = wrapError' ("converting builtin " ++ X.formatBuiltinI Y.Var ys ) -- tuple functions - X.Tuple ts -> goN' $ \es -> do - ts <- mapM runType ts - return $ - if Y.shouldBeArray ts - then Y.Call (Y.ArrayExt (head ts)) es - else Y.Call (Y.StdTuple ts) es - X.Proj ts n -> go1' $ \e -> do - ts <- mapM runType ts - return . ([],) $ - if Y.shouldBeArray ts - then Y.at e (Y.Lit (Y.LitInt32 (fromIntegral n))) - else Y.Call (Y.StdGet (toInteger n)) [e] + X.Tuple -> goNN $ \ts es -> + if Y.shouldBeArray ts + then Y.Call (Y.ArrayExt (head ts)) es + else Y.Call (Y.StdTuple ts) es + X.Proj n -> goN1 $ \ts e -> + if Y.shouldBeArray ts + then Y.at e (Y.Lit (Y.LitInt32 (fromIntegral n))) + else Y.Call (Y.StdGet (toInteger n)) [e] -- comparison - X.LessThan _ -> go2 $ \e1 e2 -> Y.BinOp Y.LessThan e1 e2 - X.LessEqual _ -> go2 $ \e1 e2 -> Y.BinOp Y.LessEqual e1 e2 - X.GreaterThan _ -> go2 $ \e1 e2 -> Y.BinOp Y.GreaterThan e1 e2 - X.GreaterEqual _ -> go2 $ \e1 e2 -> Y.BinOp Y.GreaterEqual e1 e2 - X.Equal _ -> go2 $ \e1 e2 -> Y.BinOp Y.Equal e1 e2 - X.NotEqual _ -> go2 $ \e1 e2 -> Y.BinOp Y.NotEqual e1 e2 + X.LessThan -> go12 $ \_ e1 e2 -> Y.BinOp Y.LessThan e1 e2 + X.LessEqual -> go12 $ \_ e1 e2 -> Y.BinOp Y.LessEqual e1 e2 + X.GreaterThan -> go12 $ \_ e1 e2 -> Y.BinOp Y.GreaterThan e1 e2 + X.GreaterEqual -> go12 $ \_ e1 e2 -> Y.BinOp Y.GreaterEqual e1 e2 + X.Equal -> go12 $ \_ e1 e2 -> Y.BinOp Y.Equal e1 e2 + X.NotEqual -> go12 $ \_ e1 e2 -> Y.BinOp Y.NotEqual e1 e2 -- combinational functions - X.Fact -> go1 $ \e -> Y.Call (Y.Function "jikka::notmod::fact" []) [e] - X.Choose -> go2 $ \e1 e2 -> Y.Call (Y.Function "jikka::notmod::choose" []) [e1, e2] - X.Permute -> go2 $ \e1 e2 -> Y.Call (Y.Function "jikka::notmod::permute" []) [e1, e2] - X.MultiChoose -> go2 $ \e1 e2 -> Y.Call (Y.Function "jikka::notmod::multichoose" []) [e1, e2] + X.Fact -> go01 $ \e -> Y.Call (Y.Function "jikka::notmod::fact" []) [e] + X.Choose -> go02 $ \e1 e2 -> Y.Call (Y.Function "jikka::notmod::choose" []) [e1, e2] + X.Permute -> go02 $ \e1 e2 -> Y.Call (Y.Function "jikka::notmod::permute" []) [e1, e2] + X.MultiChoose -> go02 $ \e1 e2 -> Y.Call (Y.Function "jikka::notmod::multichoose" []) [e1, e2] -- data structures - X.ConvexHullTrickInit -> go0 $ Y.Call Y.ConvexHullTrickCtor [] - X.ConvexHullTrickGetMin -> go2 $ \cht x -> Y.Call (Y.Method "get_min") [cht, x] - X.ConvexHullTrickInsert -> go3 $ \cht a b -> Y.Call Y.ConvexHullTrickCopyAddLine [cht, a, b] - X.SegmentTreeInitList semigrp -> go1 $ \a -> Y.Call (Y.SegmentTreeCtor (runSemigroup semigrp)) [a] - X.SegmentTreeGetRange _ -> go3 $ \segtree l r -> Y.Call (Y.Method "prod") [segtree, l, r] - X.SegmentTreeSetPoint semigrp -> go3 $ \segtree i a -> Y.Call (Y.SegmentTreeCopySetPoint (runSemigroup semigrp)) [segtree, i, a] + X.ConvexHullTrickInit -> go00 $ Y.Call Y.ConvexHullTrickCtor [] + X.ConvexHullTrickGetMin -> go02 $ \cht x -> Y.Call (Y.Method "get_min") [cht, x] + X.ConvexHullTrickInsert -> go03 $ \cht a b -> Y.Call Y.ConvexHullTrickCopyAddLine [cht, a, b] + X.SegmentTreeInitList semigrp -> go01 $ \a -> Y.Call (Y.SegmentTreeCtor (runSemigroup semigrp)) [a] + X.SegmentTreeGetRange _ -> go03 $ \segtree l r -> Y.Call (Y.Method "prod") [segtree, l, r] + X.SegmentTreeSetPoint semigrp -> go03 $ \segtree i a -> Y.Call (Y.SegmentTreeCopySetPoint (runSemigroup semigrp)) [segtree, i, a] runExprFunction :: (MonadAlpha m, MonadError Error m) => Env -> X.Expr -> Y.Expr -> m ([Y.Statement], [Y.Statement], Y.Expr) runExprFunction env f e = case f of @@ -576,24 +598,24 @@ runExpr env = \case e@(X.App _ _) -> do let (f, args) = X.curryApp e case f of - X.Lit (X.LitBuiltin builtin) -> do - let arity = arityOfBuiltin builtin + X.Lit (X.LitBuiltin builtin bts) -> do + arity <- arityOfBuiltin builtin bts if length args < arity then do - let (ts, ret) = X.uncurryFunTy (X.builtinToType builtin) + (ts, ret) <- X.uncurryFunTy <$> X.builtinToType builtin bts ts <- mapM runType ts ret <- runType ret xs <- replicateM (arity - length args) X.genVarName' ys <- mapM (renameVarName' Y.LocalArgumentNameKind) xs - (stmts, e) <- runAppBuiltin env builtin (args ++ map X.Var xs) + (stmts, e) <- runAppBuiltin env builtin bts (args ++ map X.Var xs) let (_, e') = foldr (\(t, y) (ret, e) -> (Y.TyFunction ret [t], Y.Lam [(t, y)] ret [Y.Return e])) (ret, e) (zip (drop (length args) ts) ys) return (stmts, e') else if length args == arity then do - runAppBuiltin env builtin args + runAppBuiltin env builtin bts args else do - (stmts, e) <- runAppBuiltin env builtin (take arity args) + (stmts, e) <- runAppBuiltin env builtin bts (take arity args) args <- mapM (runExpr env) (drop arity args) return (concatMap fst args ++ stmts, Y.CallExpr e (map snd args)) _ -> do diff --git a/src/Jikka/Common/Alpha.hs b/src/Jikka/Common/Alpha.hs index e1458c5d..8b172f3a 100644 --- a/src/Jikka/Common/Alpha.hs +++ b/src/Jikka/Common/Alpha.hs @@ -22,6 +22,8 @@ import Control.Monad.Reader import Control.Monad.Signatures import Control.Monad.State.Strict import Control.Monad.Writer.Strict +import Data.Unique +import Language.Haskell.TH (Q) class Monad m => MonadAlpha m where nextCounter :: m Int @@ -84,3 +86,9 @@ evalAlpha f i = runIdentity (evalAlphaT f i) resetAlphaT :: Monad m => Int -> AlphaT m () resetAlphaT i = AlphaT $ \_ -> return ((), i) + +instance MonadAlpha IO where + nextCounter = hashUnique <$> newUnique + +instance MonadAlpha Q where + nextCounter = liftIO nextCounter diff --git a/src/Jikka/Core/Convert/ANormal.hs b/src/Jikka/Core/Convert/ANormal.hs index 1d818df2..84d6a269 100644 --- a/src/Jikka/Core/Convert/ANormal.hs +++ b/src/Jikka/Core/Convert/ANormal.hs @@ -60,7 +60,7 @@ runExpr env = \case f <- runExpr env f args <- mapM (runExpr env) args case (f, args) of - (Lit (LitBuiltin (If _)), [e1, e2, e3]) -> do + (Lit (LitBuiltin If _), [e1, e2, e3]) -> do (_, ctx, e1) <- destruct env e1 return $ ctx (App3 f e1 e2 e3) _ -> runApp env f args diff --git a/src/Jikka/Core/Convert/CumulativeSum.hs b/src/Jikka/Core/Convert/CumulativeSum.hs index a8f13c4f..acc6381e 100644 --- a/src/Jikka/Core/Convert/CumulativeSum.hs +++ b/src/Jikka/Core/Convert/CumulativeSum.hs @@ -49,7 +49,7 @@ rule = RewriteRule $ \_ -> \case then At' IntTy (Var b) n else Minus' (At' IntTy (Var b) (Plus' n (formatArithmeticalExpr shift))) (At' IntTy (Var b) (formatArithmeticalExpr shift)) return . Just $ - Let b (ListTy IntTy) (Scanl' IntTy IntTy (Lit (LitBuiltin Plus)) Lit0 a) e + Let b (ListTy IntTy) (Scanl' IntTy IntTy (Builtin Plus) Lit0 a) e _ -> return Nothing Max1' t (Cons' _ a0 (Map' _ _ (Lam x _ (At' _ a (Var x'))) (Range1' n))) | x' == x && x `isUnusedVar` a -> do Just <$> cumulativeMax (Max2' t) t (Just a0) a n diff --git a/src/Jikka/Core/Convert/MakeScanl.hs b/src/Jikka/Core/Convert/MakeScanl.hs index 62273fc6..84b49f60 100644 --- a/src/Jikka/Core/Convert/MakeScanl.hs +++ b/src/Jikka/Core/Convert/MakeScanl.hs @@ -31,6 +31,7 @@ module Jikka.Core.Convert.MakeScanl where import Control.Monad.Trans.Maybe +import Data.List import qualified Data.Map as M import Jikka.Common.Alpha import Jikka.Common.Error @@ -56,7 +57,7 @@ reduceScanlBuild = simpleRewriteRule $ \case _ -> Nothing -- | `getRecurrenceFormulaStep1` removes `At` in @body@. -getRecurrenceFormulaStep1 :: MonadAlpha m => Int -> Type -> VarName -> VarName -> Expr -> m (Maybe Expr) +getRecurrenceFormulaStep1 :: MonadAlpha m => Integer -> Type -> VarName -> VarName -> Expr -> m (Maybe Expr) getRecurrenceFormulaStep1 shift t a i body = do x <- genVarName a let proj k = @@ -78,13 +79,13 @@ getRecurrenceFormulaStep1 shift t a i body = do Nothing -> Nothing -- | `getRecurrenceFormulaStep` replaces `At` in @body@ with `Proj`. -getRecurrenceFormulaStep :: MonadAlpha m => Int -> Int -> Type -> VarName -> VarName -> Expr -> m (Maybe Expr) +getRecurrenceFormulaStep :: MonadAlpha m => Integer -> Integer -> Type -> VarName -> VarName -> Expr -> m (Maybe Expr) getRecurrenceFormulaStep shift size t a i body = do x <- genVarName a - let ts = replicate size t + let ts = replicate (fromInteger size) t let proj k = if 0 <= toInteger shift + k && toInteger shift + k < toInteger size - then Just $ Proj' ts (shift + fromInteger k) (Var x) + then Just $ Proj' ts (shift + k) (Var x) else Nothing let go :: Expr -> Maybe Expr go = \case @@ -129,9 +130,9 @@ reduceFoldlSetAtRecurrence = RewriteRule $ \_ -> \case _ -> do let ts = replicate (length base) t2 let base' = uncurryApp (Tuple' ts) base - step <- MaybeT $ getRecurrenceFormulaStep (- length base + fromInteger k) (length base) t2 a i step + step <- MaybeT $ getRecurrenceFormulaStep (- genericLength base + k) (genericLength base) t2 a i step x <- lift (genVarName a) - return $ foldr (Cons' t2) (Map' (TupleTy ts) t2 (Lam x (TupleTy ts) (Proj' ts (length base - 1) (Var x))) (Scanl' IntTy (TupleTy ts) step base' (Range1' n))) (init base) + return $ foldr (Cons' t2) (Map' (TupleTy ts) t2 (Lam x (TupleTy ts) (Proj' ts (genericLength base - 1) (Var x))) (Scanl' IntTy (TupleTy ts) step base' (Range1' n))) (init base) _ -> return Nothing -- | `checkAccumulationFormulaStep` checks that all `At` in @body@ about @a@ are @At a i@. diff --git a/src/Jikka/Core/Convert/MatrixExponentiation.hs b/src/Jikka/Core/Convert/MatrixExponentiation.hs index 0e24cdc6..e158607f 100644 --- a/src/Jikka/Core/Convert/MatrixExponentiation.hs +++ b/src/Jikka/Core/Convert/MatrixExponentiation.hs @@ -16,6 +16,7 @@ where import Control.Monad.Trans import Control.Monad.Trans.Maybe +import Data.List import qualified Data.Vector as V import Jikka.Common.Alpha import Jikka.Common.Error @@ -52,13 +53,13 @@ fromAffineMatrix a b = bottom = uncurryApp (Tuple' (replicate (w + 1) IntTy)) (replicate w (LitInt' 0) ++ [LitInt' 1]) in uncurryApp (Tuple' (replicate (h + 1) (TupleTy (replicate (w + 1) IntTy)))) (V.toList (V.zipWith go (unMatrix a) b) ++ [bottom]) -toMatrix :: MonadAlpha m => [(VarName, Type)] -> VarName -> Int -> Expr -> m (Maybe (Matrix ArithmeticalExpr, Maybe (V.Vector ArithmeticalExpr))) +toMatrix :: MonadAlpha m => [(VarName, Type)] -> VarName -> Integer -> Expr -> m (Maybe (Matrix ArithmeticalExpr, Maybe (V.Vector ArithmeticalExpr))) toMatrix env x n step = case curryApp step of (Tuple' _, es) -> runMaybeT $ do - xs <- V.fromList <$> replicateM n (lift (genVarName x)) + xs <- V.fromList <$> replicateM (fromInteger n) (lift (genVarName x)) let unpackTuple _ e = case e of - Proj' _ i (Var x') | x' == x -> Var (xs V.! i) + Proj' _ i (Var x') | x' == x -> Var (xs V.! fromInteger i) _ -> e rows <- MaybeT . return . forM es $ \e -> do let e' = mapExpr unpackTuple env e @@ -69,14 +70,14 @@ toMatrix env x n step = return (a, b) _ -> return Nothing -addOneToVector :: Int -> VarName -> Expr +addOneToVector :: Integer -> VarName -> Expr addOneToVector n x = - let ts = replicate n IntTy + let ts = replicate (fromInteger n) IntTy in uncurryApp (Tuple' (IntTy : ts)) (map (\i -> Proj' ts i (Var x)) [0 .. n - 1] ++ [LitInt' 1]) -removeOneFromVector :: Int -> VarName -> Expr +removeOneFromVector :: Integer -> VarName -> Expr removeOneFromVector n x = - let ts = replicate n IntTy + let ts = replicate (fromInteger n) IntTy in uncurryApp (Tuple' ts) (map (\i -> Proj' (IntTy : ts) i (Var x)) [0 .. n - 1]) rule :: MonadAlpha m => RewriteRule m @@ -93,7 +94,7 @@ rule = RewriteRule $ \env -> \case b' = Mult' (FloorDiv' (Minus' (Pow' a k) (LitInt' 1)) (Minus' a (LitInt' 1))) b -- This division has no remainder. in Just $ Plus' (Mult' a' base) b' Iterate' (TupleTy ts) k (Lam x _ step) base | isVectorTy' ts -> do - let n = length ts + let n = genericLength ts let go n step base = MatAp' n n (MatPow' n step k) base step <- toMatrix env x n step case step of diff --git a/src/Jikka/Core/Convert/PropagateMod.hs b/src/Jikka/Core/Convert/PropagateMod.hs index 31e1f418..d126d059 100644 --- a/src/Jikka/Core/Convert/PropagateMod.hs +++ b/src/Jikka/Core/Convert/PropagateMod.hs @@ -14,6 +14,7 @@ module Jikka.Core.Convert.PropagateMod ) where +import Data.List import Data.Maybe import Jikka.Common.Alpha import Jikka.Common.Error @@ -81,11 +82,11 @@ putFloorMod (Mod m) = LitInt' n -> case m of LitInt' m -> return' $ LitInt' (n `mod` m) _ -> return Nothing - Proj' ts i e | isVectorTy' ts -> return' $ Proj' ts i (VecFloorMod' (length ts) e m) + Proj' ts i e | isVectorTy' ts -> return' $ Proj' ts i (VecFloorMod' (genericLength ts) e m) Proj' ts i e | isMatrixTy' ts -> let (h, w) = fromJust (sizeOfMatrixTy (TupleTy ts)) - in return' $ Proj' ts i (MatFloorMod' h w e m) + in return' $ Proj' ts i (MatFloorMod' (toInteger h) (toInteger w) e m) Map' t1 t2 f xs -> do f <- putFloorMod (Mod m) f case f of @@ -144,7 +145,7 @@ putVecFloorMod env = putFloorModGeneric fallback fallback e (Mod m) = do t <- typecheckExpr env e case t of - TupleTy ts -> return $ VecFloorMod' (length ts) e m + TupleTy ts -> return $ VecFloorMod' (genericLength ts) e m _ -> throwInternalError $ "not a vector: " ++ formatType t putMatFloorMod :: (MonadError Error m, MonadAlpha m) => [(VarName, Type)] -> Mod -> Expr -> m Expr @@ -153,7 +154,7 @@ putMatFloorMod env = putFloorModGeneric fallback fallback e (Mod m) = do t <- typecheckExpr env e case t of - TupleTy ts@(TupleTy ts' : _) -> return $ MatFloorMod' (length ts) (length ts') e m + TupleTy ts@(TupleTy ts' : _) -> return $ MatFloorMod' (genericLength ts) (genericLength ts') e m _ -> throwInternalError $ "not a matrix: " ++ formatType t rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m diff --git a/src/Jikka/Core/Convert/SegmentTree.hs b/src/Jikka/Core/Convert/SegmentTree.hs index 3fed3859..37d2bdc0 100644 --- a/src/Jikka/Core/Convert/SegmentTree.hs +++ b/src/Jikka/Core/Convert/SegmentTree.hs @@ -69,26 +69,26 @@ pattern CumulativeSumFlip t f e es <- x2 = findUnusedVarName (VarName "x") f in Scanl' t t (Lam2 x1 t x2 t (App (App f (Var x2)) (Var x1))) e es -builtinToSemigroup :: Builtin -> Maybe Semigroup' -builtinToSemigroup = \case - Plus -> Just SemigroupIntPlus - Min2 IntTy -> Just SemigroupIntMin - Max2 IntTy -> Just SemigroupIntMax +builtinToSemigroup :: Builtin -> [Type] -> Maybe Semigroup' +builtinToSemigroup builtin ts = case (builtin, ts) of + (Plus, []) -> Just SemigroupIntPlus + (Min2, [IntTy]) -> Just SemigroupIntMin + (Max2, [IntTy]) -> Just SemigroupIntMax _ -> Nothing -semigroupToBuiltin :: Semigroup' -> Builtin +semigroupToBuiltin :: Semigroup' -> (Builtin, [Type]) semigroupToBuiltin = \case - SemigroupIntPlus -> Plus - SemigroupIntMin -> Min2 IntTy - SemigroupIntMax -> Max2 IntTy + SemigroupIntPlus -> (Plus, []) + SemigroupIntMin -> (Min2, [IntTy]) + SemigroupIntMax -> (Max2, [IntTy]) unCumulativeSum :: Expr -> Expr -> Maybe (Semigroup', Expr) unCumulativeSum a = \case - CumulativeSum _ (Lit (LitBuiltin op)) b a' | a' == a -> case builtinToSemigroup op of + CumulativeSum _ (Lit (LitBuiltin op ts)) b a' | a' == a -> case builtinToSemigroup op ts of Just semigrp -> Just (semigrp, b) Nothing -> Nothing -- Semigroups must be commutative to use CumulativeSumFlip. - CumulativeSumFlip _ (Lit (LitBuiltin op)) b a' | a' == a -> case builtinToSemigroup op of + CumulativeSumFlip _ (Lit (LitBuiltin op ts)) b a' | a' == a -> case builtinToSemigroup op ts of Just semigrp -> Just (semigrp, b) Nothing -> Nothing _ -> Nothing @@ -103,7 +103,7 @@ replaceWithSegtrees a segtrees = go M.empty go env = \case At' _ (check env -> Just (e, b, semigrp)) i -> let e' = SegmentTreeGetRange' semigrp e (LitInt' 0) i - in AppBuiltin2 (semigroupToBuiltin semigrp) b e' + in App2 (Lit (uncurry LitBuiltin (semigroupToBuiltin semigrp))) b e' Var x -> Var x Lit lit -> Lit lit App e1 e2 -> App (go env e1) (go env e2) @@ -113,10 +113,11 @@ replaceWithSegtrees a segtrees = go M.empty in case check env e1' of Just (e1', b, semigrp) -> go (M.insert x (e1', b, semigrp) env) e2 Nothing -> Let x t (go env e1) (go env e2) + check :: M.Map VarName (Expr, Expr, Semigroup') -> Expr -> Maybe (Expr, Expr, Semigroup') check env = \case Var x -> M.lookup x env - CumulativeSum _ (Lit (LitBuiltin op)) b (Var a') | a' == a -> case lookup op (map (first semigroupToBuiltin) segtrees) of - Just e -> Just (e, b, fromJust (builtinToSemigroup op)) + CumulativeSum _ (Lit (LitBuiltin op ts)) b (Var a') | a' == a -> case lookup (op, ts) (map (first semigroupToBuiltin) segtrees) of + Just e -> Just (e, b, fromJust (builtinToSemigroup op ts)) Nothing -> Nothing _ -> Nothing diff --git a/src/Jikka/Core/Convert/ShortCutFusion.hs b/src/Jikka/Core/Convert/ShortCutFusion.hs index 4992bfe5..695ba8b9 100644 --- a/src/Jikka/Core/Convert/ShortCutFusion.hs +++ b/src/Jikka/Core/Convert/ShortCutFusion.hs @@ -1,5 +1,7 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE ViewPatterns #-} -- | -- Module : Jikka.Core.Convert.ShortCutFusion @@ -36,8 +38,8 @@ import Jikka.Core.Format (formatExpr) import Jikka.Core.Language.BuiltinPatterns import Jikka.Core.Language.Expr import Jikka.Core.Language.FreeVars -import Jikka.Core.Language.LambdaPatterns import Jikka.Core.Language.Lint +import Jikka.Core.Language.QuasiRules import Jikka.Core.Language.RewriteRules import Jikka.Core.Language.Util @@ -48,52 +50,30 @@ import Jikka.Core.Language.Util -- * `Nil` and `Cons` are kept as is. reduceBuild :: MonadAlpha m => RewriteRule m reduceBuild = - let return' = return . Just - in RewriteRule $ \_ -> \case - Range2' l r -> do - let n = Minus' r l - x <- genVarName' - let f = Lam x IntTy (Plus' l (Var x)) - return' $ Map' IntTy IntTy f (Range1' n) - Range3' l r step -> do - let n = CeilDiv' (Minus' r l) step - x <- genVarName' - let f = Lam x IntTy (Plus' l (Mult' step (Var x))) - return' $ Map' IntTy IntTy f (Range1' n) - _ -> return Nothing + mconcat + [ [r| "range2" forall l r. range2 l r = map (fun i -> l + i) (range (r - l)) |], + [r| "range3" forall l r step. range3 l r step = map (fun i -> l + i * step) (range ((r - l) /^ step)) |] + ] reduceMapBuild :: MonadAlpha m => RewriteRule m reduceMapBuild = - let return' = return . Just - in RewriteRule $ \_ -> \case - -- reduce `Sorted` - Sorted' _ (Nil' t) -> return' $ Nil' t - Sorted' _ (Range1' n) -> return' $ Range1' n - -- reduce `Reversed` - Reversed' _ (Nil' t) -> return' $ Nil' t - Reversed' _ (Range1' n) -> do - x <- genVarName' - let f = Lam x IntTy (Minus' (Minus' n (Var x)) (LitInt' 1)) - return' $ Map' IntTy IntTy f n - -- reduce `Filter` - Filter' _ _ (Nil' t) -> return' $ Nil' t - -- reduce `Map` - Map' _ _ _ (Nil' t) -> return' $ Nil' t - Map' t1 t2 f (Cons' _ x xs) -> return' $ Cons' t2 (App f x) (Map' t1 t2 f xs) - -- others - _ -> return Nothing + mconcat + [ [r| "sorted/nil" sorted nil = nil |], + [r| "sorted/range" forall n. sorted (range n) = range n |], + [r| "reversed/nil" reversed nil = nil |], + [r| "reversed/range" forall n. reversed (range n) = map (fun i -> n - i - 1) (range n) |], + [r| "filter/nil" filter _ nil = nil |], + [r| "map/nil" map _ nil = nil |], + [r| "map/cons" forall f x xs. map f (cons x xs) = cons (f x) (map f xs) |] + ] reduceMap :: Monad m => RewriteRule m reduceMap = - let return' = return . Just - in RewriteRule $ \_ -> \case - -- reduce `Map` - Map' _ _ (LamId _) xs -> return' xs - -- reduce `Filter` - Filter' t (Lam _ _ LitFalse) _ -> return' (Nil' t) - Filter' _ (Lam _ _ LitTrue) xs -> return' xs - -- others - _ -> return Nothing + mconcat + [ [r| "map/id" forall xs. map (fun x -> x) xs = xs |], + [r| "filter/const-false" forall xs. filter (fun _ -> false) xs = nil |], + [r| "filter/const-true" forall xs. filter (fun _ -> true) xs = xs |] + ] -- | -- * Functions are reordered as: @@ -102,57 +82,36 @@ reduceMap = -- * `Filter` (funcitons to reduce lengths) is firstly applied to lists reduceMapMap :: MonadAlpha m => RewriteRule m reduceMapMap = - let return' = return . Just - in RewriteRule $ \_ -> \case - -- reduce `Map` - Map' _ _ (LamId _) xs -> return' xs - Map' _ t3 g (Map' t1 _ f xs) -> do - x <- genVarName' - let h = Lam x t1 (App g (App f (Var x))) - return' $ Map' t1 t3 h xs - Map' t1 t2 f (Reversed' _ xs) -> return' $ Reversed' t2 (Map' t1 t2 f xs) - -- reduce `Filter` - Filter' t2 g (Map' t1 _ f xs) -> do - x <- genVarName' - let h = Lam x t1 (App g (App f (Var x))) - return' $ Map' t1 t2 f (Filter' t1 h xs) - Filter' t g (Filter' _ f xs) -> do - x <- genVarName' - let h = Lam x t (And' (App g (Var x)) (App f (Var x))) - return' $ Filter' t h xs - Filter' t f (Sorted' _ xs) -> return' $ Sorted' t (Filter' t f xs) - Filter' t f (Reversed' _ xs) -> return' $ Reversed' t (Filter' t f xs) - -- reduce `Reversed` - Reversed' _ (Reversed' _ xs) -> return' xs - Reversed' _ (Map' t1 t2 f xs) -> return' $ Map' t1 t2 f (Reversed' t1 xs) - -- reduce `Sorted` - Sorted' t (Reversed' _ xs) -> return' $ Sorted' t xs - Sorted' t (Sorted' _ xs) -> return' $ Sorted' t xs - _ -> return Nothing + mconcat + [ [r| "map/map" forall f g xs. map g (map f xs) = map (fun x -> g (f x)) xs |], + [r| "map/reversed" forall f xs. map f (reversed xs) = reversed (map f xs) |], + [r| "filter/filter" forall f g xs. filter g (filter f xs) = filter (fun x -> f x and g x) xs |], + [r| "filter/sorted" forall f xs. filter f (sorted xs) = sorted (filter f xs) |], + [r| "filter/reversed" forall f xs. filter f (reversed xs) = reversed (filter f xs) |], + [r| "reversed/reversed" forall xs. reversed (reversed xs) = xs |], + [r| "sorted/reversed" forall xs. sorted (reversed xs) = sorted xs |], + [r| "sorted/sorted" forall xs. sorted (sorted xs) = sorted xs |] + ] reduceFoldMap :: MonadAlpha m => RewriteRule m reduceFoldMap = - let return' = return . Just - in RewriteRule $ \_ -> \case - -- reduce `Reversed` - Len' t (Reversed' _ xs) -> return' $ Len' t xs - Elem' t x (Reversed' _ xs) -> return' $ Elem' t x xs - At' t (Reversed' _ xs) i -> return' $ At' t xs (Minus' (Minus' (Len' t xs) i) Lit1) - -- reduce `Sorted` - Len' t (Sorted' _ xs) -> return' $ Len' t xs - Elem' t x (Sorted' _ xs) -> return' $ Elem' t x xs - -- reduce `Map` - Len' _ (Map' t1 _ _ xs) -> return' $ Len' t1 xs - At' _ (Map' t1 _ f xs) i -> return' $ App f (At' t1 xs i) - Foldl' _ t3 g init (Map' t1 _ f xs) -> do - x3 <- genVarName' - x1 <- genVarName' - return' $ Foldl' t1 t3 (Lam2 x3 t3 x1 t1 (App2 g (Var x3) (App f (Var x1)))) init xs - -- others - Len' t (SetAt' _ xs _ _) -> return' $ Len' t xs - Len' t (Scanl' _ _ _ _ xs) -> return' $ Plus' (Len' t xs) (LitInt' 1) - At' t (SetAt' _ xs i' x) i -> return' $ If' t (Equal' IntTy i' i) x (At' t xs i) - _ -> return Nothing + mconcat + [ -- reduce `Reversed` + [r| "len/reversed" forall xs. len (reversed xs) = len xs |], + [r| "elem/reversed" forall x xs. elem x (reversed xs) = elem x xs |], + [r| "at/reversed" forall xs i. (reversed xs)[i] = xs[len(xs) - i - 1] |], + -- reduce `Sorted` + [r| "len/sorted" forall xs. len (sorted xs) = len xs |], + [r| "elem/sorted" forall x xs. elem x (sorted xs) = elem x xs |], + -- reduce `Map` + [r| "len/map" forall f xs. len (map f xs) = len xs |], + [r| "at/map" forall f xs i. (map f xs)[i] = f xs[i] |], + [r| "foldl/map" forall g init f xs. foldl g init (map f xs) = foldl (fun y x -> g y (f x)) init xs|], + -- others + [r| "len/setat" forall xs i x. len xs[i <- x] = len xs |], + [r| "len/scanl" forall f init xs. len (scanl f init xs) = len xs + 1 |], + [r| "at/setat" forall xs i x j. xs[i <- x][j] = if i == j then x else xs[j] |] + ] reduceFold :: Monad m => RewriteRule m reduceFold = simpleRewriteRule $ \case @@ -161,26 +120,27 @@ reduceFold = simpleRewriteRule $ \case reduceFoldBuild :: MonadAlpha m => RewriteRule m reduceFoldBuild = - let return' = return . Just - in RewriteRule $ \_ -> \case - -- reduce `Foldl` - Foldl' _ _ _ init (Nil' _) -> return' init - Foldl' t1 t2 g init (Cons' _ x xs) -> return' $ Foldl' t1 t2 g (App2 g init x) xs - -- reduce `Len` - Len' _ (Nil' _) -> return' Lit0 - Len' t (Cons' _ _ xs) -> return' $ Plus' Lit1 (Len' t xs) - Len' _ (Range1' n) -> return' n - -- reduce `At` - At' t (Nil' _) i -> return' $ Bottom' t $ "cannot subscript empty list: index = " ++ formatExpr i - At' t (Cons' _ x xs) i -> return' $ If' t (Equal' IntTy i Lit0) x (At' t xs (Minus' i Lit1)) - At' _ (Range1' _) i -> return' i - -- reduce `Elem` - Elem' _ _ (Nil' _) -> return' LitFalse - Elem' t y (Cons' _ x xs) -> return' $ And' (Equal' t x y) (Elem' t y xs) - Elem' _ x (Range1' n) -> return' $ And' (LessEqual' IntTy Lit0 x) (LessThan' IntTy x n) - -- others - Len' t (Build' _ _ base n) -> return' $ Plus' (Len' t base) n - _ -> return Nothing + mconcat + [ -- reduce `Foldl` + [r| "foldl/nil" forall f init. foldl f init nil = init |], + [r| "foldl/cons" forall f init x xs. foldl f init (cons x xs) = foldl f (f init x) xs |], + -- reduce `Len` + [r| "len/nil" len nil = 0 |], + [r| "len/cons" forall x xs. len (cons x xs) = 1 + len xs |], + [r| "len/range" forall n. len (range n) = n |], + -- reduce `At` + simpleRewriteRule $ \case + At' t (Nil' _) i -> Just $ Bottom' t $ "cannot subscript empty list: index = " ++ formatExpr i + _ -> Nothing, + [r| "at/cons" forall x xs i. (cons x xs)[i] = if i == 0 then x else xs[i - 1] |], + [r| "at/range" forall n i. (range n)[i] = i |], + -- reduce `Elem` + [r| "elem/nil" forall y. elem y nil = false |], + [r| "elem/cons" forall y x xs. elem y (cons x xs) = y == x or elem y xs |], + [r| "elem/range" forall i n. elem i (range n) = 0 <= i and i < n |], + -- others + [r| "len/build" forall f base n. len (build f base n) = len base + n |] + ] rule :: MonadAlpha m => RewriteRule m rule = diff --git a/src/Jikka/Core/Convert/TypeInfer.hs b/src/Jikka/Core/Convert/TypeInfer.hs index 03d9d1a1..e268f550 100644 --- a/src/Jikka/Core/Convert/TypeInfer.hs +++ b/src/Jikka/Core/Convert/TypeInfer.hs @@ -11,6 +11,8 @@ -- Portability : portable module Jikka.Core.Convert.TypeInfer ( run, + runExpr, + runRule, -- * internal types and functions Equation (..), @@ -35,7 +37,7 @@ import Jikka.Core.Format (formatType) import Jikka.Core.Language.Expr import Jikka.Core.Language.FreeVars import Jikka.Core.Language.Lint -import Jikka.Core.Language.TypeCheck (literalToType, typecheckProgram) +import Jikka.Core.Language.TypeCheck (literalToType, typecheckExpr, typecheckProgram) import Jikka.Core.Language.Util data Equation @@ -51,13 +53,13 @@ formularizeType t1 t2 = tell $ Dual [TypeEquation t1 t2] formularizeVarName :: MonadWriter Eqns m => VarName -> Type -> m () formularizeVarName x t = tell $ Dual [TypeAssertion x t] -formularizeExpr :: (MonadWriter Eqns m, MonadAlpha m) => Expr -> m Type +formularizeExpr :: (MonadWriter Eqns m, MonadAlpha m, MonadError Error m) => Expr -> m Type formularizeExpr = \case Var x -> do t <- genType formularizeVarName x t return t - Lit lit -> return $ literalToType lit + Lit lit -> literalToType lit App f e -> do ret <- genType t <- formularizeExpr e @@ -72,12 +74,12 @@ formularizeExpr = \case formularizeExpr' e1 t formularizeExpr e2 -formularizeExpr' :: (MonadWriter Eqns m, MonadAlpha m) => Expr -> Type -> m () +formularizeExpr' :: (MonadWriter Eqns m, MonadAlpha m, MonadError Error m) => Expr -> Type -> m () formularizeExpr' e t = do t' <- formularizeExpr e formularizeType t t' -formularizeToplevelExpr :: (MonadWriter Eqns m, MonadAlpha m) => ToplevelExpr -> m Type +formularizeToplevelExpr :: (MonadWriter Eqns m, MonadAlpha m, MonadError Error m) => ToplevelExpr -> m Type formularizeToplevelExpr = \case ResultExpr e -> formularizeExpr e ToplevelLet x t e cont -> do @@ -90,7 +92,7 @@ formularizeToplevelExpr = \case formularizeExpr' body ret formularizeToplevelExpr cont -formularizeProgram :: MonadAlpha m => Program -> m [Equation] +formularizeProgram :: (MonadAlpha m, MonadError Error m) => Program -> m [Equation] formularizeProgram prog = getDual <$> execWriterT (formularizeToplevelExpr prog) sortEquations :: [Equation] -> ([(Type, Type)], [(VarName, Type)]) @@ -169,39 +171,14 @@ substUnit = \case FunTy t ret -> FunTy (substUnit t) (substUnit ret) DataStructureTy ds -> DataStructureTy ds --- | `subst'` does `subst` and replaces all undetermined type variables with the unit type. subst' :: Subst -> Type -> Type subst' sigma = substUnit . subst sigma -substBuiltin :: Subst -> Builtin -> Builtin -substBuiltin sigma = mapTypeInBuiltin (subst' sigma) - -substLiteral :: Subst -> Literal -> Literal -substLiteral sigma = \case - LitBuiltin builtin -> LitBuiltin (substBuiltin sigma builtin) - LitInt n -> LitInt n - LitBool p -> LitBool p - LitNil t -> LitNil (subst' sigma t) - LitBottom t err -> LitBottom (subst' sigma t) err +substProgram :: Subst -> Program -> Program +substProgram sigma = mapTypeProgram (subst' sigma) substExpr :: Subst -> Expr -> Expr -substExpr sigma = go - where - go = \case - Var x -> Var x - Lit lit -> Lit (substLiteral sigma lit) - App f e -> App (go f) (go e) - Lam x t body -> Lam x (subst' sigma t) (go body) - Let x t e1 e2 -> Let x (subst sigma t) (go e1) (go e2) - -substToplevelExpr :: Subst -> ToplevelExpr -> ToplevelExpr -substToplevelExpr sigma = \case - ResultExpr e -> ResultExpr (substExpr sigma e) - ToplevelLet x t e cont -> ToplevelLet x (subst' sigma t) (substExpr sigma e) (substToplevelExpr sigma cont) - ToplevelLetRec f args ret body cont -> ToplevelLetRec f (map (second (subst' sigma)) args) (subst' sigma ret) (substExpr sigma body) (substToplevelExpr sigma cont) - -substProgram :: Subst -> Program -> Program -substProgram = substToplevelExpr +substExpr sigma = mapTypeExpr (subst' sigma) -- | `run` does type inference. -- @@ -228,3 +205,28 @@ run prog = wrapError' "Jikka.Core.Convert.TypeInfer" $ do postcondition $ do typecheckProgram prog return prog + +runExpr :: (MonadAlpha m, MonadError Error m) => [(VarName, Type)] -> Expr -> m Expr +runExpr env e = wrapError' "Jikka.Core.Convert.TypeInfer" $ do + eqns <- getDual <$> execWriterT (formularizeExpr e) + let (eqns', assertions) = sortEquations eqns + let eqns'' = mergeAssertions assertions + sigma <- solveEquations (eqns' ++ eqns'') + env <- return $ map (second (subst' sigma)) env + e <- return $ substExpr sigma e + postcondition $ do + typecheckExpr env e + return e + +runRule :: (MonadAlpha m, MonadError Error m) => [(VarName, Type)] -> Expr -> Expr -> m ([(VarName, Type)], Expr, Expr) +runRule args e1 e2 = wrapError' "Jikka.Core.Convert.TypeInfer" $ do + eqns <- (getDual <$>) . execWriterT $ do + t <- formularizeExpr e1 + formularizeExpr' e2 t + let (eqns', assertions) = sortEquations eqns + let eqns'' = mergeAssertions assertions + sigma <- solveEquations (eqns' ++ eqns'') + args <- return $ map (second (subst sigma)) args -- don't use substUnit + e1 <- return $ mapTypeExpr (subst sigma) e1 -- don't use substUnit + e2 <- return $ mapTypeExpr (subst sigma) e2 -- don't use substUnit + return (args, e1, e2) diff --git a/src/Jikka/Core/Convert/UnpackTuple.hs b/src/Jikka/Core/Convert/UnpackTuple.hs index 07edd122..f41d0b9d 100644 --- a/src/Jikka/Core/Convert/UnpackTuple.hs +++ b/src/Jikka/Core/Convert/UnpackTuple.hs @@ -43,8 +43,8 @@ rule = _ -> return Nothing App (Tuple' [_]) (Proj' [_] 0 e) -> return' e Proj' ts i e -> case curryApp e of - (Tuple' _, es) -> return' $ es !! i - (Lit (LitBuiltin (If _)), [e1, e2, e3]) -> return' $ If' (ts !! i) e1 (Proj' ts i e2) (Proj' ts i e3) + (Tuple' _, es) -> return' $ es !! fromInteger i + (Lit (LitBuiltin If _), [e1, e2, e3]) -> return' $ If' (ts !! fromInteger i) e1 (Proj' ts i e2) (Proj' ts i e3) _ -> return Nothing Foldl' t2 (TupleTy [t1]) (Lam x1 (TupleTy [_]) (Lam x2 _ body)) e es -> do body' <- substitute x1 (App (Tuple' [t1]) (Var x1)) (Proj' [t1] 0 body) diff --git a/src/Jikka/Core/Evaluate.hs b/src/Jikka/Core/Evaluate.hs index bdab8eec..2b080ffc 100644 --- a/src/Jikka/Core/Evaluate.hs +++ b/src/Jikka/Core/Evaluate.hs @@ -125,30 +125,30 @@ build f xs n = do -- ----------------------------------------------------------------------------- -- evaluator -callBuiltin :: MonadError Error m => Builtin -> [Value] -> m Value -callBuiltin builtin args = wrapError' ("while calling builtin " ++ formatBuiltinIsolated builtin) $ do +callBuiltin :: MonadError Error m => Builtin -> [Type] -> [Value] -> m Value +callBuiltin builtin ts args = wrapError' ("while calling builtin " ++ formatBuiltinIsolated builtin ts) $ do let go0 ret f = callValue (ret f) args let go1' t1 ret f = case args of v1 : args -> do f <- ret <$> (f =<< t1 v1) callValue f args - _ -> return $ ValBuiltin builtin args + _ -> return $ ValBuiltin builtin ts args let go1 t1 ret f = go1' t1 ret (return . f) let go2' t1 t2 ret f = case args of v1 : v2 : args -> do f <- ret <$> join (f <$> t1 v1 <*> t2 v2) callValue f args - _ -> return $ ValBuiltin builtin args + _ -> return $ ValBuiltin builtin ts args let go2 t1 t2 ret f = go2' t1 t2 ret ((return .) . f) let go3' t1 t2 t3 ret f = case args of v1 : v2 : v3 : args -> do f <- ret <$> join (f <$> t1 v1 <*> t2 v2 <*> t3 v3) callValue f args - _ -> return $ ValBuiltin builtin args + _ -> return $ ValBuiltin builtin ts args let go3 t1 t2 t3 ret f = go3' t1 t2 t3 ret (((return .) .) . f) let goN n t ret f = if length args < n - then return $ ValBuiltin builtin args + then return $ ValBuiltin builtin ts args else do f <- ret . f <$> mapM t (take n args) callValue f (drop n args) @@ -167,15 +167,15 @@ callBuiltin builtin args = wrapError' ("while calling builtin " ++ formatBuiltin Abs -> go1 valueToInt ValInt abs Gcd -> go2 valueToInt valueToInt ValInt gcd Lcm -> go2 valueToInt valueToInt ValInt lcm - Min2 _ -> go2 pure pure id minValue - Max2 _ -> go2 pure pure id maxValue - Iterate _ -> go3' valueToInt pure pure id $ \n step base -> iterate' n step base + Min2 -> go2 pure pure id minValue + Max2 -> go2 pure pure id maxValue + Iterate -> go3' valueToInt pure pure id $ \n step base -> iterate' n step base -- logical functions Not -> go1 valueToBool ValBool not And -> go2 valueToBool valueToBool ValBool (&&) Or -> go2 valueToBool valueToBool ValBool (||) Implies -> go2 valueToBool valueToBool ValBool $ \p q -> not p || q - If _ -> go3 valueToBool pure pure id $ \p a b -> if p then a else b + If -> go3 valueToBool pure pure id $ \p a b -> if p then a else b -- bitwise functions BitNot -> go1 valueToInt ValInt complement BitAnd -> go2 valueToInt valueToInt ValInt (.&.) @@ -185,8 +185,8 @@ callBuiltin builtin args = wrapError' ("while calling builtin " ++ formatBuiltin BitRightShift -> go2 valueToInt valueToInt ValInt $ \a b -> a `shift` fromInteger (- b) -- matrix functions MatAp _ _ -> go2' valueToMatrix valueToVector valueFromVector matap' - MatZero n -> go0 valueFromMatrix (matzero n) - MatOne n -> go0 valueFromMatrix (matone n) + MatZero n -> go0 valueFromMatrix (matzero (fromInteger n)) + MatOne n -> go0 valueFromMatrix (matone (fromInteger n)) MatAdd _ _ -> go2' valueToMatrix valueToMatrix valueFromMatrix matadd' MatMul _ _ _ -> go2' valueToMatrix valueToMatrix valueFromMatrix matmul' MatPow _ -> go2' valueToMatrix valueToInt valueFromMatrix matpow' @@ -204,42 +204,42 @@ callBuiltin builtin args = wrapError' ("while calling builtin " ++ formatBuiltin ModMatMul _ _ _ -> go3' pure pure valueToInt valueFromModMatrix $ \f g m -> join (matmul' <$> valueToModMatrix m f <*> valueToModMatrix m g) ModMatPow _ -> go3' pure valueToInt valueToInt valueFromModMatrix $ \f k m -> join (matpow' <$> valueToModMatrix m f <*> pure k) -- list functions - Cons _ -> go2 pure valueToList ValList V.cons - Snoc _ -> go2 valueToList pure ValList V.snoc - Foldl _ _ -> go3' pure pure valueToList id $ \f x a -> V.foldM (\x y -> callValue f [x, y]) x a - Scanl _ _ -> go3' pure pure valueToList ValList $ \f x a -> scanM (\x y -> callValue f [x, y]) x a - Build _ -> go3' pure valueToList valueToInt ValList $ \f xs n -> build (\xs -> callValue f [ValList xs]) xs n - Len _ -> go1 valueToList ValInt (fromIntegral . V.length) - Map _ _ -> go2' pure valueToList ValList map' - Filter _ -> go2' pure valueToList ValList $ \f xs -> V.filterM (\x -> (/= ValBool False) <$> callValue f [x]) xs - At _ -> go2' valueToList valueToInt id atEither - SetAt _ -> go3' valueToList valueToInt pure ValList setAtEither - Elem _ -> go2 pure valueToList ValBool V.elem + Cons -> go2 pure valueToList ValList V.cons + Snoc -> go2 valueToList pure ValList V.snoc + Foldl -> go3' pure pure valueToList id $ \f x a -> V.foldM (\x y -> callValue f [x, y]) x a + Scanl -> go3' pure pure valueToList ValList $ \f x a -> scanM (\x y -> callValue f [x, y]) x a + Build -> go3' pure valueToList valueToInt ValList $ \f xs n -> build (\xs -> callValue f [ValList xs]) xs n + Len -> go1 valueToList ValInt (fromIntegral . V.length) + Map -> go2' pure valueToList ValList map' + Filter -> go2' pure valueToList ValList $ \f xs -> V.filterM (\x -> (/= ValBool False) <$> callValue f [x]) xs + At -> go2' valueToList valueToInt id atEither + SetAt -> go3' valueToList valueToInt pure ValList setAtEither + Elem -> go2 pure valueToList ValBool V.elem Sum -> go1 valueToIntList ValInt sum ModSum -> go2 valueToIntList valueToInt ValInt $ \xs m -> sum xs `mod` m Product -> go1 valueToIntList ValInt product ModProduct -> go2 valueToIntList valueToInt ValInt $ \xs m -> product xs `mod` m - Min1 _ -> go1 valueToList id (V.minimumBy compareValues') - Max1 _ -> go1 valueToList id (V.maximumBy compareValues') - ArgMin _ -> go1 valueToList ValInt $ \xs -> snd (minimumBy (\(x, i) (y, j) -> compareValues' x y <> compare i j) (zip (V.toList xs) [0 ..])) - ArgMax _ -> go1 valueToList ValInt $ \xs -> snd (maximumBy (\(x, i) (y, j) -> compareValues' x y <> compare i j) (zip (V.toList xs) [0 ..])) + Min1 -> go1 valueToList id (V.minimumBy compareValues') + Max1 -> go1 valueToList id (V.maximumBy compareValues') + ArgMin -> go1 valueToList ValInt $ \xs -> snd (minimumBy (\(x, i) (y, j) -> compareValues' x y <> compare i j) (zip (V.toList xs) [0 ..])) + ArgMax -> go1 valueToList ValInt $ \xs -> snd (maximumBy (\(x, i) (y, j) -> compareValues' x y <> compare i j) (zip (V.toList xs) [0 ..])) All -> go1 valueToBoolList ValBool and Any -> go1 valueToBoolList ValBool or - Sorted _ -> go1 valueToList ValList sortVector - Reversed _ -> go1 valueToList ValList V.reverse + Sorted -> go1 valueToList ValList sortVector + Reversed -> go1 valueToList ValList V.reverse Range1 -> go1' valueToInt ValList range1 Range2 -> go2' valueToInt valueToInt ValList range2 Range3 -> go3' valueToInt valueToInt valueToInt ValList range3 -- tuple functions - Tuple ts -> goN (length ts) pure ValTuple id - Proj _ n -> go1 valueToTuple id (!! n) + Tuple -> goN (length ts) pure ValTuple id + Proj n -> go1 valueToTuple id (!! fromInteger n) -- -- comparison - LessThan _ -> go2 pure pure ValBool $ \a b -> compareValues a b == Just LT - LessEqual _ -> go2 pure pure ValBool $ \a b -> compareValues a b /= Just GT - GreaterThan _ -> go2 pure pure ValBool $ \a b -> compareValues a b == Just GT - GreaterEqual _ -> go2 pure pure ValBool $ \a b -> compareValues a b /= Just LT - Equal _ -> go2 pure pure ValBool (==) - NotEqual _ -> go2 pure pure ValBool (/=) + LessThan -> go2 pure pure ValBool $ \a b -> compareValues a b == Just LT + LessEqual -> go2 pure pure ValBool $ \a b -> compareValues a b /= Just GT + GreaterThan -> go2 pure pure ValBool $ \a b -> compareValues a b == Just GT + GreaterEqual -> go2 pure pure ValBool $ \a b -> compareValues a b /= Just LT + Equal -> go2 pure pure ValBool (==) + NotEqual -> go2 pure pure ValBool (/=) -- combinational functions Fact -> go1' valueToInt ValInt fact Choose -> go2' valueToInt valueToInt ValInt choose @@ -265,7 +265,7 @@ callLambda = \name env x t body args -> wrapError' ("while calling lambda " ++ m callValue :: MonadError Error m => Value -> [Value] -> m Value callValue f args = case (f, args) of - (ValBuiltin builtin args', _) -> callBuiltin builtin (args' ++ args) + (ValBuiltin builtin ts args', _) -> callBuiltin builtin ts (args' ++ args) (ValLambda name env x t body, _) -> callLambda name env x t body args (_, []) -> return f _ -> throwInternalError $ "cannot call a non-function: " ++ formatValue f @@ -276,7 +276,7 @@ evaluateExpr env = \case Nothing -> throwInternalError $ "undefined variable: " ++ unVarName x Just val -> return val Lit lit -> case lit of - LitBuiltin ConvexHullTrickInit -> callBuiltin ConvexHullTrickInit [] + LitBuiltin ConvexHullTrickInit ts -> callBuiltin ConvexHullTrickInit ts [] _ -> literalToValue lit If' _ p e1 e2 -> do p <- valueToBool =<< evaluateExpr env p diff --git a/src/Jikka/Core/Format.hs b/src/Jikka/Core/Format.hs index 0147971a..fb18c7b5 100644 --- a/src/Jikka/Core/Format.hs +++ b/src/Jikka/Core/Format.hs @@ -113,6 +113,7 @@ formatType' = \case BoolTy -> ("bool", identPrec) ListTy t -> (resolvePrec funCallPrec (formatType' t) ++ " list", funCallPrec) TupleTy ts -> case ts of + [] -> ("unit", identPrec) [t] -> (resolvePrec (pred multPrec) (formatType' t) ++ ",", multPrec) _ -> (intercalate " * " (map (resolvePrec (pred multPrec) . formatType') ts), multPrec) FunTy t1 t2 -> @@ -134,123 +135,117 @@ formatSemigroup = \case SemigroupIntMax -> "int.max" data Builtin' - = Fun [Type] String - | PrefixOp [Type] String - | InfixOp [Type] String Prec Assoc - | At' Type - | SetAt' Type - | Tuple' [Type] - | Proj' [Type] Integer - | If' Type + = Fun String + | PrefixOp String + | InfixOp String Prec Assoc + | At' + | SetAt' + | Tuple' + | Proj' Integer + | If' deriving (Eq, Ord, Show, Read) -fun :: String -> Builtin' -fun = Fun [] - -infixOp :: String -> Prec -> Assoc -> Builtin' -infixOp = InfixOp [] - analyzeBuiltin :: Builtin -> Builtin' analyzeBuiltin = \case -- arithmetical functions - Negate -> PrefixOp [] "-" - Plus -> infixOp "+" addPrec LeftToRight - Minus -> infixOp "-" addPrec LeftToRight - Mult -> infixOp "*" multPrec LeftToRight - FloorDiv -> infixOp "/" multPrec LeftToRight - FloorMod -> infixOp "%" multPrec LeftToRight - CeilDiv -> infixOp "/^" multPrec LeftToRight - CeilMod -> infixOp "%^" multPrec LeftToRight - Pow -> infixOp "**" powerPrec RightToLeft + Negate -> PrefixOp "-" + Plus -> InfixOp "+" addPrec LeftToRight + Minus -> InfixOp "-" addPrec LeftToRight + Mult -> InfixOp "*" multPrec LeftToRight + FloorDiv -> InfixOp "/" multPrec LeftToRight + FloorMod -> InfixOp "%" multPrec LeftToRight + CeilDiv -> InfixOp "/^" multPrec LeftToRight + CeilMod -> InfixOp "%^" multPrec LeftToRight + Pow -> InfixOp "**" powerPrec RightToLeft -- advanced arithmetical functions - Abs -> fun "abs" - Gcd -> fun "gcd" - Lcm -> fun "lcm" - Min2 t -> InfixOp [t] " InfixOp [t] ">?" appendPrec LeftToRight + Abs -> Fun "abs" + Gcd -> Fun "gcd" + Lcm -> Fun "lcm" + Min2 -> InfixOp " InfixOp ">?" appendPrec LeftToRight -- logical functions - Not -> PrefixOp [] "not" - And -> infixOp "and" andPrec RightToLeft - Or -> infixOp "or" orPrec RightToLeft - Implies -> infixOp "implies" impliesPrec RightToLeft - If t -> If' t + Not -> PrefixOp "not" + And -> InfixOp "and" andPrec RightToLeft + Or -> InfixOp "or" orPrec RightToLeft + Implies -> InfixOp "implies" impliesPrec RightToLeft + If -> If' -- bitwise functions - BitNot -> PrefixOp [] "~" - BitAnd -> infixOp "&" multPrec LeftToRight - BitOr -> infixOp "|" appendPrec LeftToRight - BitXor -> infixOp "^" addPrec LeftToRight - BitLeftShift -> infixOp "<<" powerPrec LeftToRight - BitRightShift -> infixOp ">>" powerPrec LeftToRight + BitNot -> PrefixOp "~" + BitAnd -> InfixOp "&" multPrec LeftToRight + BitOr -> InfixOp "|" appendPrec LeftToRight + BitXor -> InfixOp "^" addPrec LeftToRight + BitLeftShift -> InfixOp "<<" powerPrec LeftToRight + BitRightShift -> InfixOp ">>" powerPrec LeftToRight -- matrix functions - MatAp _ _ -> fun "matap" - MatZero _ -> fun "matzero" - MatOne _ -> fun "matone" - MatAdd _ _ -> fun "matadd" - MatMul _ _ _ -> fun "matmul" - MatPow _ -> fun "matpow" - VecFloorMod _ -> fun "vecfloormod" - MatFloorMod _ _ -> fun "matfloormod" + MatAp _ _ -> Fun "matap" + MatZero _ -> Fun "matzero" + MatOne _ -> Fun "matone" + MatAdd _ _ -> Fun "matadd" + MatMul _ _ _ -> Fun "matmul" + MatPow _ -> Fun "matpow" + VecFloorMod _ -> Fun "vecfloormod" + MatFloorMod _ _ -> Fun "matfloormod" -- modular functions - ModNegate -> fun "modnegate" - ModPlus -> fun "modplus" - ModMinus -> fun "modminus" - ModMult -> fun "modmult" - ModInv -> fun "modinv" - ModPow -> fun "modpow" - ModMatAp _ _ -> fun "modmatap" - ModMatAdd _ _ -> fun "modmatadd" - ModMatMul _ _ _ -> fun "modmatmul" - ModMatPow _ -> fun "modmatpow" + ModNegate -> Fun "modnegate" + ModPlus -> Fun "modplus" + ModMinus -> Fun "modminus" + ModMult -> Fun "modmult" + ModInv -> Fun "modinv" + ModPow -> Fun "modpow" + ModMatAp _ _ -> Fun "modmatap" + ModMatAdd _ _ -> Fun "modmatadd" + ModMatMul _ _ _ -> Fun "modmatmul" + ModMatPow _ -> Fun "modmatpow" -- list functions - Cons t -> Fun [t] "cons" - Snoc t -> Fun [t] "snoc" - Foldl t1 t2 -> Fun [t1, t2] "foldl" - Scanl t1 t2 -> Fun [t1, t2] "scanl" - Build t -> Fun [t] "build" - Iterate t -> Fun [t] "iterate" - Len t -> Fun [t] "len" - Map t1 t2 -> Fun [t1, t2] "map" - Filter t -> Fun [t] "filter" - At t -> At' t - SetAt t -> SetAt' t - Elem t -> Fun [t] "elem" - Sum -> fun "sum" - Product -> fun "product" - ModSum -> fun "modsum" - ModProduct -> fun "modproduct" - Min1 t -> Fun [t] "min" - Max1 t -> Fun [t] "max" - ArgMin t -> Fun [t] "argmin" - ArgMax t -> Fun [t] "argmax" - All -> fun "all" - Any -> fun "any" - Sorted t -> Fun [t] "sort" - Reversed t -> Fun [t] "reverse" - Range1 -> fun "range" - Range2 -> fun "range2" - Range3 -> fun "range3" + Cons -> Fun "cons" + Snoc -> Fun "snoc" + Foldl -> Fun "foldl" + Scanl -> Fun "scanl" + Build -> Fun "build" + Iterate -> Fun "iterate" + Len -> Fun "len" + Map -> Fun "map" + Filter -> Fun "filter" + At -> At' + SetAt -> SetAt' + Elem -> Fun "elem" + Sum -> Fun "sum" + Product -> Fun "product" + ModSum -> Fun "modsum" + ModProduct -> Fun "modproduct" + Min1 -> Fun "min" + Max1 -> Fun "max" + ArgMin -> Fun "argmin" + ArgMax -> Fun "argmax" + All -> Fun "all" + Any -> Fun "any" + Sorted -> Fun "sort" + Reversed -> Fun "reverse" + Range1 -> Fun "range" + Range2 -> Fun "range2" + Range3 -> Fun "range3" -- tuple functions - Tuple ts -> Tuple' ts - Proj ts n -> Proj' ts (toInteger n) + Tuple -> Tuple' + Proj n -> Proj' n -- comparison - LessThan t -> InfixOp [t] "<" comparePrec NoAssoc - LessEqual t -> InfixOp [t] "<=" comparePrec NoAssoc - GreaterThan t -> InfixOp [t] ">" comparePrec NoAssoc - GreaterEqual t -> InfixOp [t] ">=" comparePrec NoAssoc - Equal t -> InfixOp [t] "==" comparePrec NoAssoc - NotEqual t -> InfixOp [t] "!=" comparePrec NoAssoc + LessThan -> InfixOp "<" comparePrec NoAssoc + LessEqual -> InfixOp "<=" comparePrec NoAssoc + GreaterThan -> InfixOp ">" comparePrec NoAssoc + GreaterEqual -> InfixOp ">=" comparePrec NoAssoc + Equal -> InfixOp "==" comparePrec NoAssoc + NotEqual -> InfixOp "!=" comparePrec NoAssoc -- combinational functions - Fact -> fun "fact" - Choose -> fun "choose" - Permute -> fun "permute" - MultiChoose -> fun "multichoose" + Fact -> Fun "fact" + Choose -> Fun "choose" + Permute -> Fun "permute" + MultiChoose -> Fun "multichoose" -- data structures - ConvexHullTrickInit -> fun "cht.init" - ConvexHullTrickGetMin -> fun "cht.getmin" - ConvexHullTrickInsert -> fun "cht.insert" - SegmentTreeInitList _ -> fun "segtree.initlist" - SegmentTreeGetRange _ -> fun "segtree.getrange" - SegmentTreeSetPoint _ -> fun "segtree.setpoint" + ConvexHullTrickInit -> Fun "cht.init" + ConvexHullTrickGetMin -> Fun "cht.getmin" + ConvexHullTrickInsert -> Fun "cht.insert" + SegmentTreeInitList _ -> Fun "segtree.initlist" + SegmentTreeGetRange _ -> Fun "segtree.getrange" + SegmentTreeSetPoint _ -> Fun "segtree.setpoint" formatTemplate :: [Type] -> String formatTemplate = \case @@ -262,40 +257,40 @@ formatFunCall f = \case [] -> f args -> (resolvePrec funCallPrec f ++ "(" ++ intercalate ", " (map (resolvePrec commaPrec . formatExpr') args) ++ ")", funCallPrec) -formatBuiltinIsolated' :: Builtin' -> String -formatBuiltinIsolated' = \case - Fun ts name -> name ++ formatTemplate ts - PrefixOp ts op -> paren $ op ++ formatTemplate ts - InfixOp ts op _ _ -> paren $ op ++ formatTemplate ts - At' t -> paren $ "at" ++ formatTemplate [t] - SetAt' t -> paren $ "set-at" ++ formatTemplate [t] - Tuple' ts -> paren $ "tuple" ++ formatTemplate ts - Proj' ts n -> paren $ "proj-" ++ show n ++ formatTemplate ts - If' t -> paren $ "if-then-else" ++ formatTemplate [t] - -formatBuiltinIsolated :: Builtin -> String -formatBuiltinIsolated = formatBuiltinIsolated' . analyzeBuiltin - -formatBuiltin' :: Builtin' -> [Expr] -> (String, Prec) -formatBuiltin' builtin args = case (builtin, args) of - (Fun _ "map", [Lam x IntTy e, Range1' n]) | x `isUnusedVar` e -> formatFunCall ("replicate", identPrec) [n, e] - (Fun _ name, _) -> formatFunCall (name, identPrec) args - (PrefixOp _ op, e1 : args) -> formatFunCall (op ++ " " ++ resolvePrec unaryPrec (formatExpr' e1), unaryPrec) args - (InfixOp _ op prec assoc, e1 : e2 : args) -> formatFunCall (resolvePrecLeft prec assoc (formatExpr' e1) ++ " " ++ op ++ " " ++ resolvePrecRight prec assoc (formatExpr' e2), prec) args - (At' _, e1 : e2 : args) -> formatFunCall (resolvePrec identPrec (formatExpr' e1) ++ "[" ++ resolvePrec parenPrec (formatExpr' e2) ++ "]", identPrec) args - (SetAt' _, e1 : e2 : e3 : args) -> formatFunCall (resolvePrec identPrec (formatExpr' e1) ++ "[" ++ resolvePrec parenPrec (formatExpr' e2) ++ " := " ++ resolvePrec parenPrec (formatExpr' e3) ++ "]", identPrec) args - (Tuple' [_], e : args) -> formatFunCall (paren (resolvePrec commaPrec (formatExpr' e) ++ ","), identPrec) args - (Tuple' ts, args) | length args >= length ts -> formatFunCall (paren (intercalate ", " (map (resolvePrec commaPrec . formatExpr') (take (length ts) args))), identPrec) (drop (length ts) args) - (Proj' _ n, e : args) -> formatFunCall (resolvePrec identPrec (formatExpr' e) ++ "." ++ show n, identPrec) args - (If' _, e1 : e2 : e3 : args) -> formatFunCall ("if" ++ " " ++ resolvePrec parenPrec (formatExpr' e1) ++ " then " ++ resolvePrec parenPrec (formatExpr' e2) ++ " else " ++ resolvePrec lambdaPrec (formatExpr' e3), lambdaPrec) args - _ -> formatFunCall (formatBuiltinIsolated' builtin, identPrec) args - -formatBuiltin :: Builtin -> [Expr] -> String -formatBuiltin f args = resolvePrec parenPrec (formatBuiltin' (analyzeBuiltin f) args) +formatBuiltinIsolated' :: Builtin' -> [Type] -> String +formatBuiltinIsolated' builtin ts = case builtin of + Fun name -> name ++ formatTemplate ts + PrefixOp op -> paren $ op ++ formatTemplate ts + InfixOp op _ _ -> paren $ op ++ formatTemplate ts + At' -> paren $ "at" ++ formatTemplate ts + SetAt' -> paren $ "set-at" ++ formatTemplate ts + Tuple' -> paren $ "tuple" ++ formatTemplate ts + Proj' n -> paren $ "proj-" ++ show n ++ formatTemplate ts + If' -> paren $ "if-then-else" ++ formatTemplate ts + +formatBuiltinIsolated :: Builtin -> [Type] -> String +formatBuiltinIsolated builtin ts = formatBuiltinIsolated' (analyzeBuiltin builtin) ts + +formatBuiltin' :: Builtin' -> [Type] -> [Expr] -> (String, Prec) +formatBuiltin' builtin ts args = case (builtin, ts, args) of + (Fun "map", _, [Lam x IntTy e, Range1' n]) | x `isUnusedVar` e -> formatFunCall ("replicate", identPrec) [n, e] + (Fun name, _, _) -> formatFunCall (name, identPrec) args + (PrefixOp op, _, e1 : args) -> formatFunCall (op ++ " " ++ resolvePrec unaryPrec (formatExpr' e1), unaryPrec) args + (InfixOp op prec assoc, _, e1 : e2 : args) -> formatFunCall (resolvePrecLeft prec assoc (formatExpr' e1) ++ " " ++ op ++ " " ++ resolvePrecRight prec assoc (formatExpr' e2), prec) args + (At', _, e1 : e2 : args) -> formatFunCall (resolvePrec identPrec (formatExpr' e1) ++ "[" ++ resolvePrec parenPrec (formatExpr' e2) ++ "]", identPrec) args + (SetAt', _, e1 : e2 : e3 : args) -> formatFunCall (resolvePrec identPrec (formatExpr' e1) ++ "[" ++ resolvePrec parenPrec (formatExpr' e2) ++ " := " ++ resolvePrec parenPrec (formatExpr' e3) ++ "]", identPrec) args + (Tuple', [_], e : args) -> formatFunCall (paren (resolvePrec commaPrec (formatExpr' e) ++ ","), identPrec) args + (Tuple', _, args) | length args >= length ts -> formatFunCall (paren (intercalate ", " (map (resolvePrec commaPrec . formatExpr') (take (length ts) args))), identPrec) (drop (length ts) args) + (Proj' n, _, e : args) -> formatFunCall (resolvePrec identPrec (formatExpr' e) ++ "." ++ show n, identPrec) args + (If', _, e1 : e2 : e3 : args) -> formatFunCall ("if" ++ " " ++ resolvePrec parenPrec (formatExpr' e1) ++ " then " ++ resolvePrec parenPrec (formatExpr' e2) ++ " else " ++ resolvePrec lambdaPrec (formatExpr' e3), lambdaPrec) args + _ -> formatFunCall (formatBuiltinIsolated' builtin ts, identPrec) args + +formatBuiltin :: Builtin -> [Type] -> [Expr] -> String +formatBuiltin f ts args = resolvePrec parenPrec (formatBuiltin' (analyzeBuiltin f) ts args) formatLiteral :: Literal -> String formatLiteral = \case - LitBuiltin builtin -> formatBuiltinIsolated builtin + LitBuiltin builtin ts -> formatBuiltinIsolated builtin ts LitInt n -> show n LitBool p -> map toLower $ show p LitNil t -> "nil" ++ formatTemplate [t] @@ -312,7 +307,7 @@ formatExpr' = \case let (f, args) = curryApp e in case f of Var x -> formatFunCall (unVarName x, identPrec) args - Lit (LitBuiltin builtin) -> (formatBuiltin builtin args, identPrec) + Lit (LitBuiltin builtin ts) -> (formatBuiltin builtin ts args, identPrec) _ -> formatFunCall (formatExpr' f) args LamId _ -> ("id", identPrec) LamConst _ e -> formatFunCall ("const", identPrec) [e] diff --git a/src/Jikka/Core/Language/BuiltinPatterns.hs b/src/Jikka/Core/Language/BuiltinPatterns.hs index f70ea381..4ab931dc 100644 --- a/src/Jikka/Core/Language/BuiltinPatterns.hs +++ b/src/Jikka/Core/Language/BuiltinPatterns.hs @@ -10,13 +10,13 @@ -- Portability : portable -- -- `Jikka.Core.Language.BuiltinPatterns` provides pattern synonyms for applications of `Builtin` functions. --- For example, provide a pattern @Sum' e@ which is interpreted as @AppBuiltin Sum [e]@, or the same thing, @App (Lit (LitBuiltin Sum)) [e]@. +-- For example, provide a pattern @Sum' e@ which is interpreted as @AppBuiltin1 Sum [e]@, or the same thing, @App (Lit (LitBuiltin Sum)) [e]@. module Jikka.Core.Language.BuiltinPatterns where import Jikka.Core.Language.Expr -- arithmetical functions -pattern Negate' e = AppBuiltin Negate e +pattern Negate' e = AppBuiltin1 Negate e pattern Plus' e1 e2 = AppBuiltin2 Plus e1 e2 @@ -35,20 +35,20 @@ pattern CeilMod' e1 e2 = AppBuiltin2 CeilMod e1 e2 pattern Pow' e1 e2 = AppBuiltin2 Pow e1 e2 -- advanced arithmetical functions -pattern Abs' e = AppBuiltin Abs e +pattern Abs' e = AppBuiltin1 Abs e pattern Gcd' e1 e2 = AppBuiltin2 Gcd e1 e2 pattern Lcm' e1 e2 = AppBuiltin2 Lcm e1 e2 -pattern Min2' t e1 e2 = AppBuiltin2 (Min2 t) e1 e2 +pattern Min2' t e1 e2 = AppBuiltin12 Min2 t e1 e2 -pattern Max2' t e1 e2 = AppBuiltin2 (Max2 t) e1 e2 +pattern Max2' t e1 e2 = AppBuiltin12 Max2 t e1 e2 -pattern Iterate' t n step base = AppBuiltin3 (Iterate t) n step base +pattern Iterate' t n step base = AppBuiltin13 Iterate t n step base -- logical functions -pattern Not' e = AppBuiltin Not e +pattern Not' e = AppBuiltin1 Not e pattern And' e1 e2 = AppBuiltin2 And e1 e2 @@ -56,10 +56,10 @@ pattern Or' e1 e2 = AppBuiltin2 Or e1 e2 pattern Implies' e1 e2 = AppBuiltin2 Implies e1 e2 -pattern If' t e1 e2 e3 = AppBuiltin3 (If t) e1 e2 e3 +pattern If' t e1 e2 e3 = AppBuiltin13 If t e1 e2 e3 -- bitwise functions -pattern BitNot' e = AppBuiltin BitNot e +pattern BitNot' e = AppBuiltin1 BitNot e pattern BitAnd' e1 e2 = AppBuiltin2 BitAnd e1 e2 @@ -109,79 +109,79 @@ pattern ModMatPow' n e1 e2 e3 = AppBuiltin3 (ModMatPow n) e1 e2 e3 -- list functions pattern Nil' t = Lit (LitNil t) -pattern Cons' t e1 e2 = AppBuiltin2 (Cons t) e1 e2 +pattern Cons' t e1 e2 = AppBuiltin12 Cons t e1 e2 -pattern Snoc' t e1 e2 = AppBuiltin2 (Snoc t) e1 e2 +pattern Snoc' t e1 e2 = AppBuiltin12 Snoc t e1 e2 -pattern Foldl' t1 t2 e1 e2 e3 = AppBuiltin3 (Foldl t1 t2) e1 e2 e3 +pattern Foldl' t1 t2 e1 e2 e3 = AppBuiltin23 Foldl t1 t2 e1 e2 e3 -pattern Scanl' t1 t2 e1 e2 e3 = AppBuiltin3 (Scanl t1 t2) e1 e2 e3 +pattern Scanl' t1 t2 e1 e2 e3 = AppBuiltin23 Scanl t1 t2 e1 e2 e3 -pattern Build' t e1 e2 e3 = AppBuiltin3 (Build t) e1 e2 e3 +pattern Build' t e1 e2 e3 = AppBuiltin13 Build t e1 e2 e3 -pattern Len' t e = AppBuiltin (Len t) e +pattern Len' t e = AppBuiltin11 Len t e -pattern Map' t1 t2 f e = AppBuiltin2 (Map t1 t2) f e +pattern Map' t1 t2 f e = AppBuiltin22 Map t1 t2 f e -pattern Filter' t f e = AppBuiltin2 (Filter t) f e +pattern Filter' t f e = AppBuiltin12 Filter t f e -pattern At' t e1 e2 = AppBuiltin2 (At t) e1 e2 +pattern At' t e1 e2 = AppBuiltin12 At t e1 e2 -pattern SetAt' t e1 e2 e3 = AppBuiltin3 (SetAt t) e1 e2 e3 +pattern SetAt' t e1 e2 e3 = AppBuiltin13 SetAt t e1 e2 e3 -pattern Elem' t e1 e2 = AppBuiltin2 (Elem t) e1 e2 +pattern Elem' t e1 e2 = AppBuiltin12 Elem t e1 e2 -pattern Sum' e = AppBuiltin Sum e +pattern Sum' e = AppBuiltin1 Sum e -pattern Product' e = AppBuiltin Product e +pattern Product' e = AppBuiltin1 Product e pattern ModSum' e1 e2 = AppBuiltin2 ModSum e1 e2 pattern ModProduct' e1 e2 = AppBuiltin2 ModProduct e1 e2 -pattern Min1' t e = AppBuiltin (Min1 t) e +pattern Min1' t e = AppBuiltin11 Min1 t e -pattern Max1' t e = AppBuiltin (Max1 t) e +pattern Max1' t e = AppBuiltin11 Max1 t e -pattern ArgMin' t e = AppBuiltin (ArgMin t) e +pattern ArgMin' t e = AppBuiltin11 ArgMin t e -pattern ArgMax' t e = AppBuiltin (ArgMax t) e +pattern ArgMax' t e = AppBuiltin11 ArgMax t e -pattern All' e = AppBuiltin All e +pattern All' e = AppBuiltin1 All e -pattern Any' e = AppBuiltin Any e +pattern Any' e = AppBuiltin1 Any e -pattern Sorted' t e = AppBuiltin (Sorted t) e +pattern Sorted' t e = AppBuiltin11 Sorted t e -pattern Reversed' t e = AppBuiltin (Reversed t) e +pattern Reversed' t e = AppBuiltin11 Reversed t e -pattern Range1' e = AppBuiltin Range1 e +pattern Range1' e = AppBuiltin1 Range1 e pattern Range2' e1 e2 = AppBuiltin2 Range2 e1 e2 pattern Range3' e1 e2 e3 = AppBuiltin3 Range3 e1 e2 e3 -- tuple functions -pattern Tuple' ts = Lit (LitBuiltin (Tuple ts)) +pattern Tuple' ts = Lit (LitBuiltin Tuple ts) -pattern Proj' ts n e = AppBuiltin (Proj ts n) e +pattern Proj' ts n e = App (Lit (LitBuiltin (Proj n) ts)) e -- arithmetical relations -pattern LessThan' t e1 e2 = AppBuiltin2 (LessThan t) e1 e2 +pattern LessThan' t e1 e2 = AppBuiltin12 LessThan t e1 e2 -pattern LessEqual' t e1 e2 = AppBuiltin2 (LessEqual t) e1 e2 +pattern LessEqual' t e1 e2 = AppBuiltin12 LessEqual t e1 e2 -pattern GreaterThan' t e1 e2 = AppBuiltin2 (GreaterThan t) e1 e2 +pattern GreaterThan' t e1 e2 = AppBuiltin12 GreaterThan t e1 e2 -pattern GreaterEqual' t e1 e2 = AppBuiltin2 (GreaterEqual t) e1 e2 +pattern GreaterEqual' t e1 e2 = AppBuiltin12 GreaterEqual t e1 e2 -- equality relations (polymorphic) -pattern Equal' t e1 e2 = AppBuiltin2 (Equal t) e1 e2 +pattern Equal' t e1 e2 = AppBuiltin12 Equal t e1 e2 -pattern NotEqual' t e1 e2 = AppBuiltin2 (NotEqual t) e1 e2 +pattern NotEqual' t e1 e2 = AppBuiltin12 NotEqual t e1 e2 -- combinational functions -pattern Fact' e = AppBuiltin Fact e +pattern Fact' e = AppBuiltin1 Fact e pattern Choose' e1 e2 = AppBuiltin2 Choose e1 e2 @@ -190,13 +190,13 @@ pattern Permute' e1 e2 = AppBuiltin2 Permute e1 e2 pattern MultiChoose' e1 e2 = AppBuiltin2 MultiChoose e1 e2 -- data structures -pattern ConvexHullTrickInit' = Lit (LitBuiltin ConvexHullTrickInit) +pattern ConvexHullTrickInit' = Builtin ConvexHullTrickInit pattern ConvexHullTrickGetMin' cht a = AppBuiltin2 ConvexHullTrickGetMin cht a pattern ConvexHullTrickInsert' cht a b = AppBuiltin3 ConvexHullTrickInsert cht a b -pattern SegmentTreeInitList' semigrp a = AppBuiltin (SegmentTreeInitList semigrp) a +pattern SegmentTreeInitList' semigrp a = AppBuiltin1 (SegmentTreeInitList semigrp) a pattern SegmentTreeGetRange' semigrp segtree e1 e2 = AppBuiltin3 (SegmentTreeGetRange semigrp) segtree e1 e2 diff --git a/src/Jikka/Core/Language/Expr.hs b/src/Jikka/Core/Language/Expr.hs index e1110c88..d5cc44f1 100644 --- a/src/Jikka/Core/Language/Expr.hs +++ b/src/Jikka/Core/Language/Expr.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PatternSynonyms #-} @@ -16,14 +17,15 @@ -- They are similar to the GHC Core language. module Jikka.Core.Language.Expr where +import Data.Data import Data.String (IsString) -newtype VarName = VarName String deriving (Eq, Ord, Show, Read, IsString) +newtype VarName = VarName String deriving (Eq, Ord, Show, Read, Data, Typeable, IsString) unVarName :: VarName -> String unVarName (VarName name) = name -newtype TypeName = TypeName String deriving (Eq, Ord, Show, Read, IsString) +newtype TypeName = TypeName String deriving (Eq, Ord, Show, Read, Data, Typeable, IsString) unTypeName :: TypeName -> String unTypeName (TypeName name) = name @@ -53,18 +55,18 @@ data Type | TupleTy [Type] | FunTy Type Type | DataStructureTy DataStructure - deriving (Eq, Ord, Show, Read) + deriving (Eq, Ord, Show, Read, Data, Typeable) data DataStructure = ConvexHullTrick | SegmentTree Semigroup' - deriving (Eq, Ord, Show, Read) + deriving (Eq, Ord, Show, Read, Data, Typeable) data Semigroup' = SemigroupIntPlus | SemigroupIntMin | SemigroupIntMax - deriving (Eq, Ord, Show, Read) + deriving (Eq, Ord, Show, Read, Data, Typeable) -- | TODO: What is the difference between `Literal` and `Builtin`? data Builtin @@ -97,11 +99,11 @@ data Builtin | -- | \(: \int \to \int \to \int\) Lcm | -- | \(: \forall \alpha. \alpha \to \alpha \to \alpha\) - Min2 Type + Min2 | -- | \(: \forall \alpha. \alpha \to \alpha \to \alpha\) - Max2 Type + Max2 | -- | iterated application \((\lambda k f x. f^k(x)): \forall \alpha. \int \to (\alpha \to \alpha) \to \alpha \to \alpha\) - Iterate Type + Iterate | -- logical functions -- | \(: \bool \to \bool\) @@ -113,7 +115,7 @@ data Builtin | -- | \(: \bool \to \bool \to \bool\) Implies | -- | \(: \forall \alpha. \bool \to \alpha \to \alpha \to \alpha\) - If Type + If | -- bitwise functions -- | \(: \int \to \int\) @@ -131,21 +133,21 @@ data Builtin | -- matrix functions -- | matrix application \(: \int^{H \times W} \to \int^W \to \int^H\) - MatAp Int Int + MatAp Integer Integer | -- | zero matrix \(: \to \int^{n \times n}\) - MatZero Int + MatZero Integer | -- | unit matrix \(: \to \int^{n \times n}\) - MatOne Int + MatOne Integer | -- | matrix addition \(: \int^{H \times W} \to \int^{H \times W} \to \int^{H \times W}\) - MatAdd Int Int + MatAdd Integer Integer | -- | matrix multiplication \(: \int^{H \times n} \to \int^{n \times W} \to \int^{H \times W}\) - MatMul Int Int Int + MatMul Integer Integer Integer | -- | matrix power \(: \int^{n \times n} \to \int \to \int^{n \times n}\) - MatPow Int + MatPow Integer | -- | vector point-wise floor-mod \(: \int^{n} \to \int \to \int^{n}\) - VecFloorMod Int + VecFloorMod Integer | -- | matrix point-wise floor-mod \(: \int^{H \times W} \to \int \to \int^{H \times W}\) - MatFloorMod Int Int + MatFloorMod Integer Integer | -- modular functions -- | \(: \int \to \int \to \int\) @@ -161,37 +163,37 @@ data Builtin | -- | \(: \int \to \int \to \int \to \int\) ModPow | -- | matrix application \(: \int^{H \times W} \to \int^W \to \int \to \int^H\) - ModMatAp Int Int + ModMatAp Integer Integer | -- | matrix addition \(: \int^{H \times W} \to \int^{H \times W} \to \int \to \int^{H \times W}\) - ModMatAdd Int Int + ModMatAdd Integer Integer | -- | matrix multiplication \(: \int^{H \times n} \to \int^{n \times W} \to \int \to \int^{H \times W}\) - ModMatMul Int Int Int + ModMatMul Integer Integer Integer | -- | matrix power \(: \int^{n \times n} \to \int \to \int^{n \times n}\) - ModMatPow Int + ModMatPow Integer | -- list functions -- | \(: \forall \alpha. \alpha \to \list(\alpha) \to \list(\alpha)\) - Cons Type + Cons | -- | \(: \forall \alpha. \list(alpha) \to \alpha \to \list(\alpha)\) - Snoc Type + Snoc | -- | \(: \forall \alpha \beta. (\beta \to \alpha \to \beta) \to \beta \to \list(\alpha) \to \beta\) - Foldl Type Type + Foldl | -- | \(: \forall \alpha \beta. (\beta \to \alpha \to \beta) \to \beta \to \list(\alpha) \to \list(\beta)\) - Scanl Type Type + Scanl | -- | \(\lambda f a n.\) repeat @a <- snoc a (f a)@ @n@ times \(: \forall \alpha. (\list(\alpha) \to \alpha) \to \list(\alpha) \to \int \to \list(\alpha)\) - Build Type + Build | -- | \(: \forall \alpha. \list(\alpha) \to \int\) - Len Type + Len | -- | \(: \forall \alpha \beta. (\alpha \to \beta) \to \list(\alpha) \to \list(\beta)\) - Map Type Type + Map | -- | \(: \forall \alpha \beta. (\alpha \to \bool) \to \list(\alpha) \to \list(\beta)\) - Filter Type + Filter | -- | \(: \forall \alpha. \list(\alpha) \to \int \to \alpha\) - At Type + At | -- | \(: \forall \alpha. \list(\alpha) \to \int \to \alpha \to \list(\alpha)\) - SetAt Type + SetAt | -- | \(: \forall \alpha. \alpha \to \list(\alpha) \to \bool\) - Elem Type + Elem | -- | \(: \list(\int) \to \int\) Sum | -- | \(: \list(\int) \to \int\) @@ -201,21 +203,21 @@ data Builtin | -- | \(: \list(\int) \to \int \to \int\) ModProduct | -- | \(: \forall \alpha. \list(\alpha) \to \alpha\) - Min1 Type + Min1 | -- | \(: \forall \alpha. \list(\alpha) \to \alpha\) - Max1 Type + Max1 | -- | \(: \forall \alpha. \list(\alpha) \to \int\) - ArgMin Type + ArgMin | -- | \(: \forall \alpha. \list(\alpha) \to \int\) - ArgMax Type + ArgMax | -- | \(: \list(\bool) \to \bool\) All | -- | \(: \list(\bool) \to \bool\) Any | -- | \(: \forall \alpha. \list(\alpha) \to \list(\alpha)\) - Sorted Type + Sorted | -- | \(: \forall \alpha. \list(\alpha) \to \list(\alpha)\) - Reversed Type + Reversed | -- | \(: \int \to \list(\int)\) Range1 | -- | \(: \int \to \int \to \list(\int)\) @@ -225,23 +227,23 @@ data Builtin | -- tuple functions -- | \(: \forall \alpha_0 \alpha_1 \dots \alpha _ {n - 1}. \alpha_0 \to \dots \to \alpha _ {n - 1} \to \alpha_0 \times \dots \times \alpha _ {n - 1}\) - Tuple [Type] + Tuple | -- | \(: \forall \alpha_0 \alpha_1 \dots \alpha _ {n - 1}. \alpha_0 \times \dots \times \alpha _ {n - 1} \to \alpha_i\) - Proj [Type] Int + Proj Integer | -- comparison -- | \(: \forall \alpha. \alpha \to \alpha \to \bool\) - LessThan Type + LessThan | -- | \(: \forall \alpha. \alpha \to \alpha \to \bool\) - LessEqual Type + LessEqual | -- | \(: \forall \alpha. \alpha \to \alpha \to \bool\) - GreaterThan Type + GreaterThan | -- | \(: \forall \alpha. \alpha \to \alpha \to \bool\) - GreaterEqual Type + GreaterEqual | -- | \(: \forall \alpha. \alpha \to \alpha \to \bool\) - Equal Type + Equal | -- | \(: \forall \alpha. \alpha \to \alpha \to \bool\) - NotEqual Type + NotEqual | -- combinational functions -- | \(: \int \to \int\) @@ -266,10 +268,10 @@ data Builtin SegmentTreeGetRange Semigroup' | -- | \(: \forall S. \mathrm{segment-tree}(S) \to \int \to S \to \mathrm{segment-tree}(S)\) SegmentTreeSetPoint Semigroup' - deriving (Eq, Ord, Show, Read) + deriving (Eq, Ord, Show, Read, Data, Typeable) data Literal - = LitBuiltin Builtin + = LitBuiltin Builtin [Type] | -- | \(: \forall \alpha. \int\) LitInt Integer | -- | \(: \forall \alpha. \bool\) @@ -278,7 +280,7 @@ data Literal LitNil Type | -- | \(: \bot : \forall \alpha. \alpha\). The second argument is its error message. LitBottom Type String - deriving (Eq, Ord, Show, Read) + deriving (Eq, Ord, Show, Read, Data, Typeable) -- | `Expr` represents the exprs of our core language. This is similar to the `Expr` of GHC Core. -- See also [commentary/compiler/core-syn-type](https://gitlab.haskell.org/ghc/ghc/-/wikis/commentary/compiler/core-syn-type). @@ -290,6 +292,7 @@ data Literal -- \vert & e_0(e_1, e_2, \dots, e_n) \\ -- \vert & \lambda ~ x_0\colon \tau_0, x_1\colon \tau_1, \dots, x_{n-1}\colon \tau_{n-1}. ~ e \\ -- \vert & \mathbf{let} ~ x\colon \tau = e_1 ~ \mathbf{in} ~ e_2 +-- \vert & \tau -- \end{array} -- \] data Expr @@ -301,7 +304,7 @@ data Expr Lam VarName Type Expr | -- | This "let" is not recursive. Let VarName Type Expr Expr - deriving (Eq, Ord, Show, Read) + deriving (Eq, Ord, Show, Read, Data, Typeable) pattern Fun2Ty t1 t2 ret = FunTy t1 (FunTy t2 ret) @@ -327,11 +330,15 @@ pattern FunLTy t <- where FunLTy t = FunTy (ListTy t) t -vectorTy :: Int -> Type -vectorTy n = TupleTy (replicate n IntTy) +vectorTy :: Integer -> Type +vectorTy n + | 0 <= n && n < 10000 = TupleTy (replicate (fromInteger n) IntTy) + | otherwise = error $ "Jikka.Core.Language.Expr.vectorTy: invalid size: " ++ show n -matrixTy :: Int -> Int -> Type -matrixTy h w = TupleTy (replicate h (TupleTy (replicate w IntTy))) +matrixTy :: Integer -> Integer -> Type +matrixTy h w + | 0 <= h && h < 10000 && 0 <= w && w < 10000 = TupleTy (replicate (fromInteger h) (TupleTy (replicate (fromInteger w) IntTy))) + | otherwise = error $ "Jikka.Core.Language.Expr.matrixTy: invalid size: " ++ show (h, w) pattern UnitTy = TupleTy [] @@ -355,7 +362,11 @@ pattern LitTrue = Lit (LitBool True) pattern LitFalse = Lit (LitBool False) -pattern Builtin builtin = Lit (LitBuiltin builtin) +pattern Builtin builtin = Lit (LitBuiltin builtin []) + +pattern Builtin1 builtin t1 = Lit (LitBuiltin builtin [t1]) + +pattern Builtin2 builtin t1 t2 = Lit (LitBuiltin builtin [t1, t2]) pattern App2 f e1 e2 = App (App f e1) e2 @@ -363,11 +374,23 @@ pattern App3 f e1 e2 e3 = App (App (App f e1) e2) e3 pattern App4 f e1 e2 e3 e4 = App (App (App (App f e1) e2) e3) e4 -pattern AppBuiltin builtin e1 = App (Lit (LitBuiltin builtin)) e1 +pattern AppBuiltin1 builtin e1 = App (Lit (LitBuiltin builtin [])) e1 + +pattern AppBuiltin11 builtin t1 e1 = App (Lit (LitBuiltin builtin [t1])) e1 + +pattern AppBuiltin2 builtin e1 e2 = App2 (Lit (LitBuiltin builtin [])) e1 e2 + +pattern AppBuiltin12 builtin t1 e1 e2 = App2 (Lit (LitBuiltin builtin [t1])) e1 e2 + +pattern AppBuiltin22 builtin t1 t2 e1 e2 = App2 (Lit (LitBuiltin builtin [t1, t2])) e1 e2 + +pattern AppBuiltin3 builtin e1 e2 e3 = App3 (Lit (LitBuiltin builtin [])) e1 e2 e3 + +pattern AppBuiltin13 builtin t1 e1 e2 e3 = App3 (Lit (LitBuiltin builtin [t1])) e1 e2 e3 -pattern AppBuiltin2 builtin e1 e2 = App2 (Lit (LitBuiltin builtin)) e1 e2 +pattern AppBuiltin23 builtin t1 t2 e1 e2 e3 = App3 (Lit (LitBuiltin builtin [t1, t2])) e1 e2 e3 -pattern AppBuiltin3 builtin e1 e2 e3 = App3 (Lit (LitBuiltin builtin)) e1 e2 e3 +pattern AppBuiltin14 builtin t1 t2 e1 e2 e3 = App3 (Lit (LitBuiltin builtin [t1, t2])) e1 e2 e3 pattern Lam2 x1 t1 x2 t2 e = Lam x1 t1 (Lam x2 t2 e) @@ -386,6 +409,6 @@ data ToplevelExpr = ResultExpr Expr | ToplevelLet VarName Type Expr ToplevelExpr | ToplevelLetRec VarName [(VarName, Type)] Type Expr ToplevelExpr - deriving (Eq, Ord, Show, Read) + deriving (Eq, Ord, Show, Read, Data, Typeable) type Program = ToplevelExpr diff --git a/src/Jikka/Core/Language/QuasiRules.hs b/src/Jikka/Core/Language/QuasiRules.hs new file mode 100644 index 00000000..ec0c0bb8 --- /dev/null +++ b/src/Jikka/Core/Language/QuasiRules.hs @@ -0,0 +1,243 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE ViewPatterns #-} + +module Jikka.Core.Language.QuasiRules where + +import Control.Arrow +import Control.Monad.State.Strict +import Data.Data +import Jikka.Common.Error +import Jikka.Common.Format.Error +import qualified Jikka.Core.Convert.TypeInfer as TypeInfer +import Jikka.Core.Language.Expr +import Jikka.Core.Language.RewriteRules +import Jikka.Core.Parse (parseRule) +import Language.Haskell.TH (Body (..), Exp (..), Match (..), Pat (..), Q, Stmt (..)) +import qualified Language.Haskell.TH as TH +import qualified Language.Haskell.TH.Quote as TH +import qualified Language.Haskell.TH.Syntax as TH + +liftError :: ExceptT Error Q a -> Q a +liftError f = do + x <- runExceptT f + case x of + Left err -> fail $ "Jikka.Core.Language.QuasiRules.liftError: " ++ unlines (prettyError' err) + Right y -> return y + +lookupValueName :: (MonadTrans t, Monad (t Q), MonadFail (t Q)) => String -> t Q TH.Name +lookupValueName x = do + y <- lift $ TH.lookupValueName x + case y of + Nothing -> fail $ "Jikka.Core.Language.QuasiRules.lookupValueName: undefined symbol: " ++ x + Just y -> return y + +fromVarName :: VarName -> Q TH.Name +fromVarName (VarName x) = do + let base = takeWhile (/= '$') x + TH.newName (if null base then "x" else base) + +fromTypeName :: TypeName -> Q TH.Name +fromTypeName (TypeName x) = do + let base = takeWhile (/= '$') x + TH.newName (if null base then "t" else base) + +liftDataP :: Data a => a -> Q Pat +liftDataP = TH.dataToPatQ (const Nothing) + +data Env = Env + { vars :: [(VarName, Maybe Exp)], + typeVars :: [(TypeName, TH.Name)] + } + +toPatT :: Type -> StateT Env Q Pat +toPatT = \case + VarTy x -> do + env <- gets typeVars + case lookup x env of + Just y -> do + lift [p|((==) $(pure (VarE y)) -> True)|] + Nothing -> do + y <- lift $ fromTypeName x + modify' (\env -> env {typeVars = (x, y) : typeVars env}) + return $ VarP y + IntTy -> lift $ liftDataP IntTy + BoolTy -> lift $ liftDataP IntTy + ListTy t -> do + t <- toPatT t + lift [p|ListTy $(pure t)|] + TupleTy ts -> do + ts <- mapM toPatT ts + lift [p|TupleTy $(pure (ListP ts))|] + FunTy t1 t2 -> do + t1 <- toPatT t1 + t2 <- toPatT t2 + lift [p|FunTy $(pure t1) $(pure t2)|] + DataStructureTy ds -> do + lift [p|DataStructureTy $(liftDataP ds)|] + +toPatL :: Literal -> StateT Env Q Pat +toPatL = \case + LitBuiltin builtin ts -> do + ts <- mapM toPatT ts + lift [p|LitBuiltin $(liftDataP builtin) $(pure (ListP ts))|] + lit@(LitInt _) -> lift $ liftDataP lit + lit@(LitBool _) -> lift $ liftDataP lit + LitNil t -> do + t <- toPatT t + lift [p|LitNil $(pure t)|] + LitBottom t msg -> do + t <- toPatT t + lift [p|LitBottom $(pure t) $(liftDataP msg)|] + +toPatE :: Expr -> StateT Env Q Pat +toPatE = \case + Var x -> + if x == VarName "_" + then return WildP + else do + env <- gets vars + case lookup x env of + Just (Just y) -> do + lift [p|((== $(pure y)) -> True)|] + Just Nothing -> do + y <- lift $ fromVarName x + modify' (\env -> env {vars = (x, Just (VarE y)) : vars env}) + return $ VarP y + Nothing -> fail $ "Jikka.Core.Language.QuasiRules.toPatE: undefined variable: " ++ unVarName x + Lit lit -> do + lit <- toPatL lit + lift [p|Lit $(pure lit)|] + App e1 e2 -> do + e1 <- toPatE e1 + e2 <- toPatE e2 + lift [p|App $(pure e1) $(pure e2)|] + Lam x t e -> do + t <- toPatT t + y <- lift $ fromVarName x + y' <- lift [e|Var $(pure (VarE y))|] + modify' (\env -> env {vars = (x, Just y') : vars env}) + e <- toPatE e + lift [p|Lam $(pure (VarP y)) $(pure t) $(pure e)|] + Let x t e1 e2 -> do + t <- toPatT t + e1 <- toPatE e1 + y <- lift $ fromVarName x + modify' (\env -> env {vars = (x, Just (VarE y)) : vars env}) + e2 <- toPatE e2 + lift [p|Let $(pure (VarP y)) $(pure t) $(pure e1) $(pure e2)|] + +toExpT :: Type -> StateT Env Q Exp +toExpT = \case + VarTy x -> do + env <- gets typeVars + case lookup x env of + Just y -> return $ VarE y + Nothing -> fail $ "Jikka.Core.Language.QuasiRules.toExpT: undefined type variable: " ++ unTypeName x + IntTy -> do + lift $ TH.liftData IntTy + BoolTy -> do + lift $ TH.liftData BoolTy + ListTy t -> do + t <- toExpT t + lift [e|ListTy $(pure t)|] + TupleTy ts -> do + ts <- mapM toExpT ts + lift [e|TupleTy $(pure (ListE ts))|] + FunTy t1 t2 -> do + t1 <- toExpT t1 + t2 <- toExpT t2 + lift [e|FunTy $(pure t1) $(pure t2)|] + DataStructureTy ds -> do + lift $ TH.liftData (DataStructureTy ds) + +toExpL :: Literal -> StateT Env Q Exp +toExpL = \case + LitBuiltin builtin ts -> do + ts <- mapM toExpT ts + lift [e|LitBuiltin $(TH.liftData builtin) $(pure (ListE ts))|] + lit@(LitInt _) -> lift $ TH.liftData lit + lit@(LitBool _) -> lift $ TH.liftData lit + LitNil t -> do + t <- toExpT t + lift [e|LitNil $(pure t)|] + LitBottom t msg -> do + t <- toExpT t + lift [e|LitBottom $(pure t) $(TH.liftData msg)|] + +toExpE :: Expr -> StateT Env Q ([Stmt], Exp) +toExpE e = do + var <- lookupValueName "Var" + genVarName <- lookupValueName "Jikka.Core.Language.Util.genVarName'" + case e of + Var x -> do + env <- gets vars + case lookup x env of + Just (Just y) -> return ([], y) + _ -> fail $ "Jikka.Core.Language.QuasiRules.toExpE: undefined variable: " ++ unVarName x + Lit lit -> do + lit <- toExpL lit + e <- lift [e|Lit $(pure lit)|] + return ([], e) + App e1 e2 -> do + (stmts, e1) <- toExpE e1 + (stmts', e2) <- toExpE e2 + e <- lift [e|App $(pure e1) $(pure e2)|] + return (stmts ++ stmts', e) + Lam x t e -> do + t <- toExpT t + y <- lift $ fromVarName x + modify' (\env -> env {vars = (x, Just (AppE (ConE var) (VarE y))) : vars env}) + (stmts, e) <- toExpE e + e <- lift [e|Lam $(pure (VarE y)) $(pure t) $(pure e)|] + return (BindS (VarP y) (VarE genVarName) : stmts, e) + Let x t e1 e2 -> do + t <- toExpT t + (stmts, e1) <- toExpE e1 + y <- lift $ fromVarName x + modify' (\env -> env {vars = (x, Just (AppE (ConE var) (VarE y))) : vars env}) + (stmts', e2) <- toExpE e2 + e <- lift [e|Let $(pure (VarE y)) $(pure t) $(pure e1) $(pure e2)|] + return (stmts ++ BindS (VarP y) (VarE genVarName) : stmts', e) + +ruleExp :: String -> Q Exp +ruleExp s = do + (_, args, e1, e2) <- liftError $ parseRule s + (args, e1, e2) <- liftError $ TypeInfer.runRule args e1 e2 + env <- + return $ + Env + { vars = map (second (const Nothing)) args, + typeVars = [] + } + (pat, env) <- runStateT (toPatE e1) env + supressUnusedMatchesWarnings <- (concat <$>) . forM (vars env) $ \case + (_, Just e) -> do + e <- [e|return $(pure e)|] + return [NoBindS e] + _ -> return [] + supressUnusedMatchesWarnings' <- forM (typeVars env) $ \(_, y) -> do + NoBindS <$> [e|return $(pure (VarE y))|] + ((stmts, exp), _) <- runStateT (toExpE e2) env + rewriteRule' <- [e|RewriteRule|] + return' <- [e|return|] + just <- [e|Just|] + nothing <- [e|Nothing|] + return $ + AppE rewriteRule' $ + LamE [WildP] $ + LamCaseE + [ Match pat (NormalB (DoE (supressUnusedMatchesWarnings ++ supressUnusedMatchesWarnings' ++ stmts ++ [NoBindS (AppE return' (AppE just exp))]))) [], + Match WildP (NormalB (AppE return' nothing)) [] + ] + +r :: TH.QuasiQuoter +r = + TH.QuasiQuoter + { TH.quoteExp = ruleExp, + TH.quotePat = undefined, + TH.quoteType = undefined, + TH.quoteDec = undefined + } diff --git a/src/Jikka/Core/Language/RewriteRules.hs b/src/Jikka/Core/Language/RewriteRules.hs index c5663a97..cd7c1af8 100644 --- a/src/Jikka/Core/Language/RewriteRules.hs +++ b/src/Jikka/Core/Language/RewriteRules.hs @@ -14,6 +14,7 @@ module Jikka.Core.Language.RewriteRules pureRewriteRule, simpleRewriteRule, applyRewriteRule, + applyRewriteRule', applyRewriteRuleToplevelExpr, applyRewriteRuleProgram, applyRewriteRuleProgram', @@ -62,6 +63,9 @@ simpleRewriteRule f = RewriteRule (\_ e -> return (f e)) applyRewriteRule :: MonadError Error m => RewriteRule m -> [(VarName, Type)] -> Expr -> StateT Integer m (Maybe Expr) applyRewriteRule = applyRewriteRulePreOrder +applyRewriteRule' :: MonadError Error m => RewriteRule m -> [(VarName, Type)] -> Expr -> m (Maybe Expr) +applyRewriteRule' f env e = evalStateT (applyRewriteRule f env e) 0 + coalesceMaybes :: a -> Maybe a -> b -> Maybe b -> Maybe (a, b) coalesceMaybes _ Nothing _ Nothing = Nothing coalesceMaybes a Nothing _ (Just b) = Just (a, b) diff --git a/src/Jikka/Core/Language/TypeCheck.hs b/src/Jikka/Core/Language/TypeCheck.hs index d30575a0..50ef36af 100644 --- a/src/Jikka/Core/Language/TypeCheck.hs +++ b/src/Jikka/Core/Language/TypeCheck.hs @@ -12,111 +12,119 @@ module Jikka.Core.Language.TypeCheck where import Jikka.Common.Error -import Jikka.Core.Format (formatExpr, formatType) +import Jikka.Core.Format (formatBuiltinIsolated, formatExpr, formatType) import Jikka.Core.Language.Expr import Jikka.Core.Language.Util -builtinToType :: Builtin -> Type -builtinToType = \case - -- arithmetical functions - Negate -> Fun1STy IntTy - Plus -> Fun2STy IntTy - Minus -> Fun2STy IntTy - Mult -> Fun2STy IntTy - FloorDiv -> Fun2STy IntTy - FloorMod -> Fun2STy IntTy - CeilDiv -> Fun2STy IntTy - CeilMod -> Fun2STy IntTy - Pow -> Fun2STy IntTy - -- advanced arithmetical functions - Abs -> Fun1STy IntTy - Gcd -> Fun2STy IntTy - Lcm -> Fun2STy IntTy - Min2 t -> Fun2STy t - Max2 t -> Fun2STy t - Iterate t -> Fun3Ty IntTy (FunTy t t) t t - -- logical functions - Not -> Fun1STy BoolTy - And -> Fun2STy BoolTy - Or -> Fun2STy BoolTy - Implies -> Fun2STy BoolTy - If t -> Fun3Ty BoolTy t t t - -- bitwise functions - BitNot -> Fun1STy IntTy - BitAnd -> Fun2STy IntTy - BitOr -> Fun2STy IntTy - BitXor -> Fun2STy IntTy - BitLeftShift -> Fun2STy IntTy - BitRightShift -> Fun2STy IntTy - -- matrix functions - MatAp h w -> Fun2Ty (matrixTy h w) (vectorTy w) (vectorTy h) - MatZero n -> matrixTy n n - MatOne n -> matrixTy n n - MatAdd h w -> Fun2Ty (matrixTy h w) (matrixTy h w) (matrixTy h w) - MatMul h n w -> Fun2Ty (matrixTy h n) (matrixTy n w) (matrixTy h w) - MatPow n -> Fun2Ty (matrixTy n n) IntTy (matrixTy n n) - VecFloorMod n -> Fun2Ty (vectorTy n) IntTy (vectorTy n) - MatFloorMod h w -> Fun2Ty (matrixTy h w) IntTy (matrixTy h w) - -- modular functions - ModNegate -> Fun2STy IntTy - ModPlus -> Fun3STy IntTy - ModMinus -> Fun3STy IntTy - ModMult -> Fun3STy IntTy - ModInv -> Fun2STy IntTy - ModPow -> Fun3STy IntTy - ModMatAp h w -> Fun3Ty (matrixTy h w) (vectorTy w) IntTy (vectorTy h) - ModMatAdd h w -> Fun3Ty (matrixTy h w) (matrixTy h w) IntTy (matrixTy h w) - ModMatMul h n w -> Fun3Ty (matrixTy h n) (matrixTy n w) IntTy (matrixTy h w) - ModMatPow n -> Fun3Ty (matrixTy n n) IntTy IntTy (matrixTy n n) - -- list functions - Cons t -> Fun2Ty t (ListTy t) (ListTy t) - Snoc t -> Fun2Ty (ListTy t) t (ListTy t) - Foldl t1 t2 -> Fun3Ty (Fun2Ty t2 t1 t2) t2 (ListTy t1) t2 - Scanl t1 t2 -> Fun3Ty (Fun2Ty t2 t1 t2) t2 (ListTy t1) (ListTy t2) - Build t -> Fun3Ty (FunTy (ListTy t) t) (ListTy t) IntTy (ListTy t) - Len t -> FunTy (ListTy t) IntTy - Map t1 t2 -> Fun2Ty (FunTy t1 t2) (ListTy t1) (ListTy t2) - Filter t -> Fun2Ty (FunTy t BoolTy) (ListTy t) (ListTy t) - At t -> Fun2Ty (ListTy t) IntTy t - SetAt t -> Fun3Ty (ListTy t) IntTy t (ListTy t) - Elem t -> Fun2Ty t (ListTy t) BoolTy - Sum -> FunLTy IntTy - Product -> FunLTy IntTy - ModSum -> Fun2Ty (ListTy IntTy) IntTy IntTy - ModProduct -> Fun2Ty (ListTy IntTy) IntTy IntTy - Min1 t -> FunLTy t - Max1 t -> FunLTy t - ArgMin t -> FunTy (ListTy t) IntTy - ArgMax t -> FunTy (ListTy t) IntTy - All -> FunLTy BoolTy - Any -> FunLTy BoolTy - Sorted t -> Fun1STy (ListTy t) - Reversed t -> Fun1STy (ListTy t) - Range1 -> FunTy IntTy (ListTy IntTy) - Range2 -> Fun2Ty IntTy IntTy (ListTy IntTy) - Range3 -> Fun3Ty IntTy IntTy IntTy (ListTy IntTy) - -- tuple functions - Tuple ts -> curryFunTy ts (TupleTy ts) - Proj ts n -> FunTy (TupleTy ts) (ts !! n) - -- comparison - LessThan t -> Fun2Ty t t BoolTy - LessEqual t -> Fun2Ty t t BoolTy - GreaterThan t -> Fun2Ty t t BoolTy - GreaterEqual t -> Fun2Ty t t BoolTy - Equal t -> Fun2Ty t t BoolTy - NotEqual t -> Fun2Ty t t BoolTy - -- combinational functions - Fact -> Fun1STy IntTy - Choose -> Fun2STy IntTy - Permute -> Fun2STy IntTy - MultiChoose -> Fun2STy IntTy - -- data structure - ConvexHullTrickInit -> ConvexHullTrickTy - ConvexHullTrickGetMin -> Fun2Ty ConvexHullTrickTy IntTy IntTy - ConvexHullTrickInsert -> Fun3Ty ConvexHullTrickTy IntTy IntTy ConvexHullTrickTy - SegmentTreeInitList semigrp -> FunTy (ListTy (semigroupToType semigrp)) (SegmentTreeTy semigrp) - SegmentTreeGetRange semigrp -> Fun3Ty (SegmentTreeTy semigrp) IntTy IntTy (semigroupToType semigrp) - SegmentTreeSetPoint semigrp -> Fun3Ty (SegmentTreeTy semigrp) IntTy (semigroupToType semigrp) (SegmentTreeTy semigrp) +builtinToType :: MonadError Error m => Builtin -> [Type] -> m Type +builtinToType builtin ts = + let go0 f = return f + go1 f = case ts of + [t1] -> return $ f t1 + _ -> throwInternalError $ "expected 1 type argument, but got " ++ show (length ts) ++ ": " ++ formatBuiltinIsolated builtin ts + go2 f = case ts of + [t1, t2] -> return $ f t1 t2 + _ -> throwInternalError $ "expected 2 type arguments, but got " ++ show (length ts) ++ ": " ++ formatBuiltinIsolated builtin ts + in case builtin of + -- arithmetical functions + Negate -> go0 $ Fun1STy IntTy + Plus -> go0 $ Fun2STy IntTy + Minus -> go0 $ Fun2STy IntTy + Mult -> go0 $ Fun2STy IntTy + FloorDiv -> go0 $ Fun2STy IntTy + FloorMod -> go0 $ Fun2STy IntTy + CeilDiv -> go0 $ Fun2STy IntTy + CeilMod -> go0 $ Fun2STy IntTy + Pow -> go0 $ Fun2STy IntTy + -- advanced arithmetical functions + Abs -> go0 $ Fun1STy IntTy + Gcd -> go0 $ Fun2STy IntTy + Lcm -> go0 $ Fun2STy IntTy + Min2 -> go1 $ \t -> Fun2STy t + Max2 -> go1 $ \t -> Fun2STy t + Iterate -> go1 $ \t -> Fun3Ty IntTy (FunTy t t) t t + -- logical functions + Not -> go0 $ Fun1STy BoolTy + And -> go0 $ Fun2STy BoolTy + Or -> go0 $ Fun2STy BoolTy + Implies -> go0 $ Fun2STy BoolTy + If -> go1 $ \t -> Fun3Ty BoolTy t t t + -- bitwise functions + BitNot -> go0 $ Fun1STy IntTy + BitAnd -> go0 $ Fun2STy IntTy + BitOr -> go0 $ Fun2STy IntTy + BitXor -> go0 $ Fun2STy IntTy + BitLeftShift -> go0 $ Fun2STy IntTy + BitRightShift -> go0 $ Fun2STy IntTy + -- matrix functions + MatAp h w -> go0 $ Fun2Ty (matrixTy h w) (vectorTy w) (vectorTy h) + MatZero n -> go0 $ matrixTy n n + MatOne n -> go0 $ matrixTy n n + MatAdd h w -> go0 $ Fun2Ty (matrixTy h w) (matrixTy h w) (matrixTy h w) + MatMul h n w -> go0 $ Fun2Ty (matrixTy h n) (matrixTy n w) (matrixTy h w) + MatPow n -> go0 $ Fun2Ty (matrixTy n n) IntTy (matrixTy n n) + VecFloorMod n -> go0 $ Fun2Ty (vectorTy n) IntTy (vectorTy n) + MatFloorMod h w -> go0 $ Fun2Ty (matrixTy h w) IntTy (matrixTy h w) + -- modular functions + ModNegate -> go0 $ Fun2STy IntTy + ModPlus -> go0 $ Fun3STy IntTy + ModMinus -> go0 $ Fun3STy IntTy + ModMult -> go0 $ Fun3STy IntTy + ModInv -> go0 $ Fun2STy IntTy + ModPow -> go0 $ Fun3STy IntTy + ModMatAp h w -> go0 $ Fun3Ty (matrixTy h w) (vectorTy w) IntTy (vectorTy h) + ModMatAdd h w -> go0 $ Fun3Ty (matrixTy h w) (matrixTy h w) IntTy (matrixTy h w) + ModMatMul h n w -> go0 $ Fun3Ty (matrixTy h n) (matrixTy n w) IntTy (matrixTy h w) + ModMatPow n -> go0 $ Fun3Ty (matrixTy n n) IntTy IntTy (matrixTy n n) + -- list functions + Cons -> go1 $ \t -> Fun2Ty t (ListTy t) (ListTy t) + Snoc -> go1 $ \t -> Fun2Ty (ListTy t) t (ListTy t) + Foldl -> go2 $ \t1 t2 -> Fun3Ty (Fun2Ty t2 t1 t2) t2 (ListTy t1) t2 + Scanl -> go2 $ \t1 t2 -> Fun3Ty (Fun2Ty t2 t1 t2) t2 (ListTy t1) (ListTy t2) + Build -> go1 $ \t -> Fun3Ty (FunTy (ListTy t) t) (ListTy t) IntTy (ListTy t) + Len -> go1 $ \t -> FunTy (ListTy t) IntTy + Map -> go2 $ \t1 t2 -> Fun2Ty (FunTy t1 t2) (ListTy t1) (ListTy t2) + Filter -> go1 $ \t -> Fun2Ty (FunTy t BoolTy) (ListTy t) (ListTy t) + At -> go1 $ \t -> Fun2Ty (ListTy t) IntTy t + SetAt -> go1 $ \t -> Fun3Ty (ListTy t) IntTy t (ListTy t) + Elem -> go1 $ \t -> Fun2Ty t (ListTy t) BoolTy + Sum -> go0 $ FunLTy IntTy + Product -> go0 $ FunLTy IntTy + ModSum -> go0 $ Fun2Ty (ListTy IntTy) IntTy IntTy + ModProduct -> go0 $ Fun2Ty (ListTy IntTy) IntTy IntTy + Min1 -> go1 $ \t -> FunLTy t + Max1 -> go1 $ \t -> FunLTy t + ArgMin -> go1 $ \t -> FunTy (ListTy t) IntTy + ArgMax -> go1 $ \t -> FunTy (ListTy t) IntTy + All -> go0 $ FunLTy BoolTy + Any -> go0 $ FunLTy BoolTy + Sorted -> go1 $ \t -> Fun1STy (ListTy t) + Reversed -> go1 $ \t -> Fun1STy (ListTy t) + Range1 -> go0 $ FunTy IntTy (ListTy IntTy) + Range2 -> go0 $ Fun2Ty IntTy IntTy (ListTy IntTy) + Range3 -> go0 $ Fun3Ty IntTy IntTy IntTy (ListTy IntTy) + -- tuple functions + Tuple -> return $ curryFunTy ts (TupleTy ts) + Proj n -> return $ FunTy (TupleTy ts) (ts !! fromInteger n) + -- comparison + LessThan -> go1 $ \t -> Fun2Ty t t BoolTy + LessEqual -> go1 $ \t -> Fun2Ty t t BoolTy + GreaterThan -> go1 $ \t -> Fun2Ty t t BoolTy + GreaterEqual -> go1 $ \t -> Fun2Ty t t BoolTy + Equal -> go1 $ \t -> Fun2Ty t t BoolTy + NotEqual -> go1 $ \t -> Fun2Ty t t BoolTy + -- combinational functions + Fact -> go0 $ Fun1STy IntTy + Choose -> go0 $ Fun2STy IntTy + Permute -> go0 $ Fun2STy IntTy + MultiChoose -> go0 $ Fun2STy IntTy + -- data structure + ConvexHullTrickInit -> go0 ConvexHullTrickTy + ConvexHullTrickGetMin -> go0 $ Fun2Ty ConvexHullTrickTy IntTy IntTy + ConvexHullTrickInsert -> go0 $ Fun3Ty ConvexHullTrickTy IntTy IntTy ConvexHullTrickTy + SegmentTreeInitList semigrp -> go0 $ FunTy (ListTy (semigroupToType semigrp)) (SegmentTreeTy semigrp) + SegmentTreeGetRange semigrp -> go0 $ Fun3Ty (SegmentTreeTy semigrp) IntTy IntTy (semigroupToType semigrp) + SegmentTreeSetPoint semigrp -> go0 $ Fun3Ty (SegmentTreeTy semigrp) IntTy (semigroupToType semigrp) (SegmentTreeTy semigrp) semigroupToType :: Semigroup' -> Type semigroupToType = \case @@ -124,25 +132,25 @@ semigroupToType = \case SemigroupIntMin -> IntTy SemigroupIntMax -> IntTy -literalToType :: Literal -> Type +literalToType :: MonadError Error m => Literal -> m Type literalToType = \case - LitBuiltin builtin -> builtinToType builtin - LitInt _ -> IntTy - LitBool _ -> BoolTy - LitNil t -> ListTy t - LitBottom t _ -> t + LitBuiltin builtin ts -> builtinToType builtin ts + LitInt _ -> return IntTy + LitBool _ -> return BoolTy + LitNil t -> return $ ListTy t + LitBottom t _ -> return t -arityOfBuiltin :: Builtin -> Int -arityOfBuiltin = \case - Min2 _ -> 2 - Max2 _ -> 2 - Foldl _ _ -> 3 - Iterate _ -> 3 - At _ -> 2 - Min1 _ -> 1 - Max1 _ -> 1 - Proj _ _ -> 1 - builtin -> length (fst (uncurryFunTy (builtinToType builtin))) +arityOfBuiltin :: MonadError Error m => Builtin -> [Type] -> m Int +arityOfBuiltin builtin ts = case builtin of + Min2 -> return 2 + Max2 -> return 2 + Foldl -> return 3 + Iterate -> return 3 + At -> return 2 + Min1 -> return 1 + Max1 -> return 1 + Proj _ -> return 1 + builtin -> length . fst . uncurryFunTy <$> builtinToType builtin ts type TypeEnv = [(VarName, Type)] @@ -152,19 +160,22 @@ typecheckExpr env = \case Var x -> case lookup x env of Nothing -> throwInternalError $ "undefined variable: " ++ unVarName x Just t -> return t - Lit lit -> return $ literalToType lit + Lit lit -> literalToType lit App f e -> do tf <- typecheckExpr env f te <- typecheckExpr env e case tf of FunTy te' ret | te' == te -> return ret _ -> throwInternalError $ "wrong type funcall: function = " ++ formatExpr f ++ " and argument = " ++ formatExpr e ++ ", function's type = " ++ formatType tf ++ ", but argument's type = " ++ formatType te - Lam x t e -> FunTy t <$> typecheckExpr ((x, t) : env) e + Lam x t e -> + let env' = if x == VarName "_" then env else (x, t) : env + in FunTy t <$> typecheckExpr env' e Let x t e1 e2 -> do t' <- typecheckExpr env e1 when (t /= t') $ do throwInternalError $ "wrong type binding: " ++ formatExpr (Let x t e1 e2) - typecheckExpr ((x, t) : env) e2 + let env' = if x == VarName "_" then env else (x, t) : env + typecheckExpr env' e2 typecheckToplevelExpr :: MonadError Error m => TypeEnv -> ToplevelExpr -> m Type typecheckToplevelExpr env = \case diff --git a/src/Jikka/Core/Language/Util.hs b/src/Jikka/Core/Language/Util.hs index 841b8ece..6f0382ba 100644 --- a/src/Jikka/Core/Language/Util.hs +++ b/src/Jikka/Core/Language/Util.hs @@ -1,5 +1,6 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TupleSections #-} module Jikka.Core.Language.Util where @@ -29,107 +30,50 @@ genVarName x = do genVarName' :: MonadAlpha m => m VarName genVarName' = genVarName (VarName "_") -mapTypeInBuiltin :: (Type -> Type) -> Builtin -> Builtin -mapTypeInBuiltin f = \case - -- arithmetical functions - Negate -> Negate - Plus -> Plus - Minus -> Minus - Mult -> Mult - FloorDiv -> FloorDiv - FloorMod -> FloorMod - CeilDiv -> CeilDiv - CeilMod -> CeilMod - Pow -> Pow - -- advanced arithmetical functions - Abs -> Abs - Gcd -> Gcd - Lcm -> Lcm - Min2 t -> Min2 (f t) - Max2 t -> Max2 (f t) - Iterate t -> Iterate (f t) - -- logical functionslogical - Not -> Not - And -> And - Or -> Or - Implies -> Implies - If t -> If (f t) - -- bitwise functionsbitwise - BitNot -> BitNot - BitAnd -> BitAnd - BitOr -> BitOr - BitXor -> BitXor - BitLeftShift -> BitLeftShift - BitRightShift -> BitRightShift - -- matrix functions - MatAp h w -> MatAp h w - MatZero n -> MatZero n - MatOne n -> MatOne n - MatAdd h w -> MatAdd h w - MatMul h n w -> MatMul h n w - MatPow n -> MatPow n - VecFloorMod n -> VecFloorMod n - MatFloorMod h w -> MatFloorMod h w - -- modular functionsmodular - ModNegate -> ModNegate - ModPlus -> ModPlus - ModMinus -> ModMinus - ModMult -> ModMult - ModInv -> ModInv - ModPow -> ModPow - ModMatAp h w -> ModMatAp h w - ModMatAdd h w -> ModMatAdd h w - ModMatMul h n w -> ModMatMul h n w - ModMatPow n -> ModMatPow n - -- list functionslist - Cons t -> Cons (f t) - Snoc t -> Snoc (f t) - Foldl t1 t2 -> Foldl (f t1) (f t2) - Scanl t1 t2 -> Scanl (f t1) (f t2) - Build t -> Build (f t) - Len t -> Len (f t) - Map t1 t2 -> Map (f t1) (f t2) - Filter t -> Filter (f t) - At t -> At (f t) - SetAt t -> SetAt (f t) - Elem t -> Elem (f t) - Sum -> Sum - Product -> Product - ModSum -> ModSum - ModProduct -> ModProduct - Min1 t -> Min1 (f t) - Max1 t -> Max1 (f t) - ArgMin t -> ArgMin (f t) - ArgMax t -> ArgMax (f t) - All -> All - Any -> Any - Sorted t -> Sorted (f t) - Reversed t -> Reversed (f t) - Range1 -> Range1 - Range2 -> Range2 - Range3 -> Range3 - -- tuple functions - Tuple ts -> Tuple (map f ts) - Proj ts n -> Proj (map f ts) n - -- comparison - LessThan t -> LessThan (f t) - LessEqual t -> LessEqual (f t) - GreaterThan t -> GreaterThan (f t) - GreaterEqual t -> GreaterEqual (f t) - Equal t -> Equal (f t) - NotEqual t -> NotEqual (f t) - -- combinational functions - Fact -> Fact - Choose -> Choose - Permute -> Permute - MultiChoose -> MultiChoose - -- data structures - ConvexHullTrickInit -> ConvexHullTrickInit - ConvexHullTrickInsert -> ConvexHullTrickInsert - ConvexHullTrickGetMin -> ConvexHullTrickGetMin - SegmentTreeInitList semigrp -> SegmentTreeInitList semigrp - SegmentTreeGetRange semigrp -> SegmentTreeGetRange semigrp - SegmentTreeSetPoint semigrp -> SegmentTreeSetPoint semigrp +mapSubTypesM :: Monad m => (Type -> m Type) -> Type -> m Type +mapSubTypesM f = go + where + go = \case + VarTy x -> f $ VarTy x + IntTy -> f IntTy + BoolTy -> f BoolTy + ListTy t -> f . ListTy =<< f t + TupleTy ts -> f . TupleTy =<< mapM f ts + FunTy t1 t2 -> f =<< (FunTy <$> f t1 <*> f t2) + DataStructureTy ds -> f $ DataStructureTy ds + +mapTypeLiteralM :: Monad m => (Type -> m Type) -> Literal -> m Literal +mapTypeLiteralM f = \case + LitBuiltin builtin ts -> LitBuiltin builtin <$> mapM f ts + LitInt n -> return $ LitInt n + LitBool p -> return $ LitBool p + LitNil t -> LitNil <$> f t + LitBottom t err -> LitBottom <$> f t <*> pure err + +mapTypeExprM :: Monad m => (Type -> m Type) -> Expr -> m Expr +mapTypeExprM f = go + where + go = \case + Var x -> return $ Var x + Lit lit -> Lit <$> mapTypeLiteralM f lit + App f e -> App <$> go f <*> go e + Lam x t body -> Lam x <$> f t <*> go body + Let x t e1 e2 -> Let x <$> f t <*> go e1 <*> go e2 + +mapTypeExpr :: (Type -> Type) -> Expr -> Expr +mapTypeExpr f e = runIdentity (mapTypeExprM (return . f) e) + +mapTypeToplevelExprM :: Monad m => (Type -> m Type) -> ToplevelExpr -> m ToplevelExpr +mapTypeToplevelExprM f = \case + ResultExpr e -> ResultExpr <$> mapTypeExprM f e + ToplevelLet x t e cont -> ToplevelLet x <$> f t <*> mapTypeExprM f e <*> mapTypeToplevelExprM f cont + ToplevelLetRec g args ret body cont -> ToplevelLetRec g <$> mapM (\(x, t) -> (x,) <$> f t) args <*> f ret <*> mapTypeExprM f body <*> mapTypeToplevelExprM f cont + +mapTypeProgramM :: Monad m => (Type -> m Type) -> Program -> m Program +mapTypeProgramM = mapTypeToplevelExprM + +mapTypeProgram :: (Type -> Type) -> Program -> Program +mapTypeProgram f prog = runIdentity (mapTypeProgramM (return . f) prog) -- | `mapExprM'` substitutes exprs using given two functions, which are called in pre-order and post-order. mapExprM' :: Monad m => ([(VarName, Type)] -> Expr -> m Expr) -> ([(VarName, Type)] -> Expr -> m Expr) -> [(VarName, Type)] -> Expr -> m Expr @@ -244,15 +188,15 @@ isConstantTimeBuiltin = \case Abs -> True Gcd -> True Lcm -> True - Min2 _ -> True - Max2 _ -> True - Iterate _ -> False + Min2 -> True + Max2 -> True + Iterate -> False -- logical functions Not -> True And -> True Or -> True Implies -> True - If _ -> True + If -> True -- bitwise functions BitNot -> True BitAnd -> True @@ -281,42 +225,42 @@ isConstantTimeBuiltin = \case ModMatMul _ _ _ -> True ModMatPow _ -> True -- list functions - Cons _ -> False - Snoc _ -> False - Foldl _ _ -> False - Scanl _ _ -> False - Build _ -> False - Len _ -> True - Map _ _ -> False - Filter _ -> False - At _ -> True - SetAt _ -> False - Elem _ -> False + Cons -> False + Snoc -> False + Foldl -> False + Scanl -> False + Build -> False + Len -> True + Map -> False + Filter -> False + At -> True + SetAt -> False + Elem -> False Sum -> False Product -> False ModSum -> False ModProduct -> False - Min1 _ -> False - Max1 _ -> False - ArgMin _ -> False - ArgMax _ -> False + Min1 -> False + Max1 -> False + ArgMin -> False + ArgMax -> False All -> False Any -> False - Sorted _ -> False - Reversed _ -> False + Sorted -> False + Reversed -> False Range1 -> False Range2 -> False Range3 -> False -- tuple functions - Tuple _ -> True - Proj _ _ -> True + Tuple -> True + Proj _ -> True -- comparison - LessThan _ -> True - LessEqual _ -> True - GreaterThan _ -> True - GreaterEqual _ -> True - Equal _ -> True - NotEqual _ -> True + LessThan -> True + LessEqual -> True + GreaterThan -> True + GreaterEqual -> True + Equal -> True + NotEqual -> True -- combinational functions Fact -> True Choose -> True @@ -336,7 +280,7 @@ isConstantTimeExpr = \case Var _ -> True Lit _ -> True e@(App _ _) -> case curryApp e of - (Lit (LitBuiltin f), args) -> isConstantTimeBuiltin f && all isConstantTimeExpr args + (Lit (LitBuiltin f _), args) -> isConstantTimeBuiltin f && all isConstantTimeExpr args _ -> False Lam _ _ _ -> True Let _ _ e1 e2 -> isConstantTimeExpr e1 && isConstantTimeExpr e2 diff --git a/src/Jikka/Core/Language/Value.hs b/src/Jikka/Core/Language/Value.hs index 12cd0509..65bc73d7 100644 --- a/src/Jikka/Core/Language/Value.hs +++ b/src/Jikka/Core/Language/Value.hs @@ -20,7 +20,7 @@ data Value | ValBool Bool | ValList (V.Vector Value) | ValTuple [Value] - | ValBuiltin Builtin [Value] + | ValBuiltin Builtin [Type] [Value] | -- | The `Env` may contain the `ValLambda` cyclicly. ValLambda (Maybe VarName) Env VarName Type Expr deriving (Eq, Read) @@ -29,7 +29,7 @@ type Env = [(VarName, Value)] literalToValue :: MonadError Error m => Literal -> m Value literalToValue = \case - LitBuiltin builtin -> return $ ValBuiltin builtin [] + LitBuiltin builtin ts -> return $ ValBuiltin builtin ts [] LitInt n -> return $ ValInt n LitBool p -> return $ ValBool p LitNil _ -> return $ ValList V.empty @@ -122,8 +122,8 @@ formatValue = \case ValList xs -> "[" ++ intercalate ", " (map formatValue (V.toList xs)) ++ "]" ValTuple [x] -> "(" ++ formatValue x ++ ",)" ValTuple xs -> "(" ++ intercalate ", " (map formatValue xs) ++ ")" - ValBuiltin builtin [] -> formatBuiltinIsolated builtin - ValBuiltin builtin args -> formatBuiltinIsolated builtin ++ "(" ++ intercalate ", " (map formatValue args) ++ ")" + ValBuiltin builtin ts [] -> formatBuiltinIsolated builtin ts + ValBuiltin builtin ts args -> formatBuiltinIsolated builtin ts ++ "(" ++ intercalate ", " (map formatValue args) ++ ")" ValLambda _ _ x t body -> formatExpr (Lam x t body) -- Don't show env because it may be cyclic. readValueIO :: (MonadError Error m, MonadIO m) => IOFormat -> m ([Value], M.Map String Value) diff --git a/src/Jikka/Core/Parse.hs b/src/Jikka/Core/Parse.hs new file mode 100644 index 00000000..6e748a9f --- /dev/null +++ b/src/Jikka/Core/Parse.hs @@ -0,0 +1,42 @@ +{-# LANGUAGE FlexibleContexts #-} + +module Jikka.Core.Parse + ( run, + parseProgram, + parseExpr, + parseType, + parseRule, + ) +where + +import Data.Text (Text, unpack) +import Jikka.Common.Alpha +import Jikka.Common.Error +import Jikka.Core.Language.Expr +import qualified Jikka.Core.Parse.Alex as L +import qualified Jikka.Core.Parse.Happy as P + +parseRule :: (MonadAlpha m, MonadError Error m) => String -> m (String, [(VarName, Type)], Expr, Expr) +parseRule input = do + tokens <- L.run input + P.runRule tokens + +parseType :: (MonadAlpha m, MonadError Error m) => String -> m Type +parseType input = do + tokens <- L.run input + P.runType tokens + +parseExpr :: (MonadAlpha m, MonadError Error m) => String -> m Expr +parseExpr input = do + tokens <- L.run input + P.runExpr tokens + +parseProgram :: (MonadAlpha m, MonadError Error m) => String -> m Program +parseProgram input = do + tokens <- L.run input + P.runProgram tokens + +run :: (MonadAlpha m, MonadError Error m) => FilePath -> Text -> m Program +run _ input = do + tokens <- L.run $ unpack input + P.runProgram tokens diff --git a/src/Jikka/Core/Parse/Alex.x b/src/Jikka/Core/Parse/Alex.x new file mode 100644 index 00000000..ae807188 --- /dev/null +++ b/src/Jikka/Core/Parse/Alex.x @@ -0,0 +1,195 @@ +{ +-- vim: filetype=haskell +{-# LANGUAGE FlexibleContexts #-} + +-- | +-- Module : Jikka.Core.Parse.Alex +-- Description : tokenizes the code of our core language with Alex. +-- Copyright : (c) Kimiyuki Onaka, 2020 +-- License : Apache License 2.0 +-- Maintainer : kimiyuki95@gmail.com +-- Stability : experimental +-- Portability : portable +module Jikka.Core.Parse.Alex + ( run + ) where + +import Data.Char (chr, isHexDigit, isOctDigit) +import Jikka.Common.Error +import Jikka.Common.Location +import Jikka.Core.Parse.Token +} + +%wrapper "monad" + +$space = [\ \t\n\r] + +$alpha = [A-Z a-z] +$alnum = [0-9 A-Z a-z] +$doublequote = ["] +$backslash = [\\] +@nl = "\n" | "\r\n" + +$digit = [0-9] +$nonzerodigit = [1-9] +$bindigit = [0-1] +$octdigit = [0-7] +$hexdigit = [0-9a-fA-F] + +$shortstringchar_single = [^ \\ \r \n '] +$shortstringchar_double = [^ \\ \r \n '] +@stringescapeseq = $backslash . + +tokens :- + + $space + ; + "--" [^ \r\n] * ; + + "true" { tok (Bool True) } + "false" { tok (Bool False) } + + "0" ("_" ? "0") * { tok' parseInt } + $nonzerodigit ("_" ? $digit) * { tok' parseInt } + "0" [bB] ("_" ? $bindigit) + { tok' parseInt } + "0" [oO] ("_" ? $octdigit) + { tok' parseInt } + "0" [xX] ("_" ? $hexdigit) + { tok' parseInt } + + $doublequote ($shortstringchar_double | @stringescapeseq) * $doublequote { tok'' parseString } + + "let" { tok Let } + "rec" { tok Rec } + "in" { tok In } + "fun" { tok Fun } + "if" { tok If } + "then" { tok Then } + "else" { tok Else } + "forall" { tok Forall } + + -- punctuations + "->" { tok Arrow } + "=" { tok Equal } + ":" { tok Colon } + "," { tok Comma } + "_" { tok Underscore } + "." { tok Dot } + "<-" { tok BackArrow } + "@" { tok At } + + -- parens + "[" { tok OpenBracket } + "(" { tok OpenParen } + "]" { tok CloseBracket } + ")" { tok CloseParen } + + -- arithmetic operators + "+" { tok (Operator Plus) } + "-" { tok (Operator Minus) } + "*" { tok (Operator Mult) } + "/" { tok (Operator FloorDiv) } + "%" { tok (Operator FloorMod) } + "/^" { tok (Operator CeilDiv) } + "%^" { tok (Operator CeilMod) } + "**" { tok (Operator Pow) } + + -- boolean operators + "and" { tok (Operator And) } + "or" { tok (Operator Or) } + "not" { tok (Operator Not) } + "implies" { tok (Operator Implies) } + + -- bit operators + "~" { tok (Operator BitNot) } + "&" { tok (Operator BitAnd) } + "|" { tok (Operator BitOr) } + "^" { tok (Operator BitXor) } + "<<" { tok (Operator BitLShift) } + ">>" { tok (Operator BitRShift) } + + -- min max operators + "?" { tok (Operator Max) } + + -- comparators + ">" { tok (Operator GreaterThan) } + "<" { tok (Operator LessThan) } + "<=" { tok (Operator LessEqual) } + ">=" { tok (Operator GreaterEqual) } + "==" { tok (Operator DoubleEqual) } + "/=" { tok (Operator NotEqual) } + + -- identifier + $alpha ($alnum | "_") * { tok' Ident } + $alpha ($alnum | "_") * "$" $digit + { tok' Ident } + + -- catch error + . { skip' } +{ +type Token'' = Either Error Token' + +alexEOF :: Alex (Maybe Token'') +alexEOF = return Nothing + +tok'' :: (Loc -> String -> Token'') -> AlexAction (Maybe Token'') +tok'' f (AlexPn _ line column, _, _, s) n = return . Just $ f loc (take n s) where + loc = Loc + { line = line + , column = column + , width = n + } + +tok' :: (String -> Token) -> AlexAction (Maybe Token'') +tok' f = tok'' (\loc s -> Right (WithLoc loc (f s))) + +tok :: Token -> AlexAction (Maybe Token'') +tok token = tok' (const token) + +parseInt :: String -> Token +parseInt s' = Int $ case filter (/= '_') s' of + '0' : 'b' : s -> foldl (\acc c -> acc * 2 + read [c]) 0 (reverse s) + '0' : 'B' : s -> foldl (\acc c -> acc * 2 + read [c]) 0 (reverse s) + s@('0' : 'o' : _) -> read s + s@('0' : 'O' : _) -> read s + s@('0' : 'x' : _) -> read s + s@('0' : 'X' : _) -> read s + s -> read s + +-- | TODO: Make this compatible to Haskell. The current implementation is for Python. +parseString :: Loc -> String -> Token'' +parseString loc s = WithLoc loc . String <$> go (tail (init s)) where + go "" = Right "" + go ('\\' : s) = case s of + [] -> throwInternalErrorAt loc "invalid escape sequence" + 'a' : s -> ('\a' :) <$> go s + 'b' : s -> ('\b' :) <$> go s + 'f' : s -> ('\f' :) <$> go s + 'n' : s -> ('\n' :) <$> go s + 'r' : s -> ('\r' :) <$> go s + 't' : s -> ('\t' :) <$> go s + 'v' : s -> ('\v' :) <$> go s + o1 : o2 : o3 : s | isOctDigit o1 && isOctDigit o2 && isOctDigit o3 -> (chr (read ("0o" ++ [o1, o2, o3])) :) <$> go s + o1 : o2 : s | isOctDigit o1 && isOctDigit o2 -> (chr (read ("0o" ++ [o1, o2])) :) <$> go s + o1 : s | isOctDigit o1 -> (chr (read ("0o" ++ [o1])) :) <$> go s + 'x' : h1 : h2 : s | isHexDigit h1 && isHexDigit h2 -> (chr (read ("0x" ++ [h1, h2])) :) <$> go s + 'x' : _ -> throwLexicalErrorAt loc "truncated \\xXX escape" + c : s -> (c :) <$> go s + go (c : s) = (c :) <$> go s + +skip' :: AlexAction (Maybe Token'') +skip' (AlexPn _ line column, _, _, s) n = return (Just (Left err)) where + loc = Loc line column n + msg = show (take n s) ++ " is not a acceptable character" + err = lexicalErrorAt loc msg + +unfoldM :: Monad m => m (Maybe a) -> m [a] +unfoldM f = do + x <- f + case x of + Nothing -> return [] + Just x -> (x :) <$> unfoldM f + +run :: MonadError Error m => String -> m [Token'] +run input = wrapError' "Jikka.Core.Parse.Alex" $ do + case runAlex input (unfoldM alexMonadScan) of + Left err -> throwInternalError $ "Alex says: " ++ err + Right tokens -> reportErrors tokens +} diff --git a/src/Jikka/Core/Parse/Happy.y b/src/Jikka/Core/Parse/Happy.y new file mode 100644 index 00000000..f2f7149d --- /dev/null +++ b/src/Jikka/Core/Parse/Happy.y @@ -0,0 +1,559 @@ +{ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TupleSections #-} + +-- vim: filetype=haskell + +-- | +-- Module : Jikka.Core.Parse.Happy +-- Description : parses the code of the standard Core with Happy. +-- Copyright : (c) Kimiyuki Onaka, 2020 +-- License : Apache License 2.0 +-- Maintainer : kimiyuki95@gmail.com +-- Stability : experimental +-- Portability : portable +-- +-- See also Haskell's . +module Jikka.Core.Parse.Happy + ( runProgram + , runExpr + , runType + , runRule + ) where + +import Data.List (intercalate) +import Jikka.Common.Alpha +import Jikka.Common.Error +import Jikka.Common.Location +import Jikka.Core.Language.BuiltinPatterns +import Jikka.Core.Language.Expr +import Jikka.Core.Language.Util +import qualified Jikka.Core.Parse.Token as L +} + +%name runProgram_ program +%name runExpr_ expression +%name runType_ type +%name runRule_ rule +%tokentype { WithLoc L.Token } +%monad { Either Error } +%error { happyErrorExpList } +%errorhandlertype explist + +%token + -- literals + INTEGER { WithLoc _ (L.Int _) } + BOOLEAN { WithLoc _ (L.Bool _) } + STRING { WithLoc _ (L.String _) } + + -- keywords + "let" { WithLoc _ L.Let } + "rec" { WithLoc _ L.Rec } + "in" { WithLoc _ L.In } + "fun" { WithLoc _ L.Fun } + "if" { WithLoc _ L.If } + "then" { WithLoc _ L.Then } + "else" { WithLoc _ L.Else } + "forall" { WithLoc _ L.Forall } + + -- punctuations + "->" { WithLoc _ L.Arrow } + ":" { WithLoc _ L.Colon } + "," { WithLoc _ L.Comma } + "=" { WithLoc _ L.Equal } + "_" { WithLoc _ L.Underscore } + "." { WithLoc _ L.Dot } + "<-" { WithLoc _ L.BackArrow } + "@" { WithLoc _ L.At } + + -- parens + "[" { WithLoc _ L.OpenBracket } + "(" { WithLoc _ L.OpenParen } + "]" { WithLoc _ L.CloseBracket } + ")" { WithLoc _ L.CloseParen } + + -- types + "int" { WithLoc _ (L.Ident "int") } + "bool" { WithLoc _ (L.Ident "bool") } + "list" { WithLoc _ (L.Ident "list") } + "unit" { WithLoc _ (L.Ident "unit") } + "convex_hull_trick" { WithLoc _ (L.Ident "convex_hull_trick") } + "segment_tree" { WithLoc _ (L.Ident "segment_tree") } + "int_plus" { WithLoc _ (L.Ident "int_plus") } + "int_min" { WithLoc _ (L.Ident "int_min") } + "int_max" { WithLoc _ (L.Ident "int_max") } + + -- builtins + "nil" { WithLoc _ (L.Ident "nil") } + "abs" { WithLoc _ (L.Ident "abs") } + "gcd" { WithLoc _ (L.Ident "gcd") } + "lcm" { WithLoc _ (L.Ident "lcm") } + "iterate" { WithLoc _ (L.Ident "iterate") } + "matap" { WithLoc _ (L.Ident "matap") } + "matzero" { WithLoc _ (L.Ident "matzero") } + "matone" { WithLoc _ (L.Ident "matone") } + "matadd" { WithLoc _ (L.Ident "matadd") } + "matmul" { WithLoc _ (L.Ident "matmul") } + "matpow" { WithLoc _ (L.Ident "matpow") } + "vecfloormod" { WithLoc _ (L.Ident "vecfloormod") } + "matfloormod" { WithLoc _ (L.Ident "matfloormod") } + "modnegate" { WithLoc _ (L.Ident "modnegate") } + "modplus" { WithLoc _ (L.Ident "modplus") } + "modminus" { WithLoc _ (L.Ident "modminus") } + "modmult" { WithLoc _ (L.Ident "modmult") } + "modinv" { WithLoc _ (L.Ident "modinv") } + "modpow" { WithLoc _ (L.Ident "modpow") } + "modmatap" { WithLoc _ (L.Ident "modmatap") } + "modmatadd" { WithLoc _ (L.Ident "modmatadd") } + "modmatmul" { WithLoc _ (L.Ident "modmatmul") } + "modmatpow" { WithLoc _ (L.Ident "modmatpow") } + "cons" { WithLoc _ (L.Ident "cons") } + "snoc" { WithLoc _ (L.Ident "snoc") } + "foldl" { WithLoc _ (L.Ident "foldl") } + "scanl" { WithLoc _ (L.Ident "scanl") } + "build" { WithLoc _ (L.Ident "build") } + "len" { WithLoc _ (L.Ident "len") } + "map" { WithLoc _ (L.Ident "map") } + "filter" { WithLoc _ (L.Ident "filter") } + "elem" { WithLoc _ (L.Ident "elem") } + "sum" { WithLoc _ (L.Ident "sum") } + "product" { WithLoc _ (L.Ident "product") } + "modsum" { WithLoc _ (L.Ident "modsum") } + "modproduct" { WithLoc _ (L.Ident "modproduct") } + "min" { WithLoc _ (L.Ident "min") } + "max" { WithLoc _ (L.Ident "max") } + "argmin" { WithLoc _ (L.Ident "argmin") } + "argmax" { WithLoc _ (L.Ident "argmax") } + "all" { WithLoc _ (L.Ident "all") } + "any" { WithLoc _ (L.Ident "any") } + "sorted" { WithLoc _ (L.Ident "sorted") } + "reversed" { WithLoc _ (L.Ident "reversed") } + "range" { WithLoc _ (L.Ident "range") } + "range2" { WithLoc _ (L.Ident "range2") } + "range3" { WithLoc _ (L.Ident "range3") } + "fact" { WithLoc _ (L.Ident "fact") } + "choose" { WithLoc _ (L.Ident "choose") } + "permute" { WithLoc _ (L.Ident "permute") } + "multichoose" { WithLoc _ (L.Ident "multichoose") } + "cht_init" { WithLoc _ (L.Ident "cht_init") } + "cht_getmin" { WithLoc _ (L.Ident "cht_getmin") } + "cht_insert" { WithLoc _ (L.Ident "cht_insert") } + "segtree_init" { WithLoc _ (L.Ident "segtree_init") } + "segtree_getrange" { WithLoc _ (L.Ident "segtree_getrange") } + "segtree_setpoint" { WithLoc _ (L.Ident "segtree_setpoint") } + + -- identifiers + IDENT { WithLoc _ (L.Ident _) } + + -- arithmetic operators + "+" { WithLoc _ (L.Operator L.Plus) } + "-" { WithLoc _ (L.Operator L.Minus) } + "*" { WithLoc _ (L.Operator L.Mult) } + "/" { WithLoc _ (L.Operator L.FloorDiv) } + "%" { WithLoc _ (L.Operator L.FloorMod) } + "/^" { WithLoc _ (L.Operator L.CeilDiv) } + "%^" { WithLoc _ (L.Operator L.CeilMod) } + "**" { WithLoc _ (L.Operator L.Pow) } + + -- boolean operators + "and" { WithLoc _ (L.Operator L.And) } + "or" { WithLoc _ (L.Operator L.Or) } + "not" { WithLoc _ (L.Operator L.Not) } + "implies" { WithLoc _ (L.Operator L.Implies) } + + -- bit operators + "~" { WithLoc _ (L.Operator L.BitNot) } + "&" { WithLoc _ (L.Operator L.BitAnd) } + "|" { WithLoc _ (L.Operator L.BitOr) } + "^" { WithLoc _ (L.Operator L.BitXor) } + "<<" { WithLoc _ (L.Operator L.BitLShift) } + ">>" { WithLoc _ (L.Operator L.BitRShift) } + + -- min max operators + "?" { WithLoc _ (L.Operator L.Max) } + + -- comparators + ">" { WithLoc _ (L.Operator L.GreaterThan) } + "<" { WithLoc _ (L.Operator L.LessThan) } + "<=" { WithLoc _ (L.Operator L.LessEqual) } + ">=" { WithLoc _ (L.Operator L.GreaterEqual) } + "==" { WithLoc _ (L.Operator L.DoubleEqual) } + "/=" { WithLoc _ (L.Operator L.NotEqual) } +%% + +program :: { Program } + : topdecls { $1 } + +rule :: { (String, [(VarName, Type)], Expr, Expr) } + : STRING expression "=" expression { let L.String name = value $1 in (name, [], $2, $4) } + | STRING "forall" list1(arg) "." expression "=" expression { let L.String name = value $1 in (name, $3, $5, $7) } + +-- utilities +opt(p) -- :: { Maybe a } + : {- empty -} { Nothing } + | p { Just $1 } +rev_list1(p) -- :: { [a] } + : p { [$1] } + | rev_list1(p) p { $2 : $1 } +list1(p) -- :: { [a] } + : rev_list1(p) { reverse $1 } +list(p) -- :: { [a] } + : {- empty -} { [] } + | list1(p) { $1 } +rev_sep1(p, q) -- :: { [a] } + : p { [$1] } + | rev_sep1(p, q) q p { $3 : $1 } +sep1(p, q) -- :: { [a] } + : rev_sep1(p, q) { reverse $1 } +sep1opt(p, q) -- :: { [a] } + : rev_sep1(p, q) opt(q) { reverse $1 } +fst(p, q) + : p q { $1 } +snd(p, q) + : p q { $2 } +both(p, q) + : p q { ($1, $2) } + +topdecls :: { ToplevelExpr } + : expression_nolet { ResultExpr $1 } + | topdecl topdecls { $1 $2 } + +topdecl :: { ToplevelExpr -> ToplevelExpr } + : "let" identifier ":" type "=" expression "in" { ToplevelLet $2 $4 $6 } + | "let" "rec" identifier list(arg) ":" type "=" expression "in" { ToplevelLetRec $3 $4 $6 $8 } + +-- Types +atom_type :: { Type } + : IDENT { let (L.Ident x) = value $1 in VarTy (TypeName x) } + | "int" { IntTy } + | "bool" { BoolTy } + | atom_type "list" { ListTy $1 } + | "unit" { TupleTy [] } + | datastructure { DataStructureTy $1 } + | "(" type ")" { $2 } + +tuple_type :: { Type } + : atom_type { $1 } + | atom_type "*" sep1(atom_type, "*") { TupleTy ($1 : $3) } + +type :: { Type } + : tuple_type { $1 } + | tuple_type "->" type { FunTy $1 $3 } + +-- Data Structures +datastructure :: { DataStructure } + : "convex_hull_trick" { ConvexHullTrick } + | "segment_tree" "<" semigroup ">" { SegmentTree $3 } + +semigroup :: { Semigroup' } + : "int_plus" { SemigroupIntPlus } + | "int_min" { SemigroupIntMin } + | "int_max" { SemigroupIntMax } + +-- Arguments +arg :: { (VarName, Type) } + : identifier { ($1, underscoreTy) } + | "(" identifier ":" type ")" { ($2, $4) } + +-- Atoms +atom :: { Expr } + : identifier { Var $1 } + | literal { Lit $1 } + | parenth_form { $1 } + | builtin { Lit (uncurry LitBuiltin $1) } + +identifier :: { VarName } + : IDENT { let (L.Ident x) = value $1 in VarName x } + | "_" { VarName "_" } + +integer :: { Integer } + : INTEGER { let (L.Int n) = value $1 in n } + +literal :: { Literal } + : integer { LitInt $1 } + | BOOLEAN { let (L.Bool p) = value $1 in LitBool p } + | "nil" { LitNil underscoreTy } + | "nil" "@" atom_type { LitNil $3 } + +parenth_form :: { Expr } + : "(" ")" {% makeTuple [] UnitTy } + | "(" ")" "@" atom_type {% makeTuple [] $4 } + | "(" expression ")" { $2 } + | "(" expression "," ")" {% makeTuple [$2] (TupleTy [underscoreTy]) } + | "(" expression "," ")" "@" atom_type {% makeTuple [$2] $6 } + | "(" expression "," expression_list ")" {% makeTuple ($2 : $4) (TupleTy (replicate (length ($2 : $4)) underscoreTy)) } + | "(" expression "," expression_list ")" "@" atom_type {% makeTuple ($2 : $4) $7 } + +builtin :: { (Builtin, [Type]) } + : "abs" { (Abs, []) } + | "gcd" { (Gcd, []) } + | "lcm" { (Lcm, []) } + | "iterate" { (Iterate, [underscoreTy]) } + | "iterate" "@" atom_type { (Iterate, [$3]) } + | "matap" integer integer { (MatAp $2 $3, []) } + | "matzero" integer { (MatZero $2, []) } + | "matone" integer { (MatOne $2, []) } + | "matadd" integer integer { (MatAdd $2 $3, []) } + | "matmul" integer integer integer { (MatMul $2 $3 $4, []) } + | "matpow" integer { (MatPow $2, []) } + | "vecfloormod" integer { (VecFloorMod $2, []) } + | "matfloormod" integer integer { (MatFloorMod $2 $3, []) } + | "modnegate" { (ModNegate, []) } + | "modplus" { (ModPlus, []) } + | "modminus" { (ModMinus, []) } + | "modmult" { (ModMult, []) } + | "modinv" { (ModInv, []) } + | "modpow" { (ModPow, []) } + | "modmatap" integer integer { (ModMatAp $2 $3, []) } + | "modmatadd" integer integer { (ModMatAdd $2 $3, []) } + | "modmatmul" integer integer integer { (ModMatMul $2 $3 $4, []) } + | "modmatpow" integer { (ModMatPow $2, []) } + | "cons" { (Cons, [underscoreTy]) } + | "cons" "@" atom_type { (Cons, [$3]) } + | "snoc" { (Snoc, [underscoreTy]) } + | "snoc" "@" atom_type { (Snoc, [$3]) } + | "foldl" { (Foldl, [underscoreTy, underscoreTy]) } + | "foldl" "@" atom_type "@" atom_type { (Foldl, [$3, $5]) } + | "scanl" { (Scanl, [underscoreTy, underscoreTy]) } + | "scanl" "@" atom_type "@" atom_type { (Scanl, [$3, $5]) } + | "build" { (Build, [underscoreTy]) } + | "build" "@" atom_type { (Build, [$3]) } + | "len" { (Len, [underscoreTy]) } + | "len" "@" atom_type { (Len, [$3]) } + | "map" { (Map, [underscoreTy, underscoreTy]) } + | "map" "@" atom_type "@" atom_type { (Map, [$3, $5]) } + | "filter" { (Filter, [underscoreTy]) } + | "filter" "@" atom_type { (Filter, [$3]) } + | "elem" { (Elem, [underscoreTy]) } + | "elem" "@" atom_type { (Elem, [$3]) } + | "sum" { (Sum, []) } + | "product" { (Product, []) } + | "modsum" { (ModSum, []) } + | "modproduct" { (ModProduct, []) } + | "min" { (Min1, [underscoreTy]) } + | "min" "@" atom_type { (Min1, [$3]) } + | "max" { (Max1, [underscoreTy]) } + | "max" "@" atom_type { (Max1, [$3]) } + | "argmin" { (ArgMin, [underscoreTy]) } + | "argmin" "@" atom_type { (ArgMin, [$3]) } + | "argmax" { (ArgMax, [underscoreTy]) } + | "argmax" "@" atom_type { (ArgMax, [$3]) } + | "all" { (All, []) } + | "any" { (Any, []) } + | "sorted" { (Sorted, [underscoreTy]) } + | "sorted" "@" atom_type { (Sorted, [$3]) } + | "reversed" { (Reversed, [underscoreTy]) } + | "reversed" "@" atom_type { (Reversed, [$3]) } + | "range" { (Range1, []) } + | "range2" { (Range2, []) } + | "range3" { (Range3, []) } + | "fact" { (Fact, []) } + | "choose" { (Choose, []) } + | "permute" { (Permute, []) } + | "multichoose" { (MultiChoose, []) } + | "cht_init" { (ConvexHullTrickInit, []) } + | "cht_getmin" { (ConvexHullTrickGetMin, []) } + | "cht_insert" { (ConvexHullTrickInsert, []) } + | "segtree_init" semigroup { (SegmentTreeInitList $2, []) } + | "segtree_getrange" semigroup { (SegmentTreeGetRange $2, []) } + | "segtree_setpoint" semigroup { (SegmentTreeSetPoint $2, []) } + +-- Primaries +primary :: { Expr } + : atom { $1 } + | subscription { $1 } + +-- Subscriptions +subscription :: { Expr } + : primary "[" expression "]" { At' underscoreTy $1 $3 } + | primary "[" expression "]" "@" atom_type { At' $6 $1 $3 } + | primary "[" expression "<-" expression "]" { SetAt' underscoreTy $1 $3 $5 } + | primary "[" expression "<-" expression "]" "@" atom_type { SetAt' $8 $1 $3 $5 } + -- | primary "." integer {% makeProj $1 $3 underscoreTy } + | primary "." integer "@" atom_type {% makeProj $1 $3 $5 } + +-- Function applications +funapp :: { Expr } + : primary { $1 } + | funapp primary { App $1 $2 } + +-- The power operator +power :: { Expr } + : funapp { $1 } + | funapp "**" u_expr { Pow' $1 $3 } + +-- Unary arithmetic and bitwise operations +u_expr :: { Expr } + : power { $1 } + | "-" u_expr { Negate' $2 } + | "+" u_expr { $2 } + | "~" u_expr { BitNot' $2 } + +-- Binary arithmetic operations +m_expr :: { Expr } + : u_expr { $1 } + | m_expr "*" u_expr { Mult' $1 $3 } + | m_expr "/" u_expr { FloorDiv' $1 $3 } + | m_expr "%" u_expr { FloorMod' $1 $3 } + | m_expr "/^" u_expr { CeilDiv' $1 $3 } + | m_expr "%^" u_expr { CeilMod' $1 $3 } +a_expr :: { Expr } + : m_expr { $1 } + | a_expr "+" m_expr { Plus' $1 $3 } + | a_expr "-" m_expr { Minus' $1 $3 } + +-- Shifting operations +shift_expr :: { Expr } + : a_expr { $1 } + | shift_expr "<<" a_expr { BitLeftShift' $1 $3 } + | shift_expr ">>" a_expr { BitRightShift' $1 $3 } + +-- 6.9. Binary bitwise operations +and_expr :: { Expr } + : shift_expr { $1 } + | and_expr "&" shift_expr { BitAnd' $1 $3 } +xor_expr :: { Expr } + : and_expr { $1 } + | xor_expr "^" and_expr { BitXor' $1 $3 } +or_expr :: { Expr } + : xor_expr { $1 } + | or_expr "|" xor_expr { BitOr' $1 $3 } + +-- Min and max operations +min_expr :: { Expr } + : or_expr { $1 } + | min_expr "?" or_expr { Max2' underscoreTy $1 $3 } + | min_expr ">?" "@" atom_type or_expr { Max2' $4 $1 $5 } + +-- Comparisons +comparison :: { Expr } + : min_expr { $1 } + | comparison comp_operator min_expr { $2 underscoreTy $1 $3 } + | comparison comp_operator "@" atom_type min_expr { $2 $4 $1 $5 } +comp_operator :: { Type -> Expr -> Expr -> Expr } + : "==" { Equal' } + | "/=" { NotEqual' } + | "<" { LessThan' } + | ">" { GreaterThan' } + | "<=" { LessEqual' } + | ">=" { GreaterEqual' } + +-- Boolean operations +not_test :: { Expr } + : comparison { $1 } + | "not" not_test { Not' $2 } +and_test :: { Expr } + : not_test { $1 } + | and_test "and" not_test { And' $1 $3 } +or_test :: { Expr } + : and_test { $1 } + | or_test "or" and_test { Or' $1 $3 } + +-- Implication operation +implies_test :: { Expr } + : or_test { $1 } + | or_test "implies" implies_test { Implies' $1 $3 } + +-- Conditional expressions +conditional_expression :: { Expr } + : "if" expression "then" expression "else" expression { If' underscoreTy $2 $4 $6 } + | "if" "@" atom_type expression "then" expression "else" expression { If' $3 $4 $6 $8 } + +-- Lambda +lambda_expr :: { Expr } + : "fun" list1(arg) "->" expression { curryLam $2 $4 } + +-- Let +let_expr :: { Expr } + : "let" identifier ":" type "=" expression "in" expression { Let $2 $4 $6 $8 } + +expression_nolet :: { Expr } + : implies_test { $1 } + | conditional_expression { $1 } + | lambda_expr { $1 } +expression :: { Expr } + : expression_nolet { $1 } + | let_expr { $1 } + +-- Expression lists +expression_list :: { [Expr] } + : sep1(expression, ",") { $1 } + +{ +(<@>) :: Functor f => (a -> b) -> f a -> f b +(<@>) = (<$>) + +underscoreTy :: Type +underscoreTy = VarTy (TypeName "_") + +makeTuple :: MonadError Error m => [Expr] -> Type -> m Expr +makeTuple es t = case t of + TupleTy ts | length ts == length es -> return $ uncurryApp (Tuple' ts) es + _ -> throwSyntaxError "Jikka.Core.Parse.Happy.makeTuple: wrong type annotation for tuple" + +makeProj :: MonadError Error m => Expr -> Integer -> Type -> m Expr +makeProj e n t = case t of + TupleTy ts -> return $ Proj' ts n e + _ -> throwSyntaxError "Jikka.Core.Parse.Happy.makeTuple: wrong type annotation for proj" + +replaceUnderscoresT :: MonadAlpha m => Type -> m Type +replaceUnderscoresT = mapSubTypesM go where + go = \case + VarTy (TypeName "_") -> genType + t -> return t + +replaceUnderscoresE :: MonadAlpha m => [(VarName, Type)] -> Expr -> m Expr +replaceUnderscoresE env = mapExprM go env where + go _ = \case + Var (VarName "_") -> Var <$> genVarName' + e -> return e + +happyErrorExpList :: MonadError Error m => ([WithLoc L.Token], [String]) -> m a +happyErrorExpList (tokens, expected) = throwSyntaxErrorAt' loc' msg where + loc' :: Maybe Loc + loc' = case tokens of + [] -> Nothing + (token : _) -> Just (loc token) + msg :: String + msg = tok tokens ++ " is got, but " ++ exp expected ++ " expected" + tok :: [WithLoc L.Token] -> String + tok [] = "EOF" + tok (token : _) = wrap . show $ value token + exp :: [String] -> String + exp [] = "EOF is" + exp [item] = wrap item ++ " is" + exp items = intercalate ", " (map wrap $ init items) ++ ", or " ++ (wrap $ last items) ++ " are" + wrap :: String -> String + wrap ('\'' : s) = '`' : s + wrap s = "`" ++ s ++ "'" + +runRule :: (MonadAlpha m, MonadError Error m) => [WithLoc L.Token] -> m (String, [(VarName, Type)], Expr, Expr) +runRule tokens = wrapError' "Jikka.Core.Parse.Happy.runRule" $ do + (name, args, e1, e2) <- liftEither $ runRule_ tokens + args <- mapM (\(x, t) -> (x,) <$> replaceUnderscoresT t) args + e1 <- mapTypeExprM replaceUnderscoresT e1 + e2 <- mapTypeExprM replaceUnderscoresT e2 + -- Don't replace underscores in exprs + return (name, args, e1, e2) + +runType :: (MonadAlpha m, MonadError Error m) => [WithLoc L.Token] -> m Type +runType tokens = wrapError' "Jikka.Core.Parse.Happy.runType" $ do + t <- liftEither $ runType_ tokens + replaceUnderscoresT t + +runExpr :: (MonadAlpha m, MonadError Error m) => [WithLoc L.Token] -> m Expr +runExpr tokens = wrapError' "Jikka.Core.Parse.Happy.runExpr" $ do + e <- liftEither $ runExpr_ tokens + mapTypeExprM replaceUnderscoresT e + mapExprM replaceUnderscoresE [] e + +runProgram :: (MonadAlpha m, MonadError Error m) => [WithLoc L.Token] -> m Program +runProgram tokens = wrapError' "Jikka.Core.Parse.Happy.runProgram" $ do + prog <- liftEither $ runProgram_ tokens + prog <- mapTypeProgramM replaceUnderscoresT prog + mapExprProgramM replaceUnderscoresE prog +} diff --git a/src/Jikka/Core/Parse/Token.hs b/src/Jikka/Core/Parse/Token.hs new file mode 100644 index 00000000..c061f015 --- /dev/null +++ b/src/Jikka/Core/Parse/Token.hs @@ -0,0 +1,82 @@ +-- | +-- Module : Jikka.Core.Parse.Token +-- Description : defines tokens of our core language. / core 言語の字句要素を定義します。 +-- Copyright : (c) Kimiyuki Onaka, 2020 +-- License : Apache License 2.0 +-- Maintainer : kimiyuki95@gmail.com +-- Stability : experimental +-- Portability : portable +module Jikka.Core.Parse.Token where + +import Jikka.Common.Location + +data Operator + = -- arithmetic operators + Plus + | Minus + | Mult + | FloorDiv + | FloorMod + | CeilDiv + | CeilMod + | Pow + | -- boolean operators + Not + | And + | Or + | Implies + | -- bit operators + BitNot + | BitAnd + | BitOr + | BitXor + | BitLShift + | BitRShift + | -- min max operators + Min + | Max + | -- comparators + DoubleEqual + | NotEqual + | LessThan + | LessEqual + | GreaterThan + | GreaterEqual + deriving (Eq, Ord, Show, Read) + +-- | We don't have to classify tokens in detail, but it's convenient for testing and debugging. +data Token + = -- identifier + Ident String + | -- literals + Int Integer + | Bool Bool + | String String + | -- keywords + Let + | Rec + | In + | If + | Then + | Else + | Fun + | Dot + | Forall + | -- punctuations + Arrow + | Equal + | Colon + | Comma + | Underscore + | BackArrow + | At + | -- parens + OpenBracket + | OpenParen + | CloseBracket + | CloseParen + | -- operators + Operator Operator + deriving (Eq, Ord, Show, Read) + +type Token' = WithLoc Token diff --git a/src/Jikka/Python/Parse/Alex.x b/src/Jikka/Python/Parse/Alex.x index 7b8a1c51..6bdf8616 100644 --- a/src/Jikka/Python/Parse/Alex.x +++ b/src/Jikka/Python/Parse/Alex.x @@ -3,7 +3,7 @@ {-# LANGUAGE FlexibleContexts #-} -- | --- Module : Jikka.Core.Parse.Alex +-- Module : Jikka.Python.Parse.Alex -- Description : tokenizes the code of the standard Python with Alex. -- Copyright : (c) Kimiyuki Onaka, 2020 -- License : Apache License 2.0 @@ -12,7 +12,6 @@ -- Portability : portable -- -- * TODO: tokenize float literals --- * TODO: tokenize string literals module Jikka.Python.Parse.Alex ( run ) where diff --git a/src/Jikka/Python/Parse/Happy.y b/src/Jikka/Python/Parse/Happy.y index 3ea293c5..369ba877 100644 --- a/src/Jikka/Python/Parse/Happy.y +++ b/src/Jikka/Python/Parse/Happy.y @@ -5,7 +5,7 @@ -- vim: filetype=haskell -- | --- Module : Jikka.Core.Parse.Happy +-- Module : Jikka.Python.Parse.Happy -- Description : parses the code of the standard Python with Happy. -- Copyright : (c) Kimiyuki Onaka, 2020 -- License : Apache License 2.0 diff --git a/src/Jikka/Python/Parse/Token.hs b/src/Jikka/Python/Parse/Token.hs index fbb2822c..66dfaded 100644 --- a/src/Jikka/Python/Parse/Token.hs +++ b/src/Jikka/Python/Parse/Token.hs @@ -1,5 +1,5 @@ -- | --- Module : Jikka.Core.Parse.Token +-- Module : Jikka.Python.Parse.Token -- Description : defines tokens of the standard Python. / 標準の Python の字句要素を定義します。 -- Copyright : (c) Kimiyuki Onaka, 2020 -- License : Apache License 2.0 diff --git a/src/Jikka/RestrictedPython/Convert/ToCore.hs b/src/Jikka/RestrictedPython/Convert/ToCore.hs index 00253a1d..8cd305c4 100644 --- a/src/Jikka/RestrictedPython/Convert/ToCore.hs +++ b/src/Jikka/RestrictedPython/Convert/ToCore.hs @@ -68,89 +68,96 @@ runConstant = \case X.ConstBuiltin builtin -> runBuiltin builtin runBuiltin :: MonadError Error m => X.Builtin -> m Y.Expr -runBuiltin builtin = - let f = return . Y.Lit . Y.LitBuiltin - in case builtin of - X.BuiltinAbs -> f Y.Abs - X.BuiltinPow -> f Y.Pow - X.BuiltinModPow -> f Y.ModPow - X.BuiltinDivMod -> return $ Y.Lam2 "a" Y.IntTy "b" Y.IntTy (Y.uncurryApp (Y.Tuple' [Y.IntTy, Y.IntTy]) [Y.FloorDiv' (Y.Var "a") (Y.Var "b"), Y.FloorMod' (Y.Var "a") (Y.Var "b")]) - X.BuiltinCeilDiv -> f Y.CeilDiv - X.BuiltinCeilMod -> f Y.CeilMod - X.BuiltinFloorDiv -> f Y.FloorDiv - X.BuiltinFloorMod -> f Y.FloorMod - X.BuiltinGcd -> f Y.Gcd - X.BuiltinLcm -> f Y.Lcm - X.BuiltinInt t -> case t of - X.IntTy -> return $ Y.Lam "x" Y.IntTy (Y.Var "x") - X.BoolTy -> return $ Y.Lam "p" Y.BoolTy (Y.If' Y.IntTy (Y.Var "p") Y.Lit1 Y.Lit0) - _ -> throwTypeError "the argument of int must be int or bool" - X.BuiltinBool t -> case t of - X.IntTy -> return $ Y.Lam "x" Y.IntTy (Y.If' Y.BoolTy (Y.Equal' Y.IntTy (Y.Var "x") Y.Lit0) Y.LitFalse Y.LitTrue) - X.BoolTy -> return $ Y.Lam "p" Y.BoolTy (Y.Var "p") - X.ListTy t -> do - t <- runType t - return $ Y.Lam "xs" (Y.ListTy t) (Y.If' Y.BoolTy (Y.Equal' (Y.ListTy t) (Y.Var "xs") (Y.Lit (Y.LitNil t))) Y.LitFalse Y.LitTrue) - _ -> throwTypeError "the argument of bool must be bool, int, or list(a)" - X.BuiltinList t -> do - t <- runType t - return $ Y.Lam "xs" (Y.ListTy t) (Y.Var "xs") - X.BuiltinTuple ts -> f . Y.Tuple =<< mapM runType ts - X.BuiltinLen t -> f . Y.Len =<< runType t - X.BuiltinMap ts ret -> case ts of - [] -> Y.Nil' <$> runType ret - _ -> do - ts <- mapM runType ts - ret <- runType ret - let var i = Y.VarName ("xs" ++ show i) - let lam body = Y.Lam "f" (Y.curryFunTy ts ret) (foldr (\(i, t) -> Y.Lam (var i) (Y.ListTy t)) body (zip [0 ..] ts)) - let len = Y.Min1' Y.IntTy (foldr (Y.Cons' Y.IntTy) (Y.Nil' Y.IntTy) (zipWith (\i t -> Y.Len' t (Y.Var (var i))) [0 ..] ts)) - let body = Y.Map' Y.IntTy ret (Y.Lam "i" Y.IntTy (Y.uncurryApp (Y.Var "f") (map (Y.Var . var) [0 .. length ts - 1]))) (Y.Range1' len) - return $ lam body - X.BuiltinSorted t -> f . Y.Sorted =<< runType t - X.BuiltinReversed t -> f . Y.Reversed =<< runType t - X.BuiltinEnumerate t -> do - t <- runType t - let body = Y.Lam "i" Y.IntTy (Y.uncurryApp (Y.Tuple' [Y.IntTy, t]) [Y.Var "i", Y.At' t (Y.Var "xs") (Y.Var "i")]) - return $ Y.Lam "xs" (Y.ListTy t) (Y.Map' (Y.ListTy t) (Y.ListTy (Y.TupleTy [Y.IntTy, t])) body (Y.Range1' (Y.Len' t (Y.Var "xs")))) - X.BuiltinFilter t -> f . Y.Filter =<< runType t - X.BuiltinZip ts -> do - ts <- mapM runType ts - let var i = Y.VarName ("xs" ++ show i) - let lam body = foldr (\(i, t) -> Y.Lam (var i) (Y.ListTy t)) body (zip [0 ..] ts) - let len = Y.Min1' Y.IntTy (foldr (Y.Cons' Y.IntTy) (Y.Nil' Y.IntTy) (zipWith (\i t -> Y.Len' t (Y.Var (var i))) [0 ..] ts)) - let body = Y.Map' Y.IntTy (Y.TupleTy ts) (Y.Lam "i" Y.IntTy (Y.uncurryApp (Y.Tuple' ts) (map (Y.Var . var) [0 .. length ts - 1]))) (Y.Range1' len) - return $ lam body - X.BuiltinAll -> f Y.All - X.BuiltinAny -> f Y.Any - X.BuiltinSum -> f Y.Sum - X.BuiltinProduct -> f Y.Product - X.BuiltinRange1 -> f Y.Range1 - X.BuiltinRange2 -> f Y.Range2 - X.BuiltinRange3 -> f Y.Range1 - X.BuiltinMax1 t -> f . Y.Max1 =<< runType t - X.BuiltinMax t n -> do - when (n < 2) $ do - throwTypeError $ "max expected 2 or more arguments, got " ++ show n - t <- runType t - let args = map (\i -> Y.VarName ('x' : show i)) [0 .. n -1] - return $ Y.curryLam (map (,t) args) (foldr1 (Y.Max2' t) (map Y.Var args)) - X.BuiltinMin1 t -> f . Y.Min1 =<< runType t - X.BuiltinMin t n -> do - when (n < 2) $ do - throwTypeError $ "max min 2 or more arguments, got " ++ show n - t <- runType t - let args = map (\i -> Y.VarName ('x' : show i)) [0 .. n -1] - return $ Y.curryLam (map (,t) args) (foldr1 (Y.Min2' t) (map Y.Var args)) - X.BuiltinArgMax t -> f . Y.ArgMax =<< runType t - X.BuiltinArgMin t -> f . Y.ArgMin =<< runType t - X.BuiltinFact -> f Y.Fact - X.BuiltinChoose -> f Y.Choose - X.BuiltinPermute -> f Y.Permute - X.BuiltinMultiChoose -> f Y.MultiChoose - X.BuiltinModInv -> f Y.ModInv - X.BuiltinInput -> throwSemanticError "cannot use `input' out of main function" - X.BuiltinPrint _ -> throwSemanticError "cannot use `print' out of main function" +runBuiltin builtin = do + let go0 builtin = do + return $ Y.Lit (Y.LitBuiltin builtin []) + let go1 builtin t1 = do + t1 <- runType t1 + return $ Y.Lit (Y.LitBuiltin builtin [t1]) + let goN builtin ts = do + ts <- mapM runType ts + return $ Y.Lit (Y.LitBuiltin builtin ts) + case builtin of + X.BuiltinAbs -> go0 Y.Abs + X.BuiltinPow -> go0 Y.Pow + X.BuiltinModPow -> go0 Y.ModPow + X.BuiltinDivMod -> return $ Y.Lam2 "a" Y.IntTy "b" Y.IntTy (Y.uncurryApp (Y.Tuple' [Y.IntTy, Y.IntTy]) [Y.FloorDiv' (Y.Var "a") (Y.Var "b"), Y.FloorMod' (Y.Var "a") (Y.Var "b")]) + X.BuiltinCeilDiv -> go0 Y.CeilDiv + X.BuiltinCeilMod -> go0 Y.CeilMod + X.BuiltinFloorDiv -> go0 Y.FloorDiv + X.BuiltinFloorMod -> go0 Y.FloorMod + X.BuiltinGcd -> go0 Y.Gcd + X.BuiltinLcm -> go0 Y.Lcm + X.BuiltinInt t -> case t of + X.IntTy -> return $ Y.Lam "x" Y.IntTy (Y.Var "x") + X.BoolTy -> return $ Y.Lam "p" Y.BoolTy (Y.If' Y.IntTy (Y.Var "p") Y.Lit1 Y.Lit0) + _ -> throwTypeError "the argument of int must be int or bool" + X.BuiltinBool t -> case t of + X.IntTy -> return $ Y.Lam "x" Y.IntTy (Y.If' Y.BoolTy (Y.Equal' Y.IntTy (Y.Var "x") Y.Lit0) Y.LitFalse Y.LitTrue) + X.BoolTy -> return $ Y.Lam "p" Y.BoolTy (Y.Var "p") + X.ListTy t -> do + t <- runType t + return $ Y.Lam "xs" (Y.ListTy t) (Y.If' Y.BoolTy (Y.Equal' (Y.ListTy t) (Y.Var "xs") (Y.Lit (Y.LitNil t))) Y.LitFalse Y.LitTrue) + _ -> throwTypeError "the argument of bool must be bool, int, or list(a)" + X.BuiltinList t -> do + t <- runType t + return $ Y.Lam "xs" (Y.ListTy t) (Y.Var "xs") + X.BuiltinTuple ts -> goN Y.Tuple ts + X.BuiltinLen t -> go1 Y.Len t + X.BuiltinMap ts ret -> case ts of + [] -> Y.Nil' <$> runType ret + _ -> do + ts <- mapM runType ts + ret <- runType ret + let var i = Y.VarName ("xs" ++ show i) + let lam body = Y.Lam "go0" (Y.curryFunTy ts ret) (foldr (\(i, t) -> Y.Lam (var i) (Y.ListTy t)) body (zip [0 ..] ts)) + let len = Y.Min1' Y.IntTy (foldr (Y.Cons' Y.IntTy) (Y.Nil' Y.IntTy) (zipWith (\i t -> Y.Len' t (Y.Var (var i))) [0 ..] ts)) + let body = Y.Map' Y.IntTy ret (Y.Lam "i" Y.IntTy (Y.uncurryApp (Y.Var "go0") (map (Y.Var . var) [0 .. length ts - 1]))) (Y.Range1' len) + return $ lam body + X.BuiltinSorted t -> go1 Y.Sorted t + X.BuiltinReversed t -> go1 Y.Reversed t + X.BuiltinEnumerate t -> do + t <- runType t + let body = Y.Lam "i" Y.IntTy (Y.uncurryApp (Y.Tuple' [Y.IntTy, t]) [Y.Var "i", Y.At' t (Y.Var "xs") (Y.Var "i")]) + return $ Y.Lam "xs" (Y.ListTy t) (Y.Map' (Y.ListTy t) (Y.ListTy (Y.TupleTy [Y.IntTy, t])) body (Y.Range1' (Y.Len' t (Y.Var "xs")))) + X.BuiltinFilter t -> go1 Y.Filter t + X.BuiltinZip ts -> do + ts <- mapM runType ts + let var i = Y.VarName ("xs" ++ show i) + let lam body = foldr (\(i, t) -> Y.Lam (var i) (Y.ListTy t)) body (zip [0 ..] ts) + let len = Y.Min1' Y.IntTy (foldr (Y.Cons' Y.IntTy) (Y.Nil' Y.IntTy) (zipWith (\i t -> Y.Len' t (Y.Var (var i))) [0 ..] ts)) + let body = Y.Map' Y.IntTy (Y.TupleTy ts) (Y.Lam "i" Y.IntTy (Y.uncurryApp (Y.Tuple' ts) (map (Y.Var . var) [0 .. length ts - 1]))) (Y.Range1' len) + return $ lam body + X.BuiltinAll -> go0 Y.All + X.BuiltinAny -> go0 Y.Any + X.BuiltinSum -> go0 Y.Sum + X.BuiltinProduct -> go0 Y.Product + X.BuiltinRange1 -> go0 Y.Range1 + X.BuiltinRange2 -> go0 Y.Range2 + X.BuiltinRange3 -> go0 Y.Range1 + X.BuiltinMax1 t -> go1 Y.Max1 t + X.BuiltinMax t n -> do + when (n < 2) $ do + throwTypeError $ "max expected 2 or more arguments, got " ++ show n + t <- runType t + let args = map (\i -> Y.VarName ('x' : show i)) [0 .. n -1] + return $ Y.curryLam (map (,t) args) (foldr1 (Y.Max2' t) (map Y.Var args)) + X.BuiltinMin1 t -> go1 Y.Min1 t + X.BuiltinMin t n -> do + when (n < 2) $ do + throwTypeError $ "max min 2 or more arguments, got " ++ show n + t <- runType t + let args = map (\i -> Y.VarName ('x' : show i)) [0 .. n -1] + return $ Y.curryLam (map (,t) args) (foldr1 (Y.Min2' t) (map Y.Var args)) + X.BuiltinArgMax t -> go1 Y.ArgMax t + X.BuiltinArgMin t -> go1 Y.ArgMin t + X.BuiltinFact -> go0 Y.Fact + X.BuiltinChoose -> go0 Y.Choose + X.BuiltinPermute -> go0 Y.Permute + X.BuiltinMultiChoose -> go0 Y.MultiChoose + X.BuiltinModInv -> go0 Y.ModInv + X.BuiltinInput -> throwSemanticError "cannot use `input' out of main function" + X.BuiltinPrint _ -> throwSemanticError "cannot use `print' out of main function" runAttribute :: MonadError Error m => X.Attribute' -> m Y.Expr runAttribute a = wrapAt' (loc' a) $ do @@ -175,48 +182,49 @@ runBoolOp = \case X.Implies -> Y.Implies runUnaryOp :: X.UnaryOp -> Y.Expr -runUnaryOp = - let f = Y.Lit . Y.LitBuiltin - in \case - X.Invert -> f Y.BitNot - X.Not -> f Y.Not - X.UAdd -> Y.Lam "x" Y.IntTy (Y.Var "x") - X.USub -> f Y.Negate +runUnaryOp = \case + X.Invert -> Y.Builtin Y.BitNot + X.Not -> Y.Builtin Y.Not + X.UAdd -> Y.Lam "x" Y.IntTy (Y.Var "x") + X.USub -> Y.Builtin Y.Negate -runOperator :: MonadError Error m => X.Operator -> m Y.Builtin -runOperator = \case - X.Add -> return Y.Plus - X.Sub -> return Y.Minus - X.Mult -> return Y.Mult - X.MatMult -> throwSemanticError "matmul operator ('@') is not supported" - X.Div -> throwSemanticError "floatdiv operator ('/') is not supported" - X.FloorDiv -> return Y.FloorDiv - X.FloorMod -> return Y.FloorMod - X.CeilDiv -> return Y.CeilDiv - X.CeilMod -> return Y.CeilMod - X.Pow -> return Y.Pow - X.BitLShift -> return Y.BitLeftShift - X.BitRShift -> return Y.BitRightShift - X.BitOr -> return Y.BitOr - X.BitXor -> return Y.BitXor - X.BitAnd -> return Y.BitAnd - X.Max -> return $ Y.Max2 Y.IntTy - X.Min -> return $ Y.Min2 Y.IntTy +runOperator :: MonadError Error m => X.Operator -> m (Y.Builtin, [Y.Type]) +runOperator op = do + let go0 builtin = return (builtin, []) + let go1 builtin t1 = return (builtin, [t1]) + case op of + X.Add -> go0 Y.Plus + X.Sub -> go0 Y.Minus + X.Mult -> go0 Y.Mult + X.MatMult -> throwSemanticError "matmul operator ('@') is not supported" + X.Div -> throwSemanticError "floatdiv operator ('/') is not supported" + X.FloorDiv -> go0 Y.FloorDiv + X.FloorMod -> go0 Y.FloorMod + X.CeilDiv -> go0 Y.CeilDiv + X.CeilMod -> go0 Y.CeilMod + X.Pow -> go0 Y.Pow + X.BitLShift -> go0 Y.BitLeftShift + X.BitRShift -> go0 Y.BitRightShift + X.BitOr -> go0 Y.BitOr + X.BitXor -> go0 Y.BitXor + X.BitAnd -> go0 Y.BitAnd + X.Max -> go1 Y.Max2 Y.IntTy + X.Min -> go1 Y.Min2 Y.IntTy runCmpOp :: MonadError Error m => X.CmpOp' -> m Y.Expr runCmpOp (X.CmpOp' op t) = do t <- runType t - let f = Y.Lit . Y.LitBuiltin + let go1 builtin t1 = Y.Builtin1 builtin t1 return $ case op of - X.Lt -> f $ Y.LessThan t - X.LtE -> f $ Y.LessEqual t - X.Gt -> f $ Y.GreaterThan t - X.GtE -> f $ Y.GreaterEqual t - X.Eq' -> f $ Y.Equal t - X.NotEq -> f $ Y.NotEqual t - X.Is -> f $ Y.Equal t - X.IsNot -> f $ Y.NotEqual t - X.In -> f $ Y.Elem t + X.Lt -> go1 Y.LessThan t + X.LtE -> go1 Y.LessEqual t + X.Gt -> go1 Y.GreaterThan t + X.GtE -> go1 Y.GreaterEqual t + X.Eq' -> go1 Y.Equal t + X.NotEq -> go1 Y.NotEqual t + X.Is -> go1 Y.Equal t + X.IsNot -> go1 Y.NotEqual t + X.In -> go1 Y.Elem t X.NotIn -> Y.curryLam [("x", t), ("xs", Y.ListTy t)] (Y.Not' (Y.Elem' t (Y.Var "x") (Y.Var "xs"))) runTargetExpr :: (MonadAlpha m, MonadError Error m) => X.Target' -> m Y.Expr @@ -250,7 +258,7 @@ runListComp e (X.Comprehension x iter pred) = do runExpr :: (MonadAlpha m, MonadError Error m) => X.Expr' -> m Y.Expr runExpr e0 = wrapAt' (loc' e0) $ case value' e0 of X.BoolOp e1 op e2 -> Y.AppBuiltin2 (runBoolOp op) <$> runExpr e1 <*> runExpr e2 - X.BinOp e1 op e2 -> Y.AppBuiltin2 <$> runOperator op <*> runExpr e1 <*> runExpr e2 + X.BinOp e1 op e2 -> Y.App2 <$> (Y.Lit <$> (uncurry Y.LitBuiltin <$> runOperator op)) <*> runExpr e1 <*> runExpr e2 X.UnaryOp op e -> Y.App (runUnaryOp op) <$> runExpr e X.Lambda args body -> Y.curryLam <$> mapM (\(x, t) -> (runVarName x,) <$> runType t) args <*> runExpr body X.IfExp e1 e2 e3 -> do @@ -267,7 +275,7 @@ runExpr e0 = wrapAt' (loc' e0) $ case value' e0 of e <- runExpr e a <- runAttribute a return $ Y.App a e - X.Subscript e1 e2 -> Y.AppBuiltin2 <$> (Y.At <$> Y.genType) <*> runExpr e1 <*> runExpr e2 + X.Subscript e1 e2 -> Y.App2 <$> (Y.Builtin1 Y.At <$> Y.genType) <*> runExpr e1 <*> runExpr e2 X.Starred e -> throwSemanticErrorAt' (loc' e) "cannot use starred expr" X.Name x -> return $ Y.Var (runVarName x) X.List t es -> do @@ -367,7 +375,7 @@ runStatements (stmt : stmts) cont = case stmt of X.Return e -> runExpr e X.AugAssign x op e -> do y <- runTargetExpr x - op <- Y.Lit . Y.LitBuiltin <$> runOperator op + op <- Y.Lit . uncurry Y.LitBuiltin <$> runOperator op e <- runExpr e runAssign x (Y.App2 op y e) $ do runStatements stmts cont diff --git a/test/Jikka/Core/Convert/EtaSpec.hs b/test/Jikka/Core/Convert/EtaSpec.hs index a2e0b485..e1dd2633 100644 --- a/test/Jikka/Core/Convert/EtaSpec.hs +++ b/test/Jikka/Core/Convert/EtaSpec.hs @@ -8,6 +8,7 @@ where import Jikka.Common.Alpha import Jikka.Common.Error import Jikka.Core.Convert.Eta (run) +import Jikka.Core.Language.BuiltinPatterns import Jikka.Core.Language.Expr import Test.Hspec @@ -22,7 +23,7 @@ spec = describe "run" $ do ( Let "plus" (FunTy IntTy (FunTy IntTy IntTy)) - (Lit (LitBuiltin Plus)) + (Builtin Plus) (Var "plus") ) let expected = @@ -30,7 +31,7 @@ spec = describe "run" $ do ( Let "plus" (FunTy IntTy (FunTy IntTy IntTy)) - (Lam "$0" IntTy (Lam "$1" IntTy (App2 (Lit (LitBuiltin Plus)) (Var "$0") (Var "$1")))) + (Lam "$0" IntTy (Lam "$1" IntTy (Plus' (Var "$0") (Var "$1")))) (Var "plus") ) run' prog `shouldBe` Right expected diff --git a/test/Jikka/Core/ParseSpec.hs b/test/Jikka/Core/ParseSpec.hs new file mode 100644 index 00000000..4ab2e14a --- /dev/null +++ b/test/Jikka/Core/ParseSpec.hs @@ -0,0 +1,60 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Jikka.Core.ParseSpec + ( spec, + ) +where + +import qualified Data.Text as T +import Jikka.Common.Alpha +import Jikka.Common.Error +import Jikka.Core.Language.BuiltinPatterns +import Jikka.Core.Language.Expr +import Jikka.Core.Parse +import Test.Hspec + +run' :: String -> Either Error Program +run' prog = evalAlphaT (run "" (T.pack prog)) 100 + +spec :: Spec +spec = describe "run" $ do + it "works" $ do + let prog = + unlines + [ "let rec solve$0 (n$1: int): int =", + " let xs$2: int list =", + " map (fun (i$3: int) ->", + " i$3 * i$3", + " ) (range n$1)", + " in sum xs$2", + "in", + "solve$0" + ] + let expected = + ToplevelLetRec + "solve$0" + [("n$1", IntTy)] + IntTy + ( Let + "xs$2" + (ListTy IntTy) + ( Map' + (VarTy "$100") + (VarTy "$101") + ( Lam + "i$3" + IntTy + (Mult' (Var "i$3") (Var "i$3")) + ) + (Range1' (Var "n$1")) + ) + (Sum' (Var "xs$2")) + ) + (ResultExpr (Var "solve$0")) + run' prog `shouldBe` Right expected + it "inserts new type variables" $ do + let prog = "a[0 <- b][0]" + let expected = + ResultExpr + (At' (VarTy "$100") (SetAt' (VarTy "$101") (Var "a") (LitInt' 0) (Var "b")) (LitInt' 0)) + run' prog `shouldBe` Right expected