Skip to content

Commit

Permalink
feat(core): Update SpecializeFoldl
Browse files Browse the repository at this point in the history
  • Loading branch information
kmyk committed Aug 23, 2021
1 parent a67b95c commit ba2ca79
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 8 deletions.
1 change: 1 addition & 0 deletions Jikka.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ library
Jikka.Core.Language.FreeVars
Jikka.Core.Language.LambdaPatterns
Jikka.Core.Language.Lint
Jikka.Core.Language.ModuloExpr
Jikka.Core.Language.QuasiRules
Jikka.Core.Language.RewriteRules
Jikka.Core.Language.Runtime
Expand Down
31 changes: 23 additions & 8 deletions src/Jikka/Core/Convert/SpecializeFoldl.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module : Jikka.Core.Convert.SpecializeFoldl
Expand All @@ -22,19 +23,37 @@ where

import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Core.Language.ArithmeticExpr
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars
import Jikka.Core.Language.Lint
import Jikka.Core.Language.ModuloExpr
import Jikka.Core.Language.RewriteRules

convertToSum :: Expr -> Maybe Expr
convertToSum = \case
Foldl' t1 IntTy (Lam2 x2 _ x1 _ body) init xs -> do
(a, b) <- makeAffineFunctionFromArithmeticExpr x2 (parseArithmeticExpr body)
guard $ isOneArithmeticExpr a
return $ Plus' init (Sum' (Map' t1 IntTy (Lam x1 t1 (formatArithmeticExpr b)) xs))
_ -> Nothing

convertToModSum :: Expr -> Maybe Expr
convertToModSum = \case
Foldl' t1 IntTy (Lam2 x2 _ x1 _ body) init xs -> do
body <- parseModuloExpr body
(a, b) <- makeAffineFunctionFromArithmeticExpr x2 (arithmeticExprFromModuloExpr body)
guard $ isOneArithmeticExpr a
return $ ModPlus' init (ModSum' (Map' t1 IntTy (Lam x1 t1 (formatArithmeticExpr b)) xs) (moduloOfModuloExpr body)) (moduloOfModuloExpr body)
_ -> Nothing

rule :: MonadAlpha m => RewriteRule m
rule = simpleRewriteRule "Jikka.Core.Convert.SpecializeFoldl" $ \case
(convertToSum -> Just e) -> return e
(convertToModSum -> Just e) -> return e
-- TODO: Replace these operators with the better implementation like sum.
Foldl' t1 t2 (Lam2 x2 _ x1 _ body) init xs -> case body of
-- Sum
Plus' (Var x2') e | x2' == x2 && x2 `isUnusedVar` e -> Just $ Sum' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs))
Plus' e (Var x2') | x2' == x2 && x2 `isUnusedVar` e -> Just $ Sum' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs))
Minus' (Var x2') e | x2' == x2 && x2 `isUnusedVar` e -> Just $ Minus' init (Sum' (Map' t1 t2 (Lam x1 t1 e) xs))
-- Product
Mult' (Var x2') e | x2' == x2 && x2 `isUnusedVar` e -> Just $ Product' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs))
Mult' e (Var x2') | x2' == x2 && x2 `isUnusedVar` e -> Just $ Product' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs))
Expand All @@ -60,10 +79,6 @@ rule = simpleRewriteRule "Jikka.Core.Convert.SpecializeFoldl" $ \case
_ -> Nothing
-- The outer floor-mod is required because foldl for empty lists returns values without modulo.
FloorMod' (Foldl' t1 t2 (Lam2 x2 _ x1 _ body) init xs) m -> case body of
-- ModSum
ModPlus' (Var x2') e m' | x2' == x2 && x2 `isUnusedVar` e && m' == m -> Just $ ModSum' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) m
ModPlus' e (Var x2') m' | x2' == x2 && x2 `isUnusedVar` e && m' == m -> Just $ ModSum' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) m
ModMinus' (Var x2') e m' | x2' == x2 && x2 `isUnusedVar` e && m' == m -> Just $ ModMinus' init (ModSum' (Map' t1 t2 (Lam x1 t1 e) xs) m) m
-- ModProduct
ModMult' (Var x2') e m' | x2' == x2 && x2 `isUnusedVar` e && m' == m -> Just $ ModProduct' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) m
ModMult' e (Var x2') m' | x2' == x2 && x2 `isUnusedVar` e && m' == m -> Just $ ModProduct' (Cons' t2 init (Map' t1 t2 (Lam x1 t1 e) xs)) m
Expand Down
262 changes: 262 additions & 0 deletions src/Jikka/Core/Language/ModuloExpr.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

