-
Notifications
You must be signed in to change notification settings - Fork 0
/
AD.hs
32 lines (25 loc) · 1.14 KB
/
AD.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
-- | This module provides automatic differentiation for Quantities.
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE Rank2Types #-}
module Numeric.Units.Dimensional.AD (diff, Lift (lift)) where
import Numeric.Units.Dimensional (Dimensional (Dimensional), Quantity, Div)
import Numeric.AD.Types (AD, Mode)
import qualified Numeric.AD.Types (auto)
import qualified Numeric.AD (diff)
-- | Unwrap a Dimensional's numeric representation.
undim :: Dimensional v d a -> a
undim (Dimensional a) = a
-- | @diff f x@ computes the derivative of the function @f(x)@ for the
-- given value of @x@.
diff :: (Num a, Div d2 d1 d3)
=> (forall tag. Mode tag => Quantity d1 (AD tag a) -> Quantity d2 (AD tag a))
-> Quantity d1 a -> Quantity d3 a
diff f = Dimensional . Numeric.AD.diff (undim . f . Dimensional) . undim
-- | Class to provide 'Numeric.AD.lift'ing of constant data structures
-- (data structures with numeric constants used in a differentiated
-- function).
class Lift w where
-- | Embed a constant data structure.
lift :: (Num a, Mode t) => w a -> w (t a)
instance Lift (Dimensional v d)
where lift = Dimensional . Numeric.AD.Types.auto . undim