Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect more complex invariant loop parameters. #1111

Merged
merged 1 commit into from Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

### Added

* Obscure loop optimisation (#1110).

### Removed

### Changed
Expand Down
38 changes: 36 additions & 2 deletions src/Futhark/Analysis/SymbolTable.hs
Expand Up @@ -26,6 +26,7 @@ module Futhark.Analysis.SymbolTable
lookupSubExp,
lookupAliases,
lookupLoopVar,
lookupLoopParam,
available,
consume,
index,
Expand All @@ -40,6 +41,7 @@ module Futhark.Analysis.SymbolTable
insertFParams,
insertLParam,
insertLoopVar,
insertLoopMerge,

-- * Misc
hideCertified,
Expand Down Expand Up @@ -154,7 +156,10 @@ data LetBoundEntry lore = LetBoundEntry

data FParamEntry lore = FParamEntry
{ fparamDec :: FParamInfo lore,
fparamAliases :: Names
fparamAliases :: Names,
-- | If a loop parameter, the initial value and the eventual
-- result. The result need not be in scope in the symbol table.
fparamMerge :: Maybe (SubExp, SubExp)
}

data LParamEntry lore = LParamEntry
Expand Down Expand Up @@ -235,6 +240,11 @@ lookupLoopVar name vtable = do
LoopVar e <- entryType <$> M.lookup name (bindings vtable)
return $ loopVarBound e

lookupLoopParam :: VName -> SymbolTable lore -> Maybe (SubExp, SubExp)
lookupLoopParam name vtable = do
FParam e <- entryType <$> M.lookup name (bindings vtable)
fparamMerge e

-- | In symbol table and not consumed.
available :: VName -> SymbolTable lore -> Bool
available name = maybe False (not . entryConsumed) . M.lookup name . bindings
Expand Down Expand Up @@ -443,7 +453,8 @@ insertFParam fparam = insertEntry name entry
FParam
FParamEntry
{ fparamDec = AST.paramDec fparam,
fparamAliases = mempty
fparamAliases = mempty,
fparamMerge = Nothing
}

insertFParams ::
Expand All @@ -464,6 +475,29 @@ insertLParam param = insertEntry name bind
}
name = AST.paramName param

-- | Insert entries corresponding to the parameters of a loop (not
-- distinguishing contect and value part). Apart from the parameter
-- itself, we also insert the initial value and the subexpression
-- providing the final value. Note that the latter is likely not in
-- scope in the symbol at this point. This is OK, and can still be
-- used to help some loop optimisations detect invariant loop
-- parameters.
insertLoopMerge ::
ASTLore lore =>
[(AST.FParam lore, SubExp, SubExp)] ->
SymbolTable lore ->
SymbolTable lore
insertLoopMerge = flip $ foldl' $ flip bind
where
bind (p, initial, res) =
insertEntry (paramName p) $
FParam
FParamEntry
{ fparamDec = AST.paramDec p,
fparamAliases = mempty,
fparamMerge = Just (initial, res)
}

insertLoopVar :: ASTLore lore => VName -> IntType -> SubExp -> SymbolTable lore -> SymbolTable lore
insertLoopVar name it bound = insertEntry name bind
where
Expand Down
10 changes: 9 additions & 1 deletion src/Futhark/Optimise/Simplify/Engine.hs
Expand Up @@ -239,6 +239,13 @@ bindArrayLParams ::
bindArrayLParams params =
localVtable $ \vtable -> foldl' (flip ST.insertLParam) vtable params

bindMerge ::
SimplifiableLore lore =>
[(FParam (Wise lore), SubExp, SubExp)] ->
SimpleM lore a ->
SimpleM lore a
bindMerge = localVtable . ST.insertLoopMerge