module Jikka.Core.Language.ModuloExpr
( -- * Basic functions
ModuloExpr,
parseModuloExpr,
formatTopModuloExpr,
formatBottomModuloExpr,
integerModuloExpr,
negateModuloExpr,
plusModuloExpr,
minusModuloExpr,
multModuloExpr,
isZeroModuloExpr,
isOneModuloExpr,
moduloOfModuloExpr,
arithmeticExprFromModuloExpr,
)
where

import Data.List
import Jikka.Common.Error
import Jikka.Core.Format (formatExpr)
import Jikka.Core.Language.ArithmeticExpr
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.Runtime (modpow)

newtype Modulo = Modulo {unModulo :: Expr}
deriving (Eq, Ord, Show, Read)

formatModulo :: Modulo -> String
formatModulo = formatExpr . unModulo

data ProductExpr = ProductExpr
{ productExprConst :: Integer,
productExprList :: [Expr]
}
deriving (Eq, Ord, Show, Read)

data SumExpr = SumExpr
{ sumExprList :: [ProductExpr],
sumExprConst :: Integer
}
deriving (Eq, Ord, Show, Read)

data ModuloExpr = ModuloExpr
{ unModuloExpr :: SumExpr,
modulo :: Modulo
}
deriving (Show)

instance Eq ModuloExpr where
e1 == e2 = unModuloExpr (normalizeModuloExpr e1) == unModuloExpr (normalizeModuloExpr e2) && modulo e1 == modulo e2

instance Ord ModuloExpr where
e1 `compare` e2 = (unModuloExpr (normalizeModuloExpr e1), modulo e1) `compare` (unModuloExpr (normalizeModuloExpr e2), modulo e2)

negateProductExpr :: ProductExpr -> ProductExpr
negateProductExpr e = e {productExprConst = negate (productExprConst e)}

multProductExpr :: ProductExpr -> ProductExpr -> ProductExpr
multProductExpr e1 e2 =
ProductExpr
{ productExprConst = productExprConst e1 * productExprConst e2,
productExprList = productExprList e1 ++ productExprList e2
}

