Large diffs are not rendered by default.

@@ -10,6 +10,8 @@ import qualified Data.HashMap.Lazy as HM
import qualified Data.HashSet as HS
import qualified Data.List as L

import Prelude

import Futhark.MonadFreshNames
import Futhark.Representation.Basic
import Futhark.Substitute
@@ -18,7 +20,7 @@ import Futhark.Substitute

data FlatState = FlatState {
vnameSource :: VNameSource
, mapLetArrays :: M.Map Ident Ident
, mapLetArrays :: M.Map VName Ident
-- ^ arrays for let values in maps
--
-- @let res = map (\xs -> let y = reduce(+,0,xs) in
@@ -44,6 +46,9 @@ instance MonadFreshNames FlatM where
getNameSource = gets vnameSource
putNameSource newSrc = modify $ \s -> s { vnameSource = newSrc }

instance HasTypeEnv FlatM where
askTypeEnv = error "Please give Futhark.Flattening a proper type environment."

--------------------------------------------------------------------------------

data Error = Error String
@@ -65,7 +70,7 @@ flatError e = FlatM . lift $ Left e
-- Functions for working with FlatState
--------------------------------------------------------------------------------

getMapLetArray' :: Ident -> FlatM Ident
getMapLetArray' :: VName -> FlatM Ident
getMapLetArray' ident = do
letArrs <- gets mapLetArrays
case M.lookup ident letArrs of
@@ -74,7 +79,7 @@ getMapLetArray' ident = do
pretty ident ++
" in table"

addMapLetArray :: Ident -> Ident -> FlatM ()
addMapLetArray :: VName -> Ident -> FlatM ()
addMapLetArray ident letArr = do
letArrs <- gets mapLetArrays
case M.lookup ident letArrs of
@@ -93,7 +98,7 @@ getFlattenedDims (outer,inner) = do
case M.lookup (outer,inner) fds of
Just sz -> return sz
Nothing -> do
new <- liftM Var $ newIdent "size" (Basic Int)
new <- Var <$> newVName "size"
let fds' = M.insert (outer,inner) new fds
modify (\s -> s{flattenedDims = fds'})
return new
@@ -128,7 +133,8 @@ transformBody (Body lore bindings (Result ses)) = do
-- Only maps needs to be transformed, @map f xs@ ~~> @f^ xs@
transformBinding :: Binding -> FlatM [Binding]
transformBinding topBnd@(Let (Pattern pats) ()
(LoopOp (Map certs lambda idents))) = do
(LoopOp (Map certs lambda arrs))) = do
idents <- mapM toIdent arrs
okLamBnds <- mapM isSafeToMapBinding lamBnds
let grouped = foldr group [] $ zip okLamBnds lamBnds

@@ -139,9 +145,9 @@ transformBinding topBnd@(Let (Pattern pats) ()
case grouped of
[Right _] -> return [topBnd]
_ -> do
let loopinv_idents =
filter (`notElem` idents) $ filter (isJust . identDimentionality) $
HS.toList $ freeInExp (LoopOp $ Map certs lambda idents)
loopinv_idents <-
filter (`notElem` idents) <$> filter (isJust . identDimentionality) <$>
mapM toIdent (HS.toList $ freeInExp (LoopOp $ Map certs lambda arrs))
(loopinv_repbnds, loopinv_repidents) <-
mapAndUnzipM (replicateIdent outerSize) loopinv_idents

@@ -152,11 +158,13 @@ transformBinding topBnd@(Let (Pattern pats) ()
, mapCerts = certs
}

let mapResNeed = HS.unions $ map freeIn
mapResNeed <- liftM HS.fromList $ mapM toIdent $ HS.toList $
HS.unions $ map freeIn
(resultSubExps $ bodyResult $ lambdaBody lambda)
let freeIdents = flip map grouped $ \case
Right bnds -> HS.unions $ map (freeInExp . bindingExp) bnds
Left bnd -> freeInExp $ bindingExp bnd
freeIdents <- liftM (map HS.fromList) $ mapM (mapM toIdent . HS.toList) $
flip map grouped $ \case
Right bnds -> HS.unions $ map (freeInExp . bindingExp) bnds
Left bnd -> freeInExp $ bindingExp bnd
let _:needed = scanr HS.union mapResNeed freeIdents
let defining = flip map grouped $ \case
-- TODO: assuming Bindage == BindVar (which is ok?)
@@ -179,7 +187,7 @@ transformBinding topBnd@(Let (Pattern pats) ()
res' <- forM (resultSubExps . bodyResult $ lambdaBody lambda) $
\se -> case se of
(Constant bv) -> return $ Constant bv
(Var ident) -> liftM Var $ getMapLetArray' ident
(Var ident) -> Var <$> identName <$> getMapLetArray' ident

let resBnds =
zipWith (\pe se -> Let (Pattern [pe]) () (PrimOp $ SubExp se))
@@ -203,28 +211,28 @@ transformBinding topBnd@(Let (Pattern pats) ()

(mapIdents, argArrs) <- liftM (unzip . catMaybes)
$ forM argsNeeded $ \arg -> do
argArr <- findTarget mapInfo arg
argArr <- findTarget mapInfo $ identName arg
case argArr of
Just val -> return $ Just (arg, val)
Nothing -> return Nothing

pat <- liftM (Pattern . map (\i -> PatElem i BindVar () ))
$ forM shouldReturn $ \i -> do
iArr <- wrapInArrIdent (mapSize mapInfo) i
addMapLetArray i iArr
addMapLetArray (identName i) iArr
return iArr

let lamBody = Body { bodyLore = ()
, bodyBindings = bnds
, bodyResult = Result $ map Var shouldReturn
, bodyResult = Result $ map (Var . identName) shouldReturn
}

let wrapLambda = Lambda { lambdaParams = mapIdents
, lambdaBody = lamBody
, lambdaReturnType = map identType shouldReturn
}
, lambdaBody = lamBody
, lambdaReturnType = map identType shouldReturn
}

let theMapExp = LoopOp $ Map certs wrapLambda argArrs
let theMapExp = LoopOp $ Map certs wrapLambda $ map identName argArrs
return $ Let pat () theMapExp

transformBinding bnd = return [bnd]
@@ -273,8 +281,8 @@ pullOutOfMap :: MapInfo -> ([Ident], [Ident]) -> Binding -> FlatM [Binding]
pullOutOfMap _ (_,[]) _ = return []
pullOutOfMap mapInfo _
(Let (Pattern [PatElem resIdent BindVar patlore]) letlore
(PrimOp (Reshape certs dimses reshapeident))) = do
Just target <- findTarget mapInfo reshapeident
(PrimOp (Reshape certs dimses reshapearr))) = do
Just target <- findTarget mapInfo reshapearr

loopdep_dim_subexps <- filterM (\case
Var i -> liftM isJust $ findTarget mapInfo i
@@ -293,18 +301,18 @@ pullOutOfMap mapInfo _
return $ Ident vn' (Array bt (Shape (mapSize mapInfo:shpdms)) uniq)
_ -> flatError $ Error "impossible, result of reshape not list"

addMapLetArray resIdent newResIdent
addMapLetArray (identName resIdent) newResIdent

let newReshape = PrimOp $ Reshape (certs ++ mapCerts mapInfo)
(mapSize mapInfo:dimses) target
(mapSize mapInfo:dimses) $ identName target

return [Let (Pattern [PatElem newResIdent BindVar patlore])
letlore newReshape]

pullOutOfMap mapInfo (argsNeeded, _)
(Let (Pattern pats) letlore
(LoopOp (Map certs lambda idents))) = do

(LoopOp (Map certs lambda arrs))) = do
idents <- mapM toIdent arrs
-- For all argNeeded that are not already being mapped over:
--
-- 1) if they where created as an intermediate result in the outer map,
@@ -336,11 +344,11 @@ pullOutOfMap mapInfo (argsNeeded, _)
-----------------------------------------------
(okIdents, okLambdaParams) <-
liftM unzip
$ filterM (\(i,_) -> isJust <$> findTarget mapInfo i)
$ filterM (\(i,_) -> isJust <$> findTarget mapInfo (identName i))
$ zip idents (lambdaParams lambda)
(loopInvIdents, loopInvLambdaParams) <-
liftM unzip
$ filterM (\(i,_) -> isNothing <$> findTarget mapInfo i)
$ filterM (\(i,_) -> isNothing <$> findTarget mapInfo (identName i))
$ zip idents (lambdaParams lambda)
(loopInvRepBnds, loopInvIdentsArrs) <- mapAndUnzipM (replicateIdent $ mapSize mapInfo)
loopInvIdents
@@ -360,12 +368,12 @@ pullOutOfMap mapInfo (argsNeeded, _)
-------------------------------------------------------------
-- Handle Idents needed by body, which are not mapped over --
-------------------------------------------------------------
let reallyNeeded = filter (\i -> not $ HS.member i $ HS.unions
$ map freeIn idents) argsNeeded
let reallyNeeded = filter (\i -> not $ HS.member (identName i) $
HS.unions $ map freeIn idents) argsNeeded
--
-- Intermediate results needed
--
itmResIdents <- filterM (\i -> isJust <$> findTarget mapInfo i) reallyNeeded
itmResIdents <- filterM (\i -> isJust <$> findTarget mapInfo (identName i)) reallyNeeded

-- Need to rename so our intermediate result will not be found in
-- other calls (through mapLetArray)
@@ -400,8 +408,8 @@ pullOutOfMap mapInfo (argsNeeded, _)

let mapBnd' = Let (Pattern pats') letlore
(LoopOp (Map (certs ++ mapCerts mapInfo)
lambda'
newInnerIdents))
lambda' $
map identName newInnerIdents))

mapBnd'' <- transformBinding mapBnd'

@@ -433,7 +441,7 @@ pullOutOfMap mapInfo (argsNeeded, _)
distIdent <- newIdent (baseString vn ++ "_dist") distTp

let distExp = Apply (nameFromString "distribute")
[(Var i, Observe), (sz, Observe)]
[(Var vn, Observe), (sz, Observe)]
-- TODO: I guess Observe is okay for now
(basicRetType Int) -- FIXME

@@ -444,8 +452,8 @@ pullOutOfMap mapInfo (argsNeeded, _)

-- | Steps for exiting a nested map, meaning we step-up/unflatten the result
unflattenRes :: PatElem -> FlatM (Binding, PatElem)
unflattenRes (PatElem i@(Ident vn (Array bt (Shape (outer:rest)) uniq))
BindVar patLore) = do
unflattenRes (PatElem (Ident vn (Array bt (Shape (outer:rest)) uniq))
BindVar patLore) = do
flatSize <- getFlattenedDims (mapSize mapInfo, outer)
let flatTp = Array bt (Shape $ flatSize:rest) uniq
flatResArr <- newIdent (baseString vn ++ "_sd") flatTp
@@ -454,10 +462,10 @@ pullOutOfMap mapInfo (argsNeeded, _)
let finalTp = Array bt (Shape $ mapSize mapInfo :outer:rest) uniq
finalResArr <- newIdent (baseString vn) finalTp

addMapLetArray i finalResArr
addMapLetArray vn finalResArr

let unflattenExp = Apply (nameFromString "stepup")
[(Var flatResArr, Observe)]
[(Var $ identName flatResArr, Observe)]
-- ^ TODO: I guess Observe is okay for now
(basicRetType Int)
-- ^ TODO: stupid exsitensial types :(
@@ -468,10 +476,10 @@ pullOutOfMap mapInfo (argsNeeded, _)
unflattenRes pe = flatError $ Error $ "unflattenRes applied to " ++ pretty pe

pullOutOfMap mapinfo _ (Let (Pattern [PatElem ident1 BindVar _]) _
(PrimOp (SubExp (Var ident2))))
| identType ident1 == identType ident2 = do
addMapLetArray ident1 =<< findTarget1 mapinfo ident2
return []
(PrimOp (SubExp (Var name)))) = do
ident2 <- toIdent name
addMapLetArray (identName ident1) =<< findTarget1 mapinfo ident2
return []

pullOutOfMap _ _ binding =
flatError $ Error $ "pullOutOfMap not implemented for " ++ pretty binding ++
@@ -504,7 +512,7 @@ flattenArg mapInfo targInfo = do
flatIdent <- newIdent (baseString (identName target) ++ "_sd") flatTp

