Skip to content

Commit

Permalink
Finish inlineSat.
Browse files Browse the repository at this point in the history
  • Loading branch information
thealmarty committed Mar 16, 2023
1 parent daaf32b commit f63badb
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 121 deletions.
157 changes: 84 additions & 73 deletions plutus-core/plutus-ir/src/PlutusIR/Transform/Inline/CallSiteInline.hs
Expand Up @@ -12,11 +12,10 @@ See note [Inlining of fully applied functions].

module PlutusIR.Transform.Inline.CallSiteInline where

import Control.Lens (forMOf)
import Control.Monad.State
import Data.Map.Strict qualified as Map
import PlutusIR.Core
import PlutusIR.Transform.Inline.Utils
import Prettyprinter

{- Note [Inlining of fully applied functions]
Expand Down Expand Up @@ -148,16 +147,16 @@ We may want to check their sizes instead of just rejecting them.
-- | Computes the 'Arity' of a term.
computeArity ::
Term tyname name uni fun ann
-> Arity
-> (Arity, Term tyname name uni fun ann)
computeArity = \case
LamAbs _ _ _ body -> MkTerm : computeArity body
TyAbs _ _ _ body -> MkType : computeArity body
LamAbs _ _ _ body -> (MkTerm : fst (computeArity body), body)
TyAbs _ _ _ body -> (MkType : fst (computeArity body), body)
-- Whenever we encounter a body that is not a lambda or type abstraction, we are done counting
_ -> []
tm -> ([],tm)

-- | Inline fully applied functions iff the body of the function is `acceptable`.
considerInline ::
Term tyname name uni fun ann -- the variable that is a function
considerInline :: forall tyname name uni fun ann. InliningConstraints tyname name uni fun
=> Term tyname name uni fun ann -- the variable that is a function
-> InlineM tyname name uni fun ann (Term tyname name uni fun ann)
considerInline v@(Var ann n) = do
-- look up the variable in the `CalledVar` map
Expand All @@ -166,14 +165,16 @@ considerInline v@(Var ann n) = do
-- if it's not in the map, it's not a function, don't inline.
Nothing -> pure v
Just info -> do
let subst = calledVarDef info -- what we substitute in is its definition
isAcceptable <- acceptable subst
let
-- subst = calledVarDef info -- what we substitute in is its definition
bodyToCheckAcceptable = calledVarBody info
isAcceptable <- acceptable bodyToCheckAcceptable
-- if the size and cost are not acceptable, don't inline
if not isAcceptable then pure v
-- if the size and cost are acceptable, then check if it's fully applied
-- See note [Identifying fully applied call sites].
else do
pure v
else
inlineSat v
considerInline _notVar = -- this should not happen
Prelude.error "considerInline: should be a variable."

Expand All @@ -189,32 +190,41 @@ considerInline _notVar = -- this should not happen
-- sites.
-- type ApplicationMap ann name = Map.Map name [ApplicationOrder ann]

-- | A term or type argument.
data Args tyname name uni fun ann =
MkTermArg (Term tyname name uni fun ann)
| MkTypeArg (Type tyname uni ann)

-- | A list of type or term argument(s) being applied.
type ArgOrder tyname name uni fun ann = [Args tyname name uni fun ann]

-- | A pair of argument and the annotation of the term being applied to,
-- so the term can be built back in `mkApps`.
type ArgOrderWithAnn tyname name uni fun ann =
[(Args tyname name uni fun ann, ann)]

-- | Takes a term or type application expression and returns the function
-- being applied and the arguments to which it is applied
collectArgs :: Term tyname name uni fun ann
-> (Term tyname name uni fun ann, ArgOrder tyname name uni fun ann)
-> (Term tyname name uni fun ann, ArgOrderWithAnn tyname name uni fun ann)
collectArgs expr
= go expr []
where
go (Apply _ f a) as = go f (MkTermArg a:as)
go (TyInst _ f tyArg) as = go f (MkTypeArg tyArg:as)
go e as = (e, as)
go (Apply ann f a) as = go f ((MkTermArg a, ann):as)
go (TyInst ann f tyArg) as = go f ((MkTypeArg tyArg, ann):as)
go e as = (e, as)

-- | Apply a list of term and type arguments to a function in potentially a nested fashion.
mkApps :: Term tyname name uni fun ann
-> ArgOrder tyname name uni fun ann
-> ArgOrderWithAnn tyname name uni fun ann
-> Term tyname name uni fun ann
mkApps f (MkTermArg tmArg : args) = Apply f args
mkApps f ((MkTermArg tmArg,ann) : args) = mkApps (Apply ann f tmArg) args
mkApps f ((MkTypeArg tyArg,ann) : args) = mkApps (TyInst ann f tyArg) args
mkApps f [] = f

enoughArgs :: Arity -> ArgOrder tyname name uni fun ann -> Bool
enoughArgs [] argsOrder = True
enoughArgs arity [] = False
enoughArgs [] _argsOrder = True
enoughArgs _arity [] = False
enoughArgs lamOrder argsOrder =
-- start comparing from the end because there may be over-application
case (last lamOrder, last argsOrder) of
Expand All @@ -223,67 +233,68 @@ enoughArgs lamOrder argsOrder =
_ -> False

-- | Inline fully applied functions. See note [Identifying fully applied call sites].
inlineFullyApplied :: forall tyname name uni fun ann. InliningConstraints tyname name uni fun
inlineSat :: forall tyname name uni fun ann. InliningConstraints tyname name uni fun
=> Term tyname name uni fun ann -- ^ The `body` of the `Let` term.
-> InlineM tyname name uni fun ann (Term tyname name uni fun ann)
-- If the term is a term application, get the `AppOrder` of the
inlineFullyApplied appTerm@(Apply _ fun arg) = do
-- If the term is a term application, see if we it's applying to something that we may inline
inlineSat appTerm@(Apply _varAnn _fun _arg) = do
-- collect all the arguments of the term being applied to
let argsAppliedTo = fst $ collectArgs appTerm
args = snd $ collectArgs appTerm
case argsAppliedTo of
-- if it is a `Var` that is being applied to, check to see if it's fully applied
Var _ name -> do
Var _ann name -> do
maybeVarInfo <- gets (lookupCalled name)
case maybeVarInfo of
Nothing -> pure $ Apply _ (go fun) (go arg)
-- the variable is not in the map that contains all the in-scope functions, this shouldn't
-- happen? TODO maybe error out instead?
Nothing -> forMOf termSubterms appTerm inlineSat
Just varInfo -> do
if enoughArgs (arity varInfo) args then
if enoughArgs (arity varInfo) (map fst args) then
-- if the `Var` is fully applied (over-application is allowed) then inline it
mkApps (calledVarDef varInfo) (go <$> args)
else pure $ Apply _ (go fun) (go arg) -- otherwise just keep going
-- if the term being applied is not a `Var`, just keep going
_ -> pure $ Apply _ (go fun) (go arg)
inlineFullyApplied (TyInst _ fnBody _) =
-- If the term is a type application, add it to the application stack, and
-- keep on examining the body.
countLocal (appStack <> [MkType]) calledStack fnBody
inlineFullyApplied tm = pure tm

-- (Var ann name) =
-- -- When we encounter a body that is a variable, we have found a call site of it.
-- -- Using `insertWith` ensures that if a variable is called more than once, the new
-- -- `ApplicationOrder` map will be appended to the existing one.
-- Map.insertWith (<>) name [MkApplicationOrder ann appStack] calledStack
-- go (Let _ _ bds letBody) =
-- -- recursive or not, the bindings of this let term *may* contain the variable in
-- -- question, so we need to check all the bindings and also the body
-- let
-- -- get the list of rhs's of the term bindings
-- getRHS :: Binding tyname name uni fun ann -> Maybe (Term tyname name uni fun ann)
-- getRHS (TermBind _ _ _ rhs) = Just rhs
-- getRHS _ =
-- -- no need to keep track of the type bindings. Even though this type variable
-- -- called in the body, it does not affect the resulting `ApplicationMap`
-- Nothing
-- listOfRHSOfBindings = mapMaybe getRHS (toList bds)
-- in
-- foldr (flip $ countLocal []) (countLocal [] calledStack letBody) listOfRHSOfBindings
-- go (TyAbs _ _ _ tyAbsBody) =
-- -- start count in the body of the type lambda abstraction
-- countLocal [] calledStack tyAbsBody
-- go (LamAbs _ _ _ fnBody) =
-- -- start the count in the body of the term lambda abstraction
-- countLocal [] calledStack fnBody
-- go (Constant _ _) =
-- calledStack -- constants cannot call the variable
-- go (Builtin _ _) =
-- -- default builtin functions in `PlutusCore/Default/Builtins.hs`
-- -- cannot call the variable
-- calledStack
-- go (Unwrap _ tm) =
-- countLocal [] calledStack tm
-- go (IWrap _ _ _ tm) =
-- countLocal [] calledStack tm
-- go (Error _ _) = calledStack

pure $ mkApps (calledVarDef varInfo) args
-- otherwise just keep going
else forMOf termSubterms appTerm inlineSat
-- if the term being applied is not a `Var`, don't inline, but keep checking
v -> forMOf termSubterms v inlineSat -- keep checking all subterms
inlineSat tyInstTerm@(TyInst varAnn fun arg) = do
-- collect all the arguments of the term being applied to
let argsAppliedTo = fst $ collectArgs tyInstTerm
args = snd $ collectArgs tyInstTerm
case argsAppliedTo of
-- if it is a `Var` that is being applied to, check to see if it's fully applied
Var _ann name -> do
maybeVarInfo <- gets (lookupCalled name)
case maybeVarInfo of
Nothing -> forMOf termSubterms tyInstTerm inlineSat
Just varInfo -> do
if enoughArgs (arity varInfo) (map fst args) then
-- if the `Var` is fully applied (over-application is allowed) then inline it
pure $ mkApps (calledVarDef varInfo) args
-- otherwise just keep going
else forMOf termSubterms tyInstTerm inlineSat
-- if the term being applied is not a `Var`, don't inline but keep checking the subterms
v -> forMOf termSubterms v inlineSat
inlineSat letTm@(Let _ _ bds _letBody) = do
-- recursive or not, the bindings of this let term *may* contain a saturated function,
-- so we need to check all the bindings and also the body
-- `PlutusIR.Core.Plated.termSubterms` gives all that
forMOf termSubterms letTm inlineSat
inlineSat (TyAbs _ _ _ tyAbsBody) =
-- start count in the body of the type lambda abstraction
inlineSat tyAbsBody
inlineSat (LamAbs _ _ _ fnBody) =
-- start the count in the body of the term lambda abstraction
inlineSat fnBody
inlineSat con@(Constant _ _) =
-- constants cannot call the variable
pure con
inlineSat bi@(Builtin _ _) =
-- default builtin functions in `PlutusCore/Default/Builtins.hs`
-- cannot call the variable
pure bi
inlineSat v@(Var _ _) =
-- variables being applied should have been checked already, these ones aren't fully applied.
-- We don't inline them.
pure v
inlineSat others = pure others
Expand Up @@ -34,6 +34,7 @@ import Control.Monad.State

import Algebra.Graph qualified as G
import Data.Map qualified as Map
import PlutusIR.Transform.Inline.CallSiteInline
import Witherable (Witherable (wither))

{- Note [Inlining approach and 'Secrets of the GHC Inliner']
Expand Down Expand Up @@ -168,7 +169,7 @@ processTerm = handleTerm <=< traverseOf termSubtypes applyTypeSubstitution where
Nothing -> do
considerInline v
-- If it's in the substitution map, do the substitution
Just v -> pure v
Just var -> pure var
Let ann NonRec bs t -> do
-- Process bindings, eliminating those which will be inlined unconditionally,
-- and accumulating the new substitutions
Expand All @@ -184,7 +185,8 @@ processTerm = handleTerm <=< traverseOf termSubtypes applyTypeSubstitution where
-- This includes recursive let terms, we don't even consider inlining them at the moment
t -> forMOf termSubterms t processTerm

applyTypeSubstitution :: Type tyname uni ann
applyTypeSubstitution :: forall tyname name uni fun ann. InliningConstraints tyname name uni fun
=> Type tyname uni ann
-> InlineM tyname name uni fun ann (Type tyname uni ann)
applyTypeSubstitution t = gets isTypeSubstEmpty >>= \case
-- The type substitution is very often empty, and there are lots of types in the program,
Expand All @@ -193,17 +195,21 @@ applyTypeSubstitution t = gets isTypeSubstEmpty >>= \case
_ -> typeSubstTyNamesM substTyName t

-- See Note [Renaming strategy]
substTyName :: tyname -> InlineM tyname name uni fun ann (Maybe (Type tyname uni ann))
substTyName :: forall tyname name uni fun ann. InliningConstraints tyname name uni fun
=> tyname
-> InlineM tyname name uni fun ann (Maybe (Type tyname uni ann))
substTyName tyname = gets (lookupType tyname) >>= traverse liftDupable

-- See Note [Renaming strategy]
substName :: name -> InlineM tyname name uni fun ann (Maybe (Term tyname name uni fun ann))
substName :: forall tyname name uni fun ann. InliningConstraints tyname name uni fun
=> name
-> InlineM tyname name uni fun ann (Maybe (Term tyname name uni fun ann))
substName name = gets (lookupTerm name) >>= traverse renameTerm

-- See Note [Inlining approach and 'Secrets of the GHC Inliner']
-- Already processed term, just rename and put it in, don't do any further optimization here.
renameTerm ::
InlineTerm tyname name uni fun ann
renameTerm :: forall tyname name uni fun ann. InliningConstraints tyname name uni fun
=> InlineTerm tyname name uni fun ann
-> InlineM tyname name uni fun ann (Term tyname name uni fun ann)
renameTerm (Done t) = liftDupable t

Expand All @@ -228,50 +234,25 @@ processSingleBinding body = \case
TermBind ann s v@(VarDecl _ n (TyFun _ _tyArg _tyBody)) rhs -> do
let
-- track the term and type lambda abstraction order of the function
varLamOrder = computeArity rhs
-- examine the `body` of the `Let` term and track all term/type applications.
appSites = countApp body
-- list of all call sites of this variable
listOfCallSites = Map.lookup n appSites
case listOfCallSites of
Nothing ->
-- we don't remove the binding because we decide *at the call site* whether we want to
-- inline, and it may be called more than once
pure $ TermBind ann s v rhs
Just list -> do
let
isEqAppOrder :: ApplicationOrder ann -> Bool
isEqAppOrder appOrder = applicationOrder appOrder == varLamOrder
-- filter the list to only call locations that are fully applied
filteredFullyApplied = filter isEqAppOrder list
fullyAppliedAnns = fmap annotation filteredFullyApplied
-- add the function to `CalledVarEnv`
void $ modify' $ extendCalled n (MkCalledVarInfo rhs varLamOrder fullyAppliedAnns)
pure $ TermBind ann s v rhs
varLamOrder = fst $ computeArity rhs
bodyToCheck = snd $ computeArity rhs
-- add the function to `CalledVarEnv`
void $ modify' $ extendCalled n (MkCalledVarInfo rhs varLamOrder bodyToCheck)
-- we still want to do unconditional inline
maybeRhs' <- maybeAddSubst body ann s n rhs
pure $ TermBind ann s v <$> maybeRhs'
-- when the let binding is a type lambda abstraction, we add it to the `CalledVarEnv` and
-- consider whether we want to inline at the call site.
TermBind ann s v@(VarDecl _ n (TyLam _ann _tyname _tyArg _tyBody)) rhs -> do
let varLamOrder = countLam rhs
appSites = countApp body
listOfCallSites = Map.lookup n appSites
case listOfCallSites of
Nothing ->
-- we don't remove the binding because we decide *at the call site* whether we want to
-- inline, and it may be called more than once
pure $ TermBind ann s v rhs
Just list -> do
let
isEqAppOrder :: ApplicationOrder ann -> Bool
isEqAppOrder appOrder = applicationOrder appOrder == varLamOrder
-- filter the list to only call locations that are fully applied
filteredFullyApplied = filter isEqAppOrder list
fullyAppliedAnns = fmap annotation filteredFullyApplied
-- add the function to `CalledVarEnv`
-- add the type abstraction to `CalledVarEnv`
void $ modify' $ extendCalled n (MkCalledVarInfo rhs varLamOrder fullyAppliedAnns)
-- we don't remove the binding because we decide *at the call site* whether we want to
-- inline, and it may be called more than once
pure $ TermBind ann s v rhs
let varLamOrder = fst $ computeArity rhs
bodyToCheck = snd $ computeArity rhs
-- add the function to `CalledVarEnv`
-- add the type abstraction to `CalledVarEnv`
void $ modify' $ extendCalled n (MkCalledVarInfo rhs varLamOrder bodyToCheck)
-- we don't remove the binding because we decide *at the call site* whether we want to
-- inline, and it may be called more than once
maybeRhs' <- maybeAddSubst body ann s n rhs
pure $ TermBind ann s v <$> maybeRhs'
-- for binding that aren't functions, maybe do unconditional inline
TermBind ann s v@(VarDecl _ n _) rhs -> do
maybeRhs' <- maybeAddSubst body ann s n rhs
Expand Down
2 changes: 1 addition & 1 deletion plutus-core/plutus-ir/test/TransformSpec.hs
Expand Up @@ -214,7 +214,7 @@ inline =
computeArityTest :: TestNested
computeArityTest = testNested "computeArityTest" $
map
(goldenPir (computeArity . runQuote . PLC.rename) pTerm)
(goldenPir (fst . computeArity . runQuote . PLC.rename) pTerm)
[ "var" -- from inline tests, testing let terms
, "tyvar"
, "single"
Expand Down

0 comments on commit f63badb

Please sign in to comment.