Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
mkloczko committed Apr 3, 2022
2 parents 10cc951 + 930a1b9 commit ca3758a
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 37 deletions.
4 changes: 2 additions & 2 deletions derive-storable-plugin.cabal
Expand Up @@ -16,7 +16,7 @@ category: Foreign
build-type: Simple
extra-source-files: ChangeLog.md README.md
cabal-version: >=1.10
tested-with: GHC==8.2.2, GHC==8.4.2, GHC==8.6.5, GHC==8.8.1, GHC==8.10.2,GHC==9.0.1
tested-with: GHC==8.2.2, GHC==8.4.2, GHC==8.6.5, GHC==8.8.1, GHC==8.10.7, GHC==9.0.2, GHC==9.2.2

Flag sumtypes
Description: Use sumtypes within benchmark and tests.
Expand All @@ -32,7 +32,7 @@ library
, Foreign.Storable.Generic.Plugin.Internal.Predicates
, Foreign.Storable.Generic.Plugin.Internal.Types
other-extensions: DeriveGeneric, DeriveAnyClass, PatternGuards
build-depends: base >=4.10 && <5, ghc >= 8.2 && < 9.1, ghci >= 8.2 && < 9.1, derive-storable >= 0.3 && < 0.4
build-depends: base >=4.10 && <5, ghc >= 8.2 && < 9.3, ghci >= 8.2 && < 9.3, derive-storable >= 0.3 && < 0.4
hs-source-dirs: src
default-language: Haskell2010

