Skip to content

Commit

Permalink
Compile from PartialSchedule to UniformSchedule
Browse files Browse the repository at this point in the history
  • Loading branch information
ivogabe committed Aug 31, 2021
1 parent b465b1a commit 45cb8f4
Show file tree
Hide file tree
Showing 8 changed files with 915 additions and 104 deletions.
6 changes: 5 additions & 1 deletion src/Data/Array/Accelerate/AST/Environment.hs
Expand Up @@ -22,7 +22,7 @@ module Data.Array.Accelerate.AST.Environment (
unionPartialEnv, EnvBinding(..), partialEnvFromList, mapPartialEnv,
mapMaybePartialEnv, partialEnvValues, diffPartialEnv, diffPartialEnvWith,
intersectPartialEnv, partialEnvTail, partialEnvLast, partialEnvSkip,
partialUpdate, partialEnvToList,
partialUpdate, partialEnvToList, partialEnvSingleton,

prjUpdate', prjReplace', update', updates', mapEnv,
Identity(..), (:>)(..), weakenId, weakenSucc, weakenSucc', weakenEmpty,
Expand Down Expand Up @@ -179,6 +179,10 @@ partialEnvValues PEnd = []
partialEnvValues (PNone env) = partialEnvValues env
partialEnvValues (PPush env (IdentityF a)) = a : partialEnvValues env

partialEnvSingleton :: Idx env t -> f t -> PartialEnv f env
partialEnvSingleton ZeroIdx v = PPush PEnd v
partialEnvSingleton (SuccIdx idx) v = PNone $ partialEnvSingleton idx v

-- Wrapper to put homogenous types in an Env or PartialEnv
newtype IdentityF t f = IdentityF t

Expand Down
21 changes: 19 additions & 2 deletions src/Data/Array/Accelerate/AST/IdxSet.hs
Expand Up @@ -18,13 +18,15 @@

module Data.Array.Accelerate.AST.IdxSet (
IdxSet,
member, intersect, union, insert, skip,
push, empty, drop, drop', fromList
member, varMember, intersect, union, insert, insertVar, skip,
push, empty, drop, drop', fromList, fromVarList,
singleton, singletonVar,
) where

import Prelude hiding (drop)

import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.AST.Environment hiding ( push )
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Maybe
Expand All @@ -36,6 +38,9 @@ data Present a = Present
member :: Idx env t -> IdxSet env -> Bool
member idx (IdxSet set) = isJust $ prjPartial idx set

varMember :: Var s env t -> IdxSet env -> Bool
varMember (Var _ idx) = member idx

intersect :: IdxSet env -> IdxSet env -> IdxSet env
intersect (IdxSet a) (IdxSet b) = IdxSet $ intersectPartialEnv (\_ _ -> Present) a b

Expand All @@ -45,6 +50,9 @@ union (IdxSet a) (IdxSet b) = IdxSet $ unionPartialEnv (\_ _ -> Present) a b
insert :: Idx env t -> IdxSet env -> IdxSet env
insert idx (IdxSet a) = IdxSet $ partialUpdate Present idx a

insertVar :: Var s env t -> IdxSet env -> IdxSet env
insertVar (Var _ idx) = insert idx

skip :: IdxSet env -> IdxSet (env, t)
skip = IdxSet . PNone . unIdxSet

Expand All @@ -67,3 +75,12 @@ toList = map (\(EnvBinding idx _) -> Exists idx) . partialEnvToList . unIdxSet

fromList :: [Exists (Idx env)] -> IdxSet env
fromList = IdxSet . partialEnvFromList (\_ _ -> Present) . map (\(Exists idx) -> EnvBinding idx Present)

fromVarList :: [Exists (Var s env)] -> IdxSet env
fromVarList = fromList . map (\(Exists (Var _ idx)) -> Exists idx)

singleton :: Idx env t -> IdxSet env
singleton idx = IdxSet $ partialEnvSingleton idx Present

singletonVar :: Var s env t -> IdxSet env
singletonVar (Var _ idx) = singleton idx
7 changes: 5 additions & 2 deletions src/Data/Array/Accelerate/AST/Operation.hs
Expand Up @@ -67,7 +67,7 @@ import Data.Array.Accelerate.Trafo.Exp.Substitution
import Data.Array.Accelerate.Trafo.Exp.Shrink
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Error
import Data.Typeable ( (:~:)(..) )
import Data.Typeable ( (:~:)(..) )

import Data.ByteString.Builder.Extra
import Language.Haskell.TH ( Q, TExp )
Expand Down Expand Up @@ -385,7 +385,10 @@ paramsIn' (TupRsingle v) = ArrayInstr (Parameter v) Nil

type ReindexPartial f env env' = forall a. Idx env a -> f (Idx env' a)

