Skip to content

Commit

Permalink
more refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeizbicki committed Sep 7, 2015
1 parent c528b47 commit 17134a5
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 144 deletions.
21 changes: 21 additions & 0 deletions .gitignore
@@ -0,0 +1,21 @@
*.Rout
*.csv
*.class
*.o

results
cover_tree

hlearn-allknn
hlearn-linear
*.swp
*.swo
dist/
gitignore/
examples/old/
scripts/

.cabal-sandbox/
cabal.sandbox.config

.stack-work/
2 changes: 1 addition & 1 deletion herbie-haskell.cabal
Expand Up @@ -56,7 +56,7 @@ library

-- Modules included in this library but not exported.
other-modules:
Show
Stabalize.MathExpr

-- LANGUAGE extensions used by modules in this package.
-- other-extensions:
Expand Down
260 changes: 117 additions & 143 deletions src/Herbie.hs
Expand Up @@ -21,8 +21,13 @@ import Data.Maybe
import System.Process

import Debug.Trace
import Show
import Data.IORef

import Stabalize.MathExpr

import Prelude
ifThenElse True t f = t
ifThenElse False t f = f


--------------------------------------------------------------------------------
-- GHC plugin interface
Expand All @@ -34,10 +39,8 @@ plugin = defaultPlugin

install :: [CommandLineOption] -> [CoreToDo] -> CoreM [CoreToDo]
install _ todo = do
dflags <- getDynFlags
liftIO $ writeIORef dynFlags_ref dflags
reinitializeGlobals
return (CoreDoPluginPass "Herbie" pass : todo)
return (CoreDoPluginPass "MathInfo" pass : todo)

pass :: ModGuts -> CoreM ModGuts
pass guts = do
Expand Down Expand Up @@ -179,14 +182,6 @@ getDictMap pt = nubBy f $ go $ getDictMap0 pt
dict' = mkApps (Var (findSelId c' (classAllSelIds c))) [Type $ getParam pt, dict]
_ -> []

-- | Given a list of member functions of a class,
-- return the function that extracts the dictionary corresponding to c.
findSelId :: Class -> [Var] -> Var
findSelId c [] = error "findSelId: this shouldn't happen"
findSelId c (v:vs) = if isDictSelector c v
then v
else findSelId c vs

-- | Returns a dictionary map for just the top level cxt of the ParamType.
-- This is used to seed getDictMap.
getDictMap0 :: ParamType -> [(Class,Expr Var)]
Expand All @@ -198,6 +193,14 @@ getDictMap0 pt = map f $ getClasses pt
xs -> error $ "getDictMap: multiple dictionaries for class "++getString c
++": "++show (map getString xs)

-- | Given a list of member functions of a class,
-- return the function that extracts the dictionary corresponding to c.
findSelId :: Class -> [Var] -> Var
findSelId c [] = error "findSelId: this shouldn't happen"
findSelId c (v:vs) = if isDictSelector c v
then v
else findSelId c vs

