Skip to content

Commit

Permalink
cleaned up utility functions and examples
Browse files Browse the repository at this point in the history
  • Loading branch information
ghorn committed May 27, 2012
1 parent de162a0 commit c882b5e
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 42 deletions.
45 changes: 28 additions & 17 deletions CompileTest.hs
Expand Up @@ -6,21 +6,35 @@


import Data.Array.Repa (DIM0,DIM1,DIM2) import Data.Array.Repa (DIM0,DIM1,DIM2)


import Dvda.SymMonad ( (:*)(..), makeFun, inputs_, outputs_, node ) import Dvda.SymMonad ( (:*)(..), makeFunGraph, runFunGraph, inputs_, outputs_, node )
import Dvda.Expr import Dvda.Expr
import Dvda.Graph import Dvda.Graph
import Dvda.HSBuilder import Dvda.HSBuilder
--import Dvda.Codegen.CBuilder --import Dvda.Codegen.CBuilder


gr' :: FunGraph Double (DIM0 :* DIM0 :* DIM0) (DIM0 :* DIM0 :* DIM0 :* DIM0)
gr' = makeFunGraph (x' :* y' :* z') (f :* fx :* fy :* fz)
where
x' = sym "x"
y' = sym "y"
z' = sym "z"

f0 x y z = (z + x*y)*log(cos x / tanh y)**(z/exp y)
fx0 = f0 (f0 x' y' z') (f0 z' y' x') (f0 y' x' z')
fy0 = f0 (f0 z' x' y') (f0 x' z' y') (f0 z' z' y')
fz0 = f0 (f0 x' y' z') (f0 x' y' x') (f0 y' x' y')
f = f0 fx0 fy0 fz0

fx = diff f x'
fy = diff f y'
fz = diff f z'



gr :: FunGraph Double (DIM0 :* DIM1 :* DIM2) (DIM2 :* DIM1 :* DIM0) gr :: FunGraph Double (DIM0 :* DIM1 :* DIM2) (DIM2 :* DIM1 :* DIM0)
--gr :: FunGraph Double (DIM0 :* DIM0 :* DIM0) (DIM0 :* DIM0 :* DIM0) gr = runFunGraph $ do
gr = snd $ makeFun $ do
let x = sym "x" let x = sym "x"
y = vsym 3 "y" y = vsym 3 "y"
-- y = sym "y"
z = msym (2,3) "Z" z = msym (2,3) "Z"
-- z = sym "Z"
inputs_ (x :* y :* z) inputs_ (x :* y :* z)


z1 <- node $ (scale x z)**3 z1 <- node $ (scale x z)**3
Expand All @@ -31,8 +45,8 @@ gr = snd $ makeFun $ do


outputs_ (z1 :* z2 :* z3) outputs_ (z1 :* z2 :* z3)


main' :: IO () main :: IO ()
main' = do main = do
fun <- buildHSFunction gr fun <- buildHSFunction gr
let x = 0 let x = 0
y = vec [0,1,2] y = vec [0,1,2]
Expand All @@ -41,14 +55,11 @@ main' = do


print answer print answer


main' :: IO ()
main' = do
fun <- buildHSFunction gr'
let x = 0
y = 3
z = 6


-- main :: IO () print $ fun (x :* y :* z)
-- main = do
-- fun <- buildCFunction gr
-- print fun
-- -- let x = 0
-- -- y = vec [0,1,2::Double]
-- -- z = mat (2,3) [0,1,2,3,4,5]
-- -- answer = fun (x :* y :* z)
-- --
-- -- print answer
42 changes: 20 additions & 22 deletions Dvda/Examples.hs
Expand Up @@ -13,7 +13,6 @@ import Data.Array.Repa (DIM0,DIM1,DIM2)
import Dvda.SymMonad import Dvda.SymMonad
import Dvda.Expr import Dvda.Expr
import Dvda.Graph import Dvda.Graph
--import Dvda.CFunction


exampleFun :: State (FunGraph Double (DIM0 :* DIM1 :* DIM2) (DIM2 :* DIM1 :* DIM0)) () exampleFun :: State (FunGraph Double (DIM0 :* DIM1 :* DIM2) (DIM2 :* DIM1 :* DIM0)) ()
exampleFun = do exampleFun = do
Expand Down Expand Up @@ -43,20 +42,20 @@ exampleFun' = do
run' :: IO () run' :: IO ()
run' = do run' = do
let gr :: FunGraph Double (DIM0 :* DIM1 :* DIM2) (DIM2 :* DIM1 :* DIM0) let gr :: FunGraph Double (DIM0 :* DIM1 :* DIM2) (DIM2 :* DIM1 :* DIM0)
gr@( FunGraph _ _ _ _) = snd $ makeFun exampleFun gr@(FunGraph hm im _ _) = runFunGraph exampleFun
(FunGraph _ _ _ _) = snd $ makeFun exampleFun' (FunGraph hm' im' _ _) = runFunGraph exampleFun'


putStrLn $ funGraphSummary gr putStrLn $ funGraphSummary gr
putStrLn $ showCollisions gr putStrLn $ showCollisions gr
previewGraph gr previewGraph gr
-- putStrLn "\nimperative same as pure+cse?:" putStrLn "\nimperative same as pure+cse?:"
-- print $ hm == hm' print $ hm == hm'
-- print $ im == im' print $ im == im'


run :: IO () run :: IO ()
run = do run = do
let gr :: FunGraph Double (DIM0 :* DIM0) (DIM0 :* DIM0) let gr :: FunGraph Double (DIM0 :* DIM0) (DIM0 :* DIM0)
gr@( FunGraph _ _ _ _) = snd $ makeFun $ do gr@( FunGraph _ _ _ _) = runFunGraph $ do
let x = sym "x" let x = sym "x"
y = sym "y" y = sym "y"
z1 = x * y z1 = x * y
Expand All @@ -75,23 +74,22 @@ run = do
showoff :: IO () showoff :: IO ()
showoff = do showoff = do
let gr :: FunGraph Double (DIM0 :* DIM0 :* DIM0) (DIM0 :* DIM0 :* DIM0 :* DIM0) let gr :: FunGraph Double (DIM0 :* DIM0 :* DIM0) (DIM0 :* DIM0 :* DIM0 :* DIM0)
gr@(FunGraph {}) = snd $ makeFun $ do gr = makeFunGraph (x' :* y' :* z') (f :* fx :* fy :* fz)
let x' = sym "x" where
y' = sym "y" x' = sym "x"
z' = sym "z" y' = sym "y"
z' = sym "z"


f0 x y z = (z + x*y)*log(cos x / tanh y)**(z/exp y) f0 x y z = (z + x*y)*log(cos x / tanh y)**(z/exp y)
fx0 = f0 (f0 x' y' z') (f0 z' y' x') (f0 y' x' z') fx0 = f0 (f0 x' y' z') (f0 z' y' x') (f0 y' x' z')
fy0 = f0 (f0 z' x' y') (f0 x' z' y') (f0 z' z' y') fy0 = f0 (f0 z' x' y') (f0 x' z' y') (f0 z' z' y')
fz0 = f0 (f0 x' y' z') (f0 x' y' x') (f0 y' x' y') fz0 = f0 (f0 x' y' z') (f0 x' y' x') (f0 y' x' y')
f = f0 fx0 fy0 fz0 f = f0 fx0 fy0 fz0

fx = diff f x' fx = diff f x'
fy = diff f y' fy = diff f y'
fz = diff f z' fz = diff f z'


inputs_ (x' :* y' :* z')
outputs_ (f :* fx :* fy :* fz)


putStrLn $ showCollisions gr putStrLn $ showCollisions gr
-- putStrLn $ funGraphSummary' gr -- putStrLn $ funGraphSummary' gr
Expand Down
14 changes: 11 additions & 3 deletions Dvda/SymMonad.hs
Expand Up @@ -15,7 +15,8 @@ module Dvda.SymMonad ( (:*)(..)
, inputs_ , inputs_
, outputs , outputs
, outputs_ , outputs_
, makeFun , makeFunGraph
, runFunGraph
, rad , rad
, getSensitivities , getSensitivities
) where ) where
Expand Down Expand Up @@ -312,5 +313,12 @@ instance Shape sh => ExprList (sh :. Int) a where




---------------- utility function ----------------- ---------------- utility function -----------------
makeFun :: StateT (FunGraph a b c) Identity d -> (d, FunGraph a b c) runFunGraph :: StateT (FunGraph a b c) Identity d -> FunGraph a b c
makeFun f = runState f emptyFunGraph runFunGraph f = snd $ runState f emptyFunGraph

--makeFunGraph :: (HList c, HList b, NumT b ~ NumT c, NumT b ~ a, Eq a, Floating a, Hashable a, Unbox a) =>
makeFunGraph :: (HList c, HList b, NumT b ~ NumT c, NumT b ~ a) =>
b -> c -> FunGraph a (DimT b) (DimT c)
makeFunGraph ins outs = runFunGraph $ do
inputs_ ins
outputs_ outs

0 comments on commit c882b5e

Please sign in to comment.