reindexVar :: Applicative f => ReindexPartial f env env' -> Var s env t -> f (Var s env' t)
-- The first argument is ReindexPartial, but without the forall and specialized to 't'.
-- This makes it usable in more situations.
--
reindexVar :: Applicative f => (Idx env t -> f (Idx env' t)) -> Var s env t -> f (Var s env' t)
reindexVar k (Var repr ix) = Var repr <$> k ix

reindexVars :: Applicative f => ReindexPartial f env env' -> Vars s env t -> f (Vars s env' t)
Expand Down
161 changes: 144 additions & 17 deletions src/Data/Array/Accelerate/AST/Schedule/Uniform.hs
@@ -1,8 +1,11 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module : Data.Array.Accelerate.Schedule.Uniform
Expand All @@ -16,26 +19,35 @@

module Data.Array.Accelerate.AST.Schedule.Uniform (
UniformSchedule(..), UniformScheduleFun(..),
Input, Output, InputOutputR(..),
Input, Output, inputSingle, inputR, outputR, InputOutputR(..),
ScheduleFunction, scheduleFunctionIsBody,
Binding(..), Effect(..),
BaseR(..), BasesR, BaseVar, BaseVars, BLeftHandSide,
Signal(..), SignalResolver(..), Ref(..), OutputRef(..),
module Partitioned,
await, resolve,
signalResolverImpossible, scalarSignalResolverImpossible,

-- ** Free variables
freeVars, funFreeVars, effectFreeVars, bindingFreeVars,
) where

import Data.Array.Accelerate.AST.Exp
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.AST.IdxSet ( IdxSet )
import qualified Data.Array.Accelerate.AST.IdxSet as IdxSet
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.AST.Partitioned as Partitioned hiding (PartitionedAcc, PartitionedAfun, PreOpenAcc(..), PreOpenAfun(..))
import Data.Array.Accelerate.Trafo.Exp.Substitution
import Data.Array.Accelerate.Trafo.Operation.Substitution
import Control.Concurrent.MVar
import Data.IORef
import Data.Typeable ( (:~:)(..) )

-- Generic schedule for a uniform memory space and uniform scheduling.
-- E.g., we don't have host and device memory or scheduling.
Expand All @@ -60,9 +72,11 @@ data UniformSchedule exe env where
-> UniformSchedule exe env -- Operations after the if-then-else
-> UniformSchedule exe env

-- The step function of the loop outputs a bool to denote whether the loop should
-- proceed. If true, then the other output values should also be filled, possibly at
-- a later point in time. If it is false, then no other output values may be filled.
Awhile :: InputOutputR input output
-> UniformScheduleFun exe env (input -> Output PrimBool -> ())
-> UniformScheduleFun exe env (input -> output -> ())
-> UniformScheduleFun exe env (input -> (Output PrimBool, output) -> ())
-> BaseVars env input
-> UniformSchedule exe env -- Operations after the while loop
-> UniformSchedule exe env
Expand All @@ -83,23 +97,91 @@ data UniformScheduleFun exe env t where
-> UniformScheduleFun exe env ()

type family Input t where
Input () = ()
Input (a, b) = (Input a, Input b)
Input t = (Signal, Ref t)
Input () = ()
Input (a, b) = (Input a, Input b)
Input t = (Signal, Ref t)