Expand Down
10 changes: 7 additions & 3 deletions src/Foreign/Storable/Generic/Plugin/Internal.hs
Expand Up @@ -32,7 +32,12 @@ import qualified GHC.Types.Name as N (varName)
import GHC.Types.SrcLoc (noSrcSpan)
import GHC.Types.Unique (getUnique)
import GHC.Driver.Main (hscCompileCoreExpr, getHscEnv)
#if MIN_VERSION_GLASGOW_HASKELL(9,2,0,0)
import GHC.Driver.Env.Types (HscEnv)
import GHC.Unit.Module.ModGuts (ModGuts(..))
#else
import GHC.Driver.Types (HscEnv,ModGuts(..))
#endif
import GHC.Core.Opt.Monad
(CoreM, CoreToDo(..),
getHscEnv, getDynFlags, putMsg, putMsgS)
Expand All @@ -44,7 +49,7 @@ import GHC.Builtin.Types (intDataCon)
import GHC.Core.DataCon (dataConWorkId,dataConOrigArgTys)
import GHC.Core.Make (mkWildValBinder)
import GHC.Utils.Outputable
(cat, ppr, SDoc, showSDocUnsafe, showSDoc,
(cat, ppr, SDoc, showSDocUnsafe,
($$), ($+$), hsep, vcat, empty,text,
(<>), (<+>), nest, int, colon,hcat, comma,
punctuate, fsep)
Expand Down Expand Up @@ -73,7 +78,7 @@ import TysWiredIn (intDataCon)
import DataCon (dataConWorkId,dataConOrigArgTys)
import MkCore (mkWildValBinder)
import Outputable
(cat, ppr, SDoc, showSDocUnsafe, showSDoc,
(cat, ppr, SDoc, showSDocUnsafe,
($$), ($+$), hsep, vcat, empty,text,
(<>), (<+>), nest, int, colon,hcat, comma,
punctuate, fsep)
Expand Down Expand Up @@ -225,7 +230,6 @@ foundBinds_info flags ids = do
other -> text "The following bindings are to be optimised:"
$+$ nest 4 txt
print_binding id = ppr id
max_nest = maximum $ 0 : map (length.(showSDoc dyn_flags).ppr) ids
-- Print groups of types
printer the_groups = case the_groups of
[] -> return ()
Expand Down
69 changes: 55 additions & 14 deletions src/Foreign/Storable/Generic/Plugin/Internal/Compile.hs
Expand Up @@ -42,7 +42,7 @@ where
import Prelude hiding ((<>))

#if MIN_VERSION_GLASGOW_HASKELL(9,0,1,0)
import GHC.Core (Bind(..),Expr(..), CoreExpr, CoreBind, CoreProgram, Alt, AltCon(..), isId, Unfolding(..))
import GHC.Core (Bind(..),Expr(..), CoreExpr, CoreBind, CoreProgram, Alt(..), AltCon(..), isId, Unfolding(..))
import GHC.Types.Literal (Literal(..))
import GHC.Types.Id (isLocalId, isGlobalId,setIdInfo, Id)
import GHC.Types.Id.Info (IdInfo(..))
Expand All @@ -53,7 +53,13 @@ import qualified GHC.Types.Name as N (varName)
import GHC.Types.SrcLoc (noSrcSpan,SrcSpan)
import GHC.Types.Unique (getUnique)
import GHC.Driver.Main (hscCompileCoreExpr)
#if MIN_VERSION_GLASGOW_HASKELL(9,2,0,0)
import GHC.Data.Bag (bagToList)
import GHC.Driver.Env.Types (HscEnv)
import GHC.Unit.Module.ModGuts (ModGuts(..))
#else
import GHC.Driver.Types (HscEnv,ModGuts(..))
#endif
import GHC.Core.Opt.Monad (CoreM,CoreToDo(..),getHscEnv,getDynFlags)
import GHC.Core.Lint (lintExpr)
import GHC.Types.Basic (CompilerPhase(..), Boxity(..))
Expand Down Expand Up @@ -188,7 +194,7 @@ tryCompileExpr id core_expr = do
e_compiled <- liftIO $ try $
compileExpr hsc_env core_expr (getSrcSpan id) :: CoreM (Either SomeException a)
case e_compiled of
Left se -> return $ Left $ CompilationError (NonRec id core_expr) (stringToPpr $ show se)
Left se -> return $ Left $ CompilationError (NonRec id core_expr) [stringToPpr $ show se]
Right val-> return $ Right val

----------------------
Expand Down Expand Up @@ -232,7 +238,7 @@ intSubstitution b@(NonRec id (Lam l1 l@(Lam l2 e@(Lam l3 expr)))) = do
case m_t of
Just t -> return $ NonRec id <$> (Lam l1 <$> (Lam l2 <$> (intToExpr t <$> the_integer)))
Nothing ->
return the_integer >> return $ Left $ CompilationError b (text "Type not found")
return the_integer >> return $ Left $ CompilationError b [text "Type not found"]
-- Without GSTORABLE_SUMPTYPES
intSubstitution b@(NonRec id (Lam l1 expr)) = do
-- Get HscEnv
Expand All @@ -243,7 +249,7 @@ intSubstitution b@(NonRec id (Lam l1 expr)) = do
case m_t of
Just t -> return $ NonRec id <$> (intToExpr t <$> the_integer)
Nothing ->
return the_integer >> return $ Left $ CompilationError b (text "Type not found")
return the_integer >> return $ Left $ CompilationError b [text "Type not found"]
-- For GHC <= 8.6.5
intSubstitution b@(NonRec id e@(App expr g)) = case expr of
Lam _ (Lam _ (Lam _ e)) -> intSubstitution $ NonRec id expr
Expand All @@ -268,7 +274,7 @@ intSubstitutionWorker id expr = do
Just t -> return $ NonRec id <$> (intToExpr t <$> the_integer)
-- If the compilation error occured, first return it.
Nothing ->
return the_integer >> return $ Left $ CompilationError (NonRec id expr) (text "Type not found")
return the_integer >> return $ Left $ CompilationError (NonRec id expr) [text "Type not found"]
-----------------------
-- peek substitution --
-----------------------
Expand All @@ -281,10 +287,10 @@ offsetSubstitution b@(NonRec id expr) = do
let ne_subs = case e_subs of
-- Add the text from other error.
Left (OtherError sdoc)
-> Left $ CompilationError b sdoc
-> Left $ CompilationError b [sdoc]
-- Add the information about uncompiled expr.
Left err@(CompilationError _ _)
-> Left $ CompilationError b (pprError Some err)
-> Left $ CompilationError b [pprError Some err]
a -> a

return $ NonRec id <$> e_subs
Expand All @@ -307,8 +313,10 @@ getScopeExpr (IntPrimVal _ expr) = expr
instance Outputable OffsetScope where
ppr (IntList id expr) = ppr id <+> ppr (getUnique id) <+> comma <+> ppr expr
ppr (IntPrimVal id expr) = ppr id <+> ppr (getUnique id) <+> comma <+> ppr expr
pprPrec _ el = ppr el

#if !MIN_VERSION_GLASGOW_HASKELL(9,2,0,0)
pprPrec _ el = ppr el
#endif

-- | Create a list expression from Haskell list.
intListExpr :: [Int] -> CoreExpr
Expand Down Expand Up @@ -486,7 +494,11 @@ offsetSubstitutionTree scope expr
-- Compile case_expr and put it in scope as x#
-- case_expr is of format $w!! @Int offsets 0#
| Case case_expr _ _ [alt0] <- expr
#if MIN_VERSION_GLASGOW_HASKELL(9,2,0,0)
, (Alt (DataAlt i_prim_con) [x_id] alt_expr) <- alt0
#else
, (DataAlt i_prim_con, [x_id], alt_expr) <- alt0
#endif
, i_prim_con == intDataCon
, Just new_case_expr <- caseExprIndex scope case_expr
= do
Expand All @@ -498,13 +510,26 @@ offsetSubstitutionTree scope expr
-- Normal case expressions.
| Case case_expr cb t alts <- expr
= do
e_new_alts <- mapM (\(a, args, a_expr) -> (,,) a args <$> offsetSubstitutionTree scope a_expr) alts
#if MIN_VERSION_GLASGOW_HASKELL(9,2,0,0)
let mkAlt = Alt
#else
let mkAlt = (,,)
#endif

e_new_alts <- flip mapM alts $
#if MIN_VERSION_GLASGOW_HASKELL(9,2,0,0)
\(Alt a args a_expr) ->
#else
\(a, args, a_expr) ->
#endif
(,,) a args <$> offsetSubstitutionTree scope a_expr

new_case_expr <- offsetSubstitutionTree scope case_expr
-- Find the first error in alternative compilation
let c_err = find (\(_,_,e) -> isLeft e) e_new_alts
case c_err of
Nothing -> return $ Case <$> new_case_expr
<*> pure cb <*> pure t <*> pure [(a,b,ne) | (a,b,Right ne) <- e_new_alts]
<*> pure cb <*> pure t <*> pure [mkAlt a b ne | (a,b,Right ne) <- e_new_alts]
Just (_,_,err) -> return err
-- Variable. Return it or try to replace it.
-- Must be here, otherwise other substitutions won't happen
Expand Down Expand Up @@ -566,14 +591,26 @@ lintBind :: CoreBind -- ^ Core binding to use when returning CompilationError
lintBind b_old b@(NonRec id expr) = do
dyn_flags <- getDynFlags
case lintExpr dyn_flags [] expr of
Just sdoc -> (return $ Left $ CompilationError b_old sdoc)
Nothing -> return $ Right b
Just sdoc -> do
#if MIN_VERSION_GLASGOW_HASKELL(9,2,0,0)
let err = bagToList sdoc
#else
let err = [sdoc]
#endif
return $ Left $ CompilationError b_old err
Nothing ->
return $ Right b
lintBind b_old b@(Rec bs) = do
dyn_flags <- getDynFlags
let errs = mapMaybe (\(_,expr) -> lintExpr dyn_flags [] expr) bs
#if MIN_VERSION_GLASGOW_HASKELL(9,2,0,0)
let convert = foldMap bagToList
#else
let convert = id
#endif
case errs of
[] -> return $ Right b
_ -> return $ Left $ CompilationError b_old (vcat errs)
_ -> return $ Left $ CompilationError b_old (convert errs)

-- | Substitutes the localIds inside the bindings with bodies of provided bindings.
replaceIdsBind :: [CoreBind] -- ^ Replace with - for GStorable bindings
Expand Down Expand Up @@ -616,7 +653,11 @@ replaceIds gstorable_bs other_bs (Let b e) = Let (replaceIdsBind gstorable_bs
-- Replace the case_expression and the altenatives.
replaceIds gstorable_bs other_bs (Case e ev t alts) = do
let new_e = replaceIds gstorable_bs other_bs e
new_alts = map (\(alt, ids, exprs) -> (alt,ids, replaceIds gstorable_bs other_bs exprs)) alts
#if MIN_VERSION_GLASGOW_HASKELL(9,2,0,0)
let new_alts = map (\(Alt alt ids exprs) -> Alt alt ids (replaceIds gstorable_bs other_bs exprs)) alts
#else
let new_alts = map (\(alt, ids, exprs) -> (alt, ids, replaceIds gstorable_bs other_bs exprs)) alts
#endif
Case new_e ev t new_alts
-- Replace the expression in Cast
replaceIds gstorable_bs other_bs (Cast e c) = Cast (replaceIds gstorable_bs other_bs e) c
Expand Down
16 changes: 8 additions & 8 deletions src/Foreign/Storable/Generic/Plugin/Internal/Error.hs
Expand Up @@ -45,13 +45,13 @@ type CrashOnWarning = Bool
data Flags = Flags Verbosity CrashOnWarning

-- | All possible errors.
data Error = TypeNotFound Id -- ^ Could not obtain the type from the id.
| RecBinding CoreBind -- ^ The binding is recursive and won't be substituted.
| CompilationNotSupported CoreBind -- ^ The compilation-substitution is not supported for the given binding.
| CompilationError CoreBind SDoc -- ^ Error during compilation. The CoreBind is to be returned.
| OrderingFailedBinds Int [CoreBind] -- ^ Ordering failed for core bindings.
| OrderingFailedTypes Int [Type] -- ^ Ordering failed for types
| OtherError SDoc -- ^ Any other error.
data Error = TypeNotFound Id -- ^ Could not obtain the type from the id.
| RecBinding CoreBind -- ^ The binding is recursive and won't be substituted.
| CompilationNotSupported CoreBind -- ^ The compilation-substitution is not supported for the given binding.
| CompilationError CoreBind [SDoc] -- ^ Error during compilation. The CoreBind is to be returned.
| OrderingFailedBinds Int [CoreBind] -- ^ Ordering failed for core bindings.
| OrderingFailedTypes Int [Type] -- ^ Ordering failed for types
| OtherError SDoc -- ^ Any other error.

pprTypeNotFound :: Verbosity -> Id -> SDoc
pprTypeNotFound None _ = empty
Expand Down Expand Up @@ -141,7 +141,7 @@ pprError :: Verbosity -> Error -> SDoc
pprError verb (TypeNotFound id ) = pprTypeNotFound verb id
pprError verb (RecBinding bind) = pprRecBinding verb bind
pprError verb (CompilationNotSupported bind) = pprCompilationNotSupported verb bind
pprError verb (CompilationError bind str) = pprCompilationError verb bind str
pprError verb (CompilationError bind str) = pprCompilationError verb bind $ vcat str
pprError verb (OrderingFailedBinds d bs) = pprOrderingFailedBinds verb d bs
pprError verb (OrderingFailedTypes d ts) = pprOrderingFailedTypes verb d ts
pprError verb (OtherError sdoc ) = pprOtherError verb sdoc
Expand Down
5 changes: 5 additions & 0 deletions src/Foreign/Storable/Generic/Plugin/Internal/GroupTypes.hs
Expand Up @@ -32,7 +32,12 @@ import qualified GHC.Types.Name as N (varName)
import GHC.Types.SrcLoc (noSrcSpan)
import GHC.Types.Unique (getUnique)
import GHC.Driver.Main (hscCompileCoreExpr, getHscEnv)
#if MIN_VERSION_GLASGOW_HASKELL(9,2,0,0)
import GHC.Driver.Env.Types (HscEnv)
import GHC.Unit.Module.ModGuts (ModGuts(..))
#else
import GHC.Driver.Types (HscEnv,ModGuts(..))
#endif
import GHC.Core.Opt.Monad (CoreM,CoreToDo(..))
import GHC.Types.Basic (CompilerPhase(..))
import GHC.Core.Type hiding (eqType)
Expand Down
30 changes: 21 additions & 9 deletions src/Foreign/Storable/Generic/Plugin/Internal/Helpers.hs
Expand Up @@ -13,7 +13,7 @@ Various helping functions.
module Foreign.Storable.Generic.Plugin.Internal.Helpers where

#if MIN_VERSION_GLASGOW_HASKELL(9,0,1,0)
import GHC.Core (Bind(..),Expr(..), CoreExpr, CoreBind, CoreProgram, Alt)
import GHC.Core (Bind(..),Expr(..), CoreExpr, CoreBind, CoreBndr, CoreProgram, Alt(..))
import GHC.Types.Literal (Literal(..))
import GHC.Types.Id (isLocalId, isGlobalId,Id)
import GHC.Types.Var (Var(..))
Expand All @@ -23,7 +23,12 @@ import qualified GHC.Types.Name as N (varName)
import GHC.Types.SrcLoc (noSrcSpan)
import GHC.Types.Unique (getUnique)
import GHC.Driver.Main (hscCompileCoreExpr, getHscEnv)
#if MIN_VERSION_GLASGOW_HASKELL(9,2,0,0)
import GHC.Driver.Env.Types (HscEnv)
import GHC.Unit.Module.ModGuts (ModGuts(..))
#else
import GHC.Driver.Types (HscEnv,ModGuts(..))
#endif
import GHC.Core.Opt.Monad (CoreM,CoreToDo(..))
import GHC.Types.Basic (CompilerPhase(..))
import GHC.Core.Type (isAlgType, splitTyConApp_maybe)
Expand Down Expand Up @@ -104,15 +109,23 @@ getIdsExprsBind (Rec recs) = recs

-- | Get all IDs from CoreExpr
getIdsExpr :: CoreExpr -> [Id]
getIdsExpr (Var id) = [id]
getIdsExpr (App e1 e2) = concat [getIdsExpr e1, getIdsExpr e2]
getIdsExpr (Lam id e) = id : getIdsExpr e
getIdsExpr (Var id) = [id]
getIdsExpr (App e1 e2) = concat [getIdsExpr e1, getIdsExpr e2]
getIdsExpr (Lam id e) = id : getIdsExpr e
-- Ids from bs are ignored, as they are supposed to appear in e argument.
getIdsExpr (Let bs e) = concat [getIdsExpr e, concatMap getIdsExpr (getExprsBind bs)]
getIdsExpr (Let bs e) = concat [getIdsExpr e, concatMap getIdsExpr (getExprsBind bs)]
-- The case_binder is ignored - the evaluated expression might appear on the rhs of alts
getIdsExpr (Case e _ _ alts) = concat $ getIdsExpr e : map (\(_,_,e_c) -> getIdsExpr e_c) alts
getIdsExpr (Cast e _) = getIdsExpr e
getIdsExpr _ = []
getIdsExpr (Case e _ _ alts) = concat $ getIdsExpr e : map extractAlt alts
getIdsExpr (Cast e _) = getIdsExpr e
getIdsExpr _ = []

#if MIN_VERSION_GLASGOW_HASKELL(9,2,0,0)
extractAlt :: Alt CoreBndr -> [Id]
extractAlt (Alt _ac _bs expr) = getIdsExpr expr
#else
extractAlt :: (a, b, CoreExpr) -> [Id]
extractAlt (_, _, e_c) = getIdsExpr e_c
#endif


------------
Expand Down Expand Up @@ -249,4 +262,3 @@ removeProxy t
= ForAllTy b t2
| otherwise
= t

6 changes: 5 additions & 1 deletion src/Foreign/Storable/Generic/Plugin/Internal/Predicates.hs
Expand Up @@ -76,7 +76,12 @@ import qualified GHC.Types.Name as N (varName)
import GHC.Types.SrcLoc (noSrcSpan)
import GHC.Types.Unique (getUnique)
import GHC.Driver.Main (hscCompileCoreExpr, getHscEnv)
#if MIN_VERSION_GLASGOW_HASKELL(9,2,0,0)
import GHC.Driver.Env.Types (HscEnv)
import GHC.Unit.Module.ModGuts (ModGuts(..))
#else
import GHC.Driver.Types (HscEnv,ModGuts(..))
#endif
import GHC.Core.Opt.Monad (CoreM,CoreToDo(..))
import GHC.Types.Basic (CompilerPhase(..))
import GHC.Core.Type (isAlgType, splitTyConApp_maybe)
Expand Down Expand Up @@ -277,4 +282,3 @@ withTypeCheck ty_f id_f id = do
let ty_checked = ty_f $ varType id
id_checked = id_f id
and [isJust ty_checked, id_checked]

5 changes: 5 additions & 0 deletions src/Foreign/Storable/Generic/Plugin/Internal/Types.hs
Expand Up @@ -51,7 +51,12 @@ import qualified GHC.Types.Name as N (varName, tcClsName)
import GHC.Types.SrcLoc (noSrcSpan)
import GHC.Types.Unique (getUnique)
import GHC.Driver.Main (hscCompileCoreExpr, getHscEnv)
#if MIN_VERSION_GLASGOW_HASKELL(9,2,0,0)
import GHC.Driver.Env.Types (HscEnv)
import GHC.Unit.Module.ModGuts (ModGuts(..))
#else
import GHC.Driver.Types (HscEnv,ModGuts(..))
#endif
import GHC.Core.Opt.Monad (CoreM,CoreToDo(..))
import GHC.Types.Basic (CompilerPhase(..))
import GHC.Core.Type (isAlgType, splitTyConApp_maybe)
Expand Down

0 comments on commit ca3758a

Please sign in to comment.