Permalink
Browse files

more flexible generalise function

  • Loading branch information...
imccoy committed Aug 14, 2017
1 parent 274a63b commit b98bcd8fe5d97db56faa68128817e0685c40a130
Showing with 46 additions and 34 deletions.
  1. +1 −0 babytc.cabal
  2. +45 −34 src/Lib.hs
View
@@ -21,6 +21,7 @@ library
, containers
, either
, mtl
, recursion-schemes
, text
, unification-fd
default-extensions: DeriveFunctor
View
@@ -23,11 +23,13 @@ import qualified Data.Set as Set
import Data.Set (Set)
import Data.Text (Text)
import qualified Data.Text as T
import Data.Functor.Fixedpoint (Fix(..), hmapM, cata)
import Data.Functor.Fixedpoint (Fix(..), hmapM, cata, cataM)
import Data.Foldable (toList)
import Debug.Trace
import FoldableUterm
data ExprF f = Lam Text f
| App f f
| Var Text
@@ -99,6 +101,9 @@ data AnnExprF ann f = AnnExprF ann (ExprF f)
type AnnExpr ann = Fix (AnnExprF ann)
annFromAnnExpr :: AnnExpr a -> a
annFromAnnExpr (Fix (AnnExprF ann _)) = ann
type TyVardExpr v = AnnExpr v
type TypedExpr v = AnnExpr (Ty v)
@@ -203,8 +208,6 @@ instance (Show (f (ForallTyF f))) => Show (ForallTyF f)
where app_prec = 10
--deriving instance (Show (f (ForallTyF f))) => Show (ForallTyF f)
type ForallTy = ForallTyF TyF
@@ -220,39 +223,47 @@ deriving instance (Eq ForallTy, Eq e) => Eq (ForallTydExprF e)
type ForallTydExpr = Fix ForallTydExprF
para :: Functor f => (f (Fix f, a) -> a) -> Fix f -> a
para f = snd . cata (\v -> (Fix (fst <$> v), f v))
paraM :: (Traversable f, Functor f, Monad m) => (f (Fix f, a) -> m a) -> Fix f -> m a
paraM f = fmap snd . cataM (\v -> do v' <- f v
pure (Fix (fst <$> v), v'))
generalise :: forall t v e m em. (BindingMonad t v m, t ~ TyF, Show v, Variable v, MonadError e (em m), MonadTrans em, TinyFallible v e) => TypedExpr v -> em m ForallTydExpr
generalise = goTop Map.empty
where goTop :: Map Int Text -> TypedExpr v -> em m ForallTydExpr
goTop env node@(Fix (AnnExprF ty expr)) = do vars <- collectNonLetFreeVarIDs node
let newVarIDs = Set.filter (`Map.notMember` env) vars
let usedNames = Map.elems env
let candidateNames = filter (`List.notElem` usedNames) . fmap (T.pack . ("t" ++) . show) $ [0..]
let newEnvElems = zip (Set.toList newVarIDs) candidateNames
let newEnv = Map.union (Map.fromList newEnvElems) env
(Fix (ForallTydExprF (ForallTyd ty' expr'))) <- go newEnv node
pure $ Fix $ ForallTydExprF $ ForallTyd (Forall (snd <$> newEnvElems) ty') expr'
collectNonLetFreeVarIDs :: TypedExpr v -> em m (Set Int)
collectNonLetFreeVarIDs (Fix (AnnExprF ty expr)) = do fvs <- Set.fromList <$> (fmap getVarID <$> lift (getFreeVars ty))
Set.union fvs <$> case expr of
Let _ body -> collectNonLetFreeVarIDs body
_ -> Set.unions <$> mapM collectNonLetFreeVarIDs (exprChildNodes expr)
go :: Map Int Text -> TypedExpr v -> em m (ForallTydExpr)
go env (Fix (AnnExprF ty (Let bindings body))) = do ty' <- go' env ty
bindings' <- forM bindings $ \(name, exp) -> (name,) <$> goTop env exp
body' <- go env body
pure . Fix . ForallTydExprF $ ForallTyd ty' (Let bindings' body')
go env (Fix (AnnExprF ty exp)) = do ty' <- go' env ty
(Fix . ForallTydExprF . ForallTyd ty') <$> exprMap1 (go env) exp
go' :: Map Int Text -> Ty v -> em m ForallTy
go' env (UTerm (LamTy argTy resTy)) = HereTy <$> (LamTy <$> go' env argTy <*> go' env resTy)
go' env (UTerm TextTy) = pure . HereTy $ TextTy
go' env (UTerm NumTy) = pure . HereTy $ NumTy
go' env (UVar v) | Just n <- Map.lookup (getVarID v) env = pure . TyVar $ n
| otherwise = throwError (undefinedTyVar v)
generalise expr = withLowestForalls <$> annotatedWithTyVarIds expr >>= \(f, _) -> f Map.empty
where withLowestForalls :: Fix (AnnExprF (Ty v, Set Int)) -> (Map Int Text -> em m ForallTydExpr, Set Int)
withLowestForalls = cata $ \(AnnExprF (ty, tyVarIds) expr) -> (\tyEnv -> do let childTyVarSets = snd <$> toList expr
let childTyVarMaps = Map.fromSet (const (1 :: Int)) <$> childTyVarSets
let counts = Map.unionsWith (+) childTyVarMaps
let singleAppearances = Map.filter (== 1) counts
hereVars <- lift $ fmap getVarID <$> getFreeVars ty
let forallOnes = List.filter (`Map.member` singleAppearances) hereVars
let forallManies = Map.keys . Map.filter (>1) $ counts
let forallIds = List.filter (`Map.notMember` tyEnv) $ forallOnes ++ forallManies
let usedNames = Set.fromList . Map.elems $ tyEnv
let names = filter (`Set.notMember` usedNames) ["t" `T.append` (T.pack . show $ n) | n <- [(0::Int)..]]
let newNameIds = zip forallIds names
let newNames = Map.fromList newNameIds
let newTyEnv = newNames `Map.union` tyEnv
expr' <- traverse (\(f, _) -> f newTyEnv) expr
ty' <- forallTy newTyEnv ty
if null newNameIds
then pure $ Fix $ ForallTydExprF $ ForallTyd ty' expr'
else pure $ Fix $ ForallTydExprF $ ForallTyd (Forall (snd <$> newNameIds) ty') expr'
,tyVarIds)
annotatedWithTyVarIds :: TypedExpr v -> em m (Fix (AnnExprF (Ty v, Set Int)))
annotatedWithTyVarIds = cataM $ \(AnnExprF ty expr) -> do hereVars <- lift $ getFreeVars ty
let hereVarIds = Set.fromList (getVarID <$> hereVars)
let thereIds = snd . annFromAnnExpr <$> toList expr
pure . Fix $ AnnExprF (ty, Set.unions $ hereVarIds:thereIds) expr
forallTy :: Map Int Text -> Ty v -> em m ForallTy
forallTy env = cataM alg . refix . buildUTermF
where alg (UTermF t) = pure $ HereTy t
alg (UVarF v) | Just n <- Map.lookup (getVarID v) env = pure . TyVar $ n
| otherwise = throwError (undefinedTyVar v)
initialEnv :: [(Text, Ty v)]
initialEnv = [("+", lamTy numTy $ lamTy numTy $ numTy)

0 comments on commit b98bcd8

Please sign in to comment.