inputSingle :: forall t. GroundR t -> (Input t, Output t) :~: ((Signal, Ref t), (SignalResolver, OutputRef t))
-- Last case of type family Input and Output.
-- We must pattern match to convince the type checker that
-- t is not () or (a, b).
inputSingle (GroundRbuffer _) = Refl
inputSingle (GroundRscalar (VectorScalarType _)) = Refl
inputSingle (GroundRscalar (SingleScalarType (NumSingleType tp))) = case tp of
IntegralNumType TypeInt -> Refl
IntegralNumType TypeInt8 -> Refl
IntegralNumType TypeInt16 -> Refl
IntegralNumType TypeInt32 -> Refl
IntegralNumType TypeInt64 -> Refl
IntegralNumType TypeWord -> Refl
IntegralNumType TypeWord8 -> Refl
IntegralNumType TypeWord16 -> Refl
IntegralNumType TypeWord32 -> Refl
IntegralNumType TypeWord64 -> Refl
FloatingNumType TypeHalf -> Refl
FloatingNumType TypeFloat -> Refl
FloatingNumType TypeDouble -> Refl

inputR :: forall t. GroundsR t -> BasesR (Input t)
inputR TupRunit = TupRunit
inputR (TupRpair t1 t2) = TupRpair (inputR t1) (inputR t2)
inputR (TupRsingle ground)
-- Last case of type family Input.
-- We must pattern match to convince the type checker that
-- t is not () or (a, b).
| Refl <- inputSingle ground = TupRsingle BaseRsignal `TupRpair` TupRsingle (BaseRref ground)

type family Output t where
Output () = ()
Output (a, b) = (Output a, Output b)
Output t = (SignalResolver, OutputRef t)

outputR :: GroundsR t -> BasesR (Output t)
outputR TupRunit = TupRunit
outputR (TupRpair t1 t2) = TupRpair (outputR t1) (outputR t2)
outputR (TupRsingle ground)
-- Last case of type family Output.
-- We must pattern match to convince the type checker that
-- t is not () or (a, b).
| Refl <- inputSingle ground = TupRsingle BaseRsignalResolver `TupRpair` TupRsingle (BaseRrefWrite ground)

type family ScheduleFunction t where
ScheduleFunction (t1 -> t2) = Input t1 -> ScheduleFunction t2
ScheduleFunction t = Output t -> ()

