Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove base from index functions #2083

Merged
merged 9 commits into from
Jan 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
83 changes: 32 additions & 51 deletions src/Futhark/IR/Mem/IxFun.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import Control.Monad.State
import Data.Map.Strict qualified as M
import Data.Traversable
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Mem.LMAD hiding
( equivalent,
flatSlice,
Expand Down Expand Up @@ -74,20 +73,11 @@ import Prelude hiding (gcd, id, mod, (.))
-- distinguish row-major and column-major representations.
--
-- An index function is represented as an LMAD.
data IxFun num = IxFun
{ ixfunLMAD :: LMAD num,
-- | the shape of the support array, i.e., the original array
-- that birthed (is the start point) of this index function.
base :: Shape num
}
newtype IxFun num = IxFun {ixfunLMAD :: LMAD num}
deriving (Show, Eq)

instance (Pretty num) => Pretty (IxFun num) where
pretty (IxFun lmad oshp) =
braces . semistack $
[ "base:" <+> brackets (commasep $ map pretty oshp),
"LMAD:" <+> pretty lmad
]
pretty (IxFun lmad) = pretty lmad

instance (Substitute num) => Substitute (IxFun num) where
substituteNames substs = fmap $ substituteNames substs
Expand All @@ -107,31 +97,26 @@ instance Foldable IxFun where
-- It is important that the traversal order here is the same as in
-- mkExistential.
instance Traversable IxFun where
traverse f (IxFun lmad oshp) =
IxFun <$> traverse f lmad <*> traverse f oshp
traverse f (IxFun lmad) =
IxFun <$> traverse f lmad

-- | Substitute a name with a PrimExp in an index function.
substituteInIxFun ::
(Ord a) =>
M.Map a (TPrimExp t a) ->
IxFun (TPrimExp t a) ->
IxFun (TPrimExp t a)
substituteInIxFun tab (IxFun lmad oshp) =
IxFun
(substituteInLMAD tab lmad)
(map (TPrimExp . substituteInPrimExp tab' . untyped) oshp)
where
tab' = fmap untyped tab
substituteInIxFun tab (IxFun lmad) =
IxFun $ substituteInLMAD tab lmad

-- | Is this is a row-major array?
isDirect :: (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect (IxFun (LMAD offset dims) oshp) =
let strides_expected = reverse $ scanl (*) 1 (reverse (tail oshp))
in length oshp == length dims
&& offset == 0
isDirect (IxFun lmad@(LMAD offset dims)) =
let strides_expected = reverse $ scanl (*) 1 $ reverse $ tail $ LMAD.shape lmad
in offset == 0
&& all
(\(LMADDim s n, d, se) -> s == se && n == d)
(zip3 dims oshp strides_expected)
(\(LMADDim s _, se) -> s == se)
(zip dims strides_expected)

-- | The index space of the index function. This is the same as the
-- shape of arrays that the index function supports.
Expand All @@ -149,48 +134,46 @@ index = LMAD.index . ixfunLMAD

-- | iota with offset.
iotaOffset :: (IntegralExp num) => num -> Shape num -> IxFun num
iotaOffset o ns = IxFun (LMAD.iota o ns) ns
iotaOffset o ns = IxFun $ LMAD.iota o ns

-- | iota.
iota :: (IntegralExp num) => Shape num -> IxFun num
iota = iotaOffset 0

-- | Create a single-LMAD index function that is existential in
-- everything except shape, with the provided shape.
mkExistential :: Int -> Shape (Ext a) -> Int -> IxFun (Ext a)
mkExistential basis_rank lmad_shape start =
IxFun (LMAD.mkExistential lmad_shape start) basis
where
basis = take basis_rank $ map Ext [start + 1 + length lmad_shape ..]
mkExistential :: Shape (Ext a) -> Int -> IxFun (Ext a)
mkExistential lmad_shape start =
IxFun (LMAD.mkExistential lmad_shape start)

-- | Permute dimensions.
permute ::
(IntegralExp num) =>
IxFun num ->
Permutation ->
IxFun num
permute (IxFun lmad oshp) perm_new =
IxFun (LMAD.permute lmad perm_new) oshp
permute (IxFun lmad) perm_new =
IxFun (LMAD.permute lmad perm_new)

-- | Slice an index function.
slice ::
(Eq num, IntegralExp num) =>
IxFun num ->
Slice num ->
IxFun num
slice ixfun@(IxFun lmad@(LMAD _ _) oshp) (Slice is)
slice ixfun@(IxFun lmad@(LMAD _ _)) (Slice is)
-- Avoid identity slicing.
| is == map (unitSlice 0) (shape ixfun) = ixfun
| otherwise =
IxFun (LMAD.slice lmad (Slice is)) oshp
IxFun (LMAD.slice lmad (Slice is))

-- | Flat-slice an index function.
flatSlice ::
(Eq num, IntegralExp num) =>
IxFun num ->
FlatSlice num ->
IxFun num
flatSlice (IxFun lmad oshp) s = IxFun (LMAD.flatSlice lmad s) oshp
flatSlice (IxFun lmad) s = IxFun (LMAD.flatSlice lmad s)

-- | Reshape an index function.
--
Expand All @@ -211,8 +194,8 @@ reshape ::
IxFun num ->
Shape num ->
Maybe (IxFun num)
reshape (IxFun lmad _) new_shape =
IxFun <$> LMAD.reshape lmad new_shape <*> pure new_shape
reshape (IxFun lmad) new_shape =
IxFun <$> LMAD.reshape lmad new_shape

-- | Coerce an index function to look like it has a new shape.
-- Dynamically the shape must be the same.
Expand All @@ -221,51 +204,50 @@ coerce ::
IxFun num ->
Shape num ->
IxFun num
coerce (IxFun lmad _) new_shape =
IxFun (onLMAD lmad) new_shape
coerce (IxFun lmad) new_shape =
IxFun (onLMAD lmad)
where
onLMAD (LMAD offset dims) = LMAD offset $ zipWith onDim dims new_shape
onDim ld d = ld {ldShape = d}

-- | The number of dimensions in the domain of the input function.
rank :: (IntegralExp num) => IxFun num -> Int
rank (IxFun (LMAD _ sss) _) = length sss
rank (IxFun (LMAD _ sss)) = length sss

-- | Conceptually expand index function to be a particular slice of
-- another by adjusting the offset and strides. Used for memory
-- expansion.
expand ::
(Eq num, IntegralExp num) => num -> num -> IxFun num -> Maybe (IxFun num)
expand o p (IxFun lmad base) =
expand o p (IxFun lmad) =
let onDim ld = ld {LMAD.ldStride = p * LMAD.ldStride ld}
lmad' =
LMAD
(o + p * LMAD.offset lmad)
(map onDim (LMAD.dims lmad))
in Just $ IxFun lmad' base
in Just $ IxFun lmad'

-- | Turn all the leaves of the index function into 'Ext's, except for
-- the shape, which where the leaves are simply made 'Free'.
existentialize ::
Int ->
IxFun (TPrimExp Int64 a) ->
IxFun (TPrimExp Int64 (Ext a))
existentialize start (IxFun lmad base) = evalState (IxFun <$> lmad' <*> base') start
existentialize start (IxFun lmad) = evalState (IxFun <$> lmad') start
where
mkExt = do
i <- get
put $ i + 1
pure $ TPrimExp $ LeafExp (Ext i) int64
lmad' = LMAD <$> mkExt <*> mapM onDim (dims lmad)
base' = traverse (const mkExt) base
onDim ld = LMADDim <$> mkExt <*> pure (fmap Free (ldShape ld))

-- | Retrieve those elements that 'existentialize' changes. That is,
-- everything except the shape (and in the same order as
-- 'existentialise' existentialises them).
existentialized :: IxFun a -> [a]
existentialized (IxFun (LMAD offset dims) base) =
offset : concatMap onDim dims <> base
existentialized (IxFun (LMAD offset dims)) =
offset : concatMap onDim dims
where
onDim (LMADDim ldstride _) = [ldstride]

Expand All @@ -281,13 +263,12 @@ existentialized (IxFun (LMAD offset dims) base) =
-- this instead of `ixfun1 == ixfun2` and hope that it's good enough.
closeEnough :: IxFun num -> IxFun num -> Bool
closeEnough ixf1 ixf2 =
(length (base ixf1) == length (base ixf2))
&& closeEnoughLMADs (ixfunLMAD ixf1) (ixfunLMAD ixf2)
closeEnoughLMADs (ixfunLMAD ixf1) (ixfunLMAD ixf2)
where
closeEnoughLMADs lmad1 lmad2 =
length (LMAD.dims lmad1) == length (LMAD.dims lmad2)

-- | The largest possible linear address reachable by this index
-- function.
-- function, not counting the offset.
range :: (Pretty num) => IxFun (TPrimExp Int64 num) -> TPrimExp Int64 num
range = LMAD.range . ixfunLMAD
16 changes: 8 additions & 8 deletions src/Futhark/IR/Mem/LMAD.hs
Original file line number Diff line number Diff line change
Expand Up @@ -544,16 +544,16 @@ isDirect :: (Eq num, IntegralExp num) => LMAD num -> Bool
isDirect lmad = lmad == iota 0 (map ldShape $ dims lmad)
{-# NOINLINE isDirect #-}

-- | The largest possible linear address reachable by this LMAD. If
-- you add one to this number (and multiply it with the element size),
-- you get the amount of bytes you need to allocate for an array with
-- this LMAD.
-- | The largest possible linear address reachable by this LMAD, not
-- counting the offset. If you add one to this number (and multiply it
-- with the element size), you get the amount of bytes you need to
-- allocate for an array with this LMAD (assuming zero offset).
range :: (Pretty num) => LMAD (TPrimExp Int64 num) -> TPrimExp Int64 num
range lmad =
-- The idea is that the largest possible offset must be the offset
-- plus the sum of the maximum offsets reachable in each dimension,
-- which must be at either the minimum or maximum index.
offset lmad + sum (map dimRange $ dims lmad)
-- The idea is that the largest possible offset must be the sum of
-- the maximum offsets reachable in each dimension, which must be at
-- either the minimum or maximum index.
sum (map dimRange $ dims lmad)
where
dimRange LMADDim {ldStride, ldShape} =
0 `sMax64` ((0 `sMax64` (ldShape - 1)) * ldStride)
Expand Down
21 changes: 4 additions & 17 deletions src/Futhark/IR/Mem/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ where

import Control.Monad
import Data.List (find)
import Data.Maybe (isJust)
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.Construct
import Futhark.IR.Mem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.IR.Prop.Aliases (AliasedOp)
import Futhark.Optimise.Simplify qualified as Simplify
import Futhark.Optimise.Simplify.Engine qualified as Engine
Expand Down Expand Up @@ -128,7 +128,7 @@ memRuleBook =
unExistentialiseMemory :: (SimplifyMemory rep inner) => TopDownRuleMatch (Wise rep)
unExistentialiseMemory vtable pat _ (cond, cases, defbody, ifdec)
| ST.simplifyMemory vtable,
fixable <- foldl hasConcretisableMemory mempty $ zip [0 ..] $ patElems pat,
fixable <- foldl hasConcretisableMemory mempty $ patElems pat,
not $ null fixable = Simplify $ do
-- Create non-existential memory blocks big enough to hold the
-- arrays.
Expand Down Expand Up @@ -167,7 +167,7 @@ unExistentialiseMemory vtable pat _ (cond, cases, defbody, ifdec)
knownSize (Var v) = not $ inContext v
inContext = (`elem` patNames pat)

hasConcretisableMemory fixable (i, pat_elem)
hasConcretisableMemory fixable pat_elem
| (_, MemArray pt shape _ (ArrayIn mem ixfun)) <- patElemDec pat_elem,
Just (j, Mem space) <-
fmap patElemType
Expand All @@ -180,24 +180,11 @@ unExistentialiseMemory vtable pat _ (cond, cases, defbody, ifdec)
all knownSize (shapeDims shape),
not $ freeIn ixfun `namesIntersect` namesFromList (patNames pat),
any (defbody_se /=) cases_ses,
all (notIndex i) (defbody : map caseBody cases) =
LMAD.offset (IxFun.ixfunLMAD ixfun) == 0 =
let mem_size = untyped $ primByteSize pt * (1 + IxFun.range ixfun)
in (pat_elem, mem_size, mem, space) : fixable
| otherwise =
fixable

-- Check if the i'th result is not an index; see #1325. This is a
-- rather crude check, but this is also a rather crude
-- simplification rule. We only need to keep it around until
-- memory expansion can handle existential memory generally.
notIndex i body
| Just (SubExpRes _ (Var v)) <- maybeNth (i :: Int) $ bodyResult body =
not $ any (bad v) $ bodyStms body
where
bad v (Let index_pat _ (BasicOp (Index _ slice))) =
(v `elem` patNames index_pat) && any (isJust . dimFix) (unSlice slice)
bad _ _ = False
notIndex _ _ = True
unExistentialiseMemory _ _ _ _ = Skip

-- If an allocation is statically known to be safe, then we can remove
Expand Down
5 changes: 1 addition & 4 deletions src/Futhark/IR/Parse.hs
Original file line number Diff line number Diff line change
Expand Up @@ -963,10 +963,7 @@ pMCOp pr pOther =

pIxFunBase :: Parser a -> Parser (IxFun.IxFun a)
pIxFunBase pNum =
braces $ do
base <- pLab "base" $ brackets (pNum `sepBy` pComma) <* pSemi
lmad <- pLab "LMAD" pLMAD
pure $ IxFun.IxFun lmad base
IxFun.IxFun <$> pLMAD
where
pLab s m = keyword s *> pColon *> m
pLMAD = braces $ do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,7 @@ mkCoalsTabStm lutab stm@(Let pat _ e) td_env bu_env = do
_ -> (failed, s_acc) -- fail!

ixfunToAccessSummary :: IxFun.IxFun (TPrimExp Int64 VName) -> AccessSummary
ixfunToAccessSummary (IxFun.IxFun lmad _) = Set $ S.singleton lmad
ixfunToAccessSummary (IxFun.IxFun lmad) = Set $ S.singleton lmad

-- | Check safety conditions 2 and 5 and update new substitutions:
-- called on the pat-elements of loop and if-then-else expressions.
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/Optimise/ArrayShortCircuiting/MemRefAggreg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ recordMemRefUses td_env bu_env stm =
<> fromMaybe mempty (M.lookup m (m_alias td_env))
mbLmad indfun
| Just subs <- freeVarSubstitutions (scope td_env) (scals bu_env) indfun,
(IxFun.IxFun lmad _) <- IxFun.substituteInIxFun subs indfun =
(IxFun.IxFun lmad) <- IxFun.substituteInIxFun subs indfun =
Just lmad
mbLmad _ = Nothing
addLmads wrts uses etry =
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/Optimise/BlkRegTiling.hs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ kkLoopBody
| [slc_X'] <- patNames pat,
slc_X == slc_X',
Just ixf_fn <- M.lookup x ixfn_env,
(IxFun.IxFun lmad _) <- ixf_fn =
(IxFun.IxFun lmad) <- ixf_fn =
innerHasStride1 lmad
isInnerCoal _ _ _ =
error "kkLoopBody.isInnerCoal: not an error, but I would like to know why!"
Expand Down