Skip to content

Commit

Permalink
Fix some bugs in record type inference.
Browse files Browse the repository at this point in the history
In the previous implementation, it was possible for inference to emit
RecordSplit on values that weren't yet known to be records (well,
inference knew it, but hadn't constrained it to be so). This was causing
type errors for RecordSplit when `getType` was called before inference
solved for the missing type variables.

The fix is to constrain the type to be a record (with unknown fields)
and then zonk the expression before emitting a RecordSplit. This has
the (surprising!) side effect of actually greatly simplifying the
pattern match logic; it turns out that most of that was just duplicating
work that unification had to do anyway.

This fix exposed a second problem with row inference, which was that
I forgot to include the base case for unifying two things that can't
possibly be equal! Now, if we try to unify anything other than a type
variable with an empty row, we throw a type error immediately.
  • Loading branch information
danieldjohnson committed Aug 15, 2020
1 parent 24a3c8a commit 099cd16
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 50 deletions.
5 changes: 4 additions & 1 deletion examples/record-variant-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ def getTwoFoosAndABar (rest : Types)?->
:p
({b=b, a=a1, a=a2}) = {a=1, b=2}
(a1, a2, b)
> Type error:Labels in record pattern do not match record type. Expected structure {a: Int64 & b: Int64}
> Type error:
> Expected: {a: a & a: b & b: c}
> Actual: {a: Int64 & b: Int64}
> (Solving for: [a:Type, b:Type, c:Type])
>
> ({b=b, a=a1, a=a2}) = {a=1, b=2}
> ^^^^^^^^^^^^^^^^^
Expand Down
68 changes: 25 additions & 43 deletions src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import Data.Foldable (fold, toList, asum)
import Data.Functor
import qualified Data.List.NonEmpty as NE
import qualified Data.Map.Strict as M
import qualified Data.Map.Merge.Strict as MMerge
import Data.String (fromString)
import Data.Text.Prettyprint.Doc

