Skip to content

Commit

Permalink
covariance gradients, removing svd because of diagR bug, bug noted
Browse files Browse the repository at this point in the history
  • Loading branch information
mstksg committed Feb 10, 2018
1 parent 040736b commit f0c29ed
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 134 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ Formulas for gradients come from the following papers:
Some functions are not yet implemented! See module documentation for details.
PR's definitely appreciated :)

Tests
-----

Currently numeric tests are implemented as property tests using hedgehog, but
it is possible that the answers might differ from the true values by an amount
undetectable by property tests.

All functions currently are tested except for the higher-order functions.

They are tested by "nudging" components of inputs and checking if the change in
the function outputs match what is expected from the backpropagated gradient.

TODO
----

Expand Down
2 changes: 2 additions & 0 deletions package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ library:
- Numeric.LinearAlgebra.Static.Backprop
ghc-options:
- -fwarn-redundant-constraints
dependencies:
- ANum >= 0.2

tests:
hmatrix-backprop-test:
Expand Down
132 changes: 76 additions & 56 deletions src/Numeric/LinearAlgebra/Static/Backprop.hs
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
Expand Down Expand Up @@ -55,11 +51,10 @@
--
-- Some functions are notably unlifted:
--
-- * 'svd': I can't find any resources that allow you to backpropagate
-- * 'H.svd': I can't find any resources that allow you to backpropagate
-- if the U and V matrices are used! If you find one, let me know, or
-- feel free to submit a PR! Because of this, Currently only a version
-- that exports only the singular values is exported. 'svd_' works for
-- 'evalBP' but not 'gradBP'.
-- that exports only the singular values is exported.
-- * 'H.svdTall', 'H.svdFlat': Not sure where to start for these
-- * 'qr': Same story.
-- https://github.com/tensorflow/tensorflow/issues/6504 might yield
Expand All @@ -82,7 +77,12 @@
-- * 'H.withRows' and 'H.withColumns' made "type-safe", without
-- existential types, with 'fromRows' and 'fromColumns'.
--
-- Added 'sumElements', as well, for convenience.
-- Some other notes:
--
-- * Added 'sumElements', as well, for convenience.
-- * Lifted 'H.svd' is temporarily currently not exported, due to a bug
-- in /hmatrix/ in the 'H.diagR' function. When this bug is patched,
-- 'H.svd' will exported.

