Skip to content

Commit

Permalink
Replace built-in Either with a prelude-defined ADT version.
Browse files Browse the repository at this point in the history
We still don't have index set instances for ADTs, but we can use variants added
in #180 for sum types as index sets.

Also add unification rules for records and variants.
  • Loading branch information
dougalm committed Jul 31, 2020
1 parent 60a98dd commit 78d6ebd
Show file tree
Hide file tree
Showing 14 changed files with 28 additions and 158 deletions.
16 changes: 8 additions & 8 deletions examples/eval-tests.dx
Expand Up @@ -516,7 +516,7 @@ litArr = [10, 5, 3]
c
> ([3.0, 2.0, 1.0, 0.0], 4.0)

def eitherFloor (x:(Int|Real)) : Int = oldcase x
def eitherFloor (x:(Int|Real)) : Int = case x of
Left i -> i
Right f -> floor f

Expand All @@ -526,24 +526,24 @@ def eitherFloor (x:(Int|Real)) : Int = oldcase x
:p
x : (Int|Real) = Right 1.2
x
> (Right 1.2)
> Right 1.2

-- Sum types as flattened index sets!
def Weights (n:Type) (m:Type) : Type = n=>m=>Real
def Biases (n:Type) : Type = n=>Real
def Params (n:Type) (m:Type) : Type = ((n&m)|n)=>Real
def Params (n:Type) (m:Type) : Type = {weights:(n&m) | biases:n}=>Real

w = for i:(Fin 2). for j:(Fin 3). (i2r (ordinal i)) * 3.0 + (i2r (ordinal j))
b = for j:(Fin 2). neg (i2r (ordinal j) + 1.0)

def flatten ((w,b):(Weights n m & Biases n)): Params n m =
for idx. oldcase idx
Left (i,j) -> w.i.j
Right i -> b.i
for idx. case idx of
{| weights = (i,j) |} -> w.i.j
{| biases = i |} -> b.i

def unflatten (params:Params n m) : (Weights n m & Biases n) =
( for i. for j. params.(Left (i,j))
, for i. params.(Right i) )
( for i. for j. params.({| weights = (i,j)|})
, for i. params.({| biases = i |}))

:p unflatten (flatten (w, b)) == (w, b)
> True
Expand Down
2 changes: 1 addition & 1 deletion examples/uexpr-tests.dx
Expand Up @@ -238,7 +238,7 @@ def myOtherFst ((x, _):(a&b)) : a = x

id'' : b -> b = id

def eitherFloor (x:(Int|Real)) : Int = oldcase x
def eitherFloor (x:(Int|Real)) : Int = case x of
Left i -> i
Right f -> floor f

Expand Down
22 changes: 3 additions & 19 deletions prelude.dx
Expand Up @@ -99,15 +99,9 @@ def isNothing (x:Maybe a) : Bool = case x of
Nothing -> True
Just _ -> False

def (|) (a:Type) (b:Type) : Type = %SumType a b
def anyVal (a:Type) ?-> : a = %anyVal a
def sumCon (isLeft:Bool) (l:a) (r:b) : (a|b) =
isLeft' = unsafeCoerce InternalBool isLeft
%sumCon isLeft' l r

def Left (x:a) : (a|b) = sumCon True x anyVal
def Right (x:b) : (a|b) = sumCon False anyVal x
def caseAnalysis (x:(a|b)) (l:a->c) (r:b->c) : c = %caseAnalysis x l r
data (|) a:Type b:Type =
Left a
Right b

def select (p:Bool) (x:a) (y:a) : a = case p of
True -> x
Expand Down Expand Up @@ -190,16 +184,6 @@ def pairOrd (ordA: Ord a)?=> (ordB: Ord b)?=> : Ord (a & b) =
pairLt = \(x1,x2) (y1,y2). x1 < y1 || (x1 == y1 && x2 < y2)
MkOrd pairEq pairGt pairLt

@instance
def sumEq (eqA: Eq a)?=> (eqB: Eq b)?=> : Eq (a | b) = MkEq $
\x y. oldcase x
Left xVal -> oldcase y
Left yVal -> xVal == yVal
Right yVal -> False
Right xVal -> oldcase y
Left yVal -> False
Right yVal -> xVal == yVal