-- Pattern match to convince the type checker that t is not a function.
scheduleFunctionIsBody :: GroundsR t -> ScheduleFunction t :~: (Output t -> ())
scheduleFunctionIsBody TupRunit = Refl
scheduleFunctionIsBody TupRpair{} = Refl
scheduleFunctionIsBody (TupRsingle (GroundRbuffer _)) = Refl
scheduleFunctionIsBody (TupRsingle (GroundRscalar tp))
| VectorScalarType _ <- tp = Refl
| SingleScalarType (NumSingleType tp') <- tp = case tp' of
IntegralNumType TypeInt -> Refl
IntegralNumType TypeInt8 -> Refl
IntegralNumType TypeInt16 -> Refl
IntegralNumType TypeInt32 -> Refl
IntegralNumType TypeInt64 -> Refl
IntegralNumType TypeWord -> Refl
IntegralNumType TypeWord8 -> Refl
IntegralNumType TypeWord16 -> Refl
IntegralNumType TypeWord32 -> Refl
IntegralNumType TypeWord64 -> Refl
FloatingNumType TypeHalf -> Refl
FloatingNumType TypeFloat -> Refl
FloatingNumType TypeDouble -> Refl

-- Relation between input and output
data InputOutputR input output where
InputOutputRsignal :: InputOutputR Signal SignalResolver
InputOutputRref :: InputOutputR (Ref f) (OutputRef t)
InputOutputRpair :: InputOutputR i1 o1
-> InputOutputR i2 o2
-> InputOutputR (i1, i2) (o1, o2)
InputOutputRunit :: InputOutputR () ()
InputOutputRsignal :: InputOutputR Signal SignalResolver
-- The next iteration of the loop may signal that it wants to release the buffer,
-- such that the previous iteration can free that buffer (or release it for other operations).
InputOutputRrelease :: InputOutputR SignalResolver Signal
InputOutputRref :: InputOutputR (Ref f) (OutputRef t)
InputOutputRpair :: InputOutputR i1 o1
-> InputOutputR i2 o2
-> InputOutputR (i1, i2) (o1, o2)
InputOutputRunit :: InputOutputR () ()

-- Bindings of instructions which have some return value.
-- They cannot perform side effects.
Expand Down Expand Up @@ -171,3 +253,48 @@ await signals = Effect (SignalAwait signals)
resolve :: [Idx env SignalResolver] -> UniformSchedule exe env -> UniformSchedule exe env
resolve [] = id
resolve signals = Effect (SignalResolve signals)

freeVars :: IsExecutableAcc exe => UniformSchedule exe env -> IdxSet env
freeVars Return = IdxSet.empty
freeVars (Alet lhs bnd s) = bindingFreeVars bnd `IdxSet.union` IdxSet.drop' lhs (freeVars s)
freeVars (Effect effect s) = effectFreeVars effect `IdxSet.union` freeVars s
freeVars (Acond c t f s)
= IdxSet.insertVar c
$ IdxSet.union (freeVars t)
$ IdxSet.union (freeVars f)
$ freeVars s
freeVars (Awhile _ step init continuation)
= IdxSet.union (funFreeVars step)
$ IdxSet.union (IdxSet.fromVarList $ flattenTupR init)
$ freeVars continuation
freeVars (Fork s1 s2) = freeVars s1 `IdxSet.union` freeVars s2

funFreeVars :: IsExecutableAcc exe => UniformScheduleFun exe env t -> IdxSet env
funFreeVars (Sbody s) = freeVars s
funFreeVars (Slam lhs f) = IdxSet.drop' lhs $ funFreeVars f

bindingFreeVars :: Binding env t -> IdxSet env
bindingFreeVars NewSignal = IdxSet.empty
bindingFreeVars (NewRef _) = IdxSet.empty
bindingFreeVars (Alloc _ _ sh) = IdxSet.fromVarList $ flattenTupR sh
bindingFreeVars (Use _ _) = IdxSet.empty
bindingFreeVars (Unit var) = IdxSet.singletonVar var
bindingFreeVars (RefRead var) = IdxSet.singletonVar var
bindingFreeVars (Compute e) = IdxSet.fromList $ map f $ arrayInstrs e
where
f :: Exists (ArrayInstr env) -> Exists (Idx env)
f (Exists (Index (Var _ idx))) = Exists idx
f (Exists (Parameter (Var _ idx))) = Exists idx

effectFreeVars :: IsExecutableAcc exe => Effect exe env -> IdxSet env
effectFreeVars (Exec exe) = IdxSet.fromVarList $ execVars exe
effectFreeVars (SignalAwait signals) = IdxSet.fromList $ map Exists $ signals
effectFreeVars (SignalResolve resolvers) = IdxSet.fromList $ map Exists resolvers
effectFreeVars (RefWrite ref value) = IdxSet.insertVar ref $ IdxSet.singletonVar value

signalResolverImpossible :: GroundsR SignalResolver -> a
signalResolverImpossible (TupRsingle (GroundRscalar tp)) = scalarSignalResolverImpossible tp

scalarSignalResolverImpossible :: ScalarType SignalResolver -> a
scalarSignalResolverImpossible (SingleScalarType (NumSingleType (IntegralNumType tp))) = case tp of {}
scalarSignalResolverImpossible (SingleScalarType (NumSingleType (FloatingNumType tp))) = case tp of {}
5 changes: 5 additions & 0 deletions src/Data/Array/Accelerate/Representation/Type.hs
Expand Up @@ -162,6 +162,11 @@ mapTupR f (TupRsingle a) = TupRsingle $ f a
mapTupR _ TupRunit = TupRunit
mapTupR f (TupRpair a1 a2) = mapTupR f a1 `TupRpair` mapTupR f a2

traverseTupR :: Applicative f => (forall s. a s -> f (b s)) -> TupR a t -> f (TupR b t)
traverseTupR f (TupRsingle a) = TupRsingle <$> f a
traverseTupR _ TupRunit = pure TupRunit
traverseTupR f (TupRpair a1 a2) = TupRpair <$> traverseTupR f a1 <*> traverseTupR f a2

functionImpossible :: TypeR (s -> t) -> a
functionImpossible (TupRsingle (SingleScalarType (NumSingleType tp))) = case tp of
IntegralNumType t -> case t of {}
Expand Down
39 changes: 39 additions & 0 deletions src/Data/Array/Accelerate/Trafo/Exp/Substitution.hs
Expand Up @@ -43,6 +43,7 @@ module Data.Array.Accelerate.Trafo.Exp.Substitution (

RebuildArrayInstr, rebuildArrayInstrMap,
rebuildNoArrayInstr, mapArrayInstr,
arrayInstrs, arrayInstrsFun,

-- ** Checks
isIdentity, extractExpVars,
Expand All @@ -61,6 +62,7 @@ import Data.Array.Accelerate.Representation.Type
import qualified Data.Array.Accelerate.Debug.Stats as Stats

import Data.Kind
import Data.Maybe
import Control.Applicative hiding ( Const )
import Control.Monad
import Prelude hiding ( exp, seq )
Expand Down Expand Up @@ -520,6 +522,43 @@ rebuildArrayInstrFun
rebuildArrayInstrFun v (Body e) = Body <$> rebuildArrayInstrOpenExp v e
rebuildArrayInstrFun v (Lam lhs f) = Lam lhs <$> rebuildArrayInstrFun v f

arrayInstrs :: PreOpenExp arr env a -> [Exists arr]
arrayInstrs e = arrayInstrs' e []

arrayInstrsFun :: PreOpenFun arr env a -> [Exists arr]
arrayInstrsFun f = arrayInstrsFun' f []

arrayInstrs' :: PreOpenExp arr env a -> [Exists arr] -> [Exists arr]
arrayInstrs' expr = case expr of
Let _ e1 e2 -> arrayInstrs' e1 . arrayInstrs' e2
Evar _ -> id
Foreign _ _ _ x -> arrayInstrs' x
Pair e1 e2 -> arrayInstrs' e1 . arrayInstrs' e2
Nil -> id
VecPack _ e -> arrayInstrs' e
VecUnpack _ e -> arrayInstrs' e
IndexSlice _ slix sh -> arrayInstrs' slix . arrayInstrs' sh
IndexFull _ slix sl -> arrayInstrs' slix . arrayInstrs' sl
ToIndex _ sh ix -> arrayInstrs' sh . arrayInstrs' ix
FromIndex _ sh ix -> arrayInstrs' sh . arrayInstrs' ix
Case e rhs def -> arrayInstrs' e . alts rhs . maybe id arrayInstrs' def
Cond c t f -> arrayInstrs' c . arrayInstrs' t . arrayInstrs' f
While c f x -> arrayInstrsFun' c . arrayInstrsFun' f . arrayInstrs' x
Const _ _ -> id
PrimConst _ -> id
PrimApp _ x -> arrayInstrs' x
ArrayInstr arr _ -> (Exists arr :)
ShapeSize _ sh -> arrayInstrs' sh
Undef _ -> id
Coerce _ _ e -> arrayInstrs' e
where
alts :: [(TAG, PreOpenExp arr env b)] -> [Exists arr] -> [Exists arr]
alts [] = id
alts ((_, e):as) = arrayInstrs' e . alts as

arrayInstrsFun' :: PreOpenFun arr env a -> [Exists arr] -> [Exists arr]
arrayInstrsFun' (Body e) = arrayInstrs' e
arrayInstrsFun' (Lam _ f) = arrayInstrsFun' f

extractExpVars :: PreOpenExp arr env a -> Maybe (ExpVars env a)
extractExpVars Nil = Just TupRunit
Expand Down
2 changes: 2 additions & 0 deletions src/Data/Array/Accelerate/Trafo/Operation/Substitution.hs
Expand Up @@ -32,6 +32,8 @@ module Data.Array.Accelerate.Trafo.Operation.Substitution (
pair, alet,
weakenArrayInstr,
strengthenArrayInstr,

reindexVar, reindexVars,
) where

import Data.Array.Accelerate.AST.Idx
Expand Down

0 comments on commit 45cb8f4

Please sign in to comment.