diff --git a/changelog/2021-11-01T16_46_50+01_00_rec_nonrec_let.md b/changelog/2021-11-01T16_46_50+01_00_rec_nonrec_let.md new file mode 100644 index 0000000000..ef8908666c --- /dev/null +++ b/changelog/2021-11-01T16_46_50+01_00_rec_nonrec_let.md @@ -0,0 +1 @@ +CHANGED: Clash now keeps information about which let bindings are recursive from GHC core. This can be used to avoid performing free variable calculations, or sorting bindings in normalization. diff --git a/clash-ghc/src-ghc/Clash/GHC/Evaluator.hs b/clash-ghc/src-ghc/Clash/GHC/Evaluator.hs index 8e5d04a8b0..6efb33a66d 100644 --- a/clash-ghc/src-ghc/Clash/GHC/Evaluator.hs +++ b/clash-ghc/src-ghc/Clash/GHC/Evaluator.hs @@ -213,8 +213,9 @@ stepTyApp x ty m tcm = (term, args, _) = collectArgsTicks (TyApp x ty) tys' = fst . splitFunForallTy . inferCoreTypeOf tcm $ TyApp x ty -stepLetRec :: [LetBinding] -> Term -> Step -stepLetRec bs x m _ = Just (allocate bs x m) +stepLet :: Bind Term -> Term -> Step +stepLet (NonRec i b) x m _ = Just (allocate [(i,b)] x m) +stepLet (Rec bs) x m _ = Just (allocate bs x m) stepCase :: Term -> Type -> [Alt] -> Step stepCase scrut ty alts m _ = @@ -245,7 +246,7 @@ ghcStep m = case mTerm m of TyLam v x -> stepTyLam v x m App x y -> stepApp x y m TyApp x ty -> stepTyApp x ty m - Letrec bs x -> stepLetRec bs x m + Let bs x -> stepLet bs x m Case s ty as -> stepCase s ty as m Cast x a b -> stepCast x a b m Tick t x -> stepTick t x m diff --git a/clash-ghc/src-ghc/Clash/GHC/GHC2Core.hs b/clash-ghc/src-ghc/Clash/GHC/GHC2Core.hs index 50e4a0d3fe..85a7f71e30 100644 --- a/clash-ghc/src-ghc/Clash/GHC/GHC2Core.hs +++ b/clash-ghc/src-ghc/Clash/GHC/GHC2Core.hs @@ -434,12 +434,12 @@ coreToTerm primMap unlocs = term (e1',sp) <- termSP (getSrcSpan x) e1 x' <- coreToIdSP sp x e2' <- term e2 - return (C.Letrec [(x', e1')] e2') + return (C.Let (C.NonRec x' e1') e2') term' (Let (Rec xes) e) = do xes' <- mapM go xes e' <- term e - return (C.Letrec xes' e') + return (C.Let (C.Rec xes') e') where go (x,b) = do (b',sp) <- termSP (getSrcSpan x) b @@ -460,7 +460,7 @@ coreToTerm primMap unlocs = term if usesBndr then do ct <- caseTerm (C.Var b') - return (C.Letrec [(b', e')] ct) + return (C.Let (C.NonRec b' e') ct) else caseTerm e' term' (Cast e co) = do diff --git a/clash-ghc/src-ghc/Clash/GHC/PartialEval/Eval.hs b/clash-ghc/src-ghc/Clash/GHC/PartialEval/Eval.hs index ade67dc191..43640ab910 100644 --- a/clash-ghc/src-ghc/Clash/GHC/PartialEval/Eval.hs +++ b/clash-ghc/src-ghc/Clash/GHC/PartialEval/Eval.hs @@ -23,7 +23,6 @@ import Control.Monad (foldM) import Data.Bifunctor import Data.Bitraversable import Data.Either -import Data.Graph (SCC(..)) import Data.Primitive.ByteArray (ByteArray(..)) #if MIN_VERSION_base(4,15,0) import GHC.Num.Integer (Integer (..)) @@ -48,7 +47,6 @@ import Clash.Core.Term import Clash.Core.TyCon (tyConDataCons) import Clash.Core.Type import Clash.Core.TysPrim (integerPrimTy) -import qualified Clash.Core.Util as Util import Clash.Core.Var import Clash.Driver.Types (Binding(..), IsPrim(..)) import qualified Clash.Normalize.Primitives as NP (undefined) @@ -66,7 +64,7 @@ eval = \case TyLam i x -> evalTyLam i x App x y -> evalApp x (Left y) TyApp x ty -> evalApp x (Right ty) - Letrec bs x -> evalLetrec bs x + Let bs x -> evalLet bs x Case x ty alts -> evalCase x ty alts Cast x a b -> evalCast x a b Tick tick x -> evalTick tick x @@ -250,33 +248,30 @@ evalApp x y term = either (App x) (TyApp x) y (f, args, _ticks) = collectArgsTicks term -evalLetrec :: [LetBinding] -> Term -> Eval Value -evalLetrec bs x = do - -- Determine if a binding should be kept in a letrec or inlined. We keep - -- bindings which perform work to prevent duplication of registers etc. - (keep, inline) <- foldM evalScc ([], []) (Util.sccLetBindings bs) - eX <- withIds (keep <> inline) (eval x) +evalLet :: Bind Term -> Term -> Eval Value +evalLet (NonRec i x) body = do + iTy <- evalType (varType i) + eX <- delayEval x + wfX <- workFreeValue eX - case keep of - [] -> pure eX - _ -> pure (VNeutral (NeLetrec keep eX)) - where - evalBind (i, y) = do - iTy <- evalType (varType i) - eY <- delayEval y + eBody <- withId i eX (eval body) - pure (i { varType = iTy }, eY) + -- Only keep the let binding if it performs work. + if wfX + then pure eBody + else pure (VNeutral (NeLet (NonRec i{varType=iTy} eX) eBody)) - evalScc (k, i) = \case - AcyclicSCC y -> do - eY <- evalBind y - workFree <- workFreeValue (snd eY) +evalLet (Rec xs) body = do + binds <- traverse evalBind xs + eBody <- withIds binds (eval body) - if workFree then pure (k, eY:i) else pure (eY:k, i) + pure (VNeutral (NeLet (Rec binds) eBody)) + where + evalBind (i, x) = do + iTy <- evalType (varType i) + eX <- delayEval x - CyclicSCC ys -> do - eYs <- traverse evalBind ys - pure (eYs <> k, i) + pure (i{varType=iTy}, eX) evalCase :: Term -> Type -> [Alt] -> Eval Value evalCase term ty as = do @@ -343,9 +338,9 @@ tryTransformCase subject ty alts = -- A case of let: Pull out the let expression if possible and attempt -- caseCon on the new case expression. - VNeutral (NeLetrec bindings innerSubject) -> do + VNeutral (NeLet bindings innerSubject) -> do newCase <- caseCon innerSubject ty alts - pure (VNeutral (NeLetrec bindings newCase)) + pure (VNeutral (NeLet bindings newCase)) -- There is no way to continue evaluating the case, do nothing. -- TODO elimExistentials here. @@ -570,16 +565,16 @@ apply val arg = do case stripValue forced of -- If the LHS of application evaluates to a letrec, then add any bindings -- that do work to this letrec instead of creating a new one. - VNeutral (NeLetrec bs x) + VNeutral (NeLet bs x) | canApply -> do inner <- apply x arg - pure (VNeutral (NeLetrec bs inner)) + pure (VNeutral (NeLet bs inner)) | otherwise -> do varTy <- evalType (valueType tcm arg) var <- getUniqueId "workArg" varTy inner <- apply x (VNeutral (NeVar var)) - pure (VNeutral (NeLetrec ((var, arg) : bs) inner)) + pure (VNeutral (NeLet bs (VNeutral (NeLet (NonRec var arg) inner)))) -- If the LHS of application is neutral, make a letrec around the neutral -- application if the argument performs work. @@ -589,7 +584,7 @@ apply val arg = do varTy <- evalType (valueType tcm arg) var <- getUniqueId "workArg" varTy let inner = VNeutral (NeApp neu (VNeutral (NeVar var))) - pure (VNeutral (NeLetrec [(var, arg)] inner)) + pure (VNeutral (NeLet (NonRec var arg) inner)) -- If the LHS of application is a lambda, make a letrec with the name of -- the argument around the result of evaluation if it performs work. @@ -597,7 +592,7 @@ apply val arg = do | canApply -> setLocalEnv env $ withId i arg (eval x) | otherwise -> setLocalEnv env $ do inner <- withId i arg (eval x) - pure (VNeutral (NeLetrec [(i, arg)] inner)) + pure (VNeutral (NeLet (NonRec i arg) inner)) f -> error ("apply: Cannot apply " <> show arg <> " to " <> show f) diff --git a/clash-ghc/src-ghc/Clash/GHC/PartialEval/Quote.hs b/clash-ghc/src-ghc/Clash/GHC/PartialEval/Quote.hs index 872c790bb5..04c222cc76 100644 --- a/clash-ghc/src-ghc/Clash/GHC/PartialEval/Quote.hs +++ b/clash-ghc/src-ghc/Clash/GHC/PartialEval/Quote.hs @@ -1,5 +1,5 @@ {-| -Copyright : (C) 2020, QBayLogic B.V. +Copyright : (C) 2020-2021, QBayLogic B.V. License : BSD2 (see the file LICENSE) Maintainer : QBayLogic B.V. @@ -18,7 +18,7 @@ import Data.Bitraversable import Clash.Core.DataCon (DataCon) import Clash.Core.PartialEval.Monad import Clash.Core.PartialEval.NormalForm -import Clash.Core.Term (Term, PrimInfo, TickInfo, Pat) +import Clash.Core.Term (Bind(..), Term, PrimInfo, TickInfo, Pat) import Clash.Core.Type (Type(VarTy)) import Clash.Core.Var (Id, TyVar) @@ -41,7 +41,7 @@ quoteNeutral = \case NePrim pr args -> quoteNePrim pr args NeApp x y -> quoteNeApp x y NeTyApp x ty -> quoteNeTyApp x ty - NeLetrec bs x -> quoteNeLetrec bs x + NeLet bs x -> quoteNeLet bs x NeCase x ty alts -> quoteNeCase x ty alts quoteArgs :: Args Value -> Eval (Args Normal) @@ -50,8 +50,9 @@ quoteArgs = traverse (bitraverse quote pure) quoteAlts :: [(Pat, Value)] -> Eval [(Pat, Normal)] quoteAlts = traverse (bitraverse pure quote) -quoteBinders :: [(Id, Value)] -> Eval [(Id, Normal)] -quoteBinders = traverse (bitraverse pure quote) +quoteBind :: Bind Value -> Eval (Bind Normal) +quoteBind (NonRec i x) = NonRec i <$> quote x +quoteBind (Rec xs) = Rec <$> traverse (bitraverse pure quote) xs quoteData :: DataCon -> Args Value -> LocalEnv -> Eval Normal quoteData dc args env = setLocalEnv env (NData dc <$> quoteArgs args) @@ -90,9 +91,12 @@ quoteNeApp x y = NeApp <$> quoteNeutral x <*> quote y quoteNeTyApp :: Neutral Value -> Type -> Eval (Neutral Normal) quoteNeTyApp x ty = NeTyApp <$> quoteNeutral x <*> pure ty -quoteNeLetrec :: [(Id, Value)] -> Value -> Eval (Neutral Normal) -quoteNeLetrec bs x = - withIds bs (NeLetrec <$> quoteBinders bs <*> quote x) +quoteNeLet :: Bind Value -> Value -> Eval (Neutral Normal) +quoteNeLet bs x = + withIds (bindToList bs) (NeLet <$> quoteBind bs <*> quote x) + where + bindToList (NonRec i e) = [(i, e)] + bindToList (Rec xs) = xs quoteNeCase :: Value -> Type -> [(Pat, Value)] -> Eval (Neutral Normal) quoteNeCase x ty alts = diff --git a/clash-lib/src/Clash/Core/FreeVars.hs b/clash-lib/src/Clash/Core/FreeVars.hs index 1728727bd8..217139fcf1 100644 --- a/clash-lib/src/Clash/Core/FreeVars.hs +++ b/clash-lib/src/Clash/Core/FreeVars.hs @@ -1,7 +1,8 @@ {-| Copyright : (C) 2012-2016, University of Twente + 2021, QBayLogic B.V. License : BSD2 (see the file LICENSE) - Maintainer : Christiaan Baaij + Maintainer : QBayLogic B.V. Free variable calculations -} @@ -34,7 +35,7 @@ import Data.Coerce import qualified Data.IntSet as IntSet import Data.Monoid (All (..), Any (..)) -import Clash.Core.Term (Pat (..), Term (..), TickInfo (..)) +import Clash.Core.Term (Pat (..), Term (..), TickInfo (..), Bind(..)) import Clash.Core.Type (Type (..)) import Clash.Core.Var (Id, IdScope (..), TyVar, Var (..), isLocalId) @@ -223,11 +224,15 @@ termFreeVars' interesting f = go IntSet.empty where TyApp <$> go inLocalScope l <*> typeFreeVars' interesting inLocalScope f r - Letrec bs e -> - Letrec <$> traverse (goBind inLocalScope') bs - <*> go inLocalScope' e - where - inLocalScope' = foldr IntSet.insert inLocalScope (map (varUniq.fst) bs) + Let (NonRec i x) e -> + Let <$> (NonRec <$> goBndr inLocalScope i <*> go inLocalScope x) + <*> go (IntSet.insert (varUniq i) inLocalScope) e + + Let (Rec bs) e -> + Let <$> (Rec <$> traverse (goBind inLocalScope') bs) + <*> go inLocalScope' e + where + inLocalScope' = foldr (IntSet.insert . varUniq . fst) inLocalScope bs Case subj ty alts -> Case <$> go inLocalScope subj diff --git a/clash-lib/src/Clash/Core/HasType.hs b/clash-lib/src/Clash/Core/HasType.hs index 487ea927c1..b1a27f0406 100644 --- a/clash-lib/src/Clash/Core/HasType.hs +++ b/clash-lib/src/Clash/Core/HasType.hs @@ -142,7 +142,7 @@ instance InferType Term where case collectArgs x of (fun, args) -> applyTypeToArgs x tcm (go fun) args - Letrec _ x -> go x + Let _ x -> go x Case _ ty _ -> ty Cast _ _ a -> a Tick _ x -> go x diff --git a/clash-lib/src/Clash/Core/PartialEval/AsTerm.hs b/clash-lib/src/Clash/Core/PartialEval/AsTerm.hs index d83f79ee76..f804544fdc 100644 --- a/clash-lib/src/Clash/Core/PartialEval/AsTerm.hs +++ b/clash-lib/src/Clash/Core/PartialEval/AsTerm.hs @@ -16,12 +16,10 @@ module Clash.Core.PartialEval.AsTerm ) where import Data.Bifunctor (first, second) -import Data.Graph (SCC(..), flattenSCCs) import Clash.Core.HasFreeVars import Clash.Core.PartialEval.NormalForm -import Clash.Core.Term (Term(..), LetBinding, Pat, Alt, mkApps) -import Clash.Core.Util (sccLetBindings) +import Clash.Core.Term (Bind(..), Term(..), Pat, Alt, mkApps) import Clash.Core.VarEnv (elemVarSet) -- | Convert a term in some normal form back into a Term. This is important, @@ -37,24 +35,19 @@ instance (AsTerm a) => AsTerm (Neutral a) where NePrim pr args -> mkApps (Prim pr) (argsToTerms args) NeApp x y -> App (asTerm x) (asTerm y) NeTyApp x ty -> TyApp (asTerm x) ty - NeLetrec bs x -> - let bs' = fmap (second asTerm) bs - x' = asTerm x - in removeUnusedBindings bs' x' - + NeLet bs x -> removeUnusedBindings (fmap asTerm bs) (asTerm x) NeCase x ty alts -> Case (asTerm x) ty (altsToTerms alts) -removeUnusedBindings :: [LetBinding] -> Term -> Term +removeUnusedBindings :: Bind Term -> Term -> Term removeUnusedBindings bs x - | null used = x - | otherwise = Letrec used x + | isUsed bs = Let bs x + | otherwise = x where free = freeVarsOf x - used = flattenSCCs $ filter isUsed (sccLetBindings bs) isUsed = \case - AcyclicSCC y -> fst y `elemVarSet` free - CyclicSCC ys -> any (flip elemVarSet free . fst) ys + NonRec i _ -> elemVarSet i free + Rec xs -> any (flip elemVarSet free . fst) xs instance AsTerm Value where asTerm = \case diff --git a/clash-lib/src/Clash/Core/PartialEval/NormalForm.hs b/clash-lib/src/Clash/Core/PartialEval/NormalForm.hs index 5a0f35f703..a32a7e5130 100644 --- a/clash-lib/src/Clash/Core/PartialEval/NormalForm.hs +++ b/clash-lib/src/Clash/Core/PartialEval/NormalForm.hs @@ -37,7 +37,7 @@ import Data.Map.Strict (Map) import Clash.Core.DataCon (DataCon) import Clash.Core.Literal -import Clash.Core.Term (Term(..), PrimInfo(primName), TickInfo, Pat) +import Clash.Core.Term (Bind, Term(..), PrimInfo(primName), TickInfo, Pat) import Clash.Core.TyCon (TyConMap) import Clash.Core.Type (Type, TyVar) import Clash.Core.Util (undefinedPrims) @@ -75,7 +75,7 @@ data Neutral a | NePrim !PrimInfo !(Args a) | NeApp !(Neutral a) !a | NeTyApp !(Neutral a) !Type - | NeLetrec ![(Id, a)] !a + | NeLet !(Bind a) !a | NeCase !a !Type ![(Pat, a)] deriving (Show) diff --git a/clash-lib/src/Clash/Core/Pretty.hs b/clash-lib/src/Clash/Core/Pretty.hs index ca24bcad73..4d3648db01 100644 --- a/clash-lib/src/Clash/Core/Pretty.hs +++ b/clash-lib/src/Clash/Core/Pretty.hs @@ -61,7 +61,7 @@ import Clash.Core.DataCon (DataCon (..)) import Clash.Core.Literal (Literal (..)) import Clash.Core.Name (Name (..)) import Clash.Core.Term - (Pat (..), Term (..), TickInfo (..), NameMod (..), CoreContext (..), primArg, PrimInfo(primName)) + (Pat (..), Term (..), TickInfo (..), NameMod (..), CoreContext (..), primArg, PrimInfo(primName),Bind(..)) import Clash.Core.TyCon (TyCon (..), TyConName, isTupleTyConLike) import Clash.Core.Type (ConstTy (..), Kind, LitTy (..), Type (..), TypeView (..), tyView) @@ -221,11 +221,11 @@ pprTopLevelBndr (bndr,expr) = do expr' <- pprM expr return $ bndr' <> line <> hang 2 (sep [(bndrName <+> equals), expr']) <> line -dcolon, rarrow, lam, tylam, at, cast, coerce, letrec, in_, case_, of_, forall_ +dcolon, rarrow, lam, tylam, at, cast, coerce, let_, letrec, in_, case_, of_, forall_ :: ClashDoc -[dcolon, rarrow, lam, tylam, at, cast, coerce, letrec, in_, case_, of_, forall_] +[dcolon, rarrow, lam, tylam, at, cast, coerce, let_, letrec, in_, case_, of_, forall_] = annotate (AnnSyntax Keyword) <$> - ["::", "->", "λ", "Λ", "@", "▷", "~", "letrec", "in", "case", "of", "forall"] + ["::", "->", "λ", "Λ", "@", "▷", "~", "let", "letrec", "in", "case", "of", "forall"] instance PrettyPrec Text where pprPrec _ = pure . pretty @@ -262,7 +262,8 @@ instance PrettyPrec Term where App fun arg -> pprPrecApp prec fun arg TyApp e' ty -> annotate (AnnContext TyAppC) <$> pprPrecTyApp prec e' ty - Letrec xes e1 -> pprPrecLetrec prec xes e1 + Let (NonRec i x) e1 -> pprPrecLetrec prec False [(i,x)] e1 + Let (Rec xes) e1 -> pprPrecLetrec prec True xes e1 Case e' _ alts -> pprPrecCase prec e' alts Cast e' ty1 ty2 -> pprPrecCast prec e' ty1 ty2 Tick t e' -> do @@ -381,8 +382,14 @@ pprPrecCast prec e ty1 ty2 = do e' <> annotate (AnnSyntax Type) (softline <> nest 2 (vsep [cast, ty1', coerce, ty2'])) -pprPrecLetrec :: Monad m => Rational -> [(Id, Term)] -> Term -> m ClashDoc -pprPrecLetrec prec xes body = do +-- TODO Since Clash now keeps non-recursive let expressions separately, the +-- result of normalization will contain more nested let expressions as the old +-- Letrec-based definitions are replaced by Let. As this happens, it may be a +-- good idea to change pprPrecLetrec to encourage more compact forms such as +-- printing the entire binding on one line if possible. + +pprPrecLetrec :: Monad m => Rational -> Bool -> [(Id, Term)] -> Term -> m ClashDoc +pprPrecLetrec prec isRec xes body = do let bndrs = fst <$> xes body' <- annotate (AnnContext $ LetBody bndrs) <$> pprPrec noPrec body xes' <- mapM (\(x,e) -> do @@ -392,8 +399,9 @@ pprPrecLetrec prec xes body = do vsepHard [x', equals <+> e'] ) xes let xes'' = case xes' of { [] -> ["EmptyLetrec"]; _ -> xes' } + let kw = if isRec then letrec else let_ return $ parensIf (prec > noPrec) $ - vsepHard [hang 2 (vsepHard $ letrec : xes''), in_ <+> body'] + vsepHard [hang 2 (vsepHard $ kw : xes''), in_ <+> body'] pprPrecCase :: Monad m => Rational -> Term -> [(Pat,Term)] -> m ClashDoc pprPrecCase prec e alts = do diff --git a/clash-lib/src/Clash/Core/Subst.hs b/clash-lib/src/Clash/Core/Subst.hs index 229db5853a..1b95240e2e 100644 --- a/clash-lib/src/Clash/Core/Subst.hs +++ b/clash-lib/src/Clash/Core/Subst.hs @@ -73,7 +73,7 @@ import GHC.Stack (HasCallStack) import Clash.Core.HasFreeVars import Clash.Core.Pretty (ppr, fromPpr) import Clash.Core.Term - (LetBinding, Pat (..), Term (..), TickInfo (..), PrimInfo(primName)) + (Bind(..), Pat (..), Term (..), TickInfo (..), PrimInfo(primName)) import Clash.Core.Type (Type (..)) import Clash.Core.VarEnv import Clash.Core.Var (Id, Var (..), TyVar, isGlobalId) @@ -558,8 +558,8 @@ substTm doc subst = go where (subst',v') -> TyLam v' (substTm doc subst' e) App l r -> App (go l) (go r) TyApp l r -> TyApp (go l) (substTy subst r) - Letrec bs e -> case substBind doc subst bs of - (subst',bs') -> Letrec bs' (substTm doc subst' e) + Let bs e -> case substBind doc subst bs of + (subst',bs') -> Let bs' (substTm doc subst' e) Case subj ty alts -> Case (go subj) (substTy subst ty) (map goAlt alts) Cast e t1 t2 -> Cast (go e) (substTy subst t1) (substTy subst t2) Tick tick e -> Tick (goTick tick) (go e) @@ -671,10 +671,16 @@ substBind :: HasCallStack => Doc () -> Subst - -> [LetBinding] - -> (Subst,[LetBinding]) -substBind doc subst xs = - (subst',zip bndrs' rhss') + -> Bind Term + -> (Subst, Bind Term) +substBind doc subst (NonRec i x) = + (subst', NonRec i' x') + where + (subst', i') = substIdBndr subst i + x' = substTm ("substBind" <+> doc) subst x + +substBind doc subst (Rec xs) = + (subst', Rec (zip bndrs' rhss')) where (bndrs,rhss) = unzip xs (subst',bndrs') = List.mapAccumL substIdBndr subst bndrs @@ -718,11 +724,11 @@ deshadowLetExpr :: HasCallStack => InScopeSet -- ^ Current InScopeSet - -> [LetBinding] + -> Bind Term -- ^ Bindings of the let-expression -> Term -- ^ The body of the let-expression - -> ([LetBinding],Term) + -> (Bind Term, Term) -- ^ Deshadowed let-bindings, where let-bound expressions and the let-body -- properly reference the renamed variables deshadowLetExpr is bs e = @@ -753,9 +759,9 @@ freshenTm is0 = go (mkSubst is0) where (is2,r') -> (is2, App l' r') TyApp l r -> case go subst0 l of (is1,l') -> (is1, TyApp l' (substTy subst0 r)) - Letrec bs e -> case goBind subst0 bs of + Let bs e -> case goBind subst0 bs of (subst1,bs') -> case go subst1 e of - (is2,e') -> (is2,Letrec bs' e') + (is2,e') -> (is2,Let bs' e') Case subj ty alts -> case go subst0 subj of (is1,subj') -> case List.mapAccumL (\isN -> goAlt subst0 {substInScope = isN}) is1 alts of (is2,alts') -> (is2, Case subj' (substTy subst0 ty) alts') @@ -765,13 +771,18 @@ freshenTm is0 = go (mkSubst is0) where (is1, e') -> (is1, Tick (goTick subst0 tick) e') tm -> (substInScope subst0, tm) - goBind subst0 xs = + goBind subst0 (NonRec i x) = + let (subst1, i') = substIdBndr subst0 i + (is2, x') = go subst0 x + in (subst1 { substInScope = extendInScopeSet is2 i' }, NonRec i' x') + + goBind subst0 (Rec xs) = let (bndrs,rhss) = unzip xs (subst1,bndrs') = List.mapAccumL substIdBndr subst0 bndrs (is2,rhss') = List.mapAccumL (\isN -> go subst1 {substInScope = isN}) (substInScope subst1) rhss - in (subst1 {substInScope = is2},zip bndrs' rhss') + in (subst1 {substInScope = is2}, Rec $ zip bndrs' rhss') goAlt subst0 (pat,alt) = case pat of DataPat dc tvs ids -> case List.mapAccumL substTyVarBndr' subst0 tvs of @@ -891,7 +902,9 @@ acmpTerm' inScope = go (mkRnEnv inScope) go env l1 l2 `thenCompare` go env r1 r2 go env (TyApp l1 r1) (TyApp l2 r2) = go env l1 l2 `thenCompare` acmpType' env r1 r2 - go env (Letrec bs1 e1) (Letrec bs2 e2) = + go env (Let (NonRec i1 x1) e1) (Let (NonRec i2 x2) e2) = + go env x1 x2 `thenCompare` go (rnTmBndr env i1 i2) e1 e2 + go env (Let (Rec bs1) e1) (Let (Rec bs2) e2) = compare (length bs1) (length bs2) `thenCompare` foldr thenCmpTm EQ (zipWith (go env') rhs1 rhs2) `thenCompare` go env' e1 e2 @@ -931,9 +944,10 @@ acmpTerm' inScope = go (mkRnEnv inScope) TyApp {} -> 6 Lam {} -> 7 TyLam {} -> 8 - Letrec {} -> 9 - Case {} -> 10 - Tick {} -> 11 + Let NonRec{} _ -> 9 + Let Rec{} _ -> 10 + Case {} -> 11 + Tick {} -> 12 thenCompare :: Ordering -> Ordering -> Ordering thenCompare EQ rel = rel diff --git a/clash-lib/src/Clash/Core/Term.hs b/clash-lib/src/Clash/Core/Term.hs index f6ffea8b0c..b9d539a6f7 100644 --- a/clash-lib/src/Clash/Core/Term.hs +++ b/clash-lib/src/Clash/Core/Term.hs @@ -13,10 +13,11 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE TemplateHaskell #-} module Clash.Core.Term - ( Term (..) + ( Term (.., Letrec) , mkAbstraction , mkTyLams , mkLams @@ -26,6 +27,7 @@ module Clash.Core.Term , mkTicks , TmName , varToId + , Bind(..) , LetBinding , Pat (..) , patIds @@ -90,13 +92,23 @@ data Term | TyLam !TyVar Term -- ^ Type-abstraction | App !Term !Term -- ^ Application | TyApp !Term !Type -- ^ Type-application - | Letrec [LetBinding] Term -- ^ Recursive let-binding + | Let !(Bind Term) Term -- ^ Recursive let-binding | Case !Term !Type [Alt] -- ^ Case-expression: subject, type of -- alternatives, list of alternatives | Cast !Term !Type !Type -- ^ Cast a term from one type to another | Tick !TickInfo !Term -- ^ Annotated term deriving (Show,Generic,NFData,Hashable,Binary) +-- TODO When it is possible, remove this pattern. +pattern Letrec :: [LetBinding] -> Term -> Term +pattern Letrec bs x <- Let (bindToList -> bs) x + where + Letrec bs x = Let (Rec bs) x + +bindToList :: Bind a -> [(Id, a)] +bindToList (NonRec i x) = [(i, x)] +bindToList (Rec xs) = xs + data TickInfo = SrcSpan !SrcSpan -- ^ Source tick, will get added by GHC by running clash with `-g` @@ -173,6 +185,12 @@ type TmName = Name Term -- | Binding in a LetRec construct type LetBinding = (Id, Term) +data Bind a + = NonRec Id a + | Rec [(Id, a)] + deriving (Eq, Show, Generic, NFData, Hashable, Binary, Functor) + -- Structural equivalence instead of alpha equivalance + -- | Patterns in the LHS of a case-decomposition data Pat = DataPat !DataCon [TyVar] [Id] @@ -361,7 +379,8 @@ walkTerm f = catMaybes . DList.toList . go TyLam _ t1 -> go t1 App t1 t2 -> go t1 <> go t2 TyApp t1 _ -> go t1 - Letrec bndrs t1 -> go t1 <> mconcat (map (go . snd) bndrs) + Let (NonRec _ x) t1 -> go t1 <> go x + Let (Rec bndrs) t1 -> go t1 <> mconcat (map (go . snd) bndrs) Case t1 _ alts -> go t1 <> mconcat (map (go . snd) alts) Cast t1 _ _ -> go t1 Tick _ t1 -> go t1 @@ -373,7 +392,8 @@ collectTermIds = concat . walkTerm (Just . go) go :: Term -> [Id] go (Var i) = [i] go (Lam i _) = [i] - go (Letrec bndrs _) = map fst bndrs + go (Let (NonRec i _) _) = [i] + go (Let (Rec bndrs) _) = fmap fst bndrs go (Case _ _ alts) = concatMap (pat . fst) alts go (Data _) = [] go (Literal _) = [] diff --git a/clash-lib/src/Clash/Core/TermInfo.hs b/clash-lib/src/Clash/Core/TermInfo.hs index 59076b46c9..bafdb3e5ce 100644 --- a/clash-lib/src/Clash/Core/TermInfo.hs +++ b/clash-lib/src/Clash/Core/TermInfo.hs @@ -29,9 +29,10 @@ termSize (App e1 e2) = termSize e1 + termSize e2 termSize (TyApp e _) = termSize e termSize (Cast e _ _) = termSize e termSize (Tick _ e) = termSize e -termSize (Letrec bndrs e) = sum (bodySz:bndrSzs) +termSize (Let (NonRec _ x) e) = termSize x + termSize e +termSize (Let (Rec xs) e) = sum (bodySz:bndrSzs) where - bndrSzs = map (termSize . snd) bndrs + bndrSzs = map (termSize . snd) xs bodySz = termSize e termSize (Case subj _ alts) = sum (subjSz:altSzs) where @@ -88,8 +89,8 @@ isPolyFun m t = isPolyFunCoreTy m (inferCoreTypeOf m t) -- | Is a term a recursive let-binding? isLet :: Term -> Bool -isLet (Letrec {}) = True -isLet _ = False +isLet Let{} = True +isLet _ = False -- | Is a term a variable reference? isVar :: Term -> Bool diff --git a/clash-lib/src/Clash/Core/Util.hs b/clash-lib/src/Clash/Core/Util.hs index e4d54b5dd7..752f783ff8 100644 --- a/clash-lib/src/Clash/Core/Util.hs +++ b/clash-lib/src/Clash/Core/Util.hs @@ -58,6 +58,15 @@ import Clash.Debug (traceIf) import Clash.Unique import Clash.Util +-- | Rebuild a let expression / let expressions by taking the SCCs of a list +-- of bindings and remaking Let (NonRec ...) ... and Let (Rec ...) ... +-- +listToLets :: [LetBinding] -> Term -> Term +listToLets xs body = foldr go body (sccLetBindings xs) + where + go (Graph.AcyclicSCC (i, x)) acc = Let (NonRec i x) acc + go (Graph.CyclicSCC binds) acc = Let (Rec binds) acc + -- | The type @forall a . a@ undefinedTy ::Type undefinedTy = @@ -142,12 +151,12 @@ extractElems -- ^ Length of the vector -> Term -- ^ The vector - -> (Supply, [(Term,[LetBinding])]) + -> (Supply, [(Term,[(Id, Term)])]) extractElems supply inScope consCon resTy s maxN vec = first fst (go maxN (supply,inScope) vec) where go :: Integer -> (Supply,InScopeSet) -> Term - -> ((Supply,InScopeSet),[(Term,[LetBinding])]) + -> ((Supply,InScopeSet),[(Term,[(Id, Term)])]) go 0 uniqs _ = (uniqs,[]) go n uniqs0 e = (uniqs3,(elNVar,[(elNId, lhs),(restNId, rhs)]):restVs) @@ -194,7 +203,7 @@ extractTElems -- ^ Depth of the tree -> Term -- ^ The tree - -> (Supply,([Term],[LetBinding])) + -> (Supply,([Term],[(Id, Term)])) extractTElems supply inScope lrCon brCon resTy s maxN tree = first fst (go maxN [0..(2^(maxN+1))-2] [0..(2^maxN - 1)] (supply,inScope) tree) where @@ -203,7 +212,7 @@ extractTElems supply inScope lrCon brCon resTy s maxN tree = -> [Int] -> (Supply,InScopeSet) -> Term - -> ((Supply,InScopeSet),([Term],[LetBinding])) + -> ((Supply,InScopeSet),([Term],[(Id, Term)])) go 0 _ ks uniqs0 e = (uniqs1,([elNVar],[(elNId, rhs)])) where tys = [LitTy (NumTy 0),resTy] @@ -676,8 +685,8 @@ inverseTopSortLetBindings e = e -- | Group let-bindings into cyclic groups and acyclic individual bindings sccLetBindings :: HasCallStack - => [LetBinding] - -> [Graph.SCC LetBinding] + => [(Id, Term)] + -> [Graph.SCC (Id, Term)] sccLetBindings = Graph.stronglyConnComp . (map (\(i,e) -> let fvs = fmap varUniq diff --git a/clash-lib/src/Clash/Netlist/BlackBox.hs b/clash-lib/src/Clash/Netlist/BlackBox.hs index 6c67e39894..0bf19bab53 100644 --- a/clash-lib/src/Clash/Netlist/BlackBox.hs +++ b/clash-lib/src/Clash/Netlist/BlackBox.hs @@ -14,6 +14,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TemplateHaskell #-} @@ -267,7 +268,7 @@ mkArgument bbName bndr nArg e = do (Case scrut ty' [alt],[],_) -> do (projection,decls) <- mkProjection False (NetlistId bndr ty) scrut ty' alt return ((projection,hwTy,False),decls) - (Letrec _bnds _term, [], _ticks) -> do + (Let _bnds _term, [], _ticks) -> do (exprN, letDecls) <- mkExpr False Concurrent (NetlistId bndr ty) e return ((exprN,hwTy,False),letDecls) _ -> do @@ -543,7 +544,7 @@ mkPrimitive bbEParen bbEasD dst pInfo args tickDecls = [dstNm] -> return ( Identifier dstNm Nothing , dstDecl ++ decls ++ [Assignment dstNm expr]) - _ -> error "internal error" + _ -> error $ $(curLoc) ++ "bindSimIO: " ++ show resM _ -> return (Noop,decls) @@ -1238,7 +1239,7 @@ mkFunInput resId e = nm <- Id.makeBasic "selection" return (Right ((nm,assn++selectionDecls),Wire)) - go is0 _ e'@(Letrec {}) = do + go is0 _ e'@(Let{}) = do tcm <- Lens.use tcCache let normE = splitNormalized tcm e' (_,[],[],_,[],binders,resultM) <- case normE of diff --git a/clash-lib/src/Clash/Netlist/Util.hs b/clash-lib/src/Clash/Netlist/Util.hs index bc99df5a3a..5a7b58023f 100644 --- a/clash-lib/src/Clash/Netlist/Util.hs +++ b/clash-lib/src/Clash/Netlist/Util.hs @@ -160,6 +160,11 @@ isVoid _ = False isFilteredVoid :: FilteredHWType -> Bool isFilteredVoid = isVoid . stripFiltered +squashLets :: Term -> Term +squashLets (Letrec xs (Letrec ys e)) = + squashLets (Letrec (xs <> ys) e) +squashLets e = e + -- | Split a normalized term into: a list of arguments, a list of let-bindings, -- and a variable reference that is the body of the let-binding. Returns a -- String containing the error if the term was not in a normalized form. @@ -168,10 +173,10 @@ splitNormalized -> Term -> (Either String ([Id],[LetBinding],Id)) splitNormalized tcm expr = case collectBndrs expr of - (args, collectTicks -> (Letrec xes e, ticks)) + (args, collectTicks -> (squashLets -> Letrec xes e, ticks)) | (tmArgs,[]) <- partitionEithers args -> case stripTicks e of Var v -> Right (tmArgs, fmap (second (`mkTicks` ticks)) xes,v) - _ -> Left ($(curLoc) ++ "Not in normal form: res not simple var") + t -> Left ($(curLoc) ++ "Not in normal form: res not simple var: " ++ showPpr t) | otherwise -> Left ($(curLoc) ++ "Not in normal form: tyArgs") _ -> Left ($(curLoc) ++ "Not in normal form: no Letrec:\n\n" ++ showPpr expr ++ diff --git a/clash-lib/src/Clash/Normalize/Transformations/ANF.hs b/clash-lib/src/Clash/Normalize/Transformations/ANF.hs index 01a5ba5bcc..f1d115c990 100644 --- a/clash-lib/src/Clash/Normalize/Transformations/ANF.hs +++ b/clash-lib/src/Clash/Normalize/Transformations/ANF.hs @@ -400,10 +400,10 @@ nonRepANF ctx@(TransformContext is0 _) e@(App appConPrim arg) = do untranslatable <- isUntranslatable False arg case (untranslatable,stripTicks arg) of - (True,Letrec binds body) -> + (True,Let binds body) -> -- This is a situation similar to Note [CaseLet deshadow] let (binds1,body1) = deshadowLetExpr is0 binds body - in changed (Letrec binds1 (App appConPrim body1)) + in changed (Let binds1 (App appConPrim body1)) (True,Case {}) -> specialize ctx e (True,Lam {}) -> specialize ctx e (True,TyLam {}) -> specialize ctx e diff --git a/clash-lib/src/Clash/Normalize/Transformations/Case.hs b/clash-lib/src/Clash/Normalize/Transformations/Case.hs index 74961b0494..64feb54045 100644 --- a/clash-lib/src/Clash/Normalize/Transformations/Case.hs +++ b/clash-lib/src/Clash/Normalize/Transformations/Case.hs @@ -59,9 +59,10 @@ import Clash.Core.Pretty (showPpr) import Clash.Core.Subst import Clash.Core.Term ( Alt, Pat(..), PrimInfo(..), Term(..), collectArgs, collectArgsTicks - , collectTicks, mkApps, mkTicks, patIds, stripTicks) + , collectTicks, mkApps, mkTicks, patIds, stripTicks, Bind(..)) import Clash.Core.TyCon (TyConMap) import Clash.Core.Type (LitTy(..), Type(..), TypeView(..), coreView1, tyView) +import Clash.Core.Util (listToLets) import Clash.Core.VarEnv ( InScopeSet, elemVarSet, extendInScopeSet, extendInScopeSetList, mkVarSet , unitVarSet, uniqAway) @@ -195,7 +196,9 @@ caseCon' ctx@(TransformContext is0 _) e@(Case subj ty alts) = do body = substTm "caseCon'" subst altE case Maybe.catMaybes binds2 of [] -> pure body - binds3 -> pure (Letrec binds3 body) + -- Use listToLets to create a series of non-recursive lets instead + -- of a recursive group. We know these binders will not form a group. + binds3 -> pure (listToLets binds3 body) changed altE1 _ -> case alts of -- In Core, default patterns always come first, so we match against @@ -337,16 +340,14 @@ matchLiteralContructor matchLiteralContructor c (IntegerLiteral l) alts = go (reverse alts) where go [(DefaultPat,e)] = changed e - go ((DataPat dc [] xs,e):alts') + go ((DataPat dc [] [x],e):alts') | dcTag dc == 1 , l >= ((-2)^(63::Int)) && l < 2^(63::Int) - = let fvs = Lens.foldMapOf freeLocalIds unitVarSet e - (binds,_) = List.partition ((`elemVarSet` fvs) . fst) - $ List.zipEqual xs [Literal (IntLiteral l)] - e' = case binds of - [] -> e - _ -> Letrec binds e - in changed e' + = let fvs = Lens.foldMapOf freeLocalIds unitVarSet e + bind = NonRec x (Literal (IntLiteral l)) + in if x `elemVarSet` fvs + then changed (Let bind e) + else changed e | dcTag dc == 2 , l >= 2^(63::Int) #if MIN_VERSION_base(4,15,0) @@ -356,12 +357,10 @@ matchLiteralContructor c (IntegerLiteral l) alts = go (reverse alts) #endif ba' = BA.ByteArray ba fvs = Lens.foldMapOf freeLocalIds unitVarSet e - (binds,_) = List.partition ((`elemVarSet` fvs) . fst) - $ List.zipEqual xs [Literal (ByteArrayLiteral ba')] - e' = case binds of - [] -> e - _ -> Letrec binds e - in changed e' + bind = NonRec x (Literal (ByteArrayLiteral ba')) + in if x `elemVarSet` fvs + then changed (Let bind e) + else changed e | dcTag dc == 3 , l < ((-2)^(63::Int)) #if MIN_VERSION_base(4,15,0) @@ -371,12 +370,10 @@ matchLiteralContructor c (IntegerLiteral l) alts = go (reverse alts) #endif ba' = BA.ByteArray ba fvs = Lens.foldMapOf freeLocalIds unitVarSet e - (binds,_) = List.partition ((`elemVarSet` fvs) . fst) - $ List.zipEqual xs [Literal (ByteArrayLiteral ba')] - e' = case binds of - [] -> e - _ -> Letrec binds e - in changed e' + bind = NonRec x (Literal (ByteArrayLiteral ba')) + in if x `elemVarSet` fvs + then changed (Let bind e) + else changed e | otherwise = go alts' go ((LitPat l', e):alts') @@ -389,16 +386,14 @@ matchLiteralContructor c (IntegerLiteral l) alts = go (reverse alts) matchLiteralContructor c (NaturalLiteral l) alts = go (reverse alts) where go [(DefaultPat,e)] = changed e - go ((DataPat dc [] xs,e):alts') + go ((DataPat dc [] [x],e):alts') | dcTag dc == 1 , l >= 0 && l < 2^(64::Int) = let fvs = Lens.foldMapOf freeLocalIds unitVarSet e - (binds,_) = List.partition ((`elemVarSet` fvs) . fst) - $ List.zipEqual xs [Literal (WordLiteral l)] - e' = case binds of - [] -> e - _ -> Letrec binds e - in changed e' + bind = NonRec x (Literal (WordLiteral l)) + in if x `elemVarSet` fvs + then changed (Let bind e) + else changed e | dcTag dc == 2 , l >= 2^(64::Int) #if MIN_VERSION_base(4,15,0) @@ -408,12 +403,10 @@ matchLiteralContructor c (NaturalLiteral l) alts = go (reverse alts) #endif ba' = BA.ByteArray ba fvs = Lens.foldMapOf freeLocalIds unitVarSet e - (binds,_) = List.partition ((`elemVarSet` fvs) . fst) - $ List.zipEqual xs [Literal (ByteArrayLiteral ba')] - e' = case binds of - [] -> e - _ -> Letrec binds e - in changed e' + bind = NonRec x (Literal (ByteArrayLiteral ba')) + in if x `elemVarSet` fvs + then changed (Let bind e) + else changed e | otherwise = go alts' go ((LitPat l', e):alts') @@ -564,7 +557,7 @@ collectEqArgs _ = Nothing -- | Lift the let-bindings out of the subject of a Case-decomposition caseLet :: HasCallStack => NormRewrite -caseLet (TransformContext is0 _) (Case (collectTicks -> (Letrec xes e,ticks)) ty alts) = do +caseLet (TransformContext is0 _) (Case (collectTicks -> (Let xes e,ticks)) ty alts) = do -- Note [CaseLet deshadow] -- Imagine -- @@ -593,7 +586,7 @@ caseLet (TransformContext is0 _) (Case (collectTicks -> (Letrec xes e,ticks)) ty -- It is safe to over-approximate the free variables in `a` by simply taking -- the current InScopeSet. let (xes1,e1) = deshadowLetExpr is0 xes e - changed (Letrec (fmap (second (`mkTicks` ticks)) xes1) + changed (Let (fmap (`mkTicks` ticks) xes1) (Case (mkTicks e1 ticks) ty alts)) caseLet _ e = return e diff --git a/clash-lib/src/Clash/Normalize/Transformations/Cast.hs b/clash-lib/src/Clash/Normalize/Transformations/Cast.hs index 6e48efabc8..222a5a91fe 100644 --- a/clash-lib/src/Clash/Normalize/Transformations/Cast.hs +++ b/clash-lib/src/Clash/Normalize/Transformations/Cast.hs @@ -105,10 +105,10 @@ elimCastCast _ c@(Cast (stripTicks -> Cast e tyA tyB) tyB' tyC) = do elimCastCast _ e = return e {-# SCC elimCastCast #-} --- | Push a cast over a Letrec into it's body +-- | Push a cast over a Let into it's body letCast :: HasCallStack => NormRewrite -letCast _ (Cast (stripTicks -> Letrec binds body) ty1 ty2) = - changed $ Letrec binds (Cast body ty1 ty2) +letCast _ (Cast (stripTicks -> Let binds body) ty1 ty2) = + changed $ Let binds (Cast body ty1 ty2) letCast _ e = return e {-# SCC letCast #-} diff --git a/clash-lib/src/Clash/Normalize/Transformations/EtaExpand.hs b/clash-lib/src/Clash/Normalize/Transformations/EtaExpand.hs index 266a9ce4d4..16c123c200 100644 --- a/clash-lib/src/Clash/Normalize/Transformations/EtaExpand.hs +++ b/clash-lib/src/Clash/Normalize/Transformations/EtaExpand.hs @@ -23,7 +23,7 @@ import qualified Data.Maybe as Maybe import GHC.Stack (HasCallStack) import Clash.Core.HasType -import Clash.Core.Term (CoreContext(..), Term(..), collectArgs, mkLams) +import Clash.Core.Term (Bind(..), CoreContext(..), Term(..), collectArgs, mkLams) import Clash.Core.TermInfo (isFun) import Clash.Core.Type (splitFunTy) import Clash.Core.Util (mkInternalVar) @@ -56,6 +56,13 @@ etaExpandSyn (TransformContext is0 ctx) e@(collectArgs -> (Var f, _)) = do etaExpandSyn _ e = return e {-# SCC etaExpandSyn #-} +stripLambda :: Term -> ([Id], Term) +stripLambda (Lam bndr e) = + let (bndrs, e') = stripLambda e + in (bndr : bndrs, e') + +stripLambda e = ([], e) + -- | Eta-expand top-level lambda's (DON'T use in a traversal!) etaExpansionTL :: HasCallStack => NormRewrite etaExpansionTL (TransformContext is0 ctx) (Lam bndr e) = do @@ -63,21 +70,24 @@ etaExpansionTL (TransformContext is0 ctx) (Lam bndr e) = do e' <- etaExpansionTL ctx' e return $ Lam bndr e' -etaExpansionTL (TransformContext is0 ctx) (Letrec xes e) = do +etaExpansionTL (TransformContext is0 ctx) (Let (NonRec i x) e) = do + let ctx' = TransformContext (extendInScopeSet is0 i) (LetBody [i] : ctx) + e' <- etaExpansionTL ctx' e + case stripLambda e' of + (bs@(_:_),e2) -> do + let e3 = Let (NonRec i x) e2 + changed (mkLams e3 bs) + _ -> return (Let (NonRec i x) e') + +etaExpansionTL (TransformContext is0 ctx) (Let (Rec xes) e) = do let bndrs = map fst xes ctx' = TransformContext (extendInScopeSetList is0 bndrs) (LetBody bndrs : ctx) e' <- etaExpansionTL ctx' e case stripLambda e' of (bs@(_:_),e2) -> do - let e3 = Letrec xes e2 + let e3 = Let (Rec xes) e2 changed (mkLams e3 bs) - _ -> return (Letrec xes e') - where - stripLambda :: Term -> ([Id],Term) - stripLambda (Lam bndr e0) = - let (bndrs,e1) = stripLambda e0 - in (bndr:bndrs,e1) - stripLambda e' = ([],e') + _ -> return (Let (Rec xes) e') etaExpansionTL (TransformContext is0 ctx) e = do diff --git a/clash-lib/src/Clash/Normalize/Transformations/Letrec.hs b/clash-lib/src/Clash/Normalize/Transformations/Letrec.hs index 75ea011c92..1f4902a4e0 100644 --- a/clash-lib/src/Clash/Normalize/Transformations/Letrec.hs +++ b/clash-lib/src/Clash/Normalize/Transformations/Letrec.hs @@ -50,7 +50,7 @@ import Clash.Core.Name (mkUnsafeSystemName, nameOcc) import Clash.Core.Subst import Clash.Core.Term ( LetBinding, Pat(..), PrimInfo(..), Term(..), collectArgs, collectArgsTicks - , collectTicks, isLambdaBodyCtx, isTickCtx, mkApps, mkLams, mkTicks + , collectTicks, isLambdaBodyCtx, isTickCtx, mkApps, mkLams, mkTicks, Bind(..) , partitionTicks, stripTicks) import Clash.Core.TermInfo (isCon, isLet, isLocalVar, isTick) import Clash.Core.TyCon (tyConDataCons) @@ -84,7 +84,7 @@ not the complete names. So we use mkUnsafeSystemName to recreate the same Name. -- | Remove unused let-bindings deadCode :: HasCallStack => NormRewrite -deadCode _ e@(Letrec binds body) = +deadCode _ e@(Let binds body) = case removeUnusedBinders binds body of Just t -> changed t Nothing -> return e @@ -428,7 +428,7 @@ topLet (TransformContext is0 ctx) e then return e else do tcm <- Lens.view tcCache argId <- mkTmBinderFor is0 tcm (mkUnsafeSystemName "result" 0) e - changed (Letrec [(argId, e)] (Var argId)) + changed (Let (NonRec argId e) (Var argId)) topLet (TransformContext is0 ctx) e@(Letrec binds body) | all (\c -> isLambdaBodyCtx c || isTickCtx c) ctx @@ -439,9 +439,15 @@ topLet (TransformContext is0 ctx) e@(Letrec binds body) then return e else do tcm <- Lens.view tcCache - let is2 = extendInScopeSetList is0 (map fst binds) + let is2 = extendInScopeSetList is0 (fmap fst binds) argId <- mkTmBinderFor is2 tcm (mkUnsafeSystemName "result" 0) body - changed (Letrec (binds ++ [(argId,body)]) (Var argId)) + + -- TODO We would like this to be + -- + -- Let binds (Let (NonRec argId body) (Var argId)) + -- + -- but this makes tests/shouldwork/SimIO/Test00.hs fail. + changed (Letrec (binds ++ [(argId, body)]) (Var argId)) topLet _ e = return e {-# SCC topLet #-} diff --git a/clash-lib/src/Clash/Normalize/Transformations/MultiPrim.hs b/clash-lib/src/Clash/Normalize/Transformations/MultiPrim.hs index 9a1320235a..073c2088e1 100644 --- a/clash-lib/src/Clash/Normalize/Transformations/MultiPrim.hs +++ b/clash-lib/src/Clash/Normalize/Transformations/MultiPrim.hs @@ -28,6 +28,7 @@ import Clash.Core.Term import Clash.Core.TermInfo (multiPrimInfo') import Clash.Core.TyCon (TyConMap) import Clash.Core.Type (Type(..), mkPolyFunTy, splitFunForallTy) +import Clash.Core.Util (listToLets) import Clash.Core.Var (mkLocalId) import Clash.Normalize.Types (NormRewrite, primitives) import Clash.Primitives.Types (Primitive(..)) @@ -123,6 +124,6 @@ setupMultiResultPrim' tcm primInfo@PrimInfo{primType} = } letTerm = - Letrec + listToLets ((resId,multiPrimBind):multiPrimSelectBinds) (mkTmApps (mkTyApps (Data tupTc) resTypes) (map Var resIds)) diff --git a/clash-lib/src/Clash/Normalize/Transformations/Specialize.hs b/clash-lib/src/Clash/Normalize/Transformations/Specialize.hs index a696c8669d..5dec3c8302 100644 --- a/clash-lib/src/Clash/Normalize/Transformations/Specialize.hs +++ b/clash-lib/src/Clash/Normalize/Transformations/Specialize.hs @@ -67,7 +67,7 @@ import Clash.Core.Name import Clash.Core.Pretty (showPpr) import Clash.Core.Subst import Clash.Core.Term - ( Term(..), TickInfo, collectArgs, collectArgsTicks, mkApps, mkTmApps, mkTicks, patIds + ( Term(..), TickInfo, collectArgs, collectArgsTicks, mkApps, mkTmApps, mkTicks, patIds, Bind(..) , patVars, mkAbstraction, PrimInfo(..), WorkInfo(..), IsMultiPrim(..), PrimUnfolding(..)) import Clash.Core.TermInfo (isLocalVar, isVar, isPolyFun) import Clash.Core.TyCon (TyConMap, tyConDataCons) @@ -75,6 +75,7 @@ import Clash.Core.Type (LitTy(NumTy), Type(LitTy,VarTy), applyFunTy, splitTyConAppM, normalizeType , mkPolyFunTy, mkTyConApp) import Clash.Core.TysPrim +import Clash.Core.Util (listToLets) import Clash.Core.Var (Var(..), Id, TyVar, mkTyVar) import Clash.Core.VarEnv ( InScopeSet, extendInScopeSet, extendInScopeSetList, lookupVarEnv @@ -204,14 +205,20 @@ appProp ctx@(TransformContext is _) = \case (`mkTicks` ticks) <$> go is0 (substTm "appProp.AppLam" subst e) args [] False -> let is1 = extendInScopeSet is0 v in - Letrec [(v, arg)] <$> go is1 (deShadowTerm is1 e) args ticks + Let (NonRec v arg) <$> go is1 (deShadowTerm is1 e) args ticks - go is0 (Letrec vs e) args@(_:_) ticks = do + go is0 (Let (NonRec i x) e) args@(_:_) ticks = do + setChanged + let is1 = extendInScopeSet is0 i + -- XXX: binding should already be deshadowed w.r.t. 'is0' + Let (NonRec i x) <$> go is1 e args ticks + + go is0 (Let (Rec vs) e) args@(_:_) ticks = do setChanged let vbs = map fst vs is1 = extendInScopeSetList is0 vbs -- XXX: 'vs' should already be deshadowed w.r.t. 'is0' - Letrec vs <$> go is1 e args ticks + Let (Rec vs) <$> go is1 e args ticks go is0 (TyLam tv e) (Right t:args) ticks = do setChanged @@ -230,7 +237,10 @@ appProp ctx@(TransformContext is _) = \case let vbs = map fst vs is1 = extendInScopeSetList is0 vbs alts1 = map (deShadowAlt is1) alts - Letrec vs . (`mkTicks` ticks) . Case scrut ty1 <$> mapM (goAlt is1 args1) alts1 + -- TODO I should have a mkNonRecLets :: [LetBinding] -> Term -> Term + -- function which makes a chain of non-recursive let expressions without + -- needing to first take the SCCs of all the binders. + listToLets vs . (`mkTicks` ticks) . Case scrut ty1 <$> mapM (goAlt is1 args1) alts1 go is0 (Tick sp e) args ticks = do setChanged @@ -288,7 +298,7 @@ constantSpec ctx@(TransformContext is0 tfCtx) e@(App e1 e2) (App e1 (csrNewTerm specInfo)) if Monoid.getAny isSpec - then changed (Letrec newBindings body) + then changed (listToLets newBindings body) else return e else -- e2 has no constant parts diff --git a/clash-lib/src/Clash/Rewrite/Util.hs b/clash-lib/src/Clash/Rewrite/Util.hs index 4ef116454d..5e578e95f5 100644 --- a/clash-lib/src/Clash/Rewrite/Util.hs +++ b/clash-lib/src/Clash/Rewrite/Util.hs @@ -75,9 +75,9 @@ import Clash.Core.Type (Type (..), normalizeType) import Clash.Core.Var (Id, IdScope (..), TyVar, Var (..), mkGlobalId, mkLocalId, mkTyVar) import Clash.Core.VarEnv - (InScopeSet, extendInScopeSetList, mkInScopeSet, + (InScopeSet, extendInScopeSet, extendInScopeSetList, mkInScopeSet, uniqAway, uniqAway', mapVarEnv, eltsVarEnv, unitVarSet, emptyVarEnv, - mkVarEnv, eltsVarSet, elemVarEnv, lookupVarEnv, extendVarEnv) + mkVarEnv, eltsVarSet, elemVarEnv, lookupVarEnv, extendVarEnv, elemVarSet) import Clash.Debug import Clash.Driver.Types (TransformationInfo(..), DebugOpts(..), BindingMap, Binding(..), IsPrim(..), @@ -122,8 +122,8 @@ findAccidentialShadows = Case t _ as -> concatMap (findInPat . fst) as ++ concatMap findAccidentialShadows (t : map snd as) - Letrec bs t -> - findDups (map fst bs) ++ findAccidentialShadows t + Let NonRec{} t -> findAccidentialShadows t + Let (Rec bs) t -> findDups (map fst bs) ++ findAccidentialShadows t where findInPat :: Pat -> [[Id]] @@ -346,7 +346,17 @@ inlineBinders :: (Term -> LetBinding -> RewriteMonad extra Bool) -- ^ Property test -> Rewrite extra -inlineBinders condition (TransformContext inScope0 _) expr@(Letrec xes res) = do +inlineBinders condition (TransformContext inScope0 _) expr@(Let (NonRec i x) res) = do + inline <- condition expr (i, x) + + if inline && elemFreeVars i res then + let inScope1 = extendInScopeSet inScope0 i + subst = extendIdSubst (mkSubst inScope1) i x + in changed (substTm "inlineBinders" subst res) + else + return expr + +inlineBinders condition (TransformContext inScope0 _) expr@(Let (Rec xes) res) = do (toInline,toKeep) <- partitionM (condition expr) xes case toInline of [] -> return expr @@ -760,7 +770,7 @@ bindPureHeap tcm heap rw ctx0@(TransformContext is0 hist) e = do -- ‡ https://www.microsoft.com/en-us/research/wp-content/uploads/2016/07/supercomp-by-eval.pdf bs <- Lens.use bindings inlineBinders (inlineTest bs) ctx0 (Letrec bndrs e1) >>= \case - e2@(Letrec bnders1 e3) -> + e2@(Let bnders1 e3) -> pure (fromMaybe e2 (removeUnusedBinders bnders1 e3)) e2 -> pure e2 @@ -784,10 +794,14 @@ bindPureHeap tcm heap rw ctx0@(TransformContext is0 hist) e = do -- | Remove unused binders in given let-binding. Returns /Nothing/ if no unused -- binders were found. removeUnusedBinders - :: [LetBinding] + :: Bind Term -> Term -> Maybe Term -removeUnusedBinders binds body = +removeUnusedBinders (NonRec i _) body = + let bodyFVs = Lens.foldMapOf freeLocalIds unitVarSet body + in if i `elemVarSet` bodyFVs then Nothing else Just body + +removeUnusedBinders (Rec binds) body = case eltsVarEnv used of [] -> Just body qqL | not (List.equalLength qqL binds) diff --git a/clash-lib/src/Clash/Rewrite/WorkFree.hs b/clash-lib/src/Clash/Rewrite/WorkFree.hs index 9bc04aef7d..576fdd6782 100644 --- a/clash-lib/src/Clash/Rewrite/WorkFree.hs +++ b/clash-lib/src/Clash/Rewrite/WorkFree.hs @@ -111,7 +111,8 @@ isWorkFree cache bndrs = go True Lam _ e -> andM [go False e, allM goArg args] TyLam _ e -> andM [go False e, allM goArg args] - Letrec bs e -> andM [go False e, allM (go False . snd) bs, allM goArg args] + Let (NonRec _ x) e -> andM [go False e, go False x, allM goArg args] + Let (Rec bs) e -> andM [go False e, allM (go False . snd) bs, allM goArg args] Case s _ [(_, a)] -> andM [go False s, go False a, allM goArg args] Case e _ _ -> andM [go False e, allM goArg args] Cast e _ _ -> andM [go False e, allM goArg args]