-- TODO: accumulate using the True/&& monoid
@instance
def tabEq (n:Type) ?-> (eqA: Eq a) ?=> : Eq (n=>a) = MkEq $
Expand Down
1 change: 0 additions & 1 deletion src/lib/Algebra.hs
Expand Up @@ -87,7 +87,6 @@ indexSetSize (TC con) = case con of
ExclusiveLim h -> toPolynomial h
Unlimited -> undefined
PairType l r -> mulC (indexSetSize l) (indexSetSize r)
SumType l r -> add (indexSetSize l) (indexSetSize r)
_ -> error $ "Not implemented " ++ pprint con
indexSetSize (RecordTy types) = let
sizes = map indexSetSize (F.toList types)
Expand Down
30 changes: 1 addition & 29 deletions src/lib/Embed.hs
Expand Up @@ -24,7 +24,7 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildP
SubstEmbedT, SubstEmbed, runSubstEmbedT, dceBlock, dceModule,
TraversalDef, traverseDecls, traverseBlock, traverseExpr,
traverseAtom, arrOffset, arrLoad, evalBlockE, substTraversalDef,
sumTag, getLeft, getRight, fromSum, clampPositive, buildNAbs,
clampPositive, buildNAbs,
indexSetSizeE, indexToIntE, intToIndexE, anyValue) where

import Control.Applicative
Expand Down Expand Up @@ -295,21 +295,6 @@ unpackConsList xs = case getType xs of
(x, rest) <- fromPair xs
liftM (x:) $ unpackConsList rest

sumTag :: MonadEmbed m => Atom -> m Atom
sumTag (SumVal t _ _) = return t
sumTag s = emitOp $ SumTag s

getLeft :: MonadEmbed m => Atom -> m Atom
getLeft (SumVal _ l _) = return l
getLeft s = emitOp $ SumGet s True

getRight :: MonadEmbed m => Atom -> m Atom
getRight (SumVal _ _ r) = return r
getRight s = emitOp $ SumGet s False

fromSum :: MonadEmbed m => Atom -> m (Atom, Atom, Atom)
fromSum s = (,,) <$> sumTag s <*> getLeft s <*> getRight s

