Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

cleaned up utility functions and examples

  • Loading branch information...
commit c882b5e2d4eabeab42d9b4b2499386aba9506292 1 parent de162a0
@ghorn authored
Showing with 59 additions and 42 deletions.
  1. +28 −17 CompileTest.hs
  2. +20 −22 Dvda/Examples.hs
  3. +11 −3 Dvda/SymMonad.hs
View
45 CompileTest.hs
@@ -6,21 +6,35 @@
import Data.Array.Repa (DIM0,DIM1,DIM2)
-import Dvda.SymMonad ( (:*)(..), makeFun, inputs_, outputs_, node )
+import Dvda.SymMonad ( (:*)(..), makeFunGraph, runFunGraph, inputs_, outputs_, node )
import Dvda.Expr
import Dvda.Graph
import Dvda.HSBuilder
--import Dvda.Codegen.CBuilder
+gr' :: FunGraph Double (DIM0 :* DIM0 :* DIM0) (DIM0 :* DIM0 :* DIM0 :* DIM0)
+gr' = makeFunGraph (x' :* y' :* z') (f :* fx :* fy :* fz)
+ where
+ x' = sym "x"
+ y' = sym "y"
+ z' = sym "z"
+
+ f0 x y z = (z + x*y)*log(cos x / tanh y)**(z/exp y)
+ fx0 = f0 (f0 x' y' z') (f0 z' y' x') (f0 y' x' z')
+ fy0 = f0 (f0 z' x' y') (f0 x' z' y') (f0 z' z' y')
+ fz0 = f0 (f0 x' y' z') (f0 x' y' x') (f0 y' x' y')
+ f = f0 fx0 fy0 fz0
+
+ fx = diff f x'
+ fy = diff f y'
+ fz = diff f z'
+
gr :: FunGraph Double (DIM0 :* DIM1 :* DIM2) (DIM2 :* DIM1 :* DIM0)
---gr :: FunGraph Double (DIM0 :* DIM0 :* DIM0) (DIM0 :* DIM0 :* DIM0)
-gr = snd $ makeFun $ do
+gr = runFunGraph $ do
let x = sym "x"
y = vsym 3 "y"
--- y = sym "y"
z = msym (2,3) "Z"
--- z = sym "Z"
inputs_ (x :* y :* z)
z1 <- node $ (scale x z)**3
@@ -31,8 +45,8 @@ gr = snd $ makeFun $ do
outputs_ (z1 :* z2 :* z3)
-main' :: IO ()
-main' = do
+main :: IO ()
+main = do
fun <- buildHSFunction gr
let x = 0
y = vec [0,1,2]
@@ -41,14 +55,11 @@ main' = do
print answer
+main' :: IO ()
+main' = do
+ fun <- buildHSFunction gr'
+ let x = 0
+ y = 3
+ z = 6
--- main :: IO ()
--- main = do
--- fun <- buildCFunction gr
--- print fun
--- -- let x = 0
--- -- y = vec [0,1,2::Double]
--- -- z = mat (2,3) [0,1,2,3,4,5]
--- -- answer = fun (x :* y :* z)
--- --
--- -- print answer
+ print $ fun (x :* y :* z)
View
42 Dvda/Examples.hs
@@ -13,7 +13,6 @@ import Data.Array.Repa (DIM0,DIM1,DIM2)
import Dvda.SymMonad
import Dvda.Expr
import Dvda.Graph
---import Dvda.CFunction
exampleFun :: State (FunGraph Double (DIM0 :* DIM1 :* DIM2) (DIM2 :* DIM1 :* DIM0)) ()
exampleFun = do
@@ -43,20 +42,20 @@ exampleFun' = do
run' :: IO ()
run' = do
let gr :: FunGraph Double (DIM0 :* DIM1 :* DIM2) (DIM2 :* DIM1 :* DIM0)
- gr@( FunGraph _ _ _ _) = snd $ makeFun exampleFun
- (FunGraph _ _ _ _) = snd $ makeFun exampleFun'
+ gr@(FunGraph hm im _ _) = runFunGraph exampleFun
+ (FunGraph hm' im' _ _) = runFunGraph exampleFun'
putStrLn $ funGraphSummary gr
putStrLn $ showCollisions gr
previewGraph gr
--- putStrLn "\nimperative same as pure+cse?:"
--- print $ hm == hm'
--- print $ im == im'
+ putStrLn "\nimperative same as pure+cse?:"
+ print $ hm == hm'
+ print $ im == im'
run :: IO ()
run = do
let gr :: FunGraph Double (DIM0 :* DIM0) (DIM0 :* DIM0)
- gr@( FunGraph _ _ _ _) = snd $ makeFun $ do
+ gr@( FunGraph _ _ _ _) = runFunGraph $ do
let x = sym "x"
y = sym "y"
z1 = x * y
@@ -75,23 +74,22 @@ run = do
showoff :: IO ()
showoff = do
let gr :: FunGraph Double (DIM0 :* DIM0 :* DIM0) (DIM0 :* DIM0 :* DIM0 :* DIM0)
- gr@(FunGraph {}) = snd $ makeFun $ do
- let x' = sym "x"
- y' = sym "y"
- z' = sym "z"
+ gr = makeFunGraph (x' :* y' :* z') (f :* fx :* fy :* fz)
+ where
+ x' = sym "x"
+ y' = sym "y"
+ z' = sym "z"
- f0 x y z = (z + x*y)*log(cos x / tanh y)**(z/exp y)
- fx0 = f0 (f0 x' y' z') (f0 z' y' x') (f0 y' x' z')
- fy0 = f0 (f0 z' x' y') (f0 x' z' y') (f0 z' z' y')
- fz0 = f0 (f0 x' y' z') (f0 x' y' x') (f0 y' x' y')
- f = f0 fx0 fy0 fz0
-
- fx = diff f x'
- fy = diff f y'
- fz = diff f z'
+ f0 x y z = (z + x*y)*log(cos x / tanh y)**(z/exp y)
+ fx0 = f0 (f0 x' y' z') (f0 z' y' x') (f0 y' x' z')
+ fy0 = f0 (f0 z' x' y') (f0 x' z' y') (f0 z' z' y')
+ fz0 = f0 (f0 x' y' z') (f0 x' y' x') (f0 y' x' y')
+ f = f0 fx0 fy0 fz0
+
+ fx = diff f x'
+ fy = diff f y'
+ fz = diff f z'
- inputs_ (x' :* y' :* z')
- outputs_ (f :* fx :* fy :* fz)
putStrLn $ showCollisions gr
-- putStrLn $ funGraphSummary' gr
View
14 Dvda/SymMonad.hs
@@ -15,7 +15,8 @@ module Dvda.SymMonad ( (:*)(..)
, inputs_
, outputs
, outputs_
- , makeFun
+ , makeFunGraph
+ , runFunGraph
, rad
, getSensitivities
) where
@@ -312,5 +313,12 @@ instance Shape sh => ExprList (sh :. Int) a where
---------------- utility function -----------------
-makeFun :: StateT (FunGraph a b c) Identity d -> (d, FunGraph a b c)
-makeFun f = runState f emptyFunGraph
+runFunGraph :: StateT (FunGraph a b c) Identity d -> FunGraph a b c
+runFunGraph f = snd $ runState f emptyFunGraph
+
+--makeFunGraph :: (HList c, HList b, NumT b ~ NumT c, NumT b ~ a, Eq a, Floating a, Hashable a, Unbox a) =>
+makeFunGraph :: (HList c, HList b, NumT b ~ NumT c, NumT b ~ a) =>
+ b -> c -> FunGraph a (DimT b) (DimT c)
+makeFunGraph ins outs = runFunGraph $ do
+ inputs_ ins
+ outputs_ outs
Please sign in to comment.
Something went wrong with that request. Please try again.