let flattenExp = Apply (nameFromString "stepdown")
[(Var target, Observe)]
[(Var $ identName target, Observe)]
-- ^ TODO: I guess Observe is okay for now
(basicRetType Int)
-- ^ TODO: stupid exsitensial types :(
@@ -516,18 +524,18 @@ flattenArg mapInfo targInfo = do
return (flatBnd, flatIdent)

-- | Find the "parent" array for a given Ident in a /specific/ map
findTarget :: MapInfo -> Ident -> FlatM (Maybe Ident)
findTarget :: MapInfo -> VName -> FlatM (Maybe Ident)
findTarget mapInfo i =
case L.elemIndex i (lamParams mapInfo) of
case L.elemIndex i $ map identName $ lamParams mapInfo of
Just n -> return . Just $ mapListArgs mapInfo !! n
Nothing -> if i `notElem` mapLets mapInfo
Nothing -> if i `notElem` map identName (mapLets mapInfo)
-- this argument is loop invariant
then return Nothing
else liftM Just $ getMapLetArray' i

findTarget1 :: MapInfo -> Ident -> FlatM Ident
findTarget1 mapInfo i =
findTarget mapInfo i >>= \case
findTarget mapInfo (identName i) >>= \case
Just iArr -> return iArr
Nothing -> flatError $ Error $ "findTarget': couldn't find expected arr for "
++ pretty i
@@ -543,7 +551,7 @@ wrapInArrIdent sz (Ident vn tp) = do
replicateIdent :: SubExp -> Ident -> FlatM (Binding, Ident)
replicateIdent sz i = do
arrRes <- wrapInArrIdent sz i
let repExp = PrimOp $ Replicate sz (Var i)
let repExp = PrimOp $ Replicate sz $ Var $ identName i
repBnd = Let (Pattern [PatElem arrRes BindVar ()]) () repExp

return (repBnd, arrRes)
@@ -561,7 +569,7 @@ isSafeToMapBinding (Let _ _ e) = isSafeToMapExp e
-- Else we need to apply a segmented operator on it
isSafeToMapExp :: Exp -> FlatM Bool
isSafeToMapExp (PrimOp po) = do
let ts = primOpType po
ts <- primOpType po
and <$> mapM isSafeToMapType ts
-- DoLoop/Map/ConcatMap/Reduce/Scan/Filter/Redomap
isSafeToMapExp (LoopOp _) = return False
@@ -581,3 +589,8 @@ isSafeToMapType (Array{}) = return False
identDimentionality :: Ident -> Maybe Int
identDimentionality (Ident _ (Array _ (Shape dims) _)) = Just $ length dims
identDimentionality _ = Nothing

-- XXX: use of this function probably means that there is a design
-- flaw.
toIdent :: HasTypeEnv m => VName -> m Ident
toIdent name = Ident name <$> lookupType name

Large diffs are not rendered by default.

@@ -5,82 +5,87 @@ module Futhark.Internalise.AccurateSizes
, ensureResultShape
, ensureResultExtShape
, ensureShape
, ensureShapeIdent
, ensureShapeVar
)
where

import Control.Applicative
import Control.Monad
import Data.Loc

import qualified Data.HashMap.Lazy as HM

import Prelude

import Futhark.Representation.Basic
import Futhark.Representation.AST
import Futhark.Tools
import Futhark.MonadFreshNames

shapeBody :: [VName] -> [Type] -> Body -> Body
shapeBody shapenames ts (Body () bnds (Result ses)) =
Body () bnds $ Result shapes
where shapes = argShapes shapenames ts ses
shapeBody :: (HasTypeEnv m, MonadFreshNames m, Bindable lore) =>
[VName] -> [Type] -> Body lore
-> m (Body lore)
shapeBody shapenames ts body =
runBinder $ do
ses <- bodyBind body
sets <- mapM subExpType ses
return $ resultBody $ argShapes shapenames ts sets

annotateArrayShape :: ArrayShape shape =>
TypeBase shape -> [Int] -> TypeBase Shape
annotateArrayShape t newshape =
t `setArrayShape` Shape (take (arrayRank t) (map intconst $ newshape ++ repeat 0))

argShapes :: [VName] -> [Type] -> [SubExp] -> [SubExp]
argShapes shapes valts valargs =
argShapes :: [VName] -> [Type] -> [Type] -> [SubExp]
argShapes shapes valts valargts =
map addShape shapes
where mapping = shapeMapping valts $ map subExpType valargs
where mapping = shapeMapping valts valargts
addShape name
| Just se <- HM.lookup name mapping = se
| otherwise = Constant (IntVal 0)

ensureResultShape :: MonadBinder m =>
SrcLoc -> [Type] -> Body
-> m Body
ensureResultShape :: (HasTypeEnv m, MonadFreshNames m, Bindable lore) =>
SrcLoc -> [Type] -> Body lore
-> m (Body lore)
ensureResultShape loc =
ensureResultExtShape loc . staticShapes

ensureResultExtShape :: MonadBinder m =>
SrcLoc -> [ExtType] -> Body
-> m Body
ensureResultExtShape loc rettype body = runBinder $ do
es <- bodyBind body
let assertProperShape t se =
let name = "result_proper_shape"
in ensureExtShape loc t name se
reses <- zipWithM assertProperShape rettype es
return $ resultBody reses
ensureResultExtShape :: (HasTypeEnv m, MonadFreshNames m, Bindable lore) =>
SrcLoc -> [ExtType] -> Body lore
-> m (Body lore)
ensureResultExtShape loc rettype body =
runBinder $ insertBindingsM $ do
es <- bodyBind body
let assertProperShape t se =
let name = "result_proper_shape"
in ensureExtShape loc t name se
reses <- zipWithM assertProperShape rettype es
mkBodyM [] $ Result reses

ensureExtShape :: MonadBinder m =>
SrcLoc -> ExtType -> String -> SubExp
-> m SubExp
SrcLoc -> ExtType -> String -> SubExp
-> m SubExp
ensureExtShape loc t name orig
| Array{} <- t, Var v <- orig =
Var <$> ensureShapeIdent loc t name v
Var <$> ensureShapeVar loc t name v
| otherwise = return orig

ensureShape :: MonadBinder m =>
SrcLoc -> Type -> String -> SubExp
-> m SubExp
ensureShape loc = ensureExtShape loc . staticShapes1

ensureShapeIdent :: MonadBinder m =>
SrcLoc -> ExtType -> String -> Ident
-> m Ident
ensureShapeIdent loc t name v
ensureShapeVar :: MonadBinder m =>
SrcLoc -> ExtType -> String -> VName
-> m VName
ensureShapeVar loc t name v
| Array{} <- t = do
let checkDim desired has =
letExp "shape_cert" =<<
eAssert (pure $ PrimOp $ BinOp Equal desired has Bool) loc
certs <- zipWithM checkDim newshape oldshape
letExp name $ PrimOp $ Reshape certs newshape v
newshape <- arrayDims <$> removeExistentials t <$> lookupType v
oldshape <- arrayDims <$> lookupType v
let checkDim desired has =
letExp "shape_cert" =<<
eAssert (pure $ PrimOp $ BinOp Equal desired has Bool) loc
certs <- zipWithM checkDim newshape oldshape
letExp name $ PrimOp $ Reshape certs newshape v
| otherwise = return v
where newshape = arrayDims $ removeExistentials t $ identType v
oldshape = arrayDims $ identType v

removeExistentials :: ExtType -> Type -> Type
removeExistentials t1 t2 =
@@ -22,7 +22,6 @@ import Data.Traversable (mapM)

import Futhark.Representation.External as E
import Futhark.Representation.Basic as I
import Futhark.Tools as I
import Futhark.MonadFreshNames