emitRunWriter :: MonadEmbed m => Name -> Type -> (Atom -> m Atom) -> m Atom
emitRunWriter v ty body = do
eff <- getAllowedEffects
Expand Down Expand Up @@ -711,7 +696,6 @@ indexSetSizeE (TC con) = case con of
Unlimited -> indexSetSizeE n
clampPositive =<< high' `isub` low'
PairType a b -> bindM2 imul (indexSetSizeE a) (indexSetSizeE b)
SumType l r -> bindM2 iadd (indexSetSizeE l) (indexSetSizeE r)
_ -> error $ "Not implemented " ++ pprint con
where
indexSetSizeE (RecordTy types) = do
Expand All @@ -735,12 +719,6 @@ indexToIntE :: MonadEmbed m => Type -> Atom -> m Atom
indexToIntE ty idx = case ty of
UnitTy -> return $ IntVal 0
BoolTy -> boolToInt idx
SumTy lType rType -> do
(tag, lVal, rVal) <- fromSum idx
lTypeSize <- indexSetSizeE lType
lInt <- indexToIntE lType lVal
rInt <- iadd lTypeSize =<< indexToIntE rType rVal
select tag lInt rInt
PairTy lType rType -> do
(lVal, rVal) <- fromPair idx
lIdx <- indexToIntE lType lVal
Expand Down Expand Up @@ -781,12 +759,6 @@ intToIndexE ty@(TC con) i = case con of
iA <- intToIndexE a =<< idiv i bSize
iB <- intToIndexE b =<< irem i bSize
return $ PairVal iA iB
SumType l r -> do
lSize <- indexSetSizeE l
isLeft <- i `ilt` lSize
li <- intToIndexE l i
ri <- intToIndexE r =<< i `isub` lSize
return $ Con $ SumCon isLeft li ri
_ -> error $ "Unexpected type " ++ pprint con
where iAsIdx = return $ Con $ Coerce ty i
intToIndexE (RecordTy types) i = do
Expand Down
8 changes: 2 additions & 6 deletions src/lib/Imp.hs
Expand Up @@ -141,10 +141,8 @@ toImpOp (maybeDest, op) = case op of
ithDest <- destGet dest =<< intToIndex (binderType b) (IIntVal i)
copyAtom ithDest row
destToAtom dest
SumGet ~(SumVal _ l r) left -> returnVal $ if left then l else r
SumTag ~(SumVal t _ _) -> returnVal t
Fst ~(PairVal x _) -> returnVal x
Snd ~(PairVal _ y) -> returnVal y
Fst ~(PairVal x _) -> returnVal x
Snd ~(PairVal _ y) -> returnVal y
PrimEffect ~(Var refVar) m -> do
refDest <- looks $ (! refVar) . fst . fst
case m of
Expand Down Expand Up @@ -453,7 +451,6 @@ splitDest (maybeDest, (Block decls ans)) = do
gatherVarDests (Dest rd) rr
(UnitCon , UnitCon ) -> return ()
(Coerce _ db , Coerce _ rb ) -> gatherVarDests (Dest db) rb
(_ , SumCon _ _ _ ) -> error "Not implemented"
_ -> unreachable
_ -> unreachable
where
Expand Down Expand Up @@ -626,7 +623,6 @@ zipWithDest dest@(Dest destAtom) atom f = case (destAtom, atom) of
(PairCon ld rd, PairCon la ra) -> rec (Dest ld) la >> rec (Dest rd) ra
(UnitCon , UnitCon ) -> return ()
(Coerce _ d , Coerce _ a ) -> rec (Dest d) a
(SumCon _ _ _ , SumCon _ _ _ ) -> error "Not implemented"
(SumAsProd _ tag xs, SumAsProd _ tag' xs') -> do
recDest tag tag'
zipWithM_ (zipWithM_ recDest) xs xs'
Expand Down
8 changes: 8 additions & 0 deletions src/lib/Inference.hs
Expand Up @@ -713,13 +713,21 @@ unify t1 t2 = do
when (void arr /= void arr') $ throw TypeErr ""
unify resultTy resultTy'
unifyEff (arrowEff arr) (arrowEff arr')
(RecordTy items, RecordTy items') -> unifyLabeledItems items items'
(VariantTy items, VariantTy items') -> unifyLabeledItems items items'
(TypeCon f xs, TypeCon f' xs')
| f == f' && length xs == length xs' -> zipWithM_ unify xs xs'
(TC con, TC con') | void con == void con' ->
zipWithM_ unify (toList con) (toList con')
(Eff eff, Eff eff') -> unifyEff eff eff'
_ -> throw TypeErr ""

unifyLabeledItems :: (MonadCat SolverEnv m, MonadError Err m)
=> LabeledItems Type -> LabeledItems Type -> m ()
unifyLabeledItems m m' = do
when (reflectLabels m /= reflectLabels m') $ throw TypeErr ""
zipWithM_ unify (toList m) (toList m')

unifyEff :: (MonadCat SolverEnv m, MonadError Err m)
=> EffectRow -> EffectRow -> m ()
unifyEff r1 r2 = do
Expand Down
4 changes: 0 additions & 4 deletions src/lib/Interpreter.hs
Expand Up @@ -102,8 +102,6 @@ evalOpDefined expr = case expr of
_ -> evalEmbed (indexToIntE (getType idxArg) idxArg)
Fst p -> x where (PairVal x _) = p
Snd p -> y where (PairVal _ y) = p
SumTag s -> t where (SumVal t _ _) = s
SumGet s left -> if left then l else r where (SumVal _ l r) = s
_ -> error $ "Not implemented: " ++ pprint expr

indices :: Type -> [Atom]
Expand All @@ -113,8 +111,6 @@ indices ty = case ty of
TC (IndexRange _ _ _) -> fmap (Con . Coerce ty . IntVal) [0..n - 1]
TC (PairType lt rt) -> [PairVal l r | l <- indices lt, r <- indices rt]
TC (UnitType) -> [UnitVal]
TC (SumType lt rt) -> fmap (\l -> SumVal (BoolVal True) l (Con (AnyValue rt))) (indices lt) ++
fmap (\r -> SumVal (BoolVal False) (Con (AnyValue lt)) r) (indices rt)
RecordTy types -> let
subindices = map indices (toList types)
products = foldl (\prevs curs -> [cur:prev | cur <- curs, prev <- prevs]) [[]] subindices
Expand Down
10 changes: 0 additions & 10 deletions src/lib/PPrint.hs
Expand Up @@ -175,7 +175,6 @@ instance PrettyPrec e => PrettyPrec (PrimTC e) where
BaseType b -> prettyPrec b
ArrayType ty -> atPrec ArgPrec $ "Arr[" <> pLowest ty <> "]"
PairType a b -> atPrec ArgPrec $ parens $ pApp a <+> "&" <+> pApp b
SumType a b -> atPrec ArgPrec $ parens $ pApp a <+> "|" <+> pApp b
UnitType -> atPrec ArgPrec "Unit"
IntRange a b -> if asStr (pArg a) == "0"
then atPrec AppPrec ("Fin" <+> pArg b)
Expand Down Expand Up @@ -203,19 +202,13 @@ instance PrettyPrec e => PrettyPrec (PrimCon e) where
RefCon _ _ -> atPrec ArgPrec "RefCon"
Coerce t i -> atPrec LowestPrec $ pApp i <> "@" <> pApp t
AnyValue t -> atPrec AppPrec $ pAppArg "%anyVal" [t]
SumCon c l r -> atPrec AppPrec $ case asStr (pArg c) of
"True" -> pAppArg "Left" [l]
"False" -> pAppArg "Right" [r]
_ -> pAppArg "SumCon" [c, l, r]
SumAsProd ty tag payload -> atPrec LowestPrec $
"SumAsProd" <+> pApp ty <+> pApp tag <+> pApp payload
ClassDictHole _ -> atPrec ArgPrec "_"

instance PrettyPrec e => Pretty (PrimOp e) where pretty = prettyFromPrettyPrec
instance PrettyPrec e => PrettyPrec (PrimOp e) where
prettyPrec op = case op of
SumGet e isLeft -> atPrec AppPrec $ (if isLeft then "getLeft" else "getRight") <+> pArg e
SumTag e -> atPrec AppPrec $ pAppArg "getTag" [e]
PrimEffect ref (MPut val ) -> atPrec LowestPrec $ pApp ref <+> ":=" <+> pApp val
PrimEffect ref (MTell val) -> atPrec LowestPrec $ pApp ref <+> "+=" <+> pApp val
ArrayOffset arr idx off -> atPrec LowestPrec $ pApp arr <+> "+>" <+> pApp off <+> (parens $ "index:" <+> pLowest idx)
Expand All @@ -226,9 +219,6 @@ instance PrettyPrec e => Pretty (PrimHof e) where pretty = prettyFromPrettyPrec
instance PrettyPrec e => PrettyPrec (PrimHof e) where
prettyPrec hof = case hof of
For dir lam -> atPrec LowestPrec $ dirStr dir <+> pArg lam
SumCase c l r -> atPrec LowestPrec $ "case" <+> pArg c <> hardline
<> nest 2 (pLowest l) <> hardline
<> nest 2 (pLowest r)
_ -> prettyExprDefault $ HofExpr hof

instance Pretty a => Pretty (VarP a) where
Expand Down
29 changes: 3 additions & 26 deletions src/lib/Parser.hs
Expand Up @@ -184,7 +184,6 @@ leafExpr = parens (mayPair $ makeExprParser leafExpr ops)
<|> uLamExpr
<|> uForExpr
<|> caseExpr
<|> uCaseExpr
<|> uPrim
<|> unitCon
<|> uLabeledExprs
Expand Down Expand Up @@ -260,7 +259,7 @@ dataDef = do

-- TODO: default to `Type` if unannoted
tyConDef :: Parser UConDef
tyConDef = UConDef <$> upperName <*> manyNested annBinder
tyConDef = UConDef <$> (upperName <|> symName) <*> manyNested annBinder

-- TODO: dependent types
dataConDef :: Parser UConDef
Expand Down Expand Up @@ -443,27 +442,6 @@ leafPat =
patOps :: [[Operator Parser UPat]]
patOps = [[InfixR $ sym "," $> \x y -> joinSrc x y $ UPatPair x y]]

uCaseExpr :: Parser UExpr
uCaseExpr = do
((), pos) <- withPos $ keyWord OldCaseKW
e <- expr
withIndent $ do
l <- lexeme (string "Left") >> caseLam
nextLine
r <- lexeme (string "Right") >> caseLam
return $ applyNamed pos "caseAnalysis" [e, l, r]

caseLam :: Parser UExpr
caseLam = do
p <- pat
sym "->"
body <- blockOrExpr
return $ WithSrc (srcPos body) $ ULam (p, Nothing) (PlainArrow ()) body

applyNamed :: SrcPos -> String -> [UExpr] -> UExpr
applyNamed pos name args = foldl mkApp f args
where f = WithSrc pos $ UVar (Name SourceName (fromString name) 0:>())

annot :: Parser a -> Parser a
annot p = label "type annotation" $ sym ":" >> p

Expand Down Expand Up @@ -660,7 +638,7 @@ mkName s = Name SourceName (fromString s) 0
type Lexer = Parser

data KeyWord = DefKW | ForKW | RofKW | CaseKW | OfKW
| ReadKW | WriteKW | StateKW | OldCaseKW | DataKW | WhereKW
| ReadKW | WriteKW | StateKW | DataKW | WhereKW

upperName :: Lexer Name
upperName = liftM mkName $ label "upper-case name" $ lexeme $
Expand Down Expand Up @@ -690,11 +668,10 @@ keyWord kw = lexeme $ try $ string s >> notFollowedBy nameTailChar
StateKW -> "State"
DataKW -> "data"
WhereKW -> "where"
OldCaseKW -> "oldcase"

keyWordStrs :: [String]
keyWordStrs = ["def", "for", "rof", "case", "of", "llam",
"Read", "Write", "Accum", "oldcase", "data", "where"]
"Read", "Write", "Accum", "data", "where"]

primName :: Lexer String
primName = lexeme $ try $ char '%' >> some alphaNumChar
Expand Down
1 change: 0 additions & 1 deletion src/lib/Serialize.hs
Expand Up @@ -227,7 +227,6 @@ flattenType (TC con) = case con of
BaseType b -> [BaseTy b]
IntRange _ _ -> [IntTy]
IndexRange _ _ _ -> [IntTy]
SumType _ _ -> undefined
_ -> error $ "Unexpected type: " ++ show con
flattenType ty = error $ "Unexpected type: " ++ show ty

Expand Down
17 changes: 0 additions & 17 deletions src/lib/Simplify.hs
Expand Up @@ -92,8 +92,6 @@ simplifyAtom atom = case atom of
Pi _ -> substEmbedR atom
Con (AnyValue (TabTy v b)) -> TabValA v <$> mkAny b
Con (AnyValue (PairTy a b))-> PairVal <$> mkAny a <*> mkAny b
Con (AnyValue (SumTy l r)) -> do
Con <$> (SumCon <$> mkAny (TC $ BaseType $ Scalar BoolType) <*> mkAny l <*> mkAny r)
Con con -> Con <$> mapM simplifyAtom con
TC tc -> TC <$> mapM substEmbedR tc
Eff eff -> Eff <$> substEmbedR eff
Expand Down Expand Up @@ -192,8 +190,6 @@ simplifyOp :: Op -> SimplifyM Atom
simplifyOp op = case op of
Fst (PairVal x _) -> return x
Snd (PairVal _ y) -> return y
SumGet (SumVal _ l r) left -> return $ if left then l else r
SumTag (SumVal s _ _) -> return $ s
_ -> emitOp op

simplifyHof :: Hof -> SimplifyM Atom
Expand Down Expand Up @@ -236,19 +232,6 @@ simplifyHof hof = case hof of
(ans, sOut) <- fromPair =<< (emit $ Hof $ RunState s' lam')
ans' <- applyRecon recon ans
return $ PairVal ans' sOut
SumCase c lBody rBody -> do
~(lBody', Nothing) <- simplifyLam lBody
~(rBody', Nothing) <- simplifyLam rBody
l <- projApp lBody' True
r <- projApp rBody' False
isLeft <- simplRec $ Op $ SumTag c
emitOp $ Select isLeft l r
where
simplRec :: Expr -> SimplifyM Atom
simplRec = dropSub . simplifyExpr
projApp f isLeft = do
cComp <- simplRec $ Op $ SumGet c isLeft
simplRec $ App f cComp

dropSub :: SimplifyM a -> SimplifyM a
dropSub m = local mempty m

0 comments on commit 78d6ebd

Please sign in to comment.