Permalink
Browse files

Extracted TH Utils module to th-extras package

  • Loading branch information...
1 parent c577541 commit 0399812f2d2ba397010f2576e0acafed58b0d7c3 @mokus0 committed Jan 13, 2012
Showing with 94 additions and 73 deletions.
  1. +2 −2 dependent-sum-template.cabal
  2. +23 −6 src/Data/GADT/Compare/TH.hs
  3. +10 −3 src/Data/GADT/Show/TH.hs
  4. +0 −61 src/Language/Haskell/TH/Utils.hs
  5. +59 −1 src/Test.hs
@@ -18,7 +18,7 @@ Library
hs-source-dirs: src
exposed-modules: Data.GADT.Compare.TH
Data.GADT.Show.TH
- other-modules: Language.Haskell.TH.Utils
build-depends: base >= 3 && <5,
dependent-sum,
- template-haskell
+ template-haskell,
+ th-extras
@@ -1,3 +1,4 @@
+{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TemplateHaskell #-}
@@ -12,7 +13,7 @@ import Control.Applicative
import Control.Monad
import Data.GADT.Compare
import Language.Haskell.TH
-import Language.Haskell.TH.Utils
+import Language.Haskell.TH.Extras
-- A type class purely for overloading purposes
class DeriveGEQ t where
@@ -39,6 +40,14 @@ instance DeriveGEQ Dec where
where
inst = instanceD (cxt (map return dataCxt)) (appT (conT ''GEq) (conT name)) [geqDec]
geqDec = geqFunction bndrs cons
+#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 612
+ deriveGEq (DataInstD dataCxt name tyArgs cons _) = return <$> inst
+ where
+ inst = instanceD (cxt (map return dataCxt)) (appT (conT ''GEq) (foldl1 appT (map return $ (ConT name : init tyArgs)))) [geqDec]
+ -- TODO: figure out proper number of family parameters vs instance parameters
+ bndrs = [PlainTV v | VarT v <- tail tyArgs ]
+ geqDec = geqFunction bndrs cons
+#endif
instance DeriveGEQ t => DeriveGEQ [t] where
deriveGEq [it] = deriveGEq it
@@ -56,7 +65,7 @@ geqFunction bndrs cons = funD 'geq
geqClause bndrs con = do
let argTypes = argTypesOfCon con
- needsGEq argType = any (`occursInType` argType) (bndrs ++ varsBoundInCon con)
+ needsGEq argType = any ((`occursInType` argType) . nameOfBinder) (bndrs ++ varsBoundInCon con)
nArgs = length argTypes
lArgNames <- replicateM nArgs (newName "x")
@@ -86,10 +95,10 @@ instance Monad (GComparing a b) where
GComparing (Right x) >>= f = f x
geq' :: GCompare t => t a -> t b -> GComparing x y (a := b)
-geq' x y = GComparing $ case gcompare x y of
+geq' x y = GComparing (case gcompare x y of
GLT -> Left GLT
GEQ -> Right Refl
- GGT -> Left GGT
+ GGT -> Left GGT)
compare' x y = GComparing $ case compare x y of
LT -> Left GLT
@@ -122,6 +131,14 @@ instance DeriveGCompare Dec where
where
inst = instanceD (cxt (map return dataCxt)) (appT (conT ''GCompare) (conT name)) [gcompareDec]
gcompareDec = gcompareFunction bndrs cons
+#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 612
+ deriveGCompare (DataInstD dataCxt name tyArgs cons _) = return <$> inst
+ where
+ inst = instanceD (cxt (map return dataCxt)) (appT (conT ''GCompare) (foldl1 appT (map return $ (ConT name : init tyArgs)))) [gcompareDec]
+ -- TODO: figure out proper number of family parameters vs instance parameters
+ bndrs = [PlainTV v | VarT v <- tail tyArgs ]
+ gcompareDec = gcompareFunction bndrs cons
+#endif
instance DeriveGCompare t => DeriveGCompare [t] where
deriveGCompare [it] = deriveGCompare it
@@ -131,7 +148,7 @@ instance DeriveGCompare t => DeriveGCompare (Q t) where
deriveGCompare = (>>= deriveGCompare)
gcompareFunction boundVars cons
- | null cons = funD 'gcompare [clause [bangP wildP, bangP wildP] (normalB [| undefined |]) []]
+ | null cons = funD 'gcompare [clause [] (normalB [| \x y -> seq x (seq y undefined) |]) []]
| otherwise = funD 'gcompare (concatMap gcompareClauses cons)
where
-- for every constructor, first check for equality (recursively comparing
@@ -143,7 +160,7 @@ gcompareFunction boundVars cons
, clause [wildP, recP conName []] (normalB [| GGT |]) []
] where conName = nameOfCon con
- needsGCompare argType con = any (`occursInType` argType) (boundVars ++ varsBoundInCon con)
+ needsGCompare argType con = any ((`occursInType` argType) . nameOfBinder) (boundVars ++ varsBoundInCon con)
-- main clause; using the 'GComparing' monad, compare all arguments to the
-- constructor recursively, attempting to unify type variables by recursive
View
@@ -1,4 +1,4 @@
-{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE CPP, TemplateHaskell #-}
module Data.GADT.Show.TH
( DeriveGShow(..)
) where
@@ -8,7 +8,7 @@ import Control.Monad
import Data.GADT.Show
import Data.List
import Language.Haskell.TH
-import Language.Haskell.TH.Utils
+import Language.Haskell.TH.Extras
class DeriveGShow t where
deriveGShow :: t -> Q [Dec]
@@ -34,6 +34,13 @@ instance DeriveGShow Dec where
where
inst = instanceD (cxt (map return dataCxt)) (appT (conT ''GShow) (conT name)) [gshowDec]
gshowDec = gshowFunction cons
+#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 612
+ deriveGShow (DataInstD dataCxt name tyArgs cons _) = return <$> inst
+ where
+ inst = instanceD (cxt (map return dataCxt)) (appT (conT ''GShow) (foldl1 appT (map return $ (ConT name : init tyArgs)))) [gshowDec]
+ -- TODO: figure out proper number of family parameters vs instance parameters
+ gshowDec = gshowFunction cons
+#endif
instance DeriveGShow t => DeriveGShow [t] where
deriveGShow [it] = deriveGShow it
@@ -60,7 +67,7 @@ showsName name = [| showString $(litE . stringL $ nameBase name) |]
gshowBody prec conName [] = showsName conName
gshowBody prec conName argNames =
- [| showParen ($prec > 10) $( compose $ intersperse [| showChar ' ' |]
+ [| showParen ($prec > 10) $( composeExprs $ intersperse [| showChar ' ' |]
( showsName conName
: [ [| showsPrec 11 $arg |]
| argName <- argNames, let arg = varE argName
@@ -1,61 +0,0 @@
-{-# LANGUAGE TemplateHaskell #-}
-module Language.Haskell.TH.Utils where
-
-import Language.Haskell.TH
-import Language.Haskell.TH.Syntax
-
-punt :: Show a => a -> ExpQ
-punt = litE . stringL . show
-
--- various constants and utility functions
-
-nameOfCon (NormalC name _) = name
-nameOfCon (RecC name _) = name
-nameOfCon (InfixC _ name _) = name
-nameOfCon (ForallC _ _ con) = nameOfCon con
-
-argCountOfCon (NormalC _ args) = length args
-argCountOfCon (RecC _ args) = length args
-argCountOfCon (InfixC _ _ _) = 2
-argCountOfCon (ForallC _ _ con) = argCountOfCon con
-
--- WARNING: does not handle ForallC in any kind of generally-applicable way!
-argTypesOfCon (NormalC _ args) = map snd args
-argTypesOfCon (RecC _ args) = [t | (_,_,t) <- args]
-argTypesOfCon (InfixC x _ y) = map snd [x,y]
-argTypesOfCon (ForallC _ _ con) = argTypesOfCon con
-
-varsBoundInCon (ForallC bndrs _ con) = bndrs ++ varsBoundInCon con
-varsBoundInCon _ = []
-
-occursInType :: TyVarBndr -> Type -> Bool
-occursInType bndr ty = case ty of
- ForallT bndrs _ ty
- | any (== nameOfBinder bndr) (map nameOfBinder bndrs)
- -> False
- | otherwise
- -> occursInType bndr ty
- VarT name
- | name == nameOfBinder bndr -> True
- | otherwise -> False
- AppT ty1 ty2 -> occursInType bndr ty1 || occursInType bndr ty2
- SigT ty _ -> occursInType bndr ty
- _ -> False
-
-nameOfBinder (PlainTV name) = name
-nameOfBinder (KindedTV name _) = name
-
-headOfType :: Type -> Name
-headOfType (ForallT _ _ ty) = headOfType ty
-headOfType (VarT name) = name
-headOfType (ConT name) = name
-headOfType (TupleT n) = tupleTypeName n
-headOfType (UnboxedTupleT n) = unboxedTupleTypeName n -- error "don't know how to get the name of an unboxed tuple type"
-headOfType ArrowT = ''(->)
-headOfType ListT = ''[]
-headOfType (AppT t _) = headOfType t
-headOfType (SigT t _) = headOfType t
-
-compose [] = [| id |]
-compose [f] = f
-compose (f:fs) = [| $f . $(compose fs) |]
View
@@ -5,6 +5,7 @@
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
module Test where
import Control.Applicative
@@ -110,4 +111,61 @@ do
, deriveGShow gshowInst
]
-instance Show (a Double) => Show (Spleeb a b) where showsPrec = gshowsPrec
+instance Show (a Double) => Show (Spleeb a b) where showsPrec = gshowsPrec
+
+-- another option; start from the declaration and juggle that a bit
+do
+ decs <- [d|
+ data Fnord a where Yarr :: Fnord Double; Grr :: Fnord (Int -> String)
+ |]
+
+ geqInst <- deriveGEq decs
+ gcompareInst <- deriveGCompare decs
+ gshowInst <- deriveGShow decs
+
+ return $ concat
+ [ decs
+ , geqInst
+ , gcompareInst
+ , gshowInst
+ ]
+
+instance Show (Fnord a) where showsPrec = gshowsPrec
+
+-- also should handle data families:
+data family Squawk (f :: * -> *) :: * -> *
+
+-- data instance Squawk Maybe t where
+-- Blotto :: Maybe Int -> Squawk Maybe String
+-- Flubbet :: Squawk Maybe (Maybe Int)
+
+do
+ -- [d|
+ -- data instance Squawk Maybe t where
+ -- Blotto :: Maybe Int -> Squawk Maybe String
+ -- Flubbet :: Squawk Maybe (Maybe Int)
+ -- |]
+
+ dec <- dataInstD
+ (return [])
+ ''Squawk
+ [ conT ''Maybe, varT (mkName "t")]
+ [ ForallC [] [EqualP (VarT (mkName "t")) (ConT ''String)]
+ <$> normalC (mkName "Blotto") [strictType notStrict [t| Maybe Int |]]
+ , ForallC [] [EqualP (VarT (mkName "t")) (AppT (ConT ''Maybe) (ConT ''Int))]
+ <$> normalC (mkName "Flubbet") []
+ ]
+ []
+
+ geqInst <- deriveGEq dec
+ gcompareInst <- deriveGCompare dec
+ gshowInst <- deriveGShow dec
+
+ return $
+ dec : concat
+ [ geqInst
+ , gcompareInst
+ , gshowInst
+ ]
+
+instance Show (Squawk Maybe t) where showsPrec = gshowsPrec

0 comments on commit 0399812

Please sign in to comment.