import Futhark.Internalise.Monad
@@ -66,7 +65,7 @@ internaliseFunParams params = do
[ new_param { I.identType = t } |
(new_param, t) <- zip params' instantiated_param_types ]
return (param_implicit_shapes, instantiated_params)
let subst = HM.fromList $ zip (map E.identName params) value_params
let subst = HM.fromList $ zip (map E.identName params) (map (map I.identName) value_params)
return (declared_shape_params ++ concat implicit_shape_params,
concat value_params,
subst <> shapesubst)
@@ -77,7 +76,9 @@ bindingParams :: [E.Parameter]
bindingParams params m = do
(shapeparams, valueparams, substs) <- internaliseFunParams params
let bind env = env { envSubsts = substs `HM.union` envSubsts env }
local bind $ m shapeparams valueparams
local bind $
bindingIdentTypes (shapeparams++valueparams) $
m shapeparams valueparams

bindingFlatPattern :: [E.Ident] -> [I.Type]
-> ([I.Ident] -> InternaliseM a)
@@ -87,12 +88,13 @@ bindingFlatPattern = bindingFlatPattern' []
bindingFlatPattern' pat [] _ m = do
let (vs, substs) = unzip pat
substs' = HM.fromList substs
local (\env -> env { envSubsts = substs' `HM.union` envSubsts env})
$ m $ concat $ reverse vs
idents = concat $ reverse vs
local (\env -> env { envSubsts = substs' `HM.union` envSubsts env}) $
m idents

bindingFlatPattern' pat (p:rest) ts m = do
(ps, subst, rest_ts) <- handleMapping ts <$> internaliseBindee p
bindingFlatPattern' ((ps, (E.identName p, subst)) : pat) rest rest_ts m
bindingFlatPattern' ((ps, (E.identName p, map I.identName subst)) : pat) rest rest_ts m

handleMapping ts [] =
([], [], ts)
@@ -121,15 +123,16 @@ bindingTupIdent :: E.TupIdent -> [ExtType] -> (I.Pattern -> InternaliseM a)
-> InternaliseM a
bindingTupIdent pat ts m = do
pat' <- flattenPattern pat
(ts',shapes) <- I.instantiateShapes' ts
(ts',shapes) <- instantiateShapes' ts
let addShapeBindings pat'' = m $ I.basicPattern' $ shapes ++ pat''
bindingFlatPattern pat' ts' addShapeBindings

bindingLambdaParams :: [E.Parameter] -> [I.Type]
-> InternaliseM I.Body
-> InternaliseM (I.Body, [I.Param])
bindingLambdaParams params ts m =
bindingFlatPattern (map E.fromParam params) ts $ \params' -> do
bindingFlatPattern (map E.fromParam params) ts $ \params' ->
bindingIdentTypes params' $ do
body <- m
return (body, params')

@@ -140,20 +143,20 @@ makeShapeIdentsFromContext :: MonadFreshNames m =>
makeShapeIdentsFromContext ctx = do
(ctx', substs) <- liftM unzip $ forM (HM.toList ctx) $ \(name, i) -> do
v <- newIdent (baseString name) $ I.Basic Int
return ((i, v), (name, [v]))
return ((i, v), (name, [I.identName v]))
return (HM.fromList ctx', HM.fromList substs)

instantiateShapesWithDecls :: MonadFreshNames m =>
HM.HashMap Int I.Ident
-> [I.ExtType]
-> m ([I.Type], [I.Ident])
instantiateShapesWithDecls ctx ts =
runWriterT $ I.instantiateShapes instantiate ts
runWriterT $ instantiateShapes instantiate ts
where instantiate x
| Just v <- HM.lookup x ctx =
return $ I.Var v
return $ I.Var $ I.identName v

| otherwise = do
v <- lift $ newIdent "size" (I.Basic Int)
tell [v]
return $ I.Var v
return $ I.Var $ I.identName v
@@ -11,14 +11,12 @@ module Futhark.Internalise.Lambdas

import Control.Applicative
import Control.Monad

import Data.List
import Data.Loc

import Futhark.Representation.External as E
import Futhark.Representation.Basic as I
import Futhark.MonadFreshNames
import Futhark.Tools

import Futhark.Internalise.Monad
import Futhark.Internalise.AccurateSizes
@@ -34,15 +32,17 @@ internaliseMapLambda :: InternaliseLambda
-> [I.SubExp]
-> InternaliseM I.Lambda
internaliseMapLambda internaliseLambda lam args = do
let argtypes = map I.subExpType args
rowtypes = map I.rowType argtypes
argtypes <- mapM I.subExpType args
let rowtypes = map I.rowType argtypes
(params, body, rettype) <- internaliseLambda lam $ Just rowtypes
(rettype', inner_shapes) <- instantiateShapes' rettype
let outer_shape = arraysSize 0 argtypes
shape_body = shapeBody (map I.identName inner_shapes) rettype' body
shape_body <- bindingIdentTypes params $
shapeBody (map I.identName inner_shapes) rettype' body
shapefun <- makeShapeFun params shape_body (length inner_shapes)
bindMapShapes inner_shapes shapefun args outer_shape
body' <- ensureResultShape (srclocOf lam) rettype' body
body' <- bindingIdentTypes params $
ensureResultShape (srclocOf lam) rettype' body
return $ I.Lambda params body' rettype'

internaliseConcatMapLambda :: InternaliseLambda
@@ -101,7 +101,8 @@ internaliseFoldLambda internaliseLambda lam acctypes arrtypes = do
-- The result of the body must have the exact same shape as the
-- initial accumulator. We accomplish this with an assertion and
-- reshape().
body' <- ensureResultShape (srclocOf lam) rettype' body
body' <- bindingIdentTypes params $
ensureResultShape (srclocOf lam) rettype' body

return $ I.Lambda params body' rettype'

@@ -113,9 +114,9 @@ internaliseRedomapInnerLambda ::
-> [I.SubExp]
-> InternaliseM I.Lambda
internaliseRedomapInnerLambda internaliseLambda lam nes arr_args = do
let arrtypes = map I.subExpType arr_args
rowtypes = map I.rowType arrtypes
acctypes = map I.subExpType nes
arrtypes <- mapM I.subExpType arr_args
acctypes <- mapM I.subExpType nes
let rowtypes = map I.rowType arrtypes
--
(params, body, rettype) <- internaliseLambda lam $ Just $
acctypes ++ rowtypes
@@ -138,7 +139,8 @@ internaliseRedomapInnerLambda internaliseLambda lam nes arr_args = do
map_bindings= acc_bindings ++ bodyBindings body
map_lore = bodyLore body
map_body = I.Body map_lore map_bindings map_bodyres
shape_body = shapeBody (map I.identName inner_shapes) rettypearr' map_body
shape_body <- bindingIdentTypes params $
shapeBody (map I.identName inner_shapes) rettypearr' map_body
shapefun <- makeShapeFun (drop acc_len params) shape_body (length inner_shapes)
bindMapShapes inner_shapes shapefun arr_args outer_shape
--
@@ -150,7 +152,8 @@ internaliseRedomapInnerLambda internaliseLambda lam nes arr_args = do
-- an assertion and reshape().
--
-- finally, place assertions and return result
body' <- ensureResultShape (srclocOf lam) (acctype'++rettypearr') body
body' <- bindingIdentTypes params $
ensureResultShape (srclocOf lam) (acctype'++rettypearr') body
return $ I.Lambda params body' (acctype'++rettypearr')

internaliseStreamLambda :: InternaliseLambda
@@ -159,7 +162,7 @@ internaliseStreamLambda :: InternaliseLambda
-> [I.Type]
-> InternaliseM I.ExtLambda
internaliseStreamLambda internaliseLambda lam accs arrtypes = do
let acctypes = map I.subExpType accs
acctypes <- mapM I.subExpType accs
(params, body, rettype) <- internaliseLambda lam $ Just $
acctypes++arrtypes
-- split rettype into (i) accummulator types && (ii) result-array-elem types
@@ -180,25 +183,27 @@ internaliseStreamLambda internaliseLambda lam accs arrtypes = do
let acctype' = [ t `I.setArrayShape` arrayShape shape
| (t,shape) <- zip lam_acc_tps acctypes ]
body' <- insertBindingsM $ do
let mkArrType :: (I.Ident, I.ExtType) -> I.Type
mkArrType (x, I.Array btp shp u) =
let dsx = (I.shapeDims . I.arrayShape . I.identType) x
dsrtpx = I.extShapeDims shp
let mkArrType :: (VName, ExtType) -> InternaliseM I.Type
mkArrType (x, I.Array btp shp u) = do
dsx <- I.shapeDims <$> I.arrayShape <$> I.lookupType x
let dsrtpx = I.extShapeDims shp
resdims= zipWith (\ dx drtpx ->
case drtpx of
Ext _ -> dx
Free s -> s
) dsx dsrtpx
in I.Array btp (I.Shape resdims) u
mkArrType (_, I.Basic btp ) = I.Basic btp
mkArrType (_, I.Mem se ) = I.Mem se
return $ I.Array btp (I.Shape resdims) u
mkArrType (_, I.Basic btp ) =
return $ I.Basic btp
mkArrType (_, I.Mem se ) =
return $ I.Mem se
lamres <- bodyBind body
let (lamacc_res, lamarr_res) = (take acc_len lamres, drop acc_len lamres)
lamarr_idtps = concatMap (\(y,tp) -> case y of
I.Var ii -> [(ii,tp)]
_ -> []
) (zip lamarr_res lam_arr_tps)
arrtype' = map mkArrType lamarr_idtps
arrtype' <- mapM mkArrType lamarr_idtps
reses1 <- zipWithM assertProperShape acctype' lamacc_res
reses2 <- zipWithM assertProperShape arrtype' lamarr_res
return $ resultBody $ reses1 ++ reses2
@@ -212,8 +217,8 @@ internalisePartitionLambdas :: InternaliseLambda
-> [I.SubExp]
-> InternaliseM I.Lambda
internalisePartitionLambdas internaliseLambda lams args = do
let argtypes = map I.subExpType args
rowtypes = map I.rowType argtypes
argtypes <- mapM I.subExpType args
let rowtypes = map I.rowType argtypes
lams' <- forM lams $ \lam -> do
(params, body, _) <- internaliseLambda lam $ Just rowtypes
return (params, body)
@@ -233,14 +238,14 @@ internalisePartitionLambdas internaliseLambda lams args = do
next_lam_body <-
mkCombinedLambdaBody lam_params (i+1) lams'
let parambnds =
[ mkLet' [top] $ I.PrimOp $ I.SubExp $ I.Var fromp
[ mkLet' [top] $ I.PrimOp $ I.SubExp $ I.Var $ I.identName fromp
| (top,fromp) <- zip lam_params params ]
branchbnd = mkLet' [intres] $ I.If boolres
(resultBody [intconst i])
next_lam_body
[I.Basic Int]
return $ mkBody
(parambnds++bodybnds++[branchbnd])
(Result [I.Var intres])
(Result [I.Var $ I.identName intres])
_ ->
fail "Partition lambda returns too many values."
@@ -9,6 +9,9 @@ module Futhark.Internalise.Monad
, InternaliseEnv(..)
, FunBinding (..)
, lookupFunction
, bindingIdentTypes
-- * Convenient reexports
, module Futhark.Tools
)
where

@@ -25,13 +28,14 @@ import Data.List
import qualified Futhark.Representation.External as E
import Futhark.Representation.Basic
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Tools hiding (bindingIdentTypes)
import qualified Futhark.Tools as F

import Prelude hiding (mapM)

data FunBinding = FunBinding
{ internalFun :: ([VName], [Type],
[SubExp] -> Maybe ExtRetType)
[(SubExp,Type)] -> Maybe ExtRetType)
, externalFun :: (E.DeclType, [E.DeclType])
}

@@ -41,7 +45,7 @@ type FunTable = HM.HashMap Name FunBinding

-- | A mapping from external variable names to the corresponding
-- internalised identifiers.
type VarSubstitutions = HM.HashMap VName [Ident]
type VarSubstitutions = HM.HashMap VName [VName]

data InternaliseEnv = InternaliseEnv {
envSubsts :: VarSubstitutions
@@ -57,7 +61,7 @@ initialFtable = HM.map addBuiltin builtInFunctions
const $ Just $ ExtRetType [Basic t])
(E.Basic t, map E.Basic paramts)

newtype InternaliseM a = InternaliseM (WriterT (DL.DList Binding)
newtype InternaliseM a = InternaliseM (BinderT Basic
(ReaderT InternaliseEnv
(StateT VNameSource
(Except String)))
@@ -72,17 +76,19 @@ instance MonadFreshNames InternaliseM where
getNameSource = get
putNameSource = put

instance HasTypeEnv InternaliseM where
askTypeEnv = InternaliseM askTypeEnv

instance MonadBinder InternaliseM where
type Lore InternaliseM = Basic
mkLetM pat e = return $ mkLet pat' e
where pat' = [ (ident, bindage)
| PatElem ident bindage _ <- patternElements pat
]
mkBodyM bnds res = return $ mkBody bnds res
mkLetNamesM = mkLetNames
mkLetM pat e = InternaliseM $ mkLetM pat e
mkBodyM bnds res = InternaliseM $ mkBodyM bnds res
mkLetNamesM pat e = InternaliseM $ mkLetNamesM pat e

addBinding = addBindingWriter
collectBindings = collectBindingsWriter
addBinding =
InternaliseM . addBinding
collectBindings (InternaliseM m) =
InternaliseM $ collectBindings m

runInternaliseM :: MonadFreshNames m =>
Bool -> FunTable -> InternaliseM a
@@ -92,7 +98,7 @@ runInternaliseM boundsCheck ftable (InternaliseM m) =
let onError e = (Left e, src)
onSuccess ((prog,_),src') = (Right prog, src')
in either onError onSuccess $ runExcept $
runStateT (runReaderT (runWriterT m) newEnv) src
runStateT (runReaderT (runBinderT m mempty) newEnv) src
where newEnv = InternaliseEnv {
envSubsts = HM.empty
, envFtable = initialFtable `HM.union` ftable
@@ -104,3 +110,8 @@ lookupFunction fname = do
fun <- HM.lookup fname <$> asks envFtable
case fun of Nothing -> fail $ "Function '" ++ nameToString fname ++ "' not found"
Just fun' -> return fun'

bindingIdentTypes :: [Ident] -> InternaliseM a
-> InternaliseM a
bindingIdentTypes idents (InternaliseM m) =
InternaliseM $ F.bindingIdentTypes idents m
@@ -97,7 +97,7 @@ internaliseDeclType' (E.Array at) =
subst <- asks $ HM.lookup name . envSubsts
return $ I.Free $ I.Var $ case subst of
Just [v] -> v
_ -> I.Ident name $ I.Basic Int
_ -> name

internaliseType :: Ord vn =>
E.TypeBase E.Rank als vn -> [I.TypeBase ExtShape]
@@ -131,7 +131,7 @@ bindVar (BindInPlace _ src is) val = do
is' <- mapM (asInt <=< evalSubExp) is
case srcv of
ArrayVal arr bt shape -> do
flatidx <- indexArray (textual $ identName src) shape is'
flatidx <- indexArray (textual src) shape is'
if length is' == length shape then
case val of
BasicVal bv ->
@@ -185,8 +185,8 @@ binding bnds m = do
ppDim (Constant v) _ = pretty v
ppDim e v = pretty e ++ "=" ++ pretty v

lookupVar :: Ident -> FutharkM lore Value
lookupVar (Ident vname _) = do
lookupVar :: VName -> FutharkM lore Value
lookupVar vname = do
val <- asks $ HM.lookup vname . envVtable
case val of Just val' -> return val'
Nothing -> bad $ TypeError $ "lookupVar " ++ textual vname
@@ -479,7 +479,7 @@ evalPrimOp (Index _ ident idxs) = do
idxs' <- mapM (asInt <=< evalSubExp) idxs
case v of
ArrayVal arr bt shape -> do
flatidx <- indexArray (textual $ identName ident) shape idxs'
flatidx <- indexArray (textual ident) shape idxs'
if length idxs' == length shape
then return [BasicVal $ arr ! flatidx]
else let resshape = drop (length idxs') shape
@@ -594,13 +594,13 @@ evalPrimOp (Assert e loc) = do
evalPrimOp (Partition _ n flags arr) = do
flags_elems <- arrToList =<< lookupVar flags
arrv <- lookupVar arr
let et = elemType $ valueType arrv
arr_elems <- arrToList arrv
partitions <- partitionArray flags_elems arr_elems
return $
map (BasicVal . IntVal . length) partitions ++
[arrayVal (concat partitions) et (valueShape arrv)]
where et = elemType $ identType arr
partitionArray flagsv arrv =
where partitionArray flagsv arrv =
map reverse <$>
foldM divide (replicate n []) (zip flagsv arrv)

@@ -633,7 +633,7 @@ evalLoopOp (DoLoop respat merge (ForLoop loopvar boundexp) loopbody) = do
_ -> bad $ TypeError "evalBody DoLoop for"
where (mergepat, mergeexp) = unzip merge
iteration mergeval i =
binding [(loopvar, BindVar, BasicVal $ IntVal i)] $
binding [(Ident loopvar $ Basic Int, BindVar, BasicVal $ IntVal i)] $
binding (zip3 (map fparamIdent mergepat) (repeat BindVar) mergeval) $
evalBody loopbody

@@ -647,7 +647,8 @@ evalLoopOp (DoLoop respat merge (WhileLoop cond) loopbody) = do
case condv of
BasicVal (LogVal False) ->
mapM lookupVar $
loopResultContext (representative :: lore) respat mergepat ++ respat
loopResultContext (representative :: lore) respat mergepat ++
respat
BasicVal (LogVal True) ->
iteration =<< evalBody loopbody
_ ->
@@ -1,4 +1,4 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleInstances, UndecidableInstances #-}
-- | This module provides a monadic facility similar (and built on top
-- of) "Futhark.FreshNames". The removes the need for a (small) amount of
-- boilerplate, at the cost of using some GHC extensions. The idea is
@@ -23,6 +23,10 @@ module Futhark.MonadFreshNames
import Control.Applicative
import qualified Control.Monad.State.Lazy
import qualified Control.Monad.State.Strict
import qualified Control.Monad.Writer.Lazy
import qualified Control.Monad.Writer.Strict
import Control.Monad.Reader
import Data.Monoid

import Prelude

@@ -106,3 +110,20 @@ newIdents = mapM . newIdent
-- names used as variables in the given program.
newNameSourceForProg :: Prog lore -> VNameSource
newNameSourceForProg = newNameSource . progNames

-- Utility instance defintions for MTL classes. This requires
-- UndecidableInstances, but saves on typing elsewhere.

instance MonadFreshNames m => MonadFreshNames (ReaderT s m) where
getNameSource = lift getNameSource
putNameSource = lift . putNameSource

instance (MonadFreshNames m, Monoid s) =>
MonadFreshNames (Control.Monad.Writer.Lazy.WriterT s m) where
getNameSource = lift getNameSource
putNameSource = lift . putNameSource

instance (MonadFreshNames m, Monoid s) =>
MonadFreshNames (Control.Monad.Writer.Strict.WriterT s m) where
getNameSource = lift getNameSource
putNameSource = lift . putNameSource
@@ -62,14 +62,20 @@ cseInBinding (Let pat eattr e) m = do
Nothing -> local (addExpSubst pat' eattr e') $ m [Let pat' eattr e']

Just subpat ->
let lets =
[ Let (Pattern [patElem]) eattr $ PrimOp $ SubExp $ Var v
| (patElem,v) <- zip (patternElements pat') $ patternIdents subpat
]
in local (addNameSubst pat' subpat) $ m lets
local (addNameSubst pat' subpat) $ do
CSEState (_, nsubsts') <- ask
let lets =
[ Let (Pattern [patElem']) eattr $ PrimOp $ SubExp $ Var v
| (patElem,v) <- zip (patternElements pat') $ patternNames subpat,
let patElem' = setPatElemName (substituteNames nsubsts' patElem) $
patElemName patElem
]
m lets
where bad (Array _ _ Unique) = True
bad (Mem _) = True
bad _ = False
setPatElemName patElem name =
patElem { patElemIdent = Ident name $ identType $ patElemIdent patElem }

newtype CSEState lore =
CSEState (M.Map (Lore.Exp lore, Exp lore) (Pattern lore), HM.HashMap VName VName)
@@ -78,14 +78,14 @@ deadCodeElimBody = fst . runDCElimM . deadCodeElimBodyM
--------------------------------------------------------------------

deadCodeElimSubExp :: SubExp -> DCElimM SubExp
deadCodeElimSubExp (Var ident) = Var <$> deadCodeElimIdent ident
deadCodeElimSubExp (Var ident) = Var <$> deadCodeElimVName ident
deadCodeElimSubExp (Constant v) = return $ Constant v

deadCodeElimBodyM :: Proper lore => Body lore -> DCElimM (Body lore)

deadCodeElimBodyM (Body bodylore (Let pat explore e:bnds) res) = do
let idds = patternNames pat
seen $ freeNamesIn explore
seen $ freeIn explore
(Body _ bnds' res', noref) <-
collectRes idds $ do
deadCodeElimPat pat
@@ -97,7 +97,7 @@ deadCodeElimBodyM (Body bodylore (Let pat explore e:bnds) res) = do
(Let pat explore e':bnds') res'

deadCodeElimBodyM (Body bodylore [] (Result es)) = do
seen $ freeNamesIn bodylore
seen $ freeIn bodylore
Body bodylore [] <$>
(Result <$> mapM deadCodeElimSubExp es)

@@ -109,41 +109,39 @@ deadCodeElimExp (LoopOp (DoLoop respat merge form body)) = do
body' <- deadCodeElimBodyM body
case form of
ForLoop _ bound -> void $ deadCodeElimSubExp bound
WhileLoop cond -> void $ deadCodeElimIdent cond
WhileLoop cond -> void $ deadCodeElimVName cond
return $ LoopOp $ DoLoop respat merge form body'
deadCodeElimExp e = mapExpM mapper e
where mapper = Mapper {
mapOnBinding = return -- Handled in case for Body.
, mapOnBody = deadCodeElimBodyM
mapOnBody = deadCodeElimBodyM
, mapOnSubExp = deadCodeElimSubExp
, mapOnLambda = deadCodeElimLambda
, mapOnExtLambda = deadCodeElimExtLambda
, mapOnIdent = deadCodeElimIdent
, mapOnCertificates = mapM deadCodeElimIdent
, mapOnVName = deadCodeElimVName
, mapOnCertificates = mapM deadCodeElimVName
, mapOnRetType = \rt -> do
seen $ freeNamesIn rt
seen $ freeIn rt
return rt
, mapOnFParam = \fparam -> do
seen $ freeNamesIn fparam
seen $ freeIn fparam
return fparam
}

deadCodeElimIdent :: Ident -> DCElimM Ident
deadCodeElimIdent ident@(Ident vnm t) = do
tell $ DCElimRes False $ HS.singleton vnm
dims <- mapM deadCodeElimSubExp $ arrayDims t
return ident { identType = t `setArrayShape` Shape dims }
deadCodeElimVName :: VName -> DCElimM VName
deadCodeElimVName vnm = do
seen $ HS.singleton vnm
return vnm

deadCodeElimPat :: Proper lore => Pattern lore -> DCElimM ()
deadCodeElimPat = mapM_ deadCodeElimPatElem . patternElements

deadCodeElimPatElem :: FreeIn attr => PatElemT attr -> DCElimM ()
deadCodeElimPatElem patelem =
seen $ patElemName patelem `HS.delete` freeNamesIn patelem
seen $ patElemName patelem `HS.delete` freeIn patelem

deadCodeElimFParam :: FreeIn attr => FParamT attr -> DCElimM ()
deadCodeElimFParam fparam =
seen $ fparamName fparam `HS.delete` freeNamesIn fparam
seen $ fparamName fparam `HS.delete` freeIn fparam


deadCodeElimBnd :: Ident -> DCElimM ()
@@ -167,7 +165,7 @@ deadCodeElimExtLambda :: Proper lore =>
deadCodeElimExtLambda (ExtLambda params body rettype) = do
body' <- deadCodeElimBodyM body
mapM_ deadCodeElimBnd params
seen $ freeNamesIn rettype
seen $ freeIn rettype
return $ ExtLambda params body' rettype

seen :: Names -> DCElimM ()

Large diffs are not rendered by default.

@@ -12,8 +12,6 @@
-- The module will, however, remove duplicate inputs after fusion.
module Futhark.Optimise.Fusion.Composing
( fuseMaps
, fuseFilters
, fuseFilterIntoFold
, Input(..)
)
where
@@ -27,15 +25,15 @@ import qualified Futhark.Analysis.HORepresentation.SOAC as SOAC

import Futhark.Representation.AST
import Futhark.Binder
(Bindable(..), insertBinding, insertBindings, mkBody, mkLet')
(Bindable(..), insertBinding, insertBindings, mkLet')
import Futhark.Tools (mapResult)

-- | Something that can be used as a SOAC input. As far as this
-- module is concerned, this means supporting just a single operation.
class (Ord a, Eq a) => Input a where
-- | Check whether an arbitrary input corresponds to a plain
-- variable input. If so, return that variable.
isVarInput :: a -> Maybe Ident
isVarInput :: a -> Maybe VName

instance Input SOAC.Input where
isVarInput = SOAC.isVarInput
@@ -63,7 +61,7 @@ instance (Show a, Ord a, Input inp) => Input (a, inp) where
fuseMaps :: (Input input, Bindable lore) =>
Lambda lore -- ^ Function of SOAC to be fused.
-> [input] -- ^ Input of SOAC to be fused.
-> [(Ident,Ident)] -- ^ Output of SOAC to be fused. The
-> [(VName,Ident)] -- ^ Output of SOAC to be fused. The
-- first identifier is the name of the
-- actual output, where the second output
-- is an identifier that can be used to
@@ -86,76 +84,8 @@ fuseMaps lam1 inp1 out1 lam2 inp2 = (lam2', HM.elems inputmap)
(lam2redparams, pat, inputmap, makeCopies, makeCopiesInner) =
fuseInputs lam1 inp1 out1 lam2 inp2

-- | Similar to 'fuseMaps', although the two functions must be
-- predicates returning @{bool}@. Returns a new predicate function.
fuseFilters :: (Input input, Bindable lore) =>
Lambda lore -- ^ Function of SOAC to be fused.
-> [input] -- ^ Input of SOAC to be fused.
-> [(Ident,Ident)] -- ^ Output of SOAC to be fused.
-> Lambda lore -- ^ Function to be fused with.
-> [input] -- ^ Input of SOAC to be fused with.
-> VName -- ^ A fresh name (used internally).
-> (Lambda lore, [input]) -- ^ The fused lambda and the inputs of the resulting SOAC.
fuseFilters lam1 inp1 out1 lam2 inp2 vname =
fuseFilterInto lam1 inp1 out1 lam2 inp2 [vname] false
where false = mkBody [] $ Result [constant False]

-- | Similar to 'fuseFilters', except the second function does not
-- have to return @{bool}@, but must be a folding function taking at
-- least one reduction parameter (that is, the number of parameters
-- accepted by the function must be at least one greater than its
-- number of inputs). If @f1@ is the to-be-fused function, and @f2@
-- is the function to be fused with, the resulting function will be of
-- roughly following form:
--
-- @
-- fn (acc, args) => if f1(args)
-- then f2(acc,args)
-- else acc
-- @
fuseFilterIntoFold :: (Input input, Bindable lore) =>
Lambda lore -- ^ Function of SOAC to be fused.
-> [input] -- ^ Input of SOAC to be fused.
-> [(Ident,Ident)] -- ^ Output of SOAC to be fused.
-> Lambda lore -- ^ Function to be fused with.
-> [input] -- ^ Input of SOAC to be fused with.
-> [VName] -- ^ A fresh name (used internally).
-> (Lambda lore, [input]) -- ^ The fused lambda and the inputs of the resulting SOAC.
fuseFilterIntoFold lam1 inp1 out1 lam2 inp2 vnames =
fuseFilterInto lam1 inp1 out1 lam2 inp2 vnames identity
where identity = mkBody [] $ Result (map Var lam2redparams)
lam2redparams = take (length (lambdaParams lam2) - length inp2) $
lambdaParams lam2

fuseFilterInto :: (Input input, Bindable lore) =>
Lambda lore -> [input] -> [(Ident,Ident)]
-> Lambda lore -> [input]
-> [VName] -> Body lore
-> (Lambda lore, [input])
fuseFilterInto lam1 inp1 out1 lam2 inp2 vnames falsebranch = (lam2', HM.elems inputmap)
where lam2' =
lam2 { lambdaParams = lam2redparams ++ HM.keys inputmap
, lambdaBody = makeCopies bindins
}
restype = lambdaReturnType lam2
residents = [ Ident vname t | (vname, t) <- zip vnames restype ]
branch = flip mapResult (lambdaBody lam1) $ \res ->
let [e] = resultSubExps res -- XXX
tbranch = makeCopiesInner $ lambdaBody lam2
ts = bodyExtType tbranch `generaliseExtTypes`
bodyExtType falsebranch
in mkBody [mkLet' residents $
If e tbranch falsebranch ts] $
Result (map Var residents)
lam1tuple = [ mkLet' [v] $ PrimOp $ SubExp $ Var p
| (v,p) <- zip pat $ lambdaParams lam1 ]
bindins = lam1tuple `insertBindings` branch

(lam2redparams, pat, inputmap, makeCopies, makeCopiesInner) =
fuseInputs lam1 inp1 out1 lam2 inp2

fuseInputs :: (Input input, Bindable lore) =>
Lambda lore -> [input] -> [(Ident,Ident)]
Lambda lore -> [input] -> [(VName,Ident)]
-> Lambda lore -> [input]
-> ([Param],
[Ident],
@@ -176,7 +106,7 @@ fuseInputs lam1 inp1 out1 lam2 inp2 =
removeDuplicateInputs $ originputmap `HM.difference` outins

outParams :: Input input =>
[Ident] -> [Param] -> [input]
[VName] -> [Param] -> [input]
-> HM.HashMap Param input
outParams out1 lam2arrparams inp2 =
HM.fromList $ mapMaybe isOutParam $ zip lam2arrparams inp2
@@ -186,7 +116,7 @@ outParams out1 lam2arrparams inp2 =
isOutParam _ = Nothing

filterOutParams :: Input input =>
[(Ident,Ident)]
[(VName,Ident)]
-> HM.HashMap Param input
-> [Ident]
filterOutParams out1 outins =
@@ -202,66 +132,16 @@ filterOutParams out1 outins =
Just (p:ps) -> (M.insert a ps m, p)
_ -> (m, ra)


removeDuplicateInputs :: (Input input, Bindable lore) =>
HM.HashMap Param input
-> (HM.HashMap Param input, Body lore -> Body lore)
removeDuplicateInputs = fst . HM.foldlWithKey' comb ((HM.empty, id), M.empty)
where comb ((parmap, inner), arrmap) par arr =
case M.lookup arr arrmap of
Nothing -> ((HM.insert par arr parmap, inner),
M.insert arr par arrmap)
M.insert arr (identName par) arrmap)
Just par' -> ((parmap, inner . forward par par'),
arrmap)
forward to from b =
mkLet' [to] (PrimOp $ SubExp $ Var from)
`insertBinding` b

{-
An example of how I tested this module:
I add this import:
import Futhark.Dev
-}

{-
And now I can have top-level bindings like the following, that explicitly call fuseMaps:
(test1fun, test1ins) = fuseMaps lam1 lam1in out lam2 lam2in
where lam1in = [SOAC.varInput $ tident "[int] arr_x", SOAC.varInput $ tident "[int] arr_z"]
lam1 = lambdaToFunction $ lambda "fn {int, int} (int x, int z_b) => {x + z_b, x - z_b}"
outarr = tident "[int] arr_y"
outarr2 = tident "[int] arr_unused"
out = [outarr2, outarr]
lam2in = [Var outarr, Var $ tident "[int] arr_z"]
lam2 = lambdaToFunction $ lambda "fn {int} (int red, int y, int z) => {red + y + z}"
(test2fun, test2ins) = fuseFilterIntoFold lam1 lam1in out lam2 lam2in (name "check")
where lam1in = [SOAC.varInput $ tident "[int] arr_x", SOAC.varInput $ tident "[int] arr_v"]
lam1 = lambda "fn {bool} (int x, int v) => x+v < 0"
outarr = tident "[int] arr_y"
outarr2 = tident "[int] arr_unused"
out = [outarr, outarr2]
lam2in = [Var outarr]
lam2 = lambda "fn {int} (int red, int y) => {red + y}"
(test3fun, test3ins) = fuseFilterIntoFold lam1 lam1in out lam2 lam2in (name "check")
where lam1in = [expr "iota(30)", expr "replicate(30, 1)"]
lam1 = lambda "fn {bool} (int i, int j) => {i+j < 0}"
outarr = tident "[int] arr_p"
outarr2 = tident "[int] arr_unused"
out = [outarr, outarr2]
lam2in = [SOAC.varInput outarr]
lam2 = lambda "fn {int} (int x, int p) => {x ^ p}"
I can inspect these values directly in GHCi.
The point is to demonstrate that by factoring functionality out of the
huge monad in the fusion module, we get something that's much easier
to work with interactively.
-}

Large diffs are not rendered by default.

@@ -10,22 +10,35 @@ module Futhark.Optimise.Fusion.TryFusion

import Control.Applicative
import Control.Monad.State
import Control.Monad.Reader
import qualified Data.HashMap.Lazy as HM

import Futhark.Representation.Basic
import Futhark.NeedNames
import Futhark.MonadFreshNames

newtype TryFusion a = TryFusion (StateT VNameSource Maybe a)
deriving (Functor, Applicative, Alternative,
Monad, MonadState (NameSource VName))
newtype TryFusion a = TryFusion (ReaderT TypeEnv
(StateT VNameSource Maybe)
a)
deriving (Functor, Applicative, Alternative, Monad,
MonadReader TypeEnv,
MonadState (NameSource VName))

instance MonadFreshNames TryFusion where
getNameSource = get
putNameSource = put

tryFusion :: MonadFreshNames m => TryFusion a -> m (Maybe a)
tryFusion (TryFusion m) = modifyNameSource $ \src ->
case runStateT m src of
instance HasTypeEnv TryFusion where
lookupType name =
maybe notFound return =<< asks (HM.lookup name)
where notFound =
fail $ "Variable " ++ pretty name ++ " not found in symbol table"
askTypeEnv = ask

tryFusion :: MonadFreshNames m =>
TryFusion a -> TypeEnv -> m (Maybe a)
tryFusion (TryFusion m) types = modifyNameSource $ \src ->
case runStateT (runReaderT m types) src of
Just (x, src') -> (Just x, src')
Nothing -> (Nothing, src)

@@ -93,7 +93,7 @@ optimiseBody (Body als bnds res) = do
mapM_ seen $ resultSubExps res
return $ Body als bnds' res
where seen (Constant {}) = return ()
seen (Var v) = seenIdent v
seen (Var v) = seenVar v

optimiseBindings :: [Binding Basic]
-> ForwardingM ()
@@ -103,7 +103,7 @@ optimiseBindings [] m = m >> return []
optimiseBindings (bnd:bnds) m = do
(bnds', bup) <- tapBottomUp $ bindingBinding bnd $ optimiseBindings bnds m
bnd' <- optimiseInBinding bnd
case filter ((`elem` boundHere) . identName . updateValue) $
case filter ((`elem` boundHere) . updateValue) $
forwardThese bup of
[] -> checkIfForwardableUpdate bnd' bnds'
updates -> do
@@ -142,7 +142,7 @@ optimiseExp (LoopOp (DoLoop res merge form body)) =
bindingFParams (map fst merge) $ do
body' <- optimiseBody body
return $ LoopOp $ DoLoop res merge form body'
where boundInForm (ForLoop i _) = [i]
where boundInForm (ForLoop i _) = [Ident i $ Basic Int]
boundInForm (WhileLoop _) = []
optimiseExp e = mapExpM optimise e
where optimise = identityMapper { mapOnBody = optimiseBody
@@ -160,6 +160,7 @@ data Entry = Entry { entryNumber :: Int
, entryAliases :: Names
, entryDepth :: Int
, entryOptimisable :: Bool
, entryType :: Type
}

type VTable = HM.HashMap VName Entry
@@ -197,6 +198,9 @@ instance MonadFreshNames ForwardingM where
getNameSource = get
putNameSource = put

instance HasTypeEnv ForwardingM where
askTypeEnv = HM.map entryType <$> asks topDownTable

runForwardingM :: VNameSource -> ForwardingM a -> a
runForwardingM src (ForwardingM m) = fst $ evalRWS m emptyTopDown src
where emptyTopDown = TopDown { topDownCounter = 0
@@ -210,7 +214,7 @@ bindingFParams :: [FParam Basic]
bindingFParams fparams = local $ \(TopDown n vtable d) ->
let entry fparam =
(fparamName fparam,
Entry n mempty d False)
Entry n mempty d False $ fparamType fparam)
entries = HM.fromList $ map entry fparams
in TopDown (n+1) (HM.union entries vtable) d

@@ -222,7 +226,7 @@ bindingBinding (Let pat _ _) = local $ \(TopDown n vtable d) ->
entry patElem =
let (aliases, ()) = patElemLore patElem
in (patElemName patElem,
Entry n (unNames aliases) d True)
Entry n (unNames aliases) d True $ patElemType patElem)
in TopDown (n+1) (HM.union entries vtable) d

bindingIdents :: [Ident]
@@ -232,7 +236,7 @@ bindingIdents vs = local $ \(TopDown n vtable d) ->
let entries = HM.fromList $ map entry vs
entry v =
(identName v,
Entry n mempty d False)
Entry n mempty d False $ identType v)
in TopDown (n+1) (HM.union entries vtable) d

bindingNumber :: VName -> ForwardingM Int
@@ -251,7 +255,7 @@ areAvailableBefore ses point = do
nameNs <- mapM bindingNumber names
return $ all (< pointN) nameNs
where names = mapMaybe isVar ses
isVar (Var v) = Just $ identName v
isVar (Var v) = Just v
isVar (Constant {}) = Nothing

isInCurrentBody :: VName -> ForwardingM Bool
@@ -269,31 +273,30 @@ isOptimisable name = do
Nothing -> fail $ "isOptimisable: variable " ++
pretty name ++ " not found."

seenIdent :: Ident -> ForwardingM ()
seenIdent ident = do
seenVar :: VName -> ForwardingM ()
seenVar name = do
aliases <- asks $
maybe mempty entryAliases .
HM.lookup name . topDownTable
tell $ mempty { bottomUpSeen = HS.insert name aliases }
where name = identName ident

tapBottomUp :: ForwardingM a -> ForwardingM (a, BottomUp)
tapBottomUp m = do (x,bup) <- listen m
return (x, bup)

maybeForward :: Ident
-> Ident -> Certificates -> Ident -> SubExp
maybeForward :: VName
-> Ident -> Certificates -> VName -> SubExp
-> ForwardingM Bool
maybeForward v dest cs src i = do
-- Checks condition (2)
available <- [i,Var src] `areAvailableBefore` identName v
available <- [i,Var src] `areAvailableBefore` v
-- ...subcondition, the certificates must also.
certs_available <- map Var cs `areAvailableBefore` identName v
certs_available <- map Var cs `areAvailableBefore` v
-- Check condition (3)
samebody <- isInCurrentBody $ identName v
samebody <- isInCurrentBody v
-- Check condition (6)
optimisable <- isOptimisable $ identName v
let not_basic = not $ basicType $ identType v
optimisable <- isOptimisable v
not_basic <- not <$> basicType <$> lookupType v
if available && certs_available && samebody && optimisable && not_basic then do
let fwd = DesiredUpdate dest cs src [i] v
tell mempty { forwardThese = [fwd] }
@@ -1,3 +1,4 @@
{-# LANGUAGE FlexibleContexts #-}
module Futhark.Optimise.InPlaceLowering.LowerIntoBinding
(
lowerUpdate
@@ -6,6 +7,7 @@ module Futhark.Optimise.InPlaceLowering.LowerIntoBinding

import Control.Applicative
import Control.Monad
import Control.Monad.Writer
import Data.List (find)
import Data.Maybe (mapMaybe)
import Data.Either
@@ -21,13 +23,13 @@ import Futhark.Optimise.InPlaceLowering.SubstituteIndices
data DesiredUpdate =
DesiredUpdate { updateBindee :: Ident
, updateCertificates :: Certificates
, updateSource :: Ident
, updateSource :: VName
, updateIndices :: [SubExp]
, updateValue :: Ident
, updateValue :: VName
}

updateHasValue :: VName -> DesiredUpdate -> Bool
updateHasValue name = (name==) . identName . updateValue
updateHasValue name = (name==) . updateValue

lowerUpdate :: (Bindable lore, MonadFreshNames m) =>
Binding lore -> [DesiredUpdate] -> Maybe (m [Binding lore])
@@ -39,27 +41,27 @@ lowerUpdate (Let pat _ (LoopOp (DoLoop res merge form body))) updates = do
lowerUpdate
(Let pat _ (PrimOp (SubExp (Var v))))
[DesiredUpdate bindee cs src is val]
| patternIdents pat == [src] =
| patternNames pat == [src] =
Just $ return [mkLet [(bindee,BindInPlace cs v is)] $
PrimOp $ SubExp $ Var val]
lowerUpdate
(Let (Pattern [PatElem v BindVar _]) _ e)
[DesiredUpdate bindee cs src is val]
| v == val =
| identName v == val =
Just $ return [mkLet [(bindee,BindInPlace cs src is)] e,
mkLet' [v] $ PrimOp $ Index cs bindee is]
mkLet' [v] $ PrimOp $ Index cs (identName bindee) is]
lowerUpdate _ _ =
Nothing

lowerUpdateIntoLoop :: (Bindable lore, MonadFreshNames m) =>
[DesiredUpdate]
-> Pattern lore
-> [Ident]
-> [VName]
-> [(FParam lore, SubExp)]
-> Body lore
-> Maybe (m ([Binding lore],
[Ident],
[Ident],
[VName],
[(FParam lore, SubExp)],
Body lore))
lowerUpdateIntoLoop updates pat res merge body = do
@@ -96,11 +98,11 @@ lowerUpdateIntoLoop updates pat res merge body = do
idxsubsts = indexSubstitutions in_place_map
(idxsubsts', newbnds) <- substituteIndices idxsubsts $ bodyBindings body
let body' = mkBody newbnds $ manipulateResult in_place_map idxsubsts'
return (prebnds, pat', res', merge', body')
return (prebnds, pat', map identName res', merge', body')
where mergeparams = map fst merge
usedInBody = freeNamesInBody body
usedInBody = freeInBody body
resmap = loopResultValues
(patternIdents pat) (map identName res)
(patternIdents pat) res
(map fparamName mergeparams) $
resultSubExps $ bodyResult body

@@ -109,16 +111,18 @@ lowerUpdateIntoLoop updates pat res merge body = do
-> m ([(FParamT (), SubExp)], [Binding lore])
mkMerges summaries = do
((origmerge, extramerge), prebnds) <-
runBinderT $ partitionEithers <$> mapM mkMerge summaries
runWriterT $ partitionEithers <$> mapM mkMerge summaries
return (origmerge ++ extramerge, prebnds)

mkMerge summary
| Just (update, mergeident) <- relatedUpdate summary = do
source <- letInPlace "modified_source"
(updateCertificates update)
(updateSource update)
(updateIndices update)
$ PrimOp $ SubExp $ snd $ mergeParam summary
source <- newVName "modified_source"
let updpat = [((updateBindee update) { identName = source },
BindInPlace
(updateCertificates update)
(updateSource update)
(updateIndices update))]
tell [mkLet updpat $ PrimOp $ SubExp $ snd $ mergeParam summary]
return $ Right (FParam mergeident (), Var source)
| otherwise = return $ Left $ mergeParam summary

@@ -146,7 +150,7 @@ summariseLoop updates usedInBody resmap merge =
sequence <$> zipWithM summariseLoopResult resmap merge
where summariseLoopResult (se, Just v) (fparam, mergeinit)
| Just update <- find (updateHasValue $ identName v) updates =
if identName (updateSource update) `HS.member` usedInBody
if updateSource update `HS.member` usedInBody
then Nothing
else if hasLoopInvariantShape fparam then Just $ do
ident <-
@@ -168,7 +172,7 @@ summariseLoop updates usedInBody resmap merge =

merge_param_names = map (fparamName . fst) merge

loopInvariant (Var v) = identName v `notElem` merge_param_names
loopInvariant (Var v) = v `notElem` merge_param_names
loopInvariant (Constant {}) = True

data LoopResultSummary =
@@ -191,7 +195,7 @@ manipulateResult :: [LoopResultSummary]
-> Result
manipulateResult summaries substs =
let orig_ses = mapMaybe unchangedRes summaries
subst_ses = map (\(_,v,_) -> Var v) substs
subst_ses = map (\(_,v,_) -> Var $ identName v) substs
in Result $ orig_ses ++ subst_ses
where
unchangedRes summary =
@@ -1,3 +1,4 @@
{-# LANGUAGE FlexibleContexts #-}
-- | This module exports facilities for transforming array accesses in
-- a list of 'Binding's (intended to be the bindings in a body). The
-- idea is that you can state that some variable @x@ is in fact an
@@ -11,6 +12,7 @@ module Futhark.Optimise.InPlaceLowering.SubstituteIndices

import Control.Applicative
import Control.Monad
import qualified Data.HashMap.Lazy as HM

import Prelude

@@ -22,13 +24,19 @@ import Futhark.Util
type IndexSubstitution = (Certificates, Ident, [SubExp])
type IndexSubstitutions = [(VName, IndexSubstitution)]

typeEnvFromSubstitutions :: IndexSubstitutions -> TypeEnv
typeEnvFromSubstitutions = HM.fromList . map (fromSubstitution. snd)
where fromSubstitution (_, ident, _) =
(identName ident, identType ident)

substituteIndices :: (MonadFreshNames m, Bindable lore) =>
IndexSubstitutions -> [Binding lore]
-> m ([IndexSubstitution], [Binding lore])
substituteIndices substs bnds = do
(substs', bnds') <-
runBinder'' $ substituteIndicesInBindings substs bnds
runBinderT (substituteIndicesInBindings substs bnds) types
return (map snd substs', bnds')
where types = typeEnvFromSubstitutions substs

substituteIndicesInBindings :: MonadBinder m =>
IndexSubstitutions
@@ -46,7 +54,7 @@ substituteIndicesInBinding substs (Let pat lore e) = do
addBinding $ Let pat' lore e'
return substs'
where substitute = identityMapper { mapOnSubExp = substituteIndicesInSubExp substs
, mapOnIdent = substituteIndicesInIdent substs
, mapOnVName = substituteIndicesInVar substs
, mapOnBody = substituteIndicesInBody substs
}

@@ -58,30 +66,29 @@ substituteIndicesInPattern substs pat = do
(substs', patElems) <- mapAccumLM sub substs $ patternElements pat
return (substs', Pattern patElems)
where sub substs' (PatElem ident (BindInPlace cs src is) attr)
| Just (cs2, src2, is2) <- lookup srcname substs =
| Just (cs2, src2, is2) <- lookup src substs =
let ident' = ident { identType = identType src2 }
in return (update srcname name (cs2, ident', is2) substs',
PatElem ident' (BindInPlace (cs++cs2) src2 (is2++is)) attr)
where srcname = identName src
name = identName ident
in return (update src name (cs2, ident', is2) substs',
PatElem ident' (BindInPlace (cs++cs2) (identName src2) (is2++is)) attr)
where name = identName ident
sub substs' patElem =
return (substs', patElem)

substituteIndicesInSubExp :: MonadBinder m =>
IndexSubstitutions
-> SubExp
-> m SubExp
substituteIndicesInSubExp substs (Var v) = Var <$> substituteIndicesInIdent substs v
substituteIndicesInSubExp substs (Var v) = Var <$> substituteIndicesInVar substs v
substituteIndicesInSubExp _ se = return se


substituteIndicesInIdent :: MonadBinder m =>
IndexSubstitutions
-> Ident
-> m Ident
substituteIndicesInIdent substs v
| Just (cs2, src2, is2) <- lookup (identName v) substs =
letExp "idx" $ PrimOp $ Index cs2 src2 is2
substituteIndicesInVar :: MonadBinder m =>
IndexSubstitutions
-> VName
-> m VName
substituteIndicesInVar substs v
| Just (cs2, src2, is2) <- lookup v substs =
letExp "idx" $ PrimOp $ Index cs2 (identName src2) is2
| otherwise =
return v

@@ -14,7 +14,6 @@ import Control.Monad.Reader

import Data.List
import Data.Maybe

import qualified Data.HashMap.Lazy as HM

import Prelude
@@ -143,33 +142,42 @@ inlineInBody
continue' (Body _ callbnds res') =
continue $ callbnds ++
zipWith reshapeIfNecessary (patternIdents pat)
(withShapes $ resultSubExps res')
(runReader (withShapes $ resultSubExps res') $
typeEnvFromBindings callbnds)
in case filter ((== fname) . funDecName) inlcallees of
[] -> continue [bnd]
FunDec _ _ fargs body:_ ->
let revbnds = zip (map fparamIdent fargs) $ map fst args
in continue' $ foldr addArgBnd body revbnds
where

addArgBnd :: (Ident, SubExp) -> Body -> Body
addArgBnd (farg, aarg) body =
reshapeIfNecessary farg aarg `insertBinding` body

withShapes ses = extractShapeContext (retTypeValues rtp)
(map (arrayDims . subExpType) ses) ++ ses
withShapes ses = do
ts <- mapM subExpType ses
return $
extractShapeContext (retTypeValues rtp) (map arrayDims ts) ++
ses

reshapeIfNecessary ident se
| t@(Array {}) <- identType ident,
Var v <- se =
mkLet' [ident] $ PrimOp $ Reshape [] (arrayDims t) v
| otherwise =
mkLet' [ident] $ PrimOp $ SubExp se
inlineInBody inlcallees b = mapBody (inliner inlcallees) b
inlineInBody inlcallees (Body () (bnd:bnds) res) =
let bnd' = inlineInBinding inlcallees bnd
Body () bnds' res' = inlineInBody inlcallees $ Body () bnds res
in Body () (bnd':bnds') res'
inlineInBody _ (Body () [] res) =
Body () [] res

inliner :: Monad m => [FunDec] -> Mapper Basic Basic m
inliner funs = identityMapper {
mapOnLambda = return . inlineInLambda funs
, mapOnBody = return . inlineInBody funs
, mapOnBinding = return . inlineInBinding funs
}

inlineInBinding :: [FunDec] -> Binding -> Binding
@@ -41,9 +41,9 @@ simpleOpts simpl rules prog = do
return $ pass prog_flat_opt
where pass = deadCodeElim . simplifyProgWithRules simpl rules

normCopyOneLambda :: MonadFreshNames m =>
normCopyOneLambda :: (MonadFreshNames m, HasTypeEnv m) =>
Basic.Prog
-> Basic.Lambda
-> [Maybe Ident]
-> [Maybe VName]
-> m Basic.Lambda
normCopyOneLambda = simplifyLambdaWithRules bindableSimpleOps basicRules
@@ -46,12 +46,12 @@ simplifyFunWithRules simpl rules =
simplifyFun simpl rules

-- | Simplify just a single 'Lambda'.
simplifyLambdaWithRules :: (MonadFreshNames m, Simplifiable lore) =>
simplifyLambdaWithRules :: (MonadFreshNames m, HasTypeEnv m, Simplifiable lore) =>
SimpleOps (SimpleM lore)
-> RuleBook (SimpleM lore)
-> Prog lore
-> Lambda lore
-> [Maybe Ident]
-> [Maybe VName]
-> m (Lambda lore)
simplifyLambdaWithRules simpl rules prog lam args =
liftM removeLambdaWisdom $
@@ -22,4 +22,4 @@ simplifyApply program vtable fname args = do
where allArgsAreValues = mapM argIsValue

argIsValue (Constant val) = Just $ BasicVal val
argIsValue (Var v) = ST.lookupValue (identName v) vtable
argIsValue (Var v) = ST.lookupValue v vtable
@@ -13,12 +13,14 @@ module Futhark.Optimise.Simplifier.ClosedForm
where

import Control.Monad

import Control.Applicative
import Data.Maybe
import qualified Data.HashMap.Lazy as HM
import qualified Data.HashSet as HS
import Data.Monoid

import Prelude

import Futhark.Tools
import Futhark.Representation.AST
import Futhark.Renamer
@@ -47,15 +49,15 @@ Motivation:
-- each of the results of @foldfun@ can be expressed in a closed form.
foldClosedForm :: MonadBinder m =>
VarLookup (Lore m) -> Pattern (Lore m) -> Lambda (Lore m)
-> [SubExp] -> [Ident]
-> [SubExp] -> [VName]
-> RuleM m ()

foldClosedForm look pat lam accs arrs = do
closedBody <- checkResults (patternIdents pat) knownBindings
closedBody <- checkResults (patternNames pat) knownBindings
(lambdaParams lam) (lambdaBody lam) accs
isEmpty <- newIdent "fold_input_is_empty" (Basic Bool)
let inputsize = arraysSize 0 $ map identType arrs
letBindNames'_ [identName isEmpty] $
isEmpty <- newVName "fold_input_is_empty"
inputsize <- arraysSize 0 <$> mapM lookupType arrs
letBindNames'_ [isEmpty] $
PrimOp $ BinOp Equal inputsize (intconst 0) Bool
letBind_ pat =<<
eIf (eSubExp $ Var isEmpty)
@@ -66,15 +68,15 @@ foldClosedForm look pat lam accs arrs = do
-- | @loopClosedForm pat respat merge bound bodys@ determines whether
-- the do-loop can be expressed in a closed form.
loopClosedForm :: MonadBinder m =>
Pattern (Lore m) -> [Ident] -> [(FParam (Lore m),SubExp)]
Pattern (Lore m) -> [VName] -> [(FParam (Lore m),SubExp)]
-> SubExp -> Body (Lore m)
-> RuleM m ()
loopClosedForm pat respat merge bound body
| respat == mergeidents = do
| respat == mergenames = do
closedBody <- checkResults respat knownBindings
mergeidents body mergeexp
isEmpty <- newIdent "bound_is_zero" (Basic Bool)
letBindNames'_ [identName isEmpty] $
isEmpty <- newVName "bound_is_zero"
letBindNames'_ [isEmpty] $
PrimOp $ BinOp Leq bound (intconst 0) Bool
letBindNames'_ (patternNames pat) =<<
eIf (eSubExp $ Var isEmpty)
@@ -83,11 +85,12 @@ loopClosedForm pat respat merge bound body
| otherwise = cannotSimplify
where (mergepat, mergeexp) = unzip merge
mergeidents = map fparamIdent mergepat
knownBindings = HM.fromList $ zip mergeidents mergeexp
mergenames = map identName mergeidents
knownBindings = HM.fromList $ zip mergenames mergeexp

checkResults :: MonadBinder m =>
[Ident]
-> HM.HashMap Ident SubExp
[VName]
-> HM.HashMap VName SubExp
-> [Ident]
-> Body (Lore m)
-> [SubExp]
@@ -102,15 +105,15 @@ checkResults pat knownBindings params body accs = do
res = bodyResult body

nonFree = boundInBody body <>
HS.fromList params
HS.fromList (map identName params)

checkResult (p, e) _
| Just e' <- asFreeSubExp e = letBindNames'_ [identName p] $ PrimOp $ SubExp e'
| Just e' <- asFreeSubExp e = letBindNames'_ [p] $ PrimOp $ SubExp e'
checkResult (p, Var v) (accparam, acc) = do
e@(PrimOp (BinOp bop x y rt)) <- liftMaybe $ HM.lookup v bndMap
-- One of x,y must be *this* accumulator, and the other must
-- be something that is free in the body.
let isThisAccum = (==Var accparam)
let isThisAccum = (==Var (identName accparam))
(this, el) <- liftMaybe $
case ((asFreeSubExp x, isThisAccum y),
(asFreeSubExp y, isThisAccum x)) of
@@ -119,8 +122,8 @@ checkResults pat knownBindings params body accs = do
_ -> Nothing
case bop of
LogAnd -> do
letBindNames'_ [identName v] e
letBindNames'_ [identName p] $ PrimOp $ BinOp LogAnd this el rt
letBindNames'_ [v] e
letBindNames'_ [p] $ PrimOp $ BinOp LogAnd this el rt
_ -> cannotSimplify -- Um... sorry.

checkResult _ _ = cannotSimplify
@@ -130,25 +133,27 @@ checkResults pat knownBindings params body accs = do
| HS.member v nonFree = HM.lookup v knownBindings
asFreeSubExp se = Just se

determineKnownBindings :: VarLookup lore -> Lambda lore -> [SubExp] -> [Ident]
-> HM.HashMap Ident SubExp
determineKnownBindings :: VarLookup lore -> Lambda lore -> [SubExp] -> [VName]
-> HM.HashMap VName SubExp
determineKnownBindings look lam accs arrs =
accBindings <> arrBindings
where (accparams, arrparams) =
splitAt (length accs) $ lambdaParams lam
accBindings = HM.fromList $ zip accparams accs
arrBindings = HM.fromList $ mapMaybe isReplicate $ zip arrparams arrs
accBindings = HM.fromList $
zip (map identName accparams) accs
arrBindings = HM.fromList $ mapMaybe isReplicate $
zip (map identName arrparams) arrs

isReplicate (p, v)
| Just (PrimOp (Replicate _ ve)) <- look $ identName v = Just (p, ve)
| Just (PrimOp (Replicate _ ve)) <- look v = Just (p, ve)
isReplicate _ = Nothing

boundInBody :: Body lore -> HS.HashSet Ident
boundInBody :: Body lore -> Names
boundInBody = mconcat . map bound . bodyBindings
where bound (Let pat _ _) = HS.fromList $ patternIdents pat
where bound (Let pat _ _) = HS.fromList $ patternNames pat

makeBindMap :: Body lore -> HM.HashMap Ident (Exp lore)
makeBindMap :: Body lore -> HM.HashMap VName (Exp lore)
makeBindMap = HM.fromList . mapMaybe isSingletonBinding . bodyBindings
where isSingletonBinding (Let pat _ e) = case patternIdents pat of
where isSingletonBinding (Let pat _ e) = case patternNames pat of
[v] -> Just (v,e)
_ -> Nothing

Large diffs are not rendered by default.

@@ -25,6 +25,10 @@ instance MonadFreshNames m => MonadFreshNames (RuleM m) where
getNameSource = RuleM . lift $ getNameSource
putNameSource = RuleM . lift . putNameSource

instance (Monad m, HasTypeEnv m) => HasTypeEnv (RuleM m) where
lookupType = RuleM . lift . lookupType
askTypeEnv = RuleM . lift $ askTypeEnv

instance MonadBinder m => MonadBinder (RuleM m) where
type Lore (RuleM m) = Lore m
mkLetM pat e = RuleM $ lift $ mkLetM pat e

Large diffs are not rendered by default.

@@ -75,6 +75,17 @@ instance MonadFreshNames (SimpleM lore) where
getNameSource = snd <$> get
putNameSource y = modify $ \(x, _) -> (x,y)

instance Engine.Simplifiable lore =>
HasTypeEnv (SimpleM lore) where
askTypeEnv = ST.typeEnv <$> Engine.getVtable
lookupType name = do
vtable <- Engine.getVtable
case ST.lookupType name vtable of
Just t -> return t
Nothing -> fail $
"SimpleM.lookupType: cannot find variable " ++
pretty name ++ " in symbol table."

instance Engine.Simplifiable lore =>
MonadBinder (SimpleM lore) where
type Lore (SimpleM lore) = Wise lore
@@ -9,9 +9,12 @@ module Futhark.Optimise.Simplifier.Simplify
)
where

import Data.Monoid

import Futhark.Representation.AST
import Futhark.MonadFreshNames
import qualified Futhark.Optimise.Simplifier.Engine as Engine
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Optimise.Simplifier.Lore (Wise)
import Futhark.Optimise.Simplifier.Rule
import Futhark.Optimise.Simplifier.Simple
@@ -42,11 +45,15 @@ simplifyFun simpl rules fundec =
Engine.emptyEnv rules Nothing

-- | Simplify just a single 'Lambda'.
simplifyLambda :: (MonadFreshNames m, Simplifiable lore) =>
simplifyLambda :: (MonadFreshNames m, HasTypeEnv m, Simplifiable lore) =>
SimpleOps (SimpleM lore)
-> RuleBook (SimpleM lore)
-> Maybe (Prog lore) -> Lambda lore -> [Maybe Ident]
-> Maybe (Prog lore) -> Lambda lore -> [Maybe VName]
-> m (Lambda (Wise lore))
simplifyLambda simpl rules prog lam args =
modifyNameSource $ runSimpleM (Engine.simplifyLambda lam args) simpl $
Engine.emptyEnv rules prog
simplifyLambda simpl rules prog lam args = do
types <- askTypeEnv
let m =
Engine.localVtable (<> ST.fromTypeEnv types) $
Engine.simplifyLambda lam args
modifyNameSource $ runSimpleM m simpl $
Engine.emptyEnv rules prog
@@ -1,4 +1,4 @@
{-# LANGUAGE TypeFamilies, FlexibleContexts #-}
{-# LANGUAGE TypeFamilies, FlexibleContexts, GeneralizedNewtypeDeriving #-}
-- | For every function with an existential return shape, try to see
-- if we can extract an efficient shape slice. If so, replace every
-- call of the original function with a function to the shape and
@@ -10,6 +10,7 @@ where
import Control.Applicative
import Control.Monad.State
import Control.Monad.Writer
import Control.Monad.Reader

import qualified Data.HashMap.Lazy as HM
import Data.Maybe
@@ -28,7 +29,7 @@ import Futhark.Optimise.DeadVarElim
-- | Perform the transformation on a program.
splitShapes :: Prog -> Prog
splitShapes prog =
Prog { progFunctions = evalState m (newNameSourceForProg prog) }
Prog { progFunctions = runSplitM m HM.empty $ newNameSourceForProg prog }
where m = do let origfuns = progFunctions prog
(substs, newfuns) <-
unzip <$> map extract <$>
@@ -39,7 +40,20 @@ splitShapes prog =
funDecName valfun, funDecRetType valfun)),
[shapefun, valfun])

makeFunSubsts :: MonadFreshNames m =>
newtype SplitM a = SplitM (ReaderT TypeEnv
(State VNameSource)
a)
deriving (Applicative, Functor, Monad,
MonadReader TypeEnv,
MonadState VNameSource,
MonadFreshNames,
HasTypeEnv)

runSplitM :: SplitM a -> TypeEnv -> VNameSource -> a
runSplitM (SplitM m) =
evalState . runReaderT m

makeFunSubsts :: (MonadFreshNames m, HasTypeEnv m) =>
[FunDec] -> m [(Name, (FunDec, FunDec))]
makeFunSubsts fundecs =
cheapSubsts <$>
@@ -52,19 +66,24 @@ makeFunSubsts fundecs =
-- | Returns shape slice and value slice. The shape slice duplicates
-- the entire value slice - you should try to simplify it, and see if
-- it's "cheap", in some sense.
functionSlices :: MonadFreshNames m => FunDec -> m (FunDec, FunDec)
functionSlices :: (MonadFreshNames m, HasTypeEnv m) =>
FunDec -> m (FunDec, FunDec)
functionSlices (FunDec fname rettype params body@(Body _ bodybnds bodyres)) = do
-- The shape function should not consume its arguments - if it wants
-- to do in-place stuff, it needs to copy them first. In most
-- cases, these copies will be removed by the simplifier.
(shapeParams, cpybnds) <- nonuniqueParams $ map fparamIdent params
((shapeParams, cpybnds),_) <- runBinderEmptyEnv $ nonuniqueParams $ map fparamIdent params

-- Give names to the existentially quantified sizes of the return
-- type. These will be passed as parameters to the value function.
(staticRettype, shapeidents) <-
runWriterT $
instantiateShapes instantiate $ retTypeValues rettype

shapes <- subExpShapeContext (retTypeValues rettype) $
resultSubExps bodyres
shapetypes <- mapM subExpType shapes

valueBody <- substituteExtResultShapes staticRettype body

let valueRettype = ExtRetType $ staticShapes staticRettype
@@ -79,35 +98,31 @@ functionSlices (FunDec fname rettype params body@(Body _ bodybnds bodyres)) = do
(map mkFParam valueParams)
valueBody
return (fShape, fValue)
where shapes = subExpShapeContext (retTypeValues rettype) $
resultSubExps bodyres
shapetypes = map subExpType shapes
shapeFname = fname <> nameFromString "_shape"
where shapeFname = fname <> nameFromString "_shape"
valueFname = fname <> nameFromString "_value"

instantiate _ = do v <- lift $ newIdent "precomp_shape" (Basic Int)
tell [v]
return $ Var v
return $ Var $ identName v

substituteExtResultShapes :: MonadFreshNames m => [Type] -> Body -> m Body
substituteExtResultShapes :: (MonadFreshNames m, HasTypeEnv m) =>
[Type] -> Body -> m Body
substituteExtResultShapes rettype (Body _ bnds res) = do
bnds' <- mapM substInBnd bnds
compshapes <- typesShapes <$> mapM subExpType (resultSubExps res)
let subst = HM.fromList $ mapMaybe isSubst $ zip compshapes $ typesShapes rettype
bnds' <- mapM (substInBnd subst) bnds
let res' = res { resultSubExps = map (substituteNames subst) $
resultSubExps res
}
return $ mkBody bnds' res'
where typesShapes = concatMap (shapeDims . arrayShape)
compshapes =
typesShapes $ map subExpType $ resultSubExps res
subst =
HM.fromList $ mapMaybe isSubst $ zip compshapes (typesShapes rettype)
isSubst (Var v1, Var v2) = Just (identName v1, identName v2)
isSubst (Var v1, Var v2) = Just (v1, v2)
isSubst _ = Nothing

substInBnd (Let pat _ e) =
mkLet' <$> mapM substInBnd' (patternIdents pat) <*>
substInBnd subst (Let pat _ e) =
mkLet' <$> mapM (substInBnd' subst) (patternIdents pat) <*>
pure (substituteNames subst e)
substInBnd' v
substInBnd' subst v
| identName v' `HM.member` subst = newIdent' (<>"unused") v'
| otherwise = return v'
where v' = v { identType = substituteNames subst $ identType v }
@@ -150,14 +165,14 @@ substCalls subst fundec = do

treatBinding (Let pat _ (Apply fname args _))
| Just (shapefun,shapetype,valfun,_) <- lookup fname subst =
liftM snd . runBinder'' $ do
liftM snd . runBinderEmptyEnv $ do
let (vs,vals) =
splitAt (length $ retTypeValues shapetype) $
patternElements pat
letBind_ (Pattern vs) $
Apply shapefun args shapetype
letBind_ (Pattern vals) $
Apply valfun ([(Var $ patElemIdent v,Observe) | v <- vs]++args)
Apply valfun ([(Var $ patElemName v,Observe) | v <- vs]++args)
(ExtRetType $ staticShapes $ map patElemType vals)

treatBinding (Let pat _ e) = do
@@ -32,10 +32,6 @@ type GenM = ReaderT GenEnv (State VNameSource)
runGenM :: MonadFreshNames m => GenEnv -> GenM a -> m a
runGenM env m = modifyNameSource $ runState (runReaderT m env)

instance MonadFreshNames GenM where
getNameSource = get
putNameSource = put

banning :: Names -> GenM a -> GenM a
banning = local . banning'
where banning' names (GenEnv cert deps blacklist) =
@@ -48,15 +44,15 @@ genPredicate :: MonadFreshNames m => FunDec -> m (FunDec, FunDec)
genPredicate (FunDec fname rettype params body) = do
pred_ident <- newIdent "pred" $ Basic Bool
cert_ident <- newIdent "pred_cert" $ Basic Cert
(pred_params, bnds) <- nonuniqueParams $ map fparamIdent params
((pred_params, bnds),_) <- runBinderEmptyEnv $ nonuniqueParams $ map fparamIdent params
let env = GenEnv cert_ident (dataDependencies body) mempty
(pred_body, Body _ val_bnds val_res) <- runGenM env $ splitFunBody body
let mkFParam = flip FParam ()
pred_args = [ (Var arg, Observe) | arg <- map fparamIdent params ]
pred_args = [ (Var arg, Observe) | arg <- map fparamName params ]
pred_bnd = mkLet' [pred_ident] $
Apply predFname pred_args $ basicRetType Bool
cert_bnd = mkLet' [cert_ident] $
PrimOp $ Assert (Var pred_ident) noLoc
PrimOp $ Assert (Var $ identName pred_ident) noLoc
val_fun = FunDec fname rettype params
(mkBody (pred_bnd:cert_bnd:val_bnds) val_res)
pred_fun = FunDec predFname (basicRetType Bool)
@@ -78,7 +74,7 @@ splitBody :: Body -> GenM (Body, Body)
splitBody (Body _ bnds valres) = do
(pred_bnds, val_bnds, preds) <- unzip3 <$> mapM splitBinding bnds
(conjoined_preds, conj_bnds) <-
runBinder'' $ letSubExp "conjPreds" =<<
runBinderEmptyEnv $ letSubExp "conjPreds" =<<
foldBinOp LogAnd (constant True) (catMaybes preds) Bool
let predbody = mkBody (concat pred_bnds <> conj_bnds) $
valres { resultSubExps =
@@ -93,10 +89,10 @@ splitBinding bnd@(Let pat _ (PrimOp (Assert (Var v) _))) = do
GenEnv cert_ident deps blacklist <- ask
let forbidden =
not $ HS.null $ maybe HS.empty (HS.intersection blacklist) $
HM.lookup (identName v) deps
HM.lookup v deps
return $ if forbidden then ([bnd], bnd, Nothing)
else ([bnd],
mkLet' (patternIdents pat) $ PrimOp $ SubExp (Var cert_ident),
mkLet' (patternIdents pat) $ PrimOp $ SubExp (Var $ identName cert_ident),
Just $ Var v)

splitBinding bnd@(Let pat _ (LoopOp (Map cs fun args))) = do
@@ -130,22 +126,22 @@ splitBinding bnd@(Let pat _ (LoopOp (Redomap cs outerfun innerfun acc arr))) = d
splitBinding (Let pat _ (LoopOp (DoLoop respat merge form body))) = do
(predbody, valbody) <- splitBody body
ok <- newIdent "loop_ok" (Basic Bool)
predbody' <- conjoinLoopBody ok predbody
let predloop = LoopOp $ DoLoop (respat++[ok])
predbody' <- conjoinLoopBody (identName ok) predbody
let predloop = LoopOp $ DoLoop (respat++[identName ok])
(merge++[(FParam ok (),constant True)]) form
predbody'
valloop = LoopOp $ DoLoop respat merge form valbody
return ([mkLet' (idents<>[ok]) predloop],
mkLet' idents valloop,
Just $ Var ok)
Just $ Var $ identName ok)
where
idents = patternIdents pat
conjoinLoopBody ok (Body _ bnds res) = do
ok' <- newIdent "loop_ok_res" (Basic Bool)
case reverse $ resultSubExps res of
[] -> fail "conjoinLoopBody: null loop"
x:xs ->
let res' = res { resultSubExps = reverse $ Var ok':xs }
let res' = res { resultSubExps = reverse $ Var (identName ok'):xs }
bnds' = bnds ++
[mkLet' [ok'] $ PrimOp $ BinOp LogAnd x (Var ok) Bool]
in return $ mkBody bnds' res'
@@ -158,12 +154,12 @@ splitBinding (Let pat _ (If cond tbranch fbranch t)) = do
If cond tbranch_pred fbranch_pred
(t<>[Basic Bool])],
mkLet' idents $ If cond tbranch_val fbranch_val t,
Just $ Var ok)
Just $ Var $ identName ok)
where idents = patternIdents pat

splitBinding bnd = return ([bnd], bnd, Nothing)

splitMap :: [Ident] -> Lambda -> [Ident]
splitMap :: Certificates -> Lambda -> [VName]
-> GenM ([Binding], Lambda, Maybe SubExp)
splitMap cs fun args = do
(predfun, valfun) <- splitMapLambda fun
@@ -172,7 +168,7 @@ splitMap cs fun args = do
valfun,
Just andcheck)

splitReduce :: [Ident] -> Lambda -> [(SubExp,Ident)]
splitReduce :: Certificates -> Lambda -> [(SubExp,VName)]
-> GenM ([Binding], Lambda, Maybe SubExp)
splitReduce cs fun args = do
(predfun, valfun) <- splitFoldLambda fun $ map fst args
@@ -181,7 +177,7 @@ splitReduce cs fun args = do
valfun,
Just andcheck)

splitRedomap :: [Ident] -> Lambda -> [SubExp] -> [Ident]
splitRedomap :: Certificates -> Lambda -> [SubExp] -> [VName]
-> GenM ([Binding], Lambda, Maybe SubExp)
splitRedomap cs fun acc arr = do
(predfun, valfun) <- splitFoldLambda fun acc
@@ -219,7 +215,7 @@ splitFoldLambda lam acc = do
accbnds = [ mkLet' [p] $ PrimOp $ SubExp e
| (p,e) <- zip accParams acc ]

allTrue :: Certificates -> Lambda -> [Ident]
allTrue :: Certificates -> Lambda -> [VName]
-> GenM (Binding, SubExp)
allTrue cs predfun args = do
andchecks <- newIdent "allTrue" (Basic Bool)
@@ -228,13 +224,13 @@ allTrue cs predfun args = do
let andbnd = mkLet' [andchecks] $ LoopOp $
Redomap cs andfun innerfun [constant True] args
return (andbnd,
Var andchecks)
Var $ identName andchecks)
where predConjFun = do
acc <- newIdent "acc" (Basic Bool)
res <- newIdent "res" (Basic Bool)
let Body _ predbnds (Result [se]) = lambdaBody predfun -- XXX
andbnd = mkLet' [res] $ PrimOp $ BinOp LogAnd (Var acc) se Bool
body = mkBody (predbnds++[andbnd]) $ Result [Var res]
andbnd = mkLet' [res] $ PrimOp $ BinOp LogAnd (Var $ identName acc) se Bool
body = mkBody (predbnds++[andbnd]) $ Result [Var $ identName res]
return Lambda { lambdaParams = acc : lambdaParams predfun
, lambdaReturnType = [Basic Bool]
, lambdaBody = body
@@ -73,7 +73,7 @@ insertPredicateCalls subst prog =
return $ bnds ++ [Let pat () e']
treatExp e@(Apply predf predargs predt)
| Just preds <- HM.lookup predf subst =
runBinder'' $ callPreds predt preds e $ \predf' ->
runBinderEmptyEnv $ callPreds predt preds e $ \predf' ->
Apply predf' predargs predt
treatExp e = do
e' <- mapExpM mapper e
@@ -148,16 +148,17 @@ analyseBody vtable sctable (Body bodylore (bnd@(Let (Pattern [patElem]) _ e):bnd
sctable' = case (analyseExp vtable e,
simplify <$> ST.lookupScalExp name vtable') of
(Nothing, Just (Right se@(SE.RelExp SE.LTH0 ine)))
| Int <- SE.scalExpType ine ->
case AS.mkSuffConds se ranges of
| Int <- runReader (SE.scalExpType ine) types ->
case AS.mkSuffConds se ranges types of
Left err -> error $ show err -- Why can this even fail?
Right ses -> HM.insert name (SufficientCond ses) sctable
(Just eSCTable, _) -> sctable <> eSCTable
_ -> sctable
in analyseBody vtable' sctable' $ Body bodylore bnds res
where name = patElemName patElem
ranges = rangesRep vtable
simplify se = AS.simplify se ranges
types = ST.typeEnv vtable
simplify se = AS.simplify se ranges undefined
analyseBody vtable sctable (Body bodylore (bnd:bnds) res) =
analyseBody (ST.insertBinding bnd vtable) sctable $ Body bodylore bnds res

@@ -172,9 +173,9 @@ analyseExp :: ST.SymbolTable Basic -> Exp -> Maybe SCTable
analyseExp vtable (LoopOp (DoLoop _ _ (ForLoop i bound) body)) =
Just $ analyseExpBody vtable' body
where vtable' = clampLower $ clampUpper vtable
clampUpper = ST.insertLoopVar (identName i) bound
clampUpper = ST.insertLoopVar i bound
-- If we enter the loop, then 'bound' is at least one.
clampLower = case bound of Var v -> identName v `ST.isAtLeast` 1
clampLower = case bound of Var v -> v `ST.isAtLeast` 1
Constant {} -> id
analyseExp vtable (LoopOp (DoLoop _ _ _ body)) =
Just $ analyseExpBody vtable body
@@ -265,6 +266,8 @@ instance MonadFreshNames m => MonadFreshNames (VariantM m) where
getNameSource = VariantM . lift $ getNameSource
putNameSource = VariantM . lift . putNameSource

instance (Functor m, Monad m) => HasTypeEnv (VariantM m) where

runVariantM :: (Functor m, Monad m) =>
Env m -> VariantM m a -> m (a, Bool)
runVariantM env (VariantM m) =
@@ -337,7 +340,8 @@ instance MonadFreshNames m =>
suffe <- generating Sufficient $
Simplify.simplifyExp =<< renameExp (removeExpWisdom e)
let pat' = pat { patternElements =
zipWith tagPatElem (patternElements pat) vs
zipWith tagPatElem (patternElements pat) $
map identName vs
}
tagPatElem patElem v =
patElem `setPatElemLore` (fst $ patElemLore patElem, Just v)
@@ -346,7 +350,7 @@ instance MonadFreshNames m =>
Simplify.defaultInspectBinding $ Let pat' lore e

simplifyLetBoundLore Nothing = return Nothing
simplifyLetBoundLore (Just v) = Just <$> Simplify.simplifyIdent v
simplifyLetBoundLore (Just v) = Just <$> Simplify.simplifyVName v

simplifyFParamLore =
return
@@ -363,28 +367,27 @@ makeSufficientBinding' :: MonadFreshNames m => Context m -> S.Binding Invariance
makeSufficientBinding' context@(_,vtable) (Let pat _ e)
| Just (Right se@(SE.RelExp SE.LTH0 ine)) <-
simplify <$> SE.toScalExp (`suffScalExp` vtable) e,
Int <- SE.scalExpType ine,
Right suff <- AS.mkSuffConds se ranges,
Int <- runReader (SE.scalExpType ine) types,
Right suff <- AS.mkSuffConds se ranges types,
x:xs <- filter (scalExpUsesNoForbidden context) $ map mkConj suff = do
suffe <- SE.fromScalExp' $ foldl SE.SLogOr x xs
letBind_ pat suffe
where ranges = rangesRep vtable
simplify se = AS.simplify se ranges
types = ST.typeEnv vtable
simplify se = AS.simplify se ranges types
mkConj [] = SE.Val $ LogVal True
mkConj (x:xs) = foldl SE.SLogAnd x xs
makeSufficientBinding' _ (Let pat _ (PrimOp (BinOp LogAnd x y t))) = do
x' <- sufficientSubExp x
y' <- sufficientSubExp y
letBind_ pat $ PrimOp $ BinOp LogAnd x' y' t
makeSufficientBinding' env (Let pat _ (If (Var v) tbranch fbranch _))
| identName v `forbiddenIn` env,
makeSufficientBinding' env (Let pat _ (If (Var v) tbranch fbranch [Basic Bool]))
| v `forbiddenIn` env,
-- FIXME: Check that tbranch and fbranch are safe. We can do
-- something smarter if 'v' actually comes from an 'or'. Also,
-- currently only handles case where pat is a singleton boolean.
Body _ tbnds (Result [tres]) <- tbranch,
Body _ fbnds (Result [fres]) <- fbranch,
Basic Bool <- subExpType tres,
Basic Bool <- subExpType fres,
all safeBnd tbnds, all safeBnd fbnds = do
mapM_ addBinding tbnds
mapM_ addBinding fbnds
@@ -397,20 +400,20 @@ suffScalExp :: VName -> ST.SymbolTable Invariance -> Maybe ScalExp
suffScalExp name vtable = asSuffScalExp =<< ST.lookup name vtable
where asSuffScalExp entry
| Just (_, Just suff) <- ST.entryLetBoundLore entry,
Just se <- suffScalExp (identName suff) vtable =
Just se <- suffScalExp suff vtable =
Just se
| otherwise = ST.asScalExp entry

sufficientSubExp :: MonadFreshNames m => SubExp -> VariantM m SubExp
sufficientSubExp se@(Constant {}) = return se
sufficientSubExp (Var v) =
maybe (Var v) Var .
(snd <=< ST.entryLetBoundLore <=< ST.lookup (identName v)) <$>
(snd <=< ST.entryLetBoundLore <=< ST.lookup v) <$>
Simplify.getVtable

scalExpUsesNoForbidden :: Context m -> ScalExp -> Bool
scalExpUsesNoForbidden context =
not . any (`forbiddenIn` context) . freeNamesIn
not . any (`forbiddenIn` context) . freeIn

-- | The lore containing invariance information.
data Variance = Invariant
@@ -430,7 +433,7 @@ isForbidden TooVariant = True

data Invariance' = Invariance'
instance Lore.Lore Invariance' where
type LetBound Invariance' = Maybe Ident
type LetBound Invariance' = Maybe VName
type Exp Invariance' = Variance
representative = Invariance'
loopResultContext _ = loopResultContext (representative :: Basic)
@@ -456,22 +459,22 @@ forbiddenExp context = isNothing . walkExpM walk
where walk = Walker { walkOnSubExp = checkIf forbiddenSubExp
, walkOnBody = checkIf forbiddenBody
, walkOnBinding = checkIf $ isForbidden . snd . bindingLore
, walkOnIdent = checkIf forbiddenIdent
, walkOnVName = checkIf forbiddenVar
, walkOnLambda = checkIf $ forbiddenBody . lambdaBody
, walkOnRetType = checkIf forbiddenRetType
, walkOnFParam = checkIf forbiddenFParam
, walkOnCertificates = mapM_ $ checkIf forbiddenIdent
, walkOnCertificates = mapM_ $ checkIf forbiddenVar
}
checkIf f x = if f x
then Nothing
else Just ()

forbiddenIdent = (`forbiddenIn` context) . identName
forbiddenVar = (`forbiddenIn` context)

forbiddenSubExp (Var v) = identName v `forbiddenIn` context
forbiddenSubExp (Var v) = v `forbiddenIn` context
forbiddenSubExp (Constant {}) = False

forbiddenBody = any (isForbidden . snd . bindingLore) . bodyBindings

forbiddenRetType = any forbiddenIdent . freeIn
forbiddenFParam = any forbiddenIdent . freeIn
forbiddenRetType = any forbiddenVar . freeIn
forbiddenFParam = any forbiddenVar . freeIn