Permalink
Browse files

docs; add/use liftLin2 and indexDefault to simplify/speed

  • Loading branch information...
1 parent dac0571 commit 67c3f71c18be4c5ece8396218ad45ed3f8e484d3 Barak A. Pearlmutter committed Apr 12, 2009
Showing with 72 additions and 44 deletions.
  1. +10 −0 List/Uttl.hs
  2. +62 −44 Numeric/FAD.hs
View
@@ -21,3 +21,13 @@ _ !!~ i | i<0 = error "negative index"
[x] !!~ _ = x
(x:_) !!~ 0 = x
(x:xs) !!~ i = xs !!~ (i-1)
+
+-- | The 'indexDefault' function indexes into a list like @(!!)@, but
+-- returns the given default when it runs off the end.
+
+indexDefault :: a -> [a] -> Int -> a
+
+indexDefault def _ i | i<0 = error "negative index"
+indexDefault def (x:_) 0 = x
+indexDefault def [] i = def
+indexDefault def (x:xs) i = indexDefault def xs (i-1)
View
@@ -32,13 +32,14 @@ Bj&#246;rn Buckwalter (<bjorn.buckwalter@gmail.com>)
Notes:
-Each invocation of the differentiation function introduces a
-distinct perturbation, which requires a distinct dual number type.
-In order to prevent these from being confused, tagging, called
+Each invocation of the differentiation function introduces a distinct
+perturbation, which requires a distinct derivative-carrying number
+type. In order to prevent these from being confused, tagging, called
branding in the Haskell community, is used. This seems to prevent
perturbation confusion, although it would be nice to have an actual
-proof of this. The technique does require adding invocations of
-lift at appropriate places when nesting is present.
+proof of this. The technique does require adding invocations of lift
+at appropriate places when nesting is present, and degrades modularity
+by exposing "forall" types in type signatures.
-}
@@ -56,7 +57,7 @@ tagging to allow dynamic nesting, if the type system would allow.
-- Forward Automatic Differentiation
module Numeric.FAD (
- -- * Higher-Order Generalized Dual Numbers
+ -- * Derivative Towers: Higher-Order Generalized Dual Numbers
Tower, lift, primal,
-- * First-Order Differentiation Operators
@@ -86,7 +87,7 @@ where
import Data.List (transpose)
import Data.Foldable (Foldable)
import qualified Data.Foldable (all)
-import List.Uttl (zipWithDefaults)
+import List.Uttl (zipWithDefaults, indexDefault)
import Data.Function (on)
-- To Do:
@@ -99,16 +100,24 @@ import Data.Function (on)
-- Notes:
--- The constructor is "Bundle" because dual numbers are tangent-vector
--- bundles, in the terminology of differential geometry. For the same
--- reason, the accessor for the first derivative is "tangent".
+-- This package implements forward automatic differentiation,
+-- generalized to produce not only first derivatives, but a tower of
+-- all higher-order derivatives. This is done by replacing a base (or
+-- "primal") numberic type by a numeric type that holds the primal value
+-- but also carries along the derivative(s). If we produced only
+-- first derivatives, the augmented type would be a "Dual Number".
+-- And Dual Numbers are tangent-vector bundles, in the terminology of
+-- differential geometry. For the this reason, we call the accessor
+-- for the first derivative "tangent". We also sometimes refer to the
+-- augmented numbers as "bundles", since they bundle together a primal
+-- value and some derivative information.
-- The multivariate case is handled as a list on inputs, but an
-- arbitrary functor on outputs. This asymmetry is because Haskell
-- provides fmap but not fzipWith.
--- The derivative towers can be truncated, using Zero. Care is taken
--- to preserve said trunction, when possible.
+-- The derivative towers can be truncated. Care is taken to preserve
+-- said trunction whenever possible.
-- Other quirks:
@@ -161,13 +170,13 @@ newtype Tower tag a = Tower [a] deriving Show
-- Injectors and accessors for derivative towers
-- | The 'lift' function injects a primal number into the domain of
--- dual numbers, with a zero tower. If dual numbers were a monad,
--- 'lift' would be 'return'.
+-- derivative towers, with a zero tower. If generalized dual numbers
+-- were a monad, 'lift' would be 'return'.
lift :: Num a => a -> Tower tag a
lift = (`bundle` zero)
--- | The 'bundle' function takes a primal number and a dual number
--- tower and returns a dual number tower with the given tower shifted
+-- | The 'bundle' function takes a primal number and a derivative
+-- tower and returns a derivative tower with the given tower shifted
-- up one and the new primal inserted.
--
-- Property: @x = bundle (primal x) (tangentTower x)@
@@ -181,42 +190,43 @@ zero :: Num a => Tower tag a
zero = toTower []
-- | The 'apply' function applies a function to a number lifted from
--- the primal domain to the dual number domain, with derivative 1,
--- thus calculating the generalized push-forward, in the differential
--- geometric sense, of the given function at the given point.
+-- the primal domain to the derivative tower domain, with unit
+-- derivative, thus calculating the generalized push-forward, in the
+-- differential geometric sense, of the given function at the given
+-- point.
apply :: Num a => (Tower tag a -> b) -> a -> b
apply = (. (`bundle` 1))
--- | The 'towerElt' function finds the i-th element of a dual number
+-- | The 'towerElt' function finds the i-th element of a derivative
-- | tower, where the 0-th element is the primal value, the 1-st
-- | element is the first derivative, etc.
towerElt :: Num a => Int -> Tower tag a -> a
-towerElt i (Tower xs) = zeroPad xs !! i
+towerElt i (Tower xs) = xs !!!! i
--- | The 'fromTower' function converts a dual number tower to a list
--- of values with the i-th derivatives, i=0,1,..., possibly truncated
+-- | The 'fromTower' function converts a derivative tower to a list of
+-- values with the i-th derivatives, i=0,1,..., possibly truncated
-- when all remaining values in the tower are zero.
fromTower :: Tower tag a -> [a]
fromTower (Tower xs) = xs
--- | The 'toTower' function converts a list of numbers into a dual
--- | number tower.
+-- | The 'toTower' function converts a list of numbers into a
+-- | derivative tower.
toTower :: [a] -> Tower tag a
toTower = Tower
--- | The 'primal' function finds the primal value from a dual number
+-- | The 'primal' function finds the primal value from a derivative
-- | tower. The inverse of 'lift'.
primal :: Num a => Tower tag a -> a
primal = towerElt 0
--- | The 'tangent' function finds the tangent value of a dual number
+-- | The 'tangent' function finds the tangent value of a derivative
-- | tower, i.e., the first-order derivative.
tangent :: Num a => Tower tag a -> a
tangent = towerElt 1
-- | The 'tangentTower' function finds the entire tower of tangent
--- values of a dual number tower, starting at the 1st derivative.
--- This is equivalent, in an appropriate sense, to taking the first
+-- values of a derivative tower, starting at the 1st derivative. This
+-- is equivalent, in an appropriate sense, to taking the first
-- derivative.
tangentTower :: Num a => Tower tag a -> Tower tag a
tangentTower (Tower []) = zero
@@ -292,25 +302,30 @@ liftA1disc = (. primal)
liftA2disc :: (Num a) => (a -> a -> b) -> Tower tag a -> Tower tag a -> b
liftA2disc = (`on` primal)
--- | The 'liftLin' function lifts a linear scalar function from the
--- primal domain into the derivative tower domain. WARNING: the
+-- | The 'liftLin' function lifts a scalar linear function from the
+-- primal domain into the derivative tower domain. WARNING: The
-- restriction to linear functions is not enforced by the type system.
liftLin :: (a -> b) -> Tower tag a -> Tower tag b
liftLin f = toTower . map f . fromTower
+-- | The 'liftLin2' function lifts a binary linear function from the
+-- primal domain into the derivative tower domain. WARNING 1: The
+-- restriction to linear functions is not enforced by the type system.
+-- WARNING 2: Binary linear means linear in both arguments together,
+-- not bilinear.
+liftLin2 :: (Num a, Num b) =>
+ (a -> a -> b) -> Tower tag a -> Tower tag a -> Tower tag b
+liftLin2 f = (toTower.) . (zipWithDefaults f 0 0 `on` fromTower)
+
-- Numeric operations on derivative towers.
instance Num a => Num (Tower tag a) where
- (Tower []) + y = y
- x + (Tower []) = x
- x + y = bundle (primal x + primal y) (tangentTower x + tangentTower y)
- x - (Tower []) = x
- (Tower []) - x = negate x
- x - y = bundle (primal x - primal y) (tangentTower x - tangentTower y)
- (Tower []) * _ = zero
- _ * (Tower []) = zero
- x * y = liftA2 (*) (flip (,)) x y
- negate = liftLin negate
+ (+) = liftLin2 (+)
+ (-) = liftLin2 (-)
+ (*) (Tower []) _ = zero
+ (*) _ (Tower []) = zero
+ (*) x y = liftA2 (*) (flip (,)) x y
+ negate = liftLin negate
abs = liftA1 abs
(\x->let x0 = primal x
in
@@ -359,8 +374,8 @@ instance Floating a => Floating (Tower tag a) where
exp = liftA1_ exp const
sqrt = liftA1_ sqrt (const . recip . (2*))
log = liftA1 log recip
- -- Bug on zero base, e.g., (0**2), since derivative is fine but
- -- can get division by 0 and log 0, oops. Need special cases, ick.
+ -- Bug on zero base, e.g., diffUU (**2) 0 = NaN, which is wrong.
+ -- Need special cases to bypass avoidable division by 0 and log 0.
-- Here are some untested ideas:
-- (**) x (Tower []) = 1
-- (**) x y@(Tower [y0]) = liftA1 (**y0) ((y*) . (**(y-1))) x
@@ -544,6 +559,9 @@ zeroPad = (++ repeat 0)
zeroPadF :: (Num a, Functor f) => [f a] -> [f a]
zeroPadF fxs@(fx:_) = fxs ++ repeat (fmap (const 0) fx)
+(!!!!) :: Num a => [a] -> Int -> a
+(!!!!) = indexDefault 0
+
-- | The 'transposePad' function is like Data.List.transpose except
-- that it fills in missing elements with 0 rather than skipping them.
-- It can give a ragged output to a ragged input, but the lengths in
@@ -561,7 +579,7 @@ transposePadF :: (Num a, Foldable f, Functor f) => f [a] -> [f a]
transposePadF fx =
if Data.Foldable.all null fx
then []
- else (fmap ((!!0) . zeroPad) fx) : (transposePadF (fmap (drop 1) fx))
+ else (fmap (!!!!0) fx) : (transposePadF (fmap (drop 1) fx))
-- The 'transposeF' function transposes w/ infinite zero row padding.

0 comments on commit 67c3f71

Please sign in to comment.