forked from AccelerateHS/accelerate
/
Type.hs
173 lines (148 loc) · 5.87 KB
/
Type.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module : Data.Array.Accelerate.Representation.Type
-- Copyright : [2008..2020] The Accelerate Team
-- License : BSD3
--
-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability : experimental
-- Portability : non-portable (GHC extensions)
--
module Data.Array.Accelerate.Representation.Type
where
import Data.Array.Accelerate.Type
import Data.Primitive.Vec
import Data.Type.Equality
import Language.Haskell.TH
-- | Both arrays (Acc) and expressions (Exp) are represented as nested
-- pairs consisting of:
--
-- * unit (void)
--
-- * pairs: representing compound values (i.e. tuples) where each component
-- will be stored in a separate array.
--
-- * single array / scalar types
-- in case of expressions: values which go in registers. These may be single value
-- types such as int and float, or SIMD vectors of single value types such
-- as <4 * float>. We do not allow vectors-of-vectors.
--
data TupR s a where
TupRunit :: TupR s ()
TupRsingle :: s a -> TupR s a
TupRpair :: TupR s a -> TupR s b -> TupR s (a, b)
instance Show (TupR ScalarType a) where
show TupRunit = "()"
show (TupRsingle t) = show t
show (TupRpair a b) = "(" ++ show a ++ "," ++ show b ++")"
type TypeR = TupR ScalarType
-- | Distributes a type constructor over the elements of a tuple.
-- TODO: Could we make this type class injective? Then we wouldn't
-- need the type class Distributes any more.
-- Note that we must use a standard, lazy pair here, as we rely on
-- laziness in type alias Buffers to make host-device copies lazy.
--
type family Distribute f a = b where
Distribute f () = ()
Distribute f (a, b) = (Distribute f a, Distribute f b)
Distribute f a = f a
class Distributes s where
-- Shows that a single element isn't unit or a pair
reprIsSingle :: s t -> Distribute f t :~: f t
instance Distributes ScalarType where
reprIsSingle (VectorScalarType _) = Refl
reprIsSingle (SingleScalarType (NumSingleType tp)) = case tp of
IntegralNumType TypeInt -> Refl
IntegralNumType TypeInt8 -> Refl
IntegralNumType TypeInt16 -> Refl
IntegralNumType TypeInt32 -> Refl
IntegralNumType TypeInt64 -> Refl
IntegralNumType TypeWord -> Refl
IntegralNumType TypeWord8 -> Refl
IntegralNumType TypeWord16 -> Refl
IntegralNumType TypeWord32 -> Refl
IntegralNumType TypeWord64 -> Refl
FloatingNumType TypeHalf -> Refl
FloatingNumType TypeFloat -> Refl
FloatingNumType TypeDouble -> Refl
rnfTupR :: (forall b. s b -> ()) -> TupR s a -> ()
rnfTupR _ TupRunit = ()
rnfTupR f (TupRsingle s) = f s
rnfTupR f (TupRpair a b) = rnfTupR f a `seq` rnfTupR f b
rnfTypeR :: TypeR t -> ()
rnfTypeR = rnfTupR rnfScalarType
liftTupR :: (forall b. s b -> Q (TExp (s b))) -> TupR s a -> Q (TExp (TupR s a))
liftTupR _ TupRunit = [|| TupRunit ||]
liftTupR f (TupRsingle s) = [|| TupRsingle $$(f s) ||]
liftTupR f (TupRpair a b) = [|| TupRpair $$(liftTupR f a) $$(liftTupR f b) ||]
liftTypeR :: TypeR t -> Q (TExp (TypeR t))
liftTypeR TupRunit = [|| TupRunit ||]
liftTypeR (TupRsingle t) = [|| TupRsingle $$(liftScalarType t) ||]
liftTypeR (TupRpair ta tb) = [|| TupRpair $$(liftTypeR ta) $$(liftTypeR tb) ||]
liftTypeQ :: TypeR t -> TypeQ
liftTypeQ = tuple
where
tuple :: TypeR t -> TypeQ
tuple TupRunit = [t| () |]
tuple (TupRpair t1 t2) = [t| ($(tuple t1), $(tuple t2)) |]
tuple (TupRsingle t) = scalar t
scalar :: ScalarType t -> TypeQ
scalar (SingleScalarType t) = single t
scalar (VectorScalarType t) = vector t
vector :: VectorType (Vec n a) -> TypeQ
vector (VectorType n t) = [t| Vec $(litT (numTyLit (toInteger n))) $(single t) |]
single :: SingleType t -> TypeQ
single (NumSingleType t) = num t
num :: NumType t -> TypeQ
num (IntegralNumType t) = integral t
num (FloatingNumType t) = floating t
integral :: IntegralType t -> TypeQ
integral TypeInt = [t| Int |]
integral TypeInt8 = [t| Int8 |]
integral TypeInt16 = [t| Int16 |]
integral TypeInt32 = [t| Int32 |]
integral TypeInt64 = [t| Int64 |]
integral TypeWord = [t| Word |]
integral TypeWord8 = [t| Word8 |]
integral TypeWord16 = [t| Word16 |]
integral TypeWord32 = [t| Word32 |]
integral TypeWord64 = [t| Word64 |]
floating :: FloatingType t -> TypeQ
floating TypeHalf = [t| Half |]
floating TypeFloat = [t| Float |]
floating TypeDouble = [t| Double |]
runQ $
let
mkT :: Int -> Q Dec
mkT n =
let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ]
ts = map varT xs
rhs = foldl (\a b -> [t| ($a, $b) |]) [t| () |] ts
in
tySynD (mkName ("Tup" ++ show n)) (map plainTV xs) rhs
in
mapM mkT [2..16]
mapTupR :: (forall s. a s -> b s) -> TupR a t -> TupR b t
mapTupR f (TupRsingle a) = TupRsingle $ f a
mapTupR _ TupRunit = TupRunit
mapTupR f (TupRpair a1 a2) = mapTupR f a1 `TupRpair` mapTupR f a2
traverseTupR :: Applicative f => (forall s. a s -> f (b s)) -> TupR a t -> f (TupR b t)
traverseTupR f (TupRsingle a) = TupRsingle <$> f a
traverseTupR _ TupRunit = pure TupRunit
traverseTupR f (TupRpair a1 a2) = TupRpair <$> traverseTupR f a1 <*> traverseTupR f a2
functionImpossible :: TypeR (s -> t) -> a
functionImpossible (TupRsingle (SingleScalarType (NumSingleType tp))) = case tp of
IntegralNumType t -> case t of {}
FloatingNumType t -> case t of {}