Permalink
Browse files

working forall introduction

  • Loading branch information...
imccoy committed Jul 25, 2017
1 parent bda2bb4 commit 8c0f7779918461a28fea42f94c62ab69bbe85cb8
Showing with 62 additions and 30 deletions.
  1. +2 −0 app/Main.hs
  2. +60 −30 src/Lib.hs
@@ -6,6 +6,8 @@ code0 = app (app (var "+") (number 1)) (number 2)
code1 = app (var "inc") (number 1)
code2 = lett [("id", lam "x" (var "x"))]
$ app (app (var "+") (app (var "id") (number 1))) (number 2)
code3 = lett [("f", lam "n" (app (app (var "+") (app (var "f") (var "n"))) (number 2)))] (app (var "f") (number 2))
code4 = lett [("f", lam "n" (app (app (var "+") (app (var "f") (number 1))) (var "n")))] (app (var "f") (number 2))

main :: IO ()
main = putStrLn . show . evalTypecheck $ code0
@@ -7,7 +7,7 @@ module Lib
, runTypecheck, evalTypecheck, evalGeneralisedTypecheck, typecheck
) where

import Control.Monad ((<=<), void)
import Control.Monad ((<=<), void, forM)
import Control.Monad.Error.Class (MonadError, throwError)
import Control.Monad.Identity (runIdentity, Identity)
import Control.Monad.Trans (MonadTrans, lift)
@@ -19,6 +19,8 @@ import qualified Data.List as List
import qualified Data.Map as Map
import Data.Map (Map)
import Data.Maybe (fromMaybe)
import qualified Data.Set as Set
import Data.Set (Set)
import Data.Text (Text)
import qualified Data.Text as T
import Data.Functor.Fixedpoint (Fix(..))
@@ -29,6 +31,8 @@ import Debug.Trace
data Node t = Node t (Expr t)
deriving (Functor, Traversable, Show, Eq)

nodeDecoration (Node d _) = d

data Expr t = Lam Text (Node t)
| App (Node t) (Node t)
| Var Text
@@ -49,6 +53,15 @@ exprMap1 _ (Var v) = pure $ Var v
exprMap1 _ (Number n) = pure $ Number n
exprMap1 _ (Text t) = pure $ Text t

exprChildNodes :: Expr t -> [Node t]
exprChildNodes (Lam _ lamBody) = [lamBody]
exprChildNodes (App lamExp argExp) = [lamExp, argExp]
exprChildNodes (Let bindings bound) = bound:(map snd bindings)
exprChildNodes _ = []

childNodes :: Node t -> [Node t]
childNodes (Node _ e) = exprChildNodes e

lam argName body = Node () $ Lam argName body
app fun arg = Node () $ App fun arg
var text = Node () $ Var text
@@ -128,8 +141,11 @@ constrain tyEnv (Node tyVar expr) = do ty' <- go expr
funExpTy <- constrain tyEnv funExp
unify (UTerm $ LamTy argTy funBodyTy) funExpTy
pure funBodyTy
go (Let bindings bodyExp) = do bindingTys <- traverse (\(name, exp) -> ((name,) . (NeedsFreshening,)) <$> constrain tyEnv exp) bindings
constrain (Map.union (Map.fromList bindingTys) tyEnv) bodyExp
go (Let bindings bodyExp) = do bindingTys <- sequence . flip Map.mapWithKey (Map.fromList bindings) $ \name exp -> do
stubVar <- UVar <$> lift freeVar
realVar <- constrain (Map.insert name (SoFreshAlready, stubVar) tyEnv) exp
(NeedsFreshening,) <$> (stubVar =:= realVar)
constrain (Map.union bindingTys tyEnv) bodyExp

go (Var text)
| Just (NeedsFreshening, varTy) <- Map.lookup text tyEnv = freshen varTy
@@ -138,44 +154,58 @@ constrain tyEnv (Node tyVar expr) = do ty' <- go expr
go (Number _) = pure numTy
go (Text _) = pure textTy

data ForallTy = Forall [Text] (TyF ForallTy)
| TyVar Text
| HereTy (TyF ForallTy)
deriving (Show)
data ForallTyF f = Forall [Text] (ForallTyF f)
| TyVar Text
| HereTy (f (ForallTyF f))