bindLoopVar :: SimplifiableLore lore => VName -> IntType -> SubExp -> SimpleM lore a -> SimpleM lore a
bindLoopVar var it bound =
localVtable $ ST.insertLoopVar var it bound
Expand Down Expand Up @@ -800,7 +807,7 @@ simplifyExp (DoLoop ctx val form loopbody) = do
((loopstms, loopres), hoisted) <-
enterLoop $
consumeMerge $
bindFParams (ctxparams' ++ valparams') $
bindMerge (zipWith withRes (ctx' ++ val') (bodyResult loopbody)) $
wrapbody $
blockIf
( hasFree boundnames `orIf` isConsumed
Expand All @@ -819,6 +826,7 @@ simplifyExp (DoLoop ctx val form loopbody) = do
localVtable $ flip (foldl' (flip ST.consume)) $ namesToList consumed_by_merge
consumed_by_merge =
freeIn $ map snd $ filter (unique . paramDeclType . fst) val
withRes (p, x) y = (p, x, y)
simplifyExp (Op op) = do
(op', stms) <- simplifyOp op
return (Op op', stms)
Expand Down
45 changes: 30 additions & 15 deletions src/Futhark/Optimise/Simplify/Rules.hs
Expand Up @@ -153,11 +153,11 @@ removeRedundantMergeVariables _ _ _ _ =
-- We may change the type of the loop if we hoist out a shape
-- annotation, in which case we also need to tweak the bound pattern.
hoistLoopInvariantMergeVariables :: BinderOps lore => TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables _ pat aux (ctx, val, form, loopbody) =
hoistLoopInvariantMergeVariables vtable pat aux (ctx, val, form, loopbody) =
-- Figure out which of the elements of loopresult are
-- loop-invariant, and hoist them out.
case foldr checkInvariance ([], explpat, [], []) $
zip merge res of
zip3 (patternNames pat) merge res of
([], _, _, _) ->
-- Nothing is invariant.
Skip
Expand Down Expand Up @@ -199,11 +199,11 @@ hoistLoopInvariantMergeVariables _ pat aux (ctx, val, form, loopbody) =
(Nothing, explpat')

checkInvariance
((mergeParam, mergeInit), resExp)
(pat_name, (mergeParam, mergeInit), resExp)
(invariant, explpat', merge', resExps)
| not (unique (paramDeclType mergeParam))
|| arrayRank (paramDeclType mergeParam) == 1,
isInvariant resExp,
isInvariant,
-- Also do not remove the condition in a while-loop.
not $ paramName mergeParam `nameIn` freeIn form =
let (bnd, explpat'') =
Expand All @@ -214,21 +214,36 @@ hoistLoopInvariantMergeVariables _ pat aux (ctx, val, form, loopbody) =
resExps
)
where
-- A non-unique merge variable is invariant if the corresponding
-- subexp in the result is EITHER:
-- A non-unique merge variable is invariant if one of the
-- following is true:
--
-- (0) a variable of the same name as the parameter, where
-- all existential parameters are already known to be
-- invariant
isInvariant (Var v2)
| paramName mergeParam == v2 =
-- (0) The result is a variable of the same name as the
-- parameter, where all existential parameters are already
-- known to be invariant
isInvariant
| Var v2 <- resExp,
paramName mergeParam == v2 =
allExistentialInvariant
(namesFromList $ map (identName . fst) invariant)
mergeParam
-- (1) or identical to the initial value of the parameter.
isInvariant _ = mergeInit == resExp
checkInvariance ((mergeParam, mergeInit), resExp) (invariant, explpat', merge', resExps) =
(invariant, explpat', (mergeParam, mergeInit) : merge', resExp : resExps)
-- (1) The result is identical to the initial parameter value.
| mergeInit == resExp = True
-- (2) The initial parameter value is equal to an outer
-- loop parameter 'P', where the initial value of 'P' is
-- equal to 'resExp', AND 'resExp' ultimately becomes the
-- new value of 'P'. XXX: it's a bit clumsy that this
-- only works for one level of nesting, and I think it
-- would not be too hard to generalise.
| Var init_v <- mergeInit,
Just (p_init, p_res) <- ST.lookupLoopParam init_v vtable,
p_init == resExp,
p_res == Var pat_name =
True
| otherwise = False
checkInvariance
(_pat_name, (mergeParam, mergeInit), resExp)
(invariant, explpat', merge', resExps) =
(invariant, explpat', (mergeParam, mergeInit) : merge', resExp : resExps)

allExistentialInvariant namesOfInvariant mergeParam =
all (invariantOrNotMergeParam namesOfInvariant) $
Expand Down
9 changes: 9 additions & 0 deletions tests/loops/loop15.fut
@@ -0,0 +1,9 @@
-- Simple case; simplify away the loops.
-- ==
-- input { 10 2 } output { 2 }
-- structure { DoLoop 0 }

let main (n: i32) (a: i32) =
loop x = a for _i < n do
loop _y = x for _j < n do
a
11 changes: 11 additions & 0 deletions tests/loops/loop16.fut
@@ -0,0 +1,11 @@
-- Complex case; simplify away the loops.
-- ==
-- input { 10 2 [1,2,3] }
-- output { [1,2] }
-- structure { DoLoop 0 }

let main (n: i32) (a: i32) (arr: []i32) =
#[unsafe] -- Just to make the IR cleaner.
loop x = take a arr for _i < n do
loop _y = take (length x) arr for _j < n do
take a arr