powProductExpr :: Integer -> ProductExpr -> Integer -> ProductExpr
powProductExpr m e k =
ProductExpr
{ productExprConst = fromSuccess $ modpow (productExprConst e) k m,
productExprList = map (\e -> ModPow' e (LitInt' k) (LitInt' m)) (productExprList e)
}

isInteger :: Modulo -> Bool
isInteger (Modulo (LitInt' _)) = True
isInteger _ = False

moduloToInteger :: Modulo -> Integer
moduloToInteger (Modulo (LitInt' m)) = m
moduloToInteger m = error $ "Jikka.Core.Language.ModuloExpr.moduloToInteger: modulo is not an integer" ++ formatModulo m

parseProductExpr :: Modulo -> Expr -> ProductExpr
parseProductExpr m = \case
LitInt' n -> ProductExpr {productExprConst = n, productExprList = []}
Negate' e -> negateProductExpr (parseProductExpr m e)
Mult' e1 e2 -> multProductExpr (parseProductExpr m e1) (parseProductExpr m e2)
Pow' e1 (LitInt' k) | isInteger m -> powProductExpr (moduloToInteger m) (parseProductExpr m e1) k
ModNegate' e m' | Modulo m' == m -> negateProductExpr (parseProductExpr m e)
ModMult' e1 e2 m' | Modulo m' == m -> multProductExpr (parseProductExpr m e1) (parseProductExpr m e2)
ModPow' e1 (LitInt' k) (LitInt' m') | Modulo (LitInt' m') == m -> powProductExpr m' (parseProductExpr m e1) k
e -> ProductExpr {productExprConst = 1, productExprList = [e]}

sumExprFromProductExpr :: ProductExpr -> SumExpr
sumExprFromProductExpr e =
SumExpr
{ sumExprList = [e],
sumExprConst = 0
}

integerSumExpr :: Integer -> SumExpr
integerSumExpr n =
SumExpr
{ sumExprConst = n,
sumExprList = []
}

integerModuloExpr :: Modulo -> Integer -> ModuloExpr
integerModuloExpr m k = ModuloExpr (integerSumExpr k) m

negateSumExpr :: SumExpr -> SumExpr
negateSumExpr e =
SumExpr
{ sumExprList = map negateProductExpr (sumExprList e),
sumExprConst = negate (sumExprConst e)
}

plusSumExpr :: SumExpr -> SumExpr -> SumExpr
plusSumExpr e1 e2 =
SumExpr
{ sumExprList = sumExprList e1 ++ sumExprList e2,
sumExprConst = sumExprConst e1 + sumExprConst e2
}

multSumExpr :: Modulo -> SumExpr -> SumExpr -> SumExpr
multSumExpr m e1 e2 =
SumExpr
{ sumExprList =
let es1 = parseProductExpr m (LitInt' (sumExprConst e1)) : sumExprList e1
es2 = parseProductExpr m (LitInt' (sumExprConst e2)) : sumExprList e2
in tail $ map (uncurry multProductExpr) ((,) <$> es1 <*> es2),
sumExprConst = sumExprConst e1 * sumExprConst e2
}

negateModuloExpr :: ModuloExpr -> ModuloExpr
negateModuloExpr (ModuloExpr e m) = ModuloExpr (negateSumExpr e) m

plusModuloExpr :: ModuloExpr -> ModuloExpr -> Maybe ModuloExpr
plusModuloExpr (ModuloExpr e1 m) (ModuloExpr e2 m') | m == m' = Just $ ModuloExpr (plusSumExpr e1 e2) m
plusModuloExpr _ _ = Nothing

minusModuloExpr :: ModuloExpr -> ModuloExpr -> Maybe ModuloExpr
minusModuloExpr (ModuloExpr e1 m) (ModuloExpr e2 m') | m == m' = Just $ ModuloExpr (plusSumExpr e1 (negateSumExpr e2)) m
minusModuloExpr _ _ = Nothing

multModuloExpr :: ModuloExpr -> ModuloExpr -> Maybe ModuloExpr
multModuloExpr (ModuloExpr e1 m) (ModuloExpr e2 m') | m == m' = Just $ ModuloExpr (multSumExpr m e1 e2) m
multModuloExpr _ _ = Nothing

parseSumExpr :: Modulo -> Expr -> SumExpr
parseSumExpr m = \case
LitInt' n -> SumExpr {sumExprList = [], sumExprConst = n}
Negate' e -> negateSumExpr (parseSumExpr m e)
Plus' e1 e2 -> plusSumExpr (parseSumExpr m e1) (parseSumExpr m e2)
Minus' e1 e2 -> plusSumExpr (parseSumExpr m e1) (negateSumExpr (parseSumExpr m e2))
Mult' e1 e2 -> multSumExpr m (parseSumExpr m e1) (parseSumExpr m e2)
FloorMod' e m' | Modulo m' == m -> parseSumExpr m e
ModNegate' e m' | Modulo m' == m -> negateSumExpr (parseSumExpr m e)
ModPlus' e1 e2 m' | Modulo m' == m -> plusSumExpr (parseSumExpr m e1) (parseSumExpr m e2)
ModMinus' e1 e2 m' | Modulo m' == m -> plusSumExpr (parseSumExpr m e1) (negateSumExpr (parseSumExpr m e2))
ModMult' e1 e2 m' | Modulo m' == m -> multSumExpr m (parseSumExpr m e1) (parseSumExpr m e2)
e -> sumExprFromProductExpr (parseProductExpr m e)

getModuloFromExpr :: Expr -> Maybe Modulo
getModuloFromExpr = \case
FloorMod' _ m -> Just $ Modulo m
ModNegate' _ m -> Just $ Modulo m
ModPlus' _ _ m -> Just $ Modulo m
ModMinus' _ _ m -> Just $ Modulo m
ModMult' _ _ m -> Just $ Modulo m
_ -> Nothing

-- | `parseModuloExpr` converts a given expr to a normal form \(\sum_i \prod_j e _ {i,j}) \bmod m\).
-- This assumes given exprs have the type \(\mathbf{int}\).
parseModuloExpr :: Expr -> Maybe ModuloExpr
parseModuloExpr e = do
m <- getModuloFromExpr e
return $ ModuloExpr (parseSumExpr m e) m

formatProductExpr :: ProductExpr -> Expr
formatProductExpr e =
let k = LitInt' (productExprConst e)
k' e' = case productExprConst e of
0 -> LitInt' 0
1 -> e'
-1 -> Negate' e'
_ -> Mult' e' k
in case productExprList e of
[] -> k
eHead : esTail -> k' (foldl Mult' eHead esTail)

formatSumExpr :: SumExpr -> Expr
formatSumExpr e = case sumExprList e of
[] -> LitInt' (sumExprConst e)
eHead : esTail ->
let op e'
| productExprConst e' > 0 = Plus'
| productExprConst e' < 0 = Minus'
| otherwise = const
go e1 e2 = op e2 e1 (formatProductExpr (e2 {productExprConst = abs (productExprConst e2)}))
k' e'
| sumExprConst e > 0 = Plus' e' (LitInt' (sumExprConst e))
| sumExprConst e < 0 = Minus' e' (LitInt' (abs (sumExprConst e)))
| otherwise = e'
in k' (foldl go (formatProductExpr eHead) esTail)

-- | `formatTopModuloExpr` convert `ModuloExpr` to `Expr` with adding `FloorMod` at only the root.
formatTopModuloExpr :: ModuloExpr -> Expr
formatTopModuloExpr e = FloorMod' (formatSumExpr . unModuloExpr $ normalizeModuloExpr e) (unModulo $ modulo e)

-- | `formatBottomModuloExpr` convert `ModuloExpr` to `Expr` with adding `FloorMod` at every nodes.
formatBottomModuloExpr :: ModuloExpr -> Expr
formatBottomModuloExpr = error "Jikka.Core.Language.ModuloExpr.formatBottomModuloExpr: TODO: implement this"

normalizeProductExpr :: Modulo -> ProductExpr -> ProductExpr
normalizeProductExpr m e =
let k = case m of
Modulo (LitInt' m) -> productExprConst e `mod` m
_ -> productExprConst e
es =
if k == 0
then []
else sort (productExprList e)
in e
{ productExprList = es,
productExprConst = k
}

normalizeSumExpr :: Modulo -> SumExpr -> SumExpr
normalizeSumExpr m e =
let cmp e1 e2 = productExprList e1 `compare` productExprList e2
cmp' e1 e2 = cmp e1 e2 == EQ
es = sortBy cmp (map (normalizeProductExpr m) (sumExprList e))
es' = groupBy cmp' es
es'' = map (\group -> ProductExpr {productExprConst = sum (map productExprConst group), productExprList = productExprList (head group)}) es'
es''' = filter (\e -> productExprConst e /= 0 && not (null (productExprList e))) es''
k = sum (map (\e -> if null (productExprList e) then productExprConst e else 0) es'')
k' = sumExprConst e + k
k'' = case m of
Modulo (LitInt' m) -> k' `mod` m
_ -> k'
in SumExpr
{ sumExprList = es''',
sumExprConst = k''
}

normalizeModuloExpr :: ModuloExpr -> ModuloExpr
normalizeModuloExpr e = ModuloExpr (normalizeSumExpr (modulo e) (unModuloExpr e)) (modulo e)

isZeroModuloExpr :: ModuloExpr -> Bool
isZeroModuloExpr e = normalizeModuloExpr e == integerModuloExpr (modulo e) 0

isOneModuloExpr :: ModuloExpr -> Bool
isOneModuloExpr e = normalizeModuloExpr e == integerModuloExpr (modulo e) 1

moduloOfModuloExpr :: ModuloExpr -> Expr
moduloOfModuloExpr = unModulo . modulo

arithmeticExprFromModuloExpr :: ModuloExpr -> ArithmeticExpr
arithmeticExprFromModuloExpr = parseArithmeticExpr . formatSumExpr . unModuloExpr

0 comments on commit ba2ca79

Please sign in to comment.