-- | True only if dict is a dictionary for c
isDict :: Class -> Var -> Bool
isDict c dict = getString c == dropWhile go (getString dict)
Expand Down Expand Up @@ -256,6 +259,18 @@ getSuperClasses c = c : (nub $ concat $ map getSuperClasses $ concat $ map go $
getString :: NamedThing a => a -> String
getString = occNameString . getOccName

expr2str :: DynFlags -> Expr Var -> String
expr2str dflags (Var v) = {-"var_" ++-} var2str v
expr2str dflags e = "expr_" ++ (decorate $ showSDoc dflags (ppr e))
where
decorate :: String -> String
decorate = map go
where
go ' ' = '_'
go '@' = '_'
go '$' = '_'
go x = x

lit2rational :: Literal -> Rational
lit2rational l = case l of
MachInt i -> toRational i
Expand All @@ -274,61 +289,74 @@ maybeHead (a:_) = Just a
maybeHead _ = Nothing

--------------------------------------------------------------------------------
-- convert Expr into HerbieExpr
-- convert Expr into MathExpr

data Herbie = Herbie
{ hexpr :: HerbieExpr
data MathInfo = MathInfo
{ hexpr :: MathExpr
, numType :: ParamType
, getExprs :: [(String,Expr Var)]
}

herbie2cmd :: DynFlags -> Herbie -> String
herbie2cmd dflags herbie = "(herbie-test "++varStr++" \"cmd\" "++herbie2lisp dflags herbie++" )\n"
herbie2lisp :: DynFlags -> MathInfo -> String
herbie2lisp dflags herbie = mathExpr2lisp (hexpr herbie)

findExpr :: MathInfo -> String -> Maybe (Expr Var)
findExpr herbie str = lookup str (getExprs herbie)

mathExpr2expr :: ModGuts -> MathInfo -> CoreM (Expr CoreBndr)
mathExpr2expr guts herbie = go (hexpr herbie)
where
varStr = "(" ++ intercalate " " (map fst (getExprs herbie)) ++ ")"
t = getParam $ numType herbie

herbie2lisp :: DynFlags -> Herbie -> String
herbie2lisp dflags herbie = herbieExpr2lisp dflags (hexpr herbie)
getDict opstr = do
ret <- getDictConcrete guts opstr (getParam $ numType herbie)
case ret of
Just x -> return x
Nothing -> do
ret' <- getDictPolymorphic guts opstr (numType herbie)
case ret' of
Just x -> return x

findExpr :: Herbie -> String -> Maybe (Expr Var)
findExpr herbie str = lookup str (getExprs herbie)
-- binary operators
go (EBinOp opstr a1 a2) = do
a1' <- go a1
a2' <- go a2
op <- getVar guts opstr
dict <- getDict opstr
return $ App (App (App (App (Var op) (Type t)) dict) a1') a2'

----------------------------------------
-- unary operators
go (EMonOp opstr a) = do
a' <- go a
op <- getVar guts opstr
dict <- getDict opstr
return $ App (App (App (Var op) (Type t)) dict) a'

data HerbieExpr
= EBinOp String HerbieExpr HerbieExpr
| EMonOp String HerbieExpr
| ELeaf String
-- leaf nodes
go (ELeaf str) = case readMaybe str :: Maybe Double of

expr2str :: DynFlags -> Expr Var -> String
expr2str dflags (Var v) = {-"var_" ++-} var2str v
expr2str dflags e = "expr_" ++ (decorate $ showSDoc dflags (ppr e))
where
decorate :: String -> String
decorate = map go
where
go ' ' = '_'
go '@' = '_'
go '$' = '_'
go x = x
-- leaf is a numeric literal
Just d -> return $ mkConApp floatDataCon [mkFloatLit $ toRational d]

herbieExpr2lisp :: DynFlags -> HerbieExpr -> String
herbieExpr2lisp dflags = go
where
go (EBinOp op a1 a2) = "("++op++" "++go a1++" "++go a2++")"
go (EMonOp op a) = "("++op++" "++go a++")"
go (ELeaf e) = e
-- leaf is any other expression
Nothing -> do
dflags <- getDynFlags
return $ case findExpr herbie str of
Just x -> x
Nothing -> error $ "mathExpr2expr: var " ++ str ++ " not in scope"

----------------------------------------

expr2herbie :: DynFlags -> ParamType -> Expr Var -> Herbie
expr2herbie dflags t e = Herbie hexpr t exprs
expr2herbie :: DynFlags -> ParamType -> Expr Var -> MathInfo
expr2herbie dflags t e = MathInfo hexpr t exprs
where
(hexpr,exprs) = go e []

go :: Expr Var
-> [(String,Expr Var)]
-> (HerbieExpr,[(String,Expr Var)])
-> (MathExpr,[(String,Expr Var)])

-- we need to special case the $ operator for when HerbieExpr is run before any rewrite rules
-- we need to special case the $ operator for when MathExpr is run before any rewrite rules
go e@(App (App (App (App (Var v) (Type _)) (Type _)) a1) a2) exprs
= if var2str v == "$"
then go (App a1 a2) exprs
Expand Down Expand Up @@ -367,40 +395,6 @@ expr2herbie dflags t e = Herbie hexpr t exprs
binOpList = [ "*", "/", "-", "+", "max", "min" ]
monOpList = [ "cos", "sin", "tan", "log", "sqrt" ]

--------------------------------------------------------------------------------
-- Herbie interface

data HerbieStr
= SOp [HerbieStr]
| SLeaf String
deriving Show

-- | We just need to add spaces around the parens before calling "words"
tokenize :: String -> [String]
tokenize = words . concat . map go
where
go '(' = " ( "
go ')' = " ) "
go x = [x]

-- | Assuming the previous token was a "(", splits the result into everything up until the matching ")" in fst, and everything after ")" in snd
findParen :: [String] -> ([String],[String])
findParen xs = go xs 0 []
where
go (")":xs) 0 left = (left,xs)
go (")":xs) i left = go xs (i-1) (left++[")"])
go ("(":xs) i left = go xs (i+1) (left++["("])
go (x:xs) i left = go xs i (left++[x])

str2herbieStr :: String -> HerbieStr
str2herbieStr str = head $ go $ tokenize str
where
go ("(":xs) = SOp (go left):go right
where
(left,right) = findParen xs
go (x:xs) = SLeaf x:go xs
go [] = []

----------------------------------------
-- get information from the environment

Expand All @@ -424,7 +418,7 @@ getVar guts opstr = do

-- | Given a function name and concrete type, get the needed dictionary.
getDictConcrete :: ModGuts -> String -> Type -> CoreM (Maybe (Expr CoreBndr))
getDictConcrete guts opstr t = do
getDictConcrete guts opstr t = trace "poly" $ do
hscenv <- getHscEnv
dflags <- getDynFlags
eps <- liftIO $ hscEPS hscenv
Expand Down Expand Up @@ -463,58 +457,30 @@ getDictConcrete guts opstr t = do
getDictPolymorphic :: ModGuts -> String -> ParamType -> CoreM (Maybe (Expr (CoreBndr)))
getDictPolymorphic guts opstr pt = do
let (f,ParentIs p) = getNameParent guts opstr
c = head $ filter
c = case filter
(\x -> getOccName x == nameOccName p)
(concatMap getSuperClasses $ getClasses pt)
of
[x] -> x
xs -> error $ "xs="++show (length xs)
dictFromParamType c pt

herbieStr2expr :: ModGuts -> Herbie -> HerbieStr -> CoreM (Expr CoreBndr)
herbieStr2expr guts herbie = go
where
t = getParam $ numType herbie

getDict opstr = do
ret <- getDictConcrete guts opstr (getParam $ numType herbie)
case ret of
Just x -> return x
Nothing -> do
ret' <- getDictPolymorphic guts opstr (numType herbie)
case ret' of
Just x -> return x

-- leaf nodes might be either a literal or a raw expression
go (SLeaf str) = case readMaybe str :: Maybe Double of
Just d -> return $ mkConApp floatDataCon [mkFloatLit $ toRational d]
Nothing -> do
dflags <- getDynFlags
return $ case findExpr herbie str of
Just x -> x
Nothing -> error $ "herbieStr2expr: var " ++ str ++ " not in scope"

-- binary operators
go (SOp [SLeaf opstr, a1, a2]) = do
a1' <- go a1
a2' <- go a2
op <- getVar guts opstr
dict <- getDict opstr
return $ App (App (App (App (Var op) (Type t)) dict) a1') a2'

-- unary operators
go (SOp [SLeaf opstr, a]) = do
a' <- go a
op <- getVar guts opstr
dict <- getDict opstr
return $ App (App (App (Var op) (Type t)) dict) a'
--------------------------------------------------------------------------------
-- Herbie interface

-- higher arity operators
go xs = error $ "herbieStr2expr: expr arity not supported: "++show xs
herbie2cmd :: DynFlags -> MathInfo -> String
herbie2cmd dflags herbie = "(herbie-test "++varStr++" \"cmd\" "++herbie2lisp dflags herbie++" )\n"
where
varStr = "(" ++ intercalate " " (map fst (getExprs herbie)) ++ ")"

callHerbie :: ModGuts -> Expr Var -> Herbie -> CoreM (Expr CoreBndr)
callHerbie :: ModGuts -> Expr Var -> MathInfo -> CoreM (Expr CoreBndr)
callHerbie guts expr herbie = do
dflags <- getDynFlags
let lispstr = herbie2lisp dflags herbie
lispstr' <- liftIO $ execHerbie lispstr
herbieStr2expr guts herbie $ str2herbieStr $ lispstr'
putMsgS $ "lispstr'=" ++ lispstr'
let herbie' = herbie { hexpr = str2mathExpr lispstr' }
mathExpr2expr guts herbie'

execHerbie :: String -> IO String
execHerbie lisp = do
Expand Down Expand Up @@ -543,6 +509,14 @@ execHerbie lisp = do
-- putStrLn ""
return lisp'

-- | We just need to add spaces around the parens before calling "words"
tokenize :: String -> [String]
tokenize = words . concat . map go
where
go '(' = " ( "
go ')' = " ) "
go x = [x]

--------------------------------------------------------------------------------
--

Expand Down Expand Up @@ -580,22 +554,22 @@ myshow dflags = go 1
where
white=replicate (4*i) ' '

-- instance Show (Coercion) where
-- show _ = "Coercion"
--
-- instance Show b => Show (Bind b) where
-- show _ = "Bind"
--
-- instance Show (Tickish Id) where
-- show _ = "(Tickish Id)"
--
-- instance Show Type where
-- show _ = "Type"
--
-- instance Show AltCon where
-- show _ = "AltCon"
--
-- instance Show Var where
-- show v = "name"
instance Show (Coercion) where
show _ = "Coercion"

instance Show b => Show (Bind b) where
show _ = "Bind"

instance Show (Tickish Id) where
show _ = "(Tickish Id)"

instance Show Type where
show _ = "Type"

instance Show AltCon where
show _ = "AltCon"

instance Show Var where
show v = "name"


0 comments on commit 17134a5

Please sign in to comment.