diff --git a/src/Data/Array/Accelerate/AST/Environment.hs b/src/Data/Array/Accelerate/AST/Environment.hs index 35d049a5b..b4fefd63f 100644 --- a/src/Data/Array/Accelerate/AST/Environment.hs +++ b/src/Data/Array/Accelerate/AST/Environment.hs @@ -22,7 +22,7 @@ module Data.Array.Accelerate.AST.Environment ( unionPartialEnv, EnvBinding(..), partialEnvFromList, mapPartialEnv, mapMaybePartialEnv, partialEnvValues, diffPartialEnv, diffPartialEnvWith, intersectPartialEnv, partialEnvTail, partialEnvLast, partialEnvSkip, - partialUpdate, partialEnvToList, + partialUpdate, partialEnvToList, partialEnvSingleton, prjUpdate', prjReplace', update', updates', mapEnv, Identity(..), (:>)(..), weakenId, weakenSucc, weakenSucc', weakenEmpty, @@ -179,6 +179,10 @@ partialEnvValues PEnd = [] partialEnvValues (PNone env) = partialEnvValues env partialEnvValues (PPush env (IdentityF a)) = a : partialEnvValues env +partialEnvSingleton :: Idx env t -> f t -> PartialEnv f env +partialEnvSingleton ZeroIdx v = PPush PEnd v +partialEnvSingleton (SuccIdx idx) v = PNone $ partialEnvSingleton idx v + -- Wrapper to put homogenous types in an Env or PartialEnv newtype IdentityF t f = IdentityF t diff --git a/src/Data/Array/Accelerate/AST/IdxSet.hs b/src/Data/Array/Accelerate/AST/IdxSet.hs index dea413105..d87bf5692 100644 --- a/src/Data/Array/Accelerate/AST/IdxSet.hs +++ b/src/Data/Array/Accelerate/AST/IdxSet.hs @@ -18,13 +18,15 @@ module Data.Array.Accelerate.AST.IdxSet ( IdxSet, - member, intersect, union, insert, skip, - push, empty, drop, drop', fromList + member, varMember, intersect, union, insert, insertVar, skip, + push, empty, drop, drop', fromList, fromVarList, + singleton, singletonVar, ) where import Prelude hiding (drop) import Data.Array.Accelerate.AST.Idx +import Data.Array.Accelerate.AST.Var import Data.Array.Accelerate.AST.Environment hiding ( push ) import Data.Array.Accelerate.AST.LeftHandSide import Data.Maybe @@ -36,6 +38,9 @@ data Present a = Present member :: Idx env t -> IdxSet env -> Bool member idx (IdxSet set) = isJust $ prjPartial idx set +varMember :: Var s env t -> IdxSet env -> Bool +varMember (Var _ idx) = member idx + intersect :: IdxSet env -> IdxSet env -> IdxSet env intersect (IdxSet a) (IdxSet b) = IdxSet $ intersectPartialEnv (\_ _ -> Present) a b @@ -45,6 +50,9 @@ union (IdxSet a) (IdxSet b) = IdxSet $ unionPartialEnv (\_ _ -> Present) a b insert :: Idx env t -> IdxSet env -> IdxSet env insert idx (IdxSet a) = IdxSet $ partialUpdate Present idx a +insertVar :: Var s env t -> IdxSet env -> IdxSet env +insertVar (Var _ idx) = insert idx + skip :: IdxSet env -> IdxSet (env, t) skip = IdxSet . PNone . unIdxSet @@ -67,3 +75,12 @@ toList = map (\(EnvBinding idx _) -> Exists idx) . partialEnvToList . unIdxSet fromList :: [Exists (Idx env)] -> IdxSet env fromList = IdxSet . partialEnvFromList (\_ _ -> Present) . map (\(Exists idx) -> EnvBinding idx Present) + +fromVarList :: [Exists (Var s env)] -> IdxSet env +fromVarList = fromList . map (\(Exists (Var _ idx)) -> Exists idx) + +singleton :: Idx env t -> IdxSet env +singleton idx = IdxSet $ partialEnvSingleton idx Present + +singletonVar :: Var s env t -> IdxSet env +singletonVar (Var _ idx) = singleton idx diff --git a/src/Data/Array/Accelerate/AST/Operation.hs b/src/Data/Array/Accelerate/AST/Operation.hs index d96696446..966b65ad4 100644 --- a/src/Data/Array/Accelerate/AST/Operation.hs +++ b/src/Data/Array/Accelerate/AST/Operation.hs @@ -67,7 +67,7 @@ import Data.Array.Accelerate.Trafo.Exp.Substitution import Data.Array.Accelerate.Trafo.Exp.Shrink import Data.Array.Accelerate.Type import Data.Array.Accelerate.Error -import Data.Typeable ( (:~:)(..) ) +import Data.Typeable ( (:~:)(..) ) import Data.ByteString.Builder.Extra import Language.Haskell.TH ( Q, TExp ) @@ -385,7 +385,10 @@ paramsIn' (TupRsingle v) = ArrayInstr (Parameter v) Nil type ReindexPartial f env env' = forall a. Idx env a -> f (Idx env' a) -reindexVar :: Applicative f => ReindexPartial f env env' -> Var s env t -> f (Var s env' t) +-- The first argument is ReindexPartial, but without the forall and specialized to 't'. +-- This makes it usable in more situations. +-- +reindexVar :: Applicative f => (Idx env t -> f (Idx env' t)) -> Var s env t -> f (Var s env' t) reindexVar k (Var repr ix) = Var repr <$> k ix reindexVars :: Applicative f => ReindexPartial f env env' -> Vars s env t -> f (Vars s env' t) diff --git a/src/Data/Array/Accelerate/AST/Schedule/Uniform.hs b/src/Data/Array/Accelerate/AST/Schedule/Uniform.hs index cfd4c13b0..6d1d7cefa 100644 --- a/src/Data/Array/Accelerate/AST/Schedule/Uniform.hs +++ b/src/Data/Array/Accelerate/AST/Schedule/Uniform.hs @@ -1,8 +1,11 @@ -{-# LANGUAGE GADTs #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Schedule.Uniform @@ -16,16 +19,23 @@ module Data.Array.Accelerate.AST.Schedule.Uniform ( UniformSchedule(..), UniformScheduleFun(..), - Input, Output, InputOutputR(..), + Input, Output, inputSingle, inputR, outputR, InputOutputR(..), + ScheduleFunction, scheduleFunctionIsBody, Binding(..), Effect(..), BaseR(..), BasesR, BaseVar, BaseVars, BLeftHandSide, Signal(..), SignalResolver(..), Ref(..), OutputRef(..), module Partitioned, await, resolve, + signalResolverImpossible, scalarSignalResolverImpossible, + + -- ** Free variables + freeVars, funFreeVars, effectFreeVars, bindingFreeVars, ) where import Data.Array.Accelerate.AST.Exp import Data.Array.Accelerate.AST.Idx +import Data.Array.Accelerate.AST.IdxSet ( IdxSet ) +import qualified Data.Array.Accelerate.AST.IdxSet as IdxSet import Data.Array.Accelerate.AST.LeftHandSide import Data.Array.Accelerate.AST.Var import Data.Array.Accelerate.Type @@ -33,9 +43,11 @@ import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Shape import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.AST.Partitioned as Partitioned hiding (PartitionedAcc, PartitionedAfun, PreOpenAcc(..), PreOpenAfun(..)) +import Data.Array.Accelerate.Trafo.Exp.Substitution import Data.Array.Accelerate.Trafo.Operation.Substitution import Control.Concurrent.MVar import Data.IORef +import Data.Typeable ( (:~:)(..) ) -- Generic schedule for a uniform memory space and uniform scheduling. -- E.g., we don't have host and device memory or scheduling. @@ -60,9 +72,11 @@ data UniformSchedule exe env where -> UniformSchedule exe env -- Operations after the if-then-else -> UniformSchedule exe env + -- The step function of the loop outputs a bool to denote whether the loop should + -- proceed. If true, then the other output values should also be filled, possibly at + -- a later point in time. If it is false, then no other output values may be filled. Awhile :: InputOutputR input output - -> UniformScheduleFun exe env (input -> Output PrimBool -> ()) - -> UniformScheduleFun exe env (input -> output -> ()) + -> UniformScheduleFun exe env (input -> (Output PrimBool, output) -> ()) -> BaseVars env input -> UniformSchedule exe env -- Operations after the while loop -> UniformSchedule exe env @@ -83,23 +97,91 @@ data UniformScheduleFun exe env t where -> UniformScheduleFun exe env () type family Input t where - Input () = () - Input (a, b) = (Input a, Input b) - Input t = (Signal, Ref t) + Input () = () + Input (a, b) = (Input a, Input b) + Input t = (Signal, Ref t) + +inputSingle :: forall t. GroundR t -> (Input t, Output t) :~: ((Signal, Ref t), (SignalResolver, OutputRef t)) +-- Last case of type family Input and Output. +-- We must pattern match to convince the type checker that +-- t is not () or (a, b). +inputSingle (GroundRbuffer _) = Refl +inputSingle (GroundRscalar (VectorScalarType _)) = Refl +inputSingle (GroundRscalar (SingleScalarType (NumSingleType tp))) = case tp of + IntegralNumType TypeInt -> Refl + IntegralNumType TypeInt8 -> Refl + IntegralNumType TypeInt16 -> Refl + IntegralNumType TypeInt32 -> Refl + IntegralNumType TypeInt64 -> Refl + IntegralNumType TypeWord -> Refl + IntegralNumType TypeWord8 -> Refl + IntegralNumType TypeWord16 -> Refl + IntegralNumType TypeWord32 -> Refl + IntegralNumType TypeWord64 -> Refl + FloatingNumType TypeHalf -> Refl + FloatingNumType TypeFloat -> Refl + FloatingNumType TypeDouble -> Refl + +inputR :: forall t. GroundsR t -> BasesR (Input t) +inputR TupRunit = TupRunit +inputR (TupRpair t1 t2) = TupRpair (inputR t1) (inputR t2) +inputR (TupRsingle ground) + -- Last case of type family Input. + -- We must pattern match to convince the type checker that + -- t is not () or (a, b). + | Refl <- inputSingle ground = TupRsingle BaseRsignal `TupRpair` TupRsingle (BaseRref ground) type family Output t where Output () = () Output (a, b) = (Output a, Output b) Output t = (SignalResolver, OutputRef t) +outputR :: GroundsR t -> BasesR (Output t) +outputR TupRunit = TupRunit +outputR (TupRpair t1 t2) = TupRpair (outputR t1) (outputR t2) +outputR (TupRsingle ground) + -- Last case of type family Output. + -- We must pattern match to convince the type checker that + -- t is not () or (a, b). + | Refl <- inputSingle ground = TupRsingle BaseRsignalResolver `TupRpair` TupRsingle (BaseRrefWrite ground) + +type family ScheduleFunction t where + ScheduleFunction (t1 -> t2) = Input t1 -> ScheduleFunction t2 + ScheduleFunction t = Output t -> () + +-- Pattern match to convince the type checker that t is not a function. +scheduleFunctionIsBody :: GroundsR t -> ScheduleFunction t :~: (Output t -> ()) +scheduleFunctionIsBody TupRunit = Refl +scheduleFunctionIsBody TupRpair{} = Refl +scheduleFunctionIsBody (TupRsingle (GroundRbuffer _)) = Refl +scheduleFunctionIsBody (TupRsingle (GroundRscalar tp)) + | VectorScalarType _ <- tp = Refl + | SingleScalarType (NumSingleType tp') <- tp = case tp' of + IntegralNumType TypeInt -> Refl + IntegralNumType TypeInt8 -> Refl + IntegralNumType TypeInt16 -> Refl + IntegralNumType TypeInt32 -> Refl + IntegralNumType TypeInt64 -> Refl + IntegralNumType TypeWord -> Refl + IntegralNumType TypeWord8 -> Refl + IntegralNumType TypeWord16 -> Refl + IntegralNumType TypeWord32 -> Refl + IntegralNumType TypeWord64 -> Refl + FloatingNumType TypeHalf -> Refl + FloatingNumType TypeFloat -> Refl + FloatingNumType TypeDouble -> Refl + -- Relation between input and output data InputOutputR input output where - InputOutputRsignal :: InputOutputR Signal SignalResolver - InputOutputRref :: InputOutputR (Ref f) (OutputRef t) - InputOutputRpair :: InputOutputR i1 o1 - -> InputOutputR i2 o2 - -> InputOutputR (i1, i2) (o1, o2) - InputOutputRunit :: InputOutputR () () + InputOutputRsignal :: InputOutputR Signal SignalResolver + -- The next iteration of the loop may signal that it wants to release the buffer, + -- such that the previous iteration can free that buffer (or release it for other operations). + InputOutputRrelease :: InputOutputR SignalResolver Signal + InputOutputRref :: InputOutputR (Ref f) (OutputRef t) + InputOutputRpair :: InputOutputR i1 o1 + -> InputOutputR i2 o2 + -> InputOutputR (i1, i2) (o1, o2) + InputOutputRunit :: InputOutputR () () -- Bindings of instructions which have some return value. -- They cannot perform side effects. @@ -171,3 +253,48 @@ await signals = Effect (SignalAwait signals) resolve :: [Idx env SignalResolver] -> UniformSchedule exe env -> UniformSchedule exe env resolve [] = id resolve signals = Effect (SignalResolve signals) + +freeVars :: IsExecutableAcc exe => UniformSchedule exe env -> IdxSet env +freeVars Return = IdxSet.empty +freeVars (Alet lhs bnd s) = bindingFreeVars bnd `IdxSet.union` IdxSet.drop' lhs (freeVars s) +freeVars (Effect effect s) = effectFreeVars effect `IdxSet.union` freeVars s +freeVars (Acond c t f s) + = IdxSet.insertVar c + $ IdxSet.union (freeVars t) + $ IdxSet.union (freeVars f) + $ freeVars s +freeVars (Awhile _ step init continuation) + = IdxSet.union (funFreeVars step) + $ IdxSet.union (IdxSet.fromVarList $ flattenTupR init) + $ freeVars continuation +freeVars (Fork s1 s2) = freeVars s1 `IdxSet.union` freeVars s2 + +funFreeVars :: IsExecutableAcc exe => UniformScheduleFun exe env t -> IdxSet env +funFreeVars (Sbody s) = freeVars s +funFreeVars (Slam lhs f) = IdxSet.drop' lhs $ funFreeVars f + +bindingFreeVars :: Binding env t -> IdxSet env +bindingFreeVars NewSignal = IdxSet.empty +bindingFreeVars (NewRef _) = IdxSet.empty +bindingFreeVars (Alloc _ _ sh) = IdxSet.fromVarList $ flattenTupR sh +bindingFreeVars (Use _ _) = IdxSet.empty +bindingFreeVars (Unit var) = IdxSet.singletonVar var +bindingFreeVars (RefRead var) = IdxSet.singletonVar var +bindingFreeVars (Compute e) = IdxSet.fromList $ map f $ arrayInstrs e + where + f :: Exists (ArrayInstr env) -> Exists (Idx env) + f (Exists (Index (Var _ idx))) = Exists idx + f (Exists (Parameter (Var _ idx))) = Exists idx + +effectFreeVars :: IsExecutableAcc exe => Effect exe env -> IdxSet env +effectFreeVars (Exec exe) = IdxSet.fromVarList $ execVars exe +effectFreeVars (SignalAwait signals) = IdxSet.fromList $ map Exists $ signals +effectFreeVars (SignalResolve resolvers) = IdxSet.fromList $ map Exists resolvers +effectFreeVars (RefWrite ref value) = IdxSet.insertVar ref $ IdxSet.singletonVar value + +signalResolverImpossible :: GroundsR SignalResolver -> a +signalResolverImpossible (TupRsingle (GroundRscalar tp)) = scalarSignalResolverImpossible tp + +scalarSignalResolverImpossible :: ScalarType SignalResolver -> a +scalarSignalResolverImpossible (SingleScalarType (NumSingleType (IntegralNumType tp))) = case tp of {} +scalarSignalResolverImpossible (SingleScalarType (NumSingleType (FloatingNumType tp))) = case tp of {} diff --git a/src/Data/Array/Accelerate/Representation/Type.hs b/src/Data/Array/Accelerate/Representation/Type.hs index 211fd0c0e..8efdb3a12 100644 --- a/src/Data/Array/Accelerate/Representation/Type.hs +++ b/src/Data/Array/Accelerate/Representation/Type.hs @@ -162,6 +162,11 @@ mapTupR f (TupRsingle a) = TupRsingle $ f a mapTupR _ TupRunit = TupRunit mapTupR f (TupRpair a1 a2) = mapTupR f a1 `TupRpair` mapTupR f a2 +traverseTupR :: Applicative f => (forall s. a s -> f (b s)) -> TupR a t -> f (TupR b t) +traverseTupR f (TupRsingle a) = TupRsingle <$> f a +traverseTupR _ TupRunit = pure TupRunit +traverseTupR f (TupRpair a1 a2) = TupRpair <$> traverseTupR f a1 <*> traverseTupR f a2 + functionImpossible :: TypeR (s -> t) -> a functionImpossible (TupRsingle (SingleScalarType (NumSingleType tp))) = case tp of IntegralNumType t -> case t of {} diff --git a/src/Data/Array/Accelerate/Trafo/Exp/Substitution.hs b/src/Data/Array/Accelerate/Trafo/Exp/Substitution.hs index 2a7b731ae..812bb3bb5 100644 --- a/src/Data/Array/Accelerate/Trafo/Exp/Substitution.hs +++ b/src/Data/Array/Accelerate/Trafo/Exp/Substitution.hs @@ -43,6 +43,7 @@ module Data.Array.Accelerate.Trafo.Exp.Substitution ( RebuildArrayInstr, rebuildArrayInstrMap, rebuildNoArrayInstr, mapArrayInstr, + arrayInstrs, arrayInstrsFun, -- ** Checks isIdentity, extractExpVars, @@ -61,6 +62,7 @@ import Data.Array.Accelerate.Representation.Type import qualified Data.Array.Accelerate.Debug.Stats as Stats import Data.Kind +import Data.Maybe import Control.Applicative hiding ( Const ) import Control.Monad import Prelude hiding ( exp, seq ) @@ -520,6 +522,43 @@ rebuildArrayInstrFun rebuildArrayInstrFun v (Body e) = Body <$> rebuildArrayInstrOpenExp v e rebuildArrayInstrFun v (Lam lhs f) = Lam lhs <$> rebuildArrayInstrFun v f +arrayInstrs :: PreOpenExp arr env a -> [Exists arr] +arrayInstrs e = arrayInstrs' e [] + +arrayInstrsFun :: PreOpenFun arr env a -> [Exists arr] +arrayInstrsFun f = arrayInstrsFun' f [] + +arrayInstrs' :: PreOpenExp arr env a -> [Exists arr] -> [Exists arr] +arrayInstrs' expr = case expr of + Let _ e1 e2 -> arrayInstrs' e1 . arrayInstrs' e2 + Evar _ -> id + Foreign _ _ _ x -> arrayInstrs' x + Pair e1 e2 -> arrayInstrs' e1 . arrayInstrs' e2 + Nil -> id + VecPack _ e -> arrayInstrs' e + VecUnpack _ e -> arrayInstrs' e + IndexSlice _ slix sh -> arrayInstrs' slix . arrayInstrs' sh + IndexFull _ slix sl -> arrayInstrs' slix . arrayInstrs' sl + ToIndex _ sh ix -> arrayInstrs' sh . arrayInstrs' ix + FromIndex _ sh ix -> arrayInstrs' sh . arrayInstrs' ix + Case e rhs def -> arrayInstrs' e . alts rhs . maybe id arrayInstrs' def + Cond c t f -> arrayInstrs' c . arrayInstrs' t . arrayInstrs' f + While c f x -> arrayInstrsFun' c . arrayInstrsFun' f . arrayInstrs' x + Const _ _ -> id + PrimConst _ -> id + PrimApp _ x -> arrayInstrs' x + ArrayInstr arr _ -> (Exists arr :) + ShapeSize _ sh -> arrayInstrs' sh + Undef _ -> id + Coerce _ _ e -> arrayInstrs' e + where + alts :: [(TAG, PreOpenExp arr env b)] -> [Exists arr] -> [Exists arr] + alts [] = id + alts ((_, e):as) = arrayInstrs' e . alts as + +arrayInstrsFun' :: PreOpenFun arr env a -> [Exists arr] -> [Exists arr] +arrayInstrsFun' (Body e) = arrayInstrs' e +arrayInstrsFun' (Lam _ f) = arrayInstrsFun' f extractExpVars :: PreOpenExp arr env a -> Maybe (ExpVars env a) extractExpVars Nil = Just TupRunit diff --git a/src/Data/Array/Accelerate/Trafo/Operation/Substitution.hs b/src/Data/Array/Accelerate/Trafo/Operation/Substitution.hs index 2bb7a4f19..52557bbff 100644 --- a/src/Data/Array/Accelerate/Trafo/Operation/Substitution.hs +++ b/src/Data/Array/Accelerate/Trafo/Operation/Substitution.hs @@ -32,6 +32,8 @@ module Data.Array.Accelerate.Trafo.Operation.Substitution ( pair, alet, weakenArrayInstr, strengthenArrayInstr, + + reindexVar, reindexVars, ) where import Data.Array.Accelerate.AST.Idx diff --git a/src/Data/Array/Accelerate/Trafo/Schedule/Uniform.hs b/src/Data/Array/Accelerate/Trafo/Schedule/Uniform.hs index 50be2556a..cd8707249 100644 --- a/src/Data/Array/Accelerate/Trafo/Schedule/Uniform.hs +++ b/src/Data/Array/Accelerate/Trafo/Schedule/Uniform.hs @@ -32,20 +32,22 @@ import Prelude hiding (read) import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.AST.IdxSet (IdxSet) -import qualified Data.Array.Accelerate.AST.IdxSet as IdxSet +import qualified Data.Array.Accelerate.AST.IdxSet as IdxSet import Data.Array.Accelerate.AST.Var import Data.Array.Accelerate.AST.LeftHandSide import Data.Array.Accelerate.AST.Schedule.Uniform import Data.Array.Accelerate.AST.Environment -import qualified Data.Array.Accelerate.AST.Partitioned as C +import qualified Data.Array.Accelerate.AST.Partitioned as C +import Data.Array.Accelerate.Analysis.Match ( (:~:)(..) ) import Data.Array.Accelerate.Trafo.Var import Data.Array.Accelerate.Trafo.Substitution import Data.Array.Accelerate.Trafo.Exp.Substitution -import Data.Array.Accelerate.Trafo.Operation.Substitution (strengthenArrayInstr) +import Data.Array.Accelerate.Trafo.Operation.Substitution (strengthenArrayInstr, reindexVar, reindexVars) import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type import Data.Array.Accelerate.Error +import Data.Kind import Data.Maybe import Data.List import qualified Data.Set as S @@ -54,11 +56,11 @@ import GHC.Stack instance IsExecutableAcc exe => Sink' (UniformSchedule exe) where weaken' _ Return = Return weaken' k (Alet lhs b s) - | Exists lhs' <- rebuildLHS lhs = Alet lhs' (weaken k b) (weaken' (sinkWithLHS lhs lhs' k) s) - weaken' k (Effect effect s) = Effect (weaken' k effect) (weaken' k s) - weaken' k (Acond cond true false s) = Acond (weaken k cond) (weaken' k true) (weaken' k false) (weaken' k s) - weaken' k (Awhile io cond step input s) = Awhile io (weaken k cond) (weaken k step) (mapTupR (weaken k) input) (weaken' k s) - weaken' k (Fork s1 s2) = Fork (weaken' k s1) (weaken' k s2) + | Exists lhs' <- rebuildLHS lhs = Alet lhs' (weaken k b) (weaken' (sinkWithLHS lhs lhs' k) s) + weaken' k (Effect effect s) = Effect (weaken' k effect) (weaken' k s) + weaken' k (Acond cond true false s) = Acond (weaken k cond) (weaken' k true) (weaken' k false) (weaken' k s) + weaken' k (Awhile io f input s) = Awhile io (weaken k f) (mapTupR (weaken k) input) (weaken' k s) + weaken' k (Fork s1 s2) = Fork (weaken' k s1) (weaken' k s2) instance IsExecutableAcc exe => Sink (UniformScheduleFun exe) where weaken k (Slam lhs f) @@ -286,7 +288,7 @@ instance Ord (Sync t) where SyncRead < SyncWrite = True _ < _ = False -data Acquire genv t where +data Acquire genv where Acquire :: Modifier m -> GroundVar genv (Buffer e) -- Returns a signal to wait on before the operation can start. @@ -299,7 +301,7 @@ data Acquire genv t where -- Also provides a SignalResolver which should be resolved -- when the operation is finished. Later reads or writes to -- this buffer will wait on this signal. - -> Acquire genv (Signal, SignalResolver) + -> Acquire genv data ConvertEnv genv fenv fenv' where ConvertEnvNil :: ConvertEnv genv fenv fenv @@ -308,7 +310,7 @@ data ConvertEnv genv fenv fenv' where -> ConvertEnv genv fenv2 fenv3 -> ConvertEnv genv fenv1 fenv3 - ConvertEnvAcquire :: Acquire genv (Signal, SignalResolver) + ConvertEnvAcquire :: Acquire genv -> ConvertEnv genv fenv ((fenv, Signal), SignalResolver) ConvertEnvFuture :: GroundVar genv e @@ -319,66 +321,96 @@ data OutputEnv fenv fenv' t r where -> OutputEnv fenv2 fenv3 t' r' -> OutputEnv fenv1 fenv3 (t, t') (r, r') - -- First SignalResolver grants read access, second guarantees that all reads have been finished. + -- First SignalResolver grants access to the ref, the second grants read access and the + -- third guarantees that all reads have been finished. -- Together they thus grant write access. - OutputEnvUnique :: BLeftHandSide ((SignalResolver, SignalResolver), OutputRef (Buffer t)) fenv fenv' - -> GroundR (Buffer t) + -- + OutputEnvUnique :: fenv' ~ ((((fenv, OutputRef (Buffer t)), SignalResolver), SignalResolver), SignalResolver) + => ScalarType t + -> OutputEnv fenv fenv' (Buffer t) (((SignalResolver, SignalResolver), SignalResolver), OutputRef (Buffer t)) + + -- First SignalResolver grants access to the ref, the second grants read access. + -- + OutputEnvShared :: fenv' ~ (((fenv, OutputRef (Buffer t)), SignalResolver), SignalResolver) + => ScalarType t -> OutputEnv fenv fenv' (Buffer t) ((SignalResolver, SignalResolver), OutputRef (Buffer t)) -- Scalar values or shared buffers - OutputEnvShared :: BLeftHandSide (SignalResolver, OutputRef t) fenv fenv' - -> GroundR t + OutputEnvScalar :: fenv' ~ ((fenv, OutputRef t), SignalResolver) + => ScalarType t -> OutputEnv fenv fenv' t (SignalResolver, OutputRef t) OutputEnvUnit :: OutputEnv fenv fenv () () +outputEnvGroundsR :: OutputEnv fenv fenv' t r -> GroundsR t +outputEnvGroundsR (OutputEnvPair out1 out2) = outputEnvGroundsR out1 `TupRpair` outputEnvGroundsR out2 +outputEnvGroundsR (OutputEnvUnique tp) = TupRsingle $ GroundRbuffer tp +outputEnvGroundsR (OutputEnvShared tp) = TupRsingle $ GroundRbuffer tp +outputEnvGroundsR (OutputEnvScalar tp) = TupRsingle $ GroundRscalar tp +outputEnvGroundsR OutputEnvUnit = TupRunit + data OutputVars t r where OutputVarsPair :: OutputVars t r -> OutputVars t' r' -> OutputVars (t, t') (r, r') - OutputVarsUnique :: OutputVars (Buffer t) ((SignalResolver, SignalResolver), OutputRef (Buffer t)) + -- The SignalResolvers grant access to the reference, to reading the buffer and writing to the buffer. + -- The consumer of this buffer is the unique consumer of it, and thus takes ownership (and responsibility to deallocate it). + OutputVarsUnique :: OutputVars (Buffer t) (((SignalResolver, SignalResolver), SignalResolver), OutputRef (Buffer t)) + + -- The SignalResolvers grant access to the reference and to reading the buffer. + -- The consumer of this buffer does not get ownership, there may be multiple references to this buffer. + OutputVarsShared :: OutputVars (Buffer t) ((SignalResolver, SignalResolver), OutputRef (Buffer t)) - OutputVarsShared :: OutputVars t (SignalResolver, OutputRef t) + OutputVarsScalar :: ScalarType t -> OutputVars t (SignalResolver, OutputRef t) - -- No need to propagate the output, as we reused the same variables (using Destination in PartialDeclare) - -- Also used for Unit - OutputVarsIgnore :: OutputVars t r + -- There is no output (unit) or the output variables are reused + -- with destination-passing-style. + -- We thus do not need to copy the results manually. + -- + OutputVarsIgnore :: OutputVars t () -data DeclareOutput fenv t where - DeclareOutput :: OutputEnv fenv fenv' t r +data DefineOutput fenv t where + DefineOutput :: OutputEnv fenv fenv' t r -> fenv :> fenv' -> (forall fenv'' . fenv' :> fenv'' -> BaseVars fenv'' r) - -> DeclareOutput fenv t + -> DefineOutput fenv t -declareOutput :: forall fenv t. +defineOutput :: forall fenv t. GroundsR t -> Uniquenesses t - -> DeclareOutput fenv t -declareOutput (TupRsingle tp) (TupRsingle Unique) = DeclareOutput env subst value + -> DefineOutput fenv t +defineOutput (TupRsingle (GroundRbuffer tp)) (TupRsingle Unique) = DefineOutput env subst value where - env = OutputEnvUnique lhs tp - lhs = LeftHandSidePair lhsSignalResolver lhsSignalResolver `LeftHandSidePair` LeftHandSideSingle (BaseRrefWrite tp) + env = OutputEnvUnique tp + + subst = weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc weakenId + + value :: forall fenv''. ((((fenv, OutputRef t), SignalResolver), SignalResolver), SignalResolver) :> fenv'' -> BaseVars fenv'' (((SignalResolver, SignalResolver), SignalResolver), OutputRef t) + value k = ((TupRsingle (Var BaseRsignalResolver $ k >:> ZeroIdx) `TupRpair` TupRsingle (Var BaseRsignalResolver $ k >:> SuccIdx ZeroIdx)) `TupRpair` TupRsingle (Var BaseRsignalResolver (k >:> SuccIdx (SuccIdx ZeroIdx)))) `TupRpair` TupRsingle (Var (BaseRrefWrite $ GroundRbuffer tp) (k >:> SuccIdx (SuccIdx $ SuccIdx ZeroIdx))) +defineOutput (TupRsingle (GroundRscalar tp)) (TupRsingle Unique) = bufferImpossible tp +defineOutput (TupRsingle (GroundRbuffer tp)) _ = DefineOutput env subst value + where + env = OutputEnvShared tp subst = weakenSucc $ weakenSucc $ weakenSucc weakenId - value :: forall fenv''. (((fenv, SignalResolver), SignalResolver), OutputRef t) :> fenv'' -> BaseVars fenv'' ((SignalResolver, SignalResolver), OutputRef t) - value k = (TupRsingle (Var BaseRsignalResolver $ k >:> SuccIdx (SuccIdx ZeroIdx)) `TupRpair` TupRsingle (Var BaseRsignalResolver $ k >:> SuccIdx ZeroIdx)) `TupRpair` TupRsingle (Var (BaseRrefWrite tp) (k >:> ZeroIdx)) -declareOutput (TupRsingle tp) _ = DeclareOutput env subst value + value :: forall fenv''. ((((fenv, OutputRef t), SignalResolver), SignalResolver)) :> fenv'' -> BaseVars fenv'' ((SignalResolver, SignalResolver), OutputRef t) + value k = (TupRsingle (Var BaseRsignalResolver $ k >:> ZeroIdx) `TupRpair` TupRsingle (Var BaseRsignalResolver $ k >:> SuccIdx ZeroIdx)) `TupRpair` TupRsingle (Var (BaseRrefWrite $ GroundRbuffer tp) (k >:> SuccIdx (SuccIdx ZeroIdx))) +defineOutput (TupRsingle (GroundRscalar tp)) _ = DefineOutput env subst value where - env = OutputEnvShared lhs tp - lhs = lhsSignalResolver `LeftHandSidePair` LeftHandSideSingle (BaseRrefWrite tp) + env = OutputEnvScalar tp subst = weakenSucc $ weakenSucc weakenId - value :: forall fenv''. ((fenv, SignalResolver), OutputRef t) :> fenv'' -> BaseVars fenv'' (SignalResolver, OutputRef t) - value k = TupRsingle (Var BaseRsignalResolver $ k >:> SuccIdx ZeroIdx) `TupRpair` TupRsingle (Var (BaseRrefWrite tp) (k >:> ZeroIdx)) -declareOutput (TupRpair t1 t2) us - | DeclareOutput env1 subst1 value1 <- declareOutput t1 u1 - , DeclareOutput env2 subst2 value2 <- declareOutput t2 u2 = DeclareOutput (OutputEnvPair env1 env2) (subst2 .> subst1) (\k -> value1 (k .> subst2) `TupRpair` value2 k) + value :: forall fenv''. ((fenv, OutputRef t), SignalResolver) :> fenv'' -> BaseVars fenv'' (SignalResolver, OutputRef t) + value k = TupRsingle (Var BaseRsignalResolver $ k >:> ZeroIdx) `TupRpair` TupRsingle (Var (BaseRrefWrite $ GroundRscalar tp) (k >:> SuccIdx ZeroIdx)) +defineOutput (TupRpair t1 t2) us + | DefineOutput env1 subst1 value1 <- defineOutput t1 u1 + , DefineOutput env2 subst2 value2 <- defineOutput t2 u2 = DefineOutput (OutputEnvPair env1 env2) (subst2 .> subst1) (\k -> value1 (k .> subst2) `TupRpair` value2 k) where (u1, u2) = pairUniqueness us -declareOutput TupRunit _ = DeclareOutput OutputEnvUnit weakenId (const TupRunit) +defineOutput TupRunit _ = DefineOutput OutputEnvUnit weakenId (const TupRunit) writeOutput :: OutputEnv fenv fenv' t r -> BaseVars fenv'' r -> BaseVars fenv'' t -> UniformSchedule (Cluster op) fenv'' writeOutput outputEnv outputVars valueVars = go outputEnv outputVars valueVars Return @@ -386,13 +418,15 @@ writeOutput outputEnv outputVars valueVars = go outputEnv outputVars valueVars R go :: OutputEnv fenv fenv' t r -> BaseVars fenv'' r -> BaseVars fenv'' t -> UniformSchedule (Cluster op) fenv'' -> UniformSchedule (Cluster op) fenv'' go OutputEnvUnit _ _ = id go (OutputEnvPair o1 o2) (TupRpair r1 r2) (TupRpair v1 v2) = go o1 r1 v1 . go o2 r2 v2 - go (OutputEnvShared _ _) (TupRpair (TupRsingle signal) (TupRsingle ref)) (TupRsingle v) + go (OutputEnvScalar _) (TupRpair (TupRsingle signal) (TupRsingle ref)) (TupRsingle v) = Effect (RefWrite ref v) . Effect (SignalResolve [varIdx signal]) - go (OutputEnvUnique _ _) (TupRpair (TupRpair (TupRsingle s1) (TupRsingle s2)) (TupRsingle ref)) (TupRsingle v) + go (OutputEnvShared _) (TupRpair (TupRsingle s1 `TupRpair` TupRsingle s2) (TupRsingle ref)) (TupRsingle v) + = Effect (RefWrite ref v) + . Effect (SignalResolve [varIdx s1, varIdx s2]) + go (OutputEnvUnique _) (TupRpair (TupRpair (TupRsingle s1 `TupRpair` TupRsingle s2) (TupRsingle s3)) (TupRsingle ref)) (TupRsingle v) = Effect (RefWrite ref v) - . Effect (SignalResolve [varIdx s1]) - . Effect (SignalResolve [varIdx s2]) + . Effect (SignalResolve [varIdx s1, varIdx s2, varIdx s3]) data ReEnv genv fenv where ReEnvEnd :: ReEnv genv fenv @@ -506,6 +540,15 @@ convertEnvFromList (Exists var:vars) , Exists e2 <- convertEnvFromList vars = Exists $ e1 `ConvertEnvSeq` e2 +convertEnvToList :: ConvertEnv genv fenv fenv' -> [Exists (Idx genv)] +convertEnvToList = (`go` []) + where + go :: ConvertEnv genv fenv fenv' -> [Exists (Idx genv)] -> [Exists (Idx genv)] + go ConvertEnvNil = id + go (ConvertEnvSeq e1 e2) = go e1 . go e2 + go (ConvertEnvAcquire (Acquire _ (Var _ idx))) = (Exists idx :) + go (ConvertEnvFuture (Var _ idx)) = (Exists idx :) + convertEnvVar :: Var AccessGroundR genv t -> Exists (ConvertEnv genv fenv) convertEnvVar (Var (AccessGroundRscalar tp) ix) = Exists $ ConvertEnvFuture $ Var (GroundRscalar tp) ix convertEnvVar (Var (AccessGroundRbuffer m tp) ix) = Exists $ ConvertEnvFuture var `ConvertEnvSeq` ConvertEnvAcquire (Acquire m var) @@ -618,6 +661,13 @@ data PartialScheduleFun op genv t where Pbody :: PartialSchedule op genv t -> PartialScheduleFun op genv t +instance HasGroundsR (PartialSchedule op genv) where + groundsR (PartialDo outputEnv _ _) = outputEnvGroundsR outputEnv + groundsR (PartialReturn _ vars) = mapTupR varType vars + groundsR (PartialDeclare _ _ _ _ _ p) = groundsR p + groundsR (PartialAcond _ _ p _) = groundsR p + groundsR (PartialAwhile _ _ _ _ vars) = groundsR vars + data MaybeVar genv t where NoVars :: MaybeVar genv t ReturnVar :: GroundVar genv t -> MaybeVar genv t @@ -672,11 +722,11 @@ joinVars TupRunit _ = TupRunit joinVars _ TupRunit = TupRunit joinVars _ _ = TupRsingle NoVars -data Exists' (a :: (* -> * -> *) -> *) where +data Exists' (a :: (Type -> Type -> Type) -> Type) where Exists' :: a m -> Exists' a -partialSchedule :: forall op genv1 t1. C.PartitionedAcc op genv1 t1 -> PartialSchedule op genv1 t1 -partialSchedule = (\(s, _, _) -> s) . travA (TupRsingle Shared) +partialSchedule :: forall op genv1 t1. C.PartitionedAcc op genv1 t1 -> (PartialSchedule op genv1 t1, IdxSet genv1) +partialSchedule = (\(s, used, _) -> (s, used)) . travA (TupRsingle Shared) where travA :: forall genv t. Uniquenesses t -> C.PartitionedAcc op genv t -> (PartialSchedule op genv t, IdxSet genv, MaybeVars genv t) travA _ (C.Exec cluster) @@ -693,7 +743,7 @@ partialSchedule = (\(s, _, _) -> s) . travA (TupRsingle Shared) $ Effect (Exec cluster') $ Effect (SignalResolve resolvers) $ Return - , undefined + , IdxSet.fromList $ convertEnvToList env , TupRunit ) | otherwise = error "partialSchedule: reindexExecPartial returned Nothing. Probably some variable is missing in 'execVars'" @@ -707,7 +757,7 @@ partialSchedule = (\(s, _, _) -> s) . travA (TupRsingle Shared) combineMod' In In = Exists' In combineMod' Out Out = Exists' Out combineMod' _ _ = Exists' Mut - travA us (C.Return vars) = (PartialReturn us vars, IdxSet.fromList $ map (\(Exists (Var _ idx)) -> Exists idx) $ flattenTupR vars, mapTupR f vars) + travA us (C.Return vars) = (PartialReturn us vars, IdxSet.fromVarList $ flattenTupR vars, mapTupR f vars) where duplicates = map head $ filter (\g -> length g >= 2) $ group $ sort $ map (\(Exists (Var _ ix)) -> idxToInt ix) $ flattenTupR vars @@ -735,14 +785,18 @@ partialSchedule = (\(s, _, _) -> s) . travA (TupRsingle Shared) (t', used1, vars1) = travA us t (f', used2, vars2) = travA us f vars = joinVars vars1 vars2 - travA _ (C.Awhile us c f vars) = (partialAwhile us c' f' vars, undefined, TupRsingle NoVars) + travA _ (C.Awhile us c f vars) = (partialAwhile us c' f' vars, used1 `IdxSet.union` used2 `IdxSet.union` IdxSet.fromVarList (flattenTupR vars), TupRsingle NoVars) where - c' = partialScheduleFun c - f' = partialScheduleFun f + (c', used1) = partialScheduleFun c + (f', used2) = partialScheduleFun f -partialScheduleFun :: C.PartitionedAfun op genv t -> PartialScheduleFun op genv t -partialScheduleFun (C.Alam lhs f) = Plam lhs $ partialScheduleFun f -partialScheduleFun (C.Abody b) = Pbody $ partialSchedule b +partialScheduleFun :: C.PartitionedAfun op genv t -> (PartialScheduleFun op genv t, IdxSet genv) +partialScheduleFun (C.Alam lhs f) = (Plam lhs f', IdxSet.drop' lhs used) + where + (f', used) = partialScheduleFun f +partialScheduleFun (C.Abody b) = (Pbody b', used) + where + (b', used) = partialSchedule b partialLift1 :: GroundsR s -> (forall fenv. ExpVars fenv t -> Binding fenv s) -> ExpVars genv t -> (PartialSchedule op genv s, IdxSet genv, MaybeVars genv s) partialLift1 tp f vars = partialLift tp (\k -> f <$> strengthenVars k vars) (expVarsList vars) @@ -762,7 +816,7 @@ strengthenVars k (TupRpair v1 v2) = TupRpair <$> strengthenVars k v1 <*> partialLift :: forall op genv s. GroundsR s -> (forall fenv. genv :?> fenv -> Maybe (Binding fenv s)) -> [Exists (GroundVar genv)] -> (PartialSchedule op genv s, IdxSet genv, MaybeVars genv s) partialLift tp f vars - | DeclareOutput outputEnv kOut varsOut <- declareOutput @() @s tp (mapTupR uniqueIfBuffer tp) + | DefineOutput outputEnv kOut varsOut <- defineOutput @() @s tp (mapTupR uniqueIfBuffer tp) , Exists env <- convertEnvReadonlyFromList $ nubBy (\(Exists v1) (Exists v2) -> isJust $ matchVar v1 v2) vars -- TODO: Remove duplicates more efficiently , Reads reEnv k inputBindings <- readRefs $ convertEnvRefs env , DeclareVars lhs k' value <- declareVars $ mapTupR BaseRground tp @@ -778,7 +832,7 @@ partialLift tp f vars $ Alet lhs binding $ Effect (SignalResolve resolvers) $ writeOutput outputEnv (varsOut (k' .> k .> convertEnvWeaken env)) (value weakenId) - , undefined + , IdxSet.fromList $ convertEnvToList env , mapTupR (const NoVars) tp ) @@ -925,9 +979,9 @@ instance Sink Future where -- can get access to the resource. -- subFutureEnvironment :: forall fenv genv op. FutureEnv fenv genv -> SyncEnv genv -> (FutureEnv fenv genv, [UniformSchedule (Cluster op) fenv]) -subFutureEnvironment (PNone fenv) senv = (PNone fenv', actions) +subFutureEnvironment (PNone fenv) (PNone senv) = (PNone fenv', actions) where - (fenv', actions) = subFutureEnvironment fenv $ partialEnvTail senv + (fenv', actions) = subFutureEnvironment fenv senv subFutureEnvironment (PPush fenv f@(FutureScalar _ _ _)) senv = (PPush fenv' f, actions) where (fenv', actions) = subFutureEnvironment fenv $ partialEnvTail senv @@ -971,6 +1025,8 @@ subFutureEnvironment (PPush fenv (FutureBuffer tp signal ref read write)) (PNone = return $ Effect (SignalResolve [rr]) Return | otherwise = [] +subFutureEnvironment PEnd _ = (PEnd, []) +subFutureEnvironment _ _ = internalError "Keys of SyncEnv are not a subset of the keys of the FutureEnv" sub :: forall fenv genv op. FutureEnv fenv genv -> SyncEnv genv -> (FutureEnv fenv genv -> UniformSchedule (Cluster op) fenv) -> UniformSchedule (Cluster op) fenv sub fenv senv body = forks (body fenv' : actions) @@ -1295,6 +1351,9 @@ chainFuture (FutureBuffer tp signal ref read mwrite) SyncWrite SyncWrite lhsSignal :: LeftHandSide BaseR (Signal, SignalResolver) fenv ((fenv, Signal), SignalResolver) lhsSignal = LeftHandSidePair (LeftHandSideSingle BaseRsignal) (LeftHandSideSingle BaseRsignalResolver) +lhsRef :: GroundR tp -> LeftHandSide BaseR (Ref tp, OutputRef tp) fenv ((fenv, Ref tp), OutputRef tp) +lhsRef tp = LeftHandSidePair (LeftHandSideSingle $ BaseRref tp) (LeftHandSideSingle $ BaseRrefWrite tp) + -- Similar to 'fromPartial', but also applies the sub-environment rule fromPartialSub :: forall op fenv genv t r. @@ -1302,40 +1361,110 @@ fromPartialSub => OutputVars t r -> BaseVars fenv r -> FutureEnv fenv genv - -> PartialSchedule (Cluster op) genv t + -> PartialSchedule op genv t -> UniformSchedule (Cluster op) fenv fromPartialSub outputEnv outputVars env partial = sub env (syncEnv partial) (\env' -> fromPartial outputEnv outputVars env' partial) -fromPartial :: forall op fenv genv t r. - HasCallStack - => OutputVars t r - -> BaseVars fenv r - -> FutureEnv fenv genv - -> PartialSchedule (Cluster op) genv t - -> UniformSchedule (Cluster op) fenv +fromPartialFun + :: forall op fenv genv t r. + HasCallStack + => FutureEnv fenv genv + -> PartialScheduleFun op genv t + -> UniformScheduleFun (Cluster op) fenv (ScheduleFunction t) +fromPartialFun env = \case + Pbody body + | grounds <- groundsR body + , Refl <- scheduleFunctionIsBody $ grounds + , DeclareOutput k1 lhs k2 instr outputEnv outputVars <- declareOutput grounds + -> Slam lhs $ Sbody $ instr $ fromPartial outputEnv (outputVars weakenId) (mapPartialEnv (weaken (k2 .> k1)) env) body + Plam lhs fun + | DeclareInput _ lhs' env' <- declareInput env lhs + -> Slam lhs' $ fromPartialFun (env' weakenId) fun + +fromPartial + :: forall op fenv genv t r. + HasCallStack + => OutputVars t r + -> BaseVars fenv r + -> FutureEnv fenv genv + -> PartialSchedule op genv t + -> UniformSchedule (Cluster op) fenv fromPartial outputEnv outputVars env = \case - PartialDo outputEnv' convertEnv schedule -> undefined -- Something with a substitution + PartialDo outputEnv' convertEnv (schedule :: UniformSchedule (Cluster op) fenv') + | Just Refl <- matchOutputVarsWithEnv outputEnv outputEnv' -> + let + kEnv = partialDoSubstituteOutput outputEnv' outputVars + kEnv' :: Env (NewIdx fenv) fenv' + kEnv' = partialDoSubstituteConvertEnv convertEnv env kEnv + + k :: ReindexPartialN Identity fenv' fenv + k idx = Identity $ prj' idx kEnv' + in + runIdentity $ reindexSchedule k schedule -- Something with a substitution + | otherwise -> internalError "OutputVars and OutputEnv do not match" PartialReturn uniquenesses vars -> travReturn vars - PartialDeclare syncEnv lhs dest uniquenesses bnd body -> undefined -- Something with fork - PartialAcond syncEnv condition true false -> acond condition true false - PartialAwhile syncEnv uniquenesses condition step vars -> undefined + PartialDeclare syncEnv lhs dest uniquenesses bnd body + | DeclareBinding k instr outputEnvBnd outputVarsBnd env' <- declareBinding outputEnv outputVars env lhs dest uniquenesses -> + instr $ Fork + (fromPartial outputEnvBnd (outputVarsBnd weakenId) (mapPartialEnv (weaken k) env) bnd) + (fromPartial outputEnv (mapTupR (weaken k) outputVars) (env' weakenId) body) + PartialAcond _ condition true false -> acond condition true false + PartialAwhile _ uniquenesses condition step initial -> awhile uniquenesses condition step initial where travReturn :: GroundVars genv t -> UniformSchedule (Cluster op) fenv travReturn vars = forks ((\(signals, s) -> await signals s) <$> travReturn' outputEnv outputVars vars []) travReturn' :: OutputVars t' r' -> BaseVars fenv r' -> GroundVars genv t' -> [([Idx fenv Signal], UniformSchedule (Cluster op) fenv)] -> [([Idx fenv Signal], UniformSchedule (Cluster op) fenv)] travReturn' (OutputVarsPair o1 o2) (TupRpair r1 r2) (TupRpair v1 v2) accum = travReturn' o1 r1 v1 $ travReturn' o2 r2 v2 accum - travReturn' OutputVarsIgnore _ _ accum = accum - travReturn' OutputVarsShared (TupRpair (TupRsingle signal) (TupRsingle ref)) (TupRsingle (Var tp ix)) accum = task : accum + travReturn' (OutputVarsScalar tp') (TupRpair (TupRsingle destSignal) (TupRsingle destRef)) (TupRsingle (Var tp ix)) accum = task : accum + where + task = case prjPartial ix env of + Nothing -> internalError "Variable not present in environment" + Just (FutureScalar _ signal ref) -> + ( [signal] + , Alet (LeftHandSideSingle $ BaseRground tp) (RefRead $ Var (BaseRref tp) ref) + $ Effect (RefWrite (weaken (weakenSucc weakenId) destRef) (Var (BaseRground tp) ZeroIdx)) + $ Effect (SignalResolve [weakenSucc weakenId >:> varIdx destSignal]) + $ Return + ) + Just FutureBuffer{} -> bufferImpossible tp' + travReturn' OutputVarsShared (TupRpair (TupRsingle destSignalRef `TupRpair` TupRsingle destSignalRead) (TupRsingle destRef)) (TupRsingle (Var tp ix)) accum = task : accum where task = case prjPartial ix env of Nothing -> internalError "Variable not present in environment" - Just (FutureScalar _ signal ref) -> ([signal], undefined) - Just (FutureBuffer _ signal ref readAccess _) -> ([signal, lockSignal readAccess], Alet (LeftHandSideSingle $ BaseRground tp) (RefRead $ Var (BaseRref tp) ref) $ Effect (RefWrite undefined undefined) $ Effect (SignalResolve undefined) $ Return) - travReturn' OutputVarsUnique (TupRpair (TupRpair (TupRsingle signalRead) (TupRsingle signalWrite)) (TupRsingle ref)) (TupRsingle v) accum = undefined : accum + Just (FutureScalar tp' _ _) -> bufferImpossible tp' + Just (FutureBuffer _ signal ref readAccess _) -> + ( [signal] + , Alet (LeftHandSideSingle $ BaseRground tp) (RefRead $ Var (BaseRref tp) ref) + $ Effect (RefWrite (weaken (weakenSucc weakenId) destRef) (Var (BaseRground tp) ZeroIdx)) + $ Effect (SignalResolve [weakenSucc weakenId >:> varIdx destSignalRef]) + $ Effect (SignalAwait [weakenSucc weakenId >:> lockSignal readAccess]) + $ Effect (SignalResolve [weakenSucc weakenId >:> varIdx destSignalRead]) + $ Return + ) + travReturn' OutputVarsUnique (TupRpair (TupRpair (TupRsingle destSignalRef `TupRpair` TupRsingle destSignalRead) (TupRsingle destSignalWrite)) (TupRsingle destRef)) (TupRsingle (Var tp ix)) accum = task : accum + where + task = case prjPartial ix env of + Nothing -> internalError "Variale not present in environment" + Just (FutureScalar tp' _ _) -> bufferImpossible tp' + Just (FutureBuffer _ _ _ _ Nothing) -> internalError "Expected FutureBuffer with write access" + Just (FutureBuffer _ signal ref readAccess (Just writeAccess)) -> + ( [signal] + , Alet (LeftHandSideSingle $ BaseRground tp) (RefRead $ Var (BaseRref tp) ref) + $ Effect (RefWrite (weaken (weakenSucc weakenId) destRef) (Var (BaseRground tp) ZeroIdx)) + $ Effect (SignalResolve [weakenSucc weakenId >:> varIdx destSignalRef]) + $ Effect (SignalAwait [weakenSucc weakenId >:> lockSignal readAccess]) + $ Effect (SignalResolve [weakenSucc weakenId >:> varIdx destSignalRead]) + $ Effect (SignalAwait [weakenSucc weakenId >:> lockSignal writeAccess]) + $ Effect (SignalResolve [weakenSucc weakenId >:> varIdx destSignalWrite]) + $ Return + ) + -- Destination was reused. No need to copy + travReturn' OutputVarsIgnore _ _ accum = accum + travReturn' _ _ _ _ = internalError "Invalid variables" - acond :: ExpVar genv PrimBool -> PartialSchedule (Cluster op) genv t -> PartialSchedule (Cluster op) genv t -> UniformSchedule (Cluster op) fenv + acond :: ExpVar genv PrimBool -> PartialSchedule op genv t -> PartialSchedule op genv t -> UniformSchedule (Cluster op) fenv acond (Var _ condition) true false = case prjPartial condition env of Just (FutureScalar _ signal ref) -> -- Wait on the signal @@ -1352,6 +1481,197 @@ fromPartial outputEnv outputVars env = \case outputVars' = mapTupR (weaken (weakenSucc weakenId)) outputVars env' = mapPartialEnv (weaken (weakenSucc weakenId)) env + awhile + :: Uniquenesses t + -> PartialScheduleFun op genv (t -> PrimBool) + -> PartialScheduleFun op genv (t -> t) + -> GroundVars genv t + -> UniformSchedule (Cluster op) fenv + awhile = fromPartialAwhile outputEnv outputVars env + +fromPartialAwhile + :: forall op fenv genv t r. + HasCallStack + => OutputVars t r + -> BaseVars fenv r + -> FutureEnv fenv genv + -> Uniquenesses t + -> PartialScheduleFun op genv (t -> PrimBool) + -> PartialScheduleFun op genv (t -> t) + -> GroundVars genv t + -> UniformSchedule (Cluster op) fenv +fromPartialAwhile outputEnv outputVars env uniquenesses (Plam lhsC (Pbody condition)) (Plam lhsS (Pbody step)) initial + | tp <- mapTupR varType initial + , AwhileInputOutput io k lhsInput env' initial' outputEnv' <- awhileInputOutput env (\k -> mapPartialEnv (weaken k) env) uniquenesses initial + = let + + + in Awhile io undefined initial' Return + +awhileInputOutput :: FutureEnv fenv0 genv0 -> (forall fenv''. fenv :> fenv'' -> FutureEnv fenv'' genv) -> Uniquenesses t -> GroundVars genv0 t -> AwhileInputOutput fenv0 fenv genv t +awhileInputOutput env0 env (TupRpair u1 u2) (TupRpair v1 v2) + | AwhileInputOutput io1 k1 lhs1 env1 i1 outputEnv1 <- awhileInputOutput env0 env u1 v1 + , AwhileInputOutput io2 k2 lhs2 env2 i2 outputEnv2 <- awhileInputOutput env0 env1 u2 v2 + = AwhileInputOutput + (InputOutputRpair io1 io2) + (k2 .> k1) + (LeftHandSidePair lhs1 lhs2) + env2 + (TupRpair i1 i2) + (OutputVarsPair outputEnv1 outputEnv2) +awhileInputOutput env0 env TupRunit TupRunit + = AwhileInputOutput + InputOutputRunit + weakenId + (LeftHandSideWildcard TupRunit) + env + TupRunit + OutputVarsIgnore +awhileInputOutput env0 env (TupRsingle uniqueness) (TupRsingle (Var groundR idx)) + | GroundRbuffer tp <- groundR -- Unique buffer + , Unique <- uniqueness + = let + initial = case prjPartial idx env0 of + Just (FutureBuffer tp signal ref (Move signalRead) (Just (Move signalWrite))) -> + TupRsingle (Var BaseRsignal signal) + `TupRpair` + TupRsingle (Var BaseRsignal signalRead) + `TupRpair` + TupRsingle (Var BaseRsignal signalWrite) + `TupRpair` + TupRsingle (Var (BaseRref $ GroundRbuffer tp) ref) + Just (FutureBuffer _ _ _ _ Nothing) -> internalError "Expected a Future with write permissions." + Just (FutureBuffer _ _ _ _ _) -> internalError "Expected Move. Cannot Borrow a variable into a loop." + Just _ -> internalError "Illegal variable" + Nothing -> internalError "Variable not found" + in + AwhileInputOutput + (InputOutputRpair (InputOutputRpair (InputOutputRpair InputOutputRsignal InputOutputRsignal) InputOutputRsignal) InputOutputRref) + -- Input + (weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc weakenId) + (LeftHandSideSingle BaseRsignal `LeftHandSidePair` LeftHandSideSingle BaseRsignal `LeftHandSidePair` LeftHandSideSingle BaseRsignal `LeftHandSidePair` LeftHandSideSingle (BaseRref groundR)) + (\k -> env (weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc k) `PPush` + FutureBuffer tp (k >:> SuccIdx (SuccIdx $ SuccIdx $ ZeroIdx)) (k >:> ZeroIdx) (Move (k >:> SuccIdx (SuccIdx ZeroIdx))) (Just $ Move (k >:> SuccIdx ZeroIdx))) + initial + -- Output + OutputVarsUnique + | GroundRbuffer tp <- groundR -- Shared buffer + = let + initial = case prjPartial idx env0 of + Just (FutureBuffer tp signal ref (Move signalRead) _) -> + TupRsingle (Var BaseRsignal signal) + `TupRpair` + TupRsingle (Var BaseRsignal signalRead) + `TupRpair` + TupRsingle (Var (BaseRref $ GroundRbuffer tp) ref) + Just (FutureBuffer _ _ _ (Borrow _ _) _) -> internalError "Expected Move. Cannot Borrow a variable into a loop." + Just _ -> internalError "Illegal variable" + Nothing -> internalError "Variable not found" + in + AwhileInputOutput + (InputOutputRpair (InputOutputRpair InputOutputRsignal InputOutputRsignal) InputOutputRref) + -- Input + (weakenSucc $ weakenSucc $ weakenSucc weakenId) + (LeftHandSideSingle BaseRsignal `LeftHandSidePair` LeftHandSideSingle BaseRsignal `LeftHandSidePair` LeftHandSideSingle (BaseRref groundR)) + (\k -> env (weakenSucc $ weakenSucc $ weakenSucc k) `PPush` + FutureBuffer tp (k >:> SuccIdx (SuccIdx ZeroIdx)) (k >:> ZeroIdx) (Move (k >:> SuccIdx ZeroIdx)) Nothing) + initial + -- Output + OutputVarsShared + | GroundRscalar tp <- groundR -- Scalar + = let + initial = case prjPartial idx env0 of + Just (FutureScalar tp signal ref) -> + TupRsingle (Var BaseRsignal signal) + `TupRpair` + TupRsingle (Var (BaseRref $ GroundRscalar tp) ref) + Just _ -> internalError "Illegal variable" + Nothing -> internalError "Variable not found" + in + AwhileInputOutput + (InputOutputRpair InputOutputRsignal InputOutputRref) + -- Input + (weakenSucc $ weakenSucc weakenId) + (LeftHandSideSingle BaseRsignal `LeftHandSidePair` LeftHandSideSingle (BaseRref groundR)) + (\k -> env (weakenSucc $ weakenSucc k) `PPush` FutureScalar tp (k >:> SuccIdx ZeroIdx) (k >:> ZeroIdx)) + initial + -- Output + (OutputVarsScalar tp) + +data AwhileInputOutput fenv0 fenv genv t where + AwhileInputOutput + :: InputOutputR input output + -- Input + -> (fenv :> fenv') + -> BLeftHandSide input fenv fenv' + -> (forall fenv''. fenv' :> fenv'' -> FutureEnv fenv'' genv') + -> BaseVars fenv0 input + -- Output + -> OutputVars t output + -> AwhileInputOutput fenv0 fenv genv t + + +{- + + Awhile :: InputOutputR input output + -> UniformScheduleFun exe env (input -> Output PrimBool -> ()) + -> UniformScheduleFun exe env (input -> output -> ()) + -> BaseVars env input + -> UniformSchedule exe env -- Operations after the while loop + -> UniformSchedule exe env +-} + +matchOutputVarsWithEnv :: OutputVars t r -> OutputEnv fenv fenv' t r' -> Maybe (r :~: r') +matchOutputVarsWithEnv (OutputVarsPair v1 v2) (OutputEnvPair e1 e2) + | Just Refl <- matchOutputVarsWithEnv v1 e1 + , Just Refl <- matchOutputVarsWithEnv v2 e2 = Just Refl +matchOutputVarsWithEnv OutputVarsShared{} OutputEnvShared{} = Just Refl +matchOutputVarsWithEnv OutputVarsUnique{} OutputEnvUnique{} = Just Refl +matchOutputVarsWithEnv OutputVarsIgnore OutputEnvUnit = Just Refl +matchOutputVarsWithEnv _ _ = Nothing + +partialDoSubstituteOutput :: forall fenv fenv' t r. OutputEnv () fenv t r -> BaseVars fenv' r -> Env (NewIdx fenv') fenv +partialDoSubstituteOutput = go Empty + where + go :: Env (NewIdx fenv') fenv1 -> OutputEnv fenv1 fenv2 t' r' -> BaseVars fenv' r' -> Env (NewIdx fenv') fenv2 + go env (OutputEnvPair o1 o2) (TupRpair v1 v2) + = go (go env o1 v1) o2 v2 + go env OutputEnvUnit TupRunit + = env + go env (OutputEnvScalar _) (TupRpair (TupRsingle v1) (TupRsingle v2)) + = env `Push` NewIdxJust (varIdx v2) `Push` NewIdxJust (varIdx v1) + go env (OutputEnvShared _) (TupRpair (TupRpair (TupRsingle v1) (TupRsingle v2)) (TupRsingle v3)) + = env `Push` NewIdxJust (varIdx v3) `Push` NewIdxJust (varIdx v2) `Push` NewIdxJust (varIdx v1) + go env (OutputEnvUnique _) (TupRpair (TupRpair (TupRpair (TupRsingle v1) (TupRsingle v2)) (TupRsingle v3)) (TupRsingle v4)) + = env `Push` NewIdxJust (varIdx v4) `Push` NewIdxJust (varIdx v3) `Push` NewIdxJust (varIdx v2) `Push` NewIdxJust (varIdx v1) + go _ _ _ = internalError "Impossible BaseVars" + +partialDoSubstituteConvertEnv :: forall genv fenv1 fenv2 fenv' t r. ConvertEnv genv fenv1 fenv2 -> FutureEnv fenv' genv -> Env (NewIdx fenv') fenv1 -> Env (NewIdx fenv') fenv2 +partialDoSubstituteConvertEnv ConvertEnvNil _ env = env +partialDoSubstituteConvertEnv (ConvertEnvSeq c1 c2) fenv env = partialDoSubstituteConvertEnv c2 fenv $ partialDoSubstituteConvertEnv c1 fenv env +partialDoSubstituteConvertEnv (ConvertEnvAcquire (Acquire m var)) fenv env + | Just (FutureBuffer _ _ _ read mWrite) <- prjPartial (varIdx var) fenv = + let + lock + | In <- m = read + | Just write <- mWrite = write + | otherwise = internalError "Requested write access to a buffer, but the FutureBuffer only has read permissions" + (signal, resolver) + | Borrow s r <- lock = (NewIdxJust s, NewIdxJust r) + | Move s <- lock = (NewIdxJust s, NewIdxNoResolver) + in + env `Push` signal `Push` resolver + | otherwise = internalError "Requested access to a buffer, but the FutureBuffer was not found in the environment" +partialDoSubstituteConvertEnv (ConvertEnvFuture var) fenv env + | Just future <- prjPartial (varIdx var) fenv = + let + (signal, ref) + | FutureScalar _ s r <- future = (s, r) + | FutureBuffer _ s r _ _ <- future = (s, r) + in + env `Push` NewIdxJust signal `Push` NewIdxJust ref + | otherwise = internalError "Requested access to a value, but the Future was not found in the environment" + forks :: [UniformSchedule (Cluster op) fenv] -> UniformSchedule (Cluster op) fenv forks [] = Return forks [u] = u @@ -1370,9 +1690,113 @@ serial = go weakenId Alet lhs bnd u' -> Alet lhs bnd $ trav (weakenWithLHS lhs .> k) u' Effect effect u' -> Effect effect $ trav k u' Acond cond true false u' -> Acond cond true false $ trav k u' - Awhile io cond step input u' -> Awhile io cond step input $ trav k u' + Awhile io f input u' -> Awhile io f input $ trav k u' Fork u' u'' -> Fork (trav k u') u'' +data DeclareInput fenv genv' t where + DeclareInput :: fenv :> fenv' + -> BLeftHandSide (Input t) fenv fenv' + -> (forall fenv''. fenv' :> fenv'' -> FutureEnv fenv'' genv') + -> DeclareInput fenv genv' t + +declareInput + :: forall t fenv genv genv'. + FutureEnv fenv genv + -> GLeftHandSide t genv genv' + -> DeclareInput fenv genv' t +declareInput = \fenv -> go weakenId (\k -> mapPartialEnv (weaken k) fenv) + where + go :: forall fenv' genv1 genv2 s. fenv :> fenv' -> (forall fenv''. fenv' :> fenv'' -> FutureEnv fenv'' genv1) -> GLeftHandSide s genv1 genv2 -> DeclareInput fenv' genv2 s + go k fenv (LeftHandSidePair lhs1 lhs2) + | DeclareInput k1 lhs1' fenv1 <- go k fenv lhs1 + , DeclareInput k2 lhs2' fenv2 <- go (k1 .> k) fenv1 lhs2 + = DeclareInput (k2 .> k1) (LeftHandSidePair lhs1' lhs2') fenv2 + go _ fenv (LeftHandSideWildcard grounds) = DeclareInput weakenId (LeftHandSideWildcard $ inputR grounds) fenv + go k fenv (LeftHandSideSingle (GroundRscalar tp)) -- Scalar + | Refl <- inputSingle $ GroundRscalar tp + = DeclareInput + (weakenSucc $ weakenSucc weakenId) + (LeftHandSideSingle BaseRsignal `LeftHandSidePair` LeftHandSideSingle (BaseRref $ GroundRscalar tp)) + (\k' -> PPush (fenv $ weakenSucc $ weakenSucc k') + $ FutureScalar + tp + (k' >:> SuccIdx ZeroIdx) + (k' >:> ZeroIdx)) + go k fenv (LeftHandSideSingle (GroundRbuffer tp)) -- Buffer + = DeclareInput + (weakenSucc $ weakenSucc weakenId) + (LeftHandSideSingle BaseRsignal `LeftHandSidePair` LeftHandSideSingle (BaseRref $ GroundRbuffer tp)) + (\k' -> PPush (fenv $ weakenSucc $ weakenSucc k') + $ FutureBuffer + tp + (k' >:> SuccIdx ZeroIdx) + (k' >:> ZeroIdx) + (Move $ (k' >:> (SuccIdx ZeroIdx))) + Nothing) + +data DeclareOutput op fenv t where + DeclareOutput :: fenv :> fenv' + -> BLeftHandSide (Output t) fenv fenv' + -> fenv' :> fenv'' + -> (UniformSchedule (Cluster op) fenv'' -> UniformSchedule (Cluster op) fenv') + -> OutputVars t r + -> (forall fenv'''. fenv'' :> fenv''' -> BaseVars fenv''' r) + -> DeclareOutput op fenv t + +data DeclareOutputInternal op fenv' t where + DeclareOutputInternal :: fenv' :> fenv'' + -> (UniformSchedule (Cluster op) fenv'' -> UniformSchedule (Cluster op) fenv') + -> OutputVars t r + -> (forall fenv'''. fenv'' :> fenv''' -> BaseVars fenv''' r) + -> DeclareOutputInternal op fenv' t + +declareOutput + :: forall op fenv t. + GroundsR t + -> DeclareOutput op fenv t +declareOutput grounds + | DeclareVars lhs k1 value <- declareVars $ outputR grounds + , DeclareOutputInternal k2 instr outputEnv outputVars <- go weakenId grounds (value weakenId) + = DeclareOutput k1 lhs k2 instr outputEnv outputVars + where + go :: fenv1 :> fenv2 -> GroundsR s -> BaseVars fenv1 (Output s) -> DeclareOutputInternal op fenv2 s + go _ TupRunit TupRunit + = DeclareOutputInternal + weakenId + id + OutputVarsIgnore + $ const TupRunit + go k (TupRpair gL gR) (TupRpair vL vR) + | DeclareOutputInternal kL instrL outL varsL' <- go k gL vL + , DeclareOutputInternal kR instrR outR varsR' <- go (kL .> k) gR vR + = DeclareOutputInternal + (kR .> kL) + (instrL . instrR) + (OutputVarsPair outL outR) + $ \k' -> varsL' (k' .> kR) `TupRpair` varsR' k' + go k (TupRsingle (GroundRbuffer tp)) (TupRsingle signal `TupRpair` TupRsingle ref) + = DeclareOutputInternal + (weakenSucc $ weakenSucc weakenId) + (Alet lhsSignal NewSignal) + OutputVarsShared + $ \k' -> + let k'' = k' .> weakenSucc' (weakenSucc' k) + in TupRsingle (Var BaseRsignalResolver $ weaken k' ZeroIdx) + `TupRpair` TupRsingle (weaken k'' signal) + `TupRpair` TupRsingle (weaken k'' ref) + go k (TupRsingle (GroundRscalar tp)) vars + | Refl <- inputSingle $ GroundRscalar tp + , TupRsingle signal `TupRpair` TupRsingle ref <- vars + = DeclareOutputInternal + weakenId + id + (OutputVarsScalar tp) + $ \k' -> + let k'' = k' .> k + in TupRsingle (weaken k'' signal) + `TupRpair` TupRsingle (weaken k'' ref) + + data DeclareBinding op fenv genv' t where DeclareBinding :: fenv :> fenv' -> (UniformSchedule (Cluster op) fenv' -> UniformSchedule (Cluster op) fenv) @@ -1397,11 +1821,201 @@ declareBinding retEnv retVars = \fenv -> go weakenId (\k -> mapPartialEnv (weake | DeclareBinding k1 instr1 out1 vars1 fenv1 <- go k fenv lhs1 dest1 u1 , DeclareBinding k2 instr2 out2 vars2 fenv2 <- go (k1 .> k) fenv1 lhs2 dest2 u2 = DeclareBinding (k2 .> k1) (instr1 . instr2) (OutputVarsPair out1 out2) (\k' -> TupRpair (vars1 $ k' .> k2) (vars2 k')) fenv2 + go k fenv (LeftHandSideWildcard _) _ _ + = DeclareBinding + weakenId + id + OutputVarsIgnore + (const TupRunit) + fenv go k fenv (LeftHandSideSingle _) (TupRsingle (DestinationReuse idx)) _ = DeclareBinding weakenId id - undefined - undefined - undefined - go k fenv (LeftHandSideSingle (GroundRscalar tp)) _ _ = undefined + OutputVarsIgnore + (const TupRunit) + (\k' -> PNone $ fenv k') + go k fenv (LeftHandSideSingle (GroundRscalar tp)) _ _ + = DeclareBinding + (weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc weakenId) + instr + (OutputVarsScalar tp) + (\k' -> TupRpair + (TupRsingle $ Var BaseRsignalResolver $ k' >:> idx2) + (TupRsingle $ Var (BaseRrefWrite $ GroundRscalar tp) $ k' >:> idx0)) + (\k' -> PPush (fenv $ weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc k') + $ FutureScalar + tp + (k' >:> idx3) + (k' >:> idx1)) + where + instr + = Alet lhsSignal NewSignal + . Alet (lhsRef $ GroundRscalar tp) (NewRef $ GroundRscalar tp) + + idx0 = ZeroIdx + idx1 = SuccIdx ZeroIdx + idx2 = SuccIdx $ SuccIdx ZeroIdx + idx3 = SuccIdx $ SuccIdx $ SuccIdx ZeroIdx + go k fenv (LeftHandSideSingle (GroundRbuffer tp)) _ (TupRsingle Unique) + = DeclareBinding + (weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc weakenId) + instr + OutputVarsUnique + (\k' -> TupRpair + ( TupRpair + ( TupRpair + (TupRsingle $ Var BaseRsignalResolver $ k' >:> idx6) + (TupRsingle $ Var BaseRsignalResolver $ k' >:> idx4) + ) + (TupRsingle $ Var BaseRsignalResolver $ k' >:> idx2) + ) + (TupRsingle $ Var (BaseRrefWrite $ GroundRbuffer tp) $ k' >:> idx0)) + (\k' -> PPush (fenv $ weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc k') + $ FutureBuffer + tp + (k' >:> idx7) + (k' >:> idx1) + (Move (k' >:> idx5)) + $ Just $ Move $ k' >:> idx3) + where + instr + = Alet lhsSignal NewSignal -- Signal to grant access to the reference (idx7, idx6) + . Alet lhsSignal NewSignal -- Signal to grant read access to the array (idx5, idx4) + . Alet lhsSignal NewSignal -- Signal to grant write access to the array (idx3, idx2) + . Alet (lhsRef $ GroundRbuffer tp) (NewRef $ GroundRbuffer tp) -- (idx1, idx0) + + idx0 = ZeroIdx + idx1 = SuccIdx ZeroIdx + idx2 = SuccIdx $ SuccIdx ZeroIdx + idx3 = SuccIdx $ SuccIdx $ SuccIdx ZeroIdx + idx4 = SuccIdx $ SuccIdx $ SuccIdx $ SuccIdx ZeroIdx + idx5 = SuccIdx $ SuccIdx $ SuccIdx $ SuccIdx $ SuccIdx ZeroIdx + idx6 = SuccIdx $ SuccIdx $ SuccIdx $ SuccIdx $ SuccIdx $ SuccIdx ZeroIdx + idx7 = SuccIdx $ SuccIdx $ SuccIdx $ SuccIdx $ SuccIdx $ SuccIdx $ SuccIdx ZeroIdx + go k fenv (LeftHandSideSingle (GroundRbuffer tp)) _ (TupRsingle Shared) + = DeclareBinding + (weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc weakenId) + instr + OutputVarsShared + (\k' -> TupRpair + ( TupRpair + (TupRsingle $ Var BaseRsignalResolver $ k' >:> idx4) + (TupRsingle $ Var BaseRsignalResolver $ k' >:> idx2) + ) + (TupRsingle $ Var (BaseRrefWrite $ GroundRbuffer tp) $ k' >:> idx0)) + (\k' -> PPush (fenv $ weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc $ weakenSucc k') + $ FutureBuffer + tp + (k' >:> idx5) + (k' >:> idx1) + (Move (k' >:> idx3)) + Nothing) + where + instr + = Alet lhsSignal NewSignal + . Alet lhsSignal NewSignal + . Alet (lhsRef $ GroundRbuffer tp) (NewRef $ GroundRbuffer tp) + + idx0 = ZeroIdx + idx1 = SuccIdx ZeroIdx + idx2 = SuccIdx $ SuccIdx ZeroIdx + idx3 = SuccIdx $ SuccIdx $ SuccIdx ZeroIdx + idx4 = SuccIdx $ SuccIdx $ SuccIdx $ SuccIdx ZeroIdx + idx5 = SuccIdx $ SuccIdx $ SuccIdx $ SuccIdx $ SuccIdx ZeroIdx + +type ReindexPartialN f env env' = forall a. Idx env a -> f (NewIdx env' a) + +data NewIdx env t where + NewIdxNoResolver :: NewIdx env SignalResolver + NewIdxJust :: Idx env t -> NewIdx env t + +data SunkReindexPartialN f env env' where + Sink :: SunkReindexPartialN f env env' -> SunkReindexPartialN f (env, s) (env', s) + ReindexF :: ReindexPartialN f env env' -> SunkReindexPartialN f env env' + + +reindexSchedule :: (IsExecutableAcc exe, Applicative f) => ReindexPartialN f env env' -> UniformSchedule exe env -> f (UniformSchedule exe env') +reindexSchedule k = reindexSchedule' $ ReindexF k + +sinkReindexWithLHS :: LeftHandSide s t env1 env1' -> LeftHandSide s t env2 env2' -> SunkReindexPartialN f env1 env2 -> SunkReindexPartialN f env1' env2' +sinkReindexWithLHS (LeftHandSideWildcard _) (LeftHandSideWildcard _) k = k +sinkReindexWithLHS (LeftHandSideSingle _) (LeftHandSideSingle _) k = Sink k +sinkReindexWithLHS (LeftHandSidePair a1 b1) (LeftHandSidePair a2 b2) k = sinkReindexWithLHS b1 b2 $ sinkReindexWithLHS a1 a2 k +sinkReindexWithLHS _ _ _ = error "sinkReindexWithLHS: left hand sides don't match" + +reindex' :: Applicative f => SunkReindexPartialN f env env' -> ReindexPartialN f env env' +reindex' (ReindexF f) = f +reindex' (Sink k) = \case + ZeroIdx -> pure $ NewIdxJust ZeroIdx + SuccIdx ix -> + let + f NewIdxNoResolver = NewIdxNoResolver + f (NewIdxJust ix') = NewIdxJust $ SuccIdx ix' + in + f <$> reindex' k ix + +reindexSchedule' :: (IsExecutableAcc exe, Applicative f) => SunkReindexPartialN f env env' -> UniformSchedule exe env -> f (UniformSchedule exe env') +reindexSchedule' k = \case + Return -> pure Return + Alet lhs bnd s + | Exists lhs' <- rebuildLHS lhs -> Alet lhs' <$> reindexBinding' k bnd <*> reindexSchedule' (sinkReindexWithLHS lhs lhs' k) s + Effect effect s -> Effect <$> reindexEffect' k effect <*> reindexSchedule' k s + Acond cond t f continue -> Acond <$> reindexVarUnsafe k cond <*> reindexSchedule' k t <*> reindexSchedule' k f <*> reindexSchedule' k continue + Awhile io f init continue -> Awhile io <$> reindexScheduleFun' k f <*> traverseTupR (reindexVarUnsafe k) init <*> reindexSchedule' k continue + Fork s1 s2 -> Fork <$> reindexSchedule' k s1 <*> reindexSchedule' k s2 + +reindexVarUnsafe :: Applicative f => SunkReindexPartialN f env env' -> Var s env t -> f (Var s env' t) +reindexVarUnsafe k (Var tp idx) = Var tp . fromNewIdxUnsafe <$> reindex' k idx + +reindexScheduleFun' :: (IsExecutableAcc exe, Applicative f) => SunkReindexPartialN f env env' -> UniformScheduleFun exe env t -> f (UniformScheduleFun exe env' t) +reindexScheduleFun' k = \case + Sbody s -> Sbody <$> reindexSchedule' k s + Slam lhs f + | Exists lhs' <- rebuildLHS lhs -> Slam lhs' <$> reindexScheduleFun' (sinkReindexWithLHS lhs lhs' k) f + +reindexEffect' :: forall exe f env env'. (IsExecutableAcc exe, Applicative f) => SunkReindexPartialN f env env' -> Effect exe env -> f (Effect exe env') +reindexEffect' k = \case + Exec exe -> Exec <$> reindexExecPartial (fromNewIdxUnsafe <.> reindex' k) exe + SignalAwait signals -> SignalAwait <$> traverse (fromNewIdxSignal <.> reindex' k) signals + SignalResolve resolvers -> SignalResolve . mapMaybe toMaybe <$> traverse (reindex' k) resolvers + RefWrite ref value -> RefWrite <$> reindexVar (fromNewIdxOutputRef <.> reindex' k) ref <*> reindexVar (fromNewIdxUnsafe <.> reindex' k) value + where + toMaybe :: NewIdx env' a -> Maybe (Idx env' a) + toMaybe (NewIdxJust idx) = Just idx + toMaybe _ = Nothing + +-- For Exec we cannot have a safe function from the conversion, +-- as we cannot enforce in the type system that no SignalResolvers +-- occur in an Exec or Compute. +fromNewIdxUnsafe :: NewIdx env' a -> Idx env' a +fromNewIdxUnsafe (NewIdxJust idx) = idx +fromNewIdxUnsafe _ = error "Expected NewIdxJust" + +-- Different versions, which have different ways of getting evidence +-- that NewIdxNoResolver is impossible +fromNewIdxSignal :: NewIdx env' Signal -> Idx env' Signal +fromNewIdxSignal (NewIdxJust idx) = idx + +fromNewIdxOutputRef :: NewIdx env' (OutputRef t) -> Idx env' (OutputRef t) +fromNewIdxOutputRef (NewIdxJust idx) = idx + +fromNewIdxRef :: NewIdx env' (Ref t) -> Idx env' (Ref t) +fromNewIdxRef (NewIdxJust idx) = idx + +fromNewIdxGround :: GroundR a -> NewIdx env' a -> Idx env' a +fromNewIdxGround _ (NewIdxJust idx) = idx +fromNewIdxGround tp NewIdxNoResolver = signalResolverImpossible (TupRsingle tp) + +reindexBinding' :: Applicative f => SunkReindexPartialN f env env' -> Binding env t -> f (Binding env' t) +reindexBinding' k = \case + Compute e -> Compute <$> reindexExp (fromNewIdxUnsafe <.> reindex' k) e + NewSignal -> pure NewSignal + NewRef tp -> pure $ NewRef tp + Alloc shr tp sh -> Alloc shr tp <$> reindexVars (fromNewIdxUnsafe <.> reindex' k) sh + Use tp buffer -> pure $ Use tp buffer + Unit (Var tp idx) -> Unit . Var tp <$> (fromNewIdxGround (GroundRscalar tp) <.> reindex' k) idx + RefRead ref -> RefRead <$> reindexVar (fromNewIdxUnsafe <.> reindex' k) ref + +(<.>) :: Applicative f => (b -> c) -> (a -> f b) -> a -> f c +(<.>) g h a = g <$> h a