module Numeric.LinearAlgebra.Static.Backprop (
-- * Vector
Expand Down Expand Up @@ -124,8 +124,9 @@ module Numeric.LinearAlgebra.Static.Backprop (
, (#>)
, (<.>)
-- * Factorizations
, svd
, svd_
-- $svd
-- , svd
-- , svd_
, H.Eigen
, eigensystem
, eigenvalues
Expand Down Expand Up @@ -183,7 +184,7 @@ module Numeric.LinearAlgebra.Static.Backprop (
, (<·>)
) where

import Control.Applicative
import Data.ANum
import Data.Maybe
import Data.Proxy
import Foreign.Storable
Expand Down Expand Up @@ -435,30 +436,35 @@ infixr 8 #>
infixr 8 <.>
{-# INLINE (<.>) #-}

-- $svd
--
-- Note: Lifted versions of 'H.svd' temporarily unexported due to a bug in
-- /hmatrix/, in 'H.diagR'.

-- | Can only get the singular values, for now. Let me know if you find an
-- algorithm that can compute the gradients based on differentials for the
-- other matricies!
--
-- TODO: bug in diagR
svd :: forall m n s. (Reifies s W, KnownNat m, KnownNat n)
_svd :: forall m n s. (Reifies s W, KnownNat m, KnownNat n)
=> BVar s (H.L m n)
-> BVar s (H.R n)
svd = liftOp1 . op1 $ \x ->
_svd = liftOp1 . op1 $ \x ->
let (u, σ, v) = H.svd x
in ( σ
, \-> u H.<> H.diagR 0H.<> H.tr v
)
{-# INLINE svd #-}
{-# INLINE _svd #-}

-- | Version of 'svd' that returns the full SVD, but if you attempt to find
-- the gradient, it will fail at runtime if you ever use U or V.
--
-- Useful if you want to only use 'evalBP'.
svd_
_svd_
:: forall m n s. (Reifies s W, KnownNat m, KnownNat n)
=> BVar s (H.L m n)
-> (BVar s (H.L m m), BVar s (H.R n), BVar s (H.L n n))
svd_ r = (t ^^. _1, t ^^. _2, t ^^. _3)
_svd_ r = (t ^^. _1, t ^^. _2, t ^^. _3)
where
o :: Op '[H.L m n] (T3 (H.L m m) (H.R n) (H.L n n))
o = op1 $ \x ->
Expand All @@ -472,15 +478,18 @@ svd_ r = (t ^^. _1, t ^^. _2, t ^^. _3)
{-# INLINE o #-}
t = liftOp1 o r
{-# NOINLINE t #-}
{-# INLINE svd_ #-}
{-# INLINE _svd_ #-}

helpEigen :: KnownNat n => H.Sym n -> (H.R n, H.L n n, H.L n n, H.L n n)
helpEigen x = (l, v, H.inv v, H.tr v)
where
(l, v) = H.eigensystem x
{-# INLINE helpEigen #-}

-- | TODO: check if gradient is really symmetric
-- | /NOTE/ The gradient is not necessarily symmetric! The gradient is not
-- meant to be retireved directly; insteadl, 'eigenvalues' is meant to be
-- used as a part of a larger computation, and the gradient as an
-- intermediate step.
eigensystem
:: forall n s. (Reifies s W, KnownNat n)
=> BVar s (H.Sym n)
Expand All @@ -505,7 +514,10 @@ eigensystem u = (t ^^. _1, t ^^. _2)
{-# NOINLINE t #-}
{-# INLINE eigensystem #-}

-- | TODO: check if gradient is really symmetric
-- | /NOTE/ The gradient is not necessarily symmetric! The gradient is not
-- meant to be retireved directly; insteadl, 'eigenvalues' is meant to be
-- used as a part of a larger computation, and the gradient as an
-- intermediate step.
eigenvalues
:: forall n s. (Reifies s W, KnownNat n)
=> BVar s (H.Sym n)
Expand All @@ -523,7 +535,10 @@ eigenvalues = liftOp1 . op1 $ \x ->
-- The paper also suggests a potential imperative algorithm that might
-- help. Need to benchmark to see what is best.
--
-- TODO: Check if gradient is really symmetric
-- /NOTE/ The gradient is not necessarily symmetric! The gradient is not
-- meant to be retireved directly; insteadl, 'eigenvalues' is meant to be
-- used as a part of a larger computation, and the gradient as an
-- intermediate step.
chol
:: forall n s. (Reifies s W, KnownNat n)
=> BVar s (H.Sym n)
Expand Down Expand Up @@ -642,6 +657,27 @@ mean
mean = liftOp1 . op1 $ \x -> (H.mean x, H.konst . (/ H.norm_0 x))
{-# INLINE mean #-}

gradCov
:: forall m n. (KnownNat m, KnownNat n)
=> H.L m n
-> H.R n
-> H.Sym n
-> H.L m n
gradCov x μ dσ = fromJust
. (`H.withRows` H.exactDims)
. map (subtract (dDiffsSum / m))
$ H.toRows dDiffs
where
diffs :: H.L m n
Just diffs = (`H.withRows` H.exactDims)
. map (subtract μ)
$ H.toRows x
dDiffs = H.konst (2/n) * (diffs H.<> H.tr (H.unSym dσ))
dDiffsSum = sum . H.toRows $ dDiffs
m = fromIntegral $ natVal (Proxy @m)
n = fromIntegral $ natVal (Proxy @n)
{-# INLINE gradCov #-}

-- | Mean and covariance. If you know you only want to use one or the
-- other, use 'meanL' or 'cov'.
meanCov
Expand All @@ -652,15 +688,13 @@ meanCov v = (t ^^. _1, t ^^. _2)
where
m = fromInteger $ natVal (Proxy @m)
t = ($ v) . liftOp1 . op1 $ \x ->
( tupT2 (H.meanCov x)
, \(T2 dMean dCov) ->
let Just gradMean = (`H.withRows` H.exactDims)
. replicate m
. (/ H.konst (fromIntegral m))
$ dMean
gradCov = undefined dCov
in gradMean + gradCov
)
let (μ, σ) = H.meanCov x
in ( T2 μ σ
, \(T2 dμ dσ) ->
let Just gradMean = (`H.withRows` H.exactDims) $
replicate m (dμ / H.konst (fromIntegral m))
in gradMean + gradCov x μ dσ
)
{-# NOINLINE t #-}
{-# INLINE meanCov #-}

Expand All @@ -685,14 +719,10 @@ cov
=> BVar s (H.L m n)
-> BVar s (H.Sym n)
cov = liftOp1 . op1 $ \x ->
( snd (H.meanCov x)
, undefined
)
-- where
-- m = fromInteger $ natVal (Proxy @m)
let (μ, σ) = H.meanCov x
in (σ, gradCov x μ)
{-# INLINE cov #-}


mul :: ( Reifies s W
, KnownNat m
, KnownNat k
Expand Down Expand Up @@ -762,8 +792,6 @@ cross = liftOp2 . op2 $ \x y ->
{-# INLINE cross #-}

-- | Create matrix with diagonal, and fill with default entries
--
-- TODO: a bug in H.diagR
diagR
:: forall m n k field vec mat s.
( Reifies s W
Expand Down Expand Up @@ -1068,20 +1096,6 @@ sumElements = liftOp1 . op1 $ \x ->
)
{-# INLINE sumElements #-}

-- | Only needed until https://github.com/DanBurton/ANum/pull/1 goes
-- through
newtype Mayb a = Mayb { getMayb :: Maybe a }
deriving (Functor, Applicative, Foldable, Traversable)

instance Num a => Num (Mayb a) where
(+) = liftA2 (+)
(-) = liftA2 (-)
(*) = liftA2 (*)
negate = fmap negate
abs = fmap abs
signum = fmap signum
fromInteger = pure . fromInteger

-- | If there are extra items in the total derivative, they are dropped.
-- If there are missing items, they are treated as zero.
extractV
Expand Down Expand Up @@ -1149,9 +1163,9 @@ create
:: forall t s d q. (Reifies q W, H.Sized t s d, Num s, Num (d t))
=> BVar q (d t)
-> Maybe (BVar q s)
create = fmap (getMayb . sequenceVar) . liftOp1 $
opIso (Mayb . H.create)
(maybe 0 H.extract . getMayb )
create = fmap (unANum . sequenceVar) . liftOp1 $
opIso (ANum . H.create)
(maybe 0 H.extract . unANum )
{-# INLINE create #-}


Expand All @@ -1172,7 +1186,10 @@ takeDiag = liftOp1 . op1 $ \x ->
)
{-# INLINE takeDiag #-}

-- | $(M + M^T) / 2$
-- |
-- \[
-- \frac{1}{2} (M + M^T)
-- \]
sym :: (Reifies s W, KnownNat n)
=> BVar s (H.Sq n)
-> BVar s (H.Sym n)
Expand All @@ -1182,7 +1199,10 @@ sym = liftOp1 . op1 $ \x ->
)
{-# INLINE sym #-}

-- | $M^T M$
-- |
-- \[
-- M^T M
-- \]
mTm :: (Reifies s W, KnownNat m, KnownNat n)
=> BVar s (H.L m n)
-> BVar s (H.Sym n)
Expand Down
1 change: 1 addition & 0 deletions stack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ packages:
# (e.g., acme-missiles-0.3)
extra-deps:
- backprop-0.1.2.0
- ANum-0.2.0.1

# Override default flag values for local packages and extra-deps
# flags: {}
Expand Down
10 changes: 4 additions & 6 deletions test/Nudge.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}

module Nudge where

Expand Down Expand Up @@ -35,7 +34,7 @@ nudge :: Double
nudge = 1e-6

eps :: Double
eps = 1e-10
eps = 1e-9

class (Num c, Show c, Show (TIx c)) => Testing c where
type TIx c :: Type
Expand All @@ -62,7 +61,7 @@ instance Testing Double where
ixLens _ = id
scalarize = abs
genTest = Gen.filter ((> eps) . (**2)) $
Gen.double (Range.linearFracFrom 0 (-10) 10)
Gen.double (Range.linearFracFrom 0 (-5) 5)

instance KnownNat n => Testing (H.R n) where
type TIx (H.R n) = Int
Expand All @@ -73,7 +72,6 @@ instance KnownNat n => Testing (H.R n) where
where
n = fromIntegral $ natVal (Proxy @n)


instance (KnownNat n, KnownNat m) => Testing (H.L n m) where
type TIx (H.L n m) = (Int, Int)
allIx m = Ix.range ((0,0), bimap pred pred (H.size m))
Expand All @@ -88,14 +86,14 @@ instance Testing (HU.Vector Double) where
allIx v = [0 .. HU.size v - 1]
ixLens = ixContainer
scalarize = liftOp1 . op1 $ \xs -> (HU.sumElements xs, (`HU.konst` HU.size xs))
genTest = HU.fromList <$> Gen.list (Range.singleton 5) genTest
genTest = HU.fromList <$> replicateM 3 genTest

instance Testing (HU.Matrix Double) where
type TIx (HU.Matrix Double) = (Int, Int)
allIx m = Ix.range ((0,0), bimap pred pred (HU.size m))
ixLens = ixContainer
scalarize = liftOp1 . op1 $ \xs -> (HU.sumElements xs, (`HU.konst` HU.size xs))
genTest = HU.fromLists <$> (replicateM 5 . replicateM 4) genTest
genTest = HU.fromLists <$> (replicateM 3 . replicateM 2) genTest

instance (KnownNat n, Testing a) => Testing (SV.Vector n a) where
type TIx (SV.Vector n a) = (Finite n, TIx a)
Expand Down
Loading

0 comments on commit f0c29ed

Please sign in to comment.