Expand Down Expand Up @@ -230,12 +229,13 @@ checkOrInferRho (WithSrc pos expr) reqTy =
matchRequirement $ Record items'
URecord (Ext items (Just ext)) -> do
items' <- mapM inferRho items
ext' <- inferRho ext
restTy <- freshInferenceName LabeledRowKind
ext' <- zonk =<< (checkRho ext $ RecordTy $ Ext NoLabeledItems $ Just restTy)
matchRequirement =<< emit (RecordCons items' ext')
UVariant labels@(LabeledItems lmap) label value -> do
value' <- inferRho value
prevTys <- mapM (const $ freshType TyKind) labels
Var (rest:>_) <- freshType LabeledRowKind
rest <- freshInferenceName LabeledRowKind
let items = prevTys <> labeledSingleton label (getType value')
let extItems = Ext items $ Just rest
let i = case M.lookup label lmap of
Expand All @@ -245,7 +245,7 @@ checkOrInferRho (WithSrc pos expr) reqTy =
URecordTy row -> matchRequirement =<< RecordTy <$> checkExtLabeledRow row
UVariantTy row -> matchRequirement =<< VariantTy <$> checkExtLabeledRow row
UVariantLift labels value -> do
Var (row:>_) <- freshType LabeledRowKind
row <- freshInferenceName LabeledRowKind
value' <- checkRho value $ VariantTy $ Ext NoLabeledItems $ Just row
prev <- mapM (\() -> freshType TyKind) labels
matchRequirement =<< emit (VariantLift prev value')
Expand Down Expand Up @@ -308,25 +308,15 @@ unpackTopPat letAnn (WithSrc _ pat) expr = case pat of
++ pprint (RecordTy $ NoExt types)
xs <- zonk expr >>= emitUnpack
zipWithM_ (\p x -> unpackTopPat letAnn p (Atom x)) (toList items) xs
UPatRecord (Ext pats@(LabeledItems patItems) (Just tailPat)) -> do
-- Unpacks at the top level should always be monomorphic in type.
RecordTy (Ext (LabeledItems types) Nothing) <- pure $ getType expr
-- Note: length items /= length types in general; items is what the user
-- wants but types is what we know.
-- First, split off the types the user wants.
let leftOnly = MMerge.traverseMissing $ \k _ ->
throw TypeErr $ "Label " <> show k <> " in record pattern does not "
<> "exist in record to be matched."
let rightOnly = MMerge.dropMissing
let both = MMerge.zipWithAMatched $ \k wanted have ->
if length wanted > length have
then throw TypeErr $ "Label " <> show k <> " in record pattern "
<> "appears too many times."
else return $ NE.fromList $ NE.take (length wanted) have
wantedTypes <- MMerge.mergeA leftOnly rightOnly both patItems types
UPatRecord (Ext pats (Just tailPat)) -> do
wantedTypes <- lift $ mapM (const $ freshType TyKind) pats
restType <- lift $ freshInferenceName LabeledRowKind
let vty = getType expr
lift $ constrainEq (RecordTy $ Ext wantedTypes $ Just restType) vty
-- Split the record.
wantedTypes' <- lift $ zonk wantedTypes
val <- emit =<< zonk expr
split <- emit $ RecordSplit (LabeledItems wantedTypes) val
split <- emit $ RecordSplit wantedTypes' val
[left, right] <- getUnpacked split
leftVals <- getUnpacked left
zipWithM_ (\p x -> unpackTopPat letAnn p (Atom x)) (toList pats) leftVals
Expand Down Expand Up @@ -502,30 +492,20 @@ bindPat' (WithSrc pos pat) val = addSrcContext (Just pos) $ case pat of
lift $ constrainEq (TypeCon def params) (getType val)
xs <- lift $ zonk (Atom val) >>= emitUnpack
fold <$> zipWithM bindPat' (toList ps) xs
UPatRecord (Ext items Nothing) -> do
RecordTy (NoExt types) <- pure $ getType val
when (fmap (const ()) items /= fmap (const ()) types) $ throw TypeErr $
"Labels in record pattern do not match record type. Expected structure "
++ pprint (RecordTy $ NoExt types)
UPatRecord (Ext pats Nothing) -> do
expectedTypes <- lift $ mapM (const $ freshType TyKind) pats
lift $ constrainEq (RecordTy (NoExt expectedTypes)) (getType val)
xs <- lift $ zonk (Atom val) >>= emitUnpack
fold <$> zipWithM bindPat' (toList items) xs
UPatRecord (Ext pats@(LabeledItems patItems) (Just tailPat)) -> do
RecordTy (Ext (LabeledItems types) _) <- pure $ getType val
-- Note: length items /= length types in general; items is what the user
-- wants but types is what we know.
-- First, split off the types the user wants.
let leftOnly = MMerge.traverseMissing $ \k _ ->
throw TypeErr $ "Label " <> show k <> " in record pattern does not "
<> "exist in record to be matched."
let rightOnly = MMerge.dropMissing
let both = MMerge.zipWithAMatched $ \k wanted have ->
if length wanted > length have
then throw TypeErr $ "Label " <> show k <> " in record pattern "
<> "appears too many times."
else return $ NE.fromList $ NE.take (length wanted) have
wantedTypes <- MMerge.mergeA leftOnly rightOnly both patItems types
fold <$> zipWithM bindPat' (toList pats) xs
UPatRecord (Ext pats (Just tailPat)) -> do
wantedTypes <- lift $ mapM (const $ freshType TyKind) pats
restType <- lift $ freshInferenceName LabeledRowKind
let vty = getType val
lift $ constrainEq (RecordTy $ Ext wantedTypes $ Just restType) vty
-- Split the record.
split <- lift $ emit $ RecordSplit (LabeledItems wantedTypes) val
wantedTypes' <- lift $ zonk wantedTypes
val' <- lift $ zonk val
split <- lift $ emit $ RecordSplit wantedTypes' val'
[left, right] <- lift $ getUnpacked split
leftVals <- lift $ getUnpacked left
env1 <- fold <$> zipWithM bindPat' (toList pats) leftVals
Expand Down Expand Up @@ -843,6 +823,8 @@ unifyExtLabeledItems r1 r2 = do
bindQ (v:>LabeledRowKind) (LabeledRow r)
(Ext NoLabeledItems (Just v), r) | v `isin` vs ->
bindQ (v:>LabeledRowKind) (LabeledRow r)
(_, Ext NoLabeledItems _) -> throw TypeErr ""
(Ext NoLabeledItems _, _) -> throw TypeErr ""
(Ext (LabeledItems items1) t1, Ext (LabeledItems items2) t2) -> do
let unifyPrefixes tys1 tys2 = mapM (uncurry unify) $ NE.zip tys1 tys2
sequence_ $ M.intersectionWith unifyPrefixes items1 items2
Expand Down
1 change: 1 addition & 0 deletions src/lib/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ leafExpr = parens (mayPair $ makeExprParser leafExpr ops)
containedExpr :: Parser UExpr
containedExpr = parens (mayPair $ makeExprParser leafExpr ops)
<|> uVarOcc
<|> uLabeledExprs
<?> "contained expression"

uType :: Parser UType
Expand Down
29 changes: 23 additions & 6 deletions src/lib/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -172,21 +172,38 @@ instance HasType Expr where
return resultTy
RecordCons items record -> do
types <- mapM typeCheck items
RecordTy rest <- typeCheck record
rty <- typeCheck record
rest <- case rty of
RecordTy rest -> return rest
_ -> throw TypeErr $ "Can't add fields to a non-record object "
<> pprint record <> " (of type " <> pprint rty <> ")"
return $ RecordTy $ joinExtLabeledItems types rest
RecordSplit types record -> do
mapM_ (|: TyKind) types
RecordTy full <- typeCheck record
fullty <- typeCheck record
full <- case fullty of
RecordTy full -> return full
_ -> throw TypeErr $ "Can't split a non-record object " <> pprint record
<> " (of type " <> pprint fullty <> ")"
diff <- labeledRowDifference full (NoExt types)
return $ RecordTy $ NoExt $
Unlabeled [ RecordTy $ NoExt types, RecordTy diff ]
VariantLift types record -> do
VariantLift types variant -> do
mapM_ (|: TyKind) types
VariantTy rest <- typeCheck record
rty <- typeCheck variant
rest <- case rty of
VariantTy rest -> return rest
_ -> throw TypeErr $ "Can't add alternatives to a non-variant object "
<> pprint variant <> " (of type " <> pprint rty <> ")"
return $ VariantTy $ joinExtLabeledItems types rest
VariantSplit types variant -> do
mapM_ (|: TyKind) types
VariantTy full <- typeCheck variant
fullty <- typeCheck variant
full <- case fullty of
VariantTy full -> return full
_ -> throw TypeErr $ "Can't split a non-variant object "
<> pprint variant <> " (of type " <> pprint fullty
<> ")"
diff <- labeledRowDifference full (NoExt types)
return $ VariantTy $ NoExt $
Unlabeled [ VariantTy $ NoExt types, VariantTy diff ]
Expand Down Expand Up @@ -469,7 +486,7 @@ labeledRowDifference (Ext (LabeledItems items) rest)
case M.lookup label items of
Just types -> assertEq subtypes
(NE.fromList $ NE.take (length subtypes) types) $
"Row types for label " ++ show LabeledRowKind
"Row types for label " ++ show label
Nothing -> throw TypeErr $ "Extracting missing label " ++ show label
-- Extract remaining types from the left.
let
Expand Down

0 comments on commit 099cd16

Please sign in to comment.