diff --git a/CHANGELOG.md b/CHANGELOG.md index e75a6bccde..9500e8b768 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Added + * Obscure loop optimisation (#1110). + ### Removed ### Changed diff --git a/src/Futhark/Analysis/SymbolTable.hs b/src/Futhark/Analysis/SymbolTable.hs index edc200fc10..39a7a1b407 100644 --- a/src/Futhark/Analysis/SymbolTable.hs +++ b/src/Futhark/Analysis/SymbolTable.hs @@ -26,6 +26,7 @@ module Futhark.Analysis.SymbolTable lookupSubExp, lookupAliases, lookupLoopVar, + lookupLoopParam, available, consume, index, @@ -40,6 +41,7 @@ module Futhark.Analysis.SymbolTable insertFParams, insertLParam, insertLoopVar, + insertLoopMerge, -- * Misc hideCertified, @@ -154,7 +156,10 @@ data LetBoundEntry lore = LetBoundEntry data FParamEntry lore = FParamEntry { fparamDec :: FParamInfo lore, - fparamAliases :: Names + fparamAliases :: Names, + -- | If a loop parameter, the initial value and the eventual + -- result. The result need not be in scope in the symbol table. + fparamMerge :: Maybe (SubExp, SubExp) } data LParamEntry lore = LParamEntry @@ -235,6 +240,11 @@ lookupLoopVar name vtable = do LoopVar e <- entryType <$> M.lookup name (bindings vtable) return $ loopVarBound e +lookupLoopParam :: VName -> SymbolTable lore -> Maybe (SubExp, SubExp) +lookupLoopParam name vtable = do + FParam e <- entryType <$> M.lookup name (bindings vtable) + fparamMerge e + -- | In symbol table and not consumed. available :: VName -> SymbolTable lore -> Bool available name = maybe False (not . entryConsumed) . M.lookup name . bindings @@ -443,7 +453,8 @@ insertFParam fparam = insertEntry name entry FParam FParamEntry { fparamDec = AST.paramDec fparam, - fparamAliases = mempty + fparamAliases = mempty, + fparamMerge = Nothing } insertFParams :: @@ -464,6 +475,29 @@ insertLParam param = insertEntry name bind } name = AST.paramName param +-- | Insert entries corresponding to the parameters of a loop (not +-- distinguishing contect and value part). Apart from the parameter +-- itself, we also insert the initial value and the subexpression +-- providing the final value. Note that the latter is likely not in +-- scope in the symbol at this point. This is OK, and can still be +-- used to help some loop optimisations detect invariant loop +-- parameters. +insertLoopMerge :: + ASTLore lore => + [(AST.FParam lore, SubExp, SubExp)] -> + SymbolTable lore -> + SymbolTable lore +insertLoopMerge = flip $ foldl' $ flip bind + where + bind (p, initial, res) = + insertEntry (paramName p) $ + FParam + FParamEntry + { fparamDec = AST.paramDec p, + fparamAliases = mempty, + fparamMerge = Just (initial, res) + } + insertLoopVar :: ASTLore lore => VName -> IntType -> SubExp -> SymbolTable lore -> SymbolTable lore insertLoopVar name it bound = insertEntry name bind where diff --git a/src/Futhark/Optimise/Simplify/Engine.hs b/src/Futhark/Optimise/Simplify/Engine.hs index cb042a9d33..540464f7f3 100644 --- a/src/Futhark/Optimise/Simplify/Engine.hs +++ b/src/Futhark/Optimise/Simplify/Engine.hs @@ -239,6 +239,13 @@ bindArrayLParams :: bindArrayLParams params = localVtable $ \vtable -> foldl' (flip ST.insertLParam) vtable params +bindMerge :: + SimplifiableLore lore => + [(FParam (Wise lore), SubExp, SubExp)] -> + SimpleM lore a -> + SimpleM lore a +bindMerge = localVtable . ST.insertLoopMerge + bindLoopVar :: SimplifiableLore lore => VName -> IntType -> SubExp -> SimpleM lore a -> SimpleM lore a bindLoopVar var it bound = localVtable $ ST.insertLoopVar var it bound @@ -800,7 +807,7 @@ simplifyExp (DoLoop ctx val form loopbody) = do ((loopstms, loopres), hoisted) <- enterLoop $ consumeMerge $ - bindFParams (ctxparams' ++ valparams') $ + bindMerge (zipWith withRes (ctx' ++ val') (bodyResult loopbody)) $ wrapbody $ blockIf ( hasFree boundnames `orIf` isConsumed @@ -819,6 +826,7 @@ simplifyExp (DoLoop ctx val form loopbody) = do localVtable $ flip (foldl' (flip ST.consume)) $ namesToList consumed_by_merge consumed_by_merge = freeIn $ map snd $ filter (unique . paramDeclType . fst) val + withRes (p, x) y = (p, x, y) simplifyExp (Op op) = do (op', stms) <- simplifyOp op return (Op op', stms) diff --git a/src/Futhark/Optimise/Simplify/Rules.hs b/src/Futhark/Optimise/Simplify/Rules.hs index e5cc1a367d..abc998b3a1 100644 --- a/src/Futhark/Optimise/Simplify/Rules.hs +++ b/src/Futhark/Optimise/Simplify/Rules.hs @@ -153,11 +153,11 @@ removeRedundantMergeVariables _ _ _ _ = -- We may change the type of the loop if we hoist out a shape -- annotation, in which case we also need to tweak the bound pattern. hoistLoopInvariantMergeVariables :: BinderOps lore => TopDownRuleDoLoop lore -hoistLoopInvariantMergeVariables _ pat aux (ctx, val, form, loopbody) = +hoistLoopInvariantMergeVariables vtable pat aux (ctx, val, form, loopbody) = -- Figure out which of the elements of loopresult are -- loop-invariant, and hoist them out. case foldr checkInvariance ([], explpat, [], []) $ - zip merge res of + zip3 (patternNames pat) merge res of ([], _, _, _) -> -- Nothing is invariant. Skip @@ -199,11 +199,11 @@ hoistLoopInvariantMergeVariables _ pat aux (ctx, val, form, loopbody) = (Nothing, explpat') checkInvariance - ((mergeParam, mergeInit), resExp) + (pat_name, (mergeParam, mergeInit), resExp) (invariant, explpat', merge', resExps) | not (unique (paramDeclType mergeParam)) || arrayRank (paramDeclType mergeParam) == 1, - isInvariant resExp, + isInvariant, -- Also do not remove the condition in a while-loop. not $ paramName mergeParam `nameIn` freeIn form = let (bnd, explpat'') = @@ -214,21 +214,36 @@ hoistLoopInvariantMergeVariables _ pat aux (ctx, val, form, loopbody) = resExps ) where - -- A non-unique merge variable is invariant if the corresponding - -- subexp in the result is EITHER: + -- A non-unique merge variable is invariant if one of the + -- following is true: -- - -- (0) a variable of the same name as the parameter, where - -- all existential parameters are already known to be - -- invariant - isInvariant (Var v2) - | paramName mergeParam == v2 = + -- (0) The result is a variable of the same name as the + -- parameter, where all existential parameters are already + -- known to be invariant + isInvariant + | Var v2 <- resExp, + paramName mergeParam == v2 = allExistentialInvariant (namesFromList $ map (identName . fst) invariant) mergeParam - -- (1) or identical to the initial value of the parameter. - isInvariant _ = mergeInit == resExp - checkInvariance ((mergeParam, mergeInit), resExp) (invariant, explpat', merge', resExps) = - (invariant, explpat', (mergeParam, mergeInit) : merge', resExp : resExps) + -- (1) The result is identical to the initial parameter value. + | mergeInit == resExp = True + -- (2) The initial parameter value is equal to an outer + -- loop parameter 'P', where the initial value of 'P' is + -- equal to 'resExp', AND 'resExp' ultimately becomes the + -- new value of 'P'. XXX: it's a bit clumsy that this + -- only works for one level of nesting, and I think it + -- would not be too hard to generalise. + | Var init_v <- mergeInit, + Just (p_init, p_res) <- ST.lookupLoopParam init_v vtable, + p_init == resExp, + p_res == Var pat_name = + True + | otherwise = False + checkInvariance + (_pat_name, (mergeParam, mergeInit), resExp) + (invariant, explpat', merge', resExps) = + (invariant, explpat', (mergeParam, mergeInit) : merge', resExp : resExps) allExistentialInvariant namesOfInvariant mergeParam = all (invariantOrNotMergeParam namesOfInvariant) $ diff --git a/tests/loops/loop15.fut b/tests/loops/loop15.fut new file mode 100644 index 0000000000..01cab0829e --- /dev/null +++ b/tests/loops/loop15.fut @@ -0,0 +1,9 @@ +-- Simple case; simplify away the loops. +-- == +-- input { 10 2 } output { 2 } +-- structure { DoLoop 0 } + +let main (n: i32) (a: i32) = + loop x = a for _i < n do + loop _y = x for _j < n do + a diff --git a/tests/loops/loop16.fut b/tests/loops/loop16.fut new file mode 100644 index 0000000000..d4b0d64b45 --- /dev/null +++ b/tests/loops/loop16.fut @@ -0,0 +1,11 @@ +-- Complex case; simplify away the loops. +-- == +-- input { 10 2 [1,2,3] } +-- output { [1,2] } +-- structure { DoLoop 0 } + +let main (n: i32) (a: i32) (arr: []i32) = + #[unsafe] -- Just to make the IR cleaner. + loop x = take a arr for _i < n do + loop _y = take (length x) arr for _j < n do + take a arr