instance (Show (f (ForallTyF f))) => Show (ForallTyF f)
where showsPrec d (TyVar t) = (T.unpack t ++)
showsPrec d (HereTy t) = showsPrec d t
showsPrec d (Forall ns ty) = showParen (d > app_prec) $
("Forall [" ++ ) .
(List.intercalate "," (T.unpack <$> ns) ++) .
("] " ++ ) .
showsPrec d ty
where app_prec = 10


commonTyVars :: (BindingMonad t v m, t ~ TyF, Variable v) => Map Int Text -> Ty v -> Ty v -> m [(Int, Text)]
commonTyVars env ty1 ty2 = do ty1FreeVars <- getFreeVars ty1
ty2FreeVars <- getFreeVars ty2
let common = List.intersect (getVarID <$> ty1FreeVars) (getVarID <$> ty2FreeVars) List.\\ (Map.keys env)
let usedNames = Map.elems env
let names = filter (`List.notElem` usedNames) . fmap (T.pack . ("t" ++) . show) $ [0..]
pure $ zip common names
--deriving instance (Show (f (ForallTyF f))) => Show (ForallTyF f)

type ForallTy = ForallTyF TyF

generalise :: forall t v e m em. (BindingMonad t v m, t ~ TyF, Show v, Variable v, MonadError e (em m), MonadTrans em, TinyFallible v e) => Node (Ty v) -> em m (Node ForallTy)
generalise = go Map.empty
where go :: Map Int Text -> Node (Ty v) -> em m (Node ForallTy)
go env (Node (UTerm (LamTy argTy resTy)) exp) = do (newEnv, lamTy) <- generaliseLamTy env argTy resTy
exp' <- exprMap1 (go newEnv) exp
pure $ Node lamTy exp'
go env (Node ty exp) = Node <$> go' env ty <*> exprMap1 (go env) exp
generalise = goTop Map.empty
where goTop :: Map Int Text -> Node (Ty v) -> em m (Node ForallTy)
goTop env node@(Node ty exp) = do vars <- collectNonLetFreeVarIDs node
let newVarIDs = Set.filter (`Map.notMember` env) vars
let usedNames = Map.elems env
let candidateNames = filter (`List.notElem` usedNames) . fmap (T.pack . ("t" ++) . show) $ [0..]
let newEnvElems = zip (Set.toList newVarIDs) candidateNames
let newEnv = Map.union (Map.fromList newEnvElems) env
(Node ty' exp') <- go newEnv node
pure $ Node (Forall (snd <$> newEnvElems) ty') exp'

collectNonLetFreeVarIDs :: Node (Ty v) -> em m (Set Int)
collectNonLetFreeVarIDs (Node ty expr) = do fvs <- Set.fromList <$> (fmap getVarID <$> lift (getFreeVars ty))
Set.union fvs <$> case expr of
Let _ body -> collectNonLetFreeVarIDs body
_ -> Set.unions <$> mapM collectNonLetFreeVarIDs (exprChildNodes expr)

go :: Map Int Text -> Node (Ty v) -> em m (Node ForallTy)
go env (Node ty (Let bindings body)) = do ty' <- go' env ty
bindings' <- forM bindings $ \(name, exp) -> (name,) <$> goTop env exp
body' <- go env body
pure $ Node ty' (Let bindings' body')
go env (Node ty exp) = do ty' <- go' env ty
Node ty' <$> exprMap1 (go env) exp

go' :: Map Int Text -> Ty v -> em m ForallTy
go' env (UTerm (LamTy argTy resTy)) = snd <$> generaliseLamTy env argTy resTy
go' env (UTerm (LamTy argTy resTy)) = HereTy <$> (LamTy <$> go' env argTy <*> go' env resTy)
go' env (UTerm TextTy) = pure . HereTy $ TextTy
go' env (UTerm NumTy) = pure . HereTy $ NumTy
go' env (UVar v) | Just n <- Map.lookup (getVarID v) env = pure . TyVar $ n
| otherwise = throwError (undefinedTyVar v)

generaliseLamTy env argTy resTy = do newTyVars <- Map.fromList <$> (lift $ commonTyVars env argTy resTy)
let newEnv = Map.union newTyVars env
ty <- LamTy <$> go' newEnv argTy <*> go' newEnv resTy
pure (newEnv, Forall (Map.elems newTyVars) ty)




initialEnv :: [(Text, Ty v)]
initialEnv = [("+", lamTy numTy $ lamTy numTy $ numTy)
,("inc", lamTy numTy numTy)

0 comments on commit 8c0f777

Please sign in to comment.