-
Notifications
You must be signed in to change notification settings - Fork 0
/
AD.hs
37 lines (25 loc) · 967 Bytes
/
AD.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
module AD where
data Dual a = Dual { primal :: a, deriv :: a } deriving (Show)
instance Eq a => Eq (Dual a) where
Dual a _ == Dual b _ = a == b
instance (Ord a, Num a) => Num (Dual a) where
Dual a b + Dual c d = Dual (a + c) (b + d)
Dual a b * Dual c d = Dual (a * c) (a*d + b*c)
Dual a b - Dual c d = Dual (a - c) (b - d)
negate (Dual a b) = Dual (negate a) (negate b)
abs (Dual a b) | a <= 0 = negate (Dual a b)
| otherwise = Dual a b
signum (Dual a b) = Dual (signum a) 0
fromInteger n = Dual (fromInteger n) 0
lift :: Num a => a -> Dual a
lift a = Dual a 0
epsilon :: Num a => Dual a
epsilon = Dual 0 1
d :: (Num a, Ord a) => (Dual a -> Dual a) -> a -> a
d f a = case f (Dual a 1) of
Dual y y' -> y'
d2 :: (Num a, Ord a) => (Dual a -> Dual a -> Dual a) -> a -> a -> (a, a)
d2 f a b =
let Dual y dyda = f (lift a + epsilon) (lift b)
Dual y dydb = f (lift a) (lift b + epsilon)
